Pytorch | torch.nn.SyncBatchNorm

torch.nn.SyncBatchNorm(num_features,eps = 1e-05,动量= 0.1,仿射= True,track_running_stats = True,process_group = None)[源代码]

如论文“Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift.”中所述,对N维输入(具有附加通道维的[N-2] D输入的小批量)应用批量归一化。
在这里插入图片描述
在这里插入图片描述
同样默认情况下,在训练过程中,该层会继续对其计算的均值和方差进行估算,然后将其用于评估期间的标准化。运行估算值保持默认值momentum0.1。

如果track_running_stats将设置为False,则此层将不保持运行估计,而是在评估期间也使用批处理统计信息。

笔记

在这里插入图片描述
因为批处理规范化是在C维上完成的,计算(N,+)切片的统计信息,所以通常将此术语称为“体积批处理规范化”或“时空批处理规范化”。

当前,SyncBatchNorm仅支持每个进程具有单个GPU的DistributedDataParallel。在使用DDP包装网络之前,使用torch.nn.SyncBatchNorm.convert_sync_batchnorm()将BatchNorm层转换为SyncBatchNorm。

参数:

  • num_features –预期的大小(N,C,+)输入的CCC
  • eps –为分母增加数值的稳定性。默认值:1e-5
  • momentum–用于running_mean和running_var计算的值。可以设置None为累积移动平均线(即简单平均线)。默认值:0.1
  • affine–一个布尔值,当设置True为时,此模块具有可学习的仿射参数。默认:True
  • track_running_stats –一个布尔值,设置为时True,此模块跟踪运行平均值和方差;设置为时False,此模块不跟踪此类统计信息,并且始终在训练和评估模式下使用批处理统计信息。默认:True
  • process_group –统计信息的同步分别在每个进程组内发生。默认行为是在整个世界范围内同步

形状:

  • 输入:(N,C,+)(N,C,+)(N,C,+)
  • 输出:(N,C,+)(N,C,+)(N,C,+)(与输入形状相同)

例子:

>>> # With Learnable Parameters
>>> m = nn.SyncBatchNorm(100)
>>> # creating process group (optional)
>>> # process_ids is a list of int identifying rank ids.
>>> process_group = torch.distributed.new_group(process_ids)
>>> # Without Learnable Parameters
>>> m = nn.BatchNorm3d(100, affine=False, process_group=process_group)
>>> input = torch.randn(20, 100, 35, 45, 10)
>>> output = m(input)
 
>>> # network is nn.BatchNorm layer
>>> sync_bn_network = nn.SyncBatchNorm.convert_sync_batchnorm(network, process_group)
>>> # only single gpu per process is currently supported
>>> ddp_sync_bn_network = torch.nn.parallel.DistributedDataParallel(
>>>                         sync_bn_network,
>>>                         device_ids=[args.local_rank],
>>>                         output_device=args.local_rank)

类方法

convert_sync_batchnorm(module,process_group = None)[源代码]

辅助函数将模型中的torch.nn.BatchNormND层转换为torch.nn.SyncBatchNorm层。

参数:

  • module(nn.Module)–包含模块
  • process_group(可选)–进程组到范围的同步,

默认是整个世界。

返回:

  • 具有转换后的torch.nn.SyncBatchNorm层的原始模块。

例子:

>>> # Network with nn.BatchNorm layer
>>> module = torch.nn.Sequential(
>>>            torch.nn.Linear(20, 100),
>>>            torch.nn.BatchNorm1d(100)
>>>          ).cuda()
>>> # creating process group (optional)
>>> # process_ids is a list of int identifying rank ids.
>>> process_group = torch.distributed.new_group(process_ids)
>>> sync_bn_module = convert_sync_batchnorm(module, process_group)

你可能感兴趣的:(Python)