[作者:DeepLearningStack,阿里巴巴算法工程师,开源TensorFlow Contributor]
因为Rendezvous所涉及的模块组件较多,为了让读者循序渐进地理解TensorFlow中的通信机制,决定将Rendezvous分成多个系列,由浅入深分开梳理。这样做的目的不但能让读者阅读时对整体层次结构有较好的把握,而且简短的篇幅也便于阅读,所以建议读者按顺序阅读本系列。 本文是TensorFlow通信机制系列的第一篇文章,侧重整体结构和本地传输通信的梳理。
1 // Parses the key constructed by CreateKey and parse src/dst device 2 // names into structures respectively. 3 struct ParsedKey { 4 StringPiece src_device; 5 DeviceNameUtils::ParsedName src; 6 uint64 src_incarnation = 0; 7 StringPiece dst_device; 8 DeviceNameUtils::ParsedName dst; 9 StringPiece edge_name; 10 11 ParsedKey() {} 12 ParsedKey(const ParsedKey& b) { *this = b; } 13 14 ParsedKey& operator=(const ParsedKey& b); 15 StringPiece FullKey() const { return buf_; } 16 17 private: 18 friend class Rendezvous; 19 friend class SendOp; 20 friend class RecvOp; 21 string buf_; 22 };
CreateKey只要接受五个参数即可安全构造字符串形式的Key,这里面特殊之处有两个,a. 参数中frame_and_iter一般直接取自OpKernelContext中的FrameAndIter对象;b. src_incarnation要做一个十六进制的字符串转换。CreateKey函数的输出是以分号(";")为分隔符的字符串,该字符串同样包含五个域。CreateKey是一个static函数,代码比较简单,就不在这里列出。随后我们这个字符串传入ParseKey函数即可完成结构体ParsedKey的解析,解析过程如下。
1 /* static */ 2 Status Rendezvous::ParseKey(StringPiece key, ParsedKey* out) { 3 if (key.data() == out->buf_.data()) { 4 // Caller used our buf_ string directly, so we don't need to copy. (The 5 // SendOp and RecvOp implementations do this, for example). 6 DCHECK_EQ(key.size(), out->buf_.size()); 7 } else { 8 // Make a copy that our StringPieces can point at a copy that will persist 9 // for the lifetime of the ParsedKey object. 10 out->buf_.assign(key.data(), key.size()); 11 } 12 StringPiece s(out->buf_); 13 StringPiece parts[5]; 14 for (int i = 0; i < 5; i++) { 15 parts[i] = ConsumeNextPart(&s, ';'); 16 } 17 if (s.empty() && // Consumed the whole string 18 !parts[4].empty() && // Exactly five parts 19 DeviceNameUtils::ParseFullName(parts[0], &out->src) && 20 strings::HexStringToUint64(parts[1], &out->src_incarnation) && 21 DeviceNameUtils::ParseFullName(parts[2], &out->dst) && 22 !parts[3].empty()) { 23 out->src_device = StringPiece(parts[0].data(), parts[0].size()); 24 out->dst_device = StringPiece(parts[2].data(), parts[2].size()); 25 out->edge_name = StringPiece(parts[3].data(), parts[3].size()); 26 return Status::OK(); 27 } 28 return errors::InvalidArgument("Invalid rendezvous key: ", key); 29 }
1 // The caller is a tensor producer and it sends a message (a tensor 2 // "val" and a bool "is_dead") under the given "key". 3 // 4 // {val, is_dead} is bundled as a message sent and received. 5 // Typically, is_dead is set by some control flow nodes 6 // (e.g., a not-taken branch). args is passed by Send to the 7 // Recv function to communicate any information that the Recv 8 // function might need. This is typically only necessary for 9 // Send/Recv on the same worker. 10 // 11 // Send() never blocks. 12 virtual Status Send(const ParsedKey& key, const Args& args, const Tensor& val, const bool is_dead) = 0; 13 14 virtual void RecvAsync(const ParsedKey& key, const Args& args, DoneCallback done) = 0; 15 16 // Synchronous wrapper for RecvAsync. 17 Status Recv(const ParsedKey& key, const Args& args, Tensor* val, bool* is_dead, int64 timeout_ms); 18 Status Recv(const ParsedKey& key, const Args& args, Tensor* val, bool* is_dead);
1 Status Rendezvous::Recv(const ParsedKey& key, const Args& recv_args, 2 Tensor* val, bool* is_dead, int64 timeout_ms) { 3 Status ret; 4 Notification n; 5 RecvAsync(key, recv_args, 6 [&ret, &n, val, is_dead](const Status& s, const Args& send_args, 7 const Args& recv_args, const Tensor& v, 8 const bool dead) { 9 ret = s; 10 *val = v; 11 *is_dead = dead; 12 n.Notify(); 13 }); 14 if (timeout_ms > 0) { 15 int64 timeout_us = timeout_ms * 1000; 16 bool notified = WaitForNotificationWithTimeout(&n, timeout_us); 17 if (!notified) { 18 return Status(error::DEADLINE_EXCEEDED, 19 "Timed out waiting for notification"); 20 } 21 } else { 22 n.WaitForNotification(); 23 } 24 return ret; 25 }
1 Status Send(const ParsedKey& key, const Args& send_args, const Tensor& val, 2 const bool is_dead) override { 3 uint64 key_hash = KeyHash(key.FullKey()); 4 VLOG(2) << "Send " << this << " " << key_hash << " " << key.FullKey(); 5 6 mu_.lock(); 7 if (!status_.ok()) { 8 // Rendezvous has been aborted. 9 Status s = status_; 10 mu_.unlock(); 11 return s; 12 } 13 14 ItemQueue* queue = &table_[key_hash]; 15 if (queue->empty() || queue->front()->IsSendValue()) { 16 // There is no waiter for this message. Append the message 17 // into the queue. The waiter will pick it up when arrives. 18 // Only send-related fields need to be filled. 19 Item* item = new Item; 20 item->value = val; 21 item->is_dead = is_dead; 22 item->send_args = send_args; 23 if (item->send_args.device_context) { 24 item->send_args.device_context->Ref(); 25 } 26 queue->push_back(item); 27 mu_.unlock(); 28 return Status::OK(); 29 } 30 31 // There is an earliest waiter to consume this message. 32 Item* item = queue->front(); 33 queue->pop_front(); 34 mu_.unlock(); 35 36 // Notify the waiter by invoking its done closure, outside the 37 // lock. 38 DCHECK(!item->IsSendValue()); 39 item->waiter(Status::OK(), send_args, item->recv_args, val, is_dead); 40 delete item; 41 return Status::OK(); 42 }
1 void RecvAsync(const ParsedKey& key, const Args& recv_args, 2 DoneCallback done) override { 3 uint64 key_hash = KeyHash(key.FullKey()); 4 VLOG(2) << "Recv " << this << " " << key_hash << " " << key.FullKey(); 5 6 mu_.lock(); 7 if (!status_.ok()) { 8 // Rendezvous has been aborted. 9 Status s = status_; 10 mu_.unlock(); 11 done(s, Args(), recv_args, Tensor(), false); 12 return; 13 } 14 15 ItemQueue* queue = &table_[key_hash]; 16 if (queue->empty() || !queue->front()->IsSendValue()) { 17 // There is no message to pick up. 18 // Only recv-related fields need to be filled. 19 Item* item = new Item; 20 item->waiter = std::move(done); 21 item->recv_args = recv_args; 22 if (item->recv_args.device_context) { 23 item->recv_args.device_context->Ref(); 24 } 25 queue->push_back(item); 26 mu_.unlock(); 27 return; 28 } 29 30 // A message has already arrived and is queued in the table under 31 // this key. Consumes the message and invokes the done closure. 32 Item* item = queue->front(); 33 queue->pop_front(); 34 mu_.unlock(); 35 36 // Invokes the done() by invoking its done closure, outside scope 37 // of the table lock. 38 DCHECK(item->IsSendValue()); 39 done(Status::OK(), item->send_args, recv_args, item->value, item->is_dead); 40 delete item; 41 }
1 // Copies "input" to "output" between devices accessible to the 2 // local process via some DMA-like method. "edge_name" is the name 3 // of the tensor being copied, for debugging purposes. Depending on 4 // the type of devices and memory in use, the copy may be performed 5 // synchronously or asynchronously. 'done' will be invoked only 6 // after the copy is actually complete. 7 static void ViaDMA(StringPiece edge_name, DeviceContext* send_dev_context, 8 DeviceContext* recv_dev_context, Device* src, Device* dst, 9 const AllocatorAttributes src_alloc_attr, 10 const AllocatorAttributes dst_alloc_attr, 11 const Tensor* input, Tensor* output, 12 int dev_to_dev_stream_index, StatusCallback done);
本文是TensorFlow通信机制系列的第一篇文章,先通过抛出高并发情况下消息通信两端的对应问题引出TensorFlow中的ParsedKey结构设计的必要性,然后给出了Rendezvous全局类图,最后详细的分析了LocalRendezvous的消息传输实现过程。TensorFlow的通信机制的完美的阐释了Rendezvous一词的含义——无论是Send端还是Recv端都需要在临界资源Table中“约会”,进行消息的传输。随后还着重分析了异步情况下,本属于consumer的waiter函数调用时机设计问题——为了保证waiter函数的执行不被阻塞,从设计上采取Late invoke的方案。IntraProcessRendezous本质是LocalRendezvous的一层封装,它在数据拷贝上面做了更多的工作,借助LocalRendezvous实现了Send和Recv处于不同或相同种类Device情况下,对上层完全透明的拷贝过程。由于篇幅原因,特意将TensorFlow通信机制分为多个系列分析,作为第一篇文章,本篇介绍了Rendezvous的基本框架。在该系列之后的文章中,还会对跨进程的通信进行详细地分析。