ResNet18(Pytorch版复现)

Resnet18 复现+训练(参考书:动手学深度学习)

一、网络结构

1.1Residual block
ResNet18(Pytorch版复现)_第1张图片
1.2完整Resnet_18
ResNet18(Pytorch版复现)_第2张图片

二、代码复现

import torch
import torch.nn as nn
import d2I.torch1 as d2I
import torch.nn.functional as F
# 加载数据集
def load_data(fashion_minst=True):
    train_data, test_data = None,None
    if fashion_minst:
        train_data,test_data = d2I.load_data_fashion_mnist(batch_size=128,resize=224)
    else:
        pass
    return train_data,test_data

# 定义模型
class residual_block(nn.Module):
    def __init__(self,input_channels,output_channels,conv1_1=True,strides=1):
        super(residual_block, self).__init__()
        self.conv1_1 = conv1_1
        self.conv1 = nn.Conv2d(input_channels,output_channels,kernel_size=3,stride=strides,padding=1)
        self.conv2 = nn.Conv2d(output_channels,output_channels,kernel_size=3,stride=1,padding=1)
        self.bn1 = nn.BatchNorm2d(output_channels)
        self.bn2 = nn.BatchNorm2d(output_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv3 = nn.Conv2d(input_channels, output_channels, kernel_size=1, stride=strides)

    def forward(self,x):
        Y = self.relu(self.bn1(self.conv1(x)))
        Y = self.bn2(self.conv2(Y))
        if self.conv1_1 is True:
            x = self.conv3(x)
        Y += x
        return self.relu(Y)

class Res_block(nn.Module):
    def __init__(self,in_channels,out_channels,num_blocks,is_b2=False):
        super(Res_block, self).__init__()
        self.inchannel = in_channels
        self.outhannel = out_channels
        self.num_blocks = num_blocks
        self.is_b2 = is_b2
        res = []
        for i in range(num_blocks):
            if i==0 and self.is_b2 is not True:
                res.append(residual_block(in_channels,out_channels,conv1_1=True,strides=2))
            else:
                res.append(residual_block(out_channels, out_channels, conv1_1=False, strides=1))
        self.seq = nn.Sequential(
            *res
        )
    def forward(self,x):
        return self.seq(x)

class Resnet_18(nn.Module):
    def __init__(self,inputchannel):
        super(Resnet_18, self).__init__()
        self.b1 = nn.Sequential(
            nn.Conv2d(inputchannel,64,kernel_size=7,stride=4,padding=3), #(224,224) --> (56,56)
            nn.BatchNorm2d(num_features=64,eps=1e-5),
            nn.MaxPool2d(3,stride=1,padding=1)
        )
        self.net = nn.Sequential(
            self.b1,
            Res_block(64, 64, 2, is_b2=True),
            Res_block(64, 128, 2),
            Res_block(128, 256, 2),
            Res_block(256, 512, 2),
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(512,10)
        )

    def forward(self,x):
        return  self.net(x)

    def show_net(self):
        return self.net

Resnet = Resnet_18(inputchannel=1)
print(Resnet)

X = torch.randn((1,1,224,224),dtype=torch.float32)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
for net in Resnet.show_net():
    X = net(X)
    print(net.__class__.__name__, " output_shape: ",X.shape)

Resnet1 = Resnet_18(inputchannel=1)
train_data,test_data = load_data(fashion_minst=True)
epochs=10
lr=0.01
d2I.train_ch6(Resnet1,train_data,test_data,epochs,lr,device)

参考文献:动手学深度学习(李沐)

你可能感兴趣的:(深度学习集训,pytorch,深度学习,人工智能)