pytorch魔改data_set,帮助DataLoader实现enumerate(test_loader)载入image、target、name、oriimg

以项目pytorch-deeplab-xception为例:

测试代码:

https://github.com/jfzhang95/pytorch-deeplab-xception/issues/122

    def test(self):
        self.model.eval()
        self.evaluator.reset()
        # tbar = tqdm(self.test_loader, desc='\r')
        for i, sample in enumerate(self.test_loader):
            image, target = sample['image'], sample['label']
            with torch.no_grad():
                output = self.model(image)
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()

这里通过枚举函数,生成的只有image和target,

但是image是已经数据增强过的,有可能已经改变,而且经过Totensor()函数和均值方差,

已经不能适应后续我们可视化任务的需要,

这里我们还需要name、oriimg(resize后的原始图像,以帮助我们可视化)

pytorch魔改data_set,帮助DataLoader实现enumerate(test_loader)载入image、target、name、oriimg_第1张图片

    def test(self):
        self.model.eval()
        self.evaluator.reset()
        # tbar = tqdm(self.test_loader, desc='\r')
        num = len(self.test_loader)
        for i, sample in enumerate(self.test_loader):
            image, target = sample['image'], sample['label']
            print(i,"/",num)
            torch.cuda.synchronize()
            start = time.time()
            with torch.no_grad():
                output = self.model(image)
            end = time.time()
            times = (end - start) * 1000
            print(times, "ms")
            torch.cuda.synchronize()
            pred = output.data.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            target = target.cpu().numpy()

原始的dateset代码:

train_set = coco.COCOSegmentation(args, split='train')
val_set = coco.COCOSegmentation(args, split='val'

train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs)
val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, **kwargs)

from dataloaders.datasets import cityscapes, coco, combine_dbs, pascal, sbd
from torch.utils.data import DataLoader

def make_data_loader(args, **kwargs):

    elif args.dataset == 'coco':
        train_set = coco.COCOSegmentation(args, split='train')
        val_set = coco.COCOSegmentation(args, split='val')
        num_class = train_set.NUM_CLASSES
        train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs)
        val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, **kwargs)
        test_loader = None
        return train_loader, val_loader, test_loader, num_class

    else:
        raise NotImplementedError

通过COCO处理函数得到COCO的dataset:

import numpy as np
import torch
from torch.utils.data import Dataset
from mypath import Path
from tqdm import trange
import os
from pycocotools.coco import COCO
from pycocotools import mask
from torchvision import transforms
from dataloaders import custom_transforms as tr
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True


class COCOSegmentation(Dataset):
    NUM_CLASSES = 21
    CAT_LIST = [0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4,
        1, 64, 20, 63, 7, 72]

    def __init__(self,
                 args,
                 base_dir=Path.db_root_dir('coco'),
                 split='train',
                 year='2017'):
        super().__init__()
        ann_file = os.path.join(base_dir, 'annotations/instances_{}{}.json'.format(split, year))
        ids_file = os.path.join(base_dir, 'annotations/{}_ids_{}.pth'.format(split, year))
        self.img_dir = os.path.join(base_dir, 'images/{}{}'.format(split, year))
        self.split = split
        self.coco = COCO(ann_file)
        self.coco_mask = mask
        if os.path.exists(ids_file):
            self.ids = torch.load(ids_file)
        else:
            ids = list(self.coco.imgs.keys())
            self.ids = self._preprocess(ids, ids_file)
        self.args = args

    def __getitem__(self, index):
        _img, _target = self._make_img_gt_point_pair(index)
        sample = {'image': _img, 'label': _target}

        if self.split == "train":
            return self.transform_tr(sample)
        elif self.split == 'val':
            return self.transform_val(sample)

    def _make_img_gt_point_pair(self, index):
        coco = self.coco
        img_id = self.ids[index]
        img_metadata = coco.loadImgs(img_id)[0]
        path = img_metadata['file_name']
        _img = Image.open(os.path.join(self.img_dir, path)).convert('RGB')
        cocotarget = coco.loadAnns(coco.getAnnIds(imgIds=img_id))
        _target = Image.fromarray(self._gen_seg_mask(
            cocotarget, img_metadata['height'], img_metadata['width']))

        return _img, _target

    def _preprocess(self, ids, ids_file):
        print("Preprocessing mask, this will take a while. " + \
              "But don't worry, it only run once for each split.")
        tbar = trange(len(ids))
        new_ids = []
        for i in tbar:
            img_id = ids[i]
            cocotarget = self.coco.loadAnns(self.coco.getAnnIds(imgIds=img_id))
            img_metadata = self.coco.loadImgs(img_id)[0]
            mask = self._gen_seg_mask(cocotarget, img_metadata['height'],
                                      img_metadata['width'])
            # more than 1k pixels
            if (mask > 0).sum() > 1000:
                new_ids.append(img_id)
            tbar.set_description('Doing: {}/{}, got {} qualified images'. \
                                 format(i, len(ids), len(new_ids)))
        print('Found number of qualified images: ', len(new_ids))
        torch.save(new_ids, ids_file)
        return new_ids

    def _gen_seg_mask(self, target, h, w):
        mask = np.zeros((h, w), dtype=np.uint8)
        coco_mask = self.coco_mask
        for instance in target:
            rle = coco_mask.frPyObjects(instance['segmentation'], h, w)
            m = coco_mask.decode(rle)
            cat = instance['category_id']
            if cat in self.CAT_LIST:
                c = self.CAT_LIST.index(cat)
            else:
                continue
            if len(m.shape) < 3:
                mask[:, :] += (mask == 0) * (m * c)
            else:
                mask[:, :] += (mask == 0) * (((np.sum(m, axis=2)) > 0) * c).astype(np.uint8)
        return mask

    def transform_tr(self, sample):
        composed_transforms = transforms.Compose([
            tr.RandomHorizontalFlip(),
            tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size),
            tr.RandomGaussianBlur(),
            tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            tr.ToTensor()])

        return composed_transforms(sample)

    def transform_val(self, sample):

        composed_transforms = transforms.Compose([
            tr.FixScaleCrop(crop_size=self.args.crop_size),
            tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            tr.ToTensor()])

        return composed_transforms(sample)


    def __len__(self):
        return len(self.ids)

