Pytorch学习笔记(16)———预训练模型微调

完整工程

  • 工程目录结构
    Pytorch学习笔记(16)———预训练模型微调_第1张图片
  • Code
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import torchvision.models as models
import torchvision.transforms as transforms
import numpy as np
import copy


# ---------------------------------------------------------
# 载入预训练的AlexNet模型
model = models.alexnet(pretrained=True)
# 修改输出层,2分类
model.classifier[6] = nn.Linear(in_features=4096, out_features=2)


# -------------------------数据集----------------------------------------------------

transform = transforms.Compose([transforms.Resize((227,227)),
                                transforms.ToTensor()])

train_dataset = ImageFolder(root='./data/train', transform=transform)
val_dataset = ImageFolder(root='./data/val', transform=transform)

train_dataloader = DataLoader(dataset=train_dataset, batch_size=4, num_workers=4, shuffle=True)
val_dataloader = DataLoader(dataset=val_dataset, batch_size=4, num_workers=4, shuffle=False)


# ------------------优化方法,损失函数--------------------------------------------------
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
loss_fc = nn.CrossEntropyLoss()
scheduler = optim.lr_scheduler.StepLR(optimizer, 20, 0.1)


# --------------------判断是否支持GPU--------------------------------------------------
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model.to(device)

# -------------------训练-------------------------------------------------------------

epoch_nums = 50
best_model_wts = model.state_dict()
best_acc = 0
for epoch in range(epoch_nums):
    scheduler.step()
    running_loss = 0.0
    epoch_loss = 0.0
    correct = 0
    total = 0

    for i, sample_batch in enumerate(train_dataloader):
        inputs = sample_batch[0]
        labels = sample_batch[1]

        inputs.to(device)
        labels.to(device)

        model.train()
        optimizer.zero_grad()
        # forward
        outputs = model(inputs)
        # loss
        loss = loss_fc(outputs, labels)

        loss.backward()
        optimizer.step()

        #
        running_loss += loss.item()
        if i % 10 == 9:
            correct = 0
            total = 0
            for images_test, labels_test in val_dataloader:
                model.eval()
                images_test = images_test.to(device)
                labels_test = labels_test.to(device)
                outputs_test = model(images_test)
                _, prediction = torch.max(outputs_test, 1)
                correct += ((prediction == labels_test).sum()).item()
                total += labels_test.size(0)
            accuracy = correct/total
            print('[{}, {}] running loss={:.5f}, accuracy={:.5f}'.format(epoch + 1, i + 1, running_loss/10, accuracy))
            running_loss = 0.0
            if accuracy > best_acc:
                best_acc = accuracy
                best_model_wts = copy.deepcopy(model.state_dict())


print('Train finish')
torch.save(best_model_wts, './models/model_50.pth')

https://www.jianshu.com/p/2e5a9bd5ad36

你可能感兴趣的:(pytorch学习,pytorch,深度学习)