MXNet: wait_to_read 方法

wait_to_read

在mxnet中,类ndarray可以调用 wait_to_read,官方给出的该函数解释是:

Waits until all previous write operations on the current array are finished.
This method guarantees that all previous write operations that pushed
into the backend engine for execution are actually finished.
Examples

    >>> import time
    >>> tic = time.time()
    >>> a = mx.nd.ones((1000,1000))
    >>> b = mx.nd.dot(a, a)
    >>> print(time.time() - tic) # doctest: +SKIP
    0.003854036331176758
    >>> b.wait_to_read()
    >>> print(time.time() - tic) # doctest: +SKIP
    0.0893700122833252

翻译过来就是,调用该方法可以保证,之前进行的该ndarray的所有写操作都完成了。由于MXNet是一个异步框架,我们使用python接口调用方法的时候,只是将该方法push给了执行者,由执行者来安排什么时候执行。所以,我们就无法知道,某一个操作是不是完成了。

那么,在知道了该方法的功能后,我们的疑问就变成了,mxnet是怎么做到的呢?

在mxnet的类ndarray中定义了WaitToRead方法,如下:

inline void WaitToRead() const {
    if (is_none()) return;
    Engine::Get()->WaitForVar(ptr_->var);
}

其中ptr_->var指向了该ndarray实例对应的唯一的var,engine会根据该var来进行判断,当前操作对应的是哪个ndarray。

那么,我们只需要知道Engine对应的WariForVar做了什么就好了。在mxnet中实现了很多种engine。但是我们用的engine都继承了类ThreadedEngineWaitForVar方法。没有重载。
在这个方法里面,主要做了一件事情,就是将一个操作push给了engine。
该操作需要完成的函数是:

this, &done](RunContext, CallbackOnComplete on_complete) {
      if (engine_info_) {
        LOG(INFO) << "Sync is executed";
      }
      {
        std::unique_lock lock{finished_m_};
        done.store(true);
      }
      finished_cv_.notify_all();
      if (engine_info_) {
        LOG(INFO) << "Sync is notified";
      }
      on_complete();
    }

输出是var
函数将done赋值为true,它被初始化为false。之后会调用finished_cv_.notify_all();来通知该方法继续运行。
因为,该方法会在将操作push之后,调用

{
    std::unique_lock lock{finished_m_};
    finished_cv_.wait(lock, [this, &done]() {
        return done.load() || kill_.load();
    });
  }

导致一直等待,直到done为true。

到这里,我们大概明白了ndarray的wait_to_read的方法如何实现等待。
总体思路就是,调用了engine的WaitForVar方法。这个方法会将一个操作异步push给engine,这个操作将一个原子变量done从false改变为true。由于push是异步的,它会立刻返回,返回后,WaitForVar一直等待,知道done为true。
而engine会根据变量来安排执行,由于这次push的操作的输出是ndarray对应的var,因此,engine会保证所有之前的写入操作完成。

你可能感兴趣的:(MXNet: wait_to_read 方法)