Keras实现ResNet残差网络

 ResNet的结构包括以下5种,都是一开始有一层卷积,一个maxpooling,接着4层结构,最后一层全连接。

5种网络的主要差异就是卷积核的数目和每层结构的不同,resnet18和resnet34用的都是基本结构,即3x3卷积;resnet50,resnet101,resnet152用的是bottleneck结构,每层block的数量也不同。

 代码里的基础block都是卷积+batch normalization,就放在一起了,后面再接relu。shortcut主要是处理x和f(x)维度不同的情况,如果channel不同,就用1x1的卷积先改变x的channel。卷积核的尺寸除了第一层卷积为7x7,其它的都是3x3。而basic_block和bottleneck内部的卷积核尺寸都是固定的。

from keras.layers import *
from keras.models import Input
from keras import Model
import tensorflow as tf
import numpy as np


def conv2d_bn(input, kernel_num, kernel_size=3, strides=1, layer_name='', padding_mode='same'):
    conv1 = Conv2D(kernel_num, kernel_size, strides=strides, padding=padding_mode, name=layer_name + '_conv1')(input)
    batch1 = BatchNormalization(name=layer_name + '_bn1')(conv1)
    return batch1


def shortcut(fx, x, padding_mode='same', layer_name=''):
    layer_name += '_shortcut'
    if x.shape[-1] != fx.shape[-1]:
        k = fx.shape[-1]
        k = int(k)
        identity = conv2d_bn(x, kernel_num=k, kernel_size=1, padding_mode=padding_mode, layer_name=layer_name)
    else:
        identity = x
    return Add(name=layer_name + '_add')([identity, fx])


def bottleneck(input, kernel_num, strides=1, layer_name='bottleneck', padding_mode='same'):
    k1, k2, k3 = kernel_num
    conv1 = conv2d_bn(input, kernel_num=k1, kernel_size=1, strides=strides, padding_mode=padding_mode, layer_name=layer_name+'_1')
    relu1 = ReLU(name=layer_name + '_relu1')(conv1)
    conv2 = conv2d_bn(relu1, kernel_num=k2, kernel_size=3, strides=strides, padding_mode=padding_mode, layer_name=layer_name+'_2')
    relu2 = ReLU(name=layer_name + '_relu2')(conv2)
    conv3 = conv2d_bn(relu2, kernel_num=k3, kernel_size=1, strides=strides, padding_mode=padding_mode, layer_name=layer_name+'_3')
    # print(conv3.shape, input.shape)
    shortcut_add = shortcut(fx=conv3, x=input, layer_name=layer_name)
    relu3 = ReLU(name=layer_name + '_relu3')(shortcut_add)

    return relu3


def basic_block(input, kernel_num=64, strides=1, layer_name='basic', padding_mode='same'):
    # k1, k2 = kernel
    conv1 = conv2d_bn(input, kernel_num=kernel_num, strides=strides, kernel_size=3,
                      layer_name=layer_name+'_1', padding_mode=padding_mode)
    relu1 = ReLU(name=layer_name + '_relu1')(conv1)
    conv2 = conv2d_bn(relu1, kernel_num=kernel_num, strides=strides, kernel_size=3,
                      layer_name=layer_name+'_2', padding_mode=padding_mode)
    relu2 = ReLU(name=layer_name + '_relu2')(conv2)

    shortcut_add = shortcut(fx=relu2, x=input, layer_name=layer_name)
    relu3 = ReLU(name=layer_name + '_relu3')(shortcut_add)
    return relu3


def make_layer(input, block, block_num, kernel_num, layer_name=''):
        x = input
        for i in range(1, block_num+1):
            x = block(x, kernel_num=kernel_num, strides=1, layer_name=layer_name+str(i), padding_mode='same')
        return x


