Pytorch(笔记9)--读取自定义数据

      Pytorch中提供一个了数据接口datasets,其中封装了很多公用数据集CIFAR10/100,ImageNet等,可以用下面的接口进行简单调用,那么如何使用Pytorch来加载我们自己制作好的trainset呢?我们从源码来找答案!

      train_data = datasets.CIFAR10('./cifa10',train=True,transform=train_tranform,download=True)

     从源码可以看到class cifar  继承了VisionDataset,VisionDataset是Dataset的子类,并实现了__init__,__len__,__getitem__,三个方法,事实上我们也可以想要实现自定义的数据接口,并使用pytorch进行训练很简单,只要继承基类Dataset并实现上述的三个方法就可以了。

class Dataset(object):
    """An abstract class representing a Dataset.
    All other datasets should subclass it. All subclasses should override
    ``__len__``, that provides the size of the dataset, and ``__getitem__``,
    supporting integer indexing in range from 0 to len(self) exclusive.
    """

    def __getitem__(self, index):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError

    def __add__(self, other):
        return ConcatDataset([self, other])

      对于加载自己的数据集,Pytorch中同样提供了一个接口,torchvision.datasets.ImageFolder ,但是这个接口相对局限一些,必须符合他的目录结构:/root/ids/*.jpg

__init__ 方法

def __init__(self, root, loader, extensions=None, transform=None, target_transform=None, is_valid_file=None):
        super(DatasetFolder, self).__init__(root)
        self.transform = transform
        self.target_transform = target_transform
        classes, class_to_idx = self._find_classes(self.root)
        samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file)
        if len(samples) == 0:
            raise (RuntimeError("Found 0 files in subfolders of: " + self.root + "\n"
                                "Supported extensions are: " + ",".join(extensions)))

        self.loader = loader
        self.extensions = extensions

        self.classes = classes
        self.class_to_idx = class_to_idx
        self.samples = samples
        self.targets = [s[1] for s in samples]

      我们进行简单调试,看看这个方法都做了什么?

      首先,我们可以看到我们输入的自定义目录self.root 是我们定义的训练集目录,首先进行__find_classes操作,我们来看看__find_classes 源码

 def _find_classes(self, dir):
        """
        Finds the class folders in a dataset.
        Args:
            dir (string): Root directory path.
        Returns:
            tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.
        Ensures:
            No class is a subdirectory of another.
        """
        if sys.version_info >= (3, 5):
            # Faster and available in Python 3.5 and above
            classes = [d.name for d in os.scandir(dir) if d.is_dir()]
        else:
            classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
        classes.sort()
        class_to_idx = {classes[i]: i for i in range(len(classes))}
        return classes, class_to_idx

       返回值classes是一个列表,列表中包含着排好序的id也就是label,而class_to_ids是一个与之序号对应的字典,key是id,value是序号,如下

['102091655-1-201811011700-16', '10209231-1-201811010900-2', '1020962212-2-201811010900-24', '1020966131-3-201811011700-0', '102097752-0-201811010900-6']

{'1020962212-2-201811010900-24': 2, '1020966131-3-201811011700-0': 3, '102097752-0-201811010900-6': 4, '10209231-1-201811010900-2': 1, '102091655-1-201811011700-16': 0}

      接下来,用samples接收make_dataset的返回值,其中extensions表示Pytorch支持的图片编码格式,与is_valid_file用于验证数据的合法性。

IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')

def make_dataset(dir, class_to_idx, extensions=None, is_valid_file=None):
    images = []
    dir = os.path.expanduser(dir)
    if not ((extensions is None) ^ (is_valid_file is None)):
        raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
    if extensions is not None:
        def is_valid_file(x):
            return has_file_allowed_extension(x, extensions)
    for target in sorted(class_to_idx.keys()):
        d = os.path.join(dir, target)
        if not os.path.isdir(d):
            continue
        for root, _, fnames in sorted(os.walk(d)):
            for fname in sorted(fnames):
                path = os.path.join(root, fname)
                if is_valid_file(path):
                    item = (path, class_to_idx[target])
                    images.append(item)
    return images

    samples样例如下,是很多个tuple组成的list存储每个图片和对应的label 

[('test/102091655-1-201811011700-16/10.209.16.55-1-201811011700-201811011703_00000702_crop16.jpg', 0),

('test/102091655-1-201811011700-16/10.209.16.55-1-201811011700-201811011703_00000880_crop16.jpg', 0),

('test/10209231-1-201811010900-2/10.209.23.1-1-201811010900-201811010903_00000092_crop2.jpg', 1),

('test/1020962212-2-201811010900-24/10.209.62.212-2-201811010900-201811010903_00000756_crop24.jpg', 2),

('test/1020966131-3-201811011700-0/10.209.66.131-3-201811011700-201811011703_00000295_crop0.jpg', 3),

('test/1020966131-3-201811011700-0/10.209.66.131-3-201811011700-201811011703_00000302_crop0.jpg', 3),

('test/102097752-0-201811010900-6/10.209.77.52-0-201811010900-201811010903_00000395_crop6.jpg', 4),

('test/102097752-0-201811010900-6/10.209.77.52-0-201811010900-201811010903_00000434_crop6.jpg', 4)]

       接下来,还有一个loader的赋值操作,是一个函数参数,通常我们使用pil_loader函数进行加载。

def pil_loader(path):
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')


def accimage_loader(path):
    import accimage
    try:
        return accimage.Image(path)
    except IOError:
        # Potentially a decoding problem, fall back to PIL.Image
        return pil_loader(path)


def default_loader(path):
    from torchvision import get_image_backend
    if get_image_backend() == 'accimage':
        return accimage_loader(path)
    else:
        return pil_loader(path)

 

__getitem__ 与__len__

     get_item 是Dataloader的调度基础,输入参数是index索引,返回的是经过transform过的图片和label,len函数返回的是数据集的length

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (sample, target) where target is class_index of the target class.
        """
        path, target = self.samples[index]
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return sample, target

    def __len__(self):
        return len(self.samples)

 

DIY Interface(自定义接口)

      如果你可以看懂这几个函数的用法,就可以开始定义自己需要的数据接口了。假设我们的train.txt ,val.txt,test.txt 中的格式如下,想一下我们该如何自定义上文中的三种方法呢?

/20190424/200001320002208-1556076900-23/CJ145YWJMK1-32130200001320002208-1556076900_00002532_crop23.jpg	1
/20190424/200001320002208-1556076900-23/CJ145YWJMK1-32130200001320002208-1556076900_00002521_crop23.jpg	1
/20190424/200001320002208-1556076900-23/CJ145YWJMK1-32130200001320002208-1556076900_00002535_crop23.jpg	2
/20190424/200001320002208-1556076900-23/CJ145YWJMK1-32130200001320002208-1556076900_00002528_crop23.jpg	2
/20190424/200001320002208-1556076900-23/CJ145YWJMK1-32130200001320002208-1556076900_00002523_crop23.jpg	2
/20190424/200001320002208-1556076900-23/CJ145YWJMK1-32130200001320002208-1556076900_00002529_crop23.jpg	3
/20190424/200001320002208-1556076900-23/CJ145YWJMK1-32130200001320002208-1556076900_00002527_crop23.jpg	3
/20190424/200001320002208-1556067600-69/CJ145YWJMK1-32130200001320002208-1556067600_00000833_crop69.jpg	3
/20190424/200001320002208-1556067600-69/CJ145YWJMK1-32130200001320002208-1556067600_00000834_crop69.jpg	4
/20190424/00001320000179-1556104800-30/SZ009SZZP3-32130200001320000179-1556104800_00001954_crop30.jpg	4

        下面是我给的伪代码,没有调试,主要是为了说明这个道理!

# _*_ coding:utf-8 _*_
import torch.utils.data as data

class trueData(data.Dataset):
    def __init__(self,root,txt_path,dataset=None,transforms = None,loader=default_loader):
        with open(txt_path) as data_input:
            lines = data_input.readlines()
            self.images = [os.path.join(root,line.split('\t')[0]) for line in lines] 
            self.labels = [os.path.join(root,line.split('\t')[1]) for line in lines]
        self.transform = transforms
        self.dataset = dataset
        self.loader = loader
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, index):
        img = self.images[index]
        label = self.labels[index]
        img_data = self.loader(img)
        if self.transform:
            try:
                img = self.transform(img)
            except:
                print "error in transform"
        return img,label
        

       调用方法可以这么写,这样就完成了自定义数据的加载过程。

 image_datasets = {x: customData(img_path='/home/badoo/person',
                                    txt_path=('/home/badoo/train_list/' + x + '.txt'),
                                    data_transforms=data_transforms,
                                    dataset=x) for x in ['train', 'val']}

 DataLoader

    在我们训练过程中,前面有讲过通常输入的是tensor格式[N,C,W,H],在Pytorch中提供了一个API批量加载 DataLoader,并将结果进行transform和toTensor()以及BatchNorm等操作,源代码可供参考

 dataloders = {x: torch.utils.data.DataLoader(image_datasets[x],
                                                 batch_size=batch_size,
                                                 shuffle=True) for x in ['train', 'val']}

参数部分

1、dataset,这个就是PyTorch已有的数据读取接口(比如torchvision.datasets.ImageFolder)或者自定义的数据接口的输出,该输出要么是torch.utils.data.Dataset类的对象,要么是继承自torch.utils.data.Dataset类的自定义类的对象。

2、batch_size,根据具体情况设置即可。

3、shuffle,一般在训练数据中会采用。

4、collate_fn,是用来处理不同情况下的输入dataset的封装,一般采用默认即可,除非你自定义的数据读取输出非常少见。

5、batch_sampler,从注释可以看出,其和batch_size、shuffle等参数是互斥的,一般采用默认。

6、sampler,从代码可以看出,其和shuffle是互斥的,一般默认即可。

7、num_workers,从注释可以看出这个参数必须大于等于0,0的话表示数据导入在主进程中进行,其他大于0的数表示通过多个进程来导入数据,可以加快数据导入速度。

8、pin_memory,注释写得很清楚了: pin_memory (bool, optional): If True, the data loader will copy tensors into CUDA pinned memory before returning them. 也就是一个数据拷贝的问题。 

9、timeout,是用来设置数据读取的超时时间的,但超过这个时间还没读取到数据的话就会报错。

     下面是两种接口调用方法,我更喜欢第2种 ^_^

#写法1:
train_data=torch.utils.data.DataLoader(...) 
for i, (input, target) in enumerate(train_data): 
... 

#写法2
train_load = torch.utils.data.DataLoader(train_data,batch_size=batch_size,shuffle=True,num_workers=8)
for i,(ids,labels) in enumerate(train_load):
...

          坚持一件事或许很难,但坚持下来一定很酷!^_^

你可能感兴趣的:(Machine,Learning,Pytorch)