import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l
# 定义基本残差块
class Residual(nn.Module):
def __init__(self, input_channels, num_channels, use_1x1conv=False, strides=1):
super().__init__()
self.conv1 = nn.Conv2d(input_channels, num_channels, kernel_size=3, padding=1, stride=strides)
self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1)
if use_1x1conv: # 是否需要降低空间分辨率,增加通道维维度
self.conv3 = nn.Conv2d(input_channels, num_channels, kernel_size=1, stride=strides)
else:
self.conv3 = None
self.bn1 = nn.BatchNorm2d(num_channels)
self.bn2 = nn.BatchNorm2d(num_channels)
self.relu = nn.ReLU(inplace=True)
# inplace为True,将会改变输入的数据 ,否则不会改变原输入,只会产生新的输出。
# 产生的计算结果不会有影响。利用in-place计算可以节省内(显)存,同时还可以省去反复申请和释放内存的时间。但是会对原变量覆盖,只要不带来错误就用。
def forward(self, X):
Y = F.relu(self.bn1(self.conv1(X)))
Y = self.bn2(self.conv2(Y))
if self.conv3:
X = self.conv3(X)
Y += X
return F.relu(Y)
查看普通残差块:输入和输出形状一致
blk= Residual(3, 3)
X = torch.rand(4, 3, 6, 6)
Y = blk(X)
Y.shape
查看升维残差块:增加输出通道的同时,减半输入的高和宽
blk = Residual(3, 6, use_1x1conv=True, strides=2)
X =torch.rand(4, 3, 6, 6)
Y =blk(X)
Y.shape
# 定义resnet块
def resnet_block(input_channels, num_channels, num_residuals, first_block=False):
"""定义大的残差块(5块)"""
blk = []
for i in range(num_residuals):
if i == 0 and not first_block:
# 除了一个块,每个块的一个升维残差块,要先缩小输入特征图的尺寸,增大通道数
blk.append(
Residual(input_channels, num_channels, use_1x1conv=True, strides=2)
)
else:
# 第一块或者每块中用于提取特征的堆叠的基本残差块,输入和输出的形状一致
blk.append(
Residual(num_channels, num_channels)
)
return blk
# 定义ResNet网络模型
b1 = nn.Sequential( # 输入形状:[1, 1, 224, 224]
nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3), # (224 - 7 + 2*3)/2 + 1 = 112
nn.BatchNorm2d(64), nn.ReLU(), # [1, 64, 112, 112]
nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # [1, 64, 56. 56]
)
b2= nn.Sequential(
# *列表:表示解包操作,把列表元素顺序展开
# *[1, 3, 2 , 5, 4] = 1, 3, 2, 5, 4
*resnet_block(64, 64, 2, first_block=True) # [1, 64, 56, 56]、[1, 64, 56, 56]
)
b3 = nn.Sequential(
*resnet_block(64, 128, 2) # [1, 128, 28, 28]、[1, 128, 28, 28]
)
b4 = nn.Sequential(
*resnet_block(128, 256, 2) # [1, 256, 14, 14]、[1, 256, 14, 14]
)
b5 = nn.Sequential(
*resnet_block(256, 512, 2) # [1, 512, 7, 7]、[1, 512, 7, 7]
)
net = nn.Sequential(
b1,
b2,
b3,
b4,
b5,
nn.AdaptiveAvgPool2d((1, 1)), # [1, 512, 1, 1]
nn.Flatten(), # [1, 512*1*1]= [1, 512]
nn.Linear(512, 10) # [1, 512] --> [1, 10]
)
X = torch.randn(1, 1, 224, 224)
for layer in net:
X = layer(X)
print(layer.__class__.__name__, 'output shape:\t', X.shape)
batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=224)
lr, num_epochs = 0.05, 10
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
QA