【Pytorch实战4】基于CIFAR10数据集训练一个分类器

参考资料:

《深度学习之pytorch实战计算机视觉》

Pytorch官方教程

Pytorch中文文档

  

   先是数据的导入与预览。

import torch
import torchvision
from torchvision import datasets, transforms
from torch.autograd import Variable
import matplotlib.pyplot as plt
import numpy as np

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

#定义图片预处理操作,对载入的数据进行各种变换
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])])
# 导入数据
data_train = datasets.CIFAR10(root="./data/",
                            transform=transform,
                            train=True,
                            download=True)
data_test = datasets.CIFAR10(root="./data/",
                           transform=transform,
                           train=False)

#数据装载
data_loader_train = torch.utils.data.DataLoader(dataset = data_train,
                                                batch_size = 4,
                                                shuffle = True)
data_loader_test = torch.utils.data.DataLoader(dataset = data_test,
                                               batch_size = 4,
                                               shuffle = True)



#数据预览
images,labels = next((iter(data_loader_train)))#图片维度:batch_size,channel,height,weight
img = torchvision.utils.make_grid(images)#将一个批次的图片弄成网格模式 channel,h,w
#matplotlib待显示的数据维度必须是h,w,c
img = img.numpy().transpose(1,2,0)

img = img*0.5+0.5#去标准化
print([classes[labels[i]] for i in range(4)])
plt.imshow(img)
plt.show()

运行结果:

【Pytorch实战4】基于CIFAR10数据集训练一个分类器_第1张图片

再是网络的搭建

#定义一个卷积神经网络
class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)#输入通道,输出通道,卷积核大小
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
net = Net()

然后是模型的训练和测试。

#定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

#开始训练
epoch_n = 5
for epoch in range(epoch_n):  # loop over the dataset multiple times

    running_loss = 0.0
    running_correct = 0
    print("Epoch {}/{}".format(epoch,epoch_n))
    print('-'*10)
    for data in data_loader_train:
        #前向传播计算预测结果和loss
        x_train,y_trian = data
        x_train,y_trian = Variable(x_train),Variable(y_trian)
        outputs = net(x_train)
        _,pred = torch.max(outputs,1)#_为最大值,pred为最大值对应的索引
        loss = criterion(outputs,y_trian)
        running_loss += loss.data.item()
        running_correct += torch.sum(pred == y_trian.data)
        #反向传播更新参数
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    #训练一个epoch后计算一次测试集的准确率
    testing_correct = 0
    for data in data_loader_test:
        x_test,y_test = data
        x_test,y_test = Variable(x_test),Variable(y_test)
        outputs = net(x_test)
        _,pred = torch.max(outputs,1)
        testing_correct += torch.sum(pred == y_test.data)

    print("Loss:{:.4f},Train Accuracy:{:.4f}%,Test Accuracy:{:.4f}%".format(running_loss/len(data_train),
                                                                           100*running_correct/len(data_train),
                                                                           100*testing_correct/len(data_test)))
print('Finished Training')

#用训练好的模型对部分测试集的结果(1个batch)进行预测,并将结果可视化
x_test,y_test = next(iter(data_loader_test))
inputs = Variable(x_test)
outputs = net(inputs)
_,pred = torch.max(outputs,1)

print("Predict Label :",[classes[i] for i in pred.data])
print("Real Label:",[classes[i] for i in y_test])

#测试图片可视化
img = torchvision.utils.make_grid((x_test))
img = img.numpy().transpose(1,2,0)
img = img*0.5+0.5
plt.imshow(img)
plt.show()

下面是运行结果:

【Pytorch实战4】基于CIFAR10数据集训练一个分类器_第2张图片

【Pytorch实战4】基于CIFAR10数据集训练一个分类器_第3张图片

可以看到,训练的损失值一直在下降,而训练集和测试集的准确率也基本一致。随机抽取的四张测试图片有三张预测正确。

如果增加epoch数准确率应该能进一步提升。

通常代码到这里就结束了,如果想要进一步分析模型,可以统计一下测试集各个类的准确率。代码如下

#统计测试集各个类的准确率
class_correct = [0. for i in range(10)]
class_total = [0 for i in range(10)]
for data in data_loader_test:
    x_test,y_test = data
    outputs = net(x_test)
    _,pred = torch.max(outputs,1)
    c = (pred == y_test).squeeze()

    for i in range(4):
        label = y_test[i]
        class_correct[label]+=c[i].item()
        class_total[label] +=1

for i in range(10):
    print('Accuracy of %s: %d %%'%(classes[i],100*class_correct[i]/class_total[i]))

运行结果如下:

【Pytorch实战4】基于CIFAR10数据集训练一个分类器_第4张图片

你可能感兴趣的:(pytorch,pytorch)