本文记录学习过程中遇到的问题、我的解决过程以及学习心得,如有错误之处,欢迎指正!
在学习用pytorch进行数据批处理的过程中用到了torch.utils.data.TensorDataset()和torch.utils.data.DataLoader()函数,练习的代码如下:
import torch
import torch.utils.data as Data
torch.manual_seed(1) # 这句有关生成随机数,他会使得随机生成的结果是确定的
BATCH_SIZE= 5 # 设置批次训练数量
# 定义数据
x = torch.linspace(1, 10, steps=10) # torch.linspace()线性等分向量,前两个参数是向量的开始和结束值,steps是分割出的点数,缺省值100
y = torch.linspace(10, 1, steps=10) # x,y都是十维向量
torch_dataset = Data.TensorDataset(x, y) # x,y对应整合进数据集,应该是一个二维数据的队列(10*2矩阵)
loader = Data.DataLoader(
dataset=torch_dataset, # 加载数据集
batch_size=BATCH_SIZE, # 批次大小
shuffle=True, # 是否打乱顺序训练
num_workers=2 # 设置线程数
)
def show_batch():
for epoch in range(3): # 进行三轮训练
for step, (batch_x, batch_y) in enumerate(loader): # 每轮训练进行两批(一批5个数据)
# train your data...
print('Epoch: ', epoch, '| Step: ', step, '| batch x: ',
batch_x.numpy(), '| batch y: ', batch_y.numpy())
if __name__== '__main__':
show_batch()
莫烦pytorch课程中整合数据集这一步骤用到的代码是:
torch_dataset = Data.TensorDataset(data_tensor=x, target_tensor=y)
这一句在运行过程中会报错:
TypeError: __init__() got an unexpected keyword argument 'data_tensor'
查看TensorDataset的声明:
class TensorDataset(Dataset):
def __init__(self, *tensors):
assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
self.tensors = tensors
def __getitem__(self, index):
return tuple(tensor[index] for tensor in self.tensors)
def __len__(self):
return self.tensors[0].size(0)
再查看Dataset的的声明(篇幅太长,略去代码展示)确实不存在data_tensor或者target_tensor的参数,论坛上有人说这是由于版本不同造成的。去掉形参名直接赋值即可解决问题。
DEBUG之前的代码没有定义函数show_batch,也没有if __name__=='__main__'语句,直接执行show_batch()函数中的内容。或者定义show_batch()函数并直接调用也会导致出错。错误代码:
RuntimeError: DataLoader worker (pid(s) 12384, 10160) exited unexpectedly
这类错误是由于多线程处理造成的,若要直接对loader迭代,则需要去掉对loader赋值时Data.DataLoader()函数中的num_workers=2语句,或者赋值0.
如果需要采用多线程处理的话,最好采用最上面代码的方法定义主函数并调用步骤,直接在主函数中执行show_batch()函数的执行语句同样可以实现。
为了方便理解,我采用了如下代码,希望了解各个类中数据时如何存储的:
print(x)
print(y)
print(torch_dataset)
print(loader)
'''
输出结果为:
tensor([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10.])
tensor([10., 9., 8., 7., 6., 5., 4., 3., 2., 1.])
'''
由输出结果可见,类的数据不能直接读取,查看了pytorch中文文档中关于torch.utils.data的描述之后了解到TensorDataset通过第一个维度索引两个张量来恢复每个样本,可以通过torch_dataset[index]来读出TensorDataset中的数据,也可以通过for循环遍历数据:
for each in torch_dataset:
print(each)
'''
输出结果:
(tensor(1.), tensor(10.))
(tensor(2.), tensor(9.))
(tensor(3.), tensor(8.))
(tensor(4.), tensor(7.))
(tensor(5.), tensor(6.))
(tensor(6.), tensor(5.))
(tensor(7.), tensor(4.))
(tensor(8.), tensor(3.))
(tensor(9.), tensor(2.))
(tensor(10.), tensor(1.))
'''
而在研究DataLoader中数据存储时无法通过遍历或者迭代实现数据读出,代码和错误信息如下:
dataiter = iter(loader)
data, labels = next(loader)
print(data, labels)
# TypeError: 'DataLoader' object is not an iterator
print(loader[0])
for each in loader:
print(each)
# TypeError: 'DataLoader' object does not support indexing
看文档说DataLoader类是可以迭代的,但是错误信息指出DataLoader类不能索引或者迭代。
奈何本人半路出家,学艺不精...这个问题我仍然没有解决,并且暂时没有查找到相关资料,这里先挖一个坑,等解决了再来更新吧...