这里使用的数据增强为自己写的函数,不是pytorch中的transforms

但是使用了transforms的容器函数:transforms.Compose()

数据增强函数都有:

class Normalize(object):

class ToTensor(object):

class RandomHorizontalFlip(object):

class RandomRotate(object):

class RandomGaussianBlur(object):

class RandomScaleCrop(object):

class FixScaleCrop(object):

class FixedResize(object):

这里有一个流程,

pytorch-deeplab-xception/train.py

self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs) 

pytorch-deeplab-xception/dataloaders/__init__.py

train_set = cityscapes.CityscapesSegmentation(args, split='train') 

pytorch-deeplab-xception/dataloaders/datasets/coco.py

def __getitem__(self, index):

    def __getitem__(self, index):
        _img, _target = self._make_img_gt_point_pair(index)
        sample = {'image': _img, 'label': _target}

        if self.split == "train":
            return self.transform_tr(sample)
        elif self.split == 'val':
            return self.transform_val(sample)

 def _make_img_gt_point_pair(self, index):

    def _make_img_gt_point_pair(self, index):
        coco = self.coco
        img_id = self.ids[index]
        img_metadata = coco.loadImgs(img_id)[0]
        path = img_metadata['file_name']
        _img = Image.open(os.path.join(self.img_dir, path)).convert('RGB')
        cocotarget = coco.loadAnns(coco.getAnnIds(imgIds=img_id))
        _target = Image.fromarray(self._gen_seg_mask(
            cocotarget, img_metadata['height'], img_metadata['width']))

        return _img, _target

def transform_tr(self, sample): 

    def transform_tr(self, sample):
        composed_transforms = transforms.Compose([
            tr.RandomHorizontalFlip(),
            tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size),
            tr.RandomGaussianBlur(),
            tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            tr.ToTensor()])

        return composed_transforms(sample)

 

pytorch-deeplab-xception/dataloaders/custom_transforms.py

class Normalize(object):

class ToTensor(object):

class FixScaleCrop(object):

class FixScaleCrop(object):
    def __init__(self, crop_size):
        self.crop_size = crop_size

    def __call__(self, sample):
        img = sample['image']
        mask = sample['label']
        w, h = img.size
        if w > h:
            oh = self.crop_size
            ow = int(1.0 * w * oh / h)
        else:
            ow = self.crop_size
            oh = int(1.0 * h * ow / w)
        img = img.resize((ow, oh), Image.BILINEAR)
        mask = mask.resize((ow, oh), Image.NEAREST)
        # center crop
        w, h = img.size
        x1 = int(round((w - self.crop_size) / 2.))
        y1 = int(round((h - self.crop_size) / 2.))
        img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
        mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))

        return {'image': img,
                'label': mask}
class Normalize(object):
    """Normalize a tensor image with mean and standard deviation.
    Args:
        mean (tuple): means for each channel.
        std (tuple): standard deviations for each channel.
    """
    def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)):
        self.mean = mean
        self.std = std

    def __call__(self, sample):
        img = sample['image']
        mask = sample['label']
        img = np.array(img).astype(np.float32)
        mask = np.array(mask).astype(np.float32)
        img /= 255.0
        img -= self.mean
        img /= self.std

        return {'image': img,
                'label': mask}
