PyTorch 学习笔记(八):图像增强、ResNet完成Cifar10分类

一. 图像增强的方法

一直以来,图像识别这一计算机视觉的核心问题都面临很多挑战,同一个物体在不同情况下都会得出不同的结论。对于一张图片,我们看到的是一些物体,而计算机看到的是一些像素点。

如果拍摄照片的照相机位置发生了改变,那么拍得的图片对于我们而言,变化很小,但是对于计算机而言,图片得像素变化是很大得。拍照时得光照条件也是很重要的一个影响因素:光照太弱,照片里的物体会和背景融为一体,它们的像素点就会很接近,计算机就无法正确识别出物体。除此之外,物体本身的变形也会对计算机识别造成障碍,比如一只猫是趴着的,计算机能够识别它,但如果猫换个姿势,变成躺着的状态,那么计算机就无法识别了。最后,物体本身会隐藏在一些遮蔽物中,这样物体只呈现出局部的信息,计算机也难以识别。

针对这些问题,我们希望可以对原始图片进行增强,在一定程度上解决部分问题。在PyTorch中已经内置了一些图像增强的方法,不需要再繁琐地去实现,只需要简单的调用。

torchvision.transforms包括所有图像增强的方法:

  • 第一个函数是 Scale,对图片的尺寸进行缩小或者放大;
  • 第二个函数是 CenterCrop,对图像正中心进行给定大小的裁剪;
  • 第三个函数是 RandomCrop,对图片进行给定大小的随机裁剪;
  • 第四个函数是 RandomHorizaontalFlip,对图片进行概率为0.5的随机水平翻转;
  • 第五个函数是 RandomSizedCrop,首先对图片进行随机尺寸的裁剪,然后再对裁剪的图片进行一个随机比例的缩放,最后将图片变成给定的大小,这在InceptionNet中比较流行;
  • 最后一个是 pad,对图片进行边界零填充;

上面介绍了PyTorch内置的一些图像增强的方法,还有更多的增强方法,可以使用OpenCV或者PIL等第三方图形库实现。在网络的训练的过程中图形增强是一种常见、默认的做法,对多任务进行图像增强之后能够在一定程度上提升任务的准确率。

二. 实现 CIFAR-10 分类

cifar 10数据集有60000张图片,每张图片都是 32x32 的三通道的彩色图,一共是10个类别,每种类别有6000张图片。下面实现ResNet来处理cifar 10数据集,完成图像分类。

注意的是下面的代码只对训练图片进行图像增强,提高其泛化能力,对于测试集,仅对其中心化,不做其他的图像增强。

import torch
from torch import nn, optim
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from torchstat import stat
from torch.autograd import Variable

# 读数据
def get_data():
    train_dataset = datasets.CIFAR10(root='./data', train=True, transform=train_transform, download=True)
    test_dataset = datasets.CIFAR10(root='./data', train=False, transform=test_transform, download=True)
    print(len(train_dataset))
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True, drop_last=True)
    return train_loader, test_loader

def conv3x3(in_channels, out_channels, stride=1):
    return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)

# Residual Block
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(ResidualBlock, self).__init__()
        self.conv1 = conv3x3(in_channels, out_channels, stride)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(True)
        self.conv2 = conv3x3(out_channels, out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = downsample
    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.downsample:
            residual = self.downsample(x)
        out += residual
        out = self.relu(out)
        return out

# 构建网络
class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes=10):
        super(ResNet, self).__init__()
        self.in_channels = 16        # 64, 3, 32, 32
        self.conv = conv3x3(3, 16)       # 64, 16, 32, 32
        self.bn = nn.BatchNorm2d(16)
        self.relu = nn.ReLU(True)
        self.layer1 = self.make_layer(block, 16, layers[0])    # 64, 16, 32, 32
        self.layer2 = self.make_layer(block, 32, layers[0], 2)   # 64, 32, 16, 16
        self.layer3 = self.make_layer(block, 64, layers[1], 2)    # 64, 64, 8, 8
        self.avg_pool = nn.AvgPool2d(8)        # 64, 64, 1, 1
        self.fc = nn.Linear(64, num_classes)

    def make_layer(self, block, out_channles, blocks, stride=1):
        downsample = None
        if out_channles != self.in_channels or stride != 1:
            downsample = nn.Sequential(conv3x3(self.in_channels, out_channles, stride=stride), nn.BatchNorm2d(out_channles))
        layers = []
        layers.append(block(self.in_channels, out_channles, stride, downsample))
        self.in_channels = out_channles
        for i in range(1, blocks):
            layers.append(block(out_channles, out_channles))
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv(x)
        out = self.bn(out)
        out = self.relu(out)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.avg_pool(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out



if __name__ == "__main__":
    # 超参数配置
    batch_size = 64
    learning_rate = 1e-2
    num_epoches = 100
    # 训练图片的预处理方式
    train_transform = transforms.Compose([transforms.Scale(40), transforms.RandomHorizontalFlip(), transforms.RandomCrop(32),
                                          transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
    test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
    # 加载数据集
    train_dataset, test_dataset = get_data()
    # 构建模型
    # model = ResNet(ResidualBlock, [3, 4])
    model = torch.load('resnet_model.pth')
    stat(model, (3, 32, 32))
    if torch.cuda.is_available():
        model = model.cuda()
    # 定义损失函数和优化器
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=learning_rate)
    schedule_lr = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
    # 开始训练
    for i in range(num_epoches):
        j = 0
        for img, label in train_dataset:
            model.train()
            schedule_lr.step()
            img = Variable(img)
            label = Variable(label)
            # forward
            out = model(img)
            loss = criterion(out, label)
            # backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # print
            print("epoch= {},j= {}, loss is {}".format(i, j, loss))
            #print(list(model.children())[-1].weight)
            j += 1
            if j % 100 == 0:
                torch.save(model, './resnet_model.pth')

    # test
    model.eval()
    count = 0
    print(len(test_dataset))
    for img, label in test_dataset:
        img = Variable(img)
        out = model(img)
        _, predict = torch.max(out, 1)
        if predict == label:
            count += 1
    print(count / len(test_dataset))

你可能感兴趣的:(pytorch框架)