1.加载Fashion MNIST数据集
# TensorFlow and tf.keras
import tensorflow as tf
from tensorflow import keras
# Helper libraries
import numpy as np
import matplotlib.pyplot as plt
print(tf.__version__)
2.2.0
fashion_mnist = keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
有6万张,28*28像素
print(train_images.shape)
(60000, 28, 28)
2.画一张图
plt.imshow(train_images[0])
#构建模型
model=keras.Sequential([
keras.layers.Flatten(input_shape=(28,28)),#输入层,因为训练数据每一张是28*28的,所以input_shape=(28,28)
keras.layers.Dense(128,activation=tf.nn.relu),#ReLu激活函数
keras.layers.Dense(10,activation=tf.nn.softmax)#输出层,有十个类别
])
model.summary()
(784+1)*128#加一是因为输入层和隐藏层都加了一个bias,相当于一个截距
(128+1)*10
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
flatten (Flatten) (None, 784) 0
_________________________________________________________________
dense (Dense) (None, 128) 100480
_________________________________________________________________
dense_1 (Dense) (None, 10) 1290
=================================================================
Total params: 101,770
Trainable params: 101,770
Non-trainable params: 0
_________________________________________________________________
1290
3.训练和评估模型
#为了让训练效果更好,进行数据归一化
train_images=train_images/255.0
test_images=test_images/255.0
#编译模型
model.compile(optimizer=tf.optimizers.Adam(),loss=tf.losses.sparse_categorical_crossentropy,metrics=['accuracy'])
#训练模型
model.fit(train_images,train_labels,epochs=5)
Epoch 1/5
1875/1875 [==============================] - 2s 1ms/step - loss: 0.5292 - accuracy: 0.8149
Epoch 2/5
1875/1875 [==============================] - 2s 1ms/step - loss: 0.3807 - accuracy: 0.8639
Epoch 3/5
1875/1875 [==============================] - 2s 1ms/step - loss: 0.3450 - accuracy: 0.8755
Epoch 4/5
1875/1875 [==============================] - 2s 1ms/step - loss: 0.3213 - accuracy: 0.8825
Epoch 5/5
1875/1875 [==============================] - 2s 1ms/step - loss: 0.3061 - accuracy: 0.8876
<tensorflow.python.keras.callbacks.History at 0x1d62490af10>
#评估模型
model.evaluate(test_images,test_labels)
313/313 [==============================] - 0s 1ms/step - loss: 0.3635 - accuracy: 0.8690
[0.3634686768054962, 0.8690000176429749]
predictions =model.predict(test_images)
model.predict(test_images)
array([[4.08061032e-06, 7.08236669e-09, 3.36254260e-08, ...,
1.42676504e-02, 1.72525415e-06, 9.61163998e-01],
[3.67407029e-06, 9.00608662e-14, 9.99114692e-01, ...,
3.96537495e-16, 8.66414673e-09, 1.64345019e-13],
[1.12768330e-05, 9.99987841e-01, 6.11865669e-09, ...,
5.98114232e-15, 1.78030135e-09, 1.77714416e-15],
...,
[5.27661177e-04, 5.20954391e-10, 2.40708818e-04, ...,
2.22873410e-07, 9.98919129e-01, 4.67528860e-09],
[6.14244800e-06, 9.98748422e-01, 2.09969443e-07, ...,
1.91866835e-11, 1.46072097e-07, 5.69366720e-10],
[7.92981227e-05, 1.63143454e-08, 1.21346839e-05, ...,
1.69552374e-03, 1.03444676e-04, 5.51283892e-06]], dtype=float32)
predictions[0]
array([4.0806103e-06, 7.0823667e-09, 3.3625426e-08, 3.8752265e-08,
1.4857939e-07, 2.4561029e-02, 1.2778094e-06, 1.4267650e-02,
1.7252542e-06, 9.6116400e-01], dtype=float32)
np.argmax(predictions[0])
plt.imshow(test_images[0])
class myCallback(tf.keras.callbacks.Callback):
def on_epoch_end(self,epoch,logs={}):
if(logs.get('loss')<0.4):
print('loss太低了,我取消训练了')
self.model.stop_training=True
callbacks=myCallback()
mnist=tf.keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
train_images=train_images/255.0
test_images=test_images/255.0
#构建模型
model=keras.Sequential([
keras.layers.Flatten(input_shape=(28,28)),#输入层,因为训练数据每一张是28*28的,所以input_shape=(28,28)
keras.layers.Dense(128,activation=tf.nn.relu),#ReLu激活函数
keras.layers.Dense(10,activation=tf.nn.softmax)#输出层,有十个类别
])
#编译模型
model.compile(optimizer=tf.optimizers.Adam(),loss=tf.losses.sparse_categorical_crossentropy,metrics=['accuracy'])
#训练模型
model.fit(train_images,train_labels,epochs=5,callbacks=[callbacks])