源码参考https://github.com/zergtant/pytorch-handbook/blob/master/chapter1/4_cifar10_tutorial.ipynb
稍作修改
CIFAR10是基本的图片数据库,共十个分类,训练集有50000张图片,测试集有10000张图片,图片均为32*32分辨率。Pytorch的torchvision可以很方便的下载使用CIFAR10的数据,代码如下:
import torch
import torchvision
import torchvision.transforms as transforms
#定义超参数
BATCH_SIZE = 4
EPOCH = 2
#torchvision模块载入CIFAR10数据集,并且通过transform归一化到[0,1]
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data',train = True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset,batch_size = BATCH_SIZE,
shuffle = True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data',train = False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset,batch_size = BATCH_SIZE,
shuffle = False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
pytorch提供了很方便的接口下载常用数据库,包括MNIST,CIFAR10等,并且输出训练集以及测试集:
output[channel] = (input[channel] - mean[channel]) / std[channel]
plt.imshow(trainset.data[86]) #trainset.data中储存了原始数据,并且是array格式
plt.show()
dataiter = iter(trainloader)
images, labels = dataiter.next()
images_comb = torchvision.utils.make_grid(images)
images_comb_unnor = (images_comb*0.5+0.5).numpy()
plt.imshow(np.transpose(images_comb_unnor, (1, 2, 0)))
plt.show()
Python的模块matplotlib是很方便的绘图:
trainset[0][1]
Out[13]: 6
trainset[0][0]
Out[14]:
tensor([[[-0.5373, -0.6627, -0.6078, ..., 0.2392, 0.1922, 0.1608],
[-0.8745, -1.0000, -0.8588, ..., -0.0353, -0.0667, -0.0431],
[-0.8039, -0.8745, -0.6157, ..., -0.0745, -0.0588, -0.1451],
...,
[ 0.6314, 0.5765, 0.5529, ..., 0.2549, -0.5608, -0.5843],
[ 0.4118, 0.3569, 0.4588, ..., 0.4431, -0.2392, -0.3490],
[ 0.3882, 0.3176, 0.4039, ..., 0.6941, 0.1843, -0.0353]],
[[-0.5137, -0.6392, -0.6235, ..., 0.0353, -0.0196, -0.0275],
[-0.8431, -1.0000, -0.9373, ..., -0.3098, -0.3490, -0.3176],
[-0.8118, -0.9451, -0.7882, ..., -0.3412, -0.3412, -0.4275],
...,
[ 0.3333, 0.2000, 0.2627, ..., 0.0431, -0.7569, -0.7333],
[ 0.0902, -0.0353, 0.1294, ..., 0.1608, -0.5137, -0.5843],
[ 0.1294, 0.0118, 0.1137, ..., 0.4431, -0.0745, -0.2784]],
[[-0.5059, -0.6471, -0.6627, ..., -0.1529, -0.2000, -0.1922],
[-0.8431, -1.0000, -1.0000, ..., -0.5686, -0.6078, -0.5529],
[-0.8353, -1.0000, -0.9373, ..., -0.6078, -0.6078, -0.6706],
...,
[-0.2471, -0.7333, -0.7961, ..., -0.4510, -0.9451, -0.8431],
[-0.2471, -0.6706, -0.7647, ..., -0.2627, -0.7333, -0.7333],
[-0.0902, -0.2627, -0.3176, ..., 0.0980, -0.3412, -0.4353]]])
trainset[0][0].shape
Out[15]: torch.Size([3, 32, 32])
type(trainset[0])
Out[16]: tuple
type(trainset[0])
Out[19]: tuple
type(trainset[0][0])
Out[20]: torch.Tensor
type(trainset[0][1])
Out[21]: int
class CNN_NET(torch.nn.Module):
def __init__(self):
super(CNN_NET,self).__init__()
self.conv1 = torch.nn.Conv2d(in_channels = 3,
out_channels = 6,
kernel_size = 5,
stride = 1,
padding = 0)
self.pool = torch.nn.MaxPool2d(kernel_size = 2,
stride = 2)
self.conv2 = torch.nn.Conv2d(6,16,5)
self.fc1 = torch.nn.Linear(16*5*5,120)
self.fc2 = torch.nn.Linear(120,84)
self.fc3 = torch.nn.Linear(84,10)
def forward(self,x):
x=self.pool(F.relu(self.conv1(x)))
x=self.pool(F.relu(self.conv2(x)))
x=x.view(-1,16*5*5) #卷积结束后将多层图片平铺batchsize行16*5*5列,每行为一个sample,16*5*5个特征
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = CNN_NET()
class CNN_NET(torch.nn.Module):
def __init__(self):
super(CNN_NET,self).__init__()
...
def forward(self,x):
return x
import torch.optim as optim
optimizer = optim.SGD(net.parameters(),lr=0.001,momentum=0.9)
loss_func =torch.nn.CrossEntropyLoss() # 预测值和真实值的误差计算公式 (交叉熵)
for epoch in range(EPOCH):
running_loss = 0.0
for step, (b_x,b_y)in enumerate(trainloader):
outputs = net(b_x) # 喂给 net 训练数据 x, 输出预测值
loss = loss_func(outputs, b_y) # 计算两者的误差
optimizer.zero_grad() # 清空上一步的残余更新参数值
loss.backward() # 误差反向传播, 计算参数更新值
optimizer.step() # 将参数更新值施加到 net 的 parameters 上
# 打印状态信息
running_loss += loss.item()
if step % 1000 == 999: # 每2000批次打印一次
print('[%d, %5d] loss: %.3f' %
(epoch + 1, step + 1, running_loss / 2000))
running_loss = 0.0
print('Finished Training')
[1, 2000] loss: 2.186
[1, 4000] loss: 1.879
[1, 6000] loss: 1.671
[1, 8000] loss: 1.594
[1, 10000] loss: 1.537
[1, 12000] loss: 1.479
[2, 2000] loss: 1.408
[2, 4000] loss: 1.400
[2, 6000] loss: 1.360
[2, 8000] loss: 1.342
[2, 10000] loss: 1.337
[2, 12000] loss: 1.283
correct = 0
total = 0
with torch.no_grad():
#不计算梯度,节省时间
for (images,labels) in testloader:
outputs = net(images)
numbers,predicted = torch.max(outputs.data,1)
total +=labels.size(0)
correct+=(predicted==labels).sum().item()
print('Accuracy of the network on the 10000 test images: %d %%' % (
100 * correct / total))
class CNN_NET(torch.nn.Module):
def __init__(self):
super(CNN_NET,self).__init__()
self.conv1 = torch.nn.Conv2d(in_channels = 3,
out_channels = 64,
kernel_size = 5,
stride = 1,
padding = 0)
self.pool = torch.nn.MaxPool2d(kernel_size = 3,
stride = 2)
self.conv2 = torch.nn.Conv2d(64,64,5)
self.fc1 = torch.nn.Linear(64*4*4,384)
self.fc2 = torch.nn.Linear(384,192)
self.fc3 = torch.nn.Linear(192,10)
def forward(self,x):
x=self.pool(F.relu(self.conv1(x)))
x=self.pool(F.relu(self.conv2(x)))
x=x.view(-1,64*4*4)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
########################更新卷积参数############################
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import torch.nn.functional as F
#hyper parameter
BATCH_SIZE = 4
EPOCH = 2
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data',train = True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset,batch_size = BATCH_SIZE,
shuffle = True, num_workers=1)
testset = torchvision.datasets.CIFAR10(root='./data',train = False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset,batch_size = BATCH_SIZE,
shuffle = False, num_workers=1)
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
# plt.imshow(trainset.data[86]) #trainset.data中储存了原始数据,并且是array格式
# plt.show()
# dataiter = iter(trainloader)
# images, labels = dataiter.next()
# images_comb = torchvision.utils.make_grid(images)
# images_comb_unnor = (images_comb*0.5+0.5).numpy()
# plt.imshow(np.transpose(images_comb_unnor, (1, 2, 0)))
# plt.show()
class CNN_NET(torch.nn.Module):
def __init__(self):
super(CNN_NET,self).__init__()
self.conv1 = torch.nn.Conv2d(in_channels = 3,
out_channels = 64,
kernel_size = 5,
stride = 1,
padding = 0)
self.pool = torch.nn.MaxPool2d(kernel_size = 3,
stride = 2)
self.conv2 = torch.nn.Conv2d(64,64,5)
self.fc1 = torch.nn.Linear(64*4*4,384)
self.fc2 = torch.nn.Linear(384,192)
self.fc3 = torch.nn.Linear(192,10)
def forward(self,x):
x=self.pool(F.relu(self.conv1(x)))
x=self.pool(F.relu(self.conv2(x)))
x=x.view(-1,64*4*4)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = CNN_NET()
import torch.optim as optim
optimizer = optim.SGD(net.parameters(),lr=0.001,momentum=0.9)
loss_func =torch.nn.CrossEntropyLoss()
for epoch in range(EPOCH):
running_loss = 0.0
for step, data in enumerate(trainloader):
b_x,b_y=data
outputs = net.forward(b_x)
loss = loss_func(outputs, b_y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 打印状态信息
running_loss += loss.item()
if step % 1000 == 999: # 每2000批次打印一次
print('[%d, %5d] loss: %.3f' %
(epoch + 1, step + 1, running_loss / 2000))
running_loss = 0.0
print('Finished Training')
dataiter = iter(trainloader)
images, labels = dataiter.next()
images_comb = torchvision.utils.make_grid(images)
images_comb_unnor = (images_comb*0.5+0.5).numpy()
plt.imshow(np.transpose(images_comb_unnor, (1, 2, 0)))
plt.show()
predicts=net.forward(images)
########测试集精度#######
correct = 0
total = 0
with torch.no_grad():
#不计算梯度,节省时间
for (images,labels) in testloader:
outputs = net(images)
numbers,predicted = torch.max(outputs.data,1)
total +=labels.size(0)
correct+=(predicted==labels).sum().item()
print('Accuracy of the network on the 10000 test images: %d %%' % (
100 * correct / total))