【AlexNet】 训练 CIFAR10 数据集

AlexNet 模型结构
【AlexNet】 训练 CIFAR10 数据集_第1张图片

import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
import os
import numpy as np
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils import data
import matplotlib.pyplot as plt
import torch.optim as optim
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


# ================================= 下载图片、预处理图片、数据加载器 =================================

transform = transforms.Compose([transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC),
                                transforms.ToTensor(),
                                transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.247, 0.2435, 0.2616])
                                ])

train_set = torchvision.datasets.CIFAR10('../input/cifar10', train=True, download=True, transform=transform)
test_set = torchvision.datasets.CIFAR10('../input/cifar10', train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_set, batch_size=256, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=256, shuffle=True)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# ================================= 查看图片 =================================

# data_iter = iter(train_loader)
# sample_images, sample_labels = data_iter.next()
#
# plt.figure(figsize=(7, 7))
#
# for i in range(9):
#     plt.subplot(3, 3, i+1)
#     img = sample_images[i] / 2 + 0.5
#     plt.title(classes[sample_labels[i]])
#     plt.imshow(np.transpose(img, (1, 2, 0)))
# plt.show()


# ================================= 定义模型 =================================

class Alexnet(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=1), nn.ReLU(),
                                 nn.MaxPool2d(kernel_size=3, stride=2),
                                 nn.Conv2d(96, 256, kernel_size=5, padding=2), nn.ReLU(),
                                 nn.MaxPool2d(kernel_size=3, stride=2),
                                 nn.Conv2d(256, 384, kernel_size=3, padding=1), nn.ReLU(),
                                 nn.Conv2d(384, 384, kernel_size=3, padding=1), nn.ReLU(),
                                 nn.Conv2d(384, 256, kernel_size=3, padding=1), nn.ReLU(),
                                 nn.MaxPool2d(kernel_size=3, stride=2),
                                 nn.Flatten(), nn.Linear(256 * 5 * 5, 4096), nn.ReLU(),
                                 nn.Dropout(0.5),
                                 nn.Linear(4096, 4096), nn.ReLU(),
                                 nn.Dropout(0.5),
                                 nn.Linear(4096, 10))

    def forward(self, X):
        return self.net(X)


# ================================= 训练模型 =================================

# 参数初始化
def initial(layer):
    if isinstance(layer, nn.Linear) or isinstance(layer, nn.Conv2d):
        nn.init.xavier_normal_(layer.weight.data)


model = Alexnet().to(device)
model.apply(initial)

epochs = 20
lr = 0.01
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005)

eval_losses = []
eval_acces = []

for epoch in range(epochs):

    if (epoch + 1) % 5 == 0:
        optimizer.param_groups[0]['lr'] *= 0.1

    model.train()
    for imgs, labels in train_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        predict = model(imgs)
        loss = criterion(predict, labels)
        print('epoch {}   loss: {}'.format(epoch, loss))

        # back propagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    eval_loss = 0
    eval_acc = 0
    model.eval()
    for imgs, labels in test_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        predict = model(imgs)
        loss = criterion(predict, labels)

        # record loss
        eval_loss += loss.item()

        # record accurate rate
        result = torch.argmax(predict, axis=1)
        acc_num = (result == labels).sum().item()
        acc_rate = acc_num / imgs.shape[0]
        eval_acc += acc_rate

    eval_losses.append(eval_loss / len(test_loader))
    eval_acces.append(eval_acc / len(test_loader))

    print('epoch: {}'.format(epoch))
    print('loss: {}'.format(eval_loss / len(test_loader)))
    print('accurate rate: {}'.format(eval_acc / len(test_loader)))
    print('\n')

plt.title('evaluation loss')
plt.plot(np.arange(len(eval_losses)), eval_losses)
plt.show()

你可能感兴趣的:(#,经典网络学习,深度学习,pytorch,人工智能)