keras使用tpu训练

import tensorflow as tf

import time
from tensorflow import keras
from tensorflow.keras import layers
print("TensorFlow version:", tf.__version__)

tpu = tf.distribute.cluster_resolver.TPUClusterResolver.connect()
print("Device:", tpu.master())
strategy = tf.distribute.TPUStrategy(tpu) 
print(f"Number of replicas: {strategy.num_replicas_in_sync}")
TensorFlow version: 2.4.1
2022-10-24 04:57:03.512464: I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:301] Initialize GrpcChannelCache for job worker -> {0 -> 10.0.0.2:8470}
2022-10-24 04:57:03.512704: I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:301] Initialize GrpcChannelCache for job localhost -> {0 -> localhost:30020}
2022-10-24 04:57:03.516894: I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:301] Initialize GrpcChannelCache for job worker -> {0 -> 10.0.0.2:8470}
2022-10-24 04:57:03.516945: I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:301] Initialize GrpcChannelCache for job localhost -> {0 -> localhost:30020}
Device: grpc://10.0.0.2:8470
Number of replicas: 8
(train_images, train_labels), (test_images, test_labels) = keras.datasets.cifar10.load_data()
train_images, test_images = train_images / 255.0, test_images / 255.0

 

def residual_block(x, filters, pooling=False):
    residual = x
    x = layers.Conv2D(filters, 3, padding="same")(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)
    x = layers.Conv2D(filters, 3, padding="same")(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)
    if pooling:
        x = layers.MaxPooling2D(2, padding="same")(x)
        residual = layers.Conv2D(filters, 1, strides=2)(residual)
    elif filters != residual.shape[-1]:
        residual = layers.Conv2D(filters, 1)(residual)
    x = layers.add([x, residual])
    return x


def build_model():
    
    inputs = keras.Input(shape=(32, 32, 3))
    x = layers.Conv2D(64, 3, padding="same")(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)
    x = layers.Conv2D(64, 3, padding="same")(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)
    
    
    x = residual_block(x, filters=128, pooling=True)
    x = residual_block(x, filters=256, pooling=True)
    x = residual_block(x, filters=512, pooling=True)
    
    x = layers.Conv2D(1024, 3, padding="same")(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)
    x = layers.GlobalAveragePooling2D()(x)
    
    x = layers.Dropout(0.5)(x)
    outputs = layers.Dense(10, activation="softmax")(x)
    model = keras.Model(inputs=inputs, outputs=outputs)
    
    model.compile(
        optimizer=keras.optimizers.Adam(1e-3),
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )

    return model


with strategy.scope():
    model = build_model()

