深度学习图像预处理

在ptorch训练模型的时候,一般度输入图像做:

  • 随机裁剪
  • 随机翻转
  • Totensor
    最近使用pytorch写data.Dataset类,以前都是学习别人的程序,对这些操作写一个类,将所有数据变为numpy,可以处理各种想要处理的数据,但是现在预处理的数据全是图片的,所以想用PIL做一些简单的操作

pytorch

import torch.utils.data as data
import torchvision.transforms as transforms

class RandomCrop(object):
    def __init__(self, output_size):
        self.crop_size = output_size
        
    def __call__(self, sample):
        raw, gt = sample['raw'], sample['gt']
        
        h, w = raw.shape[0], raw.shape[1]

        np.random.seed()
        xx = np.random.randint(0, w - self.crop_size)
        yy = np.random.randint(0, h - self.crop_size)
        
        raw = raw[yy:yy + self.crop_size, xx:xx + self.crop_size, :]
        gt = gt[yy * 4:yy * 4 + self.crop_size * 4, xx * 4:xx * 4 + self.crop_size * 4, :]
        
        sample = {
                'raw': raw,
                'gt' : gt
                }
        return sample
    
class RandomFlip(object):
    def __init__(self):
        pass
    def __call__(self, sample):
        raw, gt = sample['raw'], sample['gt']
        
        do_reflection = np.random.randint(2)
        do_mirror = np.random.randint(2)
        do_transpose = np.random.randint(2)
        if do_reflection:
            raw = np.flip(raw, 0)
            gt = np.flip(gt, 0)
        if do_mirror:
            raw = np.flip(raw, 1)
            gt = np.flip(gt, 1)
        if do_transpose:
            raw = np.transpose(raw, (1, 0, 2))
            gt = np.transpose(gt, (1, 0, 2))
        sample = {
                'raw': raw,
                'gt' : gt
                }
        return sample
    
class ToTensor(object):
    def __init__(self):
        pass
    
    def __call__(self, sample):
        raw, gt = sample['raw'], sample['gt']
        raw = raw.transpose((2, 0, 1))
        gt = gt.transpose((2, 0, 1))

        raw, gt = np.ascontiguousarray(raw), np.ascontiguousarray(gt)
        raw, gt = torch.from_numpy(raw), torch.from_numpy(gt)
        sample = {
                'raw': raw,
                'gt' : gt
                }
        return sample

def get_transform(self):
        transform_list = []

        transform_list.append(RandomCrop(crop_size))
        transform_list.append(RandomFlip())
        transform_list.append(ToTensor())
    
        return transforms.Compose(transform_list)

sample = get_transform()(sample)

PIL

import Image
import torchvision.transforms.functional as tf
def transform(self, img, gt, crop_size=512):
        '''
        :img PIL image
        :gt PIL image
        '''
        #crop
        w, h = img.size
        assert type(crop_size)==int and crop_size<=min(w, h), 'check crop_size type or num'
        xx = np.random.randint(0, w-crop_size)
        yy = np.random.randint(0, h-crop_size)
        img = img.crop((xx, yy, xx+crop_size, yy+crop_size))
        gt = gt.crop((xx, yy, xx+crop_size, yy+crop_size))
        #flip
        if np.random.randint(2):
            img = img.transpose(Image.FLIP_LEFT_RIGHT)
            gt = gt.transpose(Image.FLIP_LEFT_RIGHT)
        if np.random.randint(2):
            img = img.transpose(Image.FLIP_TOP_BOTTOM)
            gt = gt.transpose(Image.FLIP_TOP_BOTTOM)
        if np.random.randint(2):
            img = img.rotate(180)
            gt = gt.rotate(180)
        #toTensor
        img = tf.to_tensor(img)
        gt = tf.to_tensor(gt)
        return img, gt

torchvision.transforms.functional 这个库还挺好用的0.0

你可能感兴趣的:(pytorch)