JTVAE( Junction Tree Variational Autoencoder )

Junction Tree Variational Autoencoder for Molecular Graph Generation

Year: 2018
Authors: Wengong Jin, Regina Barzilay, Tommi Jaakkola
Journal Name: ICML

Contributions

  1. 使用分子图自动设计分子结构
  2. 将整个任务分为编码(以连续方法表示分子)和解码(将连续的表示映射回分子图)

Junction Tree Variational Autoencoder

JTVAE( Junction Tree Variational Autoencoder )_第1张图片
分子图和连接树提供了两个表示 z = [ z T , z G ] \bm{z} = [\bm{z}_{\mathcal{T}}, \bm{z}_G] z=[zT,zG] ,两者由编码器 q ( z T ∣ T ) q(\bm{z}_{\mathcal{T}} | \mathcal{T}) q(zTT) q ( z G ∣ G ) q(\bm{z}_{G} | G) q(zGG) 产生。两个解码器 p ( T ∣ z T ) p(\mathcal{T} | \bm{z}_{\mathcal{T}}) p(TzT) p ( G ∣ T , z G ) p(G | \mathcal{T}, \bm{z}_{G}) p(GT,zG) 重构分子图。

Junction Tree

已知分子图 G = ( V , E ) G = (V, E) G=(V,E) ,连接树为 T G = ( V , E , X ) \mathcal{T}_G = (\mathcal{V}, \mathcal{E}, \mathcal{X}) TG=(V,E,X) ,其中 X \mathcal{X} X 为特征字典, V = { C 1 , . . . , C n } \mathcal{V} = \{ C_1, ..., C_n \} V={C1,...,Cn} C i = ( V i , E i ) C_i = (V_i, E_i) Ci=(Vi,Ei) G G G 的子结构,满足以下限制

  1. ∪ i V i = V \cup_i V_i = V iVi=V ∪ i E i = E \cup_i E_i = E iEi=E
  2. 如果 C k C_k Ck 在从 C i C_i Ci C j C_j Cj 的路径上, V i ∩ V j ⊆ V k V_i \cap V_j \subseteq V_k ViVjVk

Graph Encoder

每个节点 v v v 和边缘 ( u , v ) ∈ E (u, v) \in E (u,v)E 都有相对应的特征向量 x v \bm{x}_v xv x u v \bm{x}_{uv} xuv 。定义 v u v \bm{v}_{uv} vuv 为从 u u u v v v 的信息
v u v ( t ) = τ ( W 1 g x u + W 2 g x u v + W 3 g ∑ w ∈ N ( u ) ∖ v v w u ( t − 1 ) ) \bm{v}_{uv}^{(t)} = \tau(W_1^g \bm{x}_u + W_2^g \bm{x}_{uv} + W_3^g \sum_{w \in N(u) \setminus v}\bm{v}_{wu}^{(t-1)}) vuv(t)=τ(W1gxu+W2gxuv+W3gwN(u)vvwu(t1))

其中, τ \tau τ 为 RELU , v u v ( t ) \bm{v}_{uv}^{(t)} vuv(t) 表示第 t t t 轮迭代后的信息, v u v ( 0 ) = 0 \bm{v}_{uv}^{(0)} = 0 vuv(0)=0 T T T 轮迭代后,将信息聚合为每个节点的隐向量
h u = τ ( U 1 g x u + ∑ v ∈ N ( u ) U 2 g v v u ( T ) ) \bm{h}_u = \tau(U_1^g \bm{x}_u + \sum_{v \in N(u)} U_2^g \bm{v}_{vu}^{(T)}) hu=τ(U1gxu+vN(u)U2gvvu(T))

最终的图表示为 h G = ∑ i h i / ∣ V ∣ \bm{h}_G = \sum_{i} \bm{h}_i / |V| hG=ihi/V z G \bm{z}_G zG N ( μ G , σ G ) \mathcal{N}(\bm{\mu}_G, \bm{\sigma}_G) N(μG,σG) 中采样, μ G \bm{\mu}_G μG l o g σ G log \bm{\sigma}_G logσG 通过两个独立的仿射层根据 h G \bm{h}_G hG 计算得出。

Tree Encoder

对于每条边缘 ( C i , C j ) (C_i, C_j) (Ci,Cj) ,定义信息向量 m i j \bm{m}_{ij} mij m j i \bm{m}_{ji} mji
m i j = G R U ( x i , { m k i } k ∈ N ( i ) ∖ j ) \bm{m}_{ij} = GRU(\bm{x}_i, \{ \bm{m}_{ki} \}_{k \in N(i) \setminus j}) mij=GRU(xi,{mki}kN(i)j)

