MIMNCell超详细分解 论文看不懂点这里就对了!

简介

最近的一个工作需要用到MIMNCell,但是原本的论文其实一篇比较工业化的论文,里面对于离线部分的MIMN可以说完全没有解释,我一步一步的将官方的实现在这里做一些分享。

官方paper在这里:
传送门

MIMNCell的主要模块

首先MIMNCell我认为是 RNN结构的一种改进。当然其中增加了很多的模块,但是输入仍然是一个时序序列。
其中主要包括:

  1. Controller
  2. Memory Read
  3. Memory Write
  4. MIU部分,其中重要的是理解MIU部分维护的S矩阵

基本的工作流程是:

controller

在这里我们假设 MIMNCell的输入是 x;以及上一个MIMNCell的各种状态,如果是第一次输入,也就是t=0的情况下,就用0初始化。
首先根据将 x和上一个状态的Memory Read的输出 read_vector 输入到Controller中,输入的时候需要将x和read_vector 拼接起来
而Controller就是一个标准的GRU。
然后对于Controller这个GRU来说,inputs = x和read_vector的concat,初始状态就是上一个controller上一个时间步的state。
这样我们就得到了一个controller_output & controller_state.

根据一个fully_connect layer,输入是controller_output, 输出是一个非常大的维度(包括head_parameter & erase_add parameters):

  • 前num_parameters_per_head * num_heads是head的参数向量

其中num_parameters_per_head = memory_vector_dim(超参)
其中num_head = read_head_num + write_head_num

  • 后self.memory_vector_dim * 2 * self.write_head_num 表示的是erase parameter 和 add parameter
    这些paramerter都是根据controller_output生成,然后用于下面的处理。

NTM 部分(memory read & write)

NTM维护着一个M矩阵,read和write也是用于更新和修改这个M矩阵的。

对于read和write部分,每一次输入第t个behavior vector,都会生成一个paper中叫weight vector,这里无论是read head还是write head都是一样的weight vector的获取方式,论文中的部分我直接复制在这里:
MIMNCell超详细分解 论文看不懂点这里就对了!_第1张图片
但是这里首先 k t k_t kt的生成论文中并没有说明,
在coding中 k t k_t kt的获取是利用每个head parameter中的memory_vector_dim经过一个tanh激活函数得到。然后经过上图的计算就可以为每个head(read & write)都获得一个w向量。

注:
在后面的操作中,memory read的部分是利用read head,然后read head利用的是相对应的w向量
memory read的部分是利用read write,然后write head利用的是相对应的w向量

对每个read head

每一个read head都会输出一个read vector,
MIMNCell超详细分解 论文看不懂点这里就对了!_第2张图片
这里需要注意的是此时的M还是没有更新过的,也就是t-1状态的M。
到这里为止,memory read就完成了他的 工作。

现在其实我们是在介绍NTM的部分,还没有介绍memory write是如何更新,但是现在Memory read已经获取到了输出,所以现在就到了MIU的部分。memory write对于M的更新是在MIU更新之后。

MIU部分

翻译过来就是memory 归纳单元。
这里我认为是paper和coding的实现差别有点大的地方。

  1. 首先,根据第一个read head的w向量,找到M中权重最大的slot。这里假设是index=0的slot,然后进行一个one-hot编码,获得一个mask。
    mask = [1, 0, 0, …],当然别忘了要考虑batchsize,这里只是简单举一个例子。
  2. 另外维护一个channel_rnn,就是一个orignal GRU结构。
    向channel_rnn输入 concat([x, M*mask]),初始化state就是t-1 step的channel_rnn state。

这里注意M是t-1 step的M,因为在t step,还没有利用memory write对M进行更新。

然后更新S
S = channel_rnn_state * mask + channel_rnn_prev_state * (1-mask)
这里可以理解为用w中权重最大的index对应的channel_rnn的state来更新t-1 step的旧state,然后就获得了S。
同时获得t step的channel_rnn的输出:
ouput_t. = channel_rnn_output * mask + channel_rnn_prev_ouput * (1-mask)

上面这个过程就是论文中的:
在这里插入图片描述

回到NTM中的Memory write

论文中的memory write的更新如下所示:
MIMNCell超详细分解 论文看不懂点这里就对了!_第3张图片
其中 e t e_t et a t a_t at 就是上述erase parameter 和 add parameter, w t w w_t^w wtw表示的就是write w向量,不同的write head对应不同的w。
然后就可以对M进行更新。

总结

这样一个完整的t step的MIMNCell的更新就是这样完成的。
可能我的逻辑有一些不好理解,如果有问题和指正欢迎大家直接留言。

大家共勉~~

你可能感兴趣的:(推荐系统,推荐系统,MIMN)