Pytorch学习笔记(十六)Image and Video - Transfer Learning for Computer Vision Tutorial

这篇博客瞄准的是 pytorch 官方教程中 Image and Video 章节的 Transfer Learning for Computer Vision Tutorial 部分。

  • 官网链接:https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html
完整网盘链接: https://pan.baidu.com/s/1L9PVZ-KRDGVER-AJnXOvlQ?pwd=aa2m 提取码: aa2m 

Transfer Learning for Computer Vision Tutorial

这个示例中将介绍如何使用迁移学习训练卷积神经网络进行图像分类。

实际上,很少有人从头开始训练整个卷积网络,通常在非常大的数据集上预训练 ConvNet(例如 ImageNet,其中包含 120 万张图像和 1000 个类别),然后将 ConvNet 用作初始化或固定特征提取器来完成感兴趣的任务。

两个主要的迁移学习场景如下所示:

  • 微调 ConvNet:使用预训练网络来初始化网络,其余训练与常规一致;
  • ConvNet 作为固定特征提取器,冻结除最终全连接层之外的所有网络的权重,最后一个全连接层将被替换为具有随机权重的新层,并且只训练这一层;

导入依赖包:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import os, time
from PIL import Image
from tempfile import TemporaryDirectory

cudnn.benchmark = True
plt.ion()

Load Data

使用 torchvisiontorch.utils.data 来加载数据。目标是训练一个模型来对蚂蚁和蜜蜂进行分类,为蚂蚁和蜜蜂各准备了大约 120 张训练图像,每个类别有 75 张验证图像,该数据集是 imagenet 的一个非常小的子集。从这个 链接 中下载并解压数据。

定义一个数据增强函数

data_transforms = {
    'train': transforms.Compose(
        [
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]
    ),
    'val': transforms.Compose(
        [
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]
    )  
}

定义数据加载器

data_dir = 'data/hymenoptera_data'

image_datasets = {
    x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']
}
data_loaders = {
    x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4, shuffle=True, num_workers=4) for x in ['train', 'val']
}
dataset_size = {
    x: len(image_datasets[x]) for x in ['train', 'val']
}
class_names = image_datasets['train'].classes

检查可用设备

device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else 'cpu'

抽查几个数据

def imshow(inp, title=None):
    """Display image for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)

inputs, classes = next(iter(data_loaders['train']))
out = torchvision.utils.make_grid(inputs)
imshow(out, title=[class_names[x] for x in classes])

Training the model

定义训练函数

def train_model(model, criterion, optmizer, scheduler, num_epochs=25):
    since = time.time()
    
    with TemporyDirectory() as tempdir:
        best_model_params_path = os.path.join(tempdir, 'best_model_params.pt')
        print(f"Best model save as {best_model_params_path}")
        torch.save(model.state_dict(), best_model_params_path)
        best_acc = 0.0
        
        for epoch in range(num_epochs):
            print('-' * 30)
            print(f"Epoch {epoch+1}/{num_epochs}")
            
            for phase in ['train', 'val']:
                if phase == 'train':
                    model.train()
                else:
                    model.eval()
                
                running_loss = 0.0
                running_corrects = 0
                
                for inputs, labels in data_loaders[phase]:
                    inputs = inputs.to(device)
                    labels = labels.to(device)
                    optimizer.zero_grad()
                    
                    # train
                    with torch.set_grad_enabled(phase == 'train'):
                        outputs = model(inputs)
                        _, preds = torch.max(outputs, 1)
                        loss = criterion(preds, labels)
                        
                        if phase == 'train':
                            loss.backward()
                            optimizer.step()
                    
                    # staistics
                    running_loss += loss.item() * inputs.size(0)
                    running_corrects += torch.sum(preds == labels.data)
                
                if phase == 'train':
                    scheduler.step()
                    
                epoch_loss = running_loss / dataset_sizes[phase]
                epoch_acc  = running_corrects.double() / dataset_sizes[phase]
                print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

                if phase == 'val' and epoch_acc > best_acc:
                    best_acc = epoch_acc
                    torch.save(model.state_dict(), best_model_params_path)
            print()

        time_elpased = time.time() - since
        print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
        print(f'Best val Acc: {best_acc:4f}')
            
        model.load_static_dict(torch.load(best_model_params_path, weights_only=True))
    return model

Visualizing the model predictions

定义模型可视化工具

def visualize_model(model, num_images=6):
    was_training = model.training
    model.eval()
    images_so_far = 0
    fig.plt.figure()
    
    with torch.no_grad():
        for i, (inputs, labels) in enumerate(data_loaders['val']):
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            
            for j in range(inputs.size()[0]):
                images_so_far += 1
                ax = plt.subplot(num_images//2, 2, images_so_far)
                ax.axis('off')
                ax.set_title(f'predicted: {class_names[preds[j]]}')
                imshow(inputs.cpu().data[j])
                
                if images_so_far == num_images:
                    model.train(mode=was_training)
                    return 
        model.train(mode=was_training)

Finetuning the ConvNet

拉取预训练模型

model_ft = models.resnet18(weights='IMAGENET1K_V1')
num_ftrs = model_ft.fc.in_features

第一优化器与损失函数

model_ft.fc = nn.Linear(num_ftrs, 2)
model_ft = model_ft.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

训练模型

model_ft = train_model(model_ft, criterion, optimizer, exp_lr_scheduler, num_epochs=25)

抽查可视化

visualize_model(model_ft)

Pytorch学习笔记(十六)Image and Video - Transfer Learning for Computer Vision Tutorial_第1张图片


ConvNet as fixed feature extractor

冻结除最后一层之外的所有参数,设置 require_grad = False 来冻结参数,这样梯度就不会在 Backward() 中计算。

加载预训练模型

model_conv = torchvision.models.resnet18(weights="IMAGENET1K_V1")

for param in model_conv.parameters():
    param.requires_grad = False

替换掉模型的最后一层

num_ftrs = model_conv.fc.in_features
model_conv.fc = nn.Linear(num_ftrs, 2)

定义优化器与损失函数

model_conv = model_conv.to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model_conv.parameters(), lr=0.001, momentum=0.9)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

训练模型

model_conv = train_model(model_conv, criterion, optimizer, exp_lr_scheduler, num_epochs=20)

抽查可视化

visualize_model(model_conv)

plt.ioff()
plt.show()

Inference on custom images

使用指定路径文件进行推理

def visualize_model_predictions(model, img_path):
    was_training = model.training
    model.eval()
    
    img = Image.open(img_path)
    img = data_transforms['val'](img)
    img = img.unsqueeze(0)
    img = img.to(device)
    
    with torch.no_grad():
        outputs = model(img)
        _, preds = torch.max(outputs, 1)
        
        ax = plt.subplot(2,2,1)
        ax.axis('off')
        ax.set_title(f"Predicted: {class_names[preds[0]]}")
        imshow(img.cpu().data[0])
        
        model.train(mode=was_training)

绘制图像

visualize_model_predictions(
    model_conv,
    img_path='data/hymenoptera_data/val/bees/72100438_73de9f17af.jpg'
)

plt.ioff()
plt.show()

你可能感兴趣的:(pytorch学习笔记,pytorch,学习,笔记)