最近使用pytorch训练人脸属性分类,关注到训练时的图像增强方法,常规使用方法是transform,之前没有特别留意过transform的使用。今天看了几篇帖子,稍微总结一下pytorch是怎样对数据进行预处理。
直接上源码,还是放最后吧(因为我也讨厌直接看有源码的博客,呵呵呵),先讲一下调用流程:
目录
一、pytorch框架加载数据的调用函数解析
二、数据在哪里正式对数据进行预处理
还是上代码:
train_loader = torch.utils.data.DataLoader(
ImageFolder(traindir, args.traintxtroot, transforms.Compose([
# transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),#变换色彩
transforms.RandomGrayscale(p=0.1), #转换成灰度图
# transforms.Pad(padding = 4, fill=0, padding_mode='edge'),#进行边缘填充
# transforms.Resize((135, 135)), # 图像尺寸缩放#mobilenetv2
# transforms.RandomCrop((128, 128)),
transforms.Resize((120, 120)), # 图像尺寸缩放
# transforms.RandomCrop((112, 112)),
# transforms.RandomHorizontalFlip(p=0.2), #依据概率p对PIL图片进行水平翻转
# transforms.RandomVerticalFlip(p=0.2), #依据概率p对PIL图片进行垂直翻转
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # 用均值和方差归一化图片
])),
batch_size=args.train_batch, shuffle=True,
num_workers=args.workers, pin_memory=True)
1) ImageFolder
功能:用于加载数据
源码中有说到:
class ImageFolder(DatasetFolder):
"""A generic data loader where the images are arranged in this way: ::
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
然后这片的代码我是有修改的,直接读取的txt(人脸是识别几百万的数据,这篇代码需要加载半个小时以上,修改后的代码参考链接:https://blog.csdn.net/qq_22764813/article/details/94589717)
2)transforms.Compose
功能:包含了各种各样的数据增强方法
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # 用均值和方差归一化图片
以上两句话是一般都需要加上去的,ToTensor()的作用功能具体可查看以下源码。
注意点:输入图像是可以是PIL的读入图像,也可以是opencv读入的图像;
class ToTensor(object):
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
Converts a PIL Image or numpy.ndarray (H x W x C) in the range
[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
"""
def __call__(self, pic):
"""
Args:
pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
Returns:
Tensor: Converted image.
"""
return F.to_tensor(pic)
def __repr__(self):
return self.__class__.__name__ + '()'
以下代码就是判断图像的读入是否符合PIL或是OpenCV格式,然后根据图像的读入方式和图像的通道不同,进行归一化到[0, 1]的操作,方法是所有像素除以255,进行图像像素归一化;
def to_tensor(pic):
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
See ``ToTensor`` for more details.
Args:
pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
Returns:
Tensor: Converted image.
"""
if not(_is_pil_image(pic) or _is_numpy_image(pic)):
raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic)))
if isinstance(pic, np.ndarray):
# handle numpy array
img = torch.from_numpy(pic.transpose((2, 0, 1)))
# backward compatibility
if isinstance(img, torch.ByteTensor):
return img.float().div(255)
else:
return img
if accimage is not None and isinstance(pic, accimage.Image):
nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
pic.copyto(nppic)
return torch.from_numpy(nppic)
# handle PIL Image
if pic.mode == 'I':
img = torch.from_numpy(np.array(pic, np.int32, copy=False))
elif pic.mode == 'I;16':
img = torch.from_numpy(np.array(pic, np.int16, copy=False))
elif pic.mode == 'F':
img = torch.from_numpy(np.array(pic, np.float32, copy=False))
elif pic.mode == '1':
img = 255 * torch.from_numpy(np.array(pic, np.uint8, copy=False))
else:
img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
# PIL image mode: L, P, I, F, RGB, YCbCr, RGBA, CMYK
if pic.mode == 'YCbCr':
nchannel = 3
elif pic.mode == 'I;16':
nchannel = 1
else:
nchannel = len(pic.mode)
img = img.view(pic.size[1], pic.size[0], nchannel)
# put it from HWC to CHW format
# yikes, this transpose takes 80% of the loading time/CPU
img = img.transpose(0, 1).transpose(0, 2).contiguous()
if isinstance(img, torch.ByteTensor):
return img.float().div(255)
else:
return img
3) batch_size=args.train_batch, shuffle=True, num_workers=args.workers, pin_memory=True
batch_size:训练迭代一次所用的图像的数量的大小;
shuffle: # 是否随机打乱顺序
num_workers: #每次提取数据多进程数量
pin_memory: #就是锁页内存,创建DataLoader时,设置pin_memory=True,则意味着生成的Tensor数据最开始是属于内存中的锁页内存,这样将内存的Tensor转义到GPU的显存就会更快一些。
主机中的内存,有两种存在方式,一是锁页,二是不锁页,锁页内存存放的内容在任何情况下都不会与主机的虚拟内存进行交换(注:虚拟内存就是硬盘),而不锁页内存在主机内存不足时,数据会存放在虚拟内存中。
而显卡中的显存全部是锁页内存!
当计算机的内存充足的时候,可以设置pin_memory=True。当系统卡住,或者交换内存使用过多的时候,设置pin_memory=False。因为pin_memory与电脑硬件性能有关,pytorch开发者不能确保每一个炼丹玩家都有高端设备,因此pin_memory默认为False。(转自:http://www.voidcn.com/article/p-fsdktdik-bry.html)
在调用torch.utils.data.DataLoader只是将数据的路径和标签信息读入工程,以及其他信息的初始化,那在哪里具体使用transform呢?
代码:
for batch_idx, (inputs, targets) in enumerate(train_loader)
在进行上面的循环代码时,训练时数据的预处理就开始了,最终经过这次循环后,该循环执行后,得出的(inputs,target)就是预处理好的数据,也就是可以送进模型训练的数据。
1)数据读入代码
库里面默认是PIL,不利于工程部署,因为工程部署使用的是opencv读图,opencv读图和PIL读图是有差异的,差异有两点,1.图像通道,一个是RGB,一个是BGR;2.读入图像的数据类型;请仔细查阅,特别是在转模型的时候. 而且为了和PIL读入图片格式保持一致,项目中使用opencv读图后需要进行图像类型的转换,这时又会耗费更多的时间使用在图像转换上,建议使用opencv读图;
参考:https://www.jianshu.com/p/aba1142c0453(具体案例)