一级目录
import torch
import torchvision
import torchvision.transforms as transforms
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
transform = torchvision.transforms.Compose(
[
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize([0.5], [0.5])
]
)
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset,batch_size=128,shuffle=True)
testloader = torch.utils.data.DataLoader(testset,batch_size=128,shuffle=True)
import matplotlib.pyplot as plt
import numpy as np
def imshow(img):
img = img / 2 + 0.5
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
dataiter = iter(trainloader)
images, labels = dataiter.next()
print(images.shape)
print(labels.shape)
imshow(torchvision.utils.make_grid(images))
定义模型
import torch.nn as nn
from torch import nn,optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets,transforms
class CNNnet(torch.nn.Module):
def __init__(self):
super(CNNnet,self).__init__()
self.conv1 = torch.nn.Sequential(
torch.nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=2, padding=1),
torch.nn.BatchNorm2d(16),
torch.nn.ReLU()
)
self.conv2 = torch.nn.Sequential(
torch.nn.Conv2d(16,32,3,2,1),
torch.nn.BatchNorm2d(32),
torch.nn.ReLU()
)
self.conv3 = torch.nn.Sequential(
torch.nn.Conv2d(32,64,3,2,1),
torch.nn.BatchNorm2d(64),
torch.nn.ReLU()
)
self.conv4 = torch.nn.Sequential(
torch.nn.Conv2d(64,64,2,2,0),
torch.nn.BatchNorm2d(64),
torch.nn.ReLU()
)
self.mlp1 = torch.nn.Linear(2*2*64,100)
self.mlp2 = torch.nn.Linear(100,10)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
x = self.mlp1(x.view(x.size(0),-1))
x = self.mlp2(x)
return x
model = CNNnet()
model = model.cuda()
print(model)
loss_func = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=0.001)
训练网络
loss_count = []
for epoch in range(10):
running_loss = 0.0
for step, (x,y) in enumerate(trainloader, 0):
inputs = Variable(x).to(device)
labels = Variable(y).to(device)
outputs = model(inputs)
loss = loss_func(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_count.append(loss)
running_loss += loss.item()
if step % 100 == 99:
print('[%d, %5d] loss: %.3f' % (epoch + 1, step + 1, running_loss / 100))
[1, 100] loss: 1.021
[1, 200] loss: 1.337
[1, 300] loss: 1.564
[1, 400] loss: 1.734
[2, 100] loss: 0.114
[2, 200] loss: 0.225
[2, 300] loss: 0.314
[2, 400] loss: 0.401
[3, 100] loss: 0.066
[3, 200] loss: 0.134
[3, 300] loss: 0.205
[3, 400] loss: 0.268
[4, 100] loss: 0.048
[4, 200] loss: 0.100
[4, 300] loss: 0.153
[4, 400] loss: 0.207
[5, 100] loss: 0.042
[5, 200] loss: 0.083
[5, 300] loss: 0.122
[5, 400] loss: 0.169
[6, 100] loss: 0.034
[6, 200] loss: 0.067
[6, 300] loss: 0.101
[6, 400] loss: 0.135
[7, 100] loss: 0.027
[7, 200] loss: 0.054
[7, 300] loss: 0.081
[7, 400] loss: 0.116
[8, 100] loss: 0.022
[8, 200] loss: 0.047
[8, 300] loss: 0.074
[8, 400] loss: 0.101
[9, 100] loss: 0.016
[9, 200] loss: 0.040
[9, 300] loss: 0.064
[9, 400] loss: 0.092
[10, 100] loss: 0.015
[10, 200] loss: 0.033
[10, 300] loss: 0.050
[10, 400] loss: 0.074
loss_count1=[]
for i in loss_count:
loss_count1.append(i.cpu().item())
plt.figure('PyTorch_CNN_Loss')
plt.plot(loss_count1,label='Loss')
plt.legend()
plt.show()
correct = 0
total = 0
count = 0
with torch.no_grad():
for images, labels in testloader:
images = Variable(images).to(device)
labels = Variable(labels).to(device)
pre_labels = model(images)
_, pred = torch.max(pre_labels, 1)
correct += (pred == labels).sum().item()
total += labels.size(0)
count += 1
print("在第{0}个batch中的Acc为:{1}" .format(count, correct/total))
accuracy = float(correct) / total
print("====================== Result =============================")
print('测试集上平均Acc = {:.5f}'.format(accuracy))
print("测试集共样本{0}个,分为{1}个batch,预测正确{2}个".format(total, count, correct))
在第1个batch中的Acc为:1.0
在第2个batch中的Acc为:0.9921875
在第3个batch中的Acc为:0.9895833333333334
在第4个batch中的Acc为:0.990234375
在第5个batch中的Acc为:0.990625
在第6个batch中的Acc为:0.9908854166666666
在第7个batch中的Acc为:0.9899553571428571
在第8个batch中的Acc为:0.9892578125
在第9个batch中的Acc为:0.9904513888888888
在第10个batch中的Acc为:0.98984375
在第11个batch中的Acc为:0.9893465909090909
在第12个batch中的Acc为:0.990234375
在第13个batch中的Acc为:0.9897836538461539
在第14个batch中的Acc为:0.9905133928571429
在第15个batch中的Acc为:0.9911458333333333
在第16个batch中的Acc为:0.9912109375
在第17个batch中的Acc为:0.9912683823529411
在第18个batch中的Acc为:0.9917534722222222
在第19个batch中的Acc为:0.9921875
在第20个batch中的Acc为:0.9921875
在第21个batch中的Acc为:0.9925595238095238
在第22个batch中的Acc为:0.9925426136363636
在第23个batch中的Acc为:0.9918478260869565
在第24个batch中的Acc为:0.9915364583333334
在第25个batch中的Acc为:0.991875
在第26个batch中的Acc为:0.9921875
在第27个batch中的Acc为:0.9918981481481481
在第28个batch中的Acc为:0.9916294642857143
在第29个batch中的Acc为:0.9916487068965517
在第30个batch中的Acc为:0.9916666666666667
在第31个batch中的Acc为:0.9919354838709677
在第32个batch中的Acc为:0.99169921875
在第33个batch中的Acc为:0.9914772727272727
在第34个batch中的Acc为:0.9914981617647058
在第35个batch中的Acc为:0.9917410714285714
在第36个batch中的Acc为:0.9913194444444444
在第37个batch中的Acc为:0.9915540540540541
在第38个batch中的Acc为:0.9915707236842105
在第39个batch中的Acc为:0.991386217948718
在第40个batch中的Acc为:0.99140625
在第41个batch中的Acc为:0.991234756097561
在第42个batch中的Acc为:0.9908854166666666
在第43个batch中的Acc为:0.9910973837209303
在第44个batch中的Acc为:0.9909446022727273
在第45个batch中的Acc为:0.9907986111111111
在第46个batch中的Acc为:0.990828804347826
在第47个batch中的Acc为:0.9906914893617021
在第48个batch中的Acc为:0.99072265625
在第49个batch中的Acc为:0.990593112244898
在第50个batch中的Acc为:0.99078125
在第51个batch中的Acc为:0.9909620098039216
在第52个batch中的Acc为:0.9911358173076923
在第53个batch中的Acc为:0.9910082547169812
在第54个batch中的Acc为:0.9910300925925926
在第55个batch中的Acc为:0.9907670454545454
在第56个batch中的Acc为:0.9906529017857143
在第57个batch中的Acc为:0.9906798245614035
在第58个batch中的Acc为:0.9907058189655172
在第59个batch中的Acc为:0.9904661016949152
在第60个batch中的Acc为:0.990625
在第61个batch中的Acc为:0.9907786885245902
在第62个batch中的Acc为:0.9909274193548387
在第63个batch中的Acc为:0.9910714285714286
在第64个batch中的Acc为:0.9910888671875
在第65个batch中的Acc为:0.9909855769230769
在第66个batch中的Acc为:0.9906486742424242
在第67个batch中的Acc为:0.9907882462686567
在第68个batch中的Acc为:0.9909237132352942
在第69个batch中的Acc为:0.990828804347826
在第70个batch中的Acc为:0.9908482142857142
在第71个batch中的Acc为:0.9909771126760564
在第72个batch中的Acc为:0.9909939236111112
在第73个batch中的Acc为:0.9910102739726028
在第74个batch中的Acc为:0.9911317567567568
在第75个batch中的Acc为:0.9911458333333333
在第76个batch中的Acc为:0.9910567434210527
在第77个batch中的Acc为:0.9909699675324676
在第78个batch中的Acc为:0.9909855769230769
在第79个batch中的Acc为:0.9909
====================== Result =============================
测试集上平均Acc = 0.99090
测试集共样本10000个,分为79个batch,预测正确9909个