torch.utils.data.TensorDataset(data_tensor, target_tensor)
x_1 = torch.arange(30).reshape(-1,3)
y_1 = torch.arange(10)*3
# TensorDataset对tensor进行打包
dataset = data.TensorDataset(x_1, y_1)
for x_1_train, y_1_train in dataset:
print(x_1_train, y_1_train)
# dataloader进行数据封装
print('=' * 80)
train_loader = data.DataLoader(dataset=dataset, batch_size=4, shuffle=True)
for i, data_ in enumerate(train_loader):
# enumerate返回值有两个,一个是序号,一个是数据(包含训练数据和标签)
x_1_train, y_1_train = data_
print(f' batch:{i+1} x_data:{x_1_train} y_data:{y_1_train}')
运行解果:
tensor([0, 1, 2]) tensor(0)
tensor([3, 4, 5]) tensor(3)
tensor([6, 7, 8]) tensor(6)
tensor([ 9, 10, 11]) tensor(9)
tensor([12, 13, 14]) tensor(12)
tensor([15, 16, 17]) tensor(15)
tensor([18, 19, 20]) tensor(18)
tensor([21, 22, 23]) tensor(21)
tensor([24, 25, 26]) tensor(24)
tensor([27, 28, 29]) tensor(27)
================================================================================
batch:1 x_data:tensor([[27, 28, 29],
[15, 16, 17],
[ 3, 4, 5],
[ 6, 7, 8]]) y_datatensor([27, 15, 3, 6])
batch:2 x_data:tensor([[ 9, 10, 11],
[ 0, 1, 2],
[21, 22, 23],
[18, 19, 20]]) y_datatensor([ 9, 0, 21, 18])
batch:3 x_data:tensor([[12, 13, 14],
[24, 25, 26]]) y_datatensor([12, 24])
这里也可以有多个输入和输出,比如可以输入两个数据,输出两个label。
torch.utils.data.TensorDataset(data_tensor1, data_tensor2, target_tensor1, target_tensor2)
x_1 = torch.arange(30).reshape(-1,3)
x_2 = torch.arange(10)*3
y_1 = torch.arange(10)*3 + 1
y_2 = torch.arange(10)*3 + 2
# TensorDataset对tensor进行打包
dataset = data.TensorDataset(x_1, x_2, y_1, y_2)
for x_1_train, x_2_train, y_1_train, y_2_train in dataset:
print(x_1_train, x_2_train, y_1_train, y_2_train)
# dataloader进行数据封装
print('=' * 80)
train_loader = data.DataLoader(dataset=dataset, batch_size=4, shuffle=True)
for i, data_ in enumerate(train_loader):
# 注意enumerate返回值有两个,一个是序号,一个是数据(包含训练数据和标签)
x_1_train, x_2_train, y_1_train, y_2_train = data_
print(f' batch:{i+1} x_data:{x_1_train, x_2_train} y_data:{y_1_train, y_2_train}')
输出结果:
tensor([0, 1, 2]) tensor(0) tensor(1) tensor(2)
tensor([3, 4, 5]) tensor(3) tensor(4) tensor(5)
tensor([6, 7, 8]) tensor(6) tensor(7) tensor(8)
tensor([ 9, 10, 11]) tensor(9) tensor(10) tensor(11)
tensor([12, 13, 14]) tensor(12) tensor(13) tensor(14)
tensor([15, 16, 17]) tensor(15) tensor(16) tensor(17)
tensor([18, 19, 20]) tensor(18) tensor(19) tensor(20)
tensor([21, 22, 23]) tensor(21) tensor(22) tensor(23)
tensor([24, 25, 26]) tensor(24) tensor(25) tensor(26)
tensor([27, 28, 29]) tensor(27) tensor(28) tensor(29)
================================================================================
batch:1 x_data:(tensor([[24, 25, 26],
[27, 28, 29],
[21, 22, 23],
[ 3, 4, 5]]), tensor([24, 27, 21, 3])) y_data:(tensor([25, 28, 22, 4]), tensor([26, 29, 23, 5]))
batch:2 x_data:(tensor([[18, 19, 20],
[ 6, 7, 8],
[15, 16, 17],
[12, 13, 14]]), tensor([18, 6, 15, 12])) y_data:(tensor([19, 7, 16, 13]), tensor([20, 8, 17, 14]))
batch:3 x_data:(tensor([[ 9, 10, 11],
[ 0, 1, 2]]), tensor([9, 0])) y_data:(tensor([10, 1]), tensor([11, 2]))
最后是源代码:
class TensorDataset(Dataset[Tuple[Tensor, ...]]):
r"""Dataset wrapping tensors.
Each sample will be retrieved by indexing tensors along the first dimension.
Args:
*tensors (Tensor): tensors that have the same size of the first dimension.
"""
tensors: Tuple[Tensor, ...]
def __init__(self, *tensors: Tensor) -> None:
assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors), "Size mismatch between tensors"
self.tensors = tensors
def __getitem__(self, index):
return tuple(tensor[index] for tensor in self.tensors)
def __len__(self):
return self.tensors[0].size(0)