collate_fn参数

collate_fn理解

  • collate英文释义
  • PyTorch中提供的参数细节
    • 代码示例
  • 疑问解析
    • 为什么设置好collate_fn会出现报错?
      • 维度不一致代码示例

collate英文释义

vt. 校对,整理
n. 核对;小吃
n. 校对者,整理者

PyTorch中提供的参数细节

【link】
collate_fn (callable, optional) – merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset.
这里是把一些列样本融合为小batch的张量。

代码示例

DataLoader中默认的collate_function函数:

def default_collate(batch):
    r"""Puts each data field into a tensor with outer dimension batch size"""

    elem = batch[0]
    elem_type = type(elem)
    if isinstance(elem, torch.Tensor):
        out = None
        if torch.utils.data.get_worker_info() is not None:
            # If we're in a background process, concatenate directly into a
            # shared memory tensor to avoid an extra copy
            numel = sum([x.numel() for x in batch])
            storage = elem.storage()._new_shared(numel)
            out = elem.new(storage)
        return torch.stack(batch, 0, out=out)
    elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
            and elem_type.__name__ != 'string_':
        if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
            # array of string classes and object
            if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
                raise TypeError(default_collate_err_msg_format.format(elem.dtype))

            return default_collate([torch.as_tensor(b) for b in batch])
        elif elem.shape == ():  # scalars
            return torch.as_tensor(batch)
    elif isinstance(elem, float):
        return torch.tensor(batch, dtype=torch.float64)
    elif isinstance(elem, int_classes):
        return torch.tensor(batch)
    elif isinstance(elem, string_classes):
        return batch
    elif isinstance(elem, container_abcs.Mapping):
        return {key: default_collate([d[key] for d in batch]) for key in elem}
    elif isinstance(elem, tuple) and hasattr(elem, '_fields'):  # namedtuple
        return elem_type(*(default_collate(samples) for samples in zip(*batch)))
    elif isinstance(elem, container_abcs.Sequence):
        # check to make sure that the elements in batch have consistent size
        it = iter(batch)
        elem_size = len(next(it))
        if not all(len(elem) == elem_size for elem in it):
            raise RuntimeError('each element in list of batch should be of equal size')
        transposed = zip(*batch)
        return [default_collate(samples) for samples in transposed]

    raise TypeError(default_collate_err_msg_format.format(elem_type))

主要功能:对我们输入的batch进行处理,得到新的batch。

疑问解析

为什么设置好collate_fn会出现报错?

数据的输入过程中不仅和collate_fn有关系和batch_size也存在关系。
Pytorch对于训练维度的检查时按照每个batch_size的维度进行检查的,因此数据的数量和batch_size的大小会影响结果,不同的batch维度不匹配也会导致报错。

维度不一致代码示例

import numpy as np
import torch


# collate_fn:即用于collate的function,用于整理数据的函数

class NewDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):  # 必须重写
        return len(self.data)

    # 这个方法对数据集添加索引,进行相应取值
    def __getitem__(self, idx):  # 必须重写
        return self.data[idx]


# batch作为参数交给collate_fn这个函数进行进一步整理数据,然后得到real_batch,作为返回值
def New_collate(batch):
    real_batch = np.array(batch)  # 等式右边是对batch的处理过程
    real_batch = torch.from_numpy(real_batch)
    return real_batch


# 如果数据维度不一致,使用默认的collate_fn会报错
# 大多数的神经网络都是定长输入的,而且很多的操作也要求相同维度才能相加或相乘。
# 如果后面解决这个问题的方法是:在不足维度上进行补0操作
test = [[0, 0], [1, 1], [2, 2], [3, 3], [4, 4]]
dataset = NewDataset(test)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=3, collate_fn=New_collate)
# list、tuple等都是可迭代对象,我们可以通过iter()函数获取这些可迭代对象的迭代器。
# 对获取到的迭代器不断使⽤next()函数来获取下⼀条数据。
it = iter(dataloader)
c = next(it)
print(c)
c = next(it)
print(c)

运行结果示例:
collate_fn参数_第1张图片

你可能感兴趣的:(numpy,深度学习,python,pytorch)