Tensorflow源码分析之ExecutorState

写在前面

该文章以尽量详细的注释来表明源代码中每一句的含义。
Tensorflow版本:1.10
我会以自己习惯的顺序或者便于理解的顺序或者仅我知道的地方来说,如有错误,请评论指出,感激不尽~.
该代码注释仅作为记录,不作为其他使用

1. ScheduleReady

  //该函数主要将该次的Process函数处理完成过后得到的ready队列里的node分配到各个线程中去
  void ExecutorState::ScheduleReady(const TaggedNodeSeq& ready,
                                    TaggedNodeReadyQueue* inline_ready) {
    //该Node激活的Node都会放到ready队列里。
    //比如有个node的输出是其他3个node的输入,且后3个node都仅有1个输入。则这三个节点会被PropagateOutput函数放入ready队列。
    
	//如果ready队列为空,表明该节点做完之后没有激活后继节点,可能有多种原因(比如后继节点需要两个输入节点,但input只满足了一个),whatever,如果没有调度节点,则直接退出。
    if (ready.empty()) return;

    int64 scheduled_nsec = 0;
    if (stats_collector_) {
      scheduled_nsec = nodestats::NowInNsec();
    }
    //inline_ready这个队列存放着该线程或者叫该Process内ready了但未执行的Node。
    //如果没有该队列:
    if (inline_ready == nullptr) {
      // Schedule to run all the ready ops in thread pool.
      //那么完全并行,将ready队列里的每个node放到一个新的Process线程执行。
      for (auto& tagged_node : ready) {
        runner_([=]() { Process(tagged_node, scheduled_nsec); });
      }
      return;
    }
    const GraphView& gview = impl_->gview_;
    //是否已经给当前线程分配了一个比较昂贵(耗时的)计算op
    const TaggedNode* curr_expensive_node = nullptr;
    //该情况是在有inlined_ready队列情况下。也遍历每个node。
    for (auto& tagged_node : ready) {
      const NodeItem& item = *gview.node(tagged_node.node->id());
      //如果这个节点不是昂贵的(只重计算节点,比如GPU节点)
      if (tagged_node.is_dead || !item.kernel_is_expensive) {
        // Inline this inexpensive node.
        //那么就将该节点放入inline_ready里(反正他不重要,等一等就等一等吧)。
        inline_ready->push_back(tagged_node);
      } else {
      	//反之,如果他是一个比较重要的节点。那么:
      	//如果该线程或该Process函数已经被分配到了一个比较昂贵的op
        if (curr_expensive_node) {
          // Dispatch to another thread since there is plenty of work to
          // do for this thread.
          //那么就本来分配给自己的昂贵op分配给其他线程。
          runner_(std::bind(&ExecutorState::Process, this, *curr_expensive_node,
                            scheduled_nsec));
        }
        //再将新的昂贵op分配给自己。
        curr_expensive_node = &tagged_node;
        //以此循环,意思就是将ready队列里的所有昂贵op先分配给其他cpu执行,最后分配给自己。这样比较均匀。。。
      }
    }
    //最后的最后,发现本身这个线程分发到了一个昂贵op
    if (curr_expensive_node) {
    	//如果inline_ready空闲,即该线程没啥活儿干。那么:
      if (inline_ready->empty()) {
        // Tail recursion optimization
        //该线程处理这个比较昂贵的op
        inline_ready->push_back(*curr_expensive_node);
      } else {
        // There are inline nodes to run already. We dispatch this expensive
        // node to other thread.
        //如果该线程已经很累了,有很多op在执行了。那么就将这个昂贵的op交给其他线程处理。
        runner_(std::bind(&ExecutorState::Process, this, *curr_expensive_node,
                          scheduled_nsec));
      }
    }
  }

2. Process

