Resnet——深度残差网络(二)

基于上一篇resnet网络结构进行实战。

再来贴一下resnet的基本结构方便与代码进行对比

Resnet——深度残差网络(二)_第1张图片

 

resnet的自定义类如下:

import tensorflow as tf
from tensorflow import keras

class BasicBlock(keras.layers.Layer):

    # filter_num指定通道数,stride指定步长
    def __init__(self,filter_num,stride=1):
        super(BasicBlock, self).__init__()


        # 注意padding=same并不总使得输入维度等于输出维度,而是对不同的步长有不同的策略,使得滑动更加完整
        self.conv1 = keras.layers.Conv2D(filter_num,(3,3),strides=stride,padding='same')
        self.bn1 = keras.layers.BatchNormalization()
        self.relu = keras.layers.Activation('relu')

        self.conv2 = keras.layers.Conv2D(filter_num,(3,3),strides=1,padding='same')
        self.bn2 = keras.layers.BatchNormalization()

        if stride!=1:
            self.dowmsample = keras.Sequential()
            self.dowmsample.add(keras.layers.Conv2D(filter_num,(1,1),strides=stride))
        else:
            self.dowmsample = lambda x:x

    def call(self, inputs, training=None):

        out = self.conv1(inputs)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        identity = self.dowmsample(inputs)

        output = keras.layers.add([out,identity])
        output = tf.nn.relu(output)

        return output

class ResNet(keras.Model):

    # resnet基本结构为[2,2,2,2],即分为四个部分,每个部分又分两个小部分
    def __init__(self,layer_dims,num_classes=100):
        super(ResNet,self).__init__()

        # 预处理层
        self.stem = keras.Sequential([
            keras.layers.Conv2D(64,(3,3),strides=(1,1)),
            keras.layers.BatchNormalization(),
            keras.layers.Activation('relu'),
            keras.layers.MaxPool2D(pool_size=(2,2),strides=(1,1),padding='same')
        ])

        self.layer1 = self.build_resblock(64,layer_dims[0])
        self.layer2 = self.build_resblock(128, layer_dims[1], stride=2)
        self.layer3 = self.build_resblock(256, layer_dims[2], stride=2)
        self.layer4 = self.build_resblock(512, layer_dims[3], stride=2)

        # 自适应输出,方便送入全连层进行分类
        self.avgpool = keras.layers.GlobalAveragePooling2D()
        self.fc = keras.layers.Dense(num_classes)

    def call(self, inputs, training=None):
        x = self.stem(inputs)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = self.fc(x)

        return x

    def build_resblock(self,filter_num,blocks,stride=1):
        res_blocks = keras.Sequential();
        res_blocks.add(BasicBlock(filter_num,stride))

        for _ in range(1,blocks):
            res_blocks.add(BasicBlock(filter_num,1))

        return res_blocks

def resnet18():
    return ResNet([2,2,2,2])

训练过程如下:

import tensorflow as tf
from tensorflow import keras
import os
from resnet import resnet18

os.environ['TF_CPP_MIN_LOG'] = '2'

def preprocess(x,y):
    x = 2*tf.cast(x,dtype=tf.float32)/255.-1
    y = tf.cast(y,dtype=tf.int32)
    return x,y

(x,y),(x_test,y_test) = keras.datasets.cifar100.load_data()
y = tf.squeeze(y,axis=1)
y_test = tf.squeeze(y_test,axis=1)
print(x.shape,y.shape,x_test.shape,y_test.shape)

train_db = tf.data.Dataset.from_tensor_slices((x,y))
train_db = train_db.shuffle(1000).map(preprocess).batch(64)

test_db = tf.data.Dataset.from_tensor_slices((x_test,y_test))
test_db = train_db.map(preprocess).batch(64)

def main():
    model = resnet18()
    model.build(input_shape=(None,32,32,3))
    optimizer = keras.optimizers.Adam(lr=1e-3)
    model.summary()

    for epoch in range(50):
        for step,(x,y) in enumerate(train_db):
            with tf.GradientTape() as tape:
                logits = model(x)
                y_onehot = tf.one_hot(y,depth=10)
                loss = tf.losses.categorical_crossentropy(y_onehot,logits,from_logits=True)
                loss = tf.reduce_mean(loss)

            gradient = tape.gradient(loss,model.trainable_variables)
            optimizer.apply_gradients(zip(gradient,model.trainable_variables))

            if step % 100 == 0:
                print(epoch,step,'loss:',float(loss))

        total_num = 0
        total_correct = 0
        for x,y in test_db:
            logits = model(x)
            prob = tf.nn.softmax(logits,axis=1)
            pred = tf.argmax(prob,axis=1)
            pred = tf.cast(pred,dtype=tf.int32)

            correct = tf.cast(tf.equal(pred,y),dtype=tf.int32)
            correct = tf.reduce_sum(correct)

            total_num += x.shape[0]
            total_correct += correct
            acc = total_correct/total_num

            print("acc:",acc)


if __name__ == '__main__':
    main()

打印网络结构和参数量如下:

Resnet——深度残差网络(二)_第2张图片

 

你可能感兴趣的:(Resnet——深度残差网络(二))