目录
-
- 1. 加载需要的包
- 2. 载入数据集
- 3. 搭建网络模型
- 4. 模型训练
- 5. 模型测试
- 6. 训练过程可视化
1. 加载需要的包
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, optimizers, losses, datasets, Sequential
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Dense, Dropout, Activation, Flatten, Conv2D, MaxPooling2D, GlobalMaxPool2D
from tensorflow.keras.optimizers import RMSprop
import matplotlib.pyplot as plt
import numpy as np
2. 载入数据集
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape(60000, 28, 28, 1)
x_test = x_test.reshape(10000, 28, 28, 1)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /=255
y_train = keras.utils.to_categorical(y_train, num_classes=10)
y_test = keras.utils.to_categorical(y_test, num_classes=10)
3. 搭建网络模型
model = Sequential()
model.add(Conv2D(32, (3, 3), padding='same',input_shape=x_train.shape[1:]))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(64, (3, 3), padding='same'))
model.add(Activation('relu'))
model.add(GlobalMaxPool2D())
model.add(Flatten())
model.add(Dense(10))
model.add(Activation('softmax'))
print(model.summary())
model.compile(loss='categorical_crossentropy', optimizer=RMSprop(), metrics=['accuracy'])
data:image/s3,"s3://crabby-images/fb44b/fb44b8aed1373ace3839a0657507fe35391c9130" alt="搭建CNN网络训练mnist数据集_第1张图片"
4. 模型训练
history = model.fit(x_train, y_train,
batch_size=128,
epochs=10,
verbose=1,
validation_data=(x_test, y_test))
data:image/s3,"s3://crabby-images/22b40/22b40bb3c8cdf1ffaaee2767f2e37e0bcabbd0cb" alt="搭建CNN网络训练mnist数据集_第2张图片"
5. 模型测试
score = model.evaluate(x_test, y_test, verbose=1)
print('Test loss:', score[0])
print('Test accuracy:', score[1])
data:image/s3,"s3://crabby-images/064de/064deb652cdee95f8913a0b63b1b8a9fca04ceb6" alt="在这里插入图片描述"
6. 训练过程可视化
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('Model accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.show()
data:image/s3,"s3://crabby-images/7e629/7e6299fd8f48ca2ed3479d5138bf89f03fbe85f8" alt="搭建CNN网络训练mnist数据集_第3张图片"
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.show()
data:image/s3,"s3://crabby-images/60cde/60cde8642d86e9a2b6ea6c862ceadccb0eeaa334" alt="搭建CNN网络训练mnist数据集_第4张图片"