GRU 的结构如下所示
s i j = ∑ k ∈ N ( i ) ∖ j m k i z i j = σ ( W z x i + U z s i j + b z ) r k i = σ ( W r x i + U r m i j + b r ) m ~ i j = t a n h ( W x i + U ∑ k ∈ N ( i ) ∖ j r k i ⊙ m k i ) m i j = ( 1 − z i j ) ⊙ s i j + z i j ⊙ m ~ i j \bm{s}_{ij} = \sum_{k \in N(i) \setminus j} \bm{m}_{ki} \\ \bm{z}_{ij} = \sigma (W^z \bm{x}_i + U^z \bm{s}_{ij} + b^z) \\ \bm{r}_{ki} = \sigma(W^r \bm{x}_i + U^r \bm{m}_{ij} + b^r) \\ \widetilde{\bm{m}}_{ij} = tanh(W \bm{x}_i + U \sum_{k \in N(i) \setminus j} \bm{r}_{ki} \odot \bm{m}_{ki}) \\ \bm{m}_{ij} = (1 - \bm{z}_{ij}) \odot \bm{s}_{ij} + \bm{z}_{ij} \odot \widetilde{\bm{m}}_{ij} sij=kN(i)jmkizij=σ(Wzxi+Uzsij+bz)rki=σ(Wrxi+Urmij+br)m ij=tanh(Wxi+UkN(i)jrkimki)mij=(1zij)sij+zijm ij

其中, σ \sigma σ 为 sigmoid 函数。信息传递之后,每个节点的隐向量
h i = τ ( W o x i + ∑ k ∈ N ( i ) U o m k i ) \bm{h}_i = \tau(W^o \bm{x}_i + \sum_{k \in N(i)}U^o \bm{m}_{ki}) hi=τ(Woxi+kN(i)Uomki)

采样 z T \bm{z}_{\mathcal{T}} zT 的方法和图编码器类似。

Tree Decoder

JTVAE( Junction Tree Variational Autoencoder )_第2张图片
JTVAE( Junction Tree Variational Autoencoder )_第3张图片
解码过程在原分子的基础上,利用树采样继续扩展新的子结构,原分子的所有子结构均为根节点。
定义 E ~ t \widetilde{\mathcal{E}}_t E t 为到 t t t 时刻为止已经采样的边缘, h i t j t \bm{h}_{i_t j_t} hitjt 为采样过程中产生的信息。
h i t j t = G R U ( x i t , { h k i t } ( k , i t ) ∈ E ~ t , k ≠ j t ) \bm{h}_{i_t j_t} = GRU(\bm{x}_{i_t}, \{ \bm{h}_{k i_t} \}_{(k, i_t) \in \widetilde{\mathcal{E}}_t, k \neq j_t}) hitjt=GRU(xit,{hkit}(k,it)E t,k=jt)

定义 p t p_t pt 为当前叶节点是否继续扩展的概率
p t = σ ( u d ⋅ τ ( W 1 d x i t + W 2 d z T + W 3 d ∑ ( k , i t ) ∈ E ~ t h k i t ) ) p_t = \sigma(u^d · \tau(W_1^d \bm{x}_{i_t} + W_2^d \bm{z}_{\mathcal{T}} + W_3^d \sum_{(k, i_t) \in \widetilde{\mathcal{E}}_t} \bm{h}_{k i_t})) pt=σ(udτ(W1dxit+W2dzT+W3d(k,it)E thkit))

定义
q j = s o f t m a x ( U l τ ( W 1 l z T + W 2 l h i j ) ) q_j = softmax(U^l \tau(W_1^l \bm{z}_{\mathcal{T}} + W_2^l \bm{h}_{ij})) qj=softmax(Ulτ(W1lzT+W2lhij))

表示扩展节点 j j j 的特征 x j \bm{x}_j xj 在特征字典 X \mathcal{X} X 中的概率。当 j j j 为根节点时, h i j = 0 \bm{h}_{ij} = 0 hij=0 。训练时采用 teacher forcing 最小化交叉熵损失
L c ( T ) = ∑ t L d ( p t , p ^ t ) + ∑ j L l ( q j , q ^ j ) L_c(\mathcal{T}) = \sum_t L^d(p_t, \hat{p}_t) + \sum_j L^l(q_j, \hat{q}_j) Lc(T)=tLd(pt,p^t)+jLl(qj,q^j)

