from torch.utils.data import DataLoader, Dataset, TensorDataset
主要是在使用过程中对不熟悉的api做个记录。
作用:可以用来对tensor进行打包。
代码:
a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9]])
b = torch.tensor([44, 55, 66, 44, 55, 66, 44, 55, 66, 44, 55, 66])
# TensorDataset对tensor进行打包
train_ids = TensorDataset(a, b)
for x_train, y_label in train_ids:
print(x_train, y_label)
# dataloader进行数据封装
print('=' * 80)
train_loader = DataLoader(dataset=train_ids, batch_size=4, shuffle=True)
for i, data in enumerate(train_loader, 1):
# 注意enumerate返回值有两个,一个是序号,一个是数据(包含训练数据和标签)
x_data, label = data
print(' batch:{0} x_data:{1} label: {2}'.format(i, x_data, label))
运行结果:
tensor([1, 2, 3]) tensor(44)
tensor([4, 5, 6]) tensor(55)
tensor([7, 8, 9]) tensor(66)
tensor([1, 2, 3]) tensor(44)
tensor([4, 5, 6]) tensor(55)
tensor([7, 8, 9]) tensor(66)
tensor([1, 2, 3]) tensor(44)
tensor([4, 5, 6]) tensor(55)
tensor([7, 8, 9]) tensor(66)
tensor([1, 2, 3]) tensor(44)
tensor([4, 5, 6]) tensor(55)
tensor([7, 8, 9]) tensor(66)
================================================================================
batch:1 x_data:tensor([[1, 2, 3],
[1, 2, 3],
[4, 5, 6],
[4, 5, 6]]) label: tensor([44, 44, 55, 55])
batch:2 x_data:tensor([[4, 5, 6],
[7, 8, 9],
[7, 8, 9],
[7, 8, 9]]) label: tensor([55, 66, 66, 66])
batch:3 x_data:tensor([[1, 2, 3],
[1, 2, 3],
[7, 8, 9],
[4, 5, 6]]) label: tensor([44, 44, 66, 55])
pytorch之TensorDataset - Ronin的文章 - 知乎 https://zhuanlan.zhihu.com/p/349083821