Pytorch Dataset类

Pytorch Dataset类

数据集的类

  • torch.utils.data.Dataset
    • 实现两个方法
    • _getitem_(self,index)
      • 获取索引对应位置的一条数据
    • _len_(self)
      • 返回数据的总数量
import torch
import torch.utils.data as Data

class myDataSet(Data.Dataset):
    def __init__(self):
        self.f=open('data','r').readlines()
    def __getitem__(self, item):
        return self.f[item]
    def __len__(self):
        return len(self.f)

dataset=myDataSet()
print('len = ',len(dataset))
print('dataset[1] = ',dataset[0])
print('dataset[2] = ',dataset[1])
len =  2
dataset[1] =  1    2   3     4    5    6    7     8    9   10

dataset[2] =  3.5    4.7    7.7    8.3    10.9   14   14.8   17.2    18.9    21.3

数据加载器类

  • 批处理数据(batch)
  • 打乱数据(shuffle=True)
  • 使用多线程相乘并行加载数据(num_workers)
  • 删除mod(batch)多余的元素(drop_last=True)

DataLoader(dataset=dataset,batch_size=10,shuffle=True,num_workers=2)

enumerate()返回遍历的序号

data_loader=Data.DataLoader(dataset=dataset,batch_size=1,shuffle=True)

print('data_loader = ',data_loader)
print('len(data_loader) = ',len(data_loader))
for index,i in enumerate(data_loader):
    print('这是第%d个元素'%index)
    print('i = ',i)
    print('i[0].strip() = ',i[0].strip())
data_loader =  <torch.utils.data.dataloader.DataLoader object at 0x0000022132F02E48>
len(data_loader) =  2
这是第0个元素
i =  ['1    2   3     4    5    6    7     8    9   10\n']
i[0].strip() =  1    2   3     4    5    6    7     8    9   10
这是第1个元素
i =  ['3.5    4.7    7.7    8.3    10.9   14   14.8   17.2    18.9    21.3']
i[0].strip() =  3.5    4.7    7.7    8.3    10.9   14   14.8   17.2    18.9    21.3

data_loader也支持len方法

  • 向上取整(math.ceil())

附录:

data(文件)

1    2   3     4    5    6    7     8    9   10
3.5    4.7    7.7    8.3    10.9   14   14.8   17.2    18.9    21.3

你可能感兴趣的:(Pytorch Dataset类)