class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, sample):
        # swap color axis because
        # numpy image: H x W x C
        # torch image: C X H X W
        img = sample['image']
        mask = sample['label']
        img = np.array(img).astype(np.float32).transpose((2, 0, 1))
        mask = np.array(mask).astype(np.float32)

        img = torch.from_numpy(img).float()
        mask = torch.from_numpy(mask).float()

        return {'image': img,
                'label': mask}

改进一下,可以返回原图和路径:

import torch
import random
import numpy as np

from PIL import Image, ImageOps, ImageFilter

class Normalize(object):
    """Normalize a tensor image with mean and standard deviation.
    Args:
        mean (tuple): means for each channel.
        std (tuple): standard deviations for each channel.
    """
    def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)):
        self.mean = mean
        self.std = std

    def __call__(self, sample):
        img = sample['image']
        mask = sample['label']
        img = np.array(img).astype(np.float32)
        mask = np.array(mask).astype(np.float32)
        img /= 255.0
        img -= self.mean
        img /= self.std

        # return {'image': img,
        #         'label': mask}

        #return {'image': img, 'label': mask, 'ori_image': sample['ori_image'], 'path': sample['path']}

        return {'image': img, 'label': mask, 'ori_image': sample['ori_image'], 'path': sample['path']}


class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, sample):
        # swap color axis because
        # numpy image: H x W x C
        # torch image: C X H X W
        img = sample['image']
        mask = sample['label']

        # import cv2
        # image1 = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)
        # target1 = cv2.cvtColor(np.asarray(mask), cv2.COLOR_GRAY2BGR)
        # cv2.imwrite("./image5.jpg", image1)
        # cv2.imwrite("./target5.jpg", target1)
        #
        # xxx = np.array(img).astype(np.float32)
        # import copy
        # xxx1 = copy.deepcopy(xxx)
        # xxx2 = copy.deepcopy(xxx)
        # img1 = np.array(xxx1).astype(np.float32).transpose((2, 1, 0))
        # img2 = np.array(xxx2).astype(np.float32).transpose((2, 0, 1))

        img = np.array(img).astype(np.float32).transpose((2, 0, 1))
        mask = np.array(mask).astype(np.float32)

        img = torch.from_numpy(img).float()
        mask = torch.from_numpy(mask).float()

        # import cv2
        # image1=img.cpu().numpy()
        # target1=mask.cpu().numpy()
        # image1 = image1.transpose(2, 1, 0)
        # image1 = cv2.cvtColor(image1, cv2.COLOR_RGB2BGR)
        # target1 = cv2.cvtColor(target1, cv2.COLOR_GRAY2BGR)
        # cv2.imwrite("./image4.jpg", image1)
        # cv2.imwrite("./target4.jpg", target1)

        # return {'image': img,
        #         'label': mask}
        ori_image = np.array(sample['ori_image']).astype(np.float32).transpose((2, 0, 1))
        ori_image = torch.from_numpy(ori_image).float()

        #return {'image': img, 'label': mask, 'ori_image': sample['ori_image'], 'path': sample['path']}
        return {'image': img, 'label': mask, 'ori_image': ori_image, 'path': sample['path']}


class RandomHorizontalFlip(object):
    def __call__(self, sample):
        img = sample['image']
        mask = sample['label']
        if random.random() < 0.5:
            img = img.transpose(Image.FLIP_LEFT_RIGHT)
            mask = mask.transpose(Image.FLIP_LEFT_RIGHT)

        # return {'image': img,
        #         'label': mask}
        return {'image': img, 'label': mask, 'ori_image': sample['ori_image'], 'path': sample['path']}


class RandomRotate(object):
    def __init__(self, degree):
        self.degree = degree

    def __call__(self, sample):
        img = sample['image']
        mask = sample['label']
        rotate_degree = random.uniform(-1*self.degree, self.degree)
        img = img.rotate(rotate_degree, Image.BILINEAR)
        mask = mask.rotate(rotate_degree, Image.NEAREST)

        return {'image': img,
                'label': mask}


class RandomGaussianBlur(object):
    def __call__(self, sample):
        img = sample['image']
        mask = sample['label']
        if random.random() < 0.5:
            img = img.filter(ImageFilter.GaussianBlur(
                radius=random.random()))

        # return {'image': img,
        #         'label': mask}
        return {'image': img, 'label': mask, 'ori_image': sample['ori_image'], 'path': sample['path']}

