pytorch DDP加速之gradient accumulation设置

pytorch DDP

参考:https://zhuanlan.zhihu.com/p/250471767
GPU高效通信算法-Ring Allreduce: https://www.zhihu.com/question/57799212/answer/612786337
梯度累积: https://www.zhihu.com/question/303070254/answer/573037166

gradient accumulation

在梯度累加的情况下,假设一次梯度累加循环有accumulation_steps个step,每次梯度累加循环会进行K次 all_reduce,但事实上,每次梯度累加循环只会有一次 optimizer.step(),即只应用一次参数更新,这意味着在每一次梯度累加循环中,我们其实只要进行一次gradient all_reduce即可满足要求,有accumulation_steps - 1次all_reduce被浪费了。而每次 all_reduce的时间成本是比较高的。 解决问题的思路在于,对前accumulation_steps - 1次step取消其梯度同步。DDP给我们提供了一个暂时取消梯度同步的context函数 no_sync()(源代码)。在这个context下,DDP不会进行梯度同步。

for epoch in range(epoches):
    for j, data in enumerate(train_loader):
        # 前accumulation_steps - 1个step,不进行梯度同步,累积梯度。
        if accumulation_count % accumulation_steps != 0:
            with model.no_sync():
                loss = model(data)
                loss = loss/accumulation_steps
                loss.backward()
        else:
            loss = model(data)
            loss = loss / accumulation_steps
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
            model_optimizer.step()
            if model_scheduler is not None:
                model_scheduler.step()
            model_optimizer.zero_grad()
         accumulation_count += 1

优雅的写法(兼容单卡和DDP模式):

from contextlib import nullcontext
# 如果python版本小于3.7,则使用下面这个:
# from contextlib import suppress as nullcontext

if local_rank != -1:
    model = DDP(model)

optimizer.zero_grad()
for epoch in range(epoches):
    for i, data in enumerate(train_loader):
        # 只在DDP模式下,轮数不是accumulation_steps整数倍的时候使用no_sync
        mcontext = model.no_sync if local_rank != -1 and accumulation_count % accumulation_steps != 0 else nullcontext
        with mcontext():
            loss = model(data)
            loss = loss / accumulation_steps
            loss.backward()
        # 轮数为accumulation_steps整数倍的时候,传播梯度,并更新参数
        if accumulation_count % accumulation_steps == 0:
            optimizer.step()
            if model_scheduler is not None:
                model_scheduler.step()
            optimizer.zero_grad()
        accumulation_count += 1

你可能感兴趣的:(pytorch,pytorch,深度学习,梯度累积,DDP)