tensorflow实现resnet-32残差卷积网络

40行代码解决残差网络,觉得牛逼的点赞

from tensorflow import keras
from tensorflow.keras.layers import *
from tensorflow.keras.models import *



from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession


#限制显存的使用
config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)

'''
实现resNET-34
'''

class ResidualUnit(keras.layers.Layer):
     def __init__(self, filters, strides=1, activation="relu", **kwargs):
         super().__init__(**kwargs)
         self.activation = keras.activations.get(activation)
         self.main_layers = [
         keras.layers.Conv2D(filters, 3, strides=strides,
         padding="same", use_bias=False),
         keras.layers.BatchNormalization(),
         self.activation,
         keras.layers.Conv2D(filters, 3, strides=1,
         padding="same", use_bias=False),
         keras.layers.BatchNormalization()]
         self.skip_layers = []
         if strides > 1:
             self.skip_layers = [
             keras.layers.Conv2D(filters, 1, strides=strides,
             padding="same", use_bias=False),
             keras.layers.BatchNormalization()]
     def call(self, inputs):
         Z = inputs
         for layer in self.main_layers:
             Z = layer(Z)
             skip_Z = inputs
         for layer in self.skip_layers:
             skip_Z = layer(skip_Z)
         return self.activation(Z + skip_Z)



model = Sequential()
model.add(Conv2D(64,7,strides=2,input_shape=[224,224,3],padding='same',use_bias=False))
model.add(BatchNormalization())
model.add(Activation('relu'))
model.add(MaxPool2D(pool_size=3,strides=2,padding='same'))

prev_filters=64
for filters in [64]*3+[128]*4 +[256]*6+[512]*3:
    strides = 1 if filters == prev_filters else 2
    model.add(ResidualUnit(filters, strides=strides))
    prev_filters = filters

model.add(keras.layers.GlobalAvgPool2D())
model.add(keras.layers.Flatten())
model.add(keras.layers.Dense(10, activation="softmax"))

# print([64]*3+[128]*4 +[256]*6+[512]*3)  #[64, 64, 64, 128, 128, 128, 128, 256, 256, 256, 256, 256, 256, 512, 512, 512]

你可能感兴趣的:(python,机器学习,tensorflow,深度学习,神经网络)