github: https://github.com/NVIDIA/Megatron-LM
在megatron/arguments.py
中有重计算的参数配置如下:
group.add_argument('--recompute-activations', action='store_true',
help='recompute activation to allow for training '
'with larger models, sequences, and batch sizes.')
group.add_argument('--recompute-granularity', type=str, default=None,
choices=['full', 'selective'],
help='Checkpoint activations to allow for training '
'with larger models, sequences, and batch sizes. '
'It is supported at two granularities 1) full: '
'whole transformer layer is recomputed, '
'2) selective: core attention part of the transformer '
'layer is recomputed.')
group.add_argument('--distribute-saved-activations',
action='store_true',
help='If set, distribute recomputed activations '
'across model parallel group.')
group.add_argument('--recompute-method', type=str, default=None,
choices=['uniform', 'block'],
help='1) uniform: uniformly divide the total number of '
'Transformer layers and recompute the input activation of '
'each divided chunk at specified granularity, '
'2) recompute the input activations of only a set number of '
'individual Transformer layers per pipeline stage and do the '
'rest without any recomputing at specified granularity'
'default) do not apply activations recompute to any layers')
group.add_argument('--recompute-num-layers', type=int, default=1,
help='1) uniform: the number of Transformer layers in each '
'uniformly divided recompute unit, '
'2) block: the number of individual Transformer layers '
'to recompute within each pipeline stage.')
说明:
--recompute-activations
: 设置recompute_activations
等同于recompute_granularity
为selective
;selective
运行效率更高,大部分场景只设置这个就可以。如果显存更紧张时,再通过recompute-granularity
来进行full
的设置。--recompute-granularity
: 支持不同颗粒度的重计算,设为full
会重计算整个transformer层,设为selective
只会重算transformer中的core_attention
部分。--distribute-saved-activations
: 按TP并行度分开存储activation。--recompute-method
: uniform
计算会把所有的transformer layer分为若干组,分别把每组的input activation
保存在内存中, GPU显存不足时,可通过设大每个组内的layer数来运行更大的model;block
是针对pipeline并行的每个stage,checkpoint部分transformer layer的input activation, 剩余部分不进行checkpoint缓存,对于一个pipeline stage中有8层的来说,当设为5时,前5层中每一层的input activation都会被缓存,后3层在反向的时候正常计算。--recompute-num-layers
: 对于uniform
类型,表示设置在每个重计算的transformer layer group中的层数, 默认为1表示对每一层transformer layer都分别进行checkpoint;对于block
类型,设为N表示单个pipeline stage中的前N个layers会缓存input activation。设置recompute_activations
等同于recompute_granularity
为selective
,设置后会覆盖recompute_granularity
的值。
if args.recompute_activations:
args.recompute_granularity = 'selective'
del args.recompute_activations
支持不同颗粒度的重计算,设为full
会重计算整个transformer层,设为selective
只会重算transformer中的core_attention
部分。下图红框中为选择重计算的部分:
入口是在megatron/model/transformer.py
文件中的ParallelTransformer
函数中, 在forward中如果为full的话会对整个前向进行checkpoint操作。
class ParallelTransformer(MegatronModule):
def forward(self, hidden_states, attention_mask,
encoder_output=None, enc_dec_attn_mask=None,
retriever_input=None,
retriever_output=None,
retriever_attn_mask=None,
inference_params=None,
rotary_pos_emb=None):
...
# Forward pass.
if self.recompute_granularity == 'full':
hidden_states = self._checkpointed_forward(hidden_states,
attention_mask,
encoder_output,
enc_dec_attn_mask,
rotary_pos_emb,
is_first_microbatch)
...
具体执行checkpoint的方法有两种,分别是uniform
和block
。对于uniform方法来说,每次会过self.recompute_num_layers
个layer保存一次input activation。custom(l, l + self.recompute_num_layers)
, custom
函数是用于执行自定义的前向计算。对于block方法,对于小于self.recompute_num_layers
的layer执行input activation的checkpoint, 对于大于等于self.recompute_num_layers
的layer还执行原有操作。
示例代码如下:
class ParallelTransformer(MegatronModule):
def _checkpointed_forward(self, hidden_states, attention_mask,
encoder_output, enc_dec_attn_mask,
rotary_pos_emb, is_first_microbatch):
def custom(start, end):
def custom_forward(*args, **kwargs):
x_, *args = args
for index in range(start, end):
layer = self._get_layer(index)
x_ = layer(x_, *args, **kwargs)
return x_
return custom_forward
if self.recompute_method == 'uniform':
# Uniformly divide the total number of Transformer layers and
# checkpoint the input activation of each divided chunk.
# A method to further reduce memory usage reducing checkpoints.
l = 0
while l < self.num_layers:
if self.transformer_impl == 'transformer_engine':
...
else:
hidden_states = tensor_parallel.checkpoint(
custom(l, l + self.recompute_num_layers),
self.distribute_saved_activations,
hidden_states, attention_mask,
encoder_output, enc_dec_attn_mask,
None, None, None, None, rotary_pos_emb)
l += self.recompute_num_layers
elif self.recompute_method == 'block':
# Checkpoint the input activation of only a set number of individual
# Transformer layers and skip the rest.
# A method fully use the device memory removing redundant re-computation.
for l in range(self.num_layers):
if l < self.recompute_num_layers:
if self.transformer_impl == 'transformer_engine':
...
else:
hidden_states = tensor_parallel.checkpoint(
custom(l, l + 1),
self.distribute_saved_activations,
hidden_states, attention_mask,
encoder_output, enc_dec_attn_mask,
None, None, None, None, rotary_pos_emb)
else:
if self.transformer_impl == 'transformer_engine':
...
else:
hidden_states = custom(l, l + 1)(
hidden_states, attention_mask,
encoder_output, enc_dec_attn_mask,
None, None, None, None, rotary_pos_emb)
...
对于selective
颗粒度的重计算目前不支持重计算方法的选择。
if args.recompute_granularity == 'selective':
assert args.recompute_method is None, \
'recompute method is not yet supported for ' \
'selective recomputing granularity'
在计算ParallelTransformer
时会只针对attention的input activation进行checkpoint。
class ParallelTransformer(MegatronModule):
"""Transformer class."""
def __init__(...):
...
self.checkpoint_core_attention = args.recompute_granularity == 'selective'
...
def forward(self, hidden_states, attention_mask,
encoder_output=None, inference_params=None,
rotary_pos_emb=None):
...
if not self.use_flash_attn:
if self.checkpoint_core_attention:
context_layer = self._checkpointed_attention_forward(
query_layer, key_layer, value_layer, attention_mask)
...
def _checkpointed_attention_forward(self, query_layer, key_layer,
value_layer, attention_mask,
rotary_pos_emb=None):
"""Forward method with activation checkpointing."""
def custom_forward(*inputs):
query_layer = inputs[0]
key_layer = inputs[1]
value_layer = inputs[2]
attention_mask = inputs[3]
output_ = self.core_attention(query_layer, key_layer,
value_layer, attention_mask)
return output_
q_pos_emb, k_pos_emb = (None, None) if rotary_pos_emb is None \
else rotary_pos_emb
hidden_states = tensor_parallel.checkpoint(
custom_forward,
False, query_layer, key_layer, value_layer, attention_mask,
q_pos_emb, k_pos_emb)
return hidden_states
只支持设置TP>1, 并且是full粒度的重计算, 对于uniform
和block
重计算都支持。
# Activation recomputing.
if args.distribute_saved_activations:
assert args.tensor_model_parallel_size > 1, 'can distribute ' \
'recomputed activations only across tensor model ' \
'parallel groups'
assert args.recompute_granularity == 'full', \
'distributed recompute activations is only '\
'application to full recompute granularity'
assert args.recompute_method is not None, \
'for distributed recompute activations to work you '\
'need to use a recompute method '
assert TORCH_MAJOR >= 1 and TORCH_MINOR >= 10, \
'distributed recompute activations are supported for pytorch ' \
'v1.10 and above (Nvidia Pytorch container >= 21.07). Current ' \
'pytorch version is v%s.%s.' % (TORCH_MAJOR, TORCH_MINOR)
实现是在megatron/core/tensor_parallel/random.py
文件的CheckpointFunction
中,前向计算时使用no_grad()
的上下文,不保存grad,在存activation的时候会把tensor展成一维的,每个rank只存自己的那一部分。
class CheckpointFunction(torch.autograd.Function):
def forward(ctx, run_function, distribute_saved_activations, *args):
...
with torch.no_grad():
outputs = run_function(*args)
# Divide hidden states across model parallel group and only keep
# the chunk corresponding to the current rank.
if distribute_saved_activations:
ctx.input_0_shape = args[0].data.shape
safely_set_viewless_tensor_data(
args[0],
split_tensor_into_1d_equal_chunks(args[0].data, new_buffer=True))
...
# Store everything.
ctx.save_for_backward(*args)
return outputs
反向会在model_group
中先通过gather_split_1d_tensor
函数进行all_gather
操作, 再进行backward的计算。
class CheckpointFunction(torch.autograd.Function):
def backward(ctx, *args):
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError("Checkpointing is not compatible with .grad(), "
"please use .backward() if possible")
inputs = ctx.saved_tensors
if ctx.distribute_saved_activations:
safely_set_viewless_tensor_data(
inputs[0],
gather_split_1d_tensor(inputs[0].data).view(ctx.input_0_shape))
...
# Compute the forward pass.
detached_inputs = detach_variable(inputs)
with torch.enable_grad():
outputs = ctx.run_function(*detached_inputs)
...
torch.autograd.backward(outputs, args)
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp
for inp in detached_inputs)
return (None, None) + grads
随着模型大小的增加,sequence parallel和recompute都会节省内存,将内存需求减少约5倍。
条形图表示每层的前向、反向和重计算时间细分。基线代表没有重计算和序列并行时的情况。这些技术有效地减少了所有激活被重计算而不是保存时产生的开销。最大模型的开销从36%下降到仅为2%。
只设--recompute-activations
相当于只用了selective, 只对core_attention的input activation进行缓存。selective优点是新增的计算量相比较其他方案较少,性价比高。
--recompute-activations
使用block可以最大限度的使用设备显存,减少重计算,对于前N层每一层都保存对应的input activation。可在pipeline并行中配合使用。--distribute-saved-activations
可看情况使用。
--recompute-granularity full \
--recompute-method block \
--recompute-num-layers [N] \
--distribute-saved-activations \
使用uniform,每N个layer进行一次activation的缓存用于后续重计算,默认N为1表示对所有层都会。在使用uniform基础上增加使用--distribute-saved-activations
, 存activation时可以一个TP组进行分开存储;极端情况下N等于pipeline stage中layer数,从而可以最大限度使用重计算。优点是可支持更大的模型,缺点是重计算量会比较大。
--recompute-granularity full \
--recompute-method uniform \
--recompute-num-layers [N] \
--distribute-saved-activations \