在tensorflow RNN layer的搭建(GRU,LSTM等)中,我们展示了如何调用 tensorflow 内置模块和函数,搭建RNN layer。然而,当一般的GRU/LSTM layer不适用时,我们希望对其 cell 进行改进,实现自主设计的改造版的RNN cell。
这方面研究工作代表的典型有:Time-LSTM,论文链接为:What to Do Next: Modeling User Behaviors by Time-LSTM
下面,我们从tensorflow的内置函数 tf.scan()出发,展示如何自主实现/改造 RNN cell。
tf.scan(
fn,
elems,
initializer=None,
parallel_iterations=10,
back_prop=True,
swap_memory=False,
infer_shape=True,
name=None
)
fn : 一个二元函数
elems:一个tensor list
initializer:一个tensor,作为初始化值
实际上,tf.scan()所能应用的类型不止如此,这里只举了我们所需要用到的部分
在tf.scan 记录中有一个很好的例子,我们借鉴一下:
x = [1,2,3]
z = 10
x = tf.convert_to_tensor(x)
z = tf.convert_to_tensor(z)
def f(x,y):
return x+y
g = tf.scan(fn=f,elems = x,initializer=z)
sess = tf.Session()
sess.run(tf.global_variables_initializer)
sess.run(g)
得到:
In [97]: sess.run(g)
Out[97]: array([11, 13, 16], dtype=int32)
详细的计算逻辑如下:
11 = 10(初始值initializer)+ 1(x[0])
13 = 11(上次的计算结果)+2(x[1])
16 = 13(上次的计算结果)+3(x[2])
可以发现,tf.scan() 从initializer 开始,把函数 fn 不断应用在上次计算结果和elems当前的每一个元素上,不断迭代,得到一系列输出。
如果我们把elems看作RNN 的输入seq,把fn 看作cell 的内部作用函数,那么输出seq 就是一系列隐状态[ h 1 , h 2 , ⋯ , h N h_1, h_2, \cdots, h_N h1,h2,⋯,hN]。这和RNN的作用机制是相同的!
下面我们以较为简洁典型的GRU cell 为例,来看tf.scan()的应用
def GRUunit(prev_h, x):
dim_item = tf.shape(x)[1]
dim_hid = DIM_HID
w_xr = tf.get_variable('w_xr', [dim_item, dim_hid])
w_hr = tf.get_variable('w_hr', [dim_hid, dim_hid])
br = tf.get_variable('br', dim_hid)
r = tf.sigmoid(tf.matmul(x, w_xr) + tf.matmul(prev_h, w_hr) + br)
w_xz = tf.get_variable('w_xz', [dim_item, dim_hid])
w_hz = tf.get_variable('w_hz', [dim_hid, dim_hid])
bz = tf.get_variable('bz', dim_hid)
z = tf.sigmoid(tf.matmul(x, w_xz) + tf.matmul(prev_h, w_hz) + bz)
w_xh = tf.get_variable('w_xh', [dim_item, dim_hid])
w_hh = tf.get_variable('w_hh', [dim_hid, dim_hid])
bh = tf.get_variable('bh', dim_hid)
h_ = tf.nn.tanh(tf.matmul(x, w_xh) + tf.matmul(tf.multiply(r, prev_h), w_hh) + bh)
h = tf.multiply(z, h_) + tf.multiply(1-z, prev_h)
return h
def GRUlayer(inputs, layer_name, dim_hid):
with tf.variable_scope(layer_name):
batch_size = tf.shape(inputs)[0]
initial_hidden = tf.zeros([batch_size, dim_hid], tf.float32)
states = tf.scan(GRUunit, tf.transpose(inputs,[1,0,2]), initializer = initial_hidden, name='states')
return tf.transpose(states,[1,0,2]), states[-1,:]
至此,我们就完成了GRU cell 的自主实现。
注意到RNN input 的维度分别为[batch_size, steps, item_dim],而tf.scan() 是对steps维度展开,因此在输入和输出时要对input的前两维进行转置。
对于其他RNN cell,只需要对GRUunit函数进行改写即可。