import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
#归一化
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
#加载数据
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=4)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=False, num_workers=4)
#定义CIFAR10分类器
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
#定义网络
class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
self.conv1=nn.Conv2d(3,6,5)#第一个卷积层
self.pool=nn.MaxPool2d(2,2)#池化函数
self.conv2=nn.Conv2d(6,16,5)#第二个卷积层
self.fc1=nn.Linear(16*5*5,120)#第一个全连接层
self.fc2=nn.Linear(120,84)#第二个全连接层
self.fc3=nn.Linear(84,10)#第三个全连接层
def forward(self,x):
x=self.pool(F.relu(self.conv1(x)))#第一次卷积后池化
x=self.pool(F.relu(self.conv2(x)))#第二次卷积后池化
x=x.view(-1,16*5*5)#调整张量的空间结构,与全连接层连接
x=F.relu(self.fc1(x))#第一层全连接层
x=F.relu(self.fc2(x))#第二层全连接层
x=self.fc3(x)#第三层全连接层
return x
if __name__=="__main__":
net=Net()#生成网络
criterion=nn.CrossEntropyLoss()#定义交叉熵损失函数
optimizer=optim.SGD(net.parameters(),lr=0.001,momentum=0.9)#优化函数
for epoch in range(2):#迭代训练
running_loss=0.0
for i,data in enumerate(trainloader,0):#为可遍历对象生成索引
inputs,lable=data
optimizer.zero_grad()#初始化梯度
output=net(inputs)
loss=criterion(output,lable)#计算交叉熵
loss.backward()#反馈
optimizer.step()#迭代
running_loss+=loss.item()
if i % 2000==1999:
print('[%d,%5d] loss:%.3f'%(epoch+1,i+1,running_loss/2000))
print("Finished Training")
correct=0
total=0
with torch.no_grad():#测试集传播反馈梯度
for data in trainloader:
images,lables=data
output=net(images)
num,predict=torch.max(output.data,1)
total+=lables.size(0)
correct+=(predict==lables).sum().item()
print('Accuracy of the network on the 10000 test images: %d %%' % (
100 * correct / total))
如果在函数中没有使用 if __name__=="__main__" 则将 num_workers 的参数值设为0【使用主线程计算】,否则会出现运行时错误。