model.summary()
Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_4 (InputLayer)            [(None, 32, 32, 3)]  0                                            
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 32, 32, 64)   1792        input_4[0][0]                    
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 32, 32, 64)   256         conv2d_2[0][0]                   
__________________________________________________________________________________________________
activation_1 (Activation)       (None, 32, 32, 64)   0           batch_normalization_1[0][0]      
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 32, 32, 64)   36928       activation_1[0][0]               
__________________________________________________________________________________________________
batch_normalization_2 (BatchNor (None, 32, 32, 64)   256         conv2d_3[0][0]                   
__________________________________________________________________________________________________
activation_2 (Activation)       (None, 32, 32, 64)   0           batch_normalization_2[0][0]      
__________________________________________________________________________________________________
conv2d_4 (Conv2D)               (None, 32, 32, 128)  73856       activation_2[0][0]               
__________________________________________________________________________________________________
batch_normalization_3 (BatchNor (None, 32, 32, 128)  512         conv2d_4[0][0]                   
__________________________________________________________________________________________________
activation_3 (Activation)       (None, 32, 32, 128)  0           batch_normalization_3[0][0]      
__________________________________________________________________________________________________
conv2d_5 (Conv2D)               (None, 32, 32, 128)  147584      activation_3[0][0]               
__________________________________________________________________________________________________
batch_normalization_4 (BatchNor (None, 32, 32, 128)  512         conv2d_5[0][0]                   
__________________________________________________________________________________________________
activation_4 (Activation)       (None, 32, 32, 128)  0           batch_normalization_4[0][0]      
__________________________________________________________________________________________________
max_pooling2d (MaxPooling2D)    (None, 16, 16, 128)  0           activation_4[0][0]               
__________________________________________________________________________________________________
conv2d_6 (Conv2D)               (None, 16, 16, 128)  8320        activation_2[0][0]               
__________________________________________________________________________________________________
add (Add)                       (None, 16, 16, 128)  0           max_pooling2d[0][0]              
                                                                 conv2d_6[0][0]                   
__________________________________________________________________________________________________
conv2d_7 (Conv2D)               (None, 16, 16, 256)  295168      add[0][0]                        
__________________________________________________________________________________________________
batch_normalization_5 (BatchNor (None, 16, 16, 256)  1024        conv2d_7[0][0]                   
__________________________________________________________________________________________________
activation_5 (Activation)       (None, 16, 16, 256)  0           batch_normalization_5[0][0]      
__________________________________________________________________________________________________
conv2d_8 (Conv2D)               (None, 16, 16, 256)  590080      activation_5[0][0]               
__________________________________________________________________________________________________
batch_normalization_6 (BatchNor (None, 16, 16, 256)  1024        conv2d_8[0][0]                   
__________________________________________________________________________________________________
activation_6 (Activation)       (None, 16, 16, 256)  0           batch_normalization_6[0][0]      
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D)  (None, 8, 8, 256)    0           activation_6[0][0]               
__________________________________________________________________________________________________
conv2d_9 (Conv2D)               (None, 8, 8, 256)    33024       add[0][0]                        
__________________________________________________________________________________________________
add_1 (Add)                     (None, 8, 8, 256)    0           max_pooling2d_1[0][0]            
                                                                 conv2d_9[0][0]                   
__________________________________________________________________________________________________
conv2d_10 (Conv2D)              (None, 8, 8, 512)    1180160     add_1[0][0]                      
__________________________________________________________________________________________________
batch_normalization_7 (BatchNor (None, 8, 8, 512)    2048        conv2d_10[0][0]                  
__________________________________________________________________________________________________
activation_7 (Activation)       (None, 8, 8, 512)    0           batch_normalization_7[0][0]      
__________________________________________________________________________________________________
conv2d_11 (Conv2D)              (None, 8, 8, 512)    2359808     activation_7[0][0]               
__________________________________________________________________________________________________
batch_normalization_8 (BatchNor (None, 8, 8, 512)    2048        conv2d_11[0][0]                  
__________________________________________________________________________________________________
activation_8 (Activation)       (None, 8, 8, 512)    0           batch_normalization_8[0][0]      
__________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D)  (None, 4, 4, 512)    0           activation_8[0][0]               
__________________________________________________________________________________________________
conv2d_12 (Conv2D)              (None, 4, 4, 512)    131584      add_1[0][0]                      
__________________________________________________________________________________________________
add_2 (Add)                     (None, 4, 4, 512)    0           max_pooling2d_2[0][0]            
                                                                 conv2d_12[0][0]                  
__________________________________________________________________________________________________
conv2d_13 (Conv2D)              (None, 4, 4, 1024)   4719616     add_2[0][0]                      
__________________________________________________________________________________________________
batch_normalization_9 (BatchNor (None, 4, 4, 1024)   4096        conv2d_13[0][0]                  
__________________________________________________________________________________________________
activation_9 (Activation)       (None, 4, 4, 1024)   0           batch_normalization_9[0][0]      
__________________________________________________________________________________________________
global_average_pooling2d (Globa (None, 1024)         0           activation_9[0][0]               
__________________________________________________________________________________________________
dropout (Dropout)               (None, 1024)         0           global_average_pooling2d[0][0]   
__________________________________________________________________________________________________
dense (Dense)                   (None, 10)           10250       dropout[0][0]                    
==================================================================================================
Total params: 9,599,946
Trainable params: 9,594,058
Non-trainable params: 5,888
__________________________________________________________________________________________________
since = time.time()
history = model.fit(
    train_images,
    train_labels, 
    epochs=50,
    validation_split=0.1,
    batch_size=128
)
time_elapsed = time.time() - since
print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
print("Done!")

 

Epoch 1/50
352/352 [==============================] - 33s 55ms/step - loss: 1.7561 - accuracy: 0.3916 - val_loss: 1.6708 - val_accuracy: 0.4450
Epoch 2/50
352/352 [==============================] - 8s 22ms/step - loss: 0.9427 - accuracy: 0.6668 - val_loss: 0.9502 - val_accuracy: 0.6812
Epoch 3/50
352/352 [==============================] - 8s 22ms/step - loss: 0.6952 - accuracy: 0.7579 - val_loss: 0.8292 - val_accuracy: 0.7180
Epoch 4/50
352/352 [==============================] - 8s 22ms/step - loss: 0.5425 - accuracy: 0.8150 - val_loss: 0.9087 - val_accuracy: 0.7078
Epoch 5/50
352/352 [==============================] - 8s 22ms/step - loss: 0.4375 - accuracy: 0.8492 - val_loss: 0.6072 - val_accuracy: 0.7856
Epoch 6/50
352/352 [==============================] - 8s 22ms/step - loss: 0.3351 - accuracy: 0.8847 - val_loss: 1.2582 - val_accuracy: 0.6374
Epoch 7/50
352/352 [==============================] - 8s 22ms/step - loss: 0.2739 - accuracy: 0.9049 - val_loss: 0.9219 - val_accuracy: 0.7440
Epoch 8/50
352/352 [==============================] - 9s 25ms/step - loss: 0.2006 - accuracy: 0.9313 - val_loss: 0.9821 - val_accuracy: 0.7718
Epoch 9/50
352/352 [==============================] - 8s 23ms/step - loss: 0.1563 - accuracy: 0.9453 - val_loss: 0.8029 - val_accuracy: 0.7748
Epoch 10/50
352/352 [==============================] - 8s 22ms/step - loss: 0.1255 - accuracy: 0.9573 - val_loss: 0.6687 - val_accuracy: 0.8154
Epoch 11/50
352/352 [==============================] - 8s 22ms/step - loss: 0.1050 - accuracy: 0.9651 - val_loss: 0.6656 - val_accuracy: 0.8192
Epoch 12/50
352/352 [==============================] - 8s 22ms/step - loss: 0.0908 - accuracy: 0.9675 - val_loss: 0.6068 - val_accuracy: 0.8396
Epoch 13/50
352/352 [==============================] - 8s 22ms/step - loss: 0.0748 - accuracy: 0.9745 - val_loss: 1.2495 - val_accuracy: 0.7386
Epoch 14/50
352/352 [==============================] - 8s 22ms/step - loss: 0.0766 - accuracy: 0.9735 - val_loss: 0.8186 - val_accuracy: 0.8200
Epoch 15/50
352/352 [==============================] - 8s 22ms/step - loss: 0.0544 - accuracy: 0.9821 - val_loss: 0.8215 - val_accuracy: 0.8284
Epoch 16/50
352/352 [==============================] - 8s 22ms/step - loss: 0.0597 - accuracy: 0.9797 - val_loss: 1.0178 - val_accuracy: 0.7750
model.evaluate(test_images, test_labels)

 

313/313 [==============================] - 6s 14ms/step - loss: 0.6225 - accuracy: 0.8776

你可能感兴趣的:(深度学习,keras,cnn,tensorflow,深度学习,人工智能)