tensorflow2.0入门 fashion_mnist实战

github源码下载link

fashion_mnist 数据集

fashion_mnist是不同种类衣物的图片,有鞋子,T恤,裙子等

名目 数量
种类 10
训练集图片 60000
测试集图片 10000

fashion_mnist分类demo

以下各个代码块相连,可以作为完整demo

加载fashion_mnist数据集,并归一化

keras自带加载fashion_mnist的接口

import tensorflow as tf
import tensorflow.keras.layers as layers
import matplotlib.pyplot as plt

# funciton to show pictures
def poltImage(images_arr):
    numbers = images_arr.shape[0]
    fig, axes = plt.subplots(1, numbers, figsize=(10,10))
    axes = axes.flatten()
    for img, ax in zip(images_arr, axes):
        ax.imshow(img)
        ax.axis('off')
    plt.tight_layout()
    plt.show()

# to load dataset
fashion_mnist = tf.keras.datasets.fashion_mnist

(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()

poltImage(x_train[:5])

# normalization
x_train, x_test = x_train/255.0, x_test/255.0


建立模型
# function to build model
def build_model():
    model = tf.keras.models.Sequential([
        layers.Reshape((28, 28, 1), input_shape=(28, 28)),
        layers.Conv2D(64, (5,5), activation='relu'),
        layers.MaxPool2D((2,2)),
        layers.Conv2D(128, (3,3), activation='relu'),
        layers.GlobalAveragePooling2D(),
        layers.Dense(500, use_bias=False, activation='relu'),
        layers.Dense(10, activation='softmax')
    ])
    return model

model = build_model()
model.summary()

编译训练并保存参数
model.compile(
    optimizer=tf.keras.optimizers.Adam(), # function 的括号不要丢了
    loss=tf.keras.losses.SparseCategoricalCrossentropy(),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]
    # 这里不能使用 tf.keras.metrics.Accuracy()
)  
# 与
# model.compile(optimizer='adam',
#               loss='sparse_categorical_crossentropy',
#               metrics=['accuracy'])
# 等价
model.fit(x_train, y_train, epochs=20) # 训练的epochs随便调,一定范围内越高越好
model.save_weights('./fashion_mnist/ckpt')
新建模型并加载参数
model1 = build_model()
model1.load_weights('./fashion_mnist/ckpt')
测试集表现

为了使新建的model1 work,必须重新compile

model1.compile(
    optimizer='adam',
    loss=tf.keras.losses.SparseCategoricalCrossentropy(),
    metrics=['accuracy']
)
model1.evaluate(x_test, y_test)
运行结果

测试集成绩 93.45%。
tensorflow2.0入门 fashion_mnist实战_第1张图片
20个epochs结束时,准确率还在上升,epochs增大还可以继续提高准确率。

你可能感兴趣的:(tensorflow入门)