【PyTorch学习】(三)自定义Datasets

torchvision.datasets源码地址:https://github.com/pytorch/vision/blob/master/torchvision/datasets


前两篇从搭建经典的ResNet,DenseNet入手简单的了解了下PyTorch搭建网络的方式,但训练一个模型光光搭建好一个网络是不够的,正所谓巧妇难为无米之炊,如何将数据处理成网络可以传递的Tensor也尤为重要,而数据准备过程最最最最最重要的就是DatasetsDataloader两部分!

torchvision.datasets.ImageFolder就是官方给出的一个datasets的事例,具体使用直接贴上官方tutorial上的代码供参考:

data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

data_dir = 'hymenoptera_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
                                             shuffle=True, num_workers=4)
              for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

但由于torchvision.datasets.ImageFolder函数的使用必须对数据的放置有要求,必须在data_dir目录下放置train和val两个文件夹,然后每个文件夹下,每一类图片单独放在一个文件夹里。官方的例子是ants和bees,所以在train和val文件夹下都有ants和bees这两个文件夹,分别放置相应的文件。

那么问题就来了,我们通常打完标签,是不会根据标签进行分类,而且在进行目标检测时一张图可能对应有多个标签,而是通过一个xml文件或者json文件用于记录label信息,所以是不满足ImageFolder的要求的。

所以根据实际数据情况,自定义Datasets就很关键,接下来我们就根据ImageFolder的函数形式,顺藤摸瓜从头来看如何自定义一个Datasets!


一、torch.utils.data.Dataset

首先,可以看到ImageFolder类继承了DatasetFolder类,DatasetFolder类又继承了torch一个基础的抽象类torch.utils.data.Dataset类。

class Dataset(object):
    """An abstract class representing a Dataset.

    All other datasets should subclass it. All subclasses should override
    ``__len__``, that provides the size of the dataset, and ``__getitem__``,
    supporting integer indexing in range from 0 to len(self) exclusive.
    """

    def __getitem__(self, index):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError

    def __add__(self, other):
        return ConcatDataset([self, other])

自定义Datasets的关键就是重载 "__len__"和"__getitem__"两个函数!而 "__add__"函数的作用是使得类定义对象拥有"object1 + object2"的功能,一般情况不需要重载该函数。

  1. __len__函数:使得类对象拥有 "len(object)"功能,返回dataset的size。

  2. __getitem__函数:使得类对象拥有"object[index]"功能,可以用索引i去获得第i+1个样本。

二、torchvision.datasets.CocoDetection

再来看看同样继承于torch.utils.data.Dataset的CocoDetection dataset是如何定义上述两个函数的!

1.__init__:

def __init__(self, root, annFile, transform=None, target_transform=None):
    # 从cocoapi导入pycocotools下的COCO类
    from pycocotools.coco import COCO
    self.root = root
    # 初始化一个COCO对象
    self.coco = COCO(annFile)
    # 将每张图unique的id属性转化为list存储在self.ids中
    self.ids = list(self.coco.imgs.keys())
    self.transform = transform
    self.target_transform = target_transform

(1)初始化函数可以接受四个参数:

  • root: COCO形式的数据集的根目录地址。
  • annFile: COCO形式的数据集中.json文件的目录地址。
  • transform: 原始图像是否需要进行变换(数据增强,默认是None不做增强)。
  • target_transform: 标签是否需要进行变换(标签变换需要和原始图像变换相对应,默认是None不做增强)。

(2)初始化COCO对象时,将.json文件解析为字典形式导入内存,并创建调用createIndex()创建索引

(3)self.coco.imgs是以每张图unique的id作为key,json文件images下每一image信息作为value的一个字典。

2.__len__:

def __len__(self):
    # 因为图片的id是unique的,所以self.ids的长度就等于总图片数
    return len(self.ids)

3.__getitem__:

def __getitem__(self, index):
    """
    Args:
        index (int): Index

    Returns:
        tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``.
    """
    coco = self.coco
    # 通过索引获得图片的id
    img_id = self.ids[index]
    # 再通过getAnnIds方法利用img_id找到对应的anno_id
    ann_ids = coco.getAnnIds(imgIds=img_id)
    # 根据anno_id和标签之间的映射关系,解析出标签target
    target = coco.loadAnns(ann_ids)
       
    path = coco.loadImgs(img_id)[0]['file_name']
    # 根据每张图的file_name结合之前传入的图片放置的根目录读取图片信息
    img = Image.open(os.path.join(self.root, path)).convert('RGB')
    # 判断是否需要进行数据增强
    if self.transform is not None:
        img = self.transform(img)
    # 判断标签是否需要进行变换
    if self.target_transform is not None:
        target = self.target_transform(target)
        
    # 最终返回值形式可以根据自己需要进行设计。此处为一个tuple,包含一张图片以及对应的标签。
    return img, target

三、自定义人脸关键点dataset

以下这个例子就是自定义的FaceLandmarksDataset,效果是从.csv文件中读取每张图上的68个人脸面部关键点的坐标x,y,然后根据.csv文件中对应的图片名,读取相应的图片,然后返回值是一个sample字典,包含'image'和'landmarks'两个key。

class FaceLandmarksDataset(Dataset):
    
    def __init__(self, root_dir, csv_file, transform=None):
        self.root_dir = root_dir
        self.landmarks_frame = pd.read_csv(csv_file)
        self.transform = transform
        
    def __len__(self):
        return len(self.landmarks_frame)
    
    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.landmarks_frame.iloc[idx, 0])
        image = io.imread(img_name)
        landmarks = self.landmarks_frame.iloc[idx, 1:].as_matrix()
        landmarks = landmarks.astype('float').reshape(-1, 2)
        sample = {'image': image, 'landmarks': landmarks}

        if self.transform:
            sample = self.transform(sample)

        return sample

数据准备阶段datasets部分就简单介绍完了,下篇继续介绍另一个关键部分dataloader!

你可能感兴趣的:(PyTorch学习)