PyTorch基础(六)搭建ResNets网络模型

上一篇博客学习了如何搭建Inception网络,这篇博客主要讲述如何利用pytorch搭建ResNets网络。

上一篇博客中遗留了一个问题,就是1*1卷积核的作用,第一个作用是减少参数,第二个作用是压缩通道数,减少计算量。

理论上,随着网络深度的加深,训练应该越来越好,但是,如果没有残差网络,深度越深意味着用优化算法越难计算,ResNets网络模型优点在于它能够训练深层次的网络模型,并且有助于解决梯度消失和梯度爆炸的问题,而且能保证良好的性能。

1、ResNets结构图

PyTorch基础(六)搭建ResNets网络模型_第1张图片

从上图中可以看出,Resnets网络在计算时,在执行最后一个步骤的激活时,加上了原先的x的值,这样的操作就是防止梯度消失。

2、导入相关库、构造数据

import torch

from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F

#数据增强
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,),(0.3081,))
])

#构造数据集
train_dataset = datasets.MNIST(
    root='../dataset/mnist',
    download=False,
    train=True,
    transform=transform
)

test_dataset = datasets.MNIST(
    root='../dataset/mnist',
    download=False,
    train=False,
    transform = transform
)

train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=64,
    shuffle=True
)

test_loader = DataLoader(
    dataset=test_dataset,
    batch_size=64,
    shuffle=True
)

这些代码都是在这一系列实验中共有的部分,不在做过多的解释。

3、构建ResidualNet(关键部分)netnetsidual net

#构造残差模块
class ResidualBlock(torch.nn.Module):
    def __init__(self,channels):
        super(ResidualBlock,self).__init__()
        self.channels = channels
        #same卷积
        self.conv1 = torch.nn.Conv2d(channels,channels,kernel_size=3,padding=1)
        self.conv2 = torch.nn.Conv2d(channels,channels,kernel_size=3,padding=1)

    def forward(self,x):

        y = F.relu(self.conv1(x))
        y = self.conv2(y)
        return F.relu(x+y)

 从这段代码中,首先对数据x进行卷积操作和激活操作得到y,然后对y进行卷积操作处理得到新的y,最后对原始和x加上y进行激活操作。

为了保证x可以和y相加,这个网络中采用的都是same卷积,这样会使图片数据的高度和宽度不变。并且这段代码的具有很高的重用性,在构造时,可以传入相应的通道数,这样他就可以作为一个单独的模块和其他网络一起构造。

4、构建网络模型(关键部分)

#构造网络模型
class Net(torch.nn.Module):
    def __init__(self):
        super(Net,self).__init__()

        self.conv1 = torch.nn.Conv2d(1,16,kernel_size=5)
        self.conv2 = torch.nn.Conv2d(16,32,kernel_size=5)

        #最大池化
        self.mp = torch.nn.MaxPool2d(2)

        self.rblock1 = ResidualBlock(channels=16)
        self.rblock2 = ResidualBlock(channels=32)

        self.fc = torch.nn.Linear(512,10)

    def forward(self,x):
        batch_size = x.size(0)
        x = self.mp(F.relu(self.conv1(x)))
        x = self.rblock1(x)
        x = self.mp(F.relu(self.conv2(x)))
        x = self.rblock2(x)
        x = x.view(batch_size,-1)
        x = self.fc(x)
        return x

model = Net()

看代码不如看网络构件图直觉,所以我画了一个简单的图形。

PyTorch基础(六)搭建ResNets网络模型_第2张图片

 上面那段代码就是根据这个网络构建图来写的,就不做过多的解释了。

5、训练和测试模型

model = Net()

#构造损失

criterion= torch.nn.CrossEntropyLoss()
#构造优化
optimizer = torch.optim.SGD(model.parameters(),lr=0.01,momentum=0.5)
#训练模型

