Pytorch目标检测(一)

实际操作中自己创建的训练集比较小所以重新训练一个模型比较复杂,因此可以利用别人在一些大型数据集上训练好的预训练模型,然后在自己的数据集上训练。

预加载模型

  1. 利用torchvision.models自带的预训练模型,设置参数pretrained=True
  2. 加载训练过的本地模型可以利用,torch.load(path)以及model.load_state_dict()加载
    例:
import torch
from torchvision import models

vgg = models.vgg16(pretrained=True)
for layer in range(10):
   for p in vgg.features[layer].parameters():
       p.requires_grad = False
torch.save({'model':vgg.state_dict()},'C:\\Users\\czk\\Desktop\\Pytorch\\Models\\model1.path')
print(vgg.features)
checkpoint = torch.load('C:\\Users\\czk\\Desktop\\Pytorch\\Models\\model1.path')
vgg2 = models.vgg16()
vgg2.load_state_dict(checkpoint['model'])

离线下载模型:
Downloading: “https://download.pytorch.org/models/vgg16-397923af.pth” to C:\Users\czk/.cache\torch\checkpoints\vgg16-397923af.pth

数据读取

继承Dataset类

torch提供Dataset抽象类,使用时继承该类并重写len和getiterm函数

from torch.utils.data import Dataset

#第一步 继承Dataset类
class my_data(Dataset):
    def __init__(self, image_path, annotation_path, transform=None):
        # 初始化读取数据
    def __len__(self):
        # 获取数据集总大小
    def __getitem__(self, id):
        # 对于指定的id,读取该数据并返回

数据变换与增强

实际使用过程中数据大小不一样,需要进行缩放、裁剪、翻转填充以及归一化等操作,可以利用transform.Compose类将多个变换整合,但操作对象必须是PIL的Image或者Tensor

from torchvision import transforms

#第二步 定义数据变换与增强
change = transforms.Compose([transforms.Resize(256),#将图像最短边缩小至256,宽高比不变
                             transforms.RandomHorizontalFlip(),#以0.5的概率随机翻转指定的PIL图像
                             transforms.ToTensor(),#将PIL图像转为Tensor,元素由[0,255]归一到[0,1]
                             transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])]) #进行均值为0.5,标准差为0.5的标准化
dataset = my_data("your image path", "your annotation path", transforms=change)

继承dataloader

前两步可以得到变换后的每一个样本,但是批量处理和随机选取还需要dataloader类

from torch.utils.data import Dataloader

#第三步 进一步封装dataset获得批量处理和随机选取
dataloader = Dataloader(dataset, batch_size=16, shuffle=True, num_worker=4)#第一个参数是之前继承的dataset实类,第二个参数是批处理每捆的大小,第三个参数是是否打乱数据,第四个数据是使用几个线程来加载数据
#迭代训练
data_iter = iter(dataloader)
for step in range(iters_per_epoch):
    data = next(data_iter)
    # 将data用于训练

数据可视化

Pytorch 常用可视化工具有TensorBoardX(和tensorflow的tensorborad不一样)和Visdom,但Visdom的功能更强大。

Visdom 的安装

pip3 install visdom
python -m visdom.server

Visdom 实例

import torch
import visdom

# 创建visdom客户端,使用默认端口8097, 环境为first, 环境是对可视化空间进行分区
vis = visdom.Visdom(env='first')

vis.text('first visdom', win='text1')#win参数是窗口名称
vis.text('Hello Pytorch', win='text1', append=True)#append为False时会覆盖之前的text

for i in range(20):#可视化函数
    vis.line(X=torch.FloatTensor([i]), Y=torch.FloatTensor([2*i]),
             opts={'title':'y=2*x'}, win='loss', update='append')

vis.image(torch.randn(3, 256, 256), win='random_image')#随机可视化图片

打开浏览器,输入http://localhost:8097可以看到可视化结果

资料参考:《深度学习之PyTorch物体检测实战》

你可能感兴趣的:(Pytorch目标检测(一))