LeNet是来自论文《Gradient-Based Learning Applied to Document Recognition》中提出的网络,它是CNN的开山鼻祖,对于手写数字的识别有了当时最先进的结果。
按照上面对于网络结构的分解,在pytorch构建如下网络:
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(1, 6, 5),
nn.Sigmoid(),
nn.MaxPool2d(2, 2),
nn.Conv2d(6, 16, 5),
nn.Sigmoid(),
nn.MaxPool2d(2, 2)
)
self.fc = nn.Sequential(
nn.Linear(16*4*4, 120),
nn.Sigmoid(),
nn.Linear(120, 84),
nn.Sigmoid(),
nn.Linear(84, 10)
)
def forward(self ,X):
feature = self.conv(X)
output = self.fc(feature.view(X.shape[0], -1))
return output
def train(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs):
net = net.to(device)
print("training on ", device)
loss = nn.CrossEntropyLoss()
batch_count = 0
for epoch in range(num_epochs):
train_l_sum, train_acc_sum, n, start = 0.0, 0.0, 0, time.time()
for X, y in train_iter:
# 放入设备中
X = X.to(device)
y = y.to(device)
# 由模型得到的值
y_hat = net(X)
# 求损失
l = loss(y_hat, y)
optimizer.zero_grad()
l.backward()
optimizer.step()
train_l_sum += l.cpu().item()
train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item()
n += y.shape[0]
batch_count += 1
test_acc = evaluate_accuracy(test_iter, net)
print(train_l_sum, train_acc_sum)
print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f,time % .1f sec' % (epoch + 1, train_l_sum / batch_count, train_acc_sum / n, test_acc, time.time() - start))