CIFAR10数据集 图像分类训练学习笔记

CIFAR10数据集 图像分类训练学习笔记

  • CIFAR10数据集 图像分类训练学习笔记
    • CIFAR10 数据集介绍
    • 数据集的准备
    • 构造模型
    • 训练及测试
    • 鸣谢
      • 参考内容

CIFAR10数据集 图像分类训练学习笔记

学习pytorch 的基本操作后,使用CIFAR-10 的数据集进行分类模型的学习。

CIFAR10 数据集介绍

CIFAR-10数据集是图像分类领域经典的数据集,共包含10个类别的 RGB彩色图片,图片的尺寸为 32×32 ,数据集中一共有 50000 张训练圄片和 10000 张测试图片。数据集大小适合在个人电脑上测试效果。 (GTX 2060 训练约十几分钟)
数据集展示
CIFAR10数据集 图像分类训练学习笔记_第1张图片

数据集的准备

Pytorch框架的 “torchvision” 中提供该数据集的下载使用。可使用以下命令进行数据集的加载:

# transform--数据集转换方式集合(包括数据扩充)
transform = torchvision.transforms.Compose([
					torchvision.transforms.ToTensor()])
# dataset--准备数据集
dataset = torchvision.datasets.CIFAR10("../dataset",train= True,
                            transform=transform,download=True)
# dataloader--准备数据容器
dataloader = DataLoader(dataset=dataset,batch_size=64,shuffle=True,drop_last=False)

数据集不做处理时,测试准确率约为81%;transforms增加数据扩充可以提升准确率。
准备数据的文件如下: DATA.py

import torchvision
from torch.utils.data import DataLoader
import torchvision.datasets

# 增加数据集transforms
train_dataset_transform = torchvision.transforms.Compose([
    torchvision.transforms.RandomCrop(32,padding=4),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
test_dataset_transform = torchvision.transforms.Compose([
    torchvision.transforms.RandomCrop(32,padding=4),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

train_dataset = torchvision.datasets.CIFAR10("../dataset",train= True,
                                             transform=train_dataset_transform,download=True)
test_dataset = torchvision.datasets.CIFAR10("../dataset",train=False,
                                            transform=test_dataset_transform,download=True)

train_dataloader = DataLoader(dataset=train_dataset,batch_size=64,shuffle=True,drop_last=False)
test_dataloader = DataLoader(test_dataset,batch_size=32,shuffle=False,drop_last=False)

# 查看图像大小
for data in train_dataloader:
    imgs, targets = data
    print(imgs[0].shape)
    break

构造模型

测试使用经典卷积模型VGG-16,但是该数据集的图像大小和其模型数据集图像大小差异过大,模型训练消耗资源过多,故将其模型改小。模型构建的代码 Model.py如下:

import torch.nn as nn

class vgg16gai(nn.Module):
    def __init__(self,input_channels = 3,output_channels = 10):
        super(vgg16gai, self).__init__()

        self.features = nn.Sequential(
            nn.Conv2d(input_channels,64,kernel_size=3,stride=1,padding=1), # imgsize = 32*32
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),                                    # imgsize = 16*16
            nn.Conv2d(64,256,kernel_size=3,stride=1,padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256,256,kernel_size=3,stride=1,padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2)
        )

        self.avgpool = nn.AdaptiveAvgPool2d(output_size=(4,4))
        self.classifier = nn.Sequential(
            nn.Linear(512*4*4,512),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(512,64),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(64,output_channels)
        )

    def forward(self,x):
        x = self.features(x)
        x = self.avgpool(x)
        x = x.view(x.size(0),-1)
        x = self.classifier(x)
        return x

训练及测试

训练测试文件 model_test.py 如下:

from torch.utils.tensorboard import SummaryWriter
import torch
from CIFAR10_DATA import *
from Model import vgg16gai
import torch.nn as nn
# 设置有显卡用显卡,无显卡用cpu
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

writer = SummaryWriter("../logs_model")
model = vgg16gai()
model = model.to(device)

epoch = 30                      # 循环轮数
train_step =0                   # 模型训练次数
test_step = 1

# loss 函数
loss_fn = nn.CrossEntropyLoss()
loss_fn = loss_fn.to(device)

# 优化器
learning_rate = 0.01
optimizer = torch.optim.SGD(model.parameters(),lr=learning_rate,momentum=0.9)
# 学习率自适应变换--根据训练轮数学习率
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizer,
         milestones=[int(epoch * 0.4), int(epoch * 0.68),int(epoch * 0.9)],
         				             gamma=0.1, last_epoch=-1,verbose=True)


for i in range(epoch):
    print("---------第 {} 轮训练开始-------------".format(i+1))

    for data in train_dataloader:
        imgs, targets = data
        imgs = imgs.to(device)
        targets = targets.to(device)
        output = model(imgs)
        loss = loss_fn(output,targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_step += 1
        if train_step % 100 == 0:
            print("训练次数: {} ,loss: {}  ".format(train_step,loss.item()))
            writer.add_scalar("train_loss",loss.item(),train_step)

    ## 训练一轮后进行测试模型效果
    total_acc = 0                   # 测试准确率
    total_test_loss = 0

    with torch.no_grad():
        for data_test in test_dataloader:
            imgs_test, targets_test = data_test
            imgs_test = imgs_test.to(device)
            targets_test = targets_test.to(device)
            output_test = model(imgs_test)
            loss_test = loss_fn(output_test,targets_test)
            total_test_loss += loss_test
            accuracy = (output_test.argmax(1) == targets_test).sum()
            total_acc += accuracy

        acc_rate = total_acc/len(test_dataset)
        print("测试集的整体loss: {}".format(loss_test))
        print("测试集的正确率: {}".format(acc_rate))
        writer.add_scalar("test_loss",total_test_loss,test_step)
        writer.add_scalar("测试准确率",acc_rate,test_step)
        test_step += 1
    scheduler.step()
    print('\t last_lr:', scheduler.get_last_lr())
writer.close()

打印测试结果
CIFAR10数据集 图像分类训练学习笔记_第2张图片

鸣谢

学习过程中参考了很多内容,由于时间跨度问题没有一一记录下来,这里说声感谢!
分享创造价值!

参考内容

【深度学习】Pytorch实现CIFAR10图像分类任务测试集准确率达95%

PyTorch深度学习快速入门教程-B站小土堆

你可能感兴趣的:(分类,学习,深度学习)