def ResNet(input_shape, nclass, net_name='resnet18'):
    """
        :param input_shape:
        :param nclass:
        :param block:
        :return:
    """
    block_setting = {}
    block_setting['resnet18'] = {'block': basic_block, 'block_num': [2, 2, 2, 2], 'kernel_num': [64, 128, 256, 512]}
    block_setting['resnet34'] = {'block': basic_block, 'block_num': [3, 4, 6, 3], 'kernel_num': [64, 128, 256, 512]}
    block_setting['resnet50'] = {'block': bottleneck, 'block_num': [3, 4, 6, 3], 'kernel_num': [[64, 64, 256], [128, 128, 512],
                                                                           [256, 256, 1024], [512, 512, 2048]]}
    block_setting['resnet101'] = {'block': bottleneck, 'block_num': [3, 4, 23, 3], 'kernel_num': [[64, 64, 256], [128, 128, 512],
                                                                           [256, 256, 1024], [512, 512, 2048]]}
    block_setting['resnet152'] = {'block': bottleneck, 'block_num': [3, 8, 36, 3], 'kernel_num': [[64, 64, 256], [128, 128, 512],
                                                                           [256, 256, 1024], [512, 512, 2048]]}
    net_name = 'resnet18' if not block_setting.__contains__(net_name) else net_name
    block_num = block_setting[net_name]['block_num']
    kernel_num = block_setting[net_name]['kernel_num']
    block = block_setting[net_name]['block']

    input_ = Input(shape=input_shape)
    conv1 = conv2d_bn(input_, 64, kernel_size=7, strides=2, layer_name='first_conv')
    pool1 = MaxPool2D(pool_size=(3, 3), strides=2, padding='same', name='pool1')(conv1)

    conv = pool1
    for i in range(4):
          conv = make_layer(conv, block=block, block_num=block_num[i], kernel_num=kernel_num[i], layer_name='layer'+str(i+1))

    pool2 = GlobalAvgPool2D(name='globalavgpool')(conv)
    output_ = Dense(nclass, activation='softmax', name='dense')(pool2)

    model = Model(inputs=input_, outputs=output_, name='ResNet18')
    model.summary()

    return None


if __name__ == '__main__':

    ResNet(input_shape=(30, 30, 3), nclass=3, net_name='resnet152')

ResNet18网络的summary如下:

