在ptorch训练模型的时候,一般度输入图像做:
data.Dataset
类,以前都是学习别人的程序,对这些操作写一个类,将所有数据变为numpy,可以处理各种想要处理的数据,但是现在预处理的数据全是图片的,所以想用PIL
做一些简单的操作pytorch
import torch.utils.data as data
import torchvision.transforms as transforms
class RandomCrop(object):
def __init__(self, output_size):
self.crop_size = output_size
def __call__(self, sample):
raw, gt = sample['raw'], sample['gt']
h, w = raw.shape[0], raw.shape[1]
np.random.seed()
xx = np.random.randint(0, w - self.crop_size)
yy = np.random.randint(0, h - self.crop_size)
raw = raw[yy:yy + self.crop_size, xx:xx + self.crop_size, :]
gt = gt[yy * 4:yy * 4 + self.crop_size * 4, xx * 4:xx * 4 + self.crop_size * 4, :]
sample = {
'raw': raw,
'gt' : gt
}
return sample
class RandomFlip(object):
def __init__(self):
pass
def __call__(self, sample):
raw, gt = sample['raw'], sample['gt']
do_reflection = np.random.randint(2)
do_mirror = np.random.randint(2)
do_transpose = np.random.randint(2)
if do_reflection:
raw = np.flip(raw, 0)
gt = np.flip(gt, 0)
if do_mirror:
raw = np.flip(raw, 1)
gt = np.flip(gt, 1)
if do_transpose:
raw = np.transpose(raw, (1, 0, 2))
gt = np.transpose(gt, (1, 0, 2))
sample = {
'raw': raw,
'gt' : gt
}
return sample
class ToTensor(object):
def __init__(self):
pass
def __call__(self, sample):
raw, gt = sample['raw'], sample['gt']
raw = raw.transpose((2, 0, 1))
gt = gt.transpose((2, 0, 1))
raw, gt = np.ascontiguousarray(raw), np.ascontiguousarray(gt)
raw, gt = torch.from_numpy(raw), torch.from_numpy(gt)
sample = {
'raw': raw,
'gt' : gt
}
return sample
def get_transform(self):
transform_list = []
transform_list.append(RandomCrop(crop_size))
transform_list.append(RandomFlip())
transform_list.append(ToTensor())
return transforms.Compose(transform_list)
sample = get_transform()(sample)
PIL
import Image
import torchvision.transforms.functional as tf
def transform(self, img, gt, crop_size=512):
'''
:img PIL image
:gt PIL image
'''
#crop
w, h = img.size
assert type(crop_size)==int and crop_size<=min(w, h), 'check crop_size type or num'
xx = np.random.randint(0, w-crop_size)
yy = np.random.randint(0, h-crop_size)
img = img.crop((xx, yy, xx+crop_size, yy+crop_size))
gt = gt.crop((xx, yy, xx+crop_size, yy+crop_size))
#flip
if np.random.randint(2):
img = img.transpose(Image.FLIP_LEFT_RIGHT)
gt = gt.transpose(Image.FLIP_LEFT_RIGHT)
if np.random.randint(2):
img = img.transpose(Image.FLIP_TOP_BOTTOM)
gt = gt.transpose(Image.FLIP_TOP_BOTTOM)
if np.random.randint(2):
img = img.rotate(180)
gt = gt.rotate(180)
#toTensor
img = tf.to_tensor(img)
gt = tf.to_tensor(gt)
return img, gt
torchvision.transforms.functional
这个库还挺好用的0.0