class RandomScaleCrop(object):
    def __init__(self, base_size, crop_size, fill=0):
        self.base_size = base_size
        self.crop_size = crop_size
        self.fill = fill

    def __call__(self, sample):
        img = sample['image']
        mask = sample['label']
        # random scale (short edge)
        short_size = random.randint(int(self.base_size * 0.5), int(self.base_size * 2.0))
        w, h = img.size
        if h > w:
            ow = short_size
            oh = int(1.0 * h * ow / w)
        else:
            oh = short_size
            ow = int(1.0 * w * oh / h)
        img = img.resize((ow, oh), Image.BILINEAR)
        mask = mask.resize((ow, oh), Image.NEAREST)
        # pad crop
        if short_size < self.crop_size:
            padh = self.crop_size - oh if oh < self.crop_size else 0
            padw = self.crop_size - ow if ow < self.crop_size else 0
            img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0)
            mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=self.fill)
        # random crop crop_size
        w, h = img.size
        x1 = random.randint(0, w - self.crop_size)
        y1 = random.randint(0, h - self.crop_size)
        img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
        mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
        #x = mask[mask>1]
        return {'image': img, 'label': mask, 'ori_image': img, 'path': sample['path']}
        #return {'image': img, 'label': mask, 'ori_image': sample['ori_image'], 'path': sample['path']}
        # return {'image': img,
        #         'label': mask}


class FixScaleCrop(object):
    def __init__(self, crop_size):
        self.crop_size = crop_size

    def __call__(self, sample):
        img = sample['image']
        mask = sample['label']
        w, h = img.size
        if w > h:
            oh = self.crop_size
            ow = int(1.0 * w * oh / h)
        else:
            ow = self.crop_size
            oh = int(1.0 * h * ow / w)
        img = img.resize((ow, oh), Image.BILINEAR)
        mask = mask.resize((ow, oh), Image.NEAREST)
        # center crop
        w, h = img.size
        x1 = int(round((w - self.crop_size) / 2.))
        y1 = int(round((h - self.crop_size) / 2.))
        img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
        mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))

        # import cv2
        # image1 = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)
        # target1 = cv2.cvtColor(np.asarray(mask), cv2.COLOR_GRAY2BGR)
        # cv2.imwrite("./image3.jpg", image1)
        # cv2.imwrite("./target3.jpg", target1)


        # return {'image': img,
        #         'label': mask,
        #         }
        #return {'image': img, 'label': mask, 'ori_image': img, 'path': sample['path']}

        return {'image': img, 'label': mask, 'ori_image': img, 'path': sample['path']}

class FixedResize(object):
    def __init__(self, size):
        self.size = (size, size)  # size: (h, w)

    def __call__(self, sample):
        img = sample['image']
        mask = sample['label']

        assert img.size == mask.size

        img = img.resize(self.size, Image.BILINEAR)
        mask = mask.resize(self.size, Image.NEAREST)

        return {'image': img,
                'label': mask}
    def __getitem__(self, index):
        _img, _target, _path = self._make_img_gt_point_pair(index)
        sample = {'image': _img, 'label': _target, 'ori_image': _img, 'path': _path}

        if self.split == "train":
            return self.transform_tr(sample)
        elif self.split == 'val':
            return self.transform_val(sample)
        elif self.split == 'test':
            X = self.transform_val(sample)
            # aa = X['image']
            # bb = X['label']
            #
            # aa = aa.cpu().numpy()
            # bb = bb.cpu().numpy()
            # aa = aa.transpose(2, 1, 0)
            # image1 = cv2.cvtColor(aa, cv2.COLOR_RGB2BGR)
            # target1 = cv2.cvtColor(bb, cv2.COLOR_GRAY2BGR)
            # cv2.imwrite("./image2.jpg", image1)
            # cv2.imwrite("./target2.jpg", target1)

            return X

    def _make_img_gt_point_pair(self, index):
        coco = self.coco
        img_id = self.ids[index]
        img_metadata = coco.loadImgs(img_id)[0]
        path = img_metadata['file_name']
        _path = path.split('.jpg')[0]
        _img = Image.open(os.path.join(self.img_dir, path)).convert('RGB')
        cocotarget = coco.loadAnns(coco.getAnnIds(imgIds=img_id))
        _target = Image.fromarray(self._gen_seg_mask(
            cocotarget, img_metadata['height'], img_metadata['width']))

        #_targetx = np.asarray(_target)
        #x = _targetx[_targetx > 1]
        # image1 = cv2.cvtColor(np.asarray(_img), cv2.COLOR_RGB2BGR)
        # target1 = cv2.cvtColor(np.asarray(_target), cv2.COLOR_GRAY2BGR)
        # cv2.imwrite("./image1.jpg", image1)
        # cv2.imwrite("./target1.jpg", target1)

        return _img, _target, _path

 

你可能感兴趣的:(深度学习)