PyTorch复现LeNet

PytTorch复现LeNet

  • LeNet
    • 创新点
    • LeNet 使用PyTorch实现
      • LeNet.py
      • train.py
      • predict.py

LeNet

LeNet是一个较为简单的卷积神经网络。通过巧妙的设计,利用卷积、参数共享、池化等操作提取特征,避免了大量的计算成本,最后使用全连接神经网络进行分类识别。
PyTorch复现LeNet_第1张图片

创新点

现在大多数的卷积神经网络都是基于LeNet的框架(卷积层、池化层、全连接层)。

PyTorch复现LeNet_第2张图片

LeNet 使用PyTorch实现

采用PyTorch框架,使用CIFAR10训练集,训练100个epoch。

LeNet.py

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 5)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 32, 5)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(32*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))    # input(3, 32, 32) output(16, 28, 28)
        x = self.pool1(x)            # output(16, 14, 14)
        x = F.relu(self.conv2(x))    # output(32, 10, 10)
        x = self.pool2(x)            # output(32, 5, 5)
        x = x.view(-1, 32*5*5)       # output(32*5*5)
        x = F.relu(self.fc1(x))      # output(120)
        x = F.relu(self.fc2(x))      # output(84)
        x = self.fc3(x)              # output(10)
        return x

打印LeNet模型如下
PyTorch复现LeNet_第3张图片

relu
relu
relu
relu
input
conv1
pool1
conv2
pool2
fc1
fc2
fc3

train.py

import torch
import torchvision
import torch.nn as nn
from LeNet import LeNet

import torch.optim as optim
import torchvision.transforms as transforms


def main():
    transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]
    )

    #50000
    #
    train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=transform)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=36, shuffle=True, num_workers=0)

    #10000
    val_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=False, transform=transform)
    val_loader = torch.utils.data.DataLoader(val_set, batch_size=5000, shuffle=False, num_workers=0)

    val_data_iter = iter(val_loader)
    val_image, val_label = val_data_iter.next()

    net = LeNet()

    loss_function = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr = 0.001)

    for epoch in range(100):

        running_loss = 0.0
        for step, data in enumerate(train_loader, start=0):
            inputs, labels = data

            optimizer.zero_grad()
            outputs = net(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if step % 500 == 499:
                with torch.no_grad():
                    outputs = net(val_image)
                    predict_y = torch.max(outputs, dim=1)[1]
                    accuracy = torch.eq(predict_y, val_label).sum().item() / val_label.size(0)

                    print('[%d, %5d] train_loss: %.3f test_accuracy: %.3f' %
                          (epoch + 1, step + 1, running_loss / 500, accuracy))
                    running_loss = 0.0

    print('Finished Training')

    save_path = './LeNEt.pth'
    torch.save(net.state_dict(), save_path)

if __name__ == '__main__':
    main()

predict.py

import torch
import torchvision.transforms as transforms

from PIL import Image

from LeNet import LeNet
import onnx

def main():
    transform = transforms.Compose(
        [
            transforms.Resize((32, 32)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]
    )

    classes = ('plane', 'car', 'bird', 'cat',
               'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
    net = LeNet()

    print(net)
    net.load_state_dict(torch.load('Lenet.pth'))

    im = Image.open('./1.jpg')
    im = transform(im)
    im = torch.unsqueeze(im, dim =0)

    with torch.no_grad():
        outputs = net(im)
        predict = torch.max(outputs, dim=1)[1].data.numpy()

    torch.onnx.export(net, im, "lenet.onnx")
    print(classes[int(predict)])

if __name__ == '__main__':
    main()

你可能感兴趣的:(深度学习,python,pytorch,深度学习,神经网络)