[源码解析] PyTorch 分布式(11) ----- DistributedDataParallel 之 构建Reducer和Join操作

[源码解析] PyTorch 分布式(11) ----- DistributedDataParallel 之 构建Reducer和Join操作

目录
  • [源码解析] PyTorch 分布式(11) ----- DistributedDataParallel 之 构建Reducer和Join操作
    • 0x00 摘要
    • 0x01 引论
      • 1.1 调用
      • 1.2 参数说明
    • 0x02 Reducer 初始化
      • 2.1 构造函数
      • 2.2 初始化桶
      • 2.3 初始化视图
        • 2.3.1 BucketReplica成员变量
        • 2.3.2 调用
      • 2.4 初始化本地使用变量
    • 0x03 静态图
      • 3.1 缘由
      • 3.2 使用
      • 3.2 Reducer
    • 0x04 重建桶
      • 4.1 为何要重建
      • 4.2 准备重建
      • 4.3 重建
      • 4.4 何时设定重建
      • 4.5 直接调用
    • 0x05 Join
      • 5.1 缘起
      • 5.2 使用
        • 5.2.1 DistributedDataParallel
        • 5.2.2 ZeroRedundancyOptimizer
      • 5.3 原理
        • 5.3.1 Joinable
        • 5.3.2JoinHook
          • 5.3.2.1 ZeroRedundancyOptimizer
        • 5.3.3 Join
      • 5.4 例子
    • 0xFF 参考

0x00 摘要

因为前文已经围绕Reducer相关的各种成员变量做了相关分析,所以本文开始做动态逻辑分析,目的是:把前面几篇文章串联起来,为后面分析前向传播和反向传播设定基础。

本系列其他文章如下:

深度学习利器之自动微分(1)

深度学习利器之自动微分(2)

[源码解析]深度学习利器之自动微分(3) --- 示例解读

[源码解析]PyTorch如何实现前向传播(1) --- 基础类(上)

[源码解析]PyTorch如何实现前向传播(2) --- 基础类(下)

[源码解析] PyTorch如何实现前向传播(3) --- 具体实现

[源码解析] Pytorch 如何实现后向传播 (1)---- 调用引擎

[源码解析] Pytorch 如何实现后向传播 (2)---- 引擎静态结构

[源码解析] Pytorch 如何实现后向传播 (3)---- 引擎动态逻辑

[源码解析] PyTorch 如何实现后向传播 (4)---- 具体算法

[源码解析] PyTorch 分布式(1)------历史和概述

[源码解析] PyTorch 分布式(2) ----- DataParallel(上)

[源码解析] PyTorch 分布式(3) ----- DataParallel(下)

[源码解析] PyTorch 分布式(4)------分布式应用基础概念

[源码解析] PyTorch分布式(5) ------ DistributedDataParallel 总述&如何使用

[源码解析] PyTorch分布式(6) ---DistributedDataParallel -- 初始化&store

[源码解析] PyTorch 分布式(7) ----- DistributedDataParallel 之进程组

[源码解析] PyTorch 分布式(8) -------- DistributedDataParallel之论文篇

[源码解析] PyTorch 分布式(9) ----- DistributedDataParallel 之初始化

[源码解析] PyTorch 分布式(10)------DistributedDataParallel 之 Reducer静态架构

0x01 引论

为了更好的分析,我们还是需要看看如何调用。

1.1 调用

Reducer 的创建代码如下,是在_ddp_init_helper 之中。

        # Note: reverse list of buckets because we want to approximate the
        # order in which their gradients are produced, and assume they
        # are used in the forward pass in the order they are defined.
        self.reducer = dist.Reducer(
            parameters, # parameters[0]是张量列表
            list(reversed(bucket_indices)), # 桶信息
            self.process_group,
            expect_sparse_gradient,
            self.bucket_bytes_cap,
            self.find_unused_parameters,
            self.gradient_as_bucket_view,
            param_to_name_mapping,
        )

1.2 参数说明

调用的 parameters 举例如下, parameters[0] 就是 rank 0 上模型的 parameters,可以看到其只有 [0] 元素有意义,这个 [0] 原始本身包括 20 个元素:

