ResNet的结构包括以下5种,都是一开始有一层卷积,一个maxpooling,接着4层结构,最后一层全连接。
5种网络的主要差异就是卷积核的数目和每层结构的不同,resnet18和resnet34用的都是基本结构,即3x3卷积;resnet50,resnet101,resnet152用的是bottleneck结构,每层block的数量也不同。
代码里的基础block都是卷积+batch normalization,就放在一起了,后面再接relu。shortcut主要是处理x和f(x)维度不同的情况,如果channel不同,就用1x1的卷积先改变x的channel。卷积核的尺寸除了第一层卷积为7x7,其它的都是3x3。而basic_block和bottleneck内部的卷积核尺寸都是固定的。
from keras.layers import *
from keras.models import Input
from keras import Model
import tensorflow as tf
import numpy as np
def conv2d_bn(input, kernel_num, kernel_size=3, strides=1, layer_name='', padding_mode='same'):
conv1 = Conv2D(kernel_num, kernel_size, strides=strides, padding=padding_mode, name=layer_name + '_conv1')(input)
batch1 = BatchNormalization(name=layer_name + '_bn1')(conv1)
return batch1
def shortcut(fx, x, padding_mode='same', layer_name=''):
layer_name += '_shortcut'
if x.shape[-1] != fx.shape[-1]:
k = fx.shape[-1]
k = int(k)
identity = conv2d_bn(x, kernel_num=k, kernel_size=1, padding_mode=padding_mode, layer_name=layer_name)
else:
identity = x
return Add(name=layer_name + '_add')([identity, fx])
def bottleneck(input, kernel_num, strides=1, layer_name='bottleneck', padding_mode='same'):
k1, k2, k3 = kernel_num
conv1 = conv2d_bn(input, kernel_num=k1, kernel_size=1, strides=strides, padding_mode=padding_mode, layer_name=layer_name+'_1')
relu1 = ReLU(name=layer_name + '_relu1')(conv1)
conv2 = conv2d_bn(relu1, kernel_num=k2, kernel_size=3, strides=strides, padding_mode=padding_mode, layer_name=layer_name+'_2')
relu2 = ReLU(name=layer_name + '_relu2')(conv2)
conv3 = conv2d_bn(relu2, kernel_num=k3, kernel_size=1, strides=strides, padding_mode=padding_mode, layer_name=layer_name+'_3')
# print(conv3.shape, input.shape)
shortcut_add = shortcut(fx=conv3, x=input, layer_name=layer_name)
relu3 = ReLU(name=layer_name + '_relu3')(shortcut_add)
return relu3
def basic_block(input, kernel_num=64, strides=1, layer_name='basic', padding_mode='same'):
# k1, k2 = kernel
conv1 = conv2d_bn(input, kernel_num=kernel_num, strides=strides, kernel_size=3,
layer_name=layer_name+'_1', padding_mode=padding_mode)
relu1 = ReLU(name=layer_name + '_relu1')(conv1)
conv2 = conv2d_bn(relu1, kernel_num=kernel_num, strides=strides, kernel_size=3,
layer_name=layer_name+'_2', padding_mode=padding_mode)
relu2 = ReLU(name=layer_name + '_relu2')(conv2)
shortcut_add = shortcut(fx=relu2, x=input, layer_name=layer_name)
relu3 = ReLU(name=layer_name + '_relu3')(shortcut_add)
return relu3
def make_layer(input, block, block_num, kernel_num, layer_name=''):
x = input
for i in range(1, block_num+1):
x = block(x, kernel_num=kernel_num, strides=1, layer_name=layer_name+str(i), padding_mode='same')
return x
def ResNet(input_shape, nclass, net_name='resnet18'):
"""
:param input_shape:
:param nclass:
:param block:
:return:
"""
block_setting = {}
block_setting['resnet18'] = {'block': basic_block, 'block_num': [2, 2, 2, 2], 'kernel_num': [64, 128, 256, 512]}
block_setting['resnet34'] = {'block': basic_block, 'block_num': [3, 4, 6, 3], 'kernel_num': [64, 128, 256, 512]}
block_setting['resnet50'] = {'block': bottleneck, 'block_num': [3, 4, 6, 3], 'kernel_num': [[64, 64, 256], [128, 128, 512],
[256, 256, 1024], [512, 512, 2048]]}
block_setting['resnet101'] = {'block': bottleneck, 'block_num': [3, 4, 23, 3], 'kernel_num': [[64, 64, 256], [128, 128, 512],
[256, 256, 1024], [512, 512, 2048]]}
block_setting['resnet152'] = {'block': bottleneck, 'block_num': [3, 8, 36, 3], 'kernel_num': [[64, 64, 256], [128, 128, 512],
[256, 256, 1024], [512, 512, 2048]]}
net_name = 'resnet18' if not block_setting.__contains__(net_name) else net_name
block_num = block_setting[net_name]['block_num']
kernel_num = block_setting[net_name]['kernel_num']
block = block_setting[net_name]['block']
input_ = Input(shape=input_shape)
conv1 = conv2d_bn(input_, 64, kernel_size=7, strides=2, layer_name='first_conv')
pool1 = MaxPool2D(pool_size=(3, 3), strides=2, padding='same', name='pool1')(conv1)
conv = pool1
for i in range(4):
conv = make_layer(conv, block=block, block_num=block_num[i], kernel_num=kernel_num[i], layer_name='layer'+str(i+1))
pool2 = GlobalAvgPool2D(name='globalavgpool')(conv)
output_ = Dense(nclass, activation='softmax', name='dense')(pool2)
model = Model(inputs=input_, outputs=output_, name='ResNet18')
model.summary()
return None
if __name__ == '__main__':
ResNet(input_shape=(30, 30, 3), nclass=3, net_name='resnet152')
ResNet18网络的summary如下:
Model: "ResNet18"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) (None, 30, 30, 3) 0
__________________________________________________________________________________________________
first_conv_conv1 (Conv2D) (None, 15, 15, 64) 9472 input_1[0][0]
__________________________________________________________________________________________________
first_conv_bn1 (BatchNormalizat (None, 15, 15, 64) 256 first_conv_conv1[0][0]
__________________________________________________________________________________________________
pool1 (MaxPooling2D) (None, 8, 8, 64) 0 first_conv_bn1[0][0]
__________________________________________________________________________________________________
layer11_1_conv1 (Conv2D) (None, 8, 8, 64) 36928 pool1[0][0]
__________________________________________________________________________________________________
layer11_1_bn1 (BatchNormalizati (None, 8, 8, 64) 256 layer11_1_conv1[0][0]
__________________________________________________________________________________________________
layer11_relu1 (ReLU) (None, 8, 8, 64) 0 layer11_1_bn1[0][0]
__________________________________________________________________________________________________
layer11_2_conv1 (Conv2D) (None, 8, 8, 64) 36928 layer11_relu1[0][0]
__________________________________________________________________________________________________
layer11_2_bn1 (BatchNormalizati (None, 8, 8, 64) 256 layer11_2_conv1[0][0]
__________________________________________________________________________________________________
layer11_relu2 (ReLU) (None, 8, 8, 64) 0 layer11_2_bn1[0][0]
__________________________________________________________________________________________________
layer11_shortcut_add (Add) (None, 8, 8, 64) 0 pool1[0][0]
layer11_relu2[0][0]
__________________________________________________________________________________________________
layer11_relu3 (ReLU) (None, 8, 8, 64) 0 layer11_shortcut_add[0][0]
__________________________________________________________________________________________________
layer12_1_conv1 (Conv2D) (None, 8, 8, 64) 36928 layer11_relu3[0][0]
__________________________________________________________________________________________________
layer12_1_bn1 (BatchNormalizati (None, 8, 8, 64) 256 layer12_1_conv1[0][0]
__________________________________________________________________________________________________
layer12_relu1 (ReLU) (None, 8, 8, 64) 0 layer12_1_bn1[0][0]
__________________________________________________________________________________________________
layer12_2_conv1 (Conv2D) (None, 8, 8, 64) 36928 layer12_relu1[0][0]
__________________________________________________________________________________________________
layer12_2_bn1 (BatchNormalizati (None, 8, 8, 64) 256 layer12_2_conv1[0][0]
__________________________________________________________________________________________________
layer12_relu2 (ReLU) (None, 8, 8, 64) 0 layer12_2_bn1[0][0]
__________________________________________________________________________________________________
layer12_shortcut_add (Add) (None, 8, 8, 64) 0 layer11_relu3[0][0]
layer12_relu2[0][0]
__________________________________________________________________________________________________
layer12_relu3 (ReLU) (None, 8, 8, 64) 0 layer12_shortcut_add[0][0]
__________________________________________________________________________________________________
layer21_1_conv1 (Conv2D) (None, 8, 8, 126) 72702 layer12_relu3[0][0]
__________________________________________________________________________________________________
layer21_1_bn1 (BatchNormalizati (None, 8, 8, 126) 504 layer21_1_conv1[0][0]
__________________________________________________________________________________________________
layer21_relu1 (ReLU) (None, 8, 8, 126) 0 layer21_1_bn1[0][0]
__________________________________________________________________________________________________
layer21_2_conv1 (Conv2D) (None, 8, 8, 126) 143010 layer21_relu1[0][0]
__________________________________________________________________________________________________
layer21_shortcut_conv1 (Conv2D) (None, 8, 8, 126) 8190 layer12_relu3[0][0]
__________________________________________________________________________________________________
layer21_2_bn1 (BatchNormalizati (None, 8, 8, 126) 504 layer21_2_conv1[0][0]
__________________________________________________________________________________________________
layer21_shortcut_bn1 (BatchNorm (None, 8, 8, 126) 504 layer21_shortcut_conv1[0][0]
__________________________________________________________________________________________________
layer21_relu2 (ReLU) (None, 8, 8, 126) 0 layer21_2_bn1[0][0]
__________________________________________________________________________________________________
layer21_shortcut_add (Add) (None, 8, 8, 126) 0 layer21_shortcut_bn1[0][0]
layer21_relu2[0][0]
__________________________________________________________________________________________________
layer21_relu3 (ReLU) (None, 8, 8, 126) 0 layer21_shortcut_add[0][0]
__________________________________________________________________________________________________
layer22_1_conv1 (Conv2D) (None, 8, 8, 126) 143010 layer21_relu3[0][0]
__________________________________________________________________________________________________
layer22_1_bn1 (BatchNormalizati (None, 8, 8, 126) 504 layer22_1_conv1[0][0]
__________________________________________________________________________________________________
layer22_relu1 (ReLU) (None, 8, 8, 126) 0 layer22_1_bn1[0][0]
__________________________________________________________________________________________________
layer22_2_conv1 (Conv2D) (None, 8, 8, 126) 143010 layer22_relu1[0][0]
__________________________________________________________________________________________________
layer22_2_bn1 (BatchNormalizati (None, 8, 8, 126) 504 layer22_2_conv1[0][0]
__________________________________________________________________________________________________
layer22_relu2 (ReLU) (None, 8, 8, 126) 0 layer22_2_bn1[0][0]
__________________________________________________________________________________________________
layer22_shortcut_add (Add) (None, 8, 8, 126) 0 layer21_relu3[0][0]
layer22_relu2[0][0]
__________________________________________________________________________________________________
layer22_relu3 (ReLU) (None, 8, 8, 126) 0 layer22_shortcut_add[0][0]
__________________________________________________________________________________________________
layer31_1_conv1 (Conv2D) (None, 8, 8, 256) 290560 layer22_relu3[0][0]
__________________________________________________________________________________________________
layer31_1_bn1 (BatchNormalizati (None, 8, 8, 256) 1024 layer31_1_conv1[0][0]
__________________________________________________________________________________________________
layer31_relu1 (ReLU) (None, 8, 8, 256) 0 layer31_1_bn1[0][0]
__________________________________________________________________________________________________
layer31_2_conv1 (Conv2D) (None, 8, 8, 256) 590080 layer31_relu1[0][0]
__________________________________________________________________________________________________
layer31_shortcut_conv1 (Conv2D) (None, 8, 8, 256) 32512 layer22_relu3[0][0]
__________________________________________________________________________________________________
layer31_2_bn1 (BatchNormalizati (None, 8, 8, 256) 1024 layer31_2_conv1[0][0]
__________________________________________________________________________________________________
layer31_shortcut_bn1 (BatchNorm (None, 8, 8, 256) 1024 layer31_shortcut_conv1[0][0]
__________________________________________________________________________________________________
layer31_relu2 (ReLU) (None, 8, 8, 256) 0 layer31_2_bn1[0][0]
__________________________________________________________________________________________________
layer31_shortcut_add (Add) (None, 8, 8, 256) 0 layer31_shortcut_bn1[0][0]
layer31_relu2[0][0]
__________________________________________________________________________________________________
layer31_relu3 (ReLU) (None, 8, 8, 256) 0 layer31_shortcut_add[0][0]
__________________________________________________________________________________________________
layer32_1_conv1 (Conv2D) (None, 8, 8, 256) 590080 layer31_relu3[0][0]
__________________________________________________________________________________________________
layer32_1_bn1 (BatchNormalizati (None, 8, 8, 256) 1024 layer32_1_conv1[0][0]
__________________________________________________________________________________________________
layer32_relu1 (ReLU) (None, 8, 8, 256) 0 layer32_1_bn1[0][0]
__________________________________________________________________________________________________
layer32_2_conv1 (Conv2D) (None, 8, 8, 256) 590080 layer32_relu1[0][0]
__________________________________________________________________________________________________
layer32_2_bn1 (BatchNormalizati (None, 8, 8, 256) 1024 layer32_2_conv1[0][0]
__________________________________________________________________________________________________
layer32_relu2 (ReLU) (None, 8, 8, 256) 0 layer32_2_bn1[0][0]
__________________________________________________________________________________________________
layer32_shortcut_add (Add) (None, 8, 8, 256) 0 layer31_relu3[0][0]
layer32_relu2[0][0]
__________________________________________________________________________________________________
layer32_relu3 (ReLU) (None, 8, 8, 256) 0 layer32_shortcut_add[0][0]
__________________________________________________________________________________________________
layer41_1_conv1 (Conv2D) (None, 8, 8, 512) 1180160 layer32_relu3[0][0]
__________________________________________________________________________________________________
layer41_1_bn1 (BatchNormalizati (None, 8, 8, 512) 2048 layer41_1_conv1[0][0]
__________________________________________________________________________________________________
layer41_relu1 (ReLU) (None, 8, 8, 512) 0 layer41_1_bn1[0][0]
__________________________________________________________________________________________________
layer41_2_conv1 (Conv2D) (None, 8, 8, 512) 2359808 layer41_relu1[0][0]
__________________________________________________________________________________________________
layer41_shortcut_conv1 (Conv2D) (None, 8, 8, 512) 131584 layer32_relu3[0][0]
__________________________________________________________________________________________________
layer41_2_bn1 (BatchNormalizati (None, 8, 8, 512) 2048 layer41_2_conv1[0][0]
__________________________________________________________________________________________________
layer41_shortcut_bn1 (BatchNorm (None, 8, 8, 512) 2048 layer41_shortcut_conv1[0][0]
__________________________________________________________________________________________________
layer41_relu2 (ReLU) (None, 8, 8, 512) 0 layer41_2_bn1[0][0]
__________________________________________________________________________________________________
layer41_shortcut_add (Add) (None, 8, 8, 512) 0 layer41_shortcut_bn1[0][0]
layer41_relu2[0][0]
__________________________________________________________________________________________________
layer41_relu3 (ReLU) (None, 8, 8, 512) 0 layer41_shortcut_add[0][0]
__________________________________________________________________________________________________
layer42_1_conv1 (Conv2D) (None, 8, 8, 512) 2359808 layer41_relu3[0][0]
__________________________________________________________________________________________________
layer42_1_bn1 (BatchNormalizati (None, 8, 8, 512) 2048 layer42_1_conv1[0][0]
__________________________________________________________________________________________________
layer42_relu1 (ReLU) (None, 8, 8, 512) 0 layer42_1_bn1[0][0]
__________________________________________________________________________________________________
layer42_2_conv1 (Conv2D) (None, 8, 8, 512) 2359808 layer42_relu1[0][0]
__________________________________________________________________________________________________
layer42_2_bn1 (BatchNormalizati (None, 8, 8, 512) 2048 layer42_2_conv1[0][0]
__________________________________________________________________________________________________
layer42_relu2 (ReLU) (None, 8, 8, 512) 0 layer42_2_bn1[0][0]
__________________________________________________________________________________________________
layer42_shortcut_add (Add) (None, 8, 8, 512) 0 layer41_relu3[0][0]
layer42_relu2[0][0]
__________________________________________________________________________________________________
layer42_relu3 (ReLU) (None, 8, 8, 512) 0 layer42_shortcut_add[0][0]
__________________________________________________________________________________________________
globalavgpool (GlobalAveragePoo (None, 512) 0 layer42_relu3[0][0]
__________________________________________________________________________________________________
dense (Dense) (None, 3) 1539 globalavgpool[0][0]
==================================================================================================
Total params: 11,172,285
Trainable params: 11,162,705
Non-trainable params: 9,580
参考:
【Pytorch】ResNet-18实现Cifar-10图像分类_JIN_嫣熙的博客-CSDN博客_resnet50跑cifar10
resnet18 50网络结构以及pytorch实现代码 - 简书
【深度学习】ResNet解读及代码实现_z小白的博客-CSDN博客_resnet代码