深度学习与Pytorch入门实战(十二)实现ResNet-18并在Cifar-10数据集上进行验证

ResNet图解

nn.Module详解

1. Pytorch上搭建ResNet-18

1.1 ResNet block子模块

import torch
from torch import nn
from torch.nn import functional as F


class ResBlk(nn.Module):
    """
    ResNet block子模块
    """
    def __init__(self, ch_in, ch_out, stride = 1):
#         super(ResBlk, self).__init__()  # python2写法
        # python3写法
        super().__init__()
        
        self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, 
                               stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(ch_out)
        self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, # 输出通道不变
                              stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(ch_out)
        
        self.extra = nn.Sequential()
        # 如果输入和输出的通道不一致,或其步长不为 1,需要将二者转成一致
        if ch_out != ch_in:
            # 将x的维度[b, ch_in, h, w] => [b, ch_out, h, w]
            self.extra = nn.Sequential(
                nn.Conv2d(ch_in, ch_out, kernel_size=1,  
                         stride=stride), 
                nn.BatchNorm2d(ch_out)
            )
            
    def forward(self, x):
        
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        
        out = self.extra(x) + out
        out = F.relu(out)
        return out

1.2 ResNet18主模块

class ResNet18(nn.Module):
    """
    主模块
    """
    def __init__(self):
        super(ResNet18, self).__init__()
        
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=3, padding=0),
            nn.BatchNorm2d(64)
        )
        # followed 4 blocks
        self.blk1 = ResBlk(64, 128, stride=2)  # [b, 64, h, w] => [b, 128, h ,w]
        self.blk2 = ResBlk(128, 256, stride=2) # [b, 128, h, w] => [b, 256, h, w]
        self.blk3 = ResBlk(256, 512, stride=2) # [b, 256, h, w] => [b, 512, h, w]
        self.blk4 = ResBlk(512, 512, stride=2) # [b, 512, h, w] => [b, 512, h, w]
        
        self.outlayer = nn.Linear(512*1*1, 10) # 全连接层,总共10个分类
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        
        # [b, 64, h, w] => [b, 1024, h, w]
        x = self.blk1(x)
        x = self.blk2(x)
        x = self.blk3(x)
        x = self.blk4(x)
        
        # 之前的特征图尺寸为多少,只要设置为(1,1),那么最终特征图大小都为(1,1) 
        x = F.adaptive_avg_pool2d(x, [1,1])    # [b, 512, h, w] => [b, 512, 1, 1]
        # Flatten
        x = x.view(x.size(0), -1)   
        x = self.outlayer(x)
        
        return x

测试:

blk = ResBlk(64, 128, stride=4)
tmp = torch.randn(2, 64, 32, 32)
out = blk(tmp)
print('block:', out.shape)                # block: torch.Size([2, 128, 8, 8])

x = torch.randn(2, 3, 32, 32)
model = ResNet18()
out = model(x)
print('resnet:', out.shape)               # resnet: torch.Size([2, 10])
block: torch.Size([2, 128, 8, 8])
resnet: torch.Size([2, 10])

2. 训练Cifar-10数据集

  • 所选数据集为Cifar-10,该数据集共有60000张带标签的彩色图像,这些图像尺寸32*32,分为10个类,每类6000张图。

  • 这里面有50000张用于训练,每个类5000张;另外10000用于测试,每个类1000张。

import  torch
from    torch.utils.data import DataLoader
from    torchvision import datasets,transforms
from    torch import nn, optim

from    resnet import ResNet18


def main():
    batchsz = 128

    # 训练集
    cifar_train = datasets.CIFAR10('cifar', train=True, download=True, 
                                   transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                             std=[0.229, 0.224, 0.225])
    ]))
    cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True)


    # 测试集
    cifar_test = datasets.CIFAR10('cifar', train=False, 
                                  transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                             std=[0.229, 0.224, 0.225])
    ]))
    cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=True)


    x, label = iter(cifar_train).next()
    # x: torch.Size([128, 3, 32, 32])  label: torch.Size([128])
    print('x:', x.shape, 'label:', label.shape)  

    # 定义模型-ResNet
    model = ResNet18()

    # 定义损失函数和优化方式
    criteon = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    print(model)

    # 训练网络
    for epoch in range(1000):

        model.train()                               # 训练模式
        for batchidx, (x, label) in enumerate(cifar_train):
            # x: [b, 3, 32, 32]
            # label: [b]

            logits = model(x)                       # logits: [b, 10]
            loss = criteon(logits, label)           # 标量

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(epoch, 'loss:', loss.item())


        model.eval()                                # 测试模式
        with torch.no_grad():

            total_correct = 0                       # 预测正确的个数
            total_num = 0
            for x, label in cifar_test:
                # x: [b, 3, 32, 32]
                # label: [b]

                logits = model(x)                   # [b, 10]
                pred = logits.argmax(dim=1)         # [b]

                # [b] vs [b] => scalar tensor
                correct = torch.eq(pred, label).float().sum().item()
                total_correct += correct
                total_num += x.size(0)

            acc = total_correct / total_num
            print(epoch, 'test acc:', acc)


if __name__ == '__main__':
    main()
  • torch.no_grad(): 是一个上下文管理器,被该语句 wrap 起来的部分将不会 track 梯度。

  • 同时 torch.no_grad() 还可以作为一个装饰器。

  • 比如,在网络测试的函数前加上

@torch.no_grad()
def eval():
	...

太慢了,只训练一个epoch

view code
Files already downloaded and verified
x: torch.Size([128, 3, 32, 32]) label: torch.Size([128])
ResNet18(
  (conv1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(3, 3))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (blk1): ResBlk(
    (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (extra): Sequential(
      (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (blk2): ResBlk(
    (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (extra): Sequential(
      (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2))
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (blk3): ResBlk(
    (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (extra): Sequential(
      (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2))
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (blk4): ResBlk(
    (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (extra): Sequential()
  )
  (outlayer): Linear(in_features=512, out_features=10, bias=True)
)
0 loss: 1.0541729927062988
0 test acc: 0.5873

你可能感兴趣的:(深度学习与Pytorch入门实战(十二)实现ResNet-18并在Cifar-10数据集上进行验证)