摸鱼第一天 白嫖colab实现一个简易的CNN

一直没有好好学,就跟着教程写一遍吧

参考资料

Pytorch官方文档中文版
http://pytorch123.com/


使用CIFAR数据集

  • 图片尺寸为 3通道32*32

主要的使用的函数

  • torch 基础包, utils 还有一些函数用的挺多的
  • torch.nn 神经网络各种模块应有尽有, 卷积,全连接,以及各种损失函数等
  • torch.nn.functional 神经网络使用的各种函数,非线性激活函数
  • torrchvision 各种计算机视觉的工具,以及数据集, 图片的转换函数等
  • torch.optim 优化函数的包

下载数据集

一共分为十类, 类别参照classes

import torch
import torchvision
import torchvision.transforms as transforms

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ]
)
trainset = torchvision.datasets.CIFAR10(root = './data', train=True,
                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                      shuffle=True, num_workers=4)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                      download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                      shuffle=False, num_workers=4)
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

显示图片

import matplotlib.pyplot as plt
import numpy as np

def imshow(img):
  img = img/2 + 0.5  #  变换到没有进行转换的表示
  npimg = img.numpy()  # 从tensor转换到numpy形式的
  print(np.transpose(npimg,(1,2,0)).shape)
  plt.imshow(np.transpose(npimg,(1,2,0)))
  plt.show()

dataiter = iter(trainloader)   # 获取一个迭代器
images, labels = dataiter.next()
imshow(torchvision.utils.make_grid(images))
print(' '.join('%10s' % classes[labels[j]] for j in range(4)))

摸鱼第一天 白嫖colab实现一个简易的CNN_第1张图片

网络的定义

需要继承torch.nn这个类
要注意图片的大小以及卷积之后feature map的通道数和尺寸
可参考 卷积神经网络池化后的特征图大小计算

import torch.nn as nn
import torch.nn.functional as F
import torch

class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.conv1 = nn.Conv2d(3, 6, 3)
    self.conv2 = nn.Conv2d(6, 16, 3)
    self.conv3 = nn.Conv2d(16, 10, 3)

    self.fc1 = nn.Linear(10*4*4, 80)
    self.fc2 = nn.Linear(80, 40)
    self.fc3 = nn.Linear(40, 10)
  
  def forward(self, x):
    x = F.max_pool2d(F.relu(self.conv1(x)), (2,2))
    x = F.max_pool2d(F.relu(self.conv2(x)), 2)
    x = F.relu(self.conv3(x))

    x = x.view(-1, self.num_flat_features(x))
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x
    
  def num_flat_features(self, x):
    size = x.size()[1:]  # all dimensions except the batch dimension
    num_features = 1
    for s in size:
      num_features *= s
    return num_features


net = Net()
print(net)

这样一个简单的CNN结构就搭建起来了

摸鱼第一天 白嫖colab实现一个简易的CNN_第2张图片

训练过程

主要是指定如下的内容

  • 损失函数 这里使用的交叉熵 Cross-Entropy
  • 是否使用GPU
  • 参数更新的函数 这里采用随机梯度下降 (SGD)
  • 迭代的次数
  • 学习率
import torch.optim as optim

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
net = Net()

learningrate = 0.01
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr = learningrate)

net = net.to(device)
for epoch in range(10):
  running_loss = 0.0
  for i, data in enumerate(trainloader):
    inputs, labels = data

    inputs = inputs.to(device)
    labels = labels.to(device)

    output = net(inputs)
    # _,output = max(net(inputs),dim=1)
    optimizer.zero_grad()
    loss = criterion(output,labels)
    loss.backward()
    optimizer.step()

    running_loss += loss.item() 
    if i%2000 == 1999:
       print('[%d, %5d] loss: %.3f' %
             (epoch + 1, i + 1, running_loss / 2000))
       running_loss = 0.0

print("Training Finished")

查看验证集上的准确率

class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10)) 
acc = 0
total = 0
with torch.no_grad():
  for data in testloader:
    images, labels = data
    images = images.to(device)
    labels = labels.to(device)
    outputs = net(images) 
    _, pred = torch.max(outputs, 1)
    c = (pred == labels).squeeze()
    acc += torch.sum(c)
    total += labels.size()[0]
    for i in range(4):
      label = labels[i]            
      class_correct[label] += c[i].item()            
      class_total[label] += 1
for i in range(10):    
  print('Accuracy of %5s : %2d %%' % (        
    classes[i], 100 * class_correct[i] / class_total[i]))
print("Accuracy: %f %%" % (100*acc/total))

摸鱼第一天 白嫖colab实现一个简易的CNN_第3张图片

正确率有60 还算不错

你可能感兴趣的:(摸鱼第一天 白嫖colab实现一个简易的CNN)