Improved Semantic Representations From Tree-Structured Long Short-Term Memory Networks
这篇论文提出了Tree-LSTM,把LSTM结果扩展到树状的网络拓扑结构,对于一些NLP任务(或者说 处理具有树型数据结构的数据)提供了特征提取的方法。
简单来说,有两种模型,一种是Child-Sum,一种是N-ray。
个人认为:
本人的一个任务是处理树型结构AST(可以查看我的文章,有关AST处理方法),所以我采用 二叉树 N-ary Tree-LSTM。
tensorflow2 (不是1)
采用自定义layers.Layer类,建立Tree-LSTM单元。如果要使用,可以把定义的类放入Model类的call()做操作(个人没用用Sequential使用过)。
import tensorflow as tf
from tensorflow.keras import *
class MyDense(layers.Layer):
def __init__(self,input_dim,output_dim,matrices_num=2,activation='sigmoid'):
super(MyDense, self).__init__()
self.kernel = self.add_weight('w',[input_dim,output_dim])
self.bias = self.add_weight('b',[output_dim])
self.matrices_num = matrices_num
self.matrices = []
self.activation = layers.Activation(activation)
for i in range(matrices_num):
self.matrices.append(self.add_weight('u'+str(i),[output_dim,output_dim]))
def call(self, inputs, hiddens=None, **kwargs):
"""
How to use:
MyDense(input_dim,output_dim,matrices_num,activation)(inputs, hiddens)
:param inputs: x's Embedding
:param hiddens: x's children hidden state
:param kwargs: None
:return: [b, emb],[b, child_num, dim] => h:[b, dim]
"""
out = inputs @ self.kernel
for i in range(self.matrices_num):
out += hiddens[...,i,:] @ self.matrices[i]
out += self.bias
out = self.activation(out)
return out
####### test MyDense
# dense = MyDense(10, 4,matrices_num=2,activation='tanh')
# # [b, seq, emb] => [b, seq, dim]
# a = tf.random.normal([4,30,10])
# # [b, seq, num, dim]
# h = tf.random.normal([4,30,2,4])
# print(dense(a, h))
# print(dense.kernel)
# print(dense.bias)
class TreeLSTM(layers.Layer):
def __init__(self,embed_dim, hidden_dim, child_num=2):
super(TreeLSTM, self).__init__()
self.embed_dim = embed_dim
self.hidden_dim = hidden_dim
self.child_num = child_num
self.inputGate = MyDense(embed_dim, hidden_dim,matrices_num=child_num)
self.forgetGate = []
for i in range(child_num):
self.forgetGate.append(MyDense(embed_dim, hidden_dim,matrices_num=child_num))
self.outputGate = MyDense(embed_dim, hidden_dim,matrices_num=child_num)
self.updateInfo = MyDense(embed_dim, hidden_dim,matrices_num=child_num,activation='tanh')
self.activation = layers.Activation('tanh')
def call(self, inputs, hiddens=None, cells = None, training=None, **kwargs):
"""
How to use:
TreeLSTM(embed_dim, hidden_dim, child_num)(inputs, hiddens cells)
:param inputs: Node's embedding
:param hiddens: Children's hidden state
:param cells: Children's cell state
:param kwargs: None
:return: [b, embed], [b, child_num, dim], [b, child_num, dim] => h:[b, dim] c:[b,dim]
"""
c = self.inputGate(inputs, hiddens) * self.updateInfo(inputs, hiddens)
# hiddens * mask -> forgetgate
mask = 1 - tf.linalg.band_part(tf.ones([self.child_num, self.child_num], dtype=tf.float32),0,0)
mask = tf.expand_dims(mask, axis=-1)
for i in range(self.child_num):
c += self.forgetGate[i](inputs, hiddens * mask[i]) * cells[...,i,:]
h = self.outputGate(inputs, hiddens) * self.activation(c)
return h, c
######## test TreeLSTM
# treelstm = TreeLSTM(10, 4)
# # x [b, 10]
# x = tf.random.normal([5, 10])
# # h [b, 2, 4]
# h = tf.random.normal([5,2,4])
# # c [b, 2, 4]
# c = tf.random.normal([5,2,4])
# print(treelstm(x,h,c))
MyDense 类:9~12基本公式实现
Tree-LSTM类:所有公式的集合
使用时,如果对孩子数量有要求,修改child_sum即可
有疑问,或者有问题,欢迎留言