pytorch为我们封装好了很多经典的数据集在torchvision.datasets包里, torchvision.datasets这个包中包含MNIST、FakeData、COCO、LSUN、ImageFolder、DatasetFolder、ImageNet、CIFAR等一些常用的数据集,并且提供了数据集设置的一些重要参数设置,可以通过简单数据集设置来进行数据集的调用。从这些数据集中我们也可以看出数据集设置的主要变量有哪些并且有什么功能对将来自己数据集的设置也有极大的帮助。
在这里简单的举个例子:
torchvision.datasets.MNIST(root,train = True,transform = None,target_transform = None,download = False )
# 参数介绍:
#root(string) - 数据集的根目录在哪里MNIST/processed/training.pt 和 MNIST/processed/test.pt存在。
# train(bool,optional) - 如果为True,则创建数据集training.pt,否则创建数据集test.pt。
#download(bool,optional) - 如果为true,则从Internet下载数据集并将其放在根目录中。如果已下载数据集,则不会再次下载。
# transform(callable ,optional) - 一个函数/转换,它接收PIL图像并返回转换后的版本。例如,transforms.RandomCrop
# target_transform(callable ,optional) - 接收目标并对其进行转换的函数/转换。
torchvision.datasets的具体使用方法详见pytorch官方文档
但是在实践中很多时候我们还是需要设计和加载自己的数据集的,虽然我们可能有现成的数据,比如说图片和它们的标签,但是我们需要将其设计成类方便使用。
为了用pytorch实现自己的数据集,一般是要将自己的数据集设计成一个类,这个类必须包含三个函数:
以下的示例是逼着自己在实现视频小样本分类的时候为Kinetics数据集实现的一个可以方便被pytorch调用的数据集类,值得一提的是在自己实现类的时候需要继承torch.utils.data.Dataset。
class VideoDataset(Dataset):
def __init__(self, info_txt, root_dir, mode='train',data_aug=None,transform=None):
# set params
self.info_txt=info_txt
self.root_dir=root_dir
self.mode=mode
self.data_aug=data_aug
self.transform = transform
# read info_list
self.info_list=open(self.info_txt).readlines()
def __len__(self):
return len(self.info_list)
def __getitem__(self, idx):
info_line=self.info_list[idx]
video_info=info_line.strip('\n')
video, video_frame_path, video_shape=get_video_from_video_info_2(video_info,mode=self.mode)
video_label=get_label_from_video_info(video_info,self.info_txt)
sample = {'video': video, 'label': [int(video_label)],'video_frame_path':video_frame_path,'video_shape':video_shape}
sample['video'] = torch.FloatTensor(sample['video'])
sample['label'] = torch.FloatTensor(sample['label'])
return sample
数据集的加载分为三部分:数据集的初始化,数据集的load和数据集的使用:
train_data=MyDataset(txt=root+'train.txt', transform=transforms.ToTensor())
test_data=MyDataset(txt=root+'test.txt', transform=transforms.ToTensor())
#然后就是调用DataLoader和刚刚创建的数据集,来创建dataloader,这里提一句,loader的长度是有多少个batch,所以和batch_size有关
train_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_data, batch_size=64)
for batch_index, data, target in test_loader:
if use_cuda:
data, target = data.cuda(), target.cuda()
data, target = Variable(data, volatile=True), Variable(target)