我因为是做图像修复领域的,所以需要对数据进行随机擦除的预处理操作。
def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False):
assert isinstance(value, (numbers.Number, str, tuple, list))
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
warnings.warn("range should be of kind (min, max)")
if scale[0] < 0 or scale[1] > 1:
raise ValueError("range of scale should be between 0 and 1")
if p < 0 or p > 1:
raise ValueError("range of random erasing probability should be between 0 and 1")
self.p = p
self.scale = scale
self.ratio = ratio
self.value = value
self.inplace = inplace
其中
p=概率,scale=面积,ratio=长宽比,value=默认就行(代表RGB或者灰度图)
如是你想将所有图片都进行erase,可以直接将P设为1.
transforms.RandomErasing(1)
模板官网上有相关教程,这里我就写一个我自己使用的模板
my_trans = transforms.Compose([
transforms.Resize((512,512)),
#resize 图片的大小
transforms.RandomHorizontalFlip(),
#随机翻转
transforms.ToTensor(),
#归一化,一定要加的 防止各种小bug
transforms.RandomErasing(1,(0.02,0.03),value=1),
#这里就是随机擦除了,我要每张图都擦除,并且面积不要太大,所以就设置了以上的数据。
])
给出完整代码可训练自己的数据集:
import torch
from torchvision import transforms, utils
from torchvision.datasets import ImageFolder
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
"""
数据预处理
"""
def detection_collate(batch):
"""Custom collate fn for dealing with batches of images that have a different
number of associated object annotations (bounding boxes).
Arguments:
batch: (tuple) A tuple of tensor images and lists of annotations
Return:
A tuple containing:
1) (tensor) batch of images stacked on their 0 dim
2) (list of tensors) annotations for a given image are stacked on 0 dim
"""
targets = []
imgs = []
for sample in batch:
imgs.append(sample[0])
targets.append(torch.FloatTensor(sample[1]))
return torch.stack(imgs, 0), targets
hy_trans = transforms.Compose([
transforms.Resize((512,512)),
transforms.ToTensor(),
transforms.RandomErasing(1,(0.02,0.03),value=1),
])
"""
生成batch_size迭代器,用于加载数据
"""
train_data = ImageFolder(r'E:\pycharm\dataset\shujuzengqiang\test', transform=hy_trans)
#图片的真实路径其实是E:\pycharm\dataset\shujuzengqiang\test\test 不知道为啥写完整反而识别出错。
train_loader = DataLoader(train_data, batch_size=1, shuffle=True )
for i_batch, img in enumerate(train_loader):
# img[0] ->图像, img[1] -> label
fig = plt.figure()
# img 格式 {batch, channel, height , weight}, 和make_grid要求一样;
# make_grid 批量显示多幅画面 -> 输出 {channel, height , weight}
grid = utils.make_grid(img[0], padding=0)
# tensor -> numpy; chw -> hwc, plt才能显示
plt.imshow(grid.numpy().transpose((1, 2, 0)))
utils.save_image(grid,'%d.jpg'%(i_batch))
if i_batch==10:
break
print("computer finision")