Pytorch自定义数据集(Custom Dataset)的读取方式

外部数据集的接入

      • 相关模块:torchvision
      • 具体操作
          • 自定义数据集的基础方法:
          • 使用 Torchvision Transforms
      • 结合 Pandas 使用 __getitem__()
      • 使用 Dataloader 读取自定义数据集
      • Stanford Dogs 数据集自定义实例
      • FaceLandmarks实例

相关模块:torchvision

torchvision 是独立于pytorch 之外的图像操作库
具体介绍详见:DrHW的文章

torchvision主要包括一下几个包:1

  • torchvision.datasets : 几个常用视觉数据集,可以下载和加载这里主要的高级用法就是可以看源码如何自己写自己的Dataset的子类
    这部分就是本文要介绍的重点
  • torchvision.models: 流行的模型,例如 AlexNet, VGG, ResNet 和 Densenet 以及 与训练好的参数。
  • torchvision.transforms : 常用的图像操作,例如:随机切割,旋转,数据类型转换,图像到tensor ,numpy 数组到tensor , tensor 到 图像等。
  • torchvision.utils : 用于把形似 (3 x H x W) 的张量保存到硬盘中,给一个mini-batch的图像可以产生一个图像格网。
    shape = (channel, height, weight)

具体操作

  • 自定义数据集的基础方法:

引文2

"""
inout pipline for custom dataset
"""
from torch.utils.data.dataset import Dataset
class CustomDataset(Dataset):
    def __init__(self):
    	"""
    	一些初始化过程写在这里
    	"""
        # TODO
        # 1. Initialize file paths or a list of file names. 
        pass
    def __getitem__(self, index):
    	"""
    	返回数据和标签,可以这样显示调用:
    	img, label = MyCustomDataset.__getitem__(99)
    	"""
        # TODO
        # 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
        # 2. Preprocess the data (e.g. torchvision.Transform).
        # 3. Return a data pair (e.g. image and label).
        pass
    def __len__(self):
    	"""
    	返回所有数据的数量
    	"""
        # You should change 9 to the total size of your dataset.
        return 9 # e.g. 9 is size of dataset
使用 Torchvision Transforms
  • 方法一:
from torch.utils.data.dataset import Dataset
from torchvision import transforms

class MyCustomDataset(Dataset):
    def __init__(self, ..., transforms=None):
        # stuff
        ...
        self.transforms = transforms
        
    def __getitem__(self, index):
        # stuff
        ...
        data = # 一些读取的数据
        if self.transforms is not None:
            data = self.transforms(data)
        # 如果 transform 不为 None,则进行 transform 操作
        return (img, label)
 
    def __len__(self):
        return count 
        
if __name__ == \'__main__\':
    # 定义我们的 transforms (1)
    transformations = transforms.Compose([transforms.CenterCrop(100), transforms.ToTensor()])
    # 创建 dataset
    custom_dataset = MyCustomDataset(..., transformations)
  • 方法二:
    有些人不喜欢将transform写在Dataset外, 即在Dataset内定义transform

from torch.utils.data.dataset import Dataset
from torchvision import transforms
 
class MyCustomDataset(Dataset):
    def __init__(self, ...):
        # stuff
        ...
        # (2) 一种方法是单独定义 transform
        self.center_crop = transforms.CenterCrop(100)
        self.to_tensor = transforms.ToTensor()
        
        # (3) 或者写成下面这样 
        self.transformations = \
            transforms.Compose([transforms.CenterCrop(100),
                                transforms.ToTensor()])
        
    def __getitem__(self, index):
        # stuff
        ...
        data = #一些读取的数据
        
        # 当第二次调用 transform 时,调用的是 __call__()
        data = self.center_crop(data)  # (2)
        data = self.to_tensor(data)  # (2)
        
        # 或者写成下面这样
        data = self.trasnformations(data)  # (3)
        
        # 注意 (2) 和 (3) 中只需要实现一种
        return (img, label)
 
    def __len__(self):
        return count
        
if __name__ == \'__main__\':
    custom_dataset = MyCustomDataset(...)

结合 Pandas 使用 getitem()

另一种情况是 csv 文件中保存了我们需要的图像文件的像素值(比如有些 MNIST 教程就是这样的)。我们需要改动一下 getitem() 函数。

