[pytorch]构建并加载自己的数据集

[pytorch]构建并加载自己的数据集)

pytorch为我们封装好了很多经典的数据集在torchvision.datasets包里, torchvision.datasets这个包中包含MNIST、FakeData、COCO、LSUN、ImageFolder、DatasetFolder、ImageNet、CIFAR等一些常用的数据集,并且提供了数据集设置的一些重要参数设置,可以通过简单数据集设置来进行数据集的调用。从这些数据集中我们也可以看出数据集设置的主要变量有哪些并且有什么功能对将来自己数据集的设置也有极大的帮助。

在这里简单的举个例子:

torchvision.datasets.MNIST(root,train = True,transform = None,target_transform = None,download = False#  参数介绍:
 #root(string) - 数据集的根目录在哪里MNIST/processed/training.pt 和 MNIST/processed/test.pt存在。
 # train(bool,optional) - 如果为True,则创建数据集training.pt,否则创建数据集test.pt。
 #download(bool,optional) - 如果为true,则从Internet下载数据集并将其放在根目录中。如果已下载数据集,则不会再次下载。
 # transform(callable ,optional) - 一个函数/转换,它接收PIL图像并返回转换后的版本。例如,transforms.RandomCrop
 # target_transform(callable ,optional) - 接收目标并对其进行转换的函数/转换。

torchvision.datasets的具体使用方法详见pytorch官方文档

但是在实践中很多时候我们还是需要设计和加载自己的数据集的,虽然我们可能有现成的数据,比如说图片和它们的标签,但是我们需要将其设计成类方便使用。

必要的函数介绍

为了用pytorch实现自己的数据集,一般是要将自己的数据集设计成一个类,这个类必须包含三个函数:

  1. init(): 这个含糊的参数一般是包含数据所在的文件夹,还有就是对数据进行的transform。
  2. len(): 这个函数不需要参数,一般返回的是数据集的大小。
  3. getitem(): 这个函数的参数一般是索引,返回的是数据集的某一个样本,返回的一般是tensor。

数据集实现示例

以下的示例是逼着自己在实现视频小样本分类的时候为Kinetics数据集实现的一个可以方便被pytorch调用的数据集类,值得一提的是在自己实现类的时候需要继承torch.utils.data.Dataset。

class VideoDataset(Dataset):
    def __init__(self, info_txt, root_dir, mode='train',data_aug=None,transform=None):
        # set params
        self.info_txt=info_txt
        self.root_dir=root_dir
        self.mode=mode
        self.data_aug=data_aug
        self.transform = transform

        # read info_list
        self.info_list=open(self.info_txt).readlines()

    def __len__(self):
        return len(self.info_list)

    def __getitem__(self, idx):
        info_line=self.info_list[idx]
        video_info=info_line.strip('\n')
        video, video_frame_path, video_shape=get_video_from_video_info_2(video_info,mode=self.mode)
        video_label=get_label_from_video_info(video_info,self.info_txt)

        sample = {'video': video, 'label': [int(video_label)],'video_frame_path':video_frame_path,'video_shape':video_shape}

        sample['video'] = torch.FloatTensor(sample['video'])
        sample['label'] = torch.FloatTensor(sample['label'])

        return sample

数据集的加载

数据集的加载分为三部分:数据集的初始化,数据集的load和数据集的使用:

  1. 数据集初始化:
train_data=MyDataset(txt=root+'train.txt', transform=transforms.ToTensor())
test_data=MyDataset(txt=root+'test.txt', transform=transforms.ToTensor())
  1. 数据集的load,这个只需要调用torch.utils.data.dataloader这个函数:
#然后就是调用DataLoader和刚刚创建的数据集,来创建dataloader,这里提一句,loader的长度是有多少个batch,所以和batch_size有关
train_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_data, batch_size=64)
  1. 最后方便理解,笔者贴上了使用的方法:
for batch_index, data, target in test_loader:
        if use_cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data, volatile=True), Variable(target)

你可能感兴趣的:([pytorch]构建并加载自己的数据集)