ResNet图解
nn.Module详解
1. Pytorch上搭建ResNet-18
1.1 ResNet block子模块
import torch
from torch import nn
from torch.nn import functional as F
class ResBlk(nn.Module):
"""
ResNet block子模块
"""
def __init__(self, ch_in, ch_out, stride = 1):
# super(ResBlk, self).__init__() # python2写法
# python3写法
super().__init__()
self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3,
stride=stride, padding=1)
self.bn1 = nn.BatchNorm2d(ch_out)
self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, # 输出通道不变
stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(ch_out)
self.extra = nn.Sequential()
# 如果输入和输出的通道不一致,或其步长不为 1,需要将二者转成一致
if ch_out != ch_in:
# 将x的维度[b, ch_in, h, w] => [b, ch_out, h, w]
self.extra = nn.Sequential(
nn.Conv2d(ch_in, ch_out, kernel_size=1,
stride=stride),
nn.BatchNorm2d(ch_out)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out = self.extra(x) + out
out = F.relu(out)
return out
1.2 ResNet18主模块
class ResNet18(nn.Module):
"""
主模块
"""
def __init__(self):
super(ResNet18, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, stride=3, padding=0),
nn.BatchNorm2d(64)
)
# followed 4 blocks
self.blk1 = ResBlk(64, 128, stride=2) # [b, 64, h, w] => [b, 128, h ,w]
self.blk2 = ResBlk(128, 256, stride=2) # [b, 128, h, w] => [b, 256, h, w]
self.blk3 = ResBlk(256, 512, stride=2) # [b, 256, h, w] => [b, 512, h, w]
self.blk4 = ResBlk(512, 512, stride=2) # [b, 512, h, w] => [b, 512, h, w]
self.outlayer = nn.Linear(512*1*1, 10) # 全连接层,总共10个分类
def forward(self, x):
x = F.relu(self.conv1(x))
# [b, 64, h, w] => [b, 1024, h, w]
x = self.blk1(x)
x = self.blk2(x)
x = self.blk3(x)
x = self.blk4(x)
# 之前的特征图尺寸为多少,只要设置为(1,1),那么最终特征图大小都为(1,1)
x = F.adaptive_avg_pool2d(x, [1,1]) # [b, 512, h, w] => [b, 512, 1, 1]
# Flatten
x = x.view(x.size(0), -1)
x = self.outlayer(x)
return x
测试:
blk = ResBlk(64, 128, stride=4)
tmp = torch.randn(2, 64, 32, 32)
out = blk(tmp)
print('block:', out.shape) # block: torch.Size([2, 128, 8, 8])
x = torch.randn(2, 3, 32, 32)
model = ResNet18()
out = model(x)
print('resnet:', out.shape) # resnet: torch.Size([2, 10])
block: torch.Size([2, 128, 8, 8])
resnet: torch.Size([2, 10])
2. 训练Cifar-10数据集
-
所选数据集为Cifar-10,该数据集共有60000张带标签的彩色图像,这些图像尺寸32*32,分为10个类,每类6000张图。
-
这里面有50000张用于训练,每个类5000张;另外10000用于测试,每个类1000张。
import torch
from torch.utils.data import DataLoader
from torchvision import datasets,transforms
from torch import nn, optim
from resnet import ResNet18
def main():
batchsz = 128
# 训练集
cifar_train = datasets.CIFAR10('cifar', train=True, download=True,
transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
]))
cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True)
# 测试集
cifar_test = datasets.CIFAR10('cifar', train=False,
transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
]))
cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=True)
x, label = iter(cifar_train).next()
# x: torch.Size([128, 3, 32, 32]) label: torch.Size([128])
print('x:', x.shape, 'label:', label.shape)
# 定义模型-ResNet
model = ResNet18()
# 定义损失函数和优化方式
criteon = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
print(model)
# 训练网络
for epoch in range(1000):
model.train() # 训练模式
for batchidx, (x, label) in enumerate(cifar_train):
# x: [b, 3, 32, 32]
# label: [b]
logits = model(x) # logits: [b, 10]
loss = criteon(logits, label) # 标量
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(epoch, 'loss:', loss.item())
model.eval() # 测试模式
with torch.no_grad():
total_correct = 0 # 预测正确的个数
total_num = 0
for x, label in cifar_test:
# x: [b, 3, 32, 32]
# label: [b]
logits = model(x) # [b, 10]
pred = logits.argmax(dim=1) # [b]
# [b] vs [b] => scalar tensor
correct = torch.eq(pred, label).float().sum().item()
total_correct += correct
total_num += x.size(0)
acc = total_correct / total_num
print(epoch, 'test acc:', acc)
if __name__ == '__main__':
main()
-
torch.no_grad()
: 是一个上下文管理器,被该语句 wrap 起来的部分将不会 track 梯度。 -
同时
torch.no_grad()
还可以作为一个装饰器。 -
比如,在网络测试的函数前加上
@torch.no_grad()
def eval():
...
太慢了,只训练一个epoch
view code
Files already downloaded and verified
x: torch.Size([128, 3, 32, 32]) label: torch.Size([128])
ResNet18(
(conv1): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(3, 3))
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(blk1): ResBlk(
(conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(extra): Sequential(
(0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2))
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(blk2): ResBlk(
(conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(extra): Sequential(
(0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2))
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(blk3): ResBlk(
(conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(extra): Sequential(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2))
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(blk4): ResBlk(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(extra): Sequential()
)
(outlayer): Linear(in_features=512, out_features=10, bias=True)
)
0 loss: 1.0541729927062988
0 test acc: 0.5873