Label pixel_1 pixel_2
1 50 99
0 21 223
9 44 112
class CustomDatasetFromCSV(Dataset):
    def __init__(self, csv_path, height, width, transforms=None):
        """
        Args:
            csv_path (string): csv 文件路径
            height (int): 图像高度
            width (int): 图像宽度
            transform: transform 操作
        """
        self.data = pd.read_csv(csv_path)
        self.labels = np.asarray(self.data.iloc[:, 0])
        self.height = height
        self.width = width
        self.transforms = transform
 
    def __getitem__(self, index):
        single_image_label = self.labels[index]
        # 读取所有像素值,并将 1D array ([784]) reshape 成为 2D array ([28,28]) 
        img_as_np = np.asarray(self.data.iloc[index][1:]).reshape(28,28).astype(\'uint8\')
	# 把 numpy array 格式的图像转换成灰度 PIL image
        img_as_img = Image.fromarray(img_as_np)
        img_as_img = img_as_img.convert(\'L\')
        # 将图像转换成 tensor
        if self.transforms is not None:
            img_as_tensor = self.transforms(img_as_img)
        # 返回图像及其 label
        return (img_as_tensor, single_image_label)
 
    def __len__(self):
        return len(self.data.index)
        
 
if __name__ == "__main__":
    transformations = transforms.Compose([transforms.ToTensor()])
    custom_mnist_from_csv = \
        CustomDatasetFromCSV(\'../data/mnist_in_csv.csv\', 28, 28, transformations)

使用 Dataloader 读取自定义数据集

PyTorch 中的 Dataloader 只是调用 getitem() 方法并组合成 batch,我们可以这样调用:


...
if __name__ == "__main__":
    # 定义 transforms
    transformations = transforms.Compose([transforms.ToTensor()])
    # 自定义数据集
    custom_mnist_from_csv = \
        CustomDatasetFromCSV(\'../data/mnist_in_csv.csv\',
                             28, 28,
                             transformations)
    # 定义 data loader
    mn_dataset_loader = torch.utils.data.DataLoader(dataset=custom_mnist_from_csv,
                                                    batch_size=10,
                                                    shuffle=False)
    
    for images, labels in mn_dataset_loader:
        # 将数据传给网络模型 

需要注意的是使用多卡训练时,PyTorch dataloader 会将每个 batch 平均分配到各个 GPU。所以如果 batch size 过小,可能发挥不了多卡的效果。

Stanford Dogs 数据集自定义实例

from torch.utils.data.dataset import Dataset
from torchvision import transforms

class MyDateset(Dataset):
    def __init__(self, file_folder, is_test=False, transform=None):
        self.img_folder_path = '../input/images/Images/'
        self.annotation_folder_path = '../input/annotations/Annotation/'
        self.file_folder = file_folder
        self.transform = transform
        #self.transform = transforms.Compose
        self.is_test = is_test
        
    def __getitem__(self, idx):
        file = self.file_folder[idx]
        img_path = self.img_folder_path + file
        img = Image.open(img_path).convert('RGB')
        
        if not self.is_test:
            annotation_path = self.annotation_folder_path + file.split('.')[0]
            with open(annotation_path) as f:
                annotation = f.read()

            xy = self.get_xy(annotation)
            box = torch.FloatTensor(list(xy))

            new_box = self.box_resize(box, img)
            if self.transform is not None:
                img = self.transform(img)

            return img, new_box
        else:
            if self.transform is not None:
                img = self.transform(img)
            return img
    
    def __len__(self):
        return len(self.file_folder)
        
    def get_xy(self, annotation):
        xmin = int(re.findall('(?<=)[0-9]+?(?=)', annotation)[0])
        xmax = int(re.findall('(?<=)[0-9]+?(?=)', annotation)[0])
        ymin = int(re.findall('(?<=)[0-9]+?(?=)', annotation)[0])
        ymax = int(re.findall('(?<=)[0-9]+?(?=)', annotation)[0])
        
        return xmin, ymin, xmax, ymax
    
    def show_box(self):
        file = random.choice(self.file_folder)
        annotation_path = self.annotation_folder_path + file.split('.')[0]
        
        img_box = Image.open(self.img_folder_path + file)
        with open(annotation_path) as f:
            annotation = f.read()
            
        draw = ImageDraw.Draw(img_box)
        xy = self.get_xy(annotation)
        print('bbox:', xy)
        draw.rectangle(xy=[xy[:2], xy[2:]])
        
        return img_box
        
    def box_resize(self, box, img, dims=(332, 332)):
        old_dims = torch.FloatTensor([img.width, img.height, img.width, img.height]).unsqueeze(0)
        new_box = box / old_dims
        new_dims = torch.FloatTensor([dims[1], dims[0], dims[1], dims[0]]).unsqueeze(0)
        new_box = new_box * new_dims
        
        return new_box

FaceLandmarks实例

class FaceLandmarksDataset(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, csv_file, root_dir, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.landmarks_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.landmarks_frame)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir,
                                self.landmarks_frame.iloc[idx, 0])
        image = io.imread(img_name)
        landmarks = self.landmarks_frame.iloc[idx, 1:].as_matrix()
        landmarks = landmarks.astype('float').reshape(-1, 2)
        sample = {'image': image, 'landmarks': landmarks}

        if self.transform:
            sample = self.transform(sample)

        return sample

参考文献:

  • yunjey的github代码
  • pytorch官方教程
  • 数据集:Stanford Dogs Dataset
  • pytorch中文网:PyTorch 中自定义数据集的读取方法小结

  1. https://www.cnblogs.com/yjphhw/p/9773333.html ↩︎

  2. https://github.com/yunjey/pytorch-tutorial/ ↩︎

你可能感兴趣的:(应用,理论)