模型特点:每个卷积层包含3个部分:卷积、池化(Average Pooling)、非线性激活函数(Tanh)
class LeNet5(nn.Module):
""" 使用sequential构建网络,Sequential()函数的功能是将网络的层组合到一起 """
def __init__(self, in_channel, output):
super(LeNet5, self).__init__()
self.layer1 = nn.Sequential(nn.Conv2d(in_channels=in_channel, out_channels=6, kernel_size=5, stride=1, padding=2), # (6, 28, 28)
nn.Tanh(),
nn.AvgPool2d(kernel_size=2, stride=2, padding=0)) # (6, 14, 14))
self.layer2 = nn.Sequential(nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1, padding=0), # (16, 10, 10)
nn.Tanh(),
nn.AvgPool2d(kernel_size=2, stride=2, padding=0)) # (16, 5, 5)
self.layer3 = nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5) # (120, 1, 1)
self.layer4 = nn.Sequential(nn.Linear(in_features=120, out_features=84),
nn.Tanh(),
nn.Linear(in_features=84, out_features=output))
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = torch.flatten(input=x, start_dim=1)
x = self.layer4(x)
return x
import numpy as np
import torch
import torch.nn as nn
from torchvision.datasets import mnist
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.optim as optim
import matplotlib.pyplot as plt
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
train_batch_size = 12
test_batch_size = 48
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
# 下载数据 & 导入数据
train_set = mnist.MNIST("./mnist_data", train=True, download=True, transform=transform)
test_set = mnist.MNIST("./mnist_data", train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=train_batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=test_batch_size, shuffle=False)
# # 抽样查看图片
# examples = enumerate(test_loader)
# batch_idex, (example_data, example_label) = next(examples)
# sample_set = np.array(example_data)
#
# for i in range(6):
# plt.subplot(2, 3, i + 1)
# plt.imshow(sample_set[i][0])
# plt.title("Ground Truth: {}".format(example_label[i]))
# plt.show()
class LeNet5(nn.Module):
""" 使用sequential构建网络,Sequential()函数的功能是将网络的层组合到一起 """
def __init__(self, in_channel, output):
super(LeNet5, self).__init__()
self.layer1 = nn.Sequential(nn.Conv2d(in_channels=in_channel, out_channels=6, kernel_size=5, stride=1, padding=2), # (6, 28, 28)
nn.Tanh(),
nn.AvgPool2d(kernel_size=2, stride=2, padding=0)) # (6, 14, 14))
self.layer2 = nn.Sequential(nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1, padding=0), # (16, 10, 10)
nn.Tanh(),
nn.AvgPool2d(kernel_size=2, stride=2, padding=0)) # (16, 5, 5)
self.layer3 = nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5) # (120, 1, 1)
self.layer4 = nn.Sequential(nn.Linear(in_features=120, out_features=84),
nn.Tanh(),
nn.Linear(in_features=84, out_features=output))
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = torch.flatten(input=x, start_dim=1)
x = self.layer4(x)
return x
model = LeNet5(1, 10)
model.to(device)
lr = 0.01
num_epoches = 20
momentum = 0.8
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
eval_losses = []
eval_acces = []
for epoch in range(num_epoches):
if epoch % 5 == 0:
optimizer.param_groups[0]['lr'] *= 0.1
model.train()
for imgs, labels in train_loader:
imgs, labels = imgs.to(device), labels.to(device)
predict = model(imgs)
loss = criterion(predict, labels)
# back propagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
eval_loss = 0
eval_acc = 0
model.eval()
for imgs, labels in test_loader:
imgs, labels = imgs.to(device), labels.to(device)
predict = model(imgs)
loss = criterion(predict, labels)
# record loss
eval_loss += loss.item()
# record accurate rate
result = torch.argmax(predict, axis=1)
acc_num = (result == labels).sum().item()
acc_rate = acc_num / imgs.shape[0]
eval_acc += acc_rate
eval_losses.append(eval_loss / len(test_loader))
eval_acces.append(eval_acc / len(test_loader))
print('epoch: {}'.format(epoch))
print('loss: {}'.format(eval_loss / len(test_loader)))
print('accurate rate: {}'.format(eval_acc / len(test_loader)))
print('\n')
plt.title('evaluation loss')
plt.plot(np.arange(len(eval_losses)), eval_losses)
plt.show()
epoch: 0
loss: 0.20932436712157498
accurate rate: 0.9417862838915463
epoch: 1
loss: 0.1124003769263946
accurate rate: 0.9681020733652314
epoch: 2
loss: 0.0809573416740736
accurate rate: 0.9753787878787872
epoch: 3
loss: 0.07089491755452061
accurate rate: 0.9779704944178623
epoch: 4
loss: 0.05831286043338656
accurate rate: 0.9821570972886757
epoch: 5
loss: 0.05560500273351785
accurate rate: 0.9828548644338115
epoch: 6
loss: 0.0542455422597309
accurate rate: 0.9835526315789472
epoch: 7
loss: 0.05367041283908732
accurate rate: 0.9838516746411479
epoch: 8
loss: 0.05298826666370605
accurate rate: 0.9838516746411481
epoch: 9
loss: 0.05252152112530963
accurate rate: 0.9836523125996807
epoch: 10
loss: 0.05247020455629846
accurate rate: 0.9836523125996808
epoch: 11
loss: 0.05242454297127621
accurate rate: 0.9837519936204145
epoch: 12
loss: 0.05237526405083559
accurate rate: 0.9838516746411481
epoch: 13
loss: 0.05233189105290171
accurate rate: 0.9839513556618819
epoch: 14
loss: 0.05222674906053291
accurate rate: 0.9837519936204145
epoch: 15
loss: 0.052228276117072044
accurate rate: 0.9837519936204145
epoch: 16
loss: 0.05222897543727852
accurate rate: 0.9837519936204145
epoch: 17
loss: 0.05222897782574216
accurate rate: 0.9838516746411481
epoch: 18
loss: 0.05222847037079731
accurate rate: 0.9838516746411481
epoch: 19
loss: 0.05222745426054866
accurate rate: 0.9838516746411481