还记得这篇文章吗?迁移学习|代码实现
在这篇文章中,我们知道了在构建模型时,可以借助一些非常有名的模型,这些模型在ImageNet数据集上早已经得到了检验。
同时torchvision模块也提供了预训练好的模型。我们只需稍作修改,便可运用到自己的实际任务中!
我们仍然按照这个步骤开始我们的模型的训练
准备一个可迭代的数据集
定义一个神经网络
将数据集输入到神经网络进行处理
计算损失
通过梯度下降算法更新参数
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torchvision import models
数据集准备
cifar10_train = torchvision.datasets.CIFAR10(
root = 'cifar10/',
train = True,
download = True
)
cifar10_test=torchvision.datasets.CIFAR10(
root = 'cifar10/',
train = False,
download = True
)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((224,224))
])
cifar2_train=[(transform(img),[3,5].index(label)) for img,label in cifar10_train if label in [3,5]]
cifar2_test=[(transform(img),[3,5].index(label)) for img,label in cifar10_test if label in [3,5]]
train_loader = torch.utils.data.DataLoader(cifar2_train, batch_size=64,shuffle=True)
test_loader = torch.utils.data.DataLoader(cifar2_test, batch_size=64,shuffle=True)
数据集使用CIFAR-10数据集中的猫和狗。
CIFAR-10数据集类别
种类 标签
plane 0
car 1
bird 2
cat 3
deer 4
dog 5
frog 6
horse 7
ship 8
truck 9
可以看到其中cat和dog的标签分别为3和5
借助:
[3,5].index(label)
我们可以将cat标签变为0,dog标签变为1,从而回到二分类问题。
举个例子:
>>> [3,5].index(3)
0
>>> [3,5].index(5)
1
定义模型
参考这篇文章:迁移学习|代码实现
#网络搭建
network=models.resnet18(pretrained=True)
for param in network.parameters():
param.requires_grad=False
network.fc=nn.Linear(512,2)
#损失函数
criterion=nn.CrossEntropyLoss()
#优化器
optimizer=optim.SGD(network.fc.parameters(),lr=0.01,momentum=0.9)
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
network=network.to(device)
训练模型:
for epoch in range(10):
total_loss = 0
total_correct = 0
for batch in train_loader: # Get batch
images, labels =batch
images=images.to(device)
labels=labels.to(device)
optimizer.zero_grad() #告诉优化器把梯度属性中权重的梯度归零,否则pytorch会累积梯度
preds = network(images)
loss = criterion(preds, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
_,prelabels=torch.max(preds,dim=1)
total_correct += int((prelabels==labels).sum())
accuracy = total_correct/len(cifar2_train)
print("Epoch:%d , Loss:%f , Accuracy:%f "%(epoch,total_loss,accuracy))
Epoch:0 , Loss:78.549439 , Accuracy:0.788900
Epoch:1 , Loss:77.828066 , Accuracy:0.801500
Epoch:2 , Loss:66.151785 , Accuracy:0.828100
Epoch:3 , Loss:76.204446 , Accuracy:0.816800
Epoch:4 , Loss:68.886606 , Accuracy:0.828100
Epoch:5 , Loss:71.129405 , Accuracy:0.821200
Epoch:6 , Loss:66.096364 , Accuracy:0.829900
Epoch:7 , Loss:65.504227 , Accuracy:0.827700
Epoch:8 , Loss:76.303878 , Accuracy:0.817100
Epoch:9 , Loss:70.546953 , Accuracy:0.820700
测试模型:
correct=0
total=0
network.eval()
with torch.no_grad():
for batch in test_loader:
imgs,labels=batch
imgs=imgs.cuda()
labels=labels.cuda()
preds=network(imgs)
_,prelabels=torch.max(preds,dim=1)
#print(prelabels.size())
total=total+labels.size(0)
correct=correct+int((prelabels==labels).sum())
#print(total)
accuracy=correct/total
print("Accuracy: ",accuracy)
Accuracy: 0.8025
这里使用的预训练模型是resnet18,我们也可以使用VGG16模型,同时记得改变最后一个全连接层的输出参数,使得其满足我们自己的任务。
除了预训练模型之外,我们还可以对一些超参数进行调整,使最后的效果变得更好!