pytorch自定义数据集

我们做深度学习大部分时候的数据都是以数据+标注(CV)或者是纯文本(NLP)的形式存在的。

在开始一个项目时首先面对的就是如何把未经处理的数据整合成torch能识别的tensor。为此,torch提供了抽象类Datasets,它能很方便的把你的数据封装成一个可迭代的DataLoader供你使用。

要自定义数据集,首先要继承抽象类torch.utils.data.Dataset,并且需要重载两个重要的函数:__len__ 和__getitem__,前者提供了数据集的大小,后者支持整数索引,范围从0到len(self)。

import torch
from torch.utils import data
class MyDataset(data.Dataset):
    def __init__(self):
        super(MyDataset, self).__init__()
        self.data = torch.randn(8,2)#八个数据,两个一组
    def __getitem__(self, index):
        img,label=self.data[index][0],self.data[index][1]
        return img,label
    def __len__(self):
        return self.data.size()[0]
mydata = MyDataset()

在有标注(例如csv文件)时,我们可以简单的将csv转化为列表来完成__getitem__和__len__操作,__len__需要我们返回自己数据集的长度,__getitem__需要我们返回遍历时每次需要读取的数据(例如图片+标注数据集就返回img和label)

这样,我们自己的数据集就定义好了。接下来需要加载。加载之后的dataloader对象就可以直接遍历了。

print(len(mydata))
data_loader = data.DataLoader(mydata,batch_size=2,shuffle=False)
for img,label in enumerate(data_loader):
    print(img,labbel)

在更多时候我们需要将数据提前处理成对应shape的tensor,这就是数据预处理了,例如图像增强之类的操作都可以在__init__里面写。

你可能感兴趣的:(pytorch,深度学习,python)