Pytorch中TensorDataset,DataLoader的联合使用

    首先从字面意义上来理解TensorDataset和DataLoader,TensorDataset是个只用来存放tensor(张量)的数据集,而DataLoader是一个数据加载器,一般用到DataLoader的时候就说明需要遍历和操作数据了。TensorDataset(tensor1,tensor2)的功能就是形成数据tensor1和标签tensor2的对应,也就是说tensor1中是数据,而tensor2是tensor1所对应的标签。来个小例子:

from torch.utils.data import TensorDataset,DataLoader
import torch

a = torch.tensor([[1, 2, 3],
                  [4, 5, 6],
                  [7, 8, 9],
                  [1, 2, 3],
                  [4, 5, 6],
                  [7, 8, 9],
                  [1, 2, 3],
                  [4, 5, 6],
                  [7, 8, 9],
                  [1, 2, 3],
                  [4, 5, 6],
                  [7, 8, 9]])

b = torch.tensor([44, 55, 66, 44, 55, 66, 44, 55, 66, 44, 55, 66])
train_ids = TensorDataset(a,b)
# 切片输出
print(train_ids[0:4]) # 第0,1,2,3行
# 循环取数据
for x_train,y_label in train_ids:
    print(x_train,y_label)

    下面是对应的输出:

(tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9],
        [1, 2, 3]]), tensor([44, 55, 66, 44]))
===============================================
tensor([1, 2, 3]) tensor(44)
tensor([4, 5, 6]) tensor(55)
tensor([7, 8, 9]) tensor(66)
tensor([1, 2, 3]) tensor(44)
tensor([4, 5, 6]) tensor(55)
tensor([7, 8, 9]) tensor(66)
tensor([1, 2, 3]) tensor(44)
tensor([4, 5, 6]) tensor(55)
tensor([7, 8, 9]) tensor(66)
tensor([1, 2, 3]) tensor(44)
tensor([4, 5, 6]) tensor(55)
tensor([7, 8, 9]) tensor(66)

    从输出结果我们就可以很好的理解,tensor型数据和tensor型标签的对应了,这就是TensorDataset的基本应用。接下来我们把构造好的TensorDataset封装到DataLoader来操作里面的数据:

# 参数说明,dataset=train_ids表示需要封装的数据集,batch_size表示一次取几个
# shuffle表示乱序取数据,设为False表示顺序取数据,True表示乱序取数据
train_loader = DataLoader(dataset=train_ids,batch_size=4,shuffle=False)
# 注意enumerate返回值有两个,一个是序号,一个是数据(包含训练数据和标签)
for i,data in enumerate(train_loader,1):
    train_data, label = data
    print(' batch:{0} train_data:{1}  label: {2}'.format(i+1, train_data, label))

    下面是对应的输出:

 batch:1 x_data:tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9],
        [1, 2, 3]])  label: tensor([44, 55, 66, 44])
 batch:2 x_data:tensor([[4, 5, 6],
        [7, 8, 9],
        [1, 2, 3],
        [4, 5, 6]])  label: tensor([55, 66, 44, 55])
 batch:3 x_data:tensor([[7, 8, 9],
        [1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])  label: tensor([66, 44, 55, 66])

    至此,TensorDataset和DataLoader的联合使用就介绍完了。我们再看一下这两种方法的源码:

class TensorDataset(Dataset[Tuple[Tensor, ...]]):
    r"""Dataset wrapping tensors.

    Each sample will be retrieved by indexing tensors along the first dimension.

    Arguments:
        *tensors (Tensor): tensors that have the same size of the first dimension.
    """
    tensors: Tuple[Tensor, ...]

    def __init__(self, *tensors: Tensor) -> None:
        assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
        self.tensors = tensors

    def __getitem__(self, index):
        return tuple(tensor[index] for tensor in self.tensors)

    def __len__(self):
        return self.tensors[0].size(0)

# 由于此类内容过多,故仅列举了与本文相关的参数,其余参数可以自行去查看源码
class DataLoader(Generic[T_co]):
    r"""
    Data loader. Combines a dataset and a sampler, and provides an iterable over
    the given dataset.

    The :class:`~torch.utils.data.DataLoader` supports both map-style and
    iterable-style datasets with single- or multi-process loading, customizing
    loading order and optional automatic batching (collation) and memory pinning.

    See :py:mod:`torch.utils.data` documentation page for more details.

    Arguments:
        dataset (Dataset): dataset from which to load the data.
        batch_size (int, optional): how many samples per batch to load
            (default: ``1``).
        shuffle (bool, optional): set to ``True`` to have the data reshuffled
            at every epoch (default: ``False``).
    """
    dataset: Dataset[T_co]
    batch_size: Optional[int]

    def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,
                 shuffle: bool = False):

        self.dataset = dataset
        self.batch_size = batch_size
       

   感谢大家的阅读~

你可能感兴趣的:(推荐算法,pytorch,深度学习,机器学习)