先定义x和y作为数据及其标签,定义我们自己的TestDataset类,在这个类中要实现__getitem__和__len__方法,才可以在Dataloader中使用,注意看Dataloader中输出的格式。
import torch
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset
class TestDataset(Dataset):
def __init__(self, x, y):
self.x = x
self.y = y
def __getitem__(self, item):
return self.x[item], self.y[item]
def __len__(self):
return len(self.x)
if __name__ == '__main__':
BATCH_SIZE = 5 # 批训练的数据个数
x = torch.linspace(1, 10, 10) # x data (torch tensor)
y = torch.linspace(10, 1, 10) # y data (torch tensor)
print(x)
print(y)
print('--------------------')
myDataset = TestDataset(x, y)
for step, (batch_x, batch_y) in enumerate(myDataset):
print('step:{},x:{},y:{}'.format(step, batch_x, batch_y))
print('-----------------')
loader = torch.utils.data.DataLoader(
dataset=myDataset, # torch TensorDataset format
batch_size=BATCH_SIZE, # mini batch size
shuffle=True,
num_workers=2,
)
for epoch in range(3):
for step, (batch_x, batch_y) in enumerate(loader):
print('epoch:{},step:{},x:{},y:{}'.format(epoch, step, batch_x, batch_y))
# 输出的结果如下
tensor([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10.])
tensor([10., 9., 8., 7., 6., 5., 4., 3., 2., 1.])
--------------------
step:0,x:1.0,y:10.0
step:1,x:2.0,y:9.0
step:2,x:3.0,y:8.0
step:3,x:4.0,y:7.0
step:4,x:5.0,y:6.0
step:5,x:6.0,y:5.0
step:6,x:7.0,y:4.0
step:7,x:8.0,y:3.0
step:8,x:9.0,y:2.0
step:9,x:10.0,y:1.0
-----------------
epoch:0,step:0,x:tensor([5., 1., 4., 8., 7.]),y:tensor([ 6., 10., 7., 3., 4.])
epoch:0,step:1,x:tensor([ 3., 10., 9., 6., 2.]),y:tensor([8., 1., 2., 5., 9.])
epoch:1,step:0,x:tensor([ 4., 7., 10., 1., 5.]),y:tensor([ 7., 4., 1., 10., 6.])
epoch:1,step:1,x:tensor([8., 9., 6., 3., 2.]),y:tensor([3., 2., 5., 8., 9.])
epoch:2,step:0,x:tensor([2., 7., 3., 1., 4.]),y:tensor([ 9., 4., 8., 10., 7.])
epoch:2,step:1,x:tensor([ 8., 9., 6., 5., 10.]),y:tensor([3., 2., 5., 6., 1.])
在pytorch官网中,找到collate_fn参数相关的页面。地址为https://pytorch.org/docs/stable/data.html?highlight=dataloader#module-torch.utils.data
个人理解为indices为当前这一批样本的索引,比如为[1,3,5,7,9],再用collate_fn函数对这样dataset进行处理
下面是我的测试代码:
import torch
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset
def test(x):
return x
class TestDataset(Dataset):
def __init__(self, x, y):
self.x = x
self.y = y
def __getitem__(self, item):
return self.x[item], self.y[item]
def __len__(self):
return len(self.x)
def collate_fn(self, batch):
return batch
if __name__ == '__main__':
BATCH_SIZE = 5 # 批训练的数据个数
x = torch.linspace(1, 10, 10) # x data (torch tensor)
y = torch.linspace(10, 1, 10) # y data (torch tensor)
myDataset = TestDataset(x, y)
print(test([myDataset[index] for index in [1,3,5,7,9]]))
print('----------------')
loader2 = torch.utils.data.DataLoader(
dataset=myDataset, # torch TensorDataset format
batch_size=BATCH_SIZE, # mini batch size
shuffle=True, # 要不要打乱数据 (打乱比较好)
num_workers=2, # 多线程来读数据
collate_fn=myDataset.collate_fn
)
for epoch in range(3): # 训练所有!整套!数据 3 次
for step, t in enumerate(loader2): # 每一步 loader 释放一小批数据用来学习
# 假设这里就是你训练的地方...
# 打出来一些数据
print('epoch:{},step:{},data:{}'.format(epoch, step, t))
print('----------------')
下面是结果:
[(tensor(2.), tensor(9.)), (tensor(4.), tensor(7.)), (tensor(6.), tensor(5.)), (tensor(8.), tensor(3.)), (tensor(10.), tensor(1.))]
----------------
epoch:0,step:0,data:[(tensor(8.), tensor(3.)), (tensor(9.), tensor(2.)), (tensor(6.), tensor(5.)), (tensor(5.), tensor(6.)), (tensor(2.), tensor(9.))]
epoch:0,step:1,data:[(tensor(4.), tensor(7.)), (tensor(3.), tensor(8.)), (tensor(7.), tensor(4.)), (tensor(1.), tensor(10.)), (tensor(10.), tensor(1.))]
epoch:1,step:0,data:[(tensor(10.), tensor(1.)), (tensor(7.), tensor(4.)), (tensor(1.), tensor(10.)), (tensor(6.), tensor(5.)), (tensor(4.), tensor(7.))]
epoch:1,step:1,data:[(tensor(5.), tensor(6.)), (tensor(8.), tensor(3.)), (tensor(2.), tensor(9.)), (tensor(9.), tensor(2.)), (tensor(3.), tensor(8.))]
epoch:2,step:0,data:[(tensor(6.), tensor(5.)), (tensor(5.), tensor(6.)), (tensor(7.), tensor(4.)), (tensor(8.), tensor(3.)), (tensor(10.), tensor(1.))]
epoch:2,step:1,data:[(tensor(2.), tensor(9.)), (tensor(4.), tensor(7.)), (tensor(3.), tensor(8.)), (tensor(9.), tensor(2.)), (tensor(1.), tensor(10.))]
----------------
可以看出我的一次测试结果和使用了collate_fn函数的形式是一样的。如果有理解错误的地方,欢迎指教。