创建用于DataLoader的pytorch数据集

暂时介绍 image-mask型数据集, 以人手分割数据集 EGTEA Gaze+ 为例.

准备数据文件夹

  • 需要将ImageMask分开存放, 对应文件的文件名必须保持一致. 提醒: Mask 图像一般为 png 单通道
  • EGTEA Gaze+ 数据集下载解压后即得到如下的目录, 无需处理
hand14k
┣━ Images
┃	┣━ OP01-R01-PastaSalad_000014.jpg
┃	┣━ OP01-R01-PastaSalad_000015.jpg
┃	┣━ OP01-R01-PastaSalad_000016.jpg
┃	┗━ ···
┗━ Masks
	┣━ OP01-R01-PastaSalad_000014.png
	┣━ OP01-R01-PastaSalad_000015.png
	┣━ OP01-R01-PastaSalad_000016.png
	┗━ ···

生成路径文件, 划分数据集

创建用于DataLoader的pytorch数据集_第1张图片
脚本如下:

import cv2 as cv
import numpy as np
import PIL.Image as Image
import os

np.random.seed(42)


def split_dataset():
    # 读取图像文件
    images_path = "./Images/"
    images_list = os.listdir(images_path)  # 每次返回文件列表顺序不一致
    images_list.sort()  # 需要排序处理

    # 读取标签/Mask图像
    labels_path = "./Masks/"
    labels_list = os.listdir(labels_path)
    labels_list.sort()

    # 创建路径文件 (使用二进制编码, 避免操作系统不匹配)
    train_file = "./train.data"
    test_file = "./test.data"
    if os.path.isfile(train_file) and os.path.isfile(test_file):
        return
    train_file = open(train_file, "wb")
    test_file = open(test_file, "wb")

    # 划分数据集
    split_ratio = 0.8
    for image, label in zip(images_list, labels_list):
        image = os.path.join(images_path, image)
        label = os.path.join(labels_path, label)
        if os.path.basename(image).split('.')[0] != os.path.basename(label).split('.')[0]:
            continue
        file = train_file if np.random.rand() < split_ratio else test_file
        file.write((image + "\t" + label + "\n").encode("utf-8"))
    train_file.close()
    test_file.close()
    print("成功划分数据集!")


def read_image(path):
    img = np.array(Image.open(path))
    if img.ndim == 2:
        img = cv.merge([img, img, img])
    return img


def test_read():
    train_file = "./test.data"
    with open(train_file, 'rb') as f:
        datalist = f.readlines()
    datalist = [(k, v) for k, v in map(lambda x: x.decode('utf-8').strip('\n').split('\t'), datalist)]

    item = datalist[np.random.randint(42)]
    image = read_image(item[0])
    mask = read_image(item[1])
    cv.imshow("image", image)
    cv.imshow("mask", mask)
    cv.waitKey(0)
    cv.destroyAllWindows()


if __name__ == '__main__':
    split_dataset()
    test_read()

派生 Dataset 类

class MyDataset(Dataset):

    def __init__(
        self, data_file, data_dir, transform_trn=None, transform_val=None
        ):
        """
        Args:
            data_file (string): Path to the data file with annotations.
            data_dir (string): Directory with all the images.
            transform_{trn, val} (callable, optional): Optional transform to be applied
                on a sample.
        """
        with open(data_file, 'rb') as f:
            datalist = f.readlines()
        self.datalist = [(k, v) for k, v in map(lambda x: x.decode('utf-8').strip('\n').split('\t'), datalist)]
        self.root_dir = data_dir
        self.transform_trn = transform_trn
        self.transform_val = transform_val
        self.stage = 'train'

    def set_stage(self, stage):
        self.stage = stage

    def __len__(self):
        return len(self.datalist)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.datalist[idx][0])
        msk_name = os.path.join(self.root_dir, self.datalist[idx][1])
        def read_image(x):
            img_arr = np.array(Image.open(x))
            if len(img_arr.shape) == 2: # grayscale
                img_arr = np.tile(img_arr, [3, 1, 1]).transpose(1, 2, 0)
            return img_arr
        image = read_image(img_name)
        mask = np.array(Image.open(msk_name))
        if img_name != msk_name:
            assert len(mask.shape) == 2, 'Masks must be encoded without colourmap'
        sample = {'image': image, 'mask': mask}
        if self.stage == 'train':
            if self.transform_trn:
                sample = self.transform_trn(sample)
        elif self.stage == 'val':
            if self.transform_val:
                sample = self.transform_val(sample)
        return sample

构造DataLoader

# 定义Transform
composed_trn = transforms.Compose([ResizeShorterScale(shorter_side, low_scale, high_scale),
                                       Pad(crop_size, [123.675, 116.28, 103.53], ignore_label),
                                       RandomMirror(),
                                       RandomCrop(crop_size),
                                       Normalise(*normalise_params),
                                       ToTensor()])
composed_val = transforms.Compose([Normalise(*normalise_params),
                                       ToTensor()])

# 导入数据集
trainset = MyDataset(data_file=train_list,
                     data_dir=train_dir,
                     transform_trn=composed_trn,
                     transform_val=composed_val)
valset = MyDataset(data_file=val_list,
                   data_dir=val_dir,
                   transform_trn=None,
                   transform_val=composed_val)

# 构建生成器
train_loader = DataLoader(trainset,
                          batch_size=batch_size,
                          shuffle=True,
                          num_workers=num_workers,
                          pin_memory=True,
                          drop_last=True)
val_loader = DataLoader(valset,
                        batch_size=1,
                        shuffle=False,
                        num_workers=num_workers,
                        pin_memory=True)

训练

for i, sample in enumerate(train_loader):
    image = sample['image'].cuda()
    target = sample['mask'].cuda()
    image_var = torch.autograd.Variable(image).float()
    target_var = torch.autograd.Variable(target).long()
    # Compute output
    output = net(image_var)
    ...

你可能感兴趣的:(深度学习,dataset,pytorch,数据集,dataloader)