Graph Decoder

因为相同的树所重构出的图并不唯一,定义 G ( T ) \mathcal{G}(\mathcal{T}) G(T) 为树 T \mathcal{T} T 所能重构的图的集合。
G ^ = arg max ⁡ G ′ ∈ G ( T ) f a ( G ′ ) \hat{G} = \argmax_{G' \in \mathcal{G}(\mathcal{T})} f^a(G') G^=GG(T)argmaxfa(G)

其中, f a f^a fa 为评分函数。出于效率原因,作者按照树本身的解码顺序,一次扩展一个子结构进行计算。
假设根据树节点 C j C_j Cj 新扩展的子结构为 C i C_i Ci ,生成了子图 G i G_i Gi ,子图所对应的向量表示为 h G i \bm{h}_{G_i} hGi ,评分函数为
f a ( G i ) = h G i ⋅ z G f^a (G_i) = \bm{h}_{G_i} · \bm{z}_G fa(Gi)=hGizG

定义 u u u v v v G i G_i Gi 中的两个原子。如果 v ∈ C i v \in C_i vCi α v = i \alpha_v = i αv=i 。如果 v ∈ C j ∖ C i v \in C_j \setminus C_i vCjCi α v = j \alpha_v = j αv=j 。设立 α v \alpha_v αv 是为了标注原子在树中的位置。仿照图编码器,定义 μ u v \bm{\mu}_{uv} μuv 为从 u u u v v v 的信息
μ u v ( t ) = τ ( W 1 a x u + W 2 a x u v + W 3 a μ ~ u v ( t − 1 ) ) μ ~ u v ( t − 1 ) = { ∑ w ∈ N ( u ) ∖ v μ w u ( t − 1 ) , α u = α v , m ^ α u α v + ∑ w ∈ N ( u ) ∖ v μ w u ( t − 1 ) , α u ≠ α v . \bm{\mu}_{uv}^{(t)} = \tau(W_1^a \bm{x}_u + W_2^a \bm{x}_{uv} + W_3^a \widetilde{\bm{\mu}}_{uv}^{(t-1)}) \\ \widetilde{\bm{\mu}}_{uv}^{(t-1)} = \left\{ \begin{aligned} \sum_{w \in N(u) \setminus v} \bm{\mu}_{wu}^{(t-1)} & , & \alpha_u = \alpha_v, \\ \hat{\bm{m}}_{\alpha_u \alpha_v} + \sum_{w \in N(u) \setminus v} \bm{\mu}_{wu}^{(t-1)} & , & \alpha_u \neq \alpha_v. \end{aligned} \right. μuv(t)=τ(W1axu+W2axuv+W3aμ uv(t1))μ uv(t1)=wN(u)vμwu(t1)m^αuαv+wN(u)vμwu(t1),,αu=αv,αu=αv.

计算 h G i \bm{h}_{G_i} hGi 的方法与图编码器相同。
学习图解码器参数以最大化在每个树节点处预测地面真实图 G 的正确子图 G i 的对数似然
该过程的损失函数为
L g ( G ) = ∑ i [ f a ( G i ) − l o g ∑ G i ′ ∈ G i e x p ( f a ( G i ′ ) ) ] L_g(G) = \sum_i \Big[ f^a(G_i) - log \sum_{G_i' \in \mathcal{G}_i} exp(f^a(G_i')) \Big] Lg(G)=i[fa(Gi)logGiGiexp(fa(Gi))]

其中, i i i 为树的节点, G i G_i Gi 为正确子图。
以我的理解, l o g ∑ G i ′ ∈ G i e x p ( f a ( G i ′ ) ) log \sum_{G_i' \in \mathcal{G}_i} exp(f^a(G_i')) logGiGiexp(fa(Gi)) 放大了较大 f a ( G i ′ ) f^a(G_i') fa(Gi) 的影响,减少了较小 f a ( G i ′ ) f^a(G_i') fa(Gi) 的影响。所以,该损失函数倾向于使正确子图的分数无穷大,错误子图的分数为 0 ,但这样的话 f a ( G i ) f^a (G_i) fa(Gi) 直接使用内积计算相似度是否不太合理?

Results

JTVAE( Junction Tree Variational Autoencoder )_第4张图片
JTVAE( Junction Tree Variational Autoencoder )_第5张图片
JTVAE( Junction Tree Variational Autoencoder )_第6张图片

你可能感兴趣的:(RNA结构预测,深度学习,人工智能,神经网络,生物信息学,机器学习)