本文将训练一个Cifar-10分类的网络,并进行随机提取几张图像进行验证
# 导入所需要的模块
import torch
import torchvision
import torch.nn.functional as F
import torch.nn as nn
from torchvision import datasets
from torchvision import transforms
from torch import optim
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from model import LeNet5,AlexNet,VGG16,GoogLeNet,ResNet
# 设置超参数
batch_size = 128
learning_rate = 1e-3
epochsize = 30
# 训练集下载
cifar_traindata = datasets.CIFAR10('E:/学习/机器学习/数据集/cifar', True, transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
]), download=True)
cifar_train = DataLoader(cifar_traindata, batch_size=batch_size, shuffle=True)
# 测试集下载
cifar_testdata = datasets.CIFAR10('E:/学习/机器学习/数据集/cifar', False, transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
]), download=True)
cifar_test = DataLoader(cifar_testdata, batch_size=batch_size, shuffle=True)
# 查看相关参数
# real_image, label = iter(cifar_train).next()
# print('real_image:', real_image.shape, 'label:', label.shape)
# 利用GPU加速
device = torch.device('cuda')
# 定义模型
#model = LeNet5().to(device)
#model = AlexNet().to(device)
#model = VGG16().to(device)
model = GoogLeNet().to(device)
#model = ResNet().to(device)
# 导入参数
# model.load_state_dict(torch.load('E:/PyCharm/workspace/Cifar10/GoogLeNet.mdl'))
# 一般来说,分类任务使用CrossEntropyLoss;回归任务使用MSELoss
criteon = nn.CrossEntropyLoss().to(device)
# 定义优化器
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# print(model)
# 进行迭代训练
for epoch in range(epochsize):
# 训练过程
model.train()
# total_num = 0
for batchidx, (image, imagelabel) in enumerate(cifar_train):
# image.shape:[batch_size, 3, 32, 32]
image, imagelabel = image.to(device), imagelabel.to(device)
# category.shape:{batchbatch_size, 10}
category = model(image)
# category: [batch_size, 10]
# imagelabel:[batch_size]
# 计算损失
loss = criteon(category, imagelabel)
# 反向更新训练
optimizer.zero_grad()
loss.backward()
optimizer.step()
# total_num += image.size(0)
# print( 'total_num:', total_num)
print(epoch, 'loss:', loss.item())
# 测试过程
model.eval()
# 不进行计算图构建
with torch.no_grad():
total_connect = 0 # 总的正确个数
total_num = 0 # 总的当前测试个数
for (image, imagelabel) in cifar_test:
# image.shape:[batch_size, 3, 32, 32]
image, imagelabel = image.to(device), imagelabel.to(device)
# category.shape:{batchbatch_size, 10}
category = model(image)
# 得到最大值的索引
pred = category.argmax(dim=1)
# _, pred = category.max(dim=1)
# 计算每一次正确的个数
total_connect += torch.eq(pred, imagelabel).detach().float().sum().item()
total_num += image.size(0)
# 计算一次训练之后计算率
acc = total_connect / total_num
print('epoch:', epoch, 'test_acc:', acc)
# 保存网络结构
torch.save(model.state_dict(), 'GoogLeNet.mdl')
# 导入所需要的模块
import torch
import torchvision
import torch.nn.functional as F
import torch.nn as nn
from torchvision import datasets
from torchvision import transforms
from torch import optim
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from model import AlexNet
# 设置超参数
batch_size = 128
learning_rate = 1e-3
epochsize = 50
# 测试集下载
cifar_testdata = datasets.CIFAR10('E:/学习/机器学习/数据集/cifar', False, transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
]), download=True)
cifar_test = DataLoader(cifar_testdata, batch_size=batch_size, shuffle=True)
Files already downloaded and verified
image, label = iter(cifar_test).next()
# label.shape # torch.Size([128])
label
tensor([7, 1, 3, 1, 1, 9, 7, 9, 5, 2, 9, 0, 8, 6, 5, 0, 4, 6, 4, 2, 4, 1, 4, 4,
5, 8, 8, 5, 1, 0, 4, 6, 7, 5, 5, 8, 9, 8, 6, 1, 4, 1, 2, 3, 2, 9, 3, 3,
0, 4, 4, 7, 1, 9, 2, 6, 7, 4, 3, 1, 1, 1, 0, 3, 9, 7, 1, 7, 8, 5, 5, 8,
3, 2, 2, 5, 4, 7, 4, 7, 7, 9, 3, 0, 4, 9, 8, 0, 9, 8, 6, 3, 3, 0, 8, 4,
0, 3, 9, 5, 2, 5, 4, 1, 3, 2, 6, 1, 5, 9, 6, 4, 2, 2, 0, 9, 2, 7, 1, 8,
5, 7, 4, 5, 6, 0, 0, 8])
# 显示一个batch_size的照片
real_image = torchvision.utils.make_grid(image)
real_image = real_image.numpy()
plt.imshow(np.transpose(real_image,(1,2,0)))
# 显示单张图片
photo = torchvision.utils.make_grid(image[1])
photo = photo.numpy()
plt.imshow(np.transpose(photo,(1,2,0)))
device = torch.device('cuda')
net = AlexNet().to(device)
net.load_state_dict(torch.load('E:/PyCharm/workspace/Cifar10/AlexNet.mdl'))
label
tensor([7, 1, 3, 1, 1, 9, 7, 9, 5, 2, 9, 0, 8, 6, 5, 0, 4, 6, 4, 2, 4, 1, 4, 4,
5, 8, 8, 5, 1, 0, 4, 6, 7, 5, 5, 8, 9, 8, 6, 1, 4, 1, 2, 3, 2, 9, 3, 3,
0, 4, 4, 7, 1, 9, 2, 6, 7, 4, 3, 1, 1, 1, 0, 3, 9, 7, 1, 7, 8, 5, 5, 8,
3, 2, 2, 5, 4, 7, 4, 7, 7, 9, 3, 0, 4, 9, 8, 0, 9, 8, 6, 3, 3, 0, 8, 4,
0, 3, 9, 5, 2, 5, 4, 1, 3, 2, 6, 1, 5, 9, 6, 4, 2, 2, 0, 9, 2, 7, 1, 8,
5, 7, 4, 5, 6, 0, 0, 8])
以上是一些图片的标签分类,下面进行预测
# 选择第0张图片进行验证
test = image[0] # torch.Size([3, 32, 32])
test = test.unsqueeze(0) # torch.Size([1, 3, 32, 32])
test = test.to(device)
pred = net(test) # torch.Size([1, 10])
result = F.softmax(pred) # 求概率
result.max(dim=1)
torch.return_types.max(
values=tensor([1.0000], device='cuda:0', grad_fn=),
indices=tensor([7], device='cuda:0'))
可以看见网络预测输出是7,label[0],预测结果正确
# 选择第1张图片进行验证
test = image[1] # torch.Size([3, 32, 32])
test = test.unsqueeze(0) # torch.Size([1, 3, 32, 32])
test = test.to(device)
pred = net(test) # torch.Size([1, 10])
result = F.softmax(pred) # 求概率
result.max(dim=1)
torch.return_types.max(
values=tensor([0.8652], device='cuda:0', grad_fn=),
indices=tensor([1], device='cuda:0'))
可以看见网络预测输出是1,label[1],预测结果正确
# 选择第2张图片进行验证
test = image[2] # torch.Size([3, 32, 32])
test = test.unsqueeze(0) # torch.Size([1, 3, 32, 32])
test = test.to(device)
pred = net(test) # torch.Size([1, 10])
result = F.softmax(pred) # 求概率
result.max(dim=1)
torch.return_types.max(
values=tensor([0.9720], device='cuda:0', grad_fn=),
indices=tensor([3], device='cuda:0'))
可以看见网络预测输出是3,与label[2]相同,预测结果正确
随意测试一下,训练后的网络前3张图片的预测均正确