1. 下载并导入数据集,,需要将数据集导入至代码所在文件夹中,并命名为data。本次展示所用的数据集是cifar数据集
classes = ('plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck')
traindata = torchvision.datasets.CIFAR10(
root='./data', train=True, download=False, transform=trans)
trainloader = torch.utils.data.DataLoader(
traindata, batch_size=2048, shuffle=True, num_workers=2)
testdata = torchvision.datasets.CIFAR10(
root='./data', train=False, download=False, transform=trans)
testloader = torch.utils.data.DataLoader(
testdata, batch_size=2048, shuffle=True, num_workers=0)
2. 建立LeNet网络,网络有些调整。LeNet详细的网络参数可以参考其论文,我这边做了一些调整,是为了验证不同参数对结果的影响。
#网络:LeNet5 改版
net = torch.nn.Sequential(
nn.Conv2d(3, 6, kernel_size=5), nn.ReLU(),
nn.AvgPool2d(kernel_size=2, stride=2),
nn.Conv2d(6, 16, kernel_size=3), nn.ReLU(),
nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),
nn.Linear(16*5*5, 120),
nn.ReLU(),
nn.Linear(120, 84),
nn.ReLU(),
nn.Linear(84, 10),
)
3. 设置我们的损失函数为交叉熵损失,优化算法为SGDwithMomentum,学习率为1e-3
device = torch.device("cuda:0")
print(device)
criterion = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
net = net.to(device)
4. 开始训练
for epoch in range(50):
net.train()
for i, (X, Y) in enumerate(trainloader, 0):
optimizer.zero_grad()
X, Y = X.to(device), Y.to(device)
gc.collect()
torch.cuda.empty_cache()
outputs = net(X)
loss = criterion(outputs, Y.long())
loss.backward()
optimizer.step()
print(epoch, loss.item())
net.eval()
with torch.no_grad():
# test
total_correct = 0
total_num = 0
for text_X, test_Y in testloader:
text_X, test_Y = text_X.to(device), test_Y.to(device)
logits = net(text_X)
pred = logits.argmax(dim=1)
total_correct += torch.eq(pred, test_Y).float().sum().item()
total_num += text_X.size(0)
acc = total_correct / total_num
print("epoch: ", epoch, "acc: ", acc)
5. 效果展示:目前迭代了13轮,效果很差哈哈哈哈。在后面的训练中,我们可以通过减小学习率,来不断的优化
0 2.302259922027588
epoch: 0 acc: 0.1289
1 2.301879405975342
epoch: 1 acc: 0.1347
2 2.305971384048462
epoch: 2 acc: 0.1357
3 2.2990949153900146
epoch: 3 acc: 0.1353
4 2.2980239391326904
epoch: 4 acc: 0.1359
5 2.2999300956726074
epoch: 5 acc: 0.1348
6 2.29781436920166
epoch: 6 acc: 0.1348
7 2.297546625137329
epoch: 7 acc: 0.1354
8 2.2953741550445557
epoch: 8 acc: 0.1356
9 2.293961763381958
epoch: 9 acc: 0.1364
10 2.2932660579681396
epoch: 10 acc: 0.1381
11 2.290314197540283
epoch: 11 acc: 0.1419
12 2.2890193462371826
epoch: 12 acc: 0.1478
13 2.287813186645508
epoch: 13 acc: 0.1549