//Process函数的作用是处理一个准备好的op,并把该op的输出激活。该函数有两个参数,tagged_node的含义是已经ready,正在处理的节点。
void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
    WithContext wc(context_);
    const GraphView& gview = impl_->gview_;
    //ready 队列存放所有ready了的node(op)
    TaggedNodeSeq ready;
    // inline_ready 存放tagged_node节点激活的node(op)
    TaggedNodeReadyQueue inline_ready;
	
	//以下都是准备一些参数。
    // Parameters passed to OpKernel::Compute.
    TensorValueVec inputs;
    DeviceContextVec input_device_contexts;
    AllocatorAttributeVec input_alloc_attrs;

    OpKernelContext::Params params;
    params.step_id = step_id_;
    Device* device = impl_->params_.device;
    params.device = device;
    params.log_memory = log_memory_;
    params.record_tensor_accesses = impl_->device_record_tensor_accesses_;
    params.rendezvous = rendezvous_;
    params.collective_executor = collective_executor_;
    params.session_state = session_state_;
    params.tensor_store = tensor_store_;
    params.cancellation_manager = cancellation_manager_;
    params.call_frame = call_frame_;
    params.function_library = impl_->params_.function_library;
    params.resource_manager = device->resource_manager();
    params.step_container = step_container_;
    params.slice_reader_cache = slice_reader_cache_;
    params.inputs = &inputs;
    params.input_device_contexts = &input_device_contexts;
    params.input_alloc_attrs = &input_alloc_attrs;
    params.runner = &runner_;
    params.stats_collector = stats_collector_;

    Status s;
    NodeExecStatsInterface* stats = nullptr;

    EntryVector outputs;
    bool completed = false;
    //将已经准备好的tagged_node放入inline_ready, 开始跑Process最关键的循环。
    //再提示一遍,inline_ready里放的是tagged_node激活的所有节点。
    inline_ready.push_back(tagged_node);
    while (!inline_ready.empty()) {
      //拿出第一个节点。
      tagged_node = inline_ready.front();
      inline_ready.pop_front();
      //获取第一个节点的一些重要属性,比如node,id,item等。
      const Node* node = tagged_node.node;
      FrameState* input_frame = tagged_node.input_frame;
      const int64 input_iter = tagged_node.input_iter;
      const int id = node->id();
      //gview 可以理解成计算图的一个引用,类似于string_view。
      const NodeItem& item = *gview.node(id);

      // TODO(misard) Replace with a finer-grain enabling flag once we
      // add better optional debugging support.
      if (vlog_ && VLOG_IS_ON(1)) {
             mutex_lock l(input_frame->mu);
        input_frame->GetIteration(input_iter)->mark_started(item.pending_id);
      }

      // Set the device_context for this node id, if it exists.
      // 获取该id在哪个设备上执行的。
      if (id < device_context_map_.size()) {
        params.op_device_context = device_context_map_[id];
      }

      params.track_allocations = false;
      stats = nullptr;
      if (stats_collector_ && !tagged_node.is_dead) {
        stats = stats_collector_->CreateNodeExecStats(node);
        // Track allocations if and only if we are collecting statistics, and
        // `stats` object is expecting allocations to be tracked.
        params.track_allocations = stats ? stats->TrackAllocations() : false;
        nodestats::SetScheduled(stats, scheduled_nsec);
        nodestats::SetAllStart(stats);
      }

      if (vlog_) {
        VLOG(1) << "Process node: " << id << " step " << params.step_id << " "
                << SummarizeNode(*node) << (tagged_node.is_dead ? " is dead" : "")
                << " device: " << device->name();
      }

      Entry* input_tensors = GetInputTensors(input_frame, input_iter);
      Entry* first_input = input_tensors + item.input_start;
      outputs.clear();

      TensorReferenceVector accessed_tensors;
      DeviceContext* device_context = nullptr;
      // Only execute this node if it is not dead or it is a send/recv
      // transfer node. For transfer nodes, we need to propagate the "dead"
      // bit even when the node is dead.
      bool launched_asynchronously = false;
      if (tagged_node.is_dead && !IsTransferNode(node)) {
        outputs.resize(item.num_outputs);
      } else {
        // Prepares inputs.
        bool is_input_dead = false;
        s = PrepareInputs(item, first_input, &inputs, &input_device_contexts,
                          &input_alloc_attrs, &is_input_dead);
        if (!s.ok()) {
          // Clear inputs.
          int num_inputs = item.num_inputs;
          for (int i = 0; i < num_inputs; ++i) {
            (first_input + i)->ClearVal();
          }
          MaybeMarkCompleted(input_frame, input_iter, id);
          // Continue to process the nodes in 'inline_ready'.
                completed = NodeDone(s, item.node, ready, stats, &inline_ready);
          continue;
        }

        // Set up compute params.
        OpKernel* op_kernel = item.kernel;
        params.op_kernel = op_kernel;
        params.frame_iter = FrameAndIter(input_frame->frame_id, input_iter);
        params.is_input_dead = is_input_dead;
        params.output_attr_array = item.output_attrs();
        params.forward_from_array = item.forward_from();

        if (item.kernel_is_async) {
          // Asynchronous computes.
          AsyncOpKernel* async = item.kernel->AsAsync();
          DCHECK(async != nullptr);
          launched_asynchronously = true;
          AsyncState* state =
              new AsyncState(params, tagged_node, &item, first_input, stats);

          auto done = [this, state]() {
            Device* device = impl_->params_.device;
            NodeExecStatsInterface* stats = state->stats;  // Shorthand
            Entry* first_input = state->first_input;       // Shorthand

            nodestats::SetOpEnd(stats);
            EntryVector outputs;
            Status s = ProcessOutputs(*state->item, &state->ctx, &outputs, stats);
            nodestats::SetMemory(stats, &state->ctx);
            if (vlog_) {
              VLOG(2) << "Async kernel done: " << state->item->node->id()
                      << " step " << step_id_ << " "
                      << SummarizeNode(*state->item->node)
                      << (state->tagged_node.is_dead ? " is dead" : "")
                      << " device: " << device->name();
            }

            // Clears inputs.
            const int num_inputs = state->item->num_inputs;
            for (int i = 0; i < num_inputs; ++i) {
              (first_input + i)->ClearVal();
            }
            FrameState* input_frame = state->tagged_node.input_frame;
            const int64 input_iter = state->tagged_node.input_iter;
            const int id = state->tagged_node.node->id();
            MaybeMarkCompleted(input_frame, input_iter, id);
            TaggedNodeSeq ready;
            if (s.ok()) {
              PropagateOutputs(state->tagged_node, state->item, &outputs, &ready);
            }
            outputs.clear();
                     if (s.ok() && impl_->device_record_tensor_accesses_) {
              // Get the list of all tensors accessed during the execution
              TensorReferenceVector accessed;
              state->ctx.retrieve_accessed_tensors(&accessed);
              nodestats::SetReferencedTensors(stats, accessed);
              // callee takes ownership of the vector
              device->ConsumeListOfAccessedTensors(state->ctx.op_device_context(),
                                                   accessed);
            }
            const bool completed =
                NodeDone(s, state->item->node, ready, stats, nullptr);
            delete state;
            if (completed) Finish();
          };
          nodestats::SetOpStart(stats);
          device->ComputeAsync(async, &state->ctx, done);
        } else {
          // Synchronous computes.
          OpKernelContext ctx(&params, item.num_outputs);
          nodestats::SetOpStart(stats);

          if (TF_PREDICT_FALSE(
                  MightTrace(item, event_collector_, trace_using_annotations_))) {
            const string& op_name = op_kernel->name();
            tracing::ScopedRegion region(tracing::EventCategory::kCompute,
                                         op_name);
            if (trace_using_annotations_) {
              // The OpKernel may create child activities (such as GPU kernel
              // launches), so use a `ScopedAnnotation` to relate these activities
              // in the trace.
              tracing::ScopedAnnotation activity(
                  op_name, strings::StrCat(op_kernel->type_string(),
                                           "#id=", step_id_, "#"));
              device->Compute(op_kernel, &ctx);
            } else {
              // Use the cheaper `ScopedActivity` to trace just the OpKernel
              // execution.
              tracing::ScopedActivity activity(
                  op_name,
                  strings::StrCat(op_kernel->type_string(), "#id=", step_id_,
                                  "#"),
                  item.kernel_is_expensive);
              device->Compute(op_kernel, &ctx);
            }
          } else {
            // In the common case, avoid creating any tracing objects.
            device->Compute(op_kernel, &ctx);
          }

          nodestats::SetOpEnd(stats);
                   s = ProcessOutputs(item, &ctx, &outputs, stats);
          if (s.ok() && impl_->device_record_tensor_accesses_) {
            // Get the list of all tensors accessed during the execution
            ctx.retrieve_accessed_tensors(&accessed_tensors);
            device_context = ctx.op_device_context();
          }
          nodestats::SetMemory(stats, &ctx);
        }
      }

      if (!launched_asynchronously) {
        if (vlog_) {
          VLOG(2) << "Synchronous kernel done: " << id << " step "
                  << params.step_id << " " << SummarizeNode(*node)
                  << (tagged_node.is_dead ? " is dead: " : "")
                  << " device: " << device->name();
        }

        // Clears inputs.
        const int num_inputs = item.num_inputs;
        for (int i = 0; i < num_inputs; ++i) {
          (first_input + i)->ClearVal();
        }
        MaybeMarkCompleted(input_frame, input_iter, id);
        // Propagates outputs.
        if (s.ok()) {
          PropagateOutputs(tagged_node, &item, &outputs, &ready);
        }
        outputs.clear();
        if (!accessed_tensors.empty()) {
          nodestats::SetReferencedTensors(stats, accessed_tensors);
          // device_context is set above in synchronous computes
          device->ConsumeListOfAccessedTensors(device_context, accessed_tensors);
        }
        if (stats) {
          scheduled_nsec = nodestats::NowInNsec();
        }
        // Postprocess.
        completed = NodeDone(s, item.node, ready, stats, &inline_ready);
      }
    }  // while !inline_ready.empty()

    // This thread of computation is done if completed = true.
    if (completed) Finish();
  }

3. NodeDone

你可能感兴趣的:(c++,TensorFlow)