Pytorch中的TensorDataset与DataLoader的使用

TensorDataset

TensorDataset本质上与python zip方法类似,对数据进行打包整合。
官方文档说明:

**Dataset wrapping tensors.

Each sample will be retrieved by indexing tensors along the first dimension.*

Parameters:
tensors (Tensor) – tensors that have the same size of the first dimension.

该类通过每一个 tensor 的第一个维度进行索引。因此,该类中的 tensor 第一维度必须相等。

import torch
from torch.utils.data import TensorDataset
# a的形状为(4*3)
a = torch.tensor([[1,1,1],[2,2,2],[3,3,3],[4,4,4]])
# b的第一维与a相同
b = torch.tensor([1,2,3,4])
train_data = TensorDataset(a,b)
print(train_data[0:4])

输出结果如下:

(tensor([[1, 1, 1],
        [2, 2, 2],
        [3, 3, 3],
        [4, 4, 4]]), tensor([1, 2, 3, 4]))

DataLoader

DataLoader本质上就是一个iterable(跟python的内置类型list等一样),并利用多进程来加速batch data的处理,使用yield来使用有限的内存。

import torch
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader

a = torch.tensor([[1,1,1],[2,2,2],[3,3,3],[4,4,4]])
b = torch.tensor([1,2,3,4])
train_data = TensorDataset(a,b)
data = DataLoader(train_data, batch_size=2, shuffle=True)
for i, j in enumerate(data):
    x, y = j
    print(' batch:{0} x:{1}  y: {2}'.format(i, x, y))

输出:

 batch:0 x:tensor([[1, 1, 1],
        [2, 2, 2]])  y: tensor([1, 2])
 batch:1 x:tensor([[4, 4, 4],
        [3, 3, 3]])  y: tensor([4, 3])

你可能感兴趣的:(Pytorch中的TensorDataset与DataLoader的使用)