def train(epoch):
    running_loss = 0
    for batchix,datas in enumerate(train_loader,0):
        inputs,target = datas
        optimizer.zero_grad()
        label = model(inputs)
        loss = criterion(label,target)
        loss.backward()
        optimizer.step()
        running_loss+=loss.item()
        if batchix%300==299:
            print('[%d,%3d] 损失值:%.3f'%(epoch+1,batchix+1,running_loss/300))
            running_loss = 0
#测试模型

def test():
    total = 0
    correct = 0
    with torch.no_grad():
        for data in (test_loader):
            inputs,label = data
            output = model(inputs)
            _,pre = torch.max(output,dim=1)
            total += label.size(0)
            correct += (pre==label).sum().item()

    print('准确率为%.3f'%(correct/total*100))

这些代码都是重用度非常高的代码,我每次学习一个新的网络结构,我都要在写一次,增加自己对网络结构的感觉并且练习一下这些常用的代码。

6、全部代码

import torch

from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F

#数据增强
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,),(0.3081,))
])

#构造数据集
train_dataset = datasets.MNIST(
    root='../dataset/mnist',
    download=False,
    train=True,
    transform=transform
)

test_dataset = datasets.MNIST(
    root='../dataset/mnist',
    download=False,
    train=False,
    transform = transform
)

train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=64,
    shuffle=True
)

test_loader = DataLoader(
    dataset=test_dataset,
    batch_size=64,
    shuffle=True
)

#构造残差模块
class ResidualBlock(torch.nn.Module):
    def __init__(self,channels):
        super(ResidualBlock,self).__init__()
        self.channels = channels
        #same卷积
        self.conv1 = torch.nn.Conv2d(channels,channels,kernel_size=3,padding=1)
        self.conv2 = torch.nn.Conv2d(channels,channels,kernel_size=3,padding=1)

    def forward(self,x):

        y = F.relu(self.conv1(x))
        y = self.conv2(y)
        return F.relu(x+y)

#构造网络模型
class Net(torch.nn.Module):
    def __init__(self):
        super(Net,self).__init__()

        self.conv1 = torch.nn.Conv2d(1,16,kernel_size=5)
        self.conv2 = torch.nn.Conv2d(16,32,kernel_size=5)

        #最大池化
        self.mp = torch.nn.MaxPool2d(2)

        self.rblock1 = ResidualBlock(channels=16)
        self.rblock2 = ResidualBlock(channels=32)

        self.fc = torch.nn.Linear(512,10)

    def forward(self,x):
        batch_size = x.size(0)
        x = self.mp(F.relu(self.conv1(x)))
        x = self.rblock1(x)
        x = self.mp(F.relu(self.conv2(x)))
        x = self.rblock2(x)
        x = x.view(batch_size,-1)
        x = self.fc(x)
        return x

model = Net()

#构造损失

criterion= torch.nn.CrossEntropyLoss()
#构造优化
optimizer = torch.optim.SGD(model.parameters(),lr=0.01,momentum=0.5)
#训练模型

def train(epoch):
    running_loss = 0
    for batchix,datas in enumerate(train_loader,0):
        inputs,target = datas
        optimizer.zero_grad()
        label = model(inputs)
        loss = criterion(label,target)
        loss.backward()
        optimizer.step()
        running_loss+=loss.item()
        if batchix%300==299:
            print('[%d,%3d] 损失值:%.3f'%(epoch+1,batchix+1,running_loss/300))
            running_loss = 0
#测试模型

def test():
    total = 0
    correct = 0
    with torch.no_grad():
        for data in (test_loader):
            inputs,label = data
            output = model(inputs)
            _,pre = torch.max(output,dim=1)
            total += label.size(0)
            correct += (pre==label).sum().item()

    print('准确率为%.3f'%(correct/total*100))


if __name__=='__main__':
    for epoch in range(3):
        train(epoch)
        test()

运行结果:

PyTorch基础(六)搭建ResNets网络模型_第3张图片

你可能感兴趣的:(PyTorch,人工智能,pytorch,深度学习,神经网络)