**LeNet5是Yann Lecun等人在1998年提出的卷积神经网络,包含2个卷积层和3个全连接层。**LeNet5的发布使得这种叠加卷积层和池化层,并且以全连接层结束网络的架构成为标准的模板。
import torch
import torch.nn as nn
class LeNet5(nn.Module):
def __init__(self):
super(LeNet5, self).__init__()
self.conv2d1 = nn.Conv2d(in_channels=1,out_channels=6,kernel_size=5)
self.avgpool1 = nn.AvgPool2d(kernel_size=2)
self.conv2d2 = nn.Conv2d(in_channels=6,out_channels=16,kernel_size=5)
self.avgpool2 = nn.AvgPool2d(kernel_size=2)
self.classify = nn.Sequential(nn.Linear(in_features=400, out_features=120),
nn.ReLU(inplace=True),
nn.Linear(in_features=120, out_features=84),
nn.ReLU(inplace=True),
nn.Linear(in_features=84, out_features=10))
self.relu = nn.ReLU()
def forward(self, x):
x = self.conv2d1(x)
x = self.relu(x)
# print(x.shape)
x = self.avgpool1(x)
# print(x.shape)
x = self.conv2d2(x)
x = self.relu(x)
# print(x.shape)
x = self.avgpool2(x)
# print(x.shape)
x = x.view(-1, 400)
x = self.classify(x)
return x
if __name__ == "__main__":
x = torch.randn(8,1,32,32)
net = LeNet5()
y = net(x)
print(y.shape)
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm
import torch.nn.functional as F
from tensorboardX import SummaryWriter
class LeNet5(nn.Module):
def __init__(self):
super(LeNet5, self).__init__()
self.conv2d1 = nn.Conv2d(in_channels=1,out_channels=6,kernel_size=5,padding=2)
self.avgpool1 = nn.AvgPool2d(kernel_size=2)
self.conv2d2 = nn.Conv2d(in_channels=6,out_channels=16,kernel_size=5)
self.avgpool2 = nn.AvgPool2d(kernel_size=2)
self.classify = nn.Sequential(nn.Linear(in_features=400, out_features=120),
nn.ReLU(inplace=True),
nn.Linear(in_features=120, out_features=84),
nn.ReLU(inplace=True),
nn.Linear(in_features=84, out_features=10))
self.relu = nn.ReLU()
def forward(self, x):
x = self.conv2d1(x)
x = self.relu(x)
# print(x.shape)
x = self.avgpool1(x)
# print(x.shape)
x = self.conv2d2(x)
x = self.relu(x)
# print(x.shape)
x = self.avgpool2(x)
# print(x.shape)
x = x.view(-1, 400)
x = self.classify(x)
return x
class Linear(nn.Module):
def __init__(self):
super(Linear, self).__init__()
self.linear1 = nn.Linear(784,512)
self.linear2 = nn.Linear(512,256)
self.linear3 = nn.Linear(256,128)
self.linear4 = nn.Linear(128,64)
self.linear5 = nn.Linear(64,10)
def forward(self, x):
x = x.view(-1, 784)
x = self.linear1(x)
x = F.relu(x)
x = self.linear2(x)
x = F.relu(x)
x = self.linear3(x)
x = F.relu(x)
x = self.linear4(x)
x = F.relu(x)
x = self.linear5(x)
return x
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.1307,),(0.3081))])
train_dataset = datasets.MNIST(root = './data/mnist',
train = True,
transform = transform,
download = True)
test_dataset = datasets.MNIST(root = './data/mnist',
train = False,
transform = transform,
download = True)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=1)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False, num_workers=1)
LR = 0.01
model = LeNet5()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR, momentum=0.5)
def train(epoch):
running_loss = 0.0
total = 0
for idx,data in enumerate(tqdm(train_loader)):
x, y = data
optimizer.zero_grad()
y_hat = model(x)
loss = criterion(y_hat, y)
# optimizer.zero_grad()
loss.backward()
optimizer.step()
total += y.size(0)
running_loss += loss.item()
# if idx % 300 == 299:
# print('[%d, %5d] loss: %.3f' % (epoch+1, idx+1, running_loss/300))
# running_loss = 0.0
print('[%d, %5d] loss: %.3f' % (epoch + 1, idx + 1, running_loss / total))
return running_loss / total
def test():
correct = 0
total = 0
with torch.no_grad():
for data in tqdm(test_loader):
x, y = data
y_hat= model(x)
_,predicted = torch.max(y_hat.data, dim=1)
total += y.size(0)
correct += (predicted==y).sum().item()
print('Accuracy on test set: %d %%' % (100*correct/total))
return 100*correct/total
if __name__ == "__main__":
EPOCH = 10
writer = SummaryWriter()
for epoch in range(EPOCH):
loss = train(epoch)
writer.add_scalars('Train/Loss', {'trainloss': loss}, epoch)
acc = test()
writer.add_scalars('test_acc', {'acc': acc}, epoch)
writer.close()