CLASS torch.nn.
DataParallel
(module, device_ids=None, output_device=None, dim=0)
在模块水平实现数据并行。
该容器通过在批处理维度中分组,将输入分割到指定的设备上,从而并行化给定模块的应用程序(其它对象将在每个设备上复制一次)。在前向传播时,模块被复制到每个设备上,每个副本处理输入的一部分。在反向传播时,来自每个副本的梯度被累加到原始模块中。
批处理大小应该大于所使用的GPU数量。
警告:
在使用多GPU训练时,推荐使用 DistributedDataParallel,而不是使用这个类,即使你仅仅使用一个节点。参考 Use nn.parallel.DistributedDataParallel instead of multiprocessing or nn.DataParallel 和Distributed Data Parallel
允许将任意位置和关键字输入传递给DataParallel,但是有些类型是专门处理的。张量将被分散在指定的dim上(默认是0)。tuple,list,dict将被浅复制。其它类型将在不同的线程中共享,如果在模型前向传递中写入,则可能被破坏。
The parallelized module
must have its parameters and buffers on device_ids[0]
before running this DataParallel
module.
警告:
在每次forward时,module被复制到每个设备上,因此对forward中正在运行的模块的任何更新都将丢失。例如,如果模块有一个在每次forward时都递增的计数器属性,那么它始终保持初始值,因为更新是在forward后销毁的副本上完成的。但是,DataParallel保证了device[0]上的副本有其参数和缓冲区,与基本并行模块共享存储。所以,对device[0]上参数和缓冲区的就地更新将被记录。例如, BatchNorm2d
和 spectral_norm()
依赖这个行为更新缓冲区。
警告:
定义在模块和子模块上的forward和backward钩子将被调用len(device_ids)次,每个钩子的输入都位于特定的设备上。特别地,
警告:
当模块在forward()中返回一个标量时,这个包装器将会返回一个向量,且这个向量的长度等于数据并行机制中使用的设备数,包含每个设备计算结果。
example:
>>> model = model.cuda() # 使用DataParallel之前,model在device[0]上必须有parameters和buffers
>>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2])
>>> output = net(input_var) # input_var can be on any device, including CPU
torch.nn.DataParallel源码解读:
class DataParallel(Module):
def __init__(self, module, device_ids=None, output_device=None, dim=0):
super(DataParallel, self).__init__()
device_type = _get_available_device_type()
if device_type is None:
self.module = module
self.device_ids = []
return
if device_ids is None:
device_ids = _get_all_device_indices()
if output_device is None:
output_device = device_ids[0]
self.dim = dim
self.module = module
self.device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))
self.output_device = _get_device_index(output_device, True)
self.src_device_obj = torch.device(device_type, self.device_ids[0])
_check_balance(self.device_ids)
if len(self.device_ids) == 1:
self.module.to(self.src_device_obj)
def forward(self, *inputs, **kwargs):
if not self.device_ids:
return self.module(*inputs, **kwargs)
for t in chain(self.module.parameters(), self.module.buffers()):
if t.device != self.src_device_obj:
raise RuntimeError("module must have its parameters and buffers "
"on device {} (device_ids[0]) but found one of "
"them on device: {}".format(self.src_device_obj, t.device))
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
if len(self.device_ids) == 1:
return self.module(*inputs[0], **kwargs[0])
replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
outputs = self.parallel_apply(replicas, inputs, kwargs)
return self.gather(outputs, self.output_device)
def replicate(self, module, device_ids):
return replicate(module, device_ids, not torch.is_grad_enabled())
def scatter(self, inputs, kwargs, device_ids):
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
def parallel_apply(self, replicas, inputs, kwargs):
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
def gather(self, outputs, output_device):
return gather(outputs, output_device, dim=self.dim)
[docs]def data_parallel(module, inputs, device_ids=None, output_device=None, dim=0, module_kwargs=None):
r"""Evaluates module(input) in parallel across the GPUs given in device_ids.
This is the functional version of the DataParallel module.
Args:
module (Module): the module to evaluate in parallel
inputs (Tensor): inputs to the module
device_ids (list of int or torch.device): GPU ids on which to replicate module
output_device (list of int or torch.device): GPU location of the output Use -1 to indicate the CPU.
(default: device_ids[0])
Returns:
a Tensor containing the result of module(input) located on
output_device
"""
if not isinstance(inputs, tuple):
inputs = (inputs,)
device_type = _get_available_device_type()
if device_ids is None:
device_ids = _get_all_device_indices()
if output_device is None:
output_device = device_ids[0]
device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))
output_device = _get_device_index(output_device, True)
src_device_obj = torch.device(device_type, device_ids[0])
for t in chain(module.parameters(), module.buffers()):
if t.device != src_device_obj:
raise RuntimeError("module must have its parameters and buffers "
"on device {} (device_ids[0]) but found one of "
"them on device: {}".format(src_device_obj, t.device))
inputs, module_kwargs = scatter_kwargs(inputs, module_kwargs, device_ids, dim)
if len(device_ids) == 1:
return module(*inputs[0], **module_kwargs[0])
used_device_ids = device_ids[:len(inputs)]
replicas = replicate(module, used_device_ids)
outputs = parallel_apply(replicas, inputs, module_kwargs, used_device_ids)
return gather(outputs, output_device, dim)