TensorFlow 加载数据集(fashion_mnist)

TensorFlow 加载数据集(fashion_mnist)

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import sklearn
import pandas as pd
import os
import sys
import time
import tensorflow as tf
from tensorflow_core.python import keras

# 0.打印导入模块的版本
print(tf.__version__)
print(sys.version_info)
for module in mpl, np, sklearn, pd, tf, keras:
    print("%s version:%s" % (module.__name__, module.__version__))

# 1.加载fashion_mnist数据集
fashion_mnist = keras.datasets.fashion_mnist
(x_train_all, y_train_all), (x_test, y_test) = fashion_mnist.load_data()

# 2.拆分验证集, 训练集
x_valid, x_train = x_train_all[:5000], x_train_all[5000:]
y_valid, y_train = y_train_all[:5000], y_train_all[5000:]

# 3.打印 验证集, 训练集, 测试集
print(x_train.shape, y_train.shape)
print(x_valid.shape, y_valid.shape)
print(x_test.shape, y_test.shape)

class_names = ["T-shirt", "Trouser", "Pullover", "Dress", "Coat", "Sandal",
               "Shirt", "Sneaker", "Bag", "Ankle boot"]


def show_single_img(img_arr):

    plt.imshow(img_arr, cmap="binary")
    plt.show()


def show_imgs(n_rows, n_cols, x_data, y_data, class_names):

    assert len(x_data) == len(y_data)
    assert n_rows * n_cols < len(x_data)

    plt.figure(figsize=(n_cols*1.4, n_rows*1.6))
    for row in range(n_rows):
        for col in range(n_cols):
            index = n_cols * row + col
            plt.subplot(n_rows, n_cols, index + 1)
            plt.imshow(x_data[index], cmap="binary", interpolation="nearest")
            plt.axis("off")
            plt.title(class_names[y_data[index]])

    plt.show()


def main():

    # 4.显示单张照片
    show_single_img(x_train[0])
    
    # 5.显示多张照片
    show_imgs(3, 5, x_train, y_train, class_names)


if __name__ == '__main__':
    main()

你可能感兴趣的:(TensorFlow)