import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
print("TensorFlow version:", tf.__version__)
TensorFlow version: 2.6.4
(train_images, train_labels), (test_images, test_labels) = keras.datasets.cifar10.load_data()
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz 170500096/170498071 [==============================] - 2s 0us/step 170508288/170498071 [==============================] - 2s 0us/step
data_augmentation = keras.Sequential(
[
layers.RandomFlip("horizontal"),
layers.RandomRotation(0.1),
]
)
构建模型
inputs = keras.Input(shape=(32, 32, 3))
x = data_augmentation(inputs)
x = layers.Rescaling(1./255)(x)
x = layers.Conv2D(64, 3, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
x = layers.Conv2D(64, 3, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
def residual_block(x, filters, pooling=False):
residual = x
x = layers.Conv2D(filters, 3, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
x = layers.Conv2D(filters, 3, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
if pooling:
x = layers.MaxPooling2D(2, padding="same")(x)
residual = layers.Conv2D(filters, 1, strides=2)(residual)
elif filters != residual.shape[-1]:
residual = layers.Conv2D(filters, 1)(residual)
x = layers.add([x, residual])
return x
x = residual_block(x, filters=128, pooling=True)
x = residual_block(x, filters=256, pooling=True)
x = residual_block(x, filters=512, pooling=True)
x = layers.Conv2D(1024, 3, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dropout(0.5)(x)
outputs = layers.Dense(10, activation="softmax")(x)
model = keras.Model(inputs=inputs, outputs=outputs)
model.summary()
keras.utils.plot_model(model, show_shapes=True, to_file='model.png')
Model: "model" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input_1 (InputLayer) [(None, 32, 32, 3)] 0 __________________________________________________________________________________________________ sequential (Sequential) (None, 32, 32, 3) 0 input_1[0][0] __________________________________________________________________________________________________ rescaling (Rescaling) (None, 32, 32, 3) 0 sequential[0][0] __________________________________________________________________________________________________ conv2d (Conv2D) (None, 32, 32, 64) 1792 rescaling[0][0] __________________________________________________________________________________________________ batch_normalization (BatchNorma (None, 32, 32, 64) 256 conv2d[0][0] __________________________________________________________________________________________________ activation (Activation) (None, 32, 32, 64) 0 batch_normalization[0][0] __________________________________________________________________________________________________ conv2d_1 (Conv2D) (None, 32, 32, 64) 36928 activation[0][0] __________________________________________________________________________________________________ batch_normalization_1 (BatchNor (None, 32, 32, 64) 256 conv2d_1[0][0] __________________________________________________________________________________________________ activation_1 (Activation) (None, 32, 32, 64) 0 batch_normalization_1[0][0] __________________________________________________________________________________________________ conv2d_2 (Conv2D) (None, 32, 32, 128) 73856 activation_1[0][0] __________________________________________________________________________________________________ batch_normalization_2 (BatchNor (None, 32, 32, 128) 512 conv2d_2[0][0] __________________________________________________________________________________________________ activation_2 (Activation) (None, 32, 32, 128) 0 batch_normalization_2[0][0] __________________________________________________________________________________________________ conv2d_3 (Conv2D) (None, 32, 32, 128) 147584 activation_2[0][0] __________________________________________________________________________________________________ batch_normalization_3 (BatchNor (None, 32, 32, 128) 512 conv2d_3[0][0] __________________________________________________________________________________________________ activation_3 (Activation) (None, 32, 32, 128) 0 batch_normalization_3[0][0] __________________________________________________________________________________________________ max_pooling2d (MaxPooling2D) (None, 16, 16, 128) 0 activation_3[0][0] __________________________________________________________________________________________________ conv2d_4 (Conv2D) (None, 16, 16, 128) 8320 activation_1[0][0] __________________________________________________________________________________________________ add (Add) (None, 16, 16, 128) 0 max_pooling2d[0][0] conv2d_4[0][0] __________________________________________________________________________________________________ conv2d_5 (Conv2D) (None, 16, 16, 256) 295168 add[0][0] __________________________________________________________________________________________________ batch_normalization_4 (BatchNor (None, 16, 16, 256) 1024 conv2d_5[0][0] __________________________________________________________________________________________________ activation_4 (Activation) (None, 16, 16, 256) 0 batch_normalization_4[0][0] __________________________________________________________________________________________________ conv2d_6 (Conv2D) (None, 16, 16, 256) 590080 activation_4[0][0] __________________________________________________________________________________________________ batch_normalization_5 (BatchNor (None, 16, 16, 256) 1024 conv2d_6[0][0] __________________________________________________________________________________________________ activation_5 (Activation) (None, 16, 16, 256) 0 batch_normalization_5[0][0] __________________________________________________________________________________________________ max_pooling2d_1 (MaxPooling2D) (None, 8, 8, 256) 0 activation_5[0][0] __________________________________________________________________________________________________ conv2d_7 (Conv2D) (None, 8, 8, 256) 33024 add[0][0] __________________________________________________________________________________________________ add_1 (Add) (None, 8, 8, 256) 0 max_pooling2d_1[0][0] conv2d_7[0][0] __________________________________________________________________________________________________ conv2d_8 (Conv2D) (None, 8, 8, 512) 1180160 add_1[0][0] __________________________________________________________________________________________________ batch_normalization_6 (BatchNor (None, 8, 8, 512) 2048 conv2d_8[0][0] __________________________________________________________________________________________________ activation_6 (Activation) (None, 8, 8, 512) 0 batch_normalization_6[0][0] __________________________________________________________________________________________________ conv2d_9 (Conv2D) (None, 8, 8, 512) 2359808 activation_6[0][0] __________________________________________________________________________________________________ batch_normalization_7 (BatchNor (None, 8, 8, 512) 2048 conv2d_9[0][0] __________________________________________________________________________________________________ activation_7 (Activation) (None, 8, 8, 512) 0 batch_normalization_7[0][0] __________________________________________________________________________________________________ max_pooling2d_2 (MaxPooling2D) (None, 4, 4, 512) 0 activation_7[0][0] __________________________________________________________________________________________________ conv2d_10 (Conv2D) (None, 4, 4, 512) 131584 add_1[0][0] __________________________________________________________________________________________________ add_2 (Add) (None, 4, 4, 512) 0 max_pooling2d_2[0][0] conv2d_10[0][0] __________________________________________________________________________________________________ conv2d_11 (Conv2D) (None, 4, 4, 1024) 4719616 add_2[0][0] __________________________________________________________________________________________________ batch_normalization_8 (BatchNor (None, 4, 4, 1024) 4096 conv2d_11[0][0] __________________________________________________________________________________________________ activation_8 (Activation) (None, 4, 4, 1024) 0 batch_normalization_8[0][0] __________________________________________________________________________________________________ global_average_pooling2d (Globa (None, 1024) 0 activation_8[0][0] __________________________________________________________________________________________________ dropout (Dropout) (None, 1024) 0 global_average_pooling2d[0][0] __________________________________________________________________________________________________ dense (Dense) (None, 10) 10250 dropout[0][0] ================================================================================================== Total params: 9,599,946 Trainable params: 9,594,058 Non-trainable params: 5,888
model.compile(optimizer=keras.optimizers.Adam(1e-3),
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
history = model.fit(train_images,
train_labels,
epochs=50,
validation_split=0.1,
batch_size=128)
2022-10-17 04:57:23.637172: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:185] None of the MLIR Optimization Passes are enabled (registered 2)
Epoch 1/50
2022-10-17 04:57:26.319428: I tensorflow/stream_executor/cuda/cuda_dnn.cc:369] Loaded cuDNN version 8005
352/352 [==============================] - 31s 62ms/step - loss: 1.4778 - accuracy: 0.4690 - val_loss: 2.2820 - val_accuracy: 0.2614 Epoch 2/50 352/352 [==============================] - 21s 59ms/step - loss: 1.0381 - accuracy: 0.6307 - val_loss: 1.0681 - val_accuracy: 0.6328 Epoch 3/50 352/352 [==============================] - 21s 60ms/step - loss: 0.8586 - accuracy: 0.6992 - val_loss: 0.9402 - val_accuracy: 0.6740 Epoch 4/50 352/352 [==============================] - 21s 59ms/step - loss: 0.7355 - accuracy: 0.7434 - val_loss: 0.7672 - val_accuracy: 0.7384 Epoch 5/50 352/352 [==============================] - 21s 59ms/step - loss: 0.6658 - accuracy: 0.7686 - val_loss: 0.7770 - val_accuracy: 0.7348 Epoch 6/50 352/352 [==============================] - 21s 60ms/step - loss: 0.5954 - accuracy: 0.7927 - val_loss: 1.1096 - val_accuracy: 0.6752 Epoch 7/50 352/352 [==============================] - 21s 59ms/step - loss: 0.5413 - accuracy: 0.8120 - val_loss: 0.7354 - val_accuracy: 0.7742 Epoch 8/50 352/352 [==============================] - 21s 59ms/step - loss: 0.5077 - accuracy: 0.8258 - val_loss: 0.5524 - val_accuracy: 0.8126 Epoch 9/50 352/352 [==============================] - 21s 59ms/step - loss: 0.4604 - accuracy: 0.8412 - val_loss: 1.0330 - val_accuracy: 0.6908 Epoch 10/50 352/352 [==============================] - 21s 60ms/step - loss: 0.4329 - accuracy: 0.8502 - val_loss: 0.7664 - val_accuracy: 0.7686 Epoch 11/50 352/352 [==============================] - 21s 59ms/step - loss: 0.4002 - accuracy: 0.8611 - val_loss: 0.4168 - val_accuracy: 0.8568 Epoch 12/50 352/352 [==============================] - 21s 60ms/step - loss: 0.3711 - accuracy: 0.8692 - val_loss: 0.5835 - val_accuracy: 0.8132 Epoch 13/50 352/352 [==============================] - 21s 60ms/step - loss: 0.3472 - accuracy: 0.8796 - val_loss: 0.6577 - val_accuracy: 0.7998 Epoch 14/50 352/352 [==============================] - 21s 59ms/step - loss: 0.3304 - accuracy: 0.8847 - val_loss: 0.4619 - val_accuracy: 0.8512 Epoch 15/50 352/352 [==============================] - 21s 60ms/step - loss: 0.3102 - accuracy: 0.8942 - val_loss: 0.4753 - val_accuracy: 0.8454 Epoch 16/50 352/352 [==============================] - 21s 59ms/step - loss: 0.2952 - accuracy: 0.8978 - val_loss: 0.4953 - val_accuracy: 0.8436 Epoch 17/50 352/352 [==============================] - 21s 59ms/step - loss: 0.2738 - accuracy: 0.9046 - val_loss: 0.6217 - val_accuracy: 0.8166 Epoch 18/50 352/352 [==============================] - 21s 59ms/step - loss: 0.2637 - accuracy: 0.9077 - val_loss: 0.3840 - val_accuracy: 0.8744 Epoch 19/50 352/352 [==============================] - 21s 60ms/step - loss: 0.2495 - accuracy: 0.9125 - val_loss: 0.6840 - val_accuracy: 0.7990 Epoch 20/50 352/352 [==============================] - 21s 59ms/step - loss: 0.2312 - accuracy: 0.9192 - val_loss: 0.6300 - val_accuracy: 0.8158 Epoch 21/50 352/352 [==============================] - 21s 59ms/step - loss: 0.2259 - accuracy: 0.9219 - val_loss: 0.5192 - val_accuracy: 0.8496 Epoch 22/50 352/352 [==============================] - 21s 60ms/step - loss: 0.2121 - accuracy: 0.9274 - val_loss: 0.5503 - val_accuracy: 0.8302 Epoch 23/50 352/352 [==============================] - 21s 59ms/step - loss: 0.2015 - accuracy: 0.9302 - val_loss: 0.5611 - val_accuracy: 0.8428 Epoch 24/50 352/352 [==============================] - 21s 59ms/step - loss: 0.1956 - accuracy: 0.9313 - val_loss: 0.7762 - val_accuracy: 0.8052 Epoch 25/50 352/352 [==============================] - 21s 60ms/step - loss: 0.1820 - accuracy: 0.9354 - val_loss: 0.4341 - val_accuracy: 0.8638 Epoch 26/50 352/352 [==============================] - 21s 59ms/step - loss: 0.1756 - accuracy: 0.9378 - val_loss: 0.4195 - val_accuracy: 0.8734 Epoch 27/50 352/352 [==============================] - 21s 60ms/step - loss: 0.1687 - accuracy: 0.9421 - val_loss: 0.3468 - val_accuracy: 0.8976 Epoch 28/50 352/352 [==============================] - 21s 60ms/step - loss: 0.1547 - accuracy: 0.9450 - val_loss: 0.5157 - val_accuracy: 0.8568 Epoch 29/50 352/352 [==============================] - 21s 59ms/step - loss: 0.1538 - accuracy: 0.9475 - val_loss: 0.3944 - val_accuracy: 0.8856 Epoch 30/50 352/352 [==============================] - 21s 60ms/step - loss: 0.1435 - accuracy: 0.9497 - val_loss: 0.4488 - val_accuracy: 0.8714 Epoch 31/50 352/352 [==============================] - 21s 59ms/step - loss: 0.1400 - accuracy: 0.9510 - val_loss: 0.3955 - val_accuracy: 0.8882 Epoch 32/50 352/352 [==============================] - 21s 59ms/step - loss: 0.1323 - accuracy: 0.9538 - val_loss: 0.5992 - val_accuracy: 0.8422 Epoch 33/50 352/352 [==============================] - 21s 60ms/step - loss: 0.1316 - accuracy: 0.9532 - val_loss: 0.4549 - val_accuracy: 0.8812 Epoch 34/50 352/352 [==============================] - 21s 59ms/step - loss: 0.1212 - accuracy: 0.9560 - val_loss: 0.3997 - val_accuracy: 0.8838 Epoch 35/50 352/352 [==============================] - 21s 59ms/step - loss: 0.1188 - accuracy: 0.9581 - val_loss: 0.4383 - val_accuracy: 0.8828 Epoch 36/50 352/352 [==============================] - 21s 59ms/step - loss: 0.1177 - accuracy: 0.9600 - val_loss: 0.4237 - val_accuracy: 0.8854 Epoch 37/50 352/352 [==============================] - 21s 59ms/step - loss: 0.1125 - accuracy: 0.9597 - val_loss: 0.5012 - val_accuracy: 0.8764 Epoch 38/50 352/352 [==============================] - 21s 59ms/step - loss: 0.1118 - accuracy: 0.9612 - val_loss: 0.5362 - val_accuracy: 0.8648 Epoch 39/50 352/352 [==============================] - 21s 59ms/step - loss: 0.1034 - accuracy: 0.9639 - val_loss: 0.4779 - val_accuracy: 0.8720 Epoch 40/50 352/352 [==============================] - 21s 60ms/step - loss: 0.1014 - accuracy: 0.9647 - val_loss: 0.4496 - val_accuracy: 0.8814 Epoch 41/50 352/352 [==============================] - 21s 59ms/step - loss: 0.1032 - accuracy: 0.9644 - val_loss: 0.4278 - val_accuracy: 0.8846 Epoch 42/50 352/352 [==============================] - 21s 60ms/step - loss: 0.0929 - accuracy: 0.9670 - val_loss: 0.6599 - val_accuracy: 0.8530 Epoch 43/50 352/352 [==============================] - 21s 60ms/step - loss: 0.0961 - accuracy: 0.9669 - val_loss: 0.3661 - val_accuracy: 0.9060 Epoch 44/50 352/352 [==============================] - 21s 59ms/step - loss: 0.0870 - accuracy: 0.9692 - val_loss: 0.5591 - val_accuracy: 0.8702 Epoch 45/50 352/352 [==============================] - 21s 59ms/step - loss: 0.0855 - accuracy: 0.9702 - val_loss: 0.6468 - val_accuracy: 0.8612 Epoch 46/50 352/352 [==============================] - 21s 60ms/step - loss: 0.0831 - accuracy: 0.9702 - val_loss: 0.4426 - val_accuracy: 0.8932 Epoch 47/50 352/352 [==============================] - 21s 59ms/step - loss: 0.0867 - accuracy: 0.9693 - val_loss: 0.4761 - val_accuracy: 0.8822 Epoch 48/50 352/352 [==============================] - 21s 59ms/step - loss: 0.0807 - accuracy: 0.9722 - val_loss: 0.5392 - val_accuracy: 0.8790 Epoch 49/50 352/352 [==============================] - 21s 59ms/step - loss: 0.0843 - accuracy: 0.9704 - val_loss: 0.4359 - val_accuracy: 0.8986 Epoch 50/50 352/352 [==============================] - 21s 59ms/step - loss: 0.0799 - accuracy: 0.9718 - val_loss: 0.4709 - val_accuracy: 0.8870
model.compile(optimizer=keras.optimizers.Adam(1e-5),
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
history = model.fit(train_images,
train_labels,
epochs=20,
validation_split=0.1,
batch_size=128)
Epoch 1/20 352/352 [==============================] - 22s 60ms/step - loss: 0.0617 - accuracy: 0.9793 - val_loss: 0.3740 - val_accuracy: 0.9054 Epoch 2/20 352/352 [==============================] - 21s 59ms/step - loss: 0.0556 - accuracy: 0.9814 - val_loss: 0.3631 - val_accuracy: 0.9078 Epoch 3/20 352/352 [==============================] - 21s 60ms/step - loss: 0.0482 - accuracy: 0.9838 - val_loss: 0.3543 - val_accuracy: 0.9106 Epoch 4/20 352/352 [==============================] - 21s 59ms/step - loss: 0.0454 - accuracy: 0.9849 - val_loss: 0.3490 - val_accuracy: 0.9126 Epoch 5/20 352/352 [==============================] - 21s 59ms/step - loss: 0.0424 - accuracy: 0.9857 - val_loss: 0.3453 - val_accuracy: 0.9142 Epoch 6/20 352/352 [==============================] - 21s 60ms/step - loss: 0.0392 - accuracy: 0.9873 - val_loss: 0.3451 - val_accuracy: 0.9148 Epoch 7/20 352/352 [==============================] - 21s 59ms/step - loss: 0.0395 - accuracy: 0.9874 - val_loss: 0.3410 - val_accuracy: 0.9156 Epoch 8/20 352/352 [==============================] - 21s 59ms/step - loss: 0.0370 - accuracy: 0.9883 - val_loss: 0.3404 - val_accuracy: 0.9170 Epoch 9/20 352/352 [==============================] - 21s 60ms/step - loss: 0.0370 - accuracy: 0.9889 - val_loss: 0.3370 - val_accuracy: 0.9160 Epoch 10/20 352/352 [==============================] - 21s 59ms/step - loss: 0.0351 - accuracy: 0.9890 - val_loss: 0.3366 - val_accuracy: 0.9160 Epoch 11/20 352/352 [==============================] - 21s 60ms/step - loss: 0.0335 - accuracy: 0.9894 - val_loss: 0.3364 - val_accuracy: 0.9168 Epoch 12/20 352/352 [==============================] - 21s 59ms/step - loss: 0.0323 - accuracy: 0.9898 - val_loss: 0.3370 - val_accuracy: 0.9162 Epoch 13/20 352/352 [==============================] - 21s 59ms/step - loss: 0.0323 - accuracy: 0.9899 - val_loss: 0.3343 - val_accuracy: 0.9174 Epoch 14/20 352/352 [==============================] - 21s 59ms/step - loss: 0.0295 - accuracy: 0.9905 - val_loss: 0.3353 - val_accuracy: 0.9172 Epoch 15/20 352/352 [==============================] - 21s 59ms/step - loss: 0.0300 - accuracy: 0.9903 - val_loss: 0.3318 - val_accuracy: 0.9178 Epoch 16/20 352/352 [==============================] - 21s 59ms/step - loss: 0.0312 - accuracy: 0.9898 - val_loss: 0.3318 - val_accuracy: 0.9172 Epoch 17/20 352/352 [==============================] - 21s 59ms/step - loss: 0.0283 - accuracy: 0.9915 - val_loss: 0.3303 - val_accuracy: 0.9182 Epoch 18/20 352/352 [==============================] - 21s 60ms/step - loss: 0.0291 - accuracy: 0.9909 - val_loss: 0.3303 - val_accuracy: 0.9178 Epoch 19/20 352/352 [==============================] - 21s 60ms/step - loss: 0.0274 - accuracy: 0.9917 - val_loss: 0.3304 - val_accuracy: 0.9184 Epoch 20/20 352/352 [==============================] - 21s 59ms/step - loss: 0.0282 - accuracy: 0.9913 - val_loss: 0.3286 - val_accuracy: 0.9198
model.evaluate(test_images, test_labels)
313/313 [==============================] - 3s 8ms/step - loss: 0.3310 - accuracy: 0.9180