1.1Residual block
1.2完整Resnet_18
import torch
import torch.nn as nn
import d2I.torch1 as d2I
import torch.nn.functional as F
# 加载数据集
def load_data(fashion_minst=True):
train_data, test_data = None,None
if fashion_minst:
train_data,test_data = d2I.load_data_fashion_mnist(batch_size=128,resize=224)
else:
pass
return train_data,test_data
# 定义模型
class residual_block(nn.Module):
def __init__(self,input_channels,output_channels,conv1_1=True,strides=1):
super(residual_block, self).__init__()
self.conv1_1 = conv1_1
self.conv1 = nn.Conv2d(input_channels,output_channels,kernel_size=3,stride=strides,padding=1)
self.conv2 = nn.Conv2d(output_channels,output_channels,kernel_size=3,stride=1,padding=1)
self.bn1 = nn.BatchNorm2d(output_channels)
self.bn2 = nn.BatchNorm2d(output_channels)
self.relu = nn.ReLU(inplace=True)
self.conv3 = nn.Conv2d(input_channels, output_channels, kernel_size=1, stride=strides)
def forward(self,x):
Y = self.relu(self.bn1(self.conv1(x)))
Y = self.bn2(self.conv2(Y))
if self.conv1_1 is True:
x = self.conv3(x)
Y += x
return self.relu(Y)
class Res_block(nn.Module):
def __init__(self,in_channels,out_channels,num_blocks,is_b2=False):
super(Res_block, self).__init__()
self.inchannel = in_channels
self.outhannel = out_channels
self.num_blocks = num_blocks
self.is_b2 = is_b2
res = []
for i in range(num_blocks):
if i==0 and self.is_b2 is not True:
res.append(residual_block(in_channels,out_channels,conv1_1=True,strides=2))
else:
res.append(residual_block(out_channels, out_channels, conv1_1=False, strides=1))
self.seq = nn.Sequential(
*res
)
def forward(self,x):
return self.seq(x)
class Resnet_18(nn.Module):
def __init__(self,inputchannel):
super(Resnet_18, self).__init__()
self.b1 = nn.Sequential(
nn.Conv2d(inputchannel,64,kernel_size=7,stride=4,padding=3), #(224,224) --> (56,56)
nn.BatchNorm2d(num_features=64,eps=1e-5),
nn.MaxPool2d(3,stride=1,padding=1)
)
self.net = nn.Sequential(
self.b1,
Res_block(64, 64, 2, is_b2=True),
Res_block(64, 128, 2),
Res_block(128, 256, 2),
Res_block(256, 512, 2),
nn.AdaptiveAvgPool2d((1, 1)),
nn.Flatten(),
nn.Linear(512,10)
)
def forward(self,x):
return self.net(x)
def show_net(self):
return self.net
Resnet = Resnet_18(inputchannel=1)
print(Resnet)
X = torch.randn((1,1,224,224),dtype=torch.float32)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
for net in Resnet.show_net():
X = net(X)
print(net.__class__.__name__, " output_shape: ",X.shape)
Resnet1 = Resnet_18(inputchannel=1)
train_data,test_data = load_data(fashion_minst=True)
epochs=10
lr=0.01
d2I.train_ch6(Resnet1,train_data,test_data,epochs,lr,device)
参考文献:动手学深度学习(李沐)