实际操作中自己创建的训练集比较小所以重新训练一个模型比较复杂,因此可以利用别人在一些大型数据集上训练好的预训练模型,然后在自己的数据集上训练。
import torch
from torchvision import models
vgg = models.vgg16(pretrained=True)
for layer in range(10):
for p in vgg.features[layer].parameters():
p.requires_grad = False
torch.save({'model':vgg.state_dict()},'C:\\Users\\czk\\Desktop\\Pytorch\\Models\\model1.path')
print(vgg.features)
checkpoint = torch.load('C:\\Users\\czk\\Desktop\\Pytorch\\Models\\model1.path')
vgg2 = models.vgg16()
vgg2.load_state_dict(checkpoint['model'])
离线下载模型:
Downloading: “https://download.pytorch.org/models/vgg16-397923af.pth” to C:\Users\czk/.cache\torch\checkpoints\vgg16-397923af.pth
torch提供Dataset抽象类,使用时继承该类并重写len和getiterm函数
from torch.utils.data import Dataset
#第一步 继承Dataset类
class my_data(Dataset):
def __init__(self, image_path, annotation_path, transform=None):
# 初始化读取数据
def __len__(self):
# 获取数据集总大小
def __getitem__(self, id):
# 对于指定的id,读取该数据并返回
实际使用过程中数据大小不一样,需要进行缩放、裁剪、翻转填充以及归一化等操作,可以利用transform.Compose类将多个变换整合,但操作对象必须是PIL的Image或者Tensor
from torchvision import transforms
#第二步 定义数据变换与增强
change = transforms.Compose([transforms.Resize(256),#将图像最短边缩小至256,宽高比不变
transforms.RandomHorizontalFlip(),#以0.5的概率随机翻转指定的PIL图像
transforms.ToTensor(),#将PIL图像转为Tensor,元素由[0,255]归一到[0,1]
transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])]) #进行均值为0.5,标准差为0.5的标准化
dataset = my_data("your image path", "your annotation path", transforms=change)
前两步可以得到变换后的每一个样本,但是批量处理和随机选取还需要dataloader类
from torch.utils.data import Dataloader
#第三步 进一步封装dataset获得批量处理和随机选取
dataloader = Dataloader(dataset, batch_size=16, shuffle=True, num_worker=4)#第一个参数是之前继承的dataset实类,第二个参数是批处理每捆的大小,第三个参数是是否打乱数据,第四个数据是使用几个线程来加载数据
#迭代训练
data_iter = iter(dataloader)
for step in range(iters_per_epoch):
data = next(data_iter)
# 将data用于训练
Pytorch 常用可视化工具有TensorBoardX(和tensorflow的tensorborad不一样)和Visdom,但Visdom的功能更强大。
pip3 install visdom
python -m visdom.server
import torch
import visdom
# 创建visdom客户端,使用默认端口8097, 环境为first, 环境是对可视化空间进行分区
vis = visdom.Visdom(env='first')
vis.text('first visdom', win='text1')#win参数是窗口名称
vis.text('Hello Pytorch', win='text1', append=True)#append为False时会覆盖之前的text
for i in range(20):#可视化函数
vis.line(X=torch.FloatTensor([i]), Y=torch.FloatTensor([2*i]),
opts={'title':'y=2*x'}, win='loss', update='append')
vis.image(torch.randn(3, 256, 256), win='random_image')#随机可视化图片
打开浏览器,输入http://localhost:8097可以看到可视化结果