深度学习模型训练一般流程
功能:构建可迭代的数据装载器
概念辨析:
Epoch: 所有训练样本都已输入到模型中,称为一个Epoch
Iteration:一批样本输入到模型中,称之为一个Iteration
Batchsize:批大小,决定一个Epoch有多少个Iteration
e.g.:
class MyDataset(Dataset):
def __init__(self, data_dir, transforms=None):
super().__init__()
self.Label = {
'1':0, '100',1}
self.data_info = self.get_img_info(data_dir) # data_info存储所有图片路径和标签,在DataLoader中通过index读取样本
self.transform = transform
def __getitem__(self, index):
path_img, label = self.data_info[index]
img = Image.open(path_img).convert('RGB') # 0~255
if self.transform is not None:
img = self.transform(img) # 在这里做transform,转为tensor等等
return img, label
def __len__(self):
return len(self.data_info)
功能:Dataset抽象类,所有自定义的Dataset需要继承它,并且复写__getitem__()
getitem :接收一个索引,返回一个样本
torchvision.transforms : 常用的图像预处理方法
作用位置:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-WEO1DU2s-1637072332139)(https://i.loli.net/2021/11/16/9c4xH5n6JegIPjS.png)]
数据增强又称为数据增广,数据扩增,它是对训练集进行变换,使训练集更丰富,从而让模型更具泛化能力
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-2IQ1OkSu-1637072332143)(https://i.loli.net/2021/11/16/OxMqR7fTFnL8VQs.png)]
transforms.CenterCrop
功能:从图像中心裁剪图片
transforms.RandomCrop
功能:从图片中随机裁剪出尺寸为size的图片
RandomResizedCrop
功能:随机大小、长宽比裁剪图片
FiveCrop
TenCrop
功能:在图像的上下左右以及中心裁剪出尺寸为size的5张图片,TenCrop对这5张图片进行水平或者垂直镜像获得10张图片
RandomHorizontalFlip
RandomVerticalFlip
RandomRotation
功能:随机旋转图片
**degrees:**旋转角度
当为a时,在(-a,a)之间选择旋转角度
当为(a, b)时,在(a, b)之间选择旋转角度
**resample:**重采样方法
**expand:**是否扩大图片,以保持原图信息
transforms.Pad(padding,
fill=0,
padding_mode='constant')
transforms.ColorJitter(brightness=0,
contrast=0,
saturation=0,
hue=0)
Grayscale
RandomGrayscale
功能:依概率将图片转换为灰度图
RandomGrayscale(num_output_channels,
p=0.1)
Grayscale(num_output_channels)
RandomAffine
功能:对图像进行仿射变换,仿射变换是二维的线性变换,由五种基本原子变换构成,分别是旋转、平移、缩放、错切和翻转
RandomAffine(degrees,
translate=None,
scale=None,
shear=None,
resample=False,
fillcolor=0)
RandomErasing
功能:对图像进行随机遮挡
transforms.Lambda
功能:用户自定义lambda方法
transforms.RandomChoice([transforms1, transforms2, transforms3])
功能:从一系列transforms方法中随机挑选一个
transforms.RandomApply([transforms1, transforms2, transforms3], p=0.5)
功能:依据概率执行一组transforms操作
transforms.RandomOrder([transforms1, transforms2, transforms3])
功能:对一组transforms操作打乱顺序
自定义transforms要素:
class YourTransforms(object):
def __init__(self, ...):
def __call__(self, img):
return img
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-z9YuW2g3-1637072332145)(https://i.loli.net/2021/11/16/nC9JMbkfG1dZQeT.png)]
原则:让训练集与测试集更接近
ansforms要素:
class YourTransforms(object):
def __init__(self, ...):
def __call__(self, img):
return img
[外链图片转存中…(img-z9YuW2g3-1637072332145)]
原则:让训练集与测试集更接近