PyTorch复现经典网络————AlexNet

import torch
from torch import nn
import numpy as np
from torch.autograd import Variable
from torchvision.datasets import CIFAR10

class AlexNet(nn.Module):
    def __init__(self):
        super().__init__()
# 卷积层1:输入通道为3,输出通道为64, 卷积核大小5*5
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, 5),
            nn.ReLU(True),
        )
# pooling层1: 3*3池化,步长为2
        self.max_pool1 = nn.MaxPool2d(3, 2)

# 卷积层2: 输入通道为64, 输出通道为64,卷积核大小为5*5,
        self.conv2 = nn.Sequential(
            nn.Conv2d(64, 64, 5),
            nn.ReLU(True),
        )

# pooling层1: 3*3池化,步长为2
        self.max_pool2 = nn.MaxPool2d(3, 2)
# 输入1024, 输出384
        self.fc1 = nn.Sequential(
            nn.Linear(1024, 384),
            nn.ReLU(True)
        )
# 输入384, 输出192
        self.fc2 = nn.Sequential(
            nn.Linear(384, 192),
            nn.ReLU(True)
        )
# 输入192, 输出10
        self.fc3 = nn.Linear(192, 10)


    def forward(self, x):
        x = self.conv1(x)
        x = self.max_pool1(x)
        x = self.conv2(x)
        x = self.max_pool2(x)

        x = x.view(x.shape[0], -1)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x
""""
alexnet = AlexNet()

input_demo = Variable(torch.zeros(1, 3, 32, 32))
output_demo = alexnet(input_demo)

"""

def data_tf(x):
    x = np.array(x, dtype="float32") / 255
    x =(x - 0.5) / 0.5
    # print(x.shape)  # 图片格式为32*32*3
    x = x.transpose((2, 0, 1))
    # print(x.shape)  # 转换成PyTorch支持的格式3*32*32
    x = torch.from_numpy(x)
    return x

train_set = CIFAR10("./data_cifar10", train=True, transform=data_tf, download=True)
train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
test_set = CIFAR10("./data_cifar10", train=False, transform=data_tf, download=True)
test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)

net = AlexNet()
optimizer = torch.optim.SGD(net.parameters(), lr=1e-1)
criterion = nn.CrossEntropyLoss()

i =0

for e in range(20):
    losses = 0
    acces = 0
    net.train()
    for im, label in train_data:
        i = i + 1
        im = Variable(im)
        lable =Variable(label)
        out = net(im)
        loss = criterion(out, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses = losses + loss.data

        _, pred = out.max(1)
        acc = float((pred == label).sum().data) / im.shape[0]
        acces = acces + acc
        print("interation=", i, "loss = ", loss, "acc=", acc)
    print("epoch :{}, Train Loss:{:.6f}, Train ACC:{:.6f}"
          .format(e+1, losses / len(train_data), acces / len(train_data)))

采用CIFAR10S数据集,因为图像的分辨率只有32*32,所以对卷积核的大小和整体结构进行了简化。

你可能感兴趣的:(PyTorch)