Pytorch之torch.nn.DataParallel

CLASS torch.nn.DataParallel(moduledevice_ids=Noneoutput_device=Nonedim=0)

    在模块水平实现数据并行。

    该容器通过在批处理维度中分组,将输入分割到指定的设备上,从而并行化给定模块的应用程序(其它对象将在每个设备上复制一次)。在前向传播时,模块被复制到每个设备上,每个副本处理输入的一部分。在反向传播时,来自每个副本的梯度被累加到原始模块中。

    批处理大小应该大于所使用的GPU数量。

警告:

在使用多GPU训练时,推荐使用 DistributedDataParallel,而不是使用这个类,即使你仅仅使用一个节点。参考 Use nn.parallel.DistributedDataParallel instead of multiprocessing or nn.DataParallel 和Distributed Data Parallel

 

允许将任意位置和关键字输入传递给DataParallel,但是有些类型是专门处理的。张量将被分散在指定的dim上(默认是0)。tuple,list,dict将被浅复制。其它类型将在不同的线程中共享,如果在模型前向传递中写入,则可能被破坏。

The parallelized module must have its parameters and buffers on device_ids[0] before running this DataParallel module.

 

警告:

在每次forward时,module被复制到每个设备上,因此对forward中正在运行的模块的任何更新都将丢失。例如,如果模块有一个在每次forward时都递增的计数器属性,那么它始终保持初始值,因为更新是在forward后销毁的副本上完成的。但是,DataParallel保证了device[0]上的副本有其参数和缓冲区,与基本并行模块共享存储。所以,对device[0]上参数和缓冲区的就地更新将被记录。例如, BatchNorm2d 和 spectral_norm() 依赖这个行为更新缓冲区。

警告:

定义在模块和子模块上的forward和backward钩子将被调用len(device_ids)次,每个钩子的输入都位于特定的设备上。特别地,

警告:

当模块在forward()中返回一个标量时,这个包装器将会返回一个向量,且这个向量的长度等于数据并行机制中使用的设备数,包含每个设备计算结果。

example:

>>> model = model.cuda() # 使用DataParallel之前,model在device[0]上必须有parameters和buffers
>>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2])
>>> output = net(input_var)  # input_var can be on any device, including CPU

 

torch.nn.DataParallel源码解读:

class DataParallel(Module):
    

    def __init__(self, module, device_ids=None, output_device=None, dim=0):
        super(DataParallel, self).__init__()

        device_type = _get_available_device_type()
        if device_type is None:
            self.module = module
            self.device_ids = []
            return

        if device_ids is None:
            device_ids = _get_all_device_indices()

        if output_device is None:
            output_device = device_ids[0]

        self.dim = dim
        self.module = module
        self.device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))
        self.output_device = _get_device_index(output_device, True)
        self.src_device_obj = torch.device(device_type, self.device_ids[0])

        _check_balance(self.device_ids)

        if len(self.device_ids) == 1:
            self.module.to(self.src_device_obj)

    def forward(self, *inputs, **kwargs):
        if not self.device_ids:
            return self.module(*inputs, **kwargs)

        for t in chain(self.module.parameters(), self.module.buffers()):
            if t.device != self.src_device_obj:
                raise RuntimeError("module must have its parameters and buffers "
                                   "on device {} (device_ids[0]) but found one of "
                                   "them on device: {}".format(self.src_device_obj, t.device))

        inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
        if len(self.device_ids) == 1:
            return self.module(*inputs[0], **kwargs[0])
        replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
        outputs = self.parallel_apply(replicas, inputs, kwargs)
        return self.gather(outputs, self.output_device)

    def replicate(self, module, device_ids):
        return replicate(module, device_ids, not torch.is_grad_enabled())

    def scatter(self, inputs, kwargs, device_ids):
        return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)

    def parallel_apply(self, replicas, inputs, kwargs):
        return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])

    def gather(self, outputs, output_device):
        return gather(outputs, output_device, dim=self.dim)


[docs]def data_parallel(module, inputs, device_ids=None, output_device=None, dim=0, module_kwargs=None):
    r"""Evaluates module(input) in parallel across the GPUs given in device_ids.

    This is the functional version of the DataParallel module.

    Args:
        module (Module): the module to evaluate in parallel
        inputs (Tensor): inputs to the module
        device_ids (list of int or torch.device): GPU ids on which to replicate module
        output_device (list of int or torch.device): GPU location of the output  Use -1 to indicate the CPU.
            (default: device_ids[0])
    Returns:
        a Tensor containing the result of module(input) located on
        output_device
    """
    if not isinstance(inputs, tuple):
        inputs = (inputs,)

    device_type = _get_available_device_type()

    if device_ids is None:
        device_ids = _get_all_device_indices()

    if output_device is None:
        output_device = device_ids[0]

    device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))
    output_device = _get_device_index(output_device, True)
    src_device_obj = torch.device(device_type, device_ids[0])

    for t in chain(module.parameters(), module.buffers()):
        if t.device != src_device_obj:
            raise RuntimeError("module must have its parameters and buffers "
                               "on device {} (device_ids[0]) but found one of "
                               "them on device: {}".format(src_device_obj, t.device))

    inputs, module_kwargs = scatter_kwargs(inputs, module_kwargs, device_ids, dim)
    if len(device_ids) == 1:
        return module(*inputs[0], **module_kwargs[0])
    used_device_ids = device_ids[:len(inputs)]
    replicas = replicate(module, used_device_ids)
    outputs = parallel_apply(replicas, inputs, module_kwargs, used_device_ids)
    return gather(outputs, output_device, dim)

 

你可能感兴趣的:(并行训练,pytorch)