Resnet 50网络架构的实现
- 具体代码: https://github.com/xiaoaleiBLUE/computer_vision
- 前几天实现Resnet 18网络结构,自己今天看了一下Resnet 50网络结构,昨天晚上利用一点时间手动实现一下,程序没有报错,应该是把Resnet 50网络结构给实现了,于是今天把Resnet 50网络结构画了出来,参照别人画的。
文章目录
- Resnet 50网络架构的实现
- 一、Resnet 50网络
-
- 二、实现的基本思想
-
- 1.layer_1
- 2.layer_2, layer_3,...
- 三、Resnet 50网络代码
一、Resnet 50网络
1.网络架构
2.网络架构具体参数整理
二、实现的基本思想
1.layer_1
- 对于layer_1的3个Bottleneck ,通过的Conv模块可以进行复用的,同时只有在第一个Bottleneck在残差连接模块进行 Conv 操作, 其余Bottleneck在进行残差连接没有进行Conv操作。
- 代码实现
class ResBlock_1(nn.Module):
def __init__(self, down_sample, in_channels, min_channels, out_channels):
super(ResBlock_1, self).__init__()
self.down_sample = down_sample
self.conv_1 = nn.Sequential(
nn.Conv2d(in_channels, min_channels, 1, 1, 0),
nn.BatchNorm2d(min_channels),
nn.ReLU(),
nn.Conv2d(min_channels, min_channels, 3, 1, 1),
nn.BatchNorm2d(min_channels),
nn.ReLU(),
nn.Conv2d(min_channels, out_channels, 3, 1, 1),
nn.BatchNorm2d(out_channels),
nn.ReLU()
)
self.shortcut_conv = nn.Conv2d(in_channels, out_channels, 1, 1, 0)
self.shortcut_bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU()
def forward(self, x):
if self.down_sample is True:
shortcut = self.shortcut_bn(self.shortcut_conv(x))
x = self.conv_1(x)
x = x + shortcut
x = self.relu(x)
else:
shortcut = x
x = self.conv_1(x)
x = x + shortcut
x = self.relu(x)
return x
- 进行输入一个张量进行模型查看,看看是shape否和网络结构图保持一致
device = 'cuda' if torch.cuda.is_available() else 'cpu'
resblock = ResBlock_1(True, 64, 64, 256).to(device)
summary(resblock, (64, 56, 56))
2.layer_2, layer_3,…
- 仔细观察layer_2, layer_3,…我们发现基本单元也对应一致是一致(k,s,p), 只是Conv输出的通道不一致。
-
- 编写基本单元模块代码
class ResBlock_2(nn.Module):
def __init__(self, down_sample, in_channels, min_channels, out_channels):
super(ResBlock_2, self).__init__()
self.down_sample = down_sample
if self.down_sample is True:
s = 2
else:
s = 1
self.conv_2 = nn.Sequential(
nn.Conv2d(in_channels, min_channels, 1, 1, 0),
nn.BatchNorm2d(min_channels),
nn.ReLU(),
nn.Conv2d(min_channels, min_channels, 3, s, 1),
nn.BatchNorm2d(min_channels),
nn.ReLU(),
nn.Conv2d(min_channels, out_channels, 1, 1, 0),
nn.BatchNorm2d(out_channels),
nn.ReLU()
)
self.shortcut_conv = nn.Conv2d(in_channels, out_channels, 1, 2, 0)
self.shortcut_bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU()
def forward(self, x):
if self.down_sample is True:
shortcut = self.shortcut_bn(self.shortcut_conv(x))
x = self.conv_2(x)
x = x + shortcut
x = self.relu(x)
else:
shortcut = x
x = self.conv_2(x)
x = x + shortcut
x = self.relu(x)
return x
- 进行输入一个张量进行模型查看,看看是shape否和网络结构图保持一致
device = 'cuda' if torch.cuda.is_available() else 'cpu'
resblock = ResBlock_2(True, 256, 128, 512).to(device)
summary(resblock, (256, 56, 56))
三、Resnet 50网络代码
class Resnet_50(nn.Module):
"""
搭建一个简单的残差网络: RESNET 50
输入: 224*224*3
输出: 1000类
"""
def __init__(self):
super(Resnet_50, self).__init__()
self.layer_0 = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)
self.layer_1 = nn.Sequential(
ResBlock_1(True, 64, 64, 256),
ResBlock_1(False, 256, 64, 256),
ResBlock_1(False, 256, 64, 256),
)
self.layer_2 = nn.Sequential(
ResBlock_2(True, 256, 128, 512),
ResBlock_2(False, 512, 128, 512),
ResBlock_2(False, 512, 128, 512),
ResBlock_2(False, 512, 128, 512),
)
self.layer_3 = nn.Sequential(
ResBlock_2(True, 512, 256, 1024),
ResBlock_2(False, 1024, 256, 1024),
ResBlock_2(False, 1024, 256, 1024),
ResBlock_2(False, 1024, 256, 1024),
ResBlock_2(False, 1024, 256, 1024),
ResBlock_2(False, 1024, 256, 1024),
)
self.layer_4 = nn.Sequential(
ResBlock_2(True, 1024, 512, 2048),
ResBlock_2(False, 2048, 512, 2048),
ResBlock_2(False, 2048, 512, 2048),
)
self.app = nn.AdaptiveAvgPool2d(1)
self.flatten = nn.Flatten()
self.linear = nn.Linear(2048, 1000)
def forward(self, x):
x = self.layer_0(x)
x = self.layer_1(x)
x = self.layer_2(x)
x = self.layer_3(x)
x = self.layer_4(x)
x = self.app(x)
x = self.flatten(x)
x = self.linear(x)
return x
- 进行输入一个张量进行模型查看,看看是shape否和网络结构图保持一致
device = 'cuda' if torch.cuda.is_available() else 'cpu'
resblock = Resnet_50().to(device)
summary(resblock, (3, 224, 224))