【PyTorch】torch.utils.data.DataLoader函数中collect_fn

  • 个人理解:
    torch.utils.data.DataLoader中默认的collect_fn作用在于:生成批数据(batch)时添加一个batch维度;自动将NumPy数组及Python原来的数值类型转化为PyTorch的Tensor类型;保留样本的数据结构(dictionary、tuple、list…)。
    可以自定义一个collect_fn函数,使得生成的批数据满足自定义需求,如根据input生成相应的mask一同输出、对input进行适当修改等等。

  • 官方解释:
    The use of collate_fn is slightly different when automatic batching is enabled or disabled.
    When automatic batching is disabled, collate_fn is called with each individual data sample, and the output is yielded from the data loader iterator. In this case, the default collate_fn simply converts NumPy arrays in PyTorch tensors.
    When automatic batching is enabled, collate_fn is called with a list of data samples at each time. It is expected to collate the input samples into a batch for yielding from the data loader iterator. The rest of this section describes behavior of the default collate_fn in this case.
    For instance, if each data sample consists of a 3-channel image and an integral class label, i.e., each element of the dataset returns a tuple (image, class_index), the default collate_fn collates a list of such tuples into a single tuple of a batched image tensor and a batched class label Tensor. In particular, the default collate_fn has the following properties:

    • It always prepends a new dimension as the batch dimension.
    • It automatically converts NumPy arrays and Python numerical values into PyTorch Tensors.
    • It preserves the data structure, e.g., if each sample is a dictionary, it outputs a dictionary with the same set of keys but batched Tensors as values (or lists if the values can not be converted into Tensors). Same for list s, tuple s, namedtuple s, etc.
      Users may use customized collate_fn to achieve custom batching, e.g., collating along a dimension other than the first, padding sequences of various lengths, or adding support for custom data types.

你可能感兴趣的:(【PyTorch】torch.utils.data.DataLoader函数中collect_fn)