【Dual-Path-RNN-Pytorch源码分析】Dual_RNN_Block

Dual_RNN_Block应该是整个网络中最重要的部分了。

这里,每一个Block相当于网络内部的一层 ,源码中默认设置4层Dual_RNN_Block。

每一个Dual_RNN_Block又分为intra_rnn(块内rnn)和inter_rnn(块间rnn)

intra_rnninter_rnn是dual的灵魂,但是刚开始接触很难理解这个概念。
结合代码和原论文的配图,可以理解为对Dual_RNN_Block的3D上对K和S维度训练

输入张量

输入的张量shape为[B, N, K, S], 具体的来源可以参考这里。

其中B为batch-size,每一个batch里的N,K,S,如下图。(K=2P)
输入张量

intra_rnn

RNN是最最后一维做训练,但是与其他维度也有关联。尤其是-2维度。
intra_rnn是针对K的训练,K是形容block的变量,即在这个维度上理解为intra

下图为intra_rnn block的流程图
【Dual-Path-RNN-Pytorch源码分析】Dual_RNN_Block_第1张图片

inter_rnn

intra_rnn是针对S的训练,S是形容block个数的变量,是block与block之间的关系,即在这个维度上理解为inter

下图为inter_rnn block的流程图
【Dual-Path-RNN-Pytorch源码分析】Dual_RNN_Block_第2张图片

双剑合璧 Dual_RNN_Block

上述两个intra_rnn + inter_rnn就是dual_rnn了。

但是有点细节:

  1. intra_rnn的结果是加上了输入张量x 再送到 inter_rnn计算
  2. inter_rnn的结果是加上了intra_rnn的结果再输出
    【Dual-Path-RNN-Pytorch源码分析】Dual_RNN_Block_第3张图片
    最后,把paper中的图贴在这里方便大家理解。
    【Dual-Path-RNN-Pytorch源码分析】Dual_RNN_Block_第4张图片

你可能感兴趣的:(源码分析,pytorch,rnn,深度学习)