深度学习-fashion_mnist预测

一、导入库函数,加载fashion_mnist数据集

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

(x_train,y_train),(x_test,y_test) = tf.keras.datasets.fashion_mnist.load_data()

二、数据探索

plt.matshow(x_train[320])#任意特征数
mnist = ['T恤(T-shirt)','裤子(Trouser)','套头衫(Pullover)','连衣裙(Dress)','外套(Coat)','凉鞋(Sandal)','衬衫(Shirt)','运动鞋(Sneaker)','包(Bag)','靴子(Ankle boot)']#多对数据集索引命名
print("样本标签为:{} 样本内容为:{}".format(y_train[320],mnist[y_train[320]]))#输出样本并输出该特征图

深度学习-fashion_mnist预测_第1张图片

输出训练集特征320对应标签

深度学习-fashion_mnist预测_第2张图片

三、数据预处理

x_train,x_test = x_train/255.0,x_test/255.0#归一化

y_train = tf.keras.utils.to_categorical(y_train)
y_test = tf.keras.utils.to_categorical(y_test)

四、构建模型

#构建模型
model=tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28,28)),#按行,二维变成一维
     tf.keras.layers.Dense(128,activation='relu'),
    tf.keras.layers.Dense(64,activation='tanh'),
     tf.keras.layers.Dense(10,activation='softmax')#不可更改
])
print(model.summary())#打印模型

输出模型

深度学习-fashion_mnist预测_第3张图片

#设置优化器和损失函数
model.compile(optimizer=tf.keras.optimizers.Adam(0.001),#学习率,Adam,SGD
             loss = tf.keras.losses.categorical_crossentropy,#损失函数,交叉熵不可改,one-hot编码
              #loss = tf.keras.losses.sparse_categorical_crossentropy             
             metrics=['acc'])

五、模型训练

model.fit(x_train,y_train,epochs=29)#训练29轮

六、模型评估

深度学习-fashion_mnist预测_第4张图片

七、模型预测

np.set_printoptions(precision=3,suppress=True)
pred = model.predict(x_test[:9])
pred

np.argmax(y_test[:9],axis=1)

pred = np.argmax(model.predict(x_test[3:20]),axis=1)
print(pred)
np.argmax(y_test[3:20],axis=1)

y_test[9]

预测到标签为7的是一个运动鞋

深度学习-fashion_mnist预测_第5张图片

你可能感兴趣的:(深度学习,机器学习,python)