Tensorflow2自定义网络 1. Tree-LSTM大致介绍与代码实现

Improved Semantic Representations From Tree-Structured Long Short-Term Memory Networks

大致介绍

这篇论文提出了Tree-LSTM,把LSTM结果扩展到树状的网络拓扑结构,对于一些NLP任务(或者说 处理具有树型数据结构的数据)提供了特征提取的方法。

简单来说,有两种模型,一种是Child-Sum,一种是N-ray。

个人认为:

  • Child-Sum的缺点是,特征提取会丢失子节点的位置信息(孩子节点的特征向量做加法,并不能确定位置)
  • N-ary的缺点是,孩子节点数量不确定的时候,难以为孩子节点声明一定数量的权重矩阵

本人的一个任务是处理树型结构AST(可以查看我的文章,有关AST处理方法),所以我采用 二叉树 N-ary Tree-LSTM。

环境

tensorflow2 (不是1)

采用自定义layers.Layer类,建立Tree-LSTM单元。如果要使用,可以把定义的类放入Model类的call()做操作(个人没用用Sequential使用过)。

代码

import tensorflow as tf
from tensorflow.keras import *

class MyDense(layers.Layer):
    def __init__(self,input_dim,output_dim,matrices_num=2,activation='sigmoid'):
        super(MyDense, self).__init__()
        self.kernel = self.add_weight('w',[input_dim,output_dim])
        self.bias = self.add_weight('b',[output_dim])
        self.matrices_num = matrices_num
        self.matrices = []
        self.activation = layers.Activation(activation)

        for i in range(matrices_num):
            self.matrices.append(self.add_weight('u'+str(i),[output_dim,output_dim]))

    def call(self, inputs, hiddens=None, **kwargs):
        """
        How to use:
        MyDense(input_dim,output_dim,matrices_num,activation)(inputs, hiddens)

        :param inputs: x's Embedding
        :param hiddens: x's children hidden state
        :param kwargs: None
        :return: [b, emb],[b, child_num, dim] => h:[b, dim]
        """
        out = inputs @ self.kernel
        for i in range(self.matrices_num):
            out += hiddens[...,i,:] @ self.matrices[i]

        out += self.bias
        out = self.activation(out)
        return out

####### test MyDense
# dense = MyDense(10, 4,matrices_num=2,activation='tanh')
# # [b, seq, emb] => [b, seq, dim]
# a = tf.random.normal([4,30,10])
# # [b, seq, num, dim]
# h = tf.random.normal([4,30,2,4])
# print(dense(a, h))
# print(dense.kernel)
# print(dense.bias)

class TreeLSTM(layers.Layer):
    def __init__(self,embed_dim, hidden_dim, child_num=2):
        super(TreeLSTM, self).__init__()
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.child_num = child_num
        self.inputGate = MyDense(embed_dim, hidden_dim,matrices_num=child_num)
        self.forgetGate = []
        for i in range(child_num):
            self.forgetGate.append(MyDense(embed_dim, hidden_dim,matrices_num=child_num))
        self.outputGate = MyDense(embed_dim, hidden_dim,matrices_num=child_num)
        self.updateInfo = MyDense(embed_dim, hidden_dim,matrices_num=child_num,activation='tanh')
        self.activation = layers.Activation('tanh')

    def call(self, inputs, hiddens=None, cells = None, training=None, **kwargs):
        """
        How to use:
        TreeLSTM(embed_dim, hidden_dim, child_num)(inputs, hiddens cells)
        :param inputs: Node's embedding
        :param hiddens: Children's hidden state
        :param cells: Children's cell state
        :param kwargs: None
        :return: [b, embed], [b, child_num, dim], [b, child_num, dim] => h:[b, dim] c:[b,dim]
        """
        c = self.inputGate(inputs, hiddens) * self.updateInfo(inputs, hiddens)

        # hiddens * mask -> forgetgate
        mask = 1 - tf.linalg.band_part(tf.ones([self.child_num, self.child_num], dtype=tf.float32),0,0)
        mask = tf.expand_dims(mask, axis=-1)
        for i in range(self.child_num):
            c += self.forgetGate[i](inputs, hiddens * mask[i]) * cells[...,i,:]

        h = self.outputGate(inputs, hiddens) * self.activation(c)
        return h, c

######## test TreeLSTM
# treelstm = TreeLSTM(10, 4)
# # x [b, 10]
# x = tf.random.normal([5, 10])
# # h [b, 2, 4]
# h = tf.random.normal([5,2,4])
# # c [b, 2, 4]
# c = tf.random.normal([5,2,4])
# print(treelstm(x,h,c))

实现公式 N-ary Tree-LSTM如下: Tensorflow2自定义网络 1. Tree-LSTM大致介绍与代码实现_第1张图片

MyDense 类:9~12基本公式实现
Tree-LSTM类:所有公式的集合

使用时,如果对孩子数量有要求,修改child_sum即可

有疑问,或者有问题,欢迎留言

你可能感兴趣的:(Tensorflow2,算法,自然语言处理,tensorflow,机器学习,神经网络)