Note:
Click here to download the full example code
Author: Zihao Ye, Qipeng Guo, Minjie Wang, Jake Zhao, Zheng Zhang
在本教程中,您将学习使用Tree-LSTM网络进行情感分析。Tree-LSTM是长短期内存(LSTM)网络到树结构网络拓扑的概括。
Tree-LSTM结构首先由Kai等人引入。等人在ACL 2015论文中:树状结构的长期短期记忆网络的改进的语义表示。核心思想是通过将链结构LSTM扩展为树结构LSTM来为语言任务引入语法信息。依赖树和选区树技术被用来获得“潜在树”。
训练Tree-LSTM的挑战是分批处理—这是机器学习中加速优化的标准技术。但是,由于树木通常具有不同的形状,因此平行化是不平凡的。DGL提供了一种替代方法。将所有树合并为一个图,然后在每棵树的结构的引导下,诱导消息通过它们。
这里的步骤,使用 Stanford Sentiment Treebank在 dgl.data。数据集提供了细粒度的树级情感注释。有五类:非常消极,消极,中立,积极和非常积极,它们表示当前子树中的情绪。选区树中的非叶子节点不包含单词,因此请使用特殊 PAD_WORD标记来表示它们。在训练和推理期间,它们的嵌入将被屏蔽为全零。
该图显示了SST数据集的一个样本,这是一个选区分析树,其节点标记有情感。为了加快处理速度,请构建一个包含五个句子的小集合,然后看看第一个句子。
import dgl
from dgl.data.tree import SST
from dgl.data import SSTBatch
# Each sample in the dataset is a constituency tree. The leaf nodes
# represent words. The word is an int value stored in the "x" field.
# The non-leaf nodes have a special word PAD_WORD. The sentiment
# label is stored in the "y" feature field.
trainset = SST(mode='tiny') # the "tiny" set has only five trees
tiny_sst = trainset.trees
num_vocabs = trainset.num_vocabs
num_classes = trainset.num_classes
vocab = trainset.vocab # vocabulary dict: key -> id
inv_vocab = {v: k for k, v in vocab.items()} # inverted vocabulary dict: id -> word
a_tree = tiny_sst[0]
for token in a_tree.ndata['x'].tolist():
if token != trainset.PAD_WORD:
print(inv_vocab[token], end=" ")
out:
Preprocessing...
Dataset creation finished. #Trees: 5
the rock is destined to be the 21st century 's new `` conan '' and that he 's going to make a splash even greater than arnold schwarzenegger , jean-claud van damme or steven segal .
使用batch()API 将所有树添加到一张图中。
import networkx as nx
import matplotlib.pyplot as plt
graph = dgl.batch(tiny_sst)
def plot_tree(g):
# this plot requires pygraphviz package
pos = nx.nx_agraph.graphviz_layout(g, prog='dot')
nx.draw(g, pos, with_labels=False, node_size=10,
node_color=[[.5, .5, .5]], arrowsize=4)
plt.show()
plot_tree(graph.to_networkx())
您可以阅读有关的定义的更多信息batch(),或跳到下一步:…注意:
**Definition**: A :class:`~dgl.batched_graph.BatchedDGLGraph` is a
:class:`~dgl.DGLGraph` that unions a list of :class:`~dgl.DGLGraph`\ s.
- The union includes all the nodes,
edges, and their features. The order of nodes, edges, and features are
preserved.
- Given that you have :math:`V_i` nodes for graph
:math:`\mathcal{G}_i`, the node ID :math:`j` in graph
:math:`\mathcal{G}_i` correspond to node ID
:math:`j + \sum_{k=1}^{i-1} V_k` in the batched graph.
- Therefore, performing feature transformation and message passing on
``BatchedDGLGraph`` is equivalent to doing those
on all ``DGLGraph`` constituents in parallel.
- Duplicate references to the same graph are
treated as deep copies; the nodes, edges, and features are duplicated,
and mutation on one reference does not affect the other.
- Currently, ``BatchedDGLGraph`` is immutable in
graph structure. You can't add
nodes and edges to it. You need to support mutable batched graphs in
(far) future.
- The ``BatchedDGLGraph`` keeps track of the meta
information of the constituents so it can be
:func:`~dgl.batched_graph.unbatch`\ ed to list of ``DGLGraph``\ s.
有关BatchedDGLGraph DGL中的模块的更多详细信息,可以单击类名称。
研究人员提出了两种类型的Tree-LSTM:Child-Sum Tree-LSTM和 N N N-ary Tree-LSTM。在本教程中,您将重点放在将二叉树 LSTM应用于二值化的选区树。此应用程序也称为选区树LSTM。使用PyTorch作为建立网络的后端框架。
在 N N N元树LSTM中,节点处的每个单元 j j j 保持隐藏的表示 h j h_j hj 和一个存储单元 c j c_j cj。那个单位 j 接受输入向量 x j x_j xj 以及子单位的隐藏表示形式: h j l h_{jl} hjl , 1≤ l l l≤N 作为输入,然后更新其新的隐藏表示 h j h_j hj 和存储单元 c j c_j cj 通过:
i j = σ ( W ( i ) x j + ∑ l = 1 N U l ( i ) h j l + b ( i ) ) , ( 1 ) i_j = \sigma\left(W^{(i)}x_j + \sum_{l=1}^{N}U^{(i)}_l h_{jl} + b^{(i)}\right), (1) ij=σ(W(i)xj+l=1∑NUl(i)hjl+b(i)),(1)
f j k = σ ( W ( f ) x j + ∑ l = 1 N U k l ( f ) h j l + b ( f ) ) , ( 2 ) f_{jk} = \sigma\left(W^{(f)}x_j + \sum_{l=1}^{N}U_{kl}^{(f)} h_{jl} + b^{(f)} \right), (2) fjk=σ(W(f)xj+l=1∑NUkl(f)hjl+b(f)),(2)
o j = σ ( W ( o ) x j + ∑ l = 1 N U l ( o ) h j l + b ( o ) ) , ( 3 ) o_j = \sigma\left(W^{(o)}x_j + \sum_{l=1}^{N}U_{l}^{(o)} h_{jl} + b^{(o)} \right), (3) oj=σ(W(o)xj+l=1∑NUl(o)hjl+b(o)),(3)
u j = tanh ( W ( u ) x j + ∑ l = 1 N U l ( u ) h j l + b ( u ) ) , ( 4 ) u_j = \textrm{tanh}\left(W^{(u)}x_j + \sum_{l=1}^{N} U_l^{(u)}h_{jl} + b^{(u)} \right), (4) uj=tanh(W(u)xj+l=1∑NUl(u)hjl+b(u)),(4)
c j = i j ⊙ u j + ∑ l = 1 N f j l ⊙ c j l , ( 5 ) c_j = i_j \odot u_j + \sum_{l=1}^{N} f_{jl} \odot c_{jl}, (5) cj=ij⊙uj+l=1∑Nfjl⊙cjl,(5)
h j = o j ⋅ tanh ( c j ) , ( 6 ) h_j = o_j \cdot \textrm{tanh}(c_j), (6) hj=oj⋅tanh(cj),(6)
它可以分解为三个阶段:message_func, reduce_func 以及apply_node_func。
Note:
apply_node_func是以前未引入的新节点UDF。在中 apply_node_func,用户指定如何处理节点特征,而不考虑边缘特征和消息。在Tree-LSTM情况下, apply_node_func必须这样做,因为存在(叶子)节点具有 0 0 0传入边,不会使用进行更新 reduce_func。
import torch as th
import torch.nn as nn
class TreeLSTMCell(nn.Module):
def __init__(self, x_size, h_size):
super(TreeLSTMCell, self).__init__()
self.W_iou = nn.Linear(x_size, 3 * h_size, bias=False)
self.U_iou = nn.Linear(2 * h_size, 3 * h_size, bias=False)
self.b_iou = nn.Parameter(th.zeros(1, 3 * h_size))
self.U_f = nn.Linear(2 * h_size, 2 * h_size)
def message_func(self, edges):
return {'h': edges.src['h'], 'c': edges.src['c']}
def reduce_func(self, nodes):
# concatenate h_jl for equation (1), (2), (3), (4)
h_cat = nodes.mailbox['h'].view(nodes.mailbox['h'].size(0), -1)
# equation (2)
f = th.sigmoid(self.U_f(h_cat)).view(*nodes.mailbox['h'].size())
# second term of equation (5)
c = th.sum(f * nodes.mailbox['c'], 1)
return {'iou': self.U_iou(h_cat), 'c': c}
def apply_node_func(self, nodes):
# equation (1), (3), (4)
iou = nodes.data['iou'] + self.b_iou
i, o, u = th.chunk(iou, 3, 1)
i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u)
# equation (5)
c = i * u + nodes.data['c']
# equation (6)
h = o * th.tanh(c)
return {'h' : h, 'c' : c}
定义消息传递功能后,请按正确的顺序触发它们。这与GCN之类的模型有很大的不同,在GCN中,所有节点都同时从上游节点提取消息 。
对于Tree-LSTM,消息从树的叶子开始,并向上传播/处理,直到到达根为止。可视化如下:
DGL定义了一个生成器来执行拓扑排序,每一项都是一个张量,用于记录从底层到根的节点。通过检查以下各项,可以了解并行度:
print('Traversing one tree:')
print(dgl.topological_nodes_generator(a_tree))
print('Traversing many trees at the same time:')
print(dgl.topological_nodes_generator(graph))
out:
Traversing one tree:
(tensor([ 2, 3, 6, 8, 13, 15, 17, 19, 22, 23, 25, 27, 28, 29, 30, 32, 34, 36,
38, 40, 43, 46, 47, 49, 50, 52, 58, 59, 60, 62, 64, 65, 66, 68, 69, 70]), tensor([ 1, 21, 26, 45, 48, 57, 63, 67]), tensor([24, 44, 56, 61]), tensor([20, 42, 55]), tensor([18, 54]), tensor([16, 53]), tensor([14, 51]), tensor([12, 41]), tensor([11, 39]), tensor([10, 37]), tensor([35]), tensor([33]), tensor([31]), tensor([9]), tensor([7]), tensor([5]), tensor([4]), tensor([0]))
Traversing many trees at the same time:
(tensor([ 2, 3, 6, 8, 13, 15, 17, 19, 22, 23, 25, 27, 28, 29,
30, 32, 34, 36, 38, 40, 43, 46, 47, 49, 50, 52, 58, 59,
60, 62, 64, 65, 66, 68, 69, 70, 74, 76, 78, 79, 82, 83,
85, 88, 90, 92, 93, 95, 96, 100, 102, 103, 105, 109, 110, 112,
113, 117, 118, 119, 121, 125, 127, 129, 130, 132, 133, 135, 138, 140,
141, 142, 143, 150, 152, 153, 155, 158, 159, 161, 162, 164, 168, 170,
171, 174, 175, 178, 179, 182, 184, 185, 187, 189, 190, 191, 192, 195,
197, 198, 200, 202, 205, 208, 210, 212, 213, 214, 216, 218, 219, 220,
223, 225, 227, 229, 230, 232, 235, 237, 240, 242, 244, 246, 248, 249,
251, 253, 255, 256, 257, 259, 262, 263, 267, 269, 270, 271, 272]), tensor([ 1, 21, 26, 45, 48, 57, 63, 67, 77, 81, 91, 94, 101, 108,
111, 116, 128, 131, 139, 151, 157, 160, 169, 173, 177, 183, 188, 196,
211, 217, 228, 247, 254, 261, 268]), tensor([ 24, 44, 56, 61, 75, 89, 99, 107, 115, 126, 137, 149, 156, 167,
181, 186, 194, 209, 215, 226, 245, 252, 266]), tensor([ 20, 42, 55, 73, 87, 124, 136, 154, 180, 207, 224, 243, 250, 265]), tensor([ 18, 54, 86, 123, 134, 148, 176, 206, 222, 241, 264]), tensor([ 16, 53, 84, 122, 172, 204, 239, 260]), tensor([ 14, 51, 80, 120, 166, 203, 238, 258]), tensor([ 12, 41, 72, 114, 165, 201, 236]), tensor([ 11, 39, 106, 163, 199, 234]), tensor([ 10, 37, 104, 147, 193, 233]), tensor([ 35, 98, 146, 231]), tensor([ 33, 97, 145, 221]), tensor([ 31, 71, 144]), tensor([9]), tensor([7]), tensor([5]), tensor([4]), tensor([0]))
调用**prop_nodes()**以触发消息传递:
import dgl.function as fn
import torch as th
graph.ndata['a'] = th.ones(graph.number_of_nodes(), 1)
graph.register_message_func(fn.copy_src('a', 'a'))
graph.register_reduce_func(fn.sum('a', 'a'))
traversal_order = dgl.topological_nodes_generator(graph)
graph.prop_nodes(traversal_order)
# the following is a syntax sugar that does the same
# dgl.prop_nodes_topo(graph)
Note:
在调用之前,请预先prop_nodes()指定 message_func和reduce_func。在示例中,您可以看到内置的“从源复制”和“求和”功能作为消息功能,以及一个“ reduce”功能进行演示。
这是指定Tree-LSTM类的完整代码。
class TreeLSTM(nn.Module):
def __init__(self,
num_vocabs,
x_size,
h_size,
num_classes,
dropout,
pretrained_emb=None):
super(TreeLSTM, self).__init__()
self.x_size = x_size
self.embedding = nn.Embedding(num_vocabs, x_size)
if pretrained_emb is not None:
print('Using glove')
self.embedding.weight.data.copy_(pretrained_emb)
self.embedding.weight.requires_grad = True
self.dropout = nn.Dropout(dropout)
self.linear = nn.Linear(h_size, num_classes)
self.cell = TreeLSTMCell(x_size, h_size)
def forward(self, batch, h, c):
"""Compute tree-lstm prediction given a batch.
Parameters
----------
batch : dgl.data.SSTBatch
The data batch.
h : Tensor
Initial hidden state.
c : Tensor
Initial cell state.
Returns
-------
logits : Tensor
The prediction of each node.
"""
g = batch.graph
g.register_message_func(self.cell.message_func)
g.register_reduce_func(self.cell.reduce_func)
g.register_apply_node_func(self.cell.apply_node_func)
# feed embedding
embeds = self.embedding(batch.wordid * batch.mask)
g.ndata['iou'] = self.cell.W_iou(self.dropout(embeds)) * batch.mask.float().unsqueeze(-1)
g.ndata['h'] = h
g.ndata['c'] = c
# propagate
dgl.prop_nodes_topo(g)
# compute logits
h = self.dropout(g.ndata.pop('h'))
logits = self.linear(h)
return logits
最后,您可以在PyTorch中编写训练范例。
from torch.utils.data import DataLoader
import torch.nn.functional as F
device = th.device('cpu')
# hyper parameters
x_size = 256
h_size = 256
dropout = 0.5
lr = 0.05
weight_decay = 1e-4
epochs = 10
# create the model
model = TreeLSTM(trainset.num_vocabs,
x_size,
h_size,
trainset.num_classes,
dropout)
print(model)
# create the optimizer
optimizer = th.optim.Adagrad(model.parameters(),
lr=lr,
weight_decay=weight_decay)
def batcher(dev):
def batcher_dev(batch):
batch_trees = dgl.batch(batch)
return SSTBatch(graph=batch_trees,
mask=batch_trees.ndata['mask'].to(device),
wordid=batch_trees.ndata['x'].to(device),
label=batch_trees.ndata['y'].to(device))
return batcher_dev
train_loader = DataLoader(dataset=tiny_sst,
batch_size=5,
collate_fn=batcher(device),
shuffle=False,
num_workers=0)
# training loop
for epoch in range(epochs):
for step, batch in enumerate(train_loader):
g = batch.graph
n = g.number_of_nodes()
h = th.zeros((n, h_size))
c = th.zeros((n, h_size))
logits = model(batch, h, c)
logp = F.log_softmax(logits, 1)
loss = F.nll_loss(logp, batch.label, reduction='sum')
optimizer.zero_grad()
loss.backward()
optimizer.step()
pred = th.argmax(logits, 1)
acc = float(th.sum(th.eq(batch.label, pred))) / len(batch.label)
print("Epoch {:05d} | Step {:05d} | Loss {:.4f} | Acc {:.4f} |".format(
epoch, step, loss.item(), acc))
out:
TreeLSTM(
(embedding): Embedding(19536, 256)
(dropout): Dropout(p=0.5, inplace=False)
(linear): Linear(in_features=256, out_features=5, bias=True)
(cell): TreeLSTMCell(
(W_iou): Linear(in_features=256, out_features=768, bias=False)
(U_iou): Linear(in_features=512, out_features=768, bias=False)
(U_f): Linear(in_features=512, out_features=512, bias=True)
)
)
Epoch 00000 | Step 00000 | Loss 431.9546 | Acc 0.3480 |
Epoch 00001 | Step 00000 | Loss 267.9747 | Acc 0.7289 |
Epoch 00002 | Step 00000 | Loss 491.0571 | Acc 0.6117 |
Epoch 00003 | Step 00000 | Loss 425.6686 | Acc 0.7985 |
Epoch 00004 | Step 00000 | Loss 213.4947 | Acc 0.7436 |
Epoch 00005 | Step 00000 | Loss 188.6720 | Acc 0.8388 |
Epoch 00006 | Step 00000 | Loss 105.7077 | Acc 0.8498 |
Epoch 00007 | Step 00000 | Loss 77.9390 | Acc 0.9121 |
Epoch 00008 | Step 00000 | Loss 60.1893 | Acc 0.9377 |
Epoch 00009 | Step 00000 | Loss 53.4182 | Acc 0.9414 |
要在具有不同设置(例如CPU或GPU)的完整数据集上训练模型,请参阅PyTorch示例。子和树-LSTM也有一个实现。
脚本的总运行时间:(0分钟1.721秒)
下载脚本:3-tree_lstm.py
下载脚本:3-tree_lstm.ipynb