pytorch实现CIFAR10实战
步骤
代码
训练代码
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from module import *
import torchvision
import torch.nn
writer = SummaryWriter('../shizhan')
train_data = torchvision.datasets.CIFAR10(root=r'F:\研究生\深度学习项目练习\b站PyTorch深度学习\torchvision\data',train=True,transform=torchvision.transforms.ToTensor())
test_data = torchvision.datasets.CIFAR10(root=r'F:\研究生\深度学习项目练习\b站PyTorch深度学习\torchvision\data',train=False,transform=torchvision.transforms.ToTensor())
train_len = len(train_data)
test_len = len(test_data)
print("训练集的长度为:{}".format(train_len))
print("测试集的长度为:{}".format(test_len))
train_loader = DataLoader(train_data,batch_size=64)
test_loader = DataLoader(test_data,batch_size=64)
net = Module()
loss_fn = nn.CrossEntropyLoss()
learing_rate = 1e-2
optimzer = torch.optim.SGD(net.parameters(),lr=learing_rate)
total_train_step=0
total_test_step=0
epoch=10
net.train()
for i in range(epoch):
print("------第{}轮数训练开始------".format(i+1))
for data in train_loader:
img,target = data
output = net(img)
loss=loss_fn(output,target)
optimzer.zero_grad()
loss.backward()
optimzer.step()
total_train_step=total_train_step+1
if total_train_step % 100 ==0:
print("训练次数{},Loss:{}".format(total_train_step,loss))
net.eval()
total_test_loss=0
total_test_accuracy=0
with torch.no_grad():
for data in test_loader:
img,target = data
output =net(img)
loss = loss_fn(output,target)
total_test_loss+=loss
accuracy = (output.argmax(1)==target).sum()
total_test_accuracy+=accuracy.item()
print("在整个测试集上的损失率为:{}".format(total_test_loss))
print("在整个测试集上的正确率为:{}".format(total_test_accuracy/test_len))
writer.add_scalar("test_loss",total_test_loss,total_test_step)
writer.add_scalar("test_accuracy", total_test_accuracy, total_test_step)
total_test_step+=1
torch.save(net,r"F:\研究生\深度学习项目练习\b站PyTorch深度学习\net\shizhanModelFile\net_{}.pth".format(i))
print("模型已保存")
writer.close()
模型代码
import torch.nn
from torch import nn
from torch.nn import Conv2d,Sequential,MaxPool2d,Flatten,Linear
class Module(nn.Module):
def __init__(self):
super(Module, self).__init__()
self.model=Sequential(
Conv2d(3,32,5,padding=2),
MaxPool2d(2),
Conv2d(32,32,5,padding=2),
MaxPool2d(2),
Conv2d(32,64,5,padding=2),
MaxPool2d(2),
Flatten(),
Linear(1024,64),
Linear(64,10)
)
def forward(self,x):
x=self.model(x)
return x
测试结果