——学习自B站up主教学视频视频链接
MNIST数据集是新手入门深度学习计算机视觉的必经之路,数据集为多张图片,其中为手写数字如下图所示:
首先导入需要使用的相应包
import torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
在pytorch中自带数据集,我们只需使用pytorch内置函数在网上下载即可。细致步骤包括定义转换器将图片转化成tensor形式,并做归一化处理。这里做归一化处理是由于网络对输入为0-1之间的数字训练得到的结果较好。
#定义batch_size
batch_size = 64
#定义转换器对数据集转换成tensor形式,并用Normalize函数将数值进行归一化
transform = transforms.Compose([
transforms.ToTensor(),
#这里的两个数字分别为给定数据集的数据大小平均值和标准差,为大众计算结果,这样使得训练效果较好
transforms.Normalize((0.137,), (0.3081,))
])
#加载训练集:train参数(true表示数据为训练集,false表示为测试集)
train_dataset = datasets.MNIST(root='dataset/mnist/',
train=True,
download=True,
transform=transform)
#生成DataLoader对象
train_loader = DataLoader(train_dataset,
shuffle=True,
batch_size=batch_size)
test_dataset = datasets.MNIST(root='dataset/mnist/',
train=False,
download=True,
transform=transform)
test_loader = DataLoader(test_dataset,
shuffle=False,
batch_size=batch_size)
这里我们使用五层线性全连接模型,将784(28*28)个列转化为10个属性,分别表示0-9的可能性。
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.l1 = torch.nn.Linear(784, 512)
self.l2 = torch.nn.Linear(512, 256)
self.l3 = torch.nn.Linear(256, 128)
self.l4 = torch.nn.Linear(128, 64)
self.l5 = torch.nn.Linear(64, 10)
def forward(self, x):
x = x.view(-1, 784)
x = F.relu(self.l1(x))
x = F.relu(self.l2(x))
x = F.relu(self.l3(x))
x = F.relu(self.l4(x))
return self.l5(x)
由于是分类问题,我们采用交叉熵作为损失函数,并使用pytorch内置的SGD作为优化器。
#损失函数
criterion = torch.nn.CrossEntropyLoss()
#优化器选择SGD
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.5)
相应代码说明代码注释中给出。
def train(epoch):
#运行损失
running_loss = 0.0
#这里按行进行遍历去训练,包括前向、反向以及优化
for batch_idx, data in enumerate(train_loader, 0):
inputs, targets = data
#这里注意需要先将grad梯度置为0
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
running_loss += loss.item()
#每三百次输出训练情况
if batch_idx % 300 == 299:
print('[%d, %5d] loss: %.3f' % (epoch + 1, batch_idx + 1, running_loss / 300))
def test():
correct = 0
total = 0
with torch.no_grad():
for data in test_loader:
images, labels = data
outputs = model(images)
_, predicted = torch.max(outputs.data, dim=1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Accuracy on test set: %d %%' % (100 * correct / total))
if __name__ == '__main__':
for epoch in range(10):
train(epoch)
test()
Accuracy on test set: 98 %
[9, 300] loss: 0.013
[9, 600] loss: 0.027
[9, 900] loss: 0.048
Accuracy on test set: 98 %
[10, 300] loss: 0.008
[10, 600] loss: 0.018
[10, 900] loss: 0.030
Accuracy on test set: 97 %
上述给出结果并不完全,由于大小太多不全给出,我们可以看到精度大多都在97-98%,对于数字图片的分类效果很好。