pytorch CNN CIFAR10数据集识别

尝试使用深层结构进行CIFAR10的识别

import torch
import torchvision
import torchvision.transforms as transforms

BATCH_SIZE = 64
EPOCHES = 50
NUM_WORKERS = 4
LEARNING_RATE = 0.005

# 数据转换
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# 加载训练数据和测试数据
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,
                                          shuffle=True, num_workers=NUM_WORKERS)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE,
                                         shuffle=False, num_workers=NUM_WORKERS)

# 类别标签
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
Files already downloaded and verified
Files already downloaded and verified

下面定义网络

import torch.nn as nn
import torch.nn.functional as F

# 参考https://www.jianshu.com/p/016a23bc6554
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.conv3 = nn.Conv2d(32, 64, 3, padding=1)
        self.conv4 = nn.Conv2d(64, 128, 3, padding=1)
        self.conv5 = nn.Conv2d(128, 256, 3, padding=1)
        self.MaxPool = nn.MaxPool2d(2, 2)
        self.AvgPool = nn.AvgPool2d(4, 4)
        self.fc1 = nn.Linear(256, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 32)
        self.fc4 = nn.Linear(32, 10)
        
    def forward(self, x):
        x = F.relu(self.conv1(x))               # (3,32,32) -> (16,32,32)
        x = self.MaxPool(F.relu(self.conv2(x))) # (16,32,32) -> (32,16,16)
        x = F.relu(self.conv3(x))               # (32,16,16) -> (64,16,16)
        x = self.MaxPool(F.relu(self.conv4(x))) # (64,16,16) -> (128,8,8)
        x = self.MaxPool(F.relu(self.conv5(x))) # (128,8,8) -> (256,4,4)
        x = self.AvgPool(x)                          # (256,1,1)
        x = x.view(-1, 256)                     # (256)
        x = self.fc3(self.fc2(self.fc1(x)))     # (32)
        x = self.fc4(x)                         # (10)
        return x
        
net = Net()
if torch.cuda.is_available():
    net = net.cuda()
print(net)
Net(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv5): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (MaxPool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (AvgPool): AvgPool2d(kernel_size=4, stride=4, padding=0)
  (fc1): Linear(in_features=256, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=64, bias=True)
  (fc3): Linear(in_features=64, out_features=32, bias=True)
  (fc4): Linear(in_features=32, out_features=10, bias=True)
)

我们先看看一个batch的数据

import matplotlib.pyplot as plt
import numpy as np

dataiter = iter(trainloader)
image, label = dataiter.next()
print(image.shape)
print(label.shape)

def imshow(img):
    img = img / 2 + 0.5  # 反标准化     
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

imshow(torchvision.utils.make_grid(image))

# 打印标签,一行八个,打印八行
for i in range(8):
    print(" ".join("%5s" % classes[label[i*8+j]] for j in range(8)))
torch.Size([64, 3, 32, 32])
torch.Size([64])

pytorch CNN CIFAR10数据集识别_第1张图片

 bird   dog truck   cat   car   dog horse  ship
 deer  ship   dog  bird   car   cat plane  deer
 deer  deer   car truck plane   dog  deer  ship
 bird  bird horse truck truck  ship  deer   dog
truck  frog plane   car  bird   cat   car plane
 ship horse plane truck   car  deer horse  ship
horse truck  ship  ship   dog   dog  deer  ship
plane  frog   dog  bird plane  bird  ship  ship

使用交叉熵作为损失函数,Adam函数为优化函数

import torch.optim as optim
from torch.autograd import Variable
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=LEARNING_RATE)


定义训练函数和测试函数

def train(path):
    losses = []
    acces = []
    test_acc = []
    print("------train start------")
    for epoch in range(1, EPOCHES+1):
        train_loss = 0
        train_acc = 0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            
            if torch.cuda.is_available():
                inputs = Variable(inputs).cuda()
                labels = Variable(labels).cuda()
            else:
                inputs = Variable(inputs)
                labels = Variable(labels)
            
            # 前向传播
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            # 计算损失值
            train_loss += loss.item()
            
            # 计算训练的时候得到的准确率
            _, pred = outputs.max(1)
            num_correct = (pred==labels).sum().item()
            acc = num_correct/BATCH_SIZE
            train_acc += acc
        
        # 一轮训练完了,添加数据
        losses.append(train_loss/len(trainloader))
        acces.append(train_acc/len(trainloader))
        test_acc.append(test(path))
        
        print("epoch: {}, Train Loss: {:.6f}, Train acc: {:.6f}, Test acc: {:.6f}"
              .format(epoch, losses[epoch-1], acces[epoch-1], test_acc[epoch-1]))
    
    torch.save(net.state_dict(), path)
    print('Finished Training')
    return losses, acces, test_acc
def test(path):
    print("------test start------")
    test_acc = 0
    for i, data in enumerate(testloader, 0):
        inputs, labels = data
            
        if torch.cuda.is_available():
            inputs = Variable(inputs).cuda()
            labels = Variable(labels).cuda()
        else:
            inputs = Variable(inputs)
            labels = Variable(labels)
             
        # 前向传播
        optimizer.zero_grad()
        outputs = net(inputs)

        # 计算准确率
        _, pred = outputs.max(1)
        acc = (pred==labels).sum().item()/BATCH_SIZE
        test_acc += acc
    print("test acc: {:.6f}".format(test_acc/len(testloader)))
    return test_acc/len(testloader)
losses = []
acces = []
test_acc = []
path = "CIFAR10_Deep.pth"
losses, acces, test_acc = train(path)
------train start------
------test start------
test acc: 0.346338
epoch: 1, Train Loss: 1.823356, Train acc: 0.271439, Test acc: 0.346338
------test start------
test acc: 0.471537
epoch: 2, Train Loss: 1.486024, Train acc: 0.434043, Test acc: 0.471537
------test start------
test acc: 0.532245
epoch: 3, Train Loss: 1.279489, Train acc: 0.527354, Test acc: 0.532245
------test start------
test acc: 0.568969
epoch: 4, Train Loss: 1.168447, Train acc: 0.574628, Test acc: 0.568969
------test start------
test acc: 0.604100
epoch: 5, Train Loss: 1.096343, Train acc: 0.603201, Test acc: 0.604100
------test start------
test acc: 0.615844
epoch: 6, Train Loss: 1.052584, Train acc: 0.624181, Test acc: 0.615844
------test start------
test acc: 0.609375
epoch: 7, Train Loss: 1.002270, Train acc: 0.639466, Test acc: 0.609375
------test start------
test acc: 0.627986
epoch: 8, Train Loss: 0.965407, Train acc: 0.653613, Test acc: 0.627986
------test start------
test acc: 0.648487
epoch: 9, Train Loss: 0.935815, Train acc: 0.664762, Test acc: 0.648487
------test start------
test acc: 0.647492
epoch: 10, Train Loss: 0.898331, Train acc: 0.679208, Test acc: 0.647492
------test start------
test acc: 0.659136
epoch: 11, Train Loss: 0.866001, Train acc: 0.691196, Test acc: 0.659136
------test start------
test acc: 0.667396
epoch: 12, Train Loss: 0.836056, Train acc: 0.704364, Test acc: 0.667396
------test start------
test acc: 0.670482
epoch: 13, Train Loss: 0.810909, Train acc: 0.712036, Test acc: 0.670482
------test start------
test acc: 0.661027
epoch: 14, Train Loss: 0.785331, Train acc: 0.721947, Test acc: 0.661027
------test start------
test acc: 0.676254
epoch: 15, Train Loss: 0.752643, Train acc: 0.733696, Test acc: 0.676254
------test start------
test acc: 0.667297
epoch: 16, Train Loss: 0.720532, Train acc: 0.743606, Test acc: 0.667297
------test start------
test acc: 0.673069
epoch: 17, Train Loss: 0.702551, Train acc: 0.750959, Test acc: 0.673069
------test start------
test acc: 0.678344
epoch: 18, Train Loss: 0.663924, Train acc: 0.763807, Test acc: 0.678344
------test start------
test acc: 0.671477
epoch: 19, Train Loss: 0.648946, Train acc: 0.770480, Test acc: 0.671477
------test start------
test acc: 0.671079
epoch: 20, Train Loss: 0.625400, Train acc: 0.777873, Test acc: 0.671079
------test start------
test acc: 0.668690
epoch: 21, Train Loss: 0.591782, Train acc: 0.789502, Test acc: 0.668690
------test start------
test acc: 0.678045
epoch: 22, Train Loss: 0.569181, Train acc: 0.796875, Test acc: 0.678045
------test start------
test acc: 0.667496
epoch: 23, Train Loss: 0.545593, Train acc: 0.804867, Test acc: 0.667496
------test start------
test acc: 0.648288
epoch: 24, Train Loss: 0.521233, Train acc: 0.814218, Test acc: 0.648288
------test start------
test acc: 0.660231
epoch: 25, Train Loss: 0.500628, Train acc: 0.821032, Test acc: 0.660231
------test start------
test acc: 0.671576
epoch: 26, Train Loss: 0.495962, Train acc: 0.823030, Test acc: 0.671576
------test start------
test acc: 0.668193
epoch: 27, Train Loss: 0.466635, Train acc: 0.833300, Test acc: 0.668193
------test start------
test acc: 0.646198
epoch: 28, Train Loss: 0.445843, Train acc: 0.839934, Test acc: 0.646198
------test start------
test acc: 0.656449
epoch: 29, Train Loss: 0.434789, Train acc: 0.845269, Test acc: 0.656449
------test start------
test acc: 0.659236
epoch: 30, Train Loss: 0.399209, Train acc: 0.859055, Test acc: 0.659236
------test start------
test acc: 0.659236
epoch: 31, Train Loss: 0.405278, Train acc: 0.855319, Test acc: 0.659236
------test start------
test acc: 0.664013
epoch: 32, Train Loss: 0.383631, Train acc: 0.864110, Test acc: 0.664013
------test start------
test acc: 0.639829
epoch: 33, Train Loss: 0.368625, Train acc: 0.869266, Test acc: 0.639829
------test start------
test acc: 0.654956
epoch: 34, Train Loss: 0.376511, Train acc: 0.865589, Test acc: 0.654956
------test start------
test acc: 0.654260
epoch: 35, Train Loss: 0.337547, Train acc: 0.878976, Test acc: 0.654260
------test start------
test acc: 0.651174
epoch: 36, Train Loss: 0.354066, Train acc: 0.873481, Test acc: 0.651174
------test start------
test acc: 0.627986
epoch: 37, Train Loss: 0.321658, Train acc: 0.884531, Test acc: 0.627986
------test start------
test acc: 0.652269
epoch: 38, Train Loss: 0.336692, Train acc: 0.881694, Test acc: 0.652269
------test start------
test acc: 0.630573
epoch: 39, Train Loss: 0.303990, Train acc: 0.892363, Test acc: 0.630573
------test start------
test acc: 0.653463
epoch: 40, Train Loss: 0.306856, Train acc: 0.891824, Test acc: 0.653463
------test start------
test acc: 0.656748
epoch: 41, Train Loss: 0.302821, Train acc: 0.893143, Test acc: 0.656748
------test start------
test acc: 0.654061
epoch: 42, Train Loss: 0.292990, Train acc: 0.896739, Test acc: 0.654061
------test start------
test acc: 0.656250
epoch: 43, Train Loss: 0.284619, Train acc: 0.899157, Test acc: 0.656250
------test start------
test acc: 0.653762
epoch: 44, Train Loss: 0.301842, Train acc: 0.894581, Test acc: 0.653762
------test start------
test acc: 0.647393
epoch: 45, Train Loss: 0.249959, Train acc: 0.913183, Test acc: 0.647393
------test start------
test acc: 0.646994
epoch: 46, Train Loss: 0.288888, Train acc: 0.899177, Test acc: 0.646994
------test start------
test acc: 0.633459
epoch: 47, Train Loss: 0.275196, Train acc: 0.904572, Test acc: 0.633459
------test start------
test acc: 0.650577
epoch: 48, Train Loss: 0.263865, Train acc: 0.908148, Test acc: 0.650577
------test start------
test acc: 0.640326
epoch: 49, Train Loss: 0.264790, Train acc: 0.907809, Test acc: 0.640326
------test start------
test acc: 0.649781
epoch: 50, Train Loss: 0.236096, Train acc: 0.917439, Test acc: 0.649781
Finished Training
x = np.arange(1, 1+EPOCHES)
plt.plot(x, losses)
plt.title("train losses")
plt.xlabel("epoch")
plt.ylabel("loss")
plt.grid()

pytorch CNN CIFAR10数据集识别_第2张图片

plt.plot(x, acces, label="train acc")
plt.plot(x, test_acc, label="test acc")
plt.title("accuracy")
plt.xlabel("epoch")
plt.ylabel("acc")
plt.grid()
plt.legend()
plt.show()

pytorch CNN CIFAR10数据集识别_第3张图片

可以发现,虽然训练准确率很高,但是在测试集上的准确率一直徘徊在0.65左右上升不了了
我们拿测试集中的8个数据测试一下结果

dataiter = iter(testloader)
image, label = dataiter.next()
print(image.shape)
print(label.shape)
imshow(torchvision.utils.make_grid(image[:8]))

image = Variable(image).cuda()
output = net(image)
_,pred = output.max(1)
print("lables:     "+" ".join(classes[label[i]]for i in range(8)))
print("prediction: "+" ".join(classes[pred[i]] for i in range(8)))
torch.Size([64, 3, 32, 32])
torch.Size([64])

在这里插入图片描述

lables:     cat ship ship plane frog frog car frog
prediction: dog car ship ship frog frog car bird

你可能感兴趣的:(计算机视觉)