Model: "ResNet18"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            (None, 30, 30, 3)    0                                            
__________________________________________________________________________________________________
first_conv_conv1 (Conv2D)       (None, 15, 15, 64)   9472        input_1[0][0]                    
__________________________________________________________________________________________________
first_conv_bn1 (BatchNormalizat (None, 15, 15, 64)   256         first_conv_conv1[0][0]           
__________________________________________________________________________________________________
pool1 (MaxPooling2D)            (None, 8, 8, 64)     0           first_conv_bn1[0][0]             
__________________________________________________________________________________________________
layer11_1_conv1 (Conv2D)        (None, 8, 8, 64)     36928       pool1[0][0]                      
__________________________________________________________________________________________________
layer11_1_bn1 (BatchNormalizati (None, 8, 8, 64)     256         layer11_1_conv1[0][0]            
__________________________________________________________________________________________________
layer11_relu1 (ReLU)            (None, 8, 8, 64)     0           layer11_1_bn1[0][0]              
__________________________________________________________________________________________________
layer11_2_conv1 (Conv2D)        (None, 8, 8, 64)     36928       layer11_relu1[0][0]              
__________________________________________________________________________________________________
layer11_2_bn1 (BatchNormalizati (None, 8, 8, 64)     256         layer11_2_conv1[0][0]            
__________________________________________________________________________________________________
layer11_relu2 (ReLU)            (None, 8, 8, 64)     0           layer11_2_bn1[0][0]              
__________________________________________________________________________________________________
layer11_shortcut_add (Add)      (None, 8, 8, 64)     0           pool1[0][0]                      
                                                                 layer11_relu2[0][0]              
__________________________________________________________________________________________________
layer11_relu3 (ReLU)            (None, 8, 8, 64)     0           layer11_shortcut_add[0][0]       
__________________________________________________________________________________________________
layer12_1_conv1 (Conv2D)        (None, 8, 8, 64)     36928       layer11_relu3[0][0]              
__________________________________________________________________________________________________
layer12_1_bn1 (BatchNormalizati (None, 8, 8, 64)     256         layer12_1_conv1[0][0]            
__________________________________________________________________________________________________
layer12_relu1 (ReLU)            (None, 8, 8, 64)     0           layer12_1_bn1[0][0]              
__________________________________________________________________________________________________
layer12_2_conv1 (Conv2D)        (None, 8, 8, 64)     36928       layer12_relu1[0][0]              
__________________________________________________________________________________________________
layer12_2_bn1 (BatchNormalizati (None, 8, 8, 64)     256         layer12_2_conv1[0][0]            
__________________________________________________________________________________________________
layer12_relu2 (ReLU)            (None, 8, 8, 64)     0           layer12_2_bn1[0][0]              
__________________________________________________________________________________________________
layer12_shortcut_add (Add)      (None, 8, 8, 64)     0           layer11_relu3[0][0]              
                                                                 layer12_relu2[0][0]              
__________________________________________________________________________________________________
layer12_relu3 (ReLU)            (None, 8, 8, 64)     0           layer12_shortcut_add[0][0]       
__________________________________________________________________________________________________
layer21_1_conv1 (Conv2D)        (None, 8, 8, 126)    72702       layer12_relu3[0][0]              
__________________________________________________________________________________________________
layer21_1_bn1 (BatchNormalizati (None, 8, 8, 126)    504         layer21_1_conv1[0][0]            
__________________________________________________________________________________________________
layer21_relu1 (ReLU)            (None, 8, 8, 126)    0           layer21_1_bn1[0][0]              
__________________________________________________________________________________________________
layer21_2_conv1 (Conv2D)        (None, 8, 8, 126)    143010      layer21_relu1[0][0]              
__________________________________________________________________________________________________
layer21_shortcut_conv1 (Conv2D) (None, 8, 8, 126)    8190        layer12_relu3[0][0]              
__________________________________________________________________________________________________
layer21_2_bn1 (BatchNormalizati (None, 8, 8, 126)    504         layer21_2_conv1[0][0]            
__________________________________________________________________________________________________
layer21_shortcut_bn1 (BatchNorm (None, 8, 8, 126)    504         layer21_shortcut_conv1[0][0]     
__________________________________________________________________________________________________
layer21_relu2 (ReLU)            (None, 8, 8, 126)    0           layer21_2_bn1[0][0]              
__________________________________________________________________________________________________
layer21_shortcut_add (Add)      (None, 8, 8, 126)    0           layer21_shortcut_bn1[0][0]       
                                                                 layer21_relu2[0][0]              
__________________________________________________________________________________________________
layer21_relu3 (ReLU)            (None, 8, 8, 126)    0           layer21_shortcut_add[0][0]       
__________________________________________________________________________________________________
layer22_1_conv1 (Conv2D)        (None, 8, 8, 126)    143010      layer21_relu3[0][0]              
__________________________________________________________________________________________________
layer22_1_bn1 (BatchNormalizati (None, 8, 8, 126)    504         layer22_1_conv1[0][0]            
__________________________________________________________________________________________________
layer22_relu1 (ReLU)            (None, 8, 8, 126)    0           layer22_1_bn1[0][0]              
__________________________________________________________________________________________________
layer22_2_conv1 (Conv2D)        (None, 8, 8, 126)    143010      layer22_relu1[0][0]              
__________________________________________________________________________________________________
layer22_2_bn1 (BatchNormalizati (None, 8, 8, 126)    504         layer22_2_conv1[0][0]            
__________________________________________________________________________________________________
layer22_relu2 (ReLU)            (None, 8, 8, 126)    0           layer22_2_bn1[0][0]              
__________________________________________________________________________________________________
layer22_shortcut_add (Add)      (None, 8, 8, 126)    0           layer21_relu3[0][0]              
                                                                 layer22_relu2[0][0]              
__________________________________________________________________________________________________
layer22_relu3 (ReLU)            (None, 8, 8, 126)    0           layer22_shortcut_add[0][0]       
__________________________________________________________________________________________________
layer31_1_conv1 (Conv2D)        (None, 8, 8, 256)    290560      layer22_relu3[0][0]              
__________________________________________________________________________________________________
layer31_1_bn1 (BatchNormalizati (None, 8, 8, 256)    1024        layer31_1_conv1[0][0]            
__________________________________________________________________________________________________
layer31_relu1 (ReLU)            (None, 8, 8, 256)    0           layer31_1_bn1[0][0]              
__________________________________________________________________________________________________
layer31_2_conv1 (Conv2D)        (None, 8, 8, 256)    590080      layer31_relu1[0][0]              
__________________________________________________________________________________________________
layer31_shortcut_conv1 (Conv2D) (None, 8, 8, 256)    32512       layer22_relu3[0][0]              
__________________________________________________________________________________________________
layer31_2_bn1 (BatchNormalizati (None, 8, 8, 256)    1024        layer31_2_conv1[0][0]            
__________________________________________________________________________________________________
layer31_shortcut_bn1 (BatchNorm (None, 8, 8, 256)    1024        layer31_shortcut_conv1[0][0]     
__________________________________________________________________________________________________
layer31_relu2 (ReLU)            (None, 8, 8, 256)    0           layer31_2_bn1[0][0]              
__________________________________________________________________________________________________
layer31_shortcut_add (Add)      (None, 8, 8, 256)    0           layer31_shortcut_bn1[0][0]       
                                                                 layer31_relu2[0][0]              
__________________________________________________________________________________________________
layer31_relu3 (ReLU)            (None, 8, 8, 256)    0           layer31_shortcut_add[0][0]       
__________________________________________________________________________________________________
layer32_1_conv1 (Conv2D)        (None, 8, 8, 256)    590080      layer31_relu3[0][0]              
__________________________________________________________________________________________________
layer32_1_bn1 (BatchNormalizati (None, 8, 8, 256)    1024        layer32_1_conv1[0][0]            
__________________________________________________________________________________________________
layer32_relu1 (ReLU)            (None, 8, 8, 256)    0           layer32_1_bn1[0][0]              
__________________________________________________________________________________________________
layer32_2_conv1 (Conv2D)        (None, 8, 8, 256)    590080      layer32_relu1[0][0]              
__________________________________________________________________________________________________
layer32_2_bn1 (BatchNormalizati (None, 8, 8, 256)    1024        layer32_2_conv1[0][0]            
__________________________________________________________________________________________________
layer32_relu2 (ReLU)            (None, 8, 8, 256)    0           layer32_2_bn1[0][0]              
__________________________________________________________________________________________________
layer32_shortcut_add (Add)      (None, 8, 8, 256)    0           layer31_relu3[0][0]              
                                                                 layer32_relu2[0][0]              
__________________________________________________________________________________________________
layer32_relu3 (ReLU)            (None, 8, 8, 256)    0           layer32_shortcut_add[0][0]       
__________________________________________________________________________________________________
layer41_1_conv1 (Conv2D)        (None, 8, 8, 512)    1180160     layer32_relu3[0][0]              
__________________________________________________________________________________________________
layer41_1_bn1 (BatchNormalizati (None, 8, 8, 512)    2048        layer41_1_conv1[0][0]            
__________________________________________________________________________________________________
layer41_relu1 (ReLU)            (None, 8, 8, 512)    0           layer41_1_bn1[0][0]              
__________________________________________________________________________________________________
layer41_2_conv1 (Conv2D)        (None, 8, 8, 512)    2359808     layer41_relu1[0][0]              
__________________________________________________________________________________________________
layer41_shortcut_conv1 (Conv2D) (None, 8, 8, 512)    131584      layer32_relu3[0][0]              
__________________________________________________________________________________________________
layer41_2_bn1 (BatchNormalizati (None, 8, 8, 512)    2048        layer41_2_conv1[0][0]            
__________________________________________________________________________________________________
layer41_shortcut_bn1 (BatchNorm (None, 8, 8, 512)    2048        layer41_shortcut_conv1[0][0]     
__________________________________________________________________________________________________
layer41_relu2 (ReLU)            (None, 8, 8, 512)    0           layer41_2_bn1[0][0]              
__________________________________________________________________________________________________
layer41_shortcut_add (Add)      (None, 8, 8, 512)    0           layer41_shortcut_bn1[0][0]       
                                                                 layer41_relu2[0][0]              
__________________________________________________________________________________________________
layer41_relu3 (ReLU)            (None, 8, 8, 512)    0           layer41_shortcut_add[0][0]       
__________________________________________________________________________________________________
layer42_1_conv1 (Conv2D)        (None, 8, 8, 512)    2359808     layer41_relu3[0][0]              
__________________________________________________________________________________________________
layer42_1_bn1 (BatchNormalizati (None, 8, 8, 512)    2048        layer42_1_conv1[0][0]            
__________________________________________________________________________________________________
layer42_relu1 (ReLU)            (None, 8, 8, 512)    0           layer42_1_bn1[0][0]              
__________________________________________________________________________________________________
layer42_2_conv1 (Conv2D)        (None, 8, 8, 512)    2359808     layer42_relu1[0][0]              
__________________________________________________________________________________________________
layer42_2_bn1 (BatchNormalizati (None, 8, 8, 512)    2048        layer42_2_conv1[0][0]            
__________________________________________________________________________________________________
layer42_relu2 (ReLU)            (None, 8, 8, 512)    0           layer42_2_bn1[0][0]              
__________________________________________________________________________________________________
layer42_shortcut_add (Add)      (None, 8, 8, 512)    0           layer41_relu3[0][0]              
                                                                 layer42_relu2[0][0]              
__________________________________________________________________________________________________
layer42_relu3 (ReLU)            (None, 8, 8, 512)    0           layer42_shortcut_add[0][0]       
__________________________________________________________________________________________________
globalavgpool (GlobalAveragePoo (None, 512)          0           layer42_relu3[0][0]              
__________________________________________________________________________________________________
dense (Dense)                   (None, 3)            1539        globalavgpool[0][0]              
==================================================================================================
Total params: 11,172,285
Trainable params: 11,162,705
Non-trainable params: 9,580

参考:

【Pytorch】ResNet-18实现Cifar-10图像分类_JIN_嫣熙的博客-CSDN博客_resnet50跑cifar10

resnet18 50网络结构以及pytorch实现代码 - 简书

【深度学习】ResNet解读及代码实现_z小白的博客-CSDN博客_resnet代码

你可能感兴趣的:(深度学习,Keras,keras,深度学习,tensorflow)