使用ViT(Vision transformer)来训练Cifar10数据集

使用ViT(Vision transformer)来训练Cifar10数据集

下面的代码是使用ViT训练Cifar10数据集的demo。

"""
特征提取的实例:
利用迁移学习中特征提取的方法来对CIFAR-10数据集实现对10类无体的分类
"""
import torch
from torch import nn
import torchvision
import torchvision.transforms as transforms
from datetime import datetime
import matplotlib.pyplot as plt
import numpy as np
from vit_pytorch import ViT, SimpleViT

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


def main():
    # 加载和预处理数据集
    trans_train = transforms.Compose(
        [transforms.RandomResizedCrop(224),  # 将给定图像随机裁剪为不同的大小和宽高比,然后缩放所裁剪得到的图像为制定的大小;
         # (即先随机采集,然后对裁剪得到的图像缩放为同一大小) 默认scale=(0.08, 1.0)
         transforms.RandomHorizontalFlip(),  # 以给定的概率随机水平旋转给定的PIL的图像,默认为0.5;
         transforms.ToTensor(),
         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225])])

    trans_valid = transforms.Compose(
        [transforms.Resize(256),  # 是按照比例把图像最小的一个边长放缩到256,另一边按照相同比例放缩。
         transforms.CenterCrop(224),  # 依据给定的size从中心裁剪
         transforms.ToTensor(),  # 将PIL Image或者 ndarray 转换为tensor,并且归一化至[0-1]
         # 归一化至[0-1]是直接除以255,若自己的ndarray数据尺度有变化,则需要自行修改。
         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225])])  # 对数据按通道进行标准化,即先减均值,再除以标准差,注意是 hwc

    trainset = torchvision.datasets.CIFAR10(root="./cifar10", train=True, download=True, transform=trans_train)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=256, shuffle=True, num_workers=2)

    testset = torchvision.datasets.CIFAR10(root='./cifar10', train=False,
                                           download=False, transform=trans_valid)
    testloader = torch.utils.data.DataLoader(testset, batch_size=256,
                                             shuffle=False, num_workers=2)

    classes = ('plane', 'car', 'bird', 'cat',
               'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
    # 随机获取部分训练数据
    dataiter = iter(trainloader)
    images, labels = dataiter.next()

    # 显示图像
    # imshow(torchvision.utils.make_grid(images[:4]))
    # 打印标签
    print(''.join('%5s ' % classes[labels[j]] for j in range(4)))

    # 使用预训练模型
    model = get_vit_model()

    # 查看总参数及训练参数
    total_params = sum(p.numel() for p in model.parameters())
    print('总参数个数:{}'.format(total_params))

    model = model.to(device)
    criterion = nn.CrossEntropyLoss()  # 损失函数
    # 只需要优化最后一层参数
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, weight_decay=1e-3, momentum=0.9)  # 优化器

    # train
    train(model, trainloader, testloader, 30, optimizer, criterion)


# 计算准确率
def get_acc(output, label):
    total = output.shape[0]
    _, pred_label = output.max(1)
    num_correct = (pred_label == label).sum().item()
    return num_correct / total


# 显示图片
def imshow(img):
    img = img / 2 + 0.5  # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()


# 定义训练函数
def train(net, train_data, valid_data, num_epochs, optimizer, criterion):
    prev_time = datetime.now()
    for epoch in range(num_epochs):
        train_loss = 0
        train_acc = 0
        net = net.train()
        for im, label in train_data:
            im = im.to(device)  # (bs, 3, h, w)
            label = label.to(device)  # (bs, h, w)
            # forward
            output = net(im)
            loss = criterion(output, label)
            # backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            train_acc += get_acc(output, label)

        cur_time = datetime.now()
        h, remainder = divmod((cur_time - prev_time).seconds, 3600)
        m, s = divmod(remainder, 60)
        time_str = "Time %02d:%02d:%02d" % (h, m, s)
        if valid_data is not None:
            valid_loss = 0
            valid_acc = 0
            net = net.eval()
            for im, label in valid_data:
                im = im.to(device)  # (bs, 3, h, w)
                label = label.to(device)  # (bs, h, w)
                output = net(im)
                loss = criterion(output, label)
                valid_loss += loss.item()
                valid_acc += get_acc(output, label)
            epoch_str = (
                    "Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, "
                    % (epoch, train_loss / len(train_data),
                       train_acc / len(train_data), valid_loss / len(valid_data),
                       valid_acc / len(valid_data)))
        else:
            epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " %
                         (epoch, train_loss / len(train_data),
                          train_acc / len(train_data)))
        prev_time = cur_time
        print(epoch_str + time_str)


def get_vit_model():
    v = SimpleViT(
        image_size=224,
        patch_size=16,
        num_classes=10,
        dim=256,
        depth=2,
        heads=4,
        mlp_dim=128
    )
    return v


if __name__ == '__main__':
    main()

你可能感兴趣的:(transformer,深度学习,python)