这篇是PyTorch学习之路第七篇,用于记录PyTorch实现CIFAR-10分类代码
(书上的代码有好多冗余)
下面实例数据集位于:C:\Users\22130\Learning_Pytorch\dataset
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
train_batch_size = 4
test_batch_size = 4
num_workers = 0 #线程数
classes = ('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')
lr = 0.001
momentum = 0.9
#加载数据集
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])])
train_dataset = torchvision.datasets.CIFAR10('./dataset',train=True,transform=transform,download=True)
test_dataset = torchvision.datasets.CIFAR10('./dataset',train=False,transform=transform,download=True)
train_loader = DataLoader(train_dataset,batch_size=train_batch_size,shuffle=True,num_workers=num_workers)
test_loader = DataLoader(test_dataset,batch_size=test_batch_size,shuffle=False,num_workers=num_workers)
#数据可视化
import matplotlib.pyplot as plt
import numpy as np
plt.figure()
def imshow(img):
img = img/2 +0.5
npimg = img.numpy()
plt.imshow(np.transpose(npimg,(1,2,0)))
plt.show()
examples = enumerate(train_loader)
idx, (examples_data, examples_target) = next(examples) #examples_target是标签列表,0-9表示不同的类别
imshow(torchvision.utils.make_grid(examples_data))
#用于具体查看examples
print('--------------测试examples------------')
print('examples_target.shape:{}'.format(examples_target.shape))
print('examples_target[0]:{}'.format(examples_target[0]))
print('examples_data.shape:{}'.format(examples_data.shape))
#构建网络
import torch.nn as nn
import torch.nn.functional as F
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
class CNNNet(nn.Module):
def __init__(self):
super(CNNNet, self).__init__()
self.conv1 = nn.Conv2d(in_channels=3,out_channels=16,kernel_size=5,stride=1)
self.pool1 = nn.MaxPool2d(kernel_size=2,stride=2)
self.conv2 = nn.Conv2d(in_channels=16,out_channels=36,kernel_size=3,stride=1)
self.pool2 = nn.MaxPool2d(kernel_size=2,stride=2)
#self.aap = nn.AdaptiveAvgPool2d(1)
self.fc1 = nn.Linear(1296,128)
self.fc2 = nn.Linear(128,10)
#self.fc3 = nn.Linear(36,10)
def forward(self,x):
x = self.pool1(F.relu(self.conv1(x)))
x = self.pool2(F.relu(self.conv2(x)))
#x = self.aap(x)
#x = x.view(x.shape[0],-1)
#x = self.fc3(x)
x = x.view(-1,36*6*6)
#print("x.shape:{}".format(x.shape))
x = F.relu(self.fc2(F.relu(self.fc1(x))))
return x
model = CNNNet()
model = model.to(device)
print('--------------查看网络结构-----------')
print(model)
#--训练模型--
print('-----训练优化器-------')
import torch.optim as optim
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
print("----------正式训练模型---------")
losses = []
acces = []
eval_losses = []
eval_acces = []
for epoch in range(10):
train_acc = 0
train_loss = 0
num_correct = 0
model.train()
for i, data in enumerate(train_loader):
img, label = data
img, label = img.to(device), label.to(device)
#权重参数梯度清零
optimizer.zero_grad()
#正向反向传播
out = model(img)
loss = criterion(out, label)
loss.backward()
optimizer.step()
#计算损失值
train_loss += loss.item()
#计算准确率
_, pred = out.max(1)
num_correct += (pred == label).sum()
if i % 2000 == 1999:
print('[%d,%5d] loss : %.3f' % (epoch + 1, i + 1, train_loss / 2000))
train_loss = 0.0
acces.append(num_correct/(len(train_loader)*train_batch_size))
#精确率可视化
plt.title('Train Acc')
plt.plot(np.arange(len(acces)),acces)
plt.legend(['Train Acc'],loc='upper right')
plt.show()
#测试模型
eval_loss = 0
eval_acc = 0
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
total = 0
model.eval()
with torch.no_grad():
for img, label in test_loader:
img, label = img.to(device), label.to(device)
out = model(img)
#计算损失值
loss = criterion(out,label)
eval_loss += loss.item()
#计算准确率
_, pred = out.max(1)
#print("len(label):{}".format(len(label)))
num_correct += (pred == label).sum()
c = (pred == label).squeeze()
acc = num_correct/len(label)
eval_acc += acc
total += label.size(0)
#计算各类别准确率
for i in range(4):
class_correct[label[i]] += c[i].item()
class_total[label[i]] += 1
eval_losses.append(eval_loss/total)
eval_acces.append(eval_acc/total)
print("total:{}".format(total))
print("len(test_loader):{}".format(len(test_loader)))
for i in range(10):
print("accuracy of {}:{}%".format(classes[i],100*class_correct[i]/class_total[i]))
print("----------------")
print('epoch:{}, eval_loss:{:.4f},eval_acc:{:.4f}'.format(epoch,eval_loss/len(test_loader),eval_acc/len(test_loader)))
print("Accuracy of the network on the 10000 test images:%d %%" % (100 * eval_acc / len(test_loader)))
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import numpy as np
print(torch.cuda.device_count())
train_batch_size = 4
test_batch_size = 4
num_workers = 0 #线程数
classes = ('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')
lr = 0.001
momentum = 0.9
#加载数据集
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])])
#train_dataset = torchvision.datasets.CIFAR10('./dataset',train=True,transform=transform,download=True)
test_dataset = torchvision.datasets.CIFAR10('./dataset',train=False,transform=transform,download=True)
#train_loader = DataLoader(train_dataset,batch_size=train_batch_size,shuffle=True,num_workers=num_workers)
test_loader = DataLoader(test_dataset,batch_size=test_batch_size,shuffle=False,num_workers=num_workers)
#构建网络
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
criterion = nn.CrossEntropyLoss()
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
class CNNNet(nn.Module):
def __init__(self):
super(CNNNet, self).__init__()
self.conv1 = nn.Conv2d(in_channels=3,out_channels=16,kernel_size=5,stride=1)
self.pool1 = nn.MaxPool2d(kernel_size=2,stride=2)
self.conv2 = nn.Conv2d(in_channels=16,out_channels=36,kernel_size=3,stride=1)
self.pool2 = nn.MaxPool2d(kernel_size=2,stride=2)
#self.aap = nn.AdaptiveAvgPool2d(1)
self.fc1 = nn.Linear(1296,128)
self.fc2 = nn.Linear(128,10)
#self.fc3 = nn.Linear(36,10)
def forward(self,x):
x = self.pool1(F.relu(self.conv1(x)))
x = self.pool2(F.relu(self.conv2(x)))
#x = self.aap(x)
#x = x.view(x.shape[0],-1)
#x = self.fc3(x)
x = x.view(-1,36*6*6)
#print("x.shape:{}".format(x.shape))
x = F.relu(self.fc2(F.relu(self.fc1(x))))
return x
model = CNNNet()
#加载模型
model.load_state_dict(torch.load('./model/model.pth'))#再加载网络的参数
model = model.to(device)
print("load success")
print('--------------查看网络结构-----------')
print(model)
#测试模型
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
num_correct = 0
model.eval()
with torch.no_grad():
for img, label in test_loader:
img, label = img.to(device), label.to(device)
out = model(img)
#计算准确率
_, pred = out.max(1)
num_correct += (pred == label).sum()
#计算各类别准确率
c = (pred == label)
for i in range(4):
class_correct[label[i]] += c[i].item() #将True/False化为1/0
class_total[label[i]] += 1
print("精确率为:{}".format(num_correct/(len(test_loader)*test_batch_size)))
for i in range(10):
print("accuracy of {}:{}%".format(classes[i],100*class_correct[i]/class_total[i]))