论文地址:ImageNet Classification with Deep Convolutional Neural Networks
Alexnet的网络结构如图所示
cifar10数据集一共包含10个类别的RGB彩色图像,每个类别分别有6000张,共60000张大小为32x32的图片。其中50000个作为训练集,10000个作为测试集。
代码如下:
注意这里将cifar10图片上采样为224*224,以接近原始版本Alexnet;输出类别由1000修改为10。
# 导入模块
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
from torchvision.transforms.functional import InterpolationMode
import random
import matplotlib.pyplot as plt
# 下载以及转换cifar10
transform = transforms.Compose([transforms.Resize((224, 224), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.247, 0.2435, 0.2616]) # 此为训练集上的均值与方差
])
train_images = datasets.CIFAR10('./', train=True, download=True, transform=transform)
test_images = datasets.CIFAR10('./', train=False, download=True, transform=transform)
# batch size设置为256
train_data = DataLoader(train_images, batch_size=256, shuffle=True, num_workers=2)
test_data = DataLoader(test_images, batch_size=256, num_workers=2)
# Alexnet
class Model(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=1), nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(96, 256, kernel_size=5, padding=2), nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(256, 384, kernel_size=3, padding=1), nn.ReLU(),
nn.Conv2d(384, 384, kernel_size=3, padding=1), nn.ReLU(),
nn.Conv2d(384, 256, kernel_size=3, padding=1), nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Flatten(), nn.Linear(256*5*5, 4096), nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(4096, 4096), nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(4096, 10))
def forward(self, X):
return self.net(X)
# 参数初始化
def initial(layer):
if isinstance(layer, nn.Linear) or isinstance(layer, nn.Conv2d):
nn.init.xavier_normal_(layer.weight.data)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
net = Model().to(device)
net.apply(initial)
epochs = 17 # 随便设置的epoch
lr = 0.01 # 学习率
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005)
# 训练与测试
train_loss, test_loss, train_acc, test_acc = [], [], [], [] # 用来记录每个epoch的训练、测试误差以及准确率
for i in range(epochs):
# 训练
net.train()
temp_loss, temp_correct = 0, 0
for X, y in train_data:
X = X.to(device)
y = y.to(device)
y_hat = net(X)
loss = criterion(y_hat, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 计算每次loss与预测正确的个数
label_hat = torch.argmax(y_hat, dim=1)
temp_correct += (label_hat == y).sum()
temp_loss += loss
print(f'epoch:{i+1} train loss:{temp_loss/len(train_data):.3f}, train Aacc:{temp_correct/50000*100:.2f}%', end='\t')
train_loss.append((temp_loss/len(train_data)).item())
train_acc.append((temp_correct/50000).item())
# 测试
temp_loss, temp_correct = 0, 0
net.eval()
with torch.no_grad():
for X, y in test_data:
X = X.to(device)
y = y.to(device)
y_hat = net(X)
loss = criterion(y_hat, y)
label_hat = torch.argmax(y_hat, dim=1)
temp_correct += (label_hat == y).sum()
temp_loss += loss
print(f'test loss:{temp_loss/len(test_data):.3f}, test acc:{temp_correct/10000*100:.2f}%')
test_loss.append((temp_loss/len(test_data)).item())
test_acc.append((temp_correct/10000).item())
输出为:
epoch:1 train loss:1.791, train Aacc:33.77% test loss:1.544, test acc:44.03%
epoch:2 train loss:1.285, train Aacc:53.90% test loss:1.144, test acc:60.33%
epoch:3 train loss:1.009, train Aacc:64.55% test loss:0.882, test acc:69.84%
epoch:4 train loss:0.839, train Aacc:70.70% test loss:0.807, test acc:72.48%
epoch:5 train loss:0.701, train Aacc:75.82% test loss:0.727, test acc:75.22%
epoch:6 train loss:0.618, train Aacc:78.51% test loss:0.660, test acc:77.59%
epoch:7 train loss:0.529, train Aacc:81.80% test loss:0.639, test acc:77.91%
epoch:8 train loss:0.446, train Aacc:84.68% test loss:0.667, test acc:78.64%
epoch:9 train loss:0.383, train Aacc:86.66% test loss:0.591, test acc:80.92%
epoch:10 train loss:0.328, train Aacc:88.60% test loss:0.621, test acc:80.45%
epoch:11 train loss:0.266, train Aacc:90.57% test loss:0.632, test acc:81.27%
epoch:12 train loss:0.211, train Aacc:92.75% test loss:0.629, test acc:82.48%
epoch:13 train loss:0.178, train Aacc:93.75% test loss:0.655, test acc:81.68%
epoch:14 train loss:0.142, train Aacc:95.10% test loss:0.648, test acc:82.74%
epoch:15 train loss:0.130, train Aacc:95.39% test loss:0.726, test acc:83.01%
epoch:16 train loss:0.100, train Aacc:96.57% test loss:0.831, test acc:81.86%
epoch:17 train loss:0.090, train Aacc:96.89% test loss:0.688, test acc:82.73%
最后我们可视化一下模型预测的性能,我们在测试集上随机挑选12张图片输入网络进行预测。
plt.figure(figsize=(16, 14))
for i in range(12):
img_data, label_id = random.choice(list(zip(test_images.data, test_images.targets)))
img = transforms.ToPILImage()(img_data)
predict_id = torch.argmax(net(transform(img).unsqueeze(0).to(device)))
predict = test_images.classes[predict_id]
label = test_images.classes[label_id]
plt.subplot(3, 4, i+1)
plt.imshow(img)
plt.title(f'truth:{label}\npredict:{predict}')