tensorflow fashion_mnist数据集模型训练及预测

✨ 博客主页:小小马车夫的主页
✨ 所属专栏:Tensorflow

文章目录

  • 前言
  • 一、环境
  • 二、fashion_mnist数据集介绍
  • 三、fashion_mnist数据集下载和展示
  • 四、数据预处理
  • 五、构建模型和训练模型
  • 六、模型预测
  • 总结


前言

前面介绍mnist手写数字集训练,本文对fashion_mnist数据集训练和预测进行简要介绍。


一、环境

MacOS: 13.0
Python: 3.9.13
Tensorflow: 2.11.0

二、fashion_mnist数据集介绍

fashion_mnist数据集和mnist数据集类似,都是28x28的灰度图片,区分是fashion_mnist数据集是服装图片,具体分类如下图:

分类 英文描述 中文描述
0 t-shirt T恤
1 trouser 牛仔裤
2 pullover 套衫
3 dress 裙子
4 coat 外套
5 sandal 凉鞋
6 shirt 衬衫
7 sneaker 运动鞋
8 bag
9 ankle boot 短靴

三、fashion_mnist数据集下载和展示

运用tensorflow下载fashion_mnist数据集与mnist类似,代码如下:

import tensorflow as tf
from tensorflow import keras
import numpy as np

fashion_mnist = keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
print(train_images.shape, train_labels.shape)
print(test_images.shape, test_labels.shape)

输出:

(60000, 28, 28) (60000,)
(10000, 28, 28) (10000,)

可以看到训练集是60000张28x28的灰度图,测试集是10000张28x28的灰度图。
一些样例展示如下:
tensorflow fashion_mnist数据集模型训练及预测_第1张图片

四、数据预处理

数据预处理主要是对图片归一化处理,如下:

train_images=train_images / 255.
test_images = test_images / 255.

五、构建模型和训练模型

模型构建

model = keras.Sequential()
model.add(keras.layers.Flatten(input_shape=(28, 28)))
model.add(keras.layers.Dense(128, activation=tf.nn.relu))
model.add(keras.layers.Dense(10, activation=tf.nn.softmax))
model.summary()

模型结构如下:

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 flatten (Flatten)           (None, 784)               0         
                                                                 
 dense (Dense)               (None, 128)               100480    
                                                                 
 dense_1 (Dense)             (None, 10)                1290      
                                                                 
=================================================================
Total params: 101,770
Trainable params: 101,770
Non-trainable params: 0
_________________________________________________________________

模型训练

class MyCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs={}):
    	#loss小于0.25就停止训练
        if logs.get('loss') < 0.25:
            self.model.stop_training = True
callbacks = MyCallback()
model.compile(optimizer=tf.optimizers.Adam(), loss=tf.losses.sparse_categorical_crossentropy, metrics=['acc'])
h = model.fit(train_images, train_labels, batch_size=32, epochs=15, validation_data=(test_images_scaled, test_labels), callbacks=[callbacks])

查看结果

Epoch 1/15
1875/1875 [==============================] - 11s 5ms/step - loss: 0.5031 - acc: 0.8239 - val_loss: 0.4201 - val_acc: 0.8499
Epoch 2/15
1875/1875 [==============================] - 9s 5ms/step - loss: 0.3774 - acc: 0.8648 - val_loss: 0.4333 - val_acc: 0.8482
Epoch 3/15
1875/1875 [==============================] - 9s 5ms/step - loss: 0.3371 - acc: 0.8773 - val_loss: 0.3662 - val_acc: 0.8667
Epoch 4/15
1875/1875 [==============================] - 9s 5ms/step - loss: 0.3145 - acc: 0.8845 - val_loss: 0.3697 - val_acc: 0.8667
Epoch 5/15
1875/1875 [==============================] - 10s 5ms/step - loss: 0.2929 - acc: 0.8921 - val_loss: 0.3404 - val_acc: 0.8794
Epoch 6/15
1875/1875 [==============================] - 10s 5ms/step - loss: 0.2805 - acc: 0.8958 - val_loss: 0.3453 - val_acc: 0.8793
Epoch 7/15
1875/1875 [==============================] - 9s 5ms/step - loss: 0.2683 - acc: 0.9009 - val_loss: 0.3452 - val_acc: 0.8778
Epoch 8/15
1875/1875 [==============================] - 9s 5ms/step - loss: 0.2566 - acc: 0.9032 - val_loss: 0.3370 - val_acc: 0.8820
Epoch 9/15
1875/1875 [==============================] - 9s 5ms/step - loss: 0.2480 - acc: 0.9065 - val_loss: 0.3482 - val_acc: 0.8789

用图标显示损失曲线和准确率曲线

loss_list = h.history['loss']
acc_list = h.history['acc']
test_loss_list = h.history['val_loss']
test_acc_list = h.history['val_acc']

plt.rcParams['font.sans-serif'] = ['Songti SC']
plt.rcParams['axes.unicode_minus'] = False

plt.figure(figsize=(20, 10))

plt.subplot(221)
plt.ylabel('loss')
plt.plot(loss_list, color='blue', marker='.', label='train_loss')
plt.plot(test_loss_list, color='red', marker='.', label='val_loss')
plt.legend(loc='upper left')
plt.title('损失曲线', fontsize=16)

plt.subplot(222)
plt.ylabel('acc')
plt.plot(acc_list, color='blue', marker='.', label='train_acc')
plt.plot(test_acc_list, color='red', marker='.', label='val_acc')
plt.legend(loc='upper left')
plt.title('准确率曲线', fontsize=16)
plt.show()

输出:
tensorflow fashion_mnist数据集模型训练及预测_第2张图片

六、模型预测

选一个图像进行预测:

image = tf.cast(test_images[1], tf.float32)
image = tf.reshape(image, [1, 28, 28])
np.argmax(model.predict(image))
print(test_labels[1])
plt.imshow(test_images[1])
plt.show()

输出:

1/1 [==============================] - 0s 408ms/step
2

tensorflow fashion_mnist数据集模型训练及预测_第3张图片

总结

本文主要介绍了tensorflow fashion_mnist的下载、训练、预测,模型用的全连接网络。

如果觉得有些帮助或觉得文章还不错,请关注一下博主,你的关注是我持续写作的动力。另外,如果有什么问题,可以在评论区留言,或者私信博主,博主看到后会第一时间进行回复。
【间歇性的努力和蒙混过日子,都是对之前努力的清零】
欢迎转载,转载请注明出处:https://blog.csdn.net/xxm524/article/details/128160073

你可能感兴趣的:(Tensorflow,fashion_mnist,tensorflow,训练,预测)