parameters = {list: 1} 
0 = {list: 4}           
 0 = {Parameter: 10} Parameter containing:\ntensor([[-4.0381e-02,  3.8828e-02, 1  )   
 1 = {Parameter: 10} Parameter containing:\ntensor([-0.0438, -0.2033,  0.2771,  0.0721,  ) 
 2 = {Parameter: 5} Parameter containing:\ntensor([[-0.0094, -0.1319,  0.0713,  0.3155,  )
 3 = {Parameter: 5} Parameter containing:\ntensor([-0.0008,  0.0582, -0.1245, -0.2538, )
 ...
 20 = {Parameter: 5} Parameter containing:\ntensor([-0.0008,  0.0582, -0.1245, -0.2538, )                                                   
 __len__ = {int} 20
__len__ = {int} 1

bucket_indices 举例如下:

关于 tensor indices,就是给所有的tensor一个index,从0开始递增,一直到 tensors.size()。假如模型的 parameters 一共有20个张量,则 tensor index 从 0 到 19,分成 6 个buckets,则在这6个buckets之中,每个 tensor index 都是唯一不重复的。

+-----------------------------------------------------------------------+
|                                                                       |
|       |
|                                                                       |
|                                                                       |
|                             |
|                                                                       |
|                                                                       |
|  ......                                                               |
|                                                                       |
|                                                                       |
|   |
|                                                                       |
+-----------------------------------------------------------------------+

接下来,我们就看看如何进行初始化 Reducer。

0x02 Reducer 初始化

代码位于:torch/lib/c10d/reducer.h 和 torch/lib/c10d/reducer.cpp

2.1 构造函数

具体逻辑如下:

  • 看看本模块是不是多设备模块,具体是: 遍历张量,得到张量的设备,把设备插入到一个set结构之中,如果set内的设备多于一个,是多设备
  • 如果 expect_sparse_gradients没有设置,就把expect_sparse_gradients_初始化为false。
  • 调用 initialize_buckets 初始化 buckets 并尽可能按照逆序将 parameters 分配到 buckets 之中,这样按桶通信就可以提高效率。后续在运行时候也可能再次重新初始化桶。
  • 为每个 parameter 加上 grad_accumulator,它们在 backward 时负责梯度同步。
    • 因为这些variables是autograd图的叶子张量,所以它们的grad_fn都被设置为 gradient accumulation function。
    • Reducer保存了指向这些functions的指针,这样Reducer就可以知道它们在autograd传播之中是否被使用,如果没有使用,那么就把这些functions的梯度张量(grad tensors)设置为规约就绪状态。
    • 遍历张量,为每个张量生成一个类型为VariableIndex的变量index。
    • 得到Variable::AutogradMeta的grad_accumulator_,即用于累加叶子 Variable 的梯度累加器。
    • 把reducer的autograd_hook函数添加进去每个grad_accumulator_之中,变量index是hook的参数。这个 hook 挂在 autograd graph 之上,在 backward 时负责梯度同步。grad_accumulator 执行完后,autograd_hook 就会运行。
  • gradAccToVariableMap_ 存了grad_accumulator & index 的对应关系(函数指针和参数张量的对应关系),这样以后在 autograd graph 遍历寻找 unused parameters 就方便了。
  • 初始化 backward_stats_。
  • 调用 initialize_local_used_map 初始化各种 unused map。
// The constructor takes a list of variables for every model replica.
// The bucket assignment for this reducer is specified as a list of
// buckets, each of which is specified as a list of indices into the
// variables list for **a single replica** (i.e. `variables[0]`).
Reducer::Reducer(
    std::vector> replicas, // 张量
    std::vector> bucket_indices, // 桶信息
    c10::intrusive_ptr process_group,
    std::vector> expect_sparse_gradients,
    int64_t bucket_bytes_cap,
    bool find_unused_parameters,
    bool gradient_as_bucket_view,
    std::unordered_map paramNames)
    : replicas_(std::move(replicas)),
      process_group_(std::move(process_group)),
      expect_sparse_gradients_(std::move(expect_sparse_gradients)),
      expect_autograd_hooks_(false),
      require_finalize_(false),
      next_bucket_(0),
      has_marked_unused_parameters_(false),
      find_unused_parameters_(find_unused_parameters),
      gradient_as_bucket_view_(gradient_as_bucket_view),
      local_used_maps_reduced_(false),
      num_iterations_(0),
      num_buckets_ready_(0),
      has_rebuilt_bucket_(false),
      bucket_bytes_cap_(bucket_bytes_cap),
      divFactor_(kUnsetDivFactor),
      static_graph_(false),
      comm_hook_(nullptr),
      thread_local_state_(at::ThreadLocalState()),
      ddp_debug_level_(parseDistDebugLevel()),
      param_names_(std::move(paramNames)) {

  // Check whether the module is multi_device_module
  // 看看本模块是不是多设备模块
  {
    std::set unique_devices;
    for (const auto& v : replicas_[0]) { // 遍历张量
      auto device_idx = int(v.device().index()); // 得到张量的设备
      if (unique_devices.find(device_idx) == unique_devices.end()) {
        unique_devices.insert(device_idx); // 把设备插入到一个set结构之中
        if (unique_devices.size() > 1) { // 如果set内的设备多于一个,是多设备
          is_multi_device_module_ = true; 
          break;
        }
      }
    }
  }

  // If `expect_sparse_gradients` is not specified, initialize it such that
  // we do not expect sparse gradients for any parameter.
  if (expect_sparse_gradients_.empty()) {
    expect_sparse_gradients_ = std::vector>(
        replicas_.size(), std::vector(replicas_[0].size(), false));
  }

  // Initialize variable bucketing.
  // This can be reinitialized later after capturing runtime information.
  {
    std::lock_guard lock(mutex_);
    initialize_buckets(std::move(bucket_indices)); //初始化桶
  }

  // All variables are expected to have their `grad_fn` set to the gradient
  // accumulation function (since they are leafs in the autograd graph).
  // We store pointers to these functions such that we can check if they are
  // used in an autograd pass. If they are not, we know their grad tensors
  // can be marked as ready for reduction.
  {
    const auto replica_count = replicas_.size();
    grad_accumulators_.resize(replica_count);
    for (size_t replica_index = 0; replica_index < replica_count; // 只有replicas_[0]有意义
         replica_index++) {
      const auto variable_count = replicas_[replica_index].size(); //张量数目
      grad_accumulators_[replica_index].resize(variable_count); // 给grad_accumulators_分配内存
        
      for (size_t variable_index = 0; variable_index < variable_count;
           variable_index++) { // 遍历张量,variable_index 就是张量的index
        auto& variable = replicas_[replica_index][variable_index]; //得到具体的张量
        const auto index = VariableIndex(replica_index, variable_index); //每个张量生成一个VariableIndex

        // The gradient accumulator function is lazily initialized once.
        // Therefore we can use its presence in the autograd graph as
        // evidence that the parameter has participated in an iteration.
        auto grad_accumulator =
            torch::autograd::impl::grad_accumulator(variable); // 得到Variable::AutogradMeta的grad_accumulator_,即,用于累加叶子 Variable 的梯度累加器

#ifndef _WIN32
        using torch::distributed::autograd::ThreadLocalDistAutogradContext;
#endif
        // Hook to execute after the gradient accumulator has executed.
        hooks_.emplace_back(
            // 累加器添加hook,这个 hook 挂在 autograd graph 之上,在 backward 时负责梯度同步。
            // grad_accumulator 执行完后,autograd_hook 就会运行
            grad_accumulator->add_post_hook(
                torch::make_unique(
                    [=](const torch::autograd::variable_list& outputs,
                        const torch::autograd::variable_list& /* unused */) {
#ifndef _WIN32
                      this->rpc_context_.set(
                          ThreadLocalDistAutogradContext::getContextPtr());
#endif
                      this->autograd_hook(index); // 把reducer的autograd_hook函数添加进去
                      return outputs;
                    })),
            grad_accumulator);

        // Map raw function pointer to replica index and parameter index.
        // This is used later on when the autograd graph is traversed
        // to check for parameters for which no gradient is computed, if
        // find_unused_parameters=True.
        // Note that the mapping of gradient accumulator to variable should be
        // one to one as we deduplicate shared parameters before constructing
        // Reducer.
          
        // gradAccToVariableMap_ 存了grad_accumulator & index 的对应关系(函数指针和参数张量的对应关系),这样以后在 autograd graph 遍历寻找 unused parameters 就方便了
        if (find_unused_parameters_) {
          gradAccToVariableMap_[grad_accumulator.get()] = index;
        }

        numGradHooksTriggeredMap_[index] = 0;

        // The gradient accumulator is stored as weak_ptr in the autograd
        // metadata of the variable, so we have to keep it alive here for
        // the raw pointer to be valid.
        TORCH_CHECK(
            grad_accumulators_[replica_index][variable_index] == nullptr,
            c10::str(
                "Reducer tried to register duplicate grad accumulator for replica ",
                replica_index,
                " variable ",
                variable_index));
        grad_accumulators_[replica_index][variable_index] =
            std::move(grad_accumulator);
      }
    }
  }

  // Initialize backward stats vector.
  {
    const auto replica_count = replicas_.size();
    backward_stats_.resize(replica_count);
    const auto variable_count = replicas_[0].size();
    std::for_each(
        backward_stats_.begin(),
        backward_stats_.end(),
        [=](std::vector& v) { v.resize(variable_count); });
  }

  // See Note [Skip allreducing local_used_maps_dev]
  if (find_unused_parameters_) {
    initialize_local_used_map();
  }
}

我们接下来具体分析每一个部分。

2.2 初始化桶

initialize_buckets方法用来初始化桶,具体逻辑是对于每一个桶,添加其模型副本,对于每一个模型副本,添加张量列表:

  • 用分布式上下文设置 rpc_context_。

    • 如果在DDP构造函数内调用initialize_bucket,则 rpc上下文指针(rpc context ptr)是否为null 无关紧要,因为grad不会发生变化。
    • 如果在训练循环期间调用initialize_bucket,例如在rebuild_bucket 内部,因为grad可能会发生改变并指向bucket_view,那么它需要检查rpc context ptr是否为null。
    • 如果rpc context ptr是null,则改变 variable.grad(),否则,在rpc上下文中改变梯度。
  • 清空buckets_ 和 variable_locators_。

  • 重置variable_locators_的尺寸,这样每个variable都有一个bucket index。

  • 利用如下得到所有桶的个数和每个桶中副本个数:bucket_count = bucket_indices.size(); replica_count = replicas_.size();

  • 从0开始递增到 bucket_count,逐一初始化 Bucket。

    • 生成一个 Bucket bucket
    • 如果bucket_indices[bucket_index].size() == 1,说明这个桶期待一个single sparse gradient,则设置 bucket.expect_sparse_gradient = true。
    • 从0开始递增到replica_count,逐一初始化 BucketReplica。
      • 生成一个 BucketReplica replica
      • 如果这个桶期待一个single sparse gradient,则
        • 利用bucket_indices[bucket_index].front()取出向量第一个元素,设置为 variable_index。
        • 利用 variable_index 得到副本之中对应的variable。
        • 设置副本replica的变量列表,代码为replica.variables = {variable},这个副本只包括一个variable。
      • 否则说明是dense gradient,则
        • 遍历桶的variable,即利用 replicas_[replica_index][variable_index] 得到variable。
        • 设置variable的设备和数据类型
        • 给副本设置其variables,代码为:replica.variables.push_back(variable)。
        • 设置replica 的一些关于variable的元信息,这些元信息是flat contents相关的,比如offsets存储了各个张量在flat bucket contents中的offset。
        • 给relica.contents分配内存
        • 利用 initialize_bucket_views(replica, replica.contents) 初始化 cotnents 和 views。
        • 利用 bucket.replicas.push_back(std::move(replica)) 把这个 replica 加入到 bucket。
    • 遍历桶中的variable,代码为 bucket_indices[bucket_index]。
      • 设置 Reducer.variable_locators_,这样 Reducer 就知道如何在 bucket 之中确定一个varaible。bucket_indexbuckets_列表的位置,表示 buckets_ 之上的一个bucket。intra_bucket_index 是在 bucket replica 之中 vector 域的 variable index。
    • 设置桶的变量,bucket.variable_indices = std::move(bucket_indices[bucket_index]);
    • 利用 buckets_.push_back(std::move(bucket)) 把bucket这个桶加入到 Reducer之中。

具体代码是:

void Reducer::initialize_buckets(
    std::vector> bucket_indices) {
  // If initialize_buckets is called inside DDP constructor, then
  // it does not matter rpc context ptr is nullptr or not, as grad
  // will not be mutated.
  // If initialize_buckets is called during training loop, e.g, inside
  // rebuild_buckets(), since grad could be mutated and be pointed to
  // bucket_view, then it needs to check rpc context ptr is nullptr or not,
  // If rpc context ptr is nullptr, mutate variable.grad(); otherwise,
  // mutate grad in rpc context.
#ifndef _WIN32
  using torch::distributed::autograd::ThreadLocalDistAutogradContext;
  this->rpc_context_.set(ThreadLocalDistAutogradContext::getContextPtr());
#endif

  // This shouldn't be called if we're expecting autograd hooks to fire.
  TORCH_CHECK(
      !expect_autograd_hooks_,
      "`initialize_buckets` must NOT be called during autograd execution.");

  // Clear current bucket assignment.
  buckets_.clear();
  variable_locators_.clear();

  // Ensure we have a bucket index for every variable.
  variable_locators_.resize(replicas_[0].size());

  // Iterate over buckets.
  const auto bucket_count = bucket_indices.size();
  const auto replica_count = replicas_.size();
  buckets_.reserve(bucket_count);
  // 从0开始递增到bucket_count
  for (size_t bucket_index = 0; bucket_index < bucket_count; bucket_index++) {
    Bucket bucket; // 生成一个桶

    // TODO(@pietern): Validate indices.
    // Must be non-empty, unique, and unique across buckets.
    TORCH_CHECK(
        bucket_indices[bucket_index].size() > 0, "Empty bucket specified.");

    // Variables that expect sparse gradients must have their own bucket.
    if (bucket_indices[bucket_index].size() == 1) {
      // 说明这个桶期待一个single sparse gradient
      const auto variable_index = bucket_indices[bucket_index].front();
      bucket.expect_sparse_gradient =
          expect_sparse_gradients_[0][variable_index];
    } else {
      for (const auto variable_index : bucket_indices[bucket_index]) {
        TORCH_CHECK(
            !expect_sparse_gradients_[0][variable_index],
            "Buckets with more than one variable cannot include variables ",
            "that expect a sparse gradient.");
      }
    }

    // Iterate over model replicas. 从0开始递增到replica_count,遍历模型副本数目,为每一个模型副本都要做同样设置
    for (size_t replica_index = 0; replica_index < replica_count;
         replica_index++) {
      BucketReplica replica; // 生成一个副本

      if (bucket.expect_sparse_gradient) {
        // 说明这个桶期待一个single sparse gradient
        const auto variable_index = bucket_indices[bucket_index].front(); // 得到张量的index
        const auto& variable = replicas_[replica_index][variable_index]; // 得到张量
        TORCH_INTERNAL_ASSERT(bucket_indices[bucket_index].size() == 1);
        replica.variables = {variable}; // 这个副本只包括一个variable
      } else {
        at::TensorOptions options;
        // The start index of the variable in the flattened tensor.
        size_t offset = 0;

        // Reserve enough space for the per-variable fields stored in bucket
        // replica for efficiency.
        const size_t num_variables = bucket_indices[bucket_index].size();
        replica.variables.reserve(num_variables); 
        replica.offsets.reserve(num_variables);
        replica.lengths.reserve(num_variables);
        replica.sizes_vec.reserve(num_variables);

        // Iterate over bucket variables.
        for (const auto variable_index : bucket_indices[bucket_index]) { //遍历桶中的variable
          TORCH_CHECK(
              variable_index < replicas_[replica_index].size(),
              "Out of range variable index specified.");
          const auto& variable = replicas_[replica_index][variable_index];
          if (!options.has_device()) {
            options = options.device(variable.device());
          } else {
            TORCH_CHECK(
                variable.device() == options.device(),
                "All parameters in a bucket must be ",
                "placed on the same device.");
          }
          if (!options.has_dtype()) {
            options = options.dtype(variable.dtype());
          } else {
            TORCH_CHECK(
                variable.dtype() == options.dtype(),
                "All parameters in a bucket must have the same dtype.");
          }
          
          const auto length = variable.numel();
          // 给副本设置其variables
          replica.variables.push_back(variable); // 这里添加了一个新变量,所以最终能知道该桶中的变量数目
          // 设置replica 的一些关于variable的元信息
          replica.offsets.push_back(offset);
          replica.lengths.push_back(length);
          replica.sizes_vec.push_back(variable.sizes());
          offset += length;
        }

        // Allocate bucket contents tensor.
        replica.contents = at::empty({static_cast(offset)}, options);

        initialize_bucket_views(replica, replica.contents); // 初始化cotents和views
      }

      // Add bucket replica to enclosing bucket.
      bucket.replicas.push_back(std::move(replica)); // 桶的副本列表中添加一个新副本
    }

    // Map participating variables to this bucket.
    // This is identical across replicas so we only need to do this once.
    size_t intra_bucket_index = 0;
    for (const auto variable_index : bucket_indices[bucket_index]) { // 遍历桶中的variable
      TORCH_CHECK(
          variable_index < variable_locators_.size(),
          "Out of range variable index specified.");
      variable_locators_[variable_index] = // 这样 Reducer 就知道如何在 bucket 之中确定一个varaible
          VariableLocator(bucket_index, intra_bucket_index++);
    }
    bucket.variable_indices = std::move(bucket_indices[bucket_index]);

    buckets_.push_back(std::move(bucket)); // 把桶插入Reducer
  }
}

2.3 初始化视图

initialize_bucket_views 这里是设置 Replica 的contents 和 views。

// (see Note:  "Gradient Layout Contract" in initialize_buckets).
void Reducer::initialize_bucket_views(
    Reducer::BucketReplica& replica,
    at::Tensor& contents) {
  for (size_t i = 0; i < replica.variables.size(); i++) {
    auto& v = replica.variables[i];
    const auto offset = replica.offsets[i];
    const auto length = replica.lengths[i];
    if (v.is_non_overlapping_and_dense()) { // Dense类型的张量
      // If the param's memory is dense, match its layout, anticipating
      // the autograd engine (AccumulateGrad) will also create gradients
      // matching its layout.
      replica.bucket_views_in.push_back( // replica.bucket_views_in里面都是视图
          contents.as_strided(v.sizes(), v.strides(), offset));
    } else { // Sparse类型的张量
      // Fall back to a C-style contiguous view, again anticipating
      // AccumulateGrad will do the same when stashing grads for non-dense
      // params.
      replica.bucket_views_in.push_back( // replica.bucket_views_in里面都是视图
          contents.narrow(0, offset, length).view(v.sizes()));
    }
    // By default `bucket_views_out` and `bucket_views_in` are
    // essentially the same thing.
    replica.bucket_views_out = replica.bucket_views_in; // out也是视图

    // If gradient_as_bucket_view_ is set as true, then there are two cases to
    // handle: initialize_bucket_views could be called inside initialize_buckets
    // when rebuild_buckets, if grad has already been defined/calculated in
    // previous iteration, old grad needs to be copied into new bucket_view and
    // let grad point to the new bucket_view, initialize_bucket_views could also
    // be called inside initialize_buckets during construction. Grads are not
    // defined during construction time, in this case, do not let grad point to
    // bucket_view, because grads should be kept as being undefined for globally
    // unused parameters.
    if (gradient_as_bucket_view_) {
      auto& bucket_view = replica.bucket_views_in.back();
      runGradCallbackForVariable(v, [&](auto& grad) {
        if (grad.defined() && !grad.is_alias_of(bucket_view)) {
          bucket_view.copy_(grad);
          grad = bucket_view; // 梯度被修改了,需要回写
          // The grad is modefied and needs to be written back.
          return true;
        }
        // The grad is not modified and does not need to be written back.
        return false; // 不需要回写,因为没有被修改
      });
    }
  }
}

2.3.1 BucketReplica成员变量

我们先回忆一下BucketReplica的几个成员变量。

  • at::Tensor contents :把桶的内容展平的结果,即Flattened (1 dimensional) 之后的结果。
  • std::vector bucket_views_in :提供了从输入角度在 contents 之中查看具体梯度的方法。
  • std::vector bucket_views_out :提供了从输入角度在 contents 之中查看具体梯度的方法。

关于 std::vector bucket_views_instd::vector bucket_views_out 的进一步说明:

  • 这两个变量提供在 contents 之中操作具体梯度的方法,或者说,它们提供了视图(views),该视图可以操作contents 之中每个张量的梯度。用户把这两个变量作为入口点来把每个梯度的数据从 content 之中移入和移出。
  • 在 PyTorch 之中,视图是指创建一个方便查看的东西,视图与原数据共享内存,它只是将原有的数据进行整理,直接显示其中部分内容或者进行重排序后再显示出来。

也需要对几个 PyTorch 函数进行说明。

  • as_strided :依据现有tensor以及给定的步长来创建一个视图(类型仍然为tensor),需要注意,这里的结果是视图,所以这个张量依然和原始张量共享内存。
  • narrow :返回一个新的张量,其是原来张量的缩小版,但是这个张量依然和原始张量共享内存。

BucketReplica 逻辑具体如下图:

+------------------------------------------+
| BucketReplica                            |
|                                          |
|       vector bucket_views_in +--------------------+
|                                          |                |
|                                          |                |
|       vector bucket_views_out +--------------+    |
|                                          |           |    |
|                                          |           |    |
|                                          |           v    v
|                                          |     +-----+----+--------------------------+
|       Tensor contents  +---------------------> |Flattened (Tensor1, Tensor2, Tensor3)|
|                                          |     +-------------------------------------+
|                                          |
|                                          |
|       vector variables  +------------>  [Tensor1,Tensor2,Tensor3]
|                                          |
|                                          |
|                                          |
+------------------------------------------+

2.3.2 调用

如何调用?如果gradient_as_bucket_view_设置为true,则有两种情况需要处理:

  • rebuild_buckets 之中可以在initialize_bucket内调用initialize_bucket_view,如果grad在上一次迭代中已经定义/计算过,则需要将旧的grad复制到新的bucket_view中,并让grad指向新的bucket_view,
  • 在构造过程中,也可以在initialize_bucket中调用initialize_bucket_views。在构造期间不会定义梯度,在这种情况下,不要让梯度指向bucket_view,因为对于全局未使用的参数,梯度应保持为未定义。

2.4 初始化本地使用变量

initialize_local_used_map此处是初始化 local_used_maps_,我们回忆一下论文内容,local_used_maps_ 就是用来查找全局未使用参数(Globally Unused Parameters):

全局未使用参数(Globally Unused Parameters)的梯度在向前和向后过程中应保持不变。检测未使用的参数需要全局信息,因为在一个DDP过程中,一个参数可能在一次操作中不存在,但可能在另一个过程的同一次迭代中参与训练。因此DDP在位图中维护本地未使用的参数信息,并启动额外的AllReduce以收集全局位图。由于位图比张量尺寸小得多,因此模型中的所有参数共享同一位图,而不是创建每桶位图(per-bucket bitmaps)。位图位于CPU上,以避免为每次更新启动专用CUDA内核。但是,某些ProcessGroup后端可能无法在CPU 张量上运行AllReduce。例如,ProcessGroupNCCL仅支持CUDA张量。此外,由于DDP应该与任何定制的ProcessGroup后端一起工作,它不能假设所有后端都支持CPU张量。为了解决这个问题,DDP在同一设备上维护另一个位图作为第一个模型参数,并调用非阻塞拷贝操作(non-blocking copy)将CPU位图移动到设备位图以进行集合通信

具体代码如下:

void Reducer::initialize_local_used_map() {
  const auto replica_count = replicas_.size();
  const auto variable_count = replicas_[0].size();
  local_used_maps_.resize(replica_count);
  local_used_maps_dev_.resize(replica_count);

  for (size_t i = 0; i < replica_count; i++) {
    at::TensorOptions options;
    options = options.dtype(at::kInt);

    // Deliberately don't pin the memory even if local_used_maps_dev_ will
    // be cuda. See Note [local_used_maps_ -> local_used_maps_dev copying]
    local_used_maps_[i] =
        at::zeros({static_cast(variable_count)}, options);

    // This tensor needs to be on the same device as replica because backend
    // such as NCCL may not support CPU tensors, and hence it might not work
    // if we always put it on CPU.
    options = options.device(replicas_[i][0].device());
    local_used_maps_dev_[i] =
        at::empty({static_cast(variable_count)}, options);
  }
}

初始化流程大致如下:

                                    +
                                    |
                                    |
                                    v
                  rpc_context_ = ThreadLocalDistAutogradContext
                                    +
                                    |
                                    |
                                    v
                  buckets_ & variable_locators_ (clear & resize)
                                    +
                                    |
                                    |
                                    v
+----------------------->  from 0 ~ bucket_count :  +--------------------------->
|                                                                                +
|                                                                                |
|      +-------------------------------------------------------------------+     |
|      | init Bucket          set bucket_indices                           |     |
|      |                            +                                      |     |
|      |                            |                                      |     |
|      |                            |                                      |     |
|      |                            v                                      |     |
|      |   ^ +------------> from 0 ~ replica_count : +----------------->   |     |
|      |   |                                                           |   |     |
|      |   |  +---------------------------------------------------+    |   |     |
|      |   |  | init BucketReplica                                |    |   |     |
|      |   |  |                                                   |    |   |     |
<----+ |   +--+                                                   | <--+   | <---+
       |      |    bucket.replicas.push_back(std::move(replica))  |        |
       |      |                                                   |        |
       |      +----------------------+----------------------------+        |
       |                             |                                     |
       |                             |                                     |
       |                             v                                     |
       |             buckets_.push_back(std::move(bucket))                 |
       |                             +                                     |
       +-------------------------------------------------------------------+
                                     |
                                     v

得到的 Reducer 大致如下,这里需要注意的是 ,BucketReplica 每个桶只有一个:

            +----------------------------------------+                 +------------------+
            |tensor index 4, tensor index 5, tensor 6| <------+        | index 2, index 3 |
            +----------------------------------------+        |        +--------------+---+
                                                              |                       ^
                                                              |                       |
+---------------------------+   +---------------------------------------------------------+
| Reducer                   |   | +----------------------------------+     +------------+ |
|                           |   | |Bucket                     |      |     |Bucket    | | |
|                           |   | |                           +      |     |          | | |
| vector buckets_ +---> | | vector variable_indices  |     | indices ++ | |
|                           |   | |                                  |     |            | |
|                           |   | |  vector replicas  | ... | replicas   | |
|                           |   | |                         +        |     |   +        | |
|                           |   | |                         |        |     |   |        | |
|                           |   | +----------------------------------+     +------------+ |
|                           |   |                           |                  |          |
+---------------------------+   +---------------------------------------------------------+
                                                            |                  |
                                                            |                  |
                                                            v                  v
                          +---------------------------------------+   +-------------------+
                          |  +----------------------------------+ |   | +---------------+ |
                          |  | BucketReplica                    | |   | | BucketReplica | |
                          |  |                                  | |   | |               | |
                          |  |                                  | |   | |               | |
                          |  |  vector bucket_views_in  | |   | |   views_in    | |
                          |  |                                  | |   | |               | |
                          |  |  vector bucket_views_out | |   | |   views_out   | |
                          |  |                                  | |   | |               | |
                          |  |  Tensor contents                 | |   | |   contents    | |
                          |  |                                  | |   | |               | |
                          |  |  vector variables        | |   | |   variables   | |
                          |  |                     +            | |   | |      +        | |
                          |  +----------------------------------+ |   | +---------------+ |
                          +---------------------------------------+   +-------------------+
                                                   |                           |
                                                   |                           |
                                                   v                           v
                                   +---------------+------------+    +---------+----------+
                                   |Tensor 4, Tensor 5, Tensor 6|    | Tensor 2, Tensor 3 |
                                   +----------------------------+    +--------------------+

0x03 静态图

3.1 缘由

虽然 PyTorch 是动态图,但是用户可以明确地让DDP知道训练图是静态的,有如下情况时候可以设定:

  1. 已使用和未使用的参数集在整个训练循环中不变,在这种情况下,用户是否将find_unsued_parameters设置为true并不重要。

  2. 图形的训练方式在整个训练循环过程中不会改变(意味着不存在依赖于迭代的控制流)。当图被设置为静态时,DDP将支持以前不支持的case,比如:

    1. 可重入的反向传播。
    2. 多次activation checkpointing。
    3. activation checkpointing 并且find_unused_parameters = true。
    4. 并不是所有的输出张量都用于损失计算。。
    5. 在前向函数之外有一个模型参数。
    6. 当find_unsued_parameters=true时或者存在未使用的参数,可能会提高性能,因为DDP在每个迭代之内不会搜索网络来检查未使用的参数。

3.2 使用

_set_static_graph 可以配置静态图,此API应在DistributedDataParallel构造之后,并且在训练循环开始之前调用。并且,也应该以同样的方式对所有的rank 进行调用。例如:

ddp_model = DistributedDataParallel(model)
ddp_model._set_static_graph()
for i in range(n):

_set_static_graph 代码为:

def _set_static_graph(self):
    """
    Users can explicitly let DDP know the trained graph is static,
    when 1) the set of used and unused parameters will not change
    during the whole training loop; in this case, it does not matter
    whether users set find_unsued_parameters = true or not.
    2) how the graph is trained will not change during the whole training
    loop (meaning there is no control flow depending on iterations).
    When graph is set to be static, DDP will support cases that can not
    be supported in the past: 1) reentrant backwards
    2) activation checkpointing multiple times 3)
    activation checkpointing with find_unused_parameters = true.
    4) not all output tensors are used in loss calculation.
    5) there is model parameter that is outside of forward function.
    6) potentially improve performance when find_unsued_parameters = true
    or there are unused parameters, as DDP will not search graph in each
    iteraton to detect unused parameters when static_graph is set to be True.

    This API should be called after DistributedDataParallel construction, and
    before training loops starts. Also it should be called in the same way for
    all ranks. For example:
        ddp_model = DistributedDataParallel(model)
        ddp_model._set_static_graph()
        for i in range(n):
            .....
    """
    self.static_graph = True
    self.reducer._set_static_graph() # 调用 Reducer 进行配置
    self.logger._set_static_graph()
    if self.find_unused_parameters:
        warnings.warn(
            "You passed find_unused_parameters=true to DistributedDataParallel, "
            "`_set_static_graph` will detect unused parameters automatically, so "
            "you do not need to set find_unused_parameters=true, just be sure these "
            "unused parameters will not change during training loop while calling "
            "`_set_static_graph`."
        )

3.2 Reducer

Reducer 只有在第一次迭代之后才能生成静态图,因为毕竟PyTorch还是动态的,无论如何也得走一步动态生成。

void Reducer::set_static_graph() {
  std::lock_guard lock(mutex_);
  TORCH_CHECK(
      num_iterations_ == 0,
      "set_static_graph() should be called before training loop starts "
      "and after DistributedDataParallel is constructed.");
  static_graph_ = true;
  // when static_graph_ is set as true, always initialize_local_used_map
  // and detect the global unused parameters in the first iteration.
  initialize_local_used_map();
}

0x04 重建桶

4.1 为何要重建

因为 PyTorch 是动态生成计算图,所以需要相应重建桶。但是只有设置了静态图 并且 第一次迭代之后才会重建,如果设置 find_unused_parameters_,就不重建。

  // Returns true if we should rebuild buckets, else false. We only rebuild
  // buckets once after the first iteration and never rebuild them if
  // find_unused_parameters_.
  inline bool should_rebuild_buckets() const {
    return (static_graph_ || !find_unused_parameters_) && !has_rebuilt_bucket_;
  }

4.2 准备重建

我们首先看看重建之前的一些准备。

push_rebuilt_params 就是插入一个重建参数列表。

void Reducer::push_rebuilt_params(const VariableIndex& index) {
  rebuilt_params_.push_back(
      replicas_[index.replica_index][index.variable_index]);
  rebuilt_param_indices_.push_back(index.variable_index);
}

其次,push_rebuilt_params_for_all_indices 会遍历每个 replica,针对 replica 之中的每个 variable 进行设置。

void Reducer::push_rebuilt_params_for_all_indices() {
  std::lock_guard lock(mutex_);
  if (!should_rebuild_buckets() || !rebuilt_param_indices_.empty()) {
    return;
  }
  const auto replica_count = replicas_.size();
  for (size_t replica_index = 0; replica_index < replica_count;
       ++replica_index) {
    const auto variable_count = replicas_[replica_index].size();
    for (size_t variable_index = 0; variable_index < variable_count;
         ++variable_index) {
      const auto index = VariableIndex(replica_index, variable_index);
      push_rebuilt_params(index);
    }
  }
}

4.3 重建

我们接下来看看重建机制。

DDP 根据张量在后向传播中接收梯度的时间,使用 rebuilt_params_ 和 rebuilt_param_indices_ 来重建存储桶。

rebuild_buckets 函数进行广播通信调用,并且可以与下一个forward()调用重叠,因此它可以是异步的。

  • 在find_unused_parameters=true情况下重建bucket 就是异步操作,因为我们可以多次重建bucket,其中子图经过训练,参数索引顺序可能会更频繁地更改。
  • 对于find_unused_parameters=false的情况,bucket只重建一次,性能成本可以忽略不计。如果已重建存储桶, rebuild_buckets 则返回true。
bool Reducer::rebuild_buckets() {
  // Ensure reduction for previous backwards pass is finished. If user's model
  // has unused parameters for example, this will raise an error recommending to
  // run with find_unused_parameters=True, instead of the size mismatch
  // exception below.
  std::lock_guard lock(mutex_);
  ensure_prior_reduction_finished();
  if (!should_rebuild_buckets() || rebuilt_params_.empty()) {
    return false;
  }

  std::vector> rebuilt_bucket_indices;
  std::vector bucket_size_limits;
  bucket_size_limits.push_back(kDefaultFirstBucketBytes);
  bucket_size_limits.push_back(bucket_bytes_cap_);
  rebuilt_bucket_indices = compute_bucket_assignment_by_size(
      rebuilt_params_,
      bucket_size_limits,
      expect_sparse_gradients_[0],
      rebuilt_param_indices_);

  // For rebuilt bucket indices, it needs to be synced across all ranks.
  // Broadcast the newly rebuilt bucket indices from rank 0 in default.
  // After syncing up rebuilt bucket indices, initialize buckets for reducer.
  sync_bucket_indices(rebuilt_bucket_indices);

  has_rebuilt_bucket_ = true; // 只重建一次
  rebuilt_params_.clear();
  rebuilt_param_indices_.clear();

  initialize_buckets(std::move(rebuilt_bucket_indices));
  return true;
}

4.4 何时设定重建

重建仅在以下情况进行设定:

  1. 第一次重建存储桶

  2. static_graph_ is true 或 find_unused_parameters_ is false

  3. 此反向传播过程需要运行allreduce。

在这里,我们只需基于梯度到达顺序将张量及其参数索引转储到rebuilt_params_rebuilt_param_indices_。然后在finalize_backward() 结束时,将基于rebuilt_params_rebuilt_param_indices_重建存储桶,然后广播和初始化存储桶。

此外,我们只需要转储一个副本的张量和参数索引。

以 mark_variable_ready 为例,其中就会调用 push_rebuilt_params(index) 来插入列表。

void Reducer::mark_variable_ready(VariableIndex index) {
  // Rebuild bucket only if 1) it is the first time to rebuild bucket 2)
  // static_graph_ is true or find_unused_parameters_ is false,
  // 3) this backward pass needs to run allreduce.
  // Here, we just dump tensors and their parameter indices into
  // rebuilt_params_ and rebuilt_param_indices_ based on gradient arriving
  // order, and then at the end of finalize_backward(), buckets will be
  // rebuilt based on rebuilt_params_ and rebuilt_param_indices_, and then
  // will be broadcasted and initialized. Also we only need to dump tensors
  // and parameter indices of one replica.
  if (should_rebuild_buckets()) {
    push_rebuilt_params(index); // 插入列表
  }

  const auto replica_index = index.replica_index;
  const auto variable_index = index.variable_index;

  if (replica_index == 0) {
    checkAndRaiseMarkedTwiceError(variable_index);
    perIterationReadyParams_.insert(variable_index);
  }
  backward_stats_[replica_index][variable_index] =
      current_time_in_nanos() - cpu_timer_.backward_compute_start_time;

  // Any time we mark a variable ready (be it in line due to unused parameters,
  // or via an autograd hook), we require a call to the finalize function. If
  // this doesn't happen before the next iteration (or call to
  // `prepare_for_backwards`), we know something is wrong.
  require_finalize_ = true;

  const auto& bucket_index = variable_locators_[variable_index];
  auto& bucket = buckets_[bucket_index.bucket_index];
  auto& replica = bucket.replicas[replica_index];

  set_divide_factor();

  if (bucket.expect_sparse_gradient) {
    mark_variable_ready_sparse(index);
  } else {
    mark_variable_ready_dense(index);
  }

  // TODO(@pietern): Make this work for both CPU/CUDA tensors.
  // When using CPU tensors we don't need to do this.
  // // Record event so that we can wait for all of them.
  // auto& event = replica.events[bucket_index.intra_bucket_index];
  // event.record();

  // Check if this was the final gradient for this bucket.
  if (--replica.pending == 0) {
    // Kick off reduction if all replicas for this bucket are ready.
    if (--bucket.pending == 0) {
      mark_bucket_ready(bucket_index.bucket_index);
    }
  }

  // Run finalizer function and kick off reduction for local_used_maps once the
  // final bucket was marked ready.
  if (next_bucket_ == buckets_.size()) {

    if (dynamic_graph_find_unused()) {
      all_reduce_local_used_map();
    }

    // The autograd engine uses the default stream when running callbacks, so we
    // pass in the current CUDA stream in case it is not the default.
    const c10::Stream currentStream = get_current_stream();
    torch::autograd::Engine::get_default_engine().queue_callback([=] {
      std::lock_guard lock(this->mutex_);
      // Run callback with the current stream
      c10::OptionalStreamGuard currentStreamGuard{currentStream};
      if (should_collect_runtime_stats()) {
        record_backward_compute_end_time();
      }
      // Check that all buckets were completed and had their work kicked off.
      TORCH_INTERNAL_ASSERT(next_bucket_ == buckets_.size());
      this->finalize_backward();
    });
  }
}

4.5 直接调用

_rebuild_buckets 函数也可以直接调用,比如如下情况,就是在整个训练期间内在 forward 调用了一次。

def forward(self, *inputs, **kwargs):
    with torch.autograd.profiler.record_function("DistributedDataParallel.forward"):
        self.reducer.save_thread_local_state()
        if torch.is_grad_enabled() and self.require_backward_grad_sync:
            self.num_iterations += 1
            self.reducer.prepare_for_forward()
        if self.ddp_uneven_inputs_config.ddp_join_enabled:
            ones = torch.ones(1, device=self.device)
            work = dist.all_reduce(ones, group=self.process_group, async_op=True)
            if self.ddp_uneven_inputs_config.ddp_join_throw_on_early_termination:
                # Active ranks schedule an allreduce with zeros, inactive
                # ranks schedule them with 1. If the result != 0 it
                # indicates at least one rank has terminated and we should
                # throw.
                zeros = torch.zeros(1, device=self.device)
                dist.all_reduce(zeros, group=self.process_group)
                should_throw_stop_iteration = zeros.item()
                if should_throw_stop_iteration:
                    raise RuntimeError(
                        "Detected at least one rank that exhausted inputs. Throwing across all ranks."
                    )
            else:
                self.reducer._set_forward_pass_work_handle(
                    work,
                    self.ddp_uneven_inputs_config.ddp_join_divide_by_initial_world_size,
                )

        # Calling _rebuild_buckets before forward compuation,
        # It may allocate new buckets before deallocating old buckets
        # inside _rebuild_buckets. To save peak memory usage,
        # call _rebuild_buckets before the peak memory usage increases
        # during forward computation.
        # This should be called only once during whole training period.
        
        # 在这里进行直接调用
        if torch.is_grad_enabled() and self.reducer._rebuild_buckets(): # 设定
            logging.info("Reducer buckets have been rebuilt in this iteration.")

再比如 Join 方法也可以直接调用进行重建。

@contextmanager
def join(
    self,
    divide_by_initial_world_size=True,
    enable=True,
    throw_on_early_termination=False,
):
  
  									# 忽略其他代码
    
                    else:
                        # Some DDP process still needs to be joined.
                        if self.ddp_uneven_inputs_config.ddp_join_throw_on_early_termination:
                            # Schedule allreduce telling active ranks to terminate
                            ones = torch.ones(1, device=self.device)
                            dist.all_reduce(ones, group=self.process_group)
                            # Raising StopIteration doesn't throw error in python 3.6
                            # and throws RuntimeError in 3.7+ (PEP 479), so just
                            # raise RuntimeError here.
                            raise RuntimeError(
                                f"Rank {self._distributed_rank} exhausted all inputs."
                            )
                        if is_last_joiner:
                            is_last_joiner = False
                        # It will rebuild buckets only once during training period
                        
                        # 这里进行调用。
                        self.reducer._rebuild_buckets()
                        # Schedule a corresponding broadcast if we are syncing module
                        # buffers in the forward pass.
                        self._check_and_sync_module_buffers()   

既然提到了 Join,我们接下来就看看这个概念。

0x05 Join

Join 是为了解决训练数据不均匀的问题,就是允许某些输入较少的worker(其已经完成Join操作)可以继续和那些尚未结束的worker继续执行集合通信,就是一个欺骗操作(Shadow)。

5.1 缘起

支撑DDP背后的是几个集合通信库的all-reduce操作,其完成了各个worker之间的梯度同步。而当训练数据在 ranks 之间的输入是不均匀(uneven)的,就会导致DDP会挂起。因为集合通信要求在进程组中的所有rank都参与,因此如果一个rank的输入少,其他ranks会hang或者报错(取决于后端),而且任何类在执行同步集合通信时,在每次迭代都会遇到这个问题。

因此,DDP 给出了一个 "Join" API,Join是一个上下文管理器,在每个rank的训练循环之中使用。数据量少的 rank 会提前耗尽输入,这时它将给集合通信一个假象,从而会构建一个虚拟(dummy)的 all-reduce,以便在数据不足时候与其他 ranks 匹配。具体如何制造这个假象是由注册hook指定。

其大致思路如下:

                +----------------------------+
                |             Data           |
                |   +--------+   +--------+  |
                |   |        |   | Empty  |  |
                |   |        |   |        |  |
                |   +-----+--+   +--------+  |
                |         |                  |
                |         |                  |
                +----------------------------+
                          |
                          |
        +------------+    |               +------------+
        |            |    |               |            |
+---->  |    Model   |    |               |   Model    | <-----+
|       |            |    |               |            |       |
|       +------+-----+    |               +------+-----+       |
|              |          |                      |             |
|              |          |                      |             |
|              v          |                      v             |
|       +------+-----+    |             +--------+----------+  |
|       |  Forward   +<---+             | _JoinHook         |  |
|       |  (local)   |                  |                   |  |
|       +------+-----+                  |                   |  |
|              |                        |                   |  |
|              |                        |                   |  |
|              v                        | +---------------+ |  |
|       +------+-----+                  | | main_hook     | |  |
|       |  Backward  |                  | |               | |  |
|       |  (local)   |                  | |               | |  |
|       +------+-----+                  | |               | |  |
|              |                        | |               | |  |
|              |                        | |               | |  |
|              v                        | |               | |  |
|       +------+-----+                  | |               | |  |
|       | All-Reduce |     Sync grads   | |   All-Reduce  | |  |
|       |            | <--------------> | |   (Dummy)     | |  |
|       +------+-----+                  | |               | |  |
|              |                        | +---------------+ |  |
|              |                        +-------------------+  |
|              v                                 |             |
|     +--------+-------+                         |             |
|     | Update Weights |                         |             |
|     |                |                         |             |
|     +--------+-------+                         |             |
|              |                                 |             |
|              |                                 |             |
+--------------+                                 +-------------+

5.2 使用

5.2.1 DistributedDataParallel

Join 可以和 DistributedDataParallel 一起使用,比如下面的例子之中,会启动两个worker,分别是 rank 0 和 rank 1,rank 0 会得到5个输入,rank 1会得到6个输入,这就是输入不均衡。

如果没有使用 Join,则 rank 1 会在处理第6个输入时候死掉挂起,因为rank 0没有相关输入,所以rank 1只能等待。如果使用了 Join,则不会出现这种问题,可以顺利结束。

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed.algorithms.join import Join
from torch.nn.parallel import DistributedDataParallel as DDP

BACKEND = "nccl"
WORLD_SIZE = 2
NUM_INPUTS = 5

def worker(rank):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)

    model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
    # Rank 1 gets one more input than rank 0
    inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]

    num_inputs = 0
    with Join([model]):
        for input in inputs:
            num_inputs += 1
            loss = model(input).sum()
            loss.backward()

    print(f"Rank {rank} has exhausted all {num_inputs} of its inputs!")

def main():
    mp.spawn(worker, nprocs=WORLD_SIZE, join=True)

if __name__ == "__main__":
    main()

这将产生以下输出(其中print来自 0 级和 1 级的 ranks,可以任意排序):

Rank 0 has exhausted all 5 of its inputs!
Rank 1 has exhausted all 6 of its inputs!

5.2.2 ZeroRedundancyOptimizer

Join上下文不仅是和一个类合作,也可以和多个类一起,比如PyTorch 的ZeroRedundancyOptimizer

from torch.distributed.optim import ZeroRedundancyOptimizer as ZeRO
from torch.optim import Adam

def worker(rank):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)

    model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
    optim = ZeRO(model.parameters(), Adam, lr=0.01)
    # Rank 1 gets one more input than rank 0
    inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]

    num_inputs = 0
    # Pass both `model` and `optim` into `Join()`
    with Join([model, optim]):
        for input in inputs:
            num_inputs += 1
            loss = model(input).sum()
            loss.backward()
            optim.step()

    print(f"Rank {rank} has exhausted all {num_inputs} of its inputs!")

这将产生与以前相同的输出。显着的变化是需要另外将ZeroRedundancyOptimizer实例传入 Join()

后续会对ZeroRedundancyOptimizer等机制也进行分析。

5.3 原理

在最新文档 https://pytorch.org/tutorials/advanced/generic_join.html 之中,PyTorch 给出了一定解释,我们翻译如下。

为了更好的使用,我们将介绍Join类以及支持类JoinableJoinHook

备注:这部分在 v1.10.0 版本代码之中。

5.3.1 Joinable

首先,与Join上下文管理器兼容的类必须继承抽象基类Joinable。特别的,Joinable必须实现:

  • join_hook(self, **kwargs) -> JoinHook

这将返回 的JoinHook实例Joinable,用来确定加入的进程应如何影响由Joinable 执行的每次迭代集体通信。

  • join_device(self) -> torch.device

这将返回Join上下文管理器用来执行集体通信的设备,例如torch.device("cuda:0")torch.device("cpu")

  • join_process_group(self) -> ProcessGroup

这将返回Join上下文管理器用于执行集体通信的进程组。

概括一下,JoinHook负责具体行为,join_device 和 join_process_group 负责具体集合通信

需要注意的是,join_devicejoin_process_group是必需的属性,他们可以确保上下文管理器能够安排"加入"和"未加入"进程之间的集体通信。一种用法是使用 all-reduce 计算每次迭代中"未加入"进程的数量。另一种用法是实现 throw_on_early_termination=True所需的机制,我们将在下面解释。

DistributedDataParallelZeroRedundancyOptimizer已经继承Joinable并实现了上面的方法,这就是为什么我们可以在前面的例子中直接使用它们。

class DistributedDataParallel(Module, Joinable):

class ZeroRedundancyOptimizer(Optimizer, Joinable):

DDP 涉及到提供数据,所以继承Joinable可以理解,ZeroRedundancyOptimizer 为何也需要继承?这是因为 ZeroRedundancyOptimizer 可以和 DDP 一起合作,并且 ZeroRedundancyOptimizer 内部也有集合操作,所以需要被 Join 一起管理。

Joinable类应该确保调用Joinable构造函数,因为它初始化了一个JoinConfig实例,上下文管理器在内部使用JoinConfig来确保正确性。JoinConfig将在每个Joinable _join_config字段中保存。

5.3.2JoinHook

接下来,让我们分解一下JoinHook类。JoinHook提供了两个进入上下文管理器的入口点:

  • main_hook(self) -> None

当存在尚未加入(Join)的 rank 时,每个加入(Join)的 rank 都会重复调用此钩子。它目的是在每次训练迭代(例如,在一次前向传递,反向传递和优化器步骤)之中,隐藏由Joinable所执行的集体通信,即已经Join的rank 如何与未Join的rank执行集合通信

  • post_hook(self, is_last_joiner: bool) -> None

一旦所有 ranks 都加入,这个钩子就会被调用。它传递了一个额外的 bool参数is_last_joiner,其表明此 rank 是否是最后加入的 rank 之一。该参数可能对同步有用。

5.3.2.1 ZeroRedundancyOptimizer

我们以 内置的 ZeroRedundancyOptimizer main hook 来给出一个钩子的具体例子:因为加入的 rank 仍然负责更新和同步其参数分片,所以 main hook 依然执行优化器步骤。

class _ZeROJoinHook(_JoinHook):
    def __init__(self, zero):
        assert isinstance(zero, ZeroRedundancyOptimizer), \
            "ZeRO join hook requires passing in a ZeroRedundancyOptimizer " \
            "instance as the state"
        self.zero = zero
        super().__init__()

    def main_hook(self):
        """
        Performs an optimizer step, which updates the joined process's shard of
        the parameters and broadcasts those parameters.
        """
        self.zero.step()

step函数简略如下:

def step(
    self,
    closure: Optional[Callable[[], float]] = None,
    **kwargs: Any,
) -> Optional[float]:
    _Join.notify_join_context(self) # 这里会通知
    # Check if the model trainability has changed
    is_trainable_mask = self._get_is_trainable_mask()
    if is_trainable_mask != self._is_trainable_mask:
        self._build_param_buckets()
        self._is_trainable_mask = is_trainable_mask

    # Sync the exposed `param_groups` attributes to the local optimizer in
    # case they have been updated
    self._sync_param_groups(self.param_groups, self.optim.param_groups)

    # Run the optimizer step on this shard only
    if closure is not None:
        loss = self.optim.step(closure=closure, **kwargs)  # type: ignore[call-arg]
    else:
        loss = self.optim.step(**kwargs)

    # Sync all of the updated parameter shards across the ranks
    self._sync_parameters()

    # Sync any updated attributes in the local optimizer to the exposed
    # `param_groups`
    self._sync_param_groups(self.optim.param_groups, self.param_groups)

    return loss

再来看看DistributedDataParallel

  • main_hook 依然会做相关的一系列操作来欺骗其他rank。
  • post-hook 会从最后加入的rank之一来广播最终更新的模型,以确保模型在所有rank中都是相同的。
class _DDPJoinHook(_JoinHook):
    def __init__(self, ddp, divide_by_initial_world_size):
        """
        Sets config variables for internal usage.
        """
        ddp.logger._set_uneven_input_join()
        self.ddp = ddp
        self.ddp._divide_by_initial_world_size = divide_by_initial_world_size
        super().__init__()

    def main_hook(self):
        """
        Shadows the DDP collective communication operations in the forward and
        backward passes.
        """
        ddp = self.ddp
        # Buckets are rebuilt only once during a training period
        ddp.reducer._rebuild_buckets()

        # Schedule a broadcast if we are syncing module buffers in the
        # forward pass
        ddp._check_and_sync_module_buffers()

        # Check if need to sync in the backward pass
        work = ddp._check_global_requires_backward_grad_sync(is_joined_rank=True)
        work.wait()
        should_sync_backwards = work.result()[0].item() != 0
        # Forward parameter sync is disabled in the next iteration if we
        # are skipping gradient sync this iteration, so set
        # `require_forward_param_sync` accordingly
        ddp.require_forward_param_sync = should_sync_backwards
        if not should_sync_backwards:
            return

        # Schedule one allreduce per gradient bucket to match the backward
        # pass allreduce
        ddp._match_all_reduce_for_bwd_pass()

        # Check if we need to allreduce locally unused parameters
        if ddp.find_unused_parameters:
            ddp._match_unused_params_allreduce()

        # Rebuilt parameters are pushed only once during a training period
        ddp.reducer._push_all_rebuilt_params()

    def post_hook(self, is_last_joiner: bool):
        """
        Syncs the final model to ensure that the model is the same across all
        processes.
        """
        self.ddp._sync_final_model(is_last_joiner)

_sync_final_model 这里会广播最新的模型。

# When running in join model, agrees upon a common rank and broadcast model
# parameters to all other ranks.
def _sync_final_model(self, is_last_joiner):
    # Agree upon the process that will be the authoritative model copy.
    # The current rank is a candidate for being the authoritative copy if
    # is_last_joiner=True. We break ties via picking the larger rank.
    self._authoritative_rank = self._find_common_rank(
        self._distributed_rank, is_last_joiner
    )
    self._sync_params_and_buffers(authoritative_rank=self._authoritative_rank)

5.3.3 Join

最后,让我们看看这些基础类是如何适应Join类本身的。

  • __init__(self, joinables: List[Joinable], enable: bool = True, throw_on_early_termination: bool = False)

正如我们在前面的例子中看到的,构造函数接收一个参与训练循环的Joinable列表 。这些应该是在每次迭代中执行集体通信的类。

enablebool类型,如果您知道不会有不均匀的输入,则可以设置为 False,在这种情况下,上下文管理器变得类似于contextlib.nullcontext(). 这也可能会在参与Joinable列表之中禁用join-related计算。

throw_on_early_terminationbool类型,其可以设置为True,以便让每个等级在检测到不均匀输入时引发异常。这对于不符合上下文管理器要求的情况很有用,这通常是当来自不同类的集体通信可以任意交错(interleaved)时,例如DistributedDataParallel与具有SyncBatchNorm层的模型一起使用时 。在这种情况下,应将此参数设置为 True以便应用程序逻辑可以捕获异常并确定如何继续。

  • 核心逻辑出现在该__exit__()方法中,该方法在存在未加入的 rank 时会进行循环调用每个 Joinable的主钩子,然后一旦所有rank加入,就调用它们的 post 钩子。主钩子和后钩子都按照Joinables 传入的顺序进行迭代。
  • 上下文管理器需要来自未加入进程的心跳。因此,每个Joinable类都应该在每次迭代的集体通信之前调用Join.notify_join_context() 。上下文管理器将确保只有第一个传入的Joinable实际发送心跳。

5.4 例子

我们通过一个例子来具体看看。下面代码之中,每个rank会打印(1)在Join之前看到的所有rank的输入数量,以及(2)所有rank的输入总数。

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed.algorithms.join import Join, Joinable, JoinHook

BACKEND = "nccl"
WORLD_SIZE = 2
NUM_INPUTS = 5

class CounterJoinHook(JoinHook):
    r"""
    Join hook for :class:`Counter`.

    Arguments:
        counter (Counter): the :class:`Counter` object using this hook.
        sync_max_count (bool): whether to sync the max count once all ranks
            join.
    """
    def __init__(
        self,
        counter,
        sync_max_count
    ):
        self.counter = counter
        self.sync_max_count = sync_max_count

    def main_hook(self):
        r"""
        Shadows the counter's all-reduce by all-reducing a dim-1 zero tensor.
        """
        t = torch.zeros(1, device=self.counter.device)
        dist.all_reduce(t)

    def post_hook(self, is_last_joiner: bool):
        r"""
        Synchronizes the max count across all :class:`Counter` s if
        ``sync_max_count=True``.
        """
        if not self.sync_max_count:
            return
        rank = dist.get_rank(self.counter.process_group)
        common_rank = self.counter.find_common_rank(rank, is_last_joiner)
        if rank == common_rank:
            self.counter.max_count = self.counter.count.detach().clone()
        dist.broadcast(self.counter.max_count, src=common_rank)

class Counter(Joinable):
    r"""
    Example :class:`Joinable` that counts the number of training iterations
    that it participates in.
    """
    def __init__(self, device, process_group):
        super(Counter, self).__init__()
        self.device = device
        self.process_group = process_group
        self.count = torch.tensor([0], device=device).float()
        self.max_count = torch.tensor([0], device=device).float()

    def __call__(self):
        r"""
        Counts the number of inputs processed on this iteration by all ranks
        by all-reducing a dim-1 one tensor; increments its own internal count.
        """
        Join.notify_join_context(self)
        t = torch.ones(1, device=self.device).float()
        dist.all_reduce(t)
        self.count += t

    def join_hook(self, **kwargs) -> JoinHook:
        r"""
        Return a join hook that shadows the all-reduce in :meth:`__call__`.

        This join hook supports the following keyword arguments:
            sync_max_count (bool, optional): whether to synchronize the maximum
                count across all ranks once all ranks join; default is ``False``.
        """
        sync_max_count = kwargs.get("sync_max_count", False)
        return CounterJoinHook(self, sync_max_count)

    @property
    def join_device(self) -> torch.device:
        return self.device

    @property
    def join_process_group(self):
        return self.process_group

    # 确定最后join的rank,由于后加入的rank可能不止一个,所以选择rank最大的rank来同步  
    def find_common_rank(self, rank, to_consider):
        r"""
        Returns the max rank of the ones to consider over the process group.
        """
        common_rank = torch.tensor([rank if to_consider else -1], device=self.device)
        dist.all_reduce(common_rank, op=dist.ReduceOp.MAX, group=self.process_group)
        common_rank = common_rank.item()
        return common_rank

def worker(rank):
    assert torch.cuda.device_count() >= WORLD_SIZE
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)

    counter = Counter(torch.device(f"cuda:{rank}"), dist.group.WORLD)
    inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]

    with Join([counter], sync_max_count=True):
        for _ in inputs:
            counter()

    print(f"{int(counter.count.item())} inputs processed before rank {rank} joined!")
    print(f"{int(counter.max_count.item())} inputs processed across all ranks!")

def main():
    mp.spawn(worker, nprocs=WORLD_SIZE, join=True)

if __name__ == "__main__":
    main()

由于rank 0看到5个输入,rank 1看到6个,因此产生输出:

10 inputs processed before rank 0 joined!
11 inputs processed across all ranks!
11 inputs processed before rank 1 joined!
11 inputs processed across all ranks!

需要强调的一些要点:

  • Counter实例在每次迭代中执行一个all reduce操作,因此:
    • 对于已经Join的rank,其 main hook 也执行单个all reduce来对整体通信进行蒙骗操作( shadow it),注意这个 all-reduce是调用一个为0的tensor,所以对整体结果不影响。
    • 其他未 Join 的 rank 会以为这依然是一个正确的满员的集合操作。
    • 这样就处理了不均匀输入。
  • Counter类在其 __call__()方法的开头调用 Join.notify_join_context() ,因为这是每次集合操作(all-reduce)的地方,需要在这里通知上下文管理器,本示例还没有Join(已经结束的rank不会调用到这里)。
  • 'is_last_joiner'参数用于确定post-hooks中的广播源。
  • 我们将 sync_max_count 关键字参数传递给上下文管理器,上下文管理器会将其转发给'Counter'的join hook。
  • post-hooks之中,会对 self.counter.max_count 进行广播。

0xFF 参考

pytorch分布式系列3——分布式训练时,torch.utils.data.distributed.DistributedSampler做了什么?

pytorch分布式系列1——搞清torch.distributed.launch相关的环境变量

pytorch分布式系列2——DistributedDataParallel是如何做同步的?

pytorch(分布式)数据并行个人实践总结——DataParallel/DistributedDataParallel

Pytorch的nn.DataParallel

https://discuss.pytorch.org/t/dataparallel-imbalanced-memory-usage/22551/20

https://pytorch.org/docs/stable/distributed.html

PyTorch 源码解读之分布式训练了解一下?

实操教程|PyTorch AutoGrad C++层实现

PYTORCH 自动微分(一)

PyTorch如何加速数据并行训练?分布式秘籍大揭秘

pytorch分布式训练(二init_process_group)

https://pytorch.org/tutorials/intermediate/ddp_tutorial.html

https://pytorch.org/docs/master/notes/ddp.html

https://pytorch.org/tutorials/intermediate/dist_tuto.html

PyTorch 源码解读之 DP & DDP:模型并行和分布式训练解析

Pytorch模型中的parameter与buffer

【PyTorch开发者日 2020】PyTorch分布式数据并行(DDP)

[中文字幕] 深入理解 PyTorch 中的 Hook 机制

[中文字幕] 深入解读 Pytorch AutoGrad

DISTRIBUTED TRAINING WITH UNEVEN INPUTS USING THE JOIN CONTEXT MANAGER

谈谈torch1.10中的ZeroRedundancyOptimizer和Join

你可能感兴趣的:([源码解析] PyTorch 分布式(11) ----- DistributedDataParallel 之 构建Reducer和Join操作)