利用闲暇时间写了resnet18 的实现代码,可能存在错误,看官可以给与指正。
pytorch中给与了resnet的实现模型,可以供小白调用,这里不赘述方法。下面所有代码的实现都是使用pytorch框架书写,采用python语言。
网络上搜索到的resne18的网络结构图如下。resnet18只看图中左侧网络结构就可以。(ps:使用的是简书上一个博主的图,如有冒犯,请谅解)
接下来,根据如图的网络结构进行搭建网络。通过观察网络结构,发现在网络结构中存在两种不同基础块,第一种是实现标注跳跃连接的部分,如下:在这个块中具体实现工作流程如下图:
实现方法如下:
import torch.nn as nn
import torch.nn.functional as F
class basic_block(nn.Module):
'''定义了带实线部分的残差块'''
def __init__(self,in_channels):
super(basic_block, self).__init__()
self.conv1 = nn.Conv2d(in_channels,in_channels,kernel_size=3,stride=1,padding=1)
self.conv2 = nn.Conv2d(in_channels,in_channels,kernel_size=3,stride=1,padding=1)
def forward(self, x):
y = F.relu(self.conv1(x))
y = self.conv2(y)
return F.relu(x+y)
第二种是带有虚线标注的跳跃连接部分,第一种结构是在通道数不变的情况下,进行的残差结构运算,第二种的跳跃连接结构,通道数发生了改变,于是把它单独做成一个基础块,如下图:
实现代码如下:
class basic_block2(nn.Module):
'''定义了带虚线部分的残差块'''
def __init__(self,in_channels,out_channels):
super(basic_block2, self).__init__()
self.conv1 = nn.Conv2d(in_channels,out_channels,kernel_size=1,stride=2)
self.conv2 = nn.Conv2d(in_channels,out_channels,kernel_size=3,stride=2,padding=1)
self.conv3 = nn.Conv2d(out_channels,out_channels,kernel_size=3,stride=1,padding=1)
def forward(self, x):
z = self.conv1(x)
y = F.relu(self.conv2(x))
y = self.conv3(y)
return F.relu(y+z)
这样我们就有了两种残差结构块,然后按照文章开头给出的网络结构顺序连接起来就行了,实现代码如下:
class resnet_test(nn.Module):
'''按照网络结构图直接连接,确定好通道数量就可以'''
def __init__(self):
super(resnet_test, self).__init__()
self.conv1 = nn.Conv2d(3,64,kernel_size=7, stride=2, padding=3)
self.maxp1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.resn1 = basic_block(64)
self.resn2 = basic_block(64)
self.resn3 = basic_block2(64,128)
self.resn4 = basic_block(128)
self.rest5 = basic_block2(128,256)
self.rest6 = basic_block(256)
self.rest7 = basic_block2(256,512)
self.rest8 = basic_block(512)
self.avgp1 = nn.AvgPool2d(7)
self.fullc = nn.Linear(512,1000)
def forward(self,x) :
in_size = x.size(0)
x = self.maxp1(F.relu(self.conv1(x)))
x = self.resn1(x)
x = self.resn2(x)
x = self.resn3(x)
x = self.resn4(x)
x = self.resn5(x)
x = self.resn6(x)
x = self.resn7(x)
x = self.resn8(x)
x = self.avgp1(F.relu(x))
x = x.view(in_size,-1)
x = self.fullc(x)
return F.softmax(x,dim=1) ###使用softmax激活函数进行得分计算
这样我们就得到了自己手敲的一个resnet18网络,虽然步骤繁冗,但是小白级别的初学者容易看懂,欢迎交流,第一次写博文,不喜勿喷。