torchvision中的transforms

torchvision.transforms提供了很多数据增强的方法,如下:
- Compose:统一的接口,用来方便组合各种不同的数据增强方法
- ToTensor
- ToPILImage
- Normalize
- Resize
- Scale
- Pad
- CenterCrop
- FiveCrop
- TenCrop
- RandomCrop
- RandomHorizontalFlip
- RandomVerticalFlip
- RandomRotation
- RandomResizedCrop
- RandomSizedCrop
- LinearTransformation
- ColorJitter
- Grayscale
- RandomGrayscale
- Lambda

torchvision使用PIL.Image作为核心数据结构,并没有使用cv2,想必是因为cv2不够pythonic.

一般在代码中,transforms处于的角色是,把dataset的输出值,变成网络直接的输入值,中间的桥梁是由transforms来完成。因为一般dataset主要负责读取各种各样形式的数据,然后在__getitem__里面,调用transformtarget_transform,把原始数据,变成网络的输入数据。遵循unix的每次只做一件事的风格。

所以,transforms需要具备如下特点:
1. 统一使用Compose组合所有的操作,方便使用。
2. 同时提供traintest两种接口(因为像RandomHorizontalFlip这样的方法,在推断的时候,不用使用)。
3. 针对不同的问题,比如classification可能只针对输入图像操作,object detection可能还要对2组坐标点处理,saliency可能需要定义各种groundtruth的形式,提供不同的接口操作。
4. 所有的代码,均要测试。
5. 每一类操作的接口,都必须是相同的,如果不同,有两种可能:
- 它是其他类型的操作。
- 它应该在dataset中被实现。

一个合理的写代码的流程是:
1. 针对具体数据集形式,写dataset类,留好transforms和target_transforms接口。
2. 设计transforms和target_transforms分别要做的事情。
3. 在augmentation.py中写各个操作的接口。
4. gen-test生成测试脚本,写测试脚本。
5. 实现augmentation.py中写各个操作,并通过测试。

你可能感兴趣的:(pytorch)