fashion_mnist是不同种类衣物的图片,有鞋子,T恤,裙子等
名目 | 数量 |
---|---|
种类 | 10 |
训练集图片 | 60000 |
测试集图片 | 10000 |
以下各个代码块相连,可以作为完整demo
keras自带加载fashion_mnist的接口
import tensorflow as tf
import tensorflow.keras.layers as layers
import matplotlib.pyplot as plt
# funciton to show pictures
def poltImage(images_arr):
numbers = images_arr.shape[0]
fig, axes = plt.subplots(1, numbers, figsize=(10,10))
axes = axes.flatten()
for img, ax in zip(images_arr, axes):
ax.imshow(img)
ax.axis('off')
plt.tight_layout()
plt.show()
# to load dataset
fashion_mnist = tf.keras.datasets.fashion_mnist
(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
poltImage(x_train[:5])
# normalization
x_train, x_test = x_train/255.0, x_test/255.0
# function to build model
def build_model():
model = tf.keras.models.Sequential([
layers.Reshape((28, 28, 1), input_shape=(28, 28)),
layers.Conv2D(64, (5,5), activation='relu'),
layers.MaxPool2D((2,2)),
layers.Conv2D(128, (3,3), activation='relu'),
layers.GlobalAveragePooling2D(),
layers.Dense(500, use_bias=False, activation='relu'),
layers.Dense(10, activation='softmax')
])
return model
model = build_model()
model.summary()
model.compile(
optimizer=tf.keras.optimizers.Adam(), # function 的括号不要丢了
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]
# 这里不能使用 tf.keras.metrics.Accuracy()
)
# 与
# model.compile(optimizer='adam',
# loss='sparse_categorical_crossentropy',
# metrics=['accuracy'])
# 等价
model.fit(x_train, y_train, epochs=20) # 训练的epochs随便调,一定范围内越高越好
model.save_weights('./fashion_mnist/ckpt')
model1 = build_model()
model1.load_weights('./fashion_mnist/ckpt')
为了使新建的model1 work,必须重新compile
model1.compile(
optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=['accuracy']
)
model1.evaluate(x_test, y_test)