pytorch之TensorDataset

  • 包装数据和目标张量的数据集。
torch.utils.data.TensorDataset(data_tensor, target_tensor)
x_1 = torch.arange(30).reshape(-1,3)

y_1 = torch.arange(10)*3



# TensorDataset对tensor进行打包
dataset = data.TensorDataset(x_1, y_1) 
for x_1_train, y_1_train in dataset:
    print(x_1_train, y_1_train)

# dataloader进行数据封装
print('=' * 80)
train_loader = data.DataLoader(dataset=dataset, batch_size=4, shuffle=True)
for i, data_ in enumerate(train_loader):  
# enumerate返回值有两个,一个是序号,一个是数据(包含训练数据和标签)
    x_1_train, y_1_train = data_
    print(f' batch:{i+1} x_data:{x_1_train}  y_data:{y_1_train}')

运行解果:

tensor([0, 1, 2]) tensor(0)
tensor([3, 4, 5]) tensor(3)
tensor([6, 7, 8]) tensor(6)
tensor([ 9, 10, 11]) tensor(9)
tensor([12, 13, 14]) tensor(12)
tensor([15, 16, 17]) tensor(15)
tensor([18, 19, 20]) tensor(18)
tensor([21, 22, 23]) tensor(21)
tensor([24, 25, 26]) tensor(24)
tensor([27, 28, 29]) tensor(27)
================================================================================
 batch:1 x_data:tensor([[27, 28, 29],
        [15, 16, 17],
        [ 3,  4,  5],
        [ 6,  7,  8]])  y_datatensor([27, 15,  3,  6])
 batch:2 x_data:tensor([[ 9, 10, 11],
        [ 0,  1,  2],
        [21, 22, 23],
        [18, 19, 20]])  y_datatensor([ 9,  0, 21, 18])
 batch:3 x_data:tensor([[12, 13, 14],
        [24, 25, 26]])  y_datatensor([12, 24])

这里也可以有多个输入和输出,比如可以输入两个数据,输出两个label。

torch.utils.data.TensorDataset(data_tensor1, data_tensor2, target_tensor1, target_tensor2)
x_1 = torch.arange(30).reshape(-1,3)
x_2 = torch.arange(10)*3
y_1 = torch.arange(10)*3 + 1
y_2 = torch.arange(10)*3 + 2


# TensorDataset对tensor进行打包
dataset = data.TensorDataset(x_1, x_2, y_1, y_2) 
for x_1_train, x_2_train, y_1_train, y_2_train in dataset:
    print(x_1_train, x_2_train, y_1_train, y_2_train)

# dataloader进行数据封装
print('=' * 80)
train_loader = data.DataLoader(dataset=dataset, batch_size=4, shuffle=True)
for i, data_ in enumerate(train_loader):  
# 注意enumerate返回值有两个,一个是序号,一个是数据(包含训练数据和标签)
    x_1_train, x_2_train, y_1_train, y_2_train = data_
    print(f' batch:{i+1} x_data:{x_1_train, x_2_train}  y_data:{y_1_train, y_2_train}')

输出结果:

tensor([0, 1, 2]) tensor(0) tensor(1) tensor(2)
tensor([3, 4, 5]) tensor(3) tensor(4) tensor(5)
tensor([6, 7, 8]) tensor(6) tensor(7) tensor(8)
tensor([ 9, 10, 11]) tensor(9) tensor(10) tensor(11)
tensor([12, 13, 14]) tensor(12) tensor(13) tensor(14)
tensor([15, 16, 17]) tensor(15) tensor(16) tensor(17)
tensor([18, 19, 20]) tensor(18) tensor(19) tensor(20)
tensor([21, 22, 23]) tensor(21) tensor(22) tensor(23)
tensor([24, 25, 26]) tensor(24) tensor(25) tensor(26)
tensor([27, 28, 29]) tensor(27) tensor(28) tensor(29)
================================================================================
 batch:1 x_data:(tensor([[24, 25, 26],
        [27, 28, 29],
        [21, 22, 23],
        [ 3,  4,  5]]), tensor([24, 27, 21,  3]))  y_data:(tensor([25, 28, 22,  4]), tensor([26, 29, 23,  5]))
 batch:2 x_data:(tensor([[18, 19, 20],
        [ 6,  7,  8],
        [15, 16, 17],
        [12, 13, 14]]), tensor([18,  6, 15, 12]))  y_data:(tensor([19,  7, 16, 13]), tensor([20,  8, 17, 14]))
 batch:3 x_data:(tensor([[ 9, 10, 11],
        [ 0,  1,  2]]), tensor([9, 0]))  y_data:(tensor([10,  1]), tensor([11,  2]))

最后是源代码:

class TensorDataset(Dataset[Tuple[Tensor, ...]]):
    r"""Dataset wrapping tensors.

    Each sample will be retrieved by indexing tensors along the first dimension.

    Args:
        *tensors (Tensor): tensors that have the same size of the first dimension.
    """
    tensors: Tuple[Tensor, ...]

    def __init__(self, *tensors: Tensor) -> None:
        assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors), "Size mismatch between tensors"
        self.tensors = tensors

    def __getitem__(self, index):
        return tuple(tensor[index] for tensor in self.tensors)

    def __len__(self):
        return self.tensors[0].size(0)

你可能感兴趣的:(pytorch,深度学习,机器学习)