import tensorflow as tf
import time
from tensorflow import keras
from tensorflow.keras import layers
print("TensorFlow version:", tf.__version__)
tpu = tf.distribute.cluster_resolver.TPUClusterResolver.connect()
print("Device:", tpu.master())
strategy = tf.distribute.TPUStrategy(tpu)
print(f"Number of replicas: {strategy.num_replicas_in_sync}")
TensorFlow version: 2.4.1
2022-10-24 04:57:03.512464: I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:301] Initialize GrpcChannelCache for job worker -> {0 -> 10.0.0.2:8470} 2022-10-24 04:57:03.512704: I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:301] Initialize GrpcChannelCache for job localhost -> {0 -> localhost:30020} 2022-10-24 04:57:03.516894: I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:301] Initialize GrpcChannelCache for job worker -> {0 -> 10.0.0.2:8470} 2022-10-24 04:57:03.516945: I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:301] Initialize GrpcChannelCache for job localhost -> {0 -> localhost:30020}
Device: grpc://10.0.0.2:8470 Number of replicas: 8
(train_images, train_labels), (test_images, test_labels) = keras.datasets.cifar10.load_data()
train_images, test_images = train_images / 255.0, test_images / 255.0
def residual_block(x, filters, pooling=False):
residual = x
x = layers.Conv2D(filters, 3, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
x = layers.Conv2D(filters, 3, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
if pooling:
x = layers.MaxPooling2D(2, padding="same")(x)
residual = layers.Conv2D(filters, 1, strides=2)(residual)
elif filters != residual.shape[-1]:
residual = layers.Conv2D(filters, 1)(residual)
x = layers.add([x, residual])
return x
def build_model():
inputs = keras.Input(shape=(32, 32, 3))
x = layers.Conv2D(64, 3, padding="same")(inputs)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
x = layers.Conv2D(64, 3, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
x = residual_block(x, filters=128, pooling=True)
x = residual_block(x, filters=256, pooling=True)
x = residual_block(x, filters=512, pooling=True)
x = layers.Conv2D(1024, 3, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dropout(0.5)(x)
outputs = layers.Dense(10, activation="softmax")(x)
model = keras.Model(inputs=inputs, outputs=outputs)
model.compile(
optimizer=keras.optimizers.Adam(1e-3),
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
return model
with strategy.scope():
model = build_model()
model.summary()
Model: "model" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input_4 (InputLayer) [(None, 32, 32, 3)] 0 __________________________________________________________________________________________________ conv2d_2 (Conv2D) (None, 32, 32, 64) 1792 input_4[0][0] __________________________________________________________________________________________________ batch_normalization_1 (BatchNor (None, 32, 32, 64) 256 conv2d_2[0][0] __________________________________________________________________________________________________ activation_1 (Activation) (None, 32, 32, 64) 0 batch_normalization_1[0][0] __________________________________________________________________________________________________ conv2d_3 (Conv2D) (None, 32, 32, 64) 36928 activation_1[0][0] __________________________________________________________________________________________________ batch_normalization_2 (BatchNor (None, 32, 32, 64) 256 conv2d_3[0][0] __________________________________________________________________________________________________ activation_2 (Activation) (None, 32, 32, 64) 0 batch_normalization_2[0][0] __________________________________________________________________________________________________ conv2d_4 (Conv2D) (None, 32, 32, 128) 73856 activation_2[0][0] __________________________________________________________________________________________________ batch_normalization_3 (BatchNor (None, 32, 32, 128) 512 conv2d_4[0][0] __________________________________________________________________________________________________ activation_3 (Activation) (None, 32, 32, 128) 0 batch_normalization_3[0][0] __________________________________________________________________________________________________ conv2d_5 (Conv2D) (None, 32, 32, 128) 147584 activation_3[0][0] __________________________________________________________________________________________________ batch_normalization_4 (BatchNor (None, 32, 32, 128) 512 conv2d_5[0][0] __________________________________________________________________________________________________ activation_4 (Activation) (None, 32, 32, 128) 0 batch_normalization_4[0][0] __________________________________________________________________________________________________ max_pooling2d (MaxPooling2D) (None, 16, 16, 128) 0 activation_4[0][0] __________________________________________________________________________________________________ conv2d_6 (Conv2D) (None, 16, 16, 128) 8320 activation_2[0][0] __________________________________________________________________________________________________ add (Add) (None, 16, 16, 128) 0 max_pooling2d[0][0] conv2d_6[0][0] __________________________________________________________________________________________________ conv2d_7 (Conv2D) (None, 16, 16, 256) 295168 add[0][0] __________________________________________________________________________________________________ batch_normalization_5 (BatchNor (None, 16, 16, 256) 1024 conv2d_7[0][0] __________________________________________________________________________________________________ activation_5 (Activation) (None, 16, 16, 256) 0 batch_normalization_5[0][0] __________________________________________________________________________________________________ conv2d_8 (Conv2D) (None, 16, 16, 256) 590080 activation_5[0][0] __________________________________________________________________________________________________ batch_normalization_6 (BatchNor (None, 16, 16, 256) 1024 conv2d_8[0][0] __________________________________________________________________________________________________ activation_6 (Activation) (None, 16, 16, 256) 0 batch_normalization_6[0][0] __________________________________________________________________________________________________ max_pooling2d_1 (MaxPooling2D) (None, 8, 8, 256) 0 activation_6[0][0] __________________________________________________________________________________________________ conv2d_9 (Conv2D) (None, 8, 8, 256) 33024 add[0][0] __________________________________________________________________________________________________ add_1 (Add) (None, 8, 8, 256) 0 max_pooling2d_1[0][0] conv2d_9[0][0] __________________________________________________________________________________________________ conv2d_10 (Conv2D) (None, 8, 8, 512) 1180160 add_1[0][0] __________________________________________________________________________________________________ batch_normalization_7 (BatchNor (None, 8, 8, 512) 2048 conv2d_10[0][0] __________________________________________________________________________________________________ activation_7 (Activation) (None, 8, 8, 512) 0 batch_normalization_7[0][0] __________________________________________________________________________________________________ conv2d_11 (Conv2D) (None, 8, 8, 512) 2359808 activation_7[0][0] __________________________________________________________________________________________________ batch_normalization_8 (BatchNor (None, 8, 8, 512) 2048 conv2d_11[0][0] __________________________________________________________________________________________________ activation_8 (Activation) (None, 8, 8, 512) 0 batch_normalization_8[0][0] __________________________________________________________________________________________________ max_pooling2d_2 (MaxPooling2D) (None, 4, 4, 512) 0 activation_8[0][0] __________________________________________________________________________________________________ conv2d_12 (Conv2D) (None, 4, 4, 512) 131584 add_1[0][0] __________________________________________________________________________________________________ add_2 (Add) (None, 4, 4, 512) 0 max_pooling2d_2[0][0] conv2d_12[0][0] __________________________________________________________________________________________________ conv2d_13 (Conv2D) (None, 4, 4, 1024) 4719616 add_2[0][0] __________________________________________________________________________________________________ batch_normalization_9 (BatchNor (None, 4, 4, 1024) 4096 conv2d_13[0][0] __________________________________________________________________________________________________ activation_9 (Activation) (None, 4, 4, 1024) 0 batch_normalization_9[0][0] __________________________________________________________________________________________________ global_average_pooling2d (Globa (None, 1024) 0 activation_9[0][0] __________________________________________________________________________________________________ dropout (Dropout) (None, 1024) 0 global_average_pooling2d[0][0] __________________________________________________________________________________________________ dense (Dense) (None, 10) 10250 dropout[0][0] ================================================================================================== Total params: 9,599,946 Trainable params: 9,594,058 Non-trainable params: 5,888 __________________________________________________________________________________________________
since = time.time()
history = model.fit(
train_images,
train_labels,
epochs=50,
validation_split=0.1,
batch_size=128
)
time_elapsed = time.time() - since
print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
print("Done!")
Epoch 1/50 352/352 [==============================] - 33s 55ms/step - loss: 1.7561 - accuracy: 0.3916 - val_loss: 1.6708 - val_accuracy: 0.4450 Epoch 2/50 352/352 [==============================] - 8s 22ms/step - loss: 0.9427 - accuracy: 0.6668 - val_loss: 0.9502 - val_accuracy: 0.6812 Epoch 3/50 352/352 [==============================] - 8s 22ms/step - loss: 0.6952 - accuracy: 0.7579 - val_loss: 0.8292 - val_accuracy: 0.7180 Epoch 4/50 352/352 [==============================] - 8s 22ms/step - loss: 0.5425 - accuracy: 0.8150 - val_loss: 0.9087 - val_accuracy: 0.7078 Epoch 5/50 352/352 [==============================] - 8s 22ms/step - loss: 0.4375 - accuracy: 0.8492 - val_loss: 0.6072 - val_accuracy: 0.7856 Epoch 6/50 352/352 [==============================] - 8s 22ms/step - loss: 0.3351 - accuracy: 0.8847 - val_loss: 1.2582 - val_accuracy: 0.6374 Epoch 7/50 352/352 [==============================] - 8s 22ms/step - loss: 0.2739 - accuracy: 0.9049 - val_loss: 0.9219 - val_accuracy: 0.7440 Epoch 8/50 352/352 [==============================] - 9s 25ms/step - loss: 0.2006 - accuracy: 0.9313 - val_loss: 0.9821 - val_accuracy: 0.7718 Epoch 9/50 352/352 [==============================] - 8s 23ms/step - loss: 0.1563 - accuracy: 0.9453 - val_loss: 0.8029 - val_accuracy: 0.7748 Epoch 10/50 352/352 [==============================] - 8s 22ms/step - loss: 0.1255 - accuracy: 0.9573 - val_loss: 0.6687 - val_accuracy: 0.8154 Epoch 11/50 352/352 [==============================] - 8s 22ms/step - loss: 0.1050 - accuracy: 0.9651 - val_loss: 0.6656 - val_accuracy: 0.8192 Epoch 12/50 352/352 [==============================] - 8s 22ms/step - loss: 0.0908 - accuracy: 0.9675 - val_loss: 0.6068 - val_accuracy: 0.8396 Epoch 13/50 352/352 [==============================] - 8s 22ms/step - loss: 0.0748 - accuracy: 0.9745 - val_loss: 1.2495 - val_accuracy: 0.7386 Epoch 14/50 352/352 [==============================] - 8s 22ms/step - loss: 0.0766 - accuracy: 0.9735 - val_loss: 0.8186 - val_accuracy: 0.8200 Epoch 15/50 352/352 [==============================] - 8s 22ms/step - loss: 0.0544 - accuracy: 0.9821 - val_loss: 0.8215 - val_accuracy: 0.8284 Epoch 16/50 352/352 [==============================] - 8s 22ms/step - loss: 0.0597 - accuracy: 0.9797 - val_loss: 1.0178 - val_accuracy: 0.7750
model.evaluate(test_images, test_labels)
313/313 [==============================] - 6s 14ms/step - loss: 0.6225 - accuracy: 0.8776