本文将介绍利用PyTorch实现卷积神经网络LeNet-5,关于卷积神经网络LeNet-5的介绍,可以参考:手写数字识别问题(3)——详解卷积神经网络LeNet-5。
首先,需要定义一下Reshape类,将图像转为(X,1,28,28)的形式,其中X为图像的数量,1* 28* 28为图像格式,1为通道数。
#将x转为1*28*28的数据
class Reshape(torch.nn.Module):
def forward(self,x):
return x.view(-1,1,28,28)
然后,定义LeNet-5网络结构:
#LeNet-5网络结构
net=nn.Sequential(
Reshape(),nn.Conv2d(1,6,kernel_size=5,padding=2),nn.Sigmoid(),
nn.AvgPool2d(kernel_size=2,stride=2),
nn.Conv2d(6,16,kernel_size=5),nn.Sigmoid(),
nn.AvgPool2d(kernel_size=2,stride=2),nn.Flatten(),
nn.Linear(16*5*5,120),nn.Sigmoid(),
nn.Linear(120,84),nn.Sigmoid(),
nn.Linear(84,10))
可以初始化一个1* 1* 28* 28的torch张量对模型进行检查,查看其相应层的输出,代码如下:
#检查模型
x=torch.rand(size=(1,1,28,28),dtype=torch.float32)
for layer in net:
x=layer(x)
print(layer.__class__.__name__,'output shape:\t',x.shape)
运行结果:
Reshape output shape: torch.Size([1, 1, 28, 28])
Conv2d output shape: torch.Size([1, 6, 28, 28])
Sigmoid output shape: torch.Size([1, 6, 28, 28])
AvgPool2d output shape: torch.Size([1, 6, 14, 14])
Conv2d output shape: torch.Size([1, 16, 10, 10])
Sigmoid output shape: torch.Size([1, 16, 10, 10])
AvgPool2d output shape: torch.Size([1, 16, 5, 5])
Flatten output shape: torch.Size([1, 400])
Linear output shape: torch.Size([1, 120])
Sigmoid output shape: torch.Size([1, 120])
Linear output shape: torch.Size([1, 84])
Sigmoid output shape: torch.Size([1, 84])
Linear output shape: torch.Size([1, 10])
在整个卷积块中,与上一层相比,每一层特征的高度和宽度都减小了。 第一个卷积层使用 2 个像素的填充,来补偿 5×5 卷积核导致的特征减少。 相反,第二个卷积层没有填充,因此高度和宽度都减少了 4 个像素。 随着层叠的上升,通道的数量从输入时的 1 个,增加到第一个卷积层之后的 6 个,再到第二个卷积层之后的 16 个。 同时,每个汇聚层的高度和宽度都减半。最后,每个全连接层减少维数,最终输出一个维数与结果分类数相匹配的输出。
定义加载数据集函数:
#定义加载数据集函数
def load_data_fashion_mnist(batch_size):
'''下载MNIST数据集然后加载到内存中'''
train_dataset=datasets.MNIST(root='../data',train=True,transform=transforms.ToTensor(),download=True)
test_dataset=datasets.MNIST(root='../data',train=False,transform=transforms.ToTensor(),download=True)
return (data.DataLoader(train_dataset,batch_size,shuffle=True),
data.DataLoader(test_dataset,batch_size,shuffle=False))
然后,调用该函数,设置batch_size为64:
#LeNet-5在MNIST数据集上的表现
batch_size=64
train_iter,test_iter=load_data_fashion_mnist(batch_size=batch_size)
torchvision.utils的make_grid()函数将数据集可视化,首先需要定义imshow()函数显示图像:
def imshow(img):
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
解释一下这句话:plt.imshow(np.transpose(npimg, (1, 2, 0)))。因为在plt.imshow在现实的时候输入的是(imagesize,imagesize,channels)【如:(1,28,28)】,而定义的imshow()函数中,参数img的格式为(channels,imagesize,imagesize)【如:(28,28,1)】,这两者的格式不一致,我们需要调用一次np.transpose函数,即np.transpose(npimg,(1,2,0)),将npimg的数据格式由(channels,imagesize,imagesize)转化为(imagesize,imagesize,channels),进行格式的转换后方可进行显示。
iter(train_iter).next()[0]语句可以输出一个batch_size的图像,这里是64,我们可以查看一下它的size:
iter(train_iter).next()[0].size()
运行结果:
torch.Size([64, 1, 28, 28])
然后利用torchvision.utils.make_grid()函数输出图像,再进行可视化显示:
imshow(torchvision.utils.make_grid(iter(train_iter).next()[0]))
注意,make_grid()函数默认网格中每行显示的图像数为8,它会自动推导每列数据,由于总数据为64,因此每列也为8,数据集一个batch_size的可视化如下:
设置损失函数、优化器及开始训练代码如下:
#损失函数
loss_function=nn.CrossEntropyLoss()
#优化器
optimizer=torch.optim.Adam(net.parameters())
#开始训练
num_epochs=10
train_loss = []
for epoch in range(num_epochs):
for batch_idx, (x, y) in enumerate(train_iter):
# x = x.view(x.size(0), 28 * 28)
out = net(x)
y_onehot =F.one_hot(y,num_classes=10).float() # 转为one-hot编码
loss = loss_function(out, y_onehot) # 均方差
# 清零梯度
optimizer.zero_grad()
loss.backward()
# w' = w -lr * grad
optimizer.step()
train_loss.append(loss.item())
if batch_idx % 10 == 0:
print(epoch, batch_idx, loss.item())
本次训练十次,训练完成后,绘制损失曲线:
#绘制损失曲线
plt.figure(figsize=(16,8))
plt.grid(True,linestyle='--',alpha=0.5)
plt.plot(train_loss,label='loss')
plt.legend(loc="best")
测试准确率:
total_correct = 0
for batch_idx,(x,y) in enumerate(test_iter):
# x = x.view(x.size(0),28*28)
out = net(x)
pred = out.argmax(dim=1)
correct = pred.eq(y).sum().float().item()
total_correct += correct
total_num = len(test_iter.dataset)
test_acc=total_correct/total_num
print(total_correct,total_num)
print("test acc:",test_acc)
准确率为:98.19%,测试集准确率98.19%,说明LeNet-5的网络性能非常好。
完整代码可以参考:https://download.csdn.net/download/didi_ya/40736609
ok,以上便是本文的全部内容了,看完了之后记得一定要亲自独立动手实践一下呀~
参考链接: