import torch
from torch import nn
import numpy as np
from torch.autograd import Variable
from torchvision.datasets import CIFAR10
class AlexNet(nn.Module):
def __init__(self):
super().__init__()
# 卷积层1:输入通道为3,输出通道为64, 卷积核大小5*5
self.conv1 = nn.Sequential(
nn.Conv2d(3, 64, 5),
nn.ReLU(True),
)
# pooling层1: 3*3池化,步长为2
self.max_pool1 = nn.MaxPool2d(3, 2)
# 卷积层2: 输入通道为64, 输出通道为64,卷积核大小为5*5,
self.conv2 = nn.Sequential(
nn.Conv2d(64, 64, 5),
nn.ReLU(True),
)
# pooling层1: 3*3池化,步长为2
self.max_pool2 = nn.MaxPool2d(3, 2)
# 输入1024, 输出384
self.fc1 = nn.Sequential(
nn.Linear(1024, 384),
nn.ReLU(True)
)
# 输入384, 输出192
self.fc2 = nn.Sequential(
nn.Linear(384, 192),
nn.ReLU(True)
)
# 输入192, 输出10
self.fc3 = nn.Linear(192, 10)
def forward(self, x):
x = self.conv1(x)
x = self.max_pool1(x)
x = self.conv2(x)
x = self.max_pool2(x)
x = x.view(x.shape[0], -1)
x = self.fc1(x)
x = self.fc2(x)
x = self.fc3(x)
return x
""""
alexnet = AlexNet()
input_demo = Variable(torch.zeros(1, 3, 32, 32))
output_demo = alexnet(input_demo)
"""
def data_tf(x):
x = np.array(x, dtype="float32") / 255
x =(x - 0.5) / 0.5
# print(x.shape) # 图片格式为32*32*3
x = x.transpose((2, 0, 1))
# print(x.shape) # 转换成PyTorch支持的格式3*32*32
x = torch.from_numpy(x)
return x
train_set = CIFAR10("./data_cifar10", train=True, transform=data_tf, download=True)
train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
test_set = CIFAR10("./data_cifar10", train=False, transform=data_tf, download=True)
test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)
net = AlexNet()
optimizer = torch.optim.SGD(net.parameters(), lr=1e-1)
criterion = nn.CrossEntropyLoss()
i =0
for e in range(20):
losses = 0
acces = 0
net.train()
for im, label in train_data:
i = i + 1
im = Variable(im)
lable =Variable(label)
out = net(im)
loss = criterion(out, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses = losses + loss.data
_, pred = out.max(1)
acc = float((pred == label).sum().data) / im.shape[0]
acces = acces + acc
print("interation=", i, "loss = ", loss, "acc=", acc)
print("epoch :{}, Train Loss:{:.6f}, Train ACC:{:.6f}"
.format(e+1, losses / len(train_data), acces / len(train_data)))
采用CIFAR10S数据集,因为图像的分辨率只有32*32,所以对卷积核的大小和整体结构进行了简化。