Python+Tensorflow学习(三)——fashion_mnist数据集

Python+Tensorflow学习(三)——fashion_mnist数据集

学习视频链接:(强推)TensorFlow官方入门实操课程

其他学习记录:
Python+Tensorflow学习(二)——初试keras
源码如下:

# -*- coding = utf-8 -*-
# @Time : 2021/8/9 16:51
# @Author : 西兰花
# @File : tf02.py
# @Software : PyCharm


from tensorflow import keras    # Keras是tensorflow中的一个高级API
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np


"""
例2:利用fashion_mnist数据集识别衣服、鞋子等
"""
# 以load_data()方式引入fashion_mnist数据集
# 并定义了train_images、train_labels、test_images、test_labels四个变量
fashion_mnist = keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

# 查看是否加载成功
# print(train_images.shape)
# print(train_labels.shape)
# print(train_images[0])   # 查看某张图片具体的值,其值为该图片的灰度图具体像素值
# print(test_images.shape)
# print(test_labels.shape)
# print(train_labels[0])

# 利用plt显示源图片
# plt.figure()    # 创建plot新窗口
# plt.imshow(train_images[99])     # 选择需要显示的图片
# plt.colorbar()  # 色彩栏
# plt.grid(False)     # 取消栅格
# plt.show()      # 显示绘图——此步不可缺


# 训练模型
# model = keras.Sequential([
#     keras.layers.Flatten(input_shape=(28, 28)),         # 输入层
#     keras.layers.Dense(128, activation=tf.nn.relu),     # 中间层
#     keras.layers.Dense(10, activation=tf.nn.softmax)    # 输出层
# ])


class myCallback(tf.keras.callbacks.Callback):  # 自动终止训练(预防过拟合现象)
    def on_epoch_end(self, epoch, logs={}):
        if logs.get('loss') < 0.4:
            print("\nLoss is low so cancelling training!")
            self.model.stop_training = True


callbacks = myCallback()
mnist = tf.keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
training_images_scaled = train_images/255.0
test_images_scaled = test_images/255.0
model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(512, activation=tf.nn.relu),
    tf.keras.layers.Dense(10, activation=tf.nn.softmax)
])

# 查看训练模型
# model.summary()

# 测试训练模型
model.compile(optimizer=tf.optimizers.Adam(), loss=tf.losses.sparse_categorical_crossentropy, metrics=['accuracy'])
model.fit(train_images, train_labels, epochs=5)
model.evaluate(test_images, test_labels)

print(np.argmax(model.predict([[test_images[0]/255]])))
print(test_labels[0])


输出结果:
Python+Tensorflow学习(三)——fashion_mnist数据集_第1张图片
Python+Tensorflow学习(三)——fashion_mnist数据集_第2张图片
Python+Tensorflow学习(三)——fashion_mnist数据集_第3张图片
下载fashion_mnist数据集遇到的问题及解决方法:
第一次运行程序时自动下载fashion_mnist数据集,系统加载的链接可能出现下载失败现象;
Python+Tensorflow学习(三)——fashion_mnist数据集_第4张图片
解决方法:
参考文章:关于tensorflow入门keras的Fashion-mnist数据集无法下载的解决方法
第一步:到fashion_mnist数据集官网手动下载数据集,共四个压缩包;
Python+Tensorflow学习(三)——fashion_mnist数据集_第5张图片
第二步:在C盘中找到.keras文件夹(笔者的路径是:C:\Users\Administrator.keras\datasets\fashion-mnist),并将下载四个压缩包放入fashion_mnist文件夹;
Python+Tensorflow学习(三)——fashion_mnist数据集_第6张图片
第三步:重新运行程序即可。
Python+Tensorflow学习(三)——fashion_mnist数据集_第7张图片

你可能感兴趣的:(tensorflow,深度学习,机器学习,神经网络,python)