采用PyTorch构建网络,主要需要定义网络结构(Net)和定义前向传播函数(forward)。
以LeNet-5为例,网络图和代码如下:
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import datasets, transforms
import torch.utils.data as Data
import torch.optim as optim
from matplotlib import pyplot as plt
import numpy as np
import os
if not os.path.exists('./models/'):#保存模型文件
os.mkdir('./models/')
class lenet5(nn.Module):
def __init__(self):
super(lenet5,self).__init__()
self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros')
self.pool1 = nn.AvgPool2d(kernel_size=2, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None)
self.conv2 = nn.Conv2d(6,16,5,1)
self.pool2 = nn.AvgPool2d(2)
self.fc1 = nn.Linear(4*4*16,120)#注:按原始的minst数据集,输入为图示中的32*32时,此处应该是5*5*16.但是按torchvision.datasets中的输入大小则是28*28,此处为4*4*16。或者直接在Data.DataLoader时将输入transforms到32*32大小
self.fc2 = nn.Linear(120,84)
self.out = nn.Linear(84,10)
def forward(self, x):
x = self.pool1(F.relu(self.conv1(x)))
x = self.pool2(F.relu(self.conv2(x)))
x = x.view(x.shape[0], -1)#flatten the output of pool2 to (batch_size, 16 * 4 * 4),x.shape[0]为batch_size,-1为自适应调整大小
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.softmax(self.out(x),dim=-1)
return x
net = lenet5()
print(net)#打印网络结构
#load dataset
BATCH_SIZE = 256
transformImg = transforms.ToTensor()#此处还可以调整输入图的大小
train_dataset = datasets.MNIST(root="./datasets/", train=True, transform=transformImg, download=True)
test_dataset = datasets.MNIST(root="./datasets/",train=False, transform=transformImg)
train_loader = Data.DataLoader(dataset=train_dataset,batch_size = BATCH_SIZE,shuffle=True, num_workers=8)
test_loader = Data.DataLoader(dataset=test_dataset,batch_size=BATCH_SIZE,shuffle=False, num_workers=8)
loss = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(),lr=0.1)
num_epochs = 100
loss_list = []
for epoch in range(num_epochs):
for i, (X, label) in enumerate(train_loader):
y = net(X)
loss_value = loss(y, label)
optimizer.zero_grad()
loss_value.backward()
optimizer.step()
print("iter: ", epoch, ", loss: ", loss_value)
loss_list.append([epoch, loss_value])
correct = 0
_sum = 0
for idx, (test_x, test_label) in enumerate(test_loader):
predict_y = net(test_x.float()).detach()
predict_ys = np.argmax(predict_y, axis=-1)
label_np = test_label.numpy()
_ = predict_ys == test_label
correct += np.sum(_.numpy(), axis=-1)
_sum += _.shape[0]
print('accuracy: {:.2f}'.format(correct / _sum))
torch.save(net, './models/mnist_{:.2f}.pkl'.format(correct / _sum))
plt.plot([i[0] for i in loss_list],[i[1] for i in loss_list])
plt.show()