Pytorch实战入门(一):MLP
Pytorch实战入门(二):CNN与MNIST
Pytorch实战入门(三):迁移学习
数据集下载地址,提取码:6smh
数据集格式满足 torchvision.datasets.ImageFolder
读取要求(根目录/类别名/图像名.jpg
)
主要涉及
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms, models
import numpy as np
import matplotlib.pyplot as plt
import os
import copy
def train(model, dataloader, loss_fn, optimizer, epoch):
model.train()
train_loss = 0.
train_corrects = 0.
for inputs, labels in dataloader:
inputs, labels = inputs.to(device), labels.to(device)
with torch.autograd.set_grad_enabled(True):
outputs = model(inputs)
loss = loss_fn(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
preds = outputs.argmax(dim=1)
train_loss += loss.item() * inputs.size(0)
train_corrects += torch.sum(preds.view(-1) == labels.view(-1)).item()
epoch_loss = train_loss / len(dataloader.dataset)
epoch_acc = train_corrects / len(dataloader.dataset)
print("epoch {} train loss: {}, acc: {}".format(epoch, epoch_loss, epoch_acc))
return epoch_loss, epoch_acc
def test(model, dataloader):
model.eval()
test_corrects = 0.
with torch.no_grad():
for inputs, labels in dataloader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
preds = outputs.argmax(dim=1)
test_corrects += torch.sum(preds.view(-1) == labels.view(-1)).item()
epoch_acc = test_corrects / len(dataloader.dataset)
print("test acc: {}".format(epoch_acc))
return epoch_acc
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data_path = "./hymenoptera_data" # 数据集路径
input_size = 224 # 输入图像 224*224
num_classes = 2 # 类别数量 2 ants和bees
batch_size = 32
epochs = 20
lr = 0.001
pretrained = False
feature_extract = False
data_transforms = {
"train": transforms.Compose([
transforms.RandomResizedCrop(input_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.515, 0.469, 0.341], [0.271, 0.255, 0.281])
]),
"val": transforms.Compose([
transforms.Resize(input_size),
transforms.CenterCrop(input_size),
transforms.ToTensor(),
transforms.Normalize([0.515, 0.469, 0.341], [0.271, 0.255, 0.281])
])
}
# 数据集
image_datasets = {x: datasets.ImageFolder(os.path.join(data_path, x), data_transforms[x]) for x in ["train", "val"]}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x],
batch_size=batch_size,
shuffle=True, num_workers=1) for x in ["train", "val"]
}
train_dataloader = dataloaders["train"]
val_dataloader = dataloaders["val"]
# 模型
# 拿到 pytorch定义的 resnet18
# pretrained=True则自动下载在ImageNet上训练好的模型并读取参数
model = models.resnet18(pretrained=pretrained)
# feature_extract==True 则网络已有的参数不参与训练
if feature_extract:
for param in model.parameters():
param.requires_grad = False
in_features = model.fc.in_features # 拿到全连接层的输入维度
# 原本网络是1000类分类,即最后一层是 nn.Linear(in_features, 1000)
# 而我们的任务是二分类,因此要重新定义一个 nn.Linear
# 如果feature_extract==True,则整个网络只训练这个重新定义的全连接层
model.fc = nn.Linear(in_features, num_classes)
# param = torch.load("./models/best_model3.pt")
# model.load_state_dict(param)
model = model.to(device)
# 优化器
optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()),
lr=lr, momentum=0.9)
# 损失函数
loss_fn = nn.CrossEntropyLoss()
train_loss_history = []
train_acc_history = []
val_acc_history = []
for epoch in range(epochs):
best_model = copy.deepcopy(model.state_dict())
best_acc = 0.
train_loss, train_acc = train(model, train_dataloader, loss_fn, optimizer, epoch)
train_loss_history.append(train_loss)
train_acc_history.append(train_acc)
val_acc = test(model, val_dataloader)
val_acc_history.append(val_acc)
if val_acc > best_acc:
best_acc = val_acc
best_model = copy.deepcopy(model.state_dict())
torch.save(best_model, "./models/best_model1.pt")
对数据集的 transform 可见 图像变换 torchvision.transforms 笔记
直接训练
只拿现成的网络结构,从头训练整个网络。
pretrained = False
,不使用预训练模型
feature_extract = False
,不冻结除最后一个全连接层以外的网络参数
epochs = 20
预训练
除最后一个全连接层以外的网络参数使用预训练模型参数,并且整个网络一起训练。
pretrained = True
feature_extract = False
epochs = 20
预训练 + 微调
使用预训练模型参数,但只训练最后一个分类层,保存 best_model3.pt
。
pretrained = True
feature_extract = True
epochs = 10
之后读取 best_model3.pt
,训练整个网络。
pretrained = False
feature_extract = False
epochs = 10
param = torch.load("./models/best_model3.pt")
model.load_state_dict(param)
transforms.Normalize()
参数获取,由于网络是在ImageNet上预训练的,理论上用 ImageNet的标准化参数 比较好 mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
# 对未使用 transforms.Normalize() 的 image_datasets
data1 = [d[0][0].cpu().numpy() for d in image_datasets["train"]]
data2 = [d[0][1].cpu().numpy() for d in image_datasets["train"]]
data3 = [d[0][2].cpu().numpy() for d in image_datasets["train"]]
print(np.mean(data1), np.mean(data2), np.mean(data3))
print(np.std(data1), np.std(data2), np.std(data3))
看一看数据图像
img = image_datasets["train"]
unloader = transforms.ToPILImage() # reconvert into PIL image
plt.ion()
def imshow(tensor, title=None):
image = tensor.cpu().clone() # we clone the tensor to not do changes on it
image = unloader(image)
plt.axis('off')
plt.imshow(image)
if title is not None:
plt.title(title)
plt.pause(0.001) # pause a bit so that plots are updated
plt.figure()
imshow(img[1][0], title='Image')
未使用 transforms.Normalize()
和使用后的两张图像
拿到模型以后若要修改首先要知道模型原本的结构。
model = models.resnet18(pretrained=pretrained)
print(model)
直接打印模型,就可以看出 resnet-18 由 conv1 + bn1 + relu + maxpool + layer1 + layer2 + layer3 + layer4 + avgpool + fc
构成,用 model.fc
可以直接拿到最后一个全连接层的信息:
Linear(in_features=512, out_features=1000, bias=True)
任务需求是二分类问题,简单修改 model.fc
即可。