GRAM: Graph-based Attention Model for Healthcare Representation Learning文献阅读记录

Abstract

目前的深度学习模型在医疗领域的应用面临着两个问题:

  • 数据不充分:深度学习的模型需要比较大的数据量,在医疗领域的建模有时数据量不太够
  • 解释性:表示学习所学得的方法应该能够与医学知识匹配

Introduction

本文首先介绍了一下EHR在医疗领域的应用,不再赘述。
然后提到关于训练深度学习模型时可能面临的数据量不足的问题。举例说明了cerebral degenerations (e.g. Leukodystrophy, Cerebral lipidoses) or developmental disorders (e.g. autistic disorder, Heller’s syndrome)的预测是很困难的,因为训练数据集中对于这些罕见病的学习机会是很少的。

解决上述问题的方法之一是引入医学本体,从而将医学概念的结构与关系都进行编码,从而在缩小search space的同时也不损失信息。
这里提到了几个应该可以用的本体:

  • International Classification of Diseases (ICD)
  • Clinical Classifications Software (CCS)
  • Systematized Nomenclature of Medicine-Clinical Terms (SNOMED-CT)

这些本体结构中比较相近的nodes可能会和相似的患者有更紧密的联系,可以为模型进行knowledge transfer提供可能。
这篇文章中提出的方法叫做GRAM(GRaph-based Attention Model),这个模型可以通过注意力机制将医学本体的知识融合到深度学习模型中。

结果比较由以下几个方面展开:

  • 预测性能与RNN在两个sequential diagnoses预测任务和heart failure预测任务中展开,GRAM在那些不常见的疾病预测任务上最多可以获得比RNN高10%的accuracy
  • 对表示的结果进行了可视化,GRAM所表示的医学概念中更相似的位置也会更相近

Methodology

Basic Notation

  • 整个EHR中的medical codes记录为 c 1 , c 2 , . . . , c ∣ C ∣ c_1,c_2,...,c_{|\mathcal{C}|} c1,c2,...,cC的集合,总共有 ∣ C ∣ |\mathcal{C}| C种medical codes
  • 每个患者的clinical records可以看作是多次visits,即 V 1 , V 2 , . . . , V T V_1,V_2,...,V_T V1,V2,...,VT,而每次visit中又包含着一系列medical codes。 V t V_t Vt可以表示成binary vector, x t ∈ { 0 , 1 } ∣ C ∣ \bm{x}_t \isin\{0,1\}^{|\mathcal{C}|} xt{0,1}C,其中第 i i i个元素代表这个患者的这次记录中是否包含了code c i c_i ci
  • 本体结构 G \mathcal{G} G所展示的是医学概念的层次结构,是以 p a r e n t − c h i l d parent-child parentchild的形式展示关系的,而上面的 C \mathcal{C} C形成的是所有的叶子结点。本体结构 G \mathcal{G} G可以以directed acyclic graph(DAG) 的形式进行表示,这个图中的所有结点为 D = C + C ′ \mathcal{D}=\mathcal{C}+\mathcal{C}' D=C+C,其中的 C ′ = { c ∣ C ∣ + 1 , c ∣ C ∣ + 2 , . . . , c ∣ C ∣ + ∣ C ′ ∣ } \mathcal{C}'=\{c_{|\mathcal{C}|+1},c_{|\mathcal{C}|+2},...,c_{|\mathcal{C}|+|\mathcal{C}'|}\} C={cC+1,cC+2,...,cC+C},为所有的非叶子结点(也就是叶子结点的ancestors)。像ICD-9和CCS这样的本体结构中的分支都是没有交叉的,而SNOMED-CT中会存在少量少量交叉,但是本文主要只是考虑 p a r e n t − c h i l d parent-child parentchild的关系。

Knowledge DAG and the Attention Mechanism

GRAM的目的是通过本体 G \mathcal{G} G中的 p a r e n t − c h i l d parent-child parentchild的关系在数据量不足的情况下学习到稳健的表示。
GRAM的设计中面临着使用本体的信息数据中的信息的一个权衡,这会影响所学医学概念的特异性,这就是说如果用本体结构的信息更多,那么学到的概念应该是更加范化的,用数据的信息更多(数据量足够的话)那么学到的概念就会比较聚焦。
GRAM: Graph-based Attention Model for Healthcare Representation Learning文献阅读记录_第1张图片
在上图中,下面几个实线的圈代表叶子结点,上面的虚线代表不同层次的 p a r e n t parent parent,其中叶子结点 c i c_i ci的表示可以由它几个 p a r e n t parent parent的表示组合得到,然后可以使用几个 p a r e n t parent parent的向量排出来的 G G G矩阵进行嵌入学习,再结合注意力机制,可以将患者的 v t v_t vt表示为一个向量 x t \bm{x}_t xt

在knowledge DAG中,每一个结点 c i c_i ci都有一个基础的嵌入向量 e i ∈ R m \bm{e}_i \isin \mathbb{R}^m eiRm。对应就有 e 1 , e 1 , . . . , e ∣ C ∣ \bm{e}_1,\bm{e}_1,...,\bm{e}_{|\mathcal{C}|} e1,e1,...,eC一直到 e ∣ C ∣ + 1 , e ∣ C ∣ + 2 , . . . , e ∣ C ∣ + ∣ C ′ ∣ \bm{e}_{|\mathcal{C}|+1},\bm{e}_{|\mathcal{C}|+2},...,\bm{e}_{|\mathcal{C}|+|\mathcal{C}'|} eC+1,eC+2,...,eC+C分别为叶子结点和非叶子结点的嵌入向量。而一个叶子结点的最终表示由它自己的嵌入向量和它的ancestors的嵌入向量的凸组合来表示:
g i = ∑ j ∈ A ( i ) α i j e j , ∑ j ∈ A ( i ) α i j = 1 , α i j ≥ 0   f o r j ∈ A ( i ) \bm{g}_i=\sum_{j \isin \mathcal{A}(i)}\alpha_{ij} \bm{e}_j,\sum_{j \isin \bm{\mathcal{A}(i)}}\alpha_{ij}=1,\alpha_{ij} \geq0 \ for j \isin \mathcal{A}(i) gi=jA(i)αijej,jA(i)αij=1,αij0 forjA(i)
其中 g i \bm{g}_i gi表示叶子结点的最终表示, A ( i ) \mathcal{A}(i) A(i)表示 c i c_i ci的所有祖先和自己的集合。 α i j \alpha_{ij} αij是注意力机制的权重值,用来权衡使用第 j j j个祖先的嵌入来表示 i i i的信息量权重,可以由下式计算得到:
α i j = exp ⁡ ( f ( e i , e j ) ) ∑ k ∈ A ( i ) exp ⁡ f ( e i , e j ) \alpha_{ij}=\frac {\exp({f(\bm{e_i},\bm{e_j})})} {\sum_{k\isin \mathcal{A}(i)}\exp{f(\bm{e_i},\bm{e_j})}} αij=kA(i)expf(ei,ej)exp(f(ei,ej))
其中 f ( e i , e j ) f(\bm{e_i},\bm{e_j}) f(ei,ej)度量的是 e i \bm{e}_i ei e k \bm{e}_k ek的匹配度,本文使用了一个单层的MLP来计算:
f ( e i , e j ) = u a T tanh ⁡ ( W a [ e i e j ] + b a ) f(\bm{e}_i,\bm{e}_j)=\bm{u}_a^T\tanh({\bm{W}_a} \begin{bmatrix} \bm{e}_i \\ \bm{e}_j \end{bmatrix} + \bm{b}_a) f(ei,ej)=uaTtanh(Wa[eiej]+ba)
其中的 W a ∈ R l × 2 m \bm{W}_a \isin \mathbb{R}^{l\times 2m} WaRl×2m用于对 e i \bm{e_i} ei e j \bm{e_j} ej拼接后的向量进行线性组合, b ∈ R l \bm{b} \isin \mathbb{R}^l bRl是偏置向量, u a T \bm{u}_a^T uaT是一个用于生成标量值的向量。式子中的 l l l其实就是MLP的中间神经元的个数。

文章在这里特意说明了一下,虽然在这个研究中本体结构基本上是single path形式的(应该是指一个子结点只有一个父母),但是也可以拓展到多path的形式。

End-to-End Training with a Predictive Model

当得到了每一个code的表示 g 1 , g 2 , . . . , g ∣ C ∣ \bm{g}_1,\bm{g}_2,...,\bm{g}_{|\mathcal{C}|} g1,g2,...,gC后,可以对它们进行拼接然后得到嵌入矩阵 G ∈ R m × ∣ C ∣ \bm{G} \isin \mathcal{R}^{m \times |\mathcal{C}|} GRm×C g i \bm{g}_i gi对应着这个矩阵的第 i i i列。
对于每个患者的一个visit V t V_t Vt,可以通过一个多热向量 x t \bm{x}_t xt G \bm{G} G作点乘再通过非线性转换获得一个向量表示 v t \bm{v}_t vt,然后可以预测最终的label y t \bm{y}_t yt。本文使用RNN来预测疾病序列。也就是说,给定了前几次的visits V 1 , V 2 , . . . , V t V_1,V_2,...,V_t V1,V2,...,Vt以预测下一次的 V t + 1 V_{t+1} Vt+1。可以表示为如下形式:
v 1 , v 2 , . . . , v t = tanh ⁡ G [ x 1 , x 2 , . . . , x t ] , h 1 , h 2 , . . . , h t = R N N ( G [ v 1 , v 2 , . . . , v t , θ r ] ) , y ^ t = x ^ t + 1 = S o f t m a x ( W h t + b ) \bm{v_1},\bm{v_2},...,\bm{v_t}=\tanh{\bm{G}[\bm{x}_1,\bm{x}_2,...,\bm{x}_t]},\\ \bm{h_1},\bm{h_2},...,\bm{h_t}=RNN( {\bm{G}[\bm{v}_1,\bm{v}_2,...,\bm{v}_t,\theta_r]}),\\ \hat{y}_t=\hat{x}_{t+1} =Softmax(\bm{Wh}_t+\bm{b}) v1,v2,...,vt=tanhG[x1,x2,...,xt],h1,h2,...,ht=RNN(G[v1,v2,...,vt,θr]),y^t=x^t+1=Softmax(Wht+b)
其中 x t \bm{x}_t xt是第 t t t次visit对应的多热向量,然后可以对应得到表示 v t \bm{v}_t vt h t \bm{h}_t ht为对应的隐藏层向量, θ r , W , b \theta_r,\bm{W},\bm{b} θr,W,b为参数。作者说这里最后预测使用的是Softmax而非dimension-wise sigmoid是因为前者有更好的预测效果。

模型使用的损失函数为交叉熵函数:
L ( x 1 , x 2 , . . . , x T ) = − 1 T − 1 ∑ t = 1 T − 1 ( y t T log ⁡ ( y ^ t ) ) + ( 1 − y t T ) log ⁡ ( 1 − y ^ t ) ) \mathcal{L}(\bm{x}_1,\bm{x}_2,...,\bm{x}_T)=-\frac{1}{T-1}\sum ^{T-1}_{t=1}(\bm{y}_t^T\log(\hat\bm{y}_t))+(1-\bm{y}_t^T)\log(1-\hat\bm{y}_t)) L(x1,x2,...,xT)=T11t=1T1(ytTlog(y^t))+(1ytT)log(1y^t))
这里的损失是针对一个患者的,并且是对所有的时刻的损失进行了加和。在实际的操作中,还需要对所有的患者进行取平均的操作。
下面是整个算法的伪代码:
GRAM: Graph-based Attention Model for Healthcare Representation Learning文献阅读记录_第2张图片

Initializing Basic Embeddings

本文使用的是medical codes的共现关系来学习它们的基础表示,具体的方法用的是GloVe.
在这篇文章中,medical codes及它们的共现矩阵是根据患者的每次visit进行计数的。
这里举了一个例子:
GRAM: Graph-based Attention Model for Healthcare Representation Learning文献阅读记录_第3张图片
其中 c o u n t ( c i , V t ′ ) count(c_i,V_t') count(ci,Vt) c i c_i ci V t ′ V_t' Vt中的个数。在这个例子中, c a c_a ca c c c_c cc的共线值为 3 × 2 3\times 2 3×2, c i c_i ci c a c_a ca的共线值为 1 × 3 1\times 3 1×3。最终可以得到一个共现矩阵 M ∈ R ∣ D ∣ × ∣ D ∣ \bm{M}\isin{\mathbb{R}^{|\mathcal{D}|\times |\mathcal{D}|}} MRD×D
GRAM: Graph-based Attention Model for Healthcare Representation Learning文献阅读记录_第4张图片
然后可以使用下面文献的方法来学习嵌入表示。

Jeffrey Pennington, Richard Socher, and Christopher D Manning. 2014. Glove: Global Vectors for Word Representation. In EMNLP.
GRAM: Graph-based Attention Model for Healthcare Representation Learning文献阅读记录_第5张图片

Experiments

Experiment Setup

Prediction tasks and source of data
本文的预测任务是序列疾预测任务,即使用这次以前所有的患者记录以预测下一次患者可能诊断出的所有病。
使用了俩数据集,Sutter Palo Alto Medical Foundation (PAMF) 和MIMIC-III。前者每个患者有比较多的visits,后者只有比较少的visits。
疾病ICD通过CCS降低了类别数量以降低任务难度和加快训练速度,这里疾病编码出现的频率反映着数据的不充足程度。使用的指标是 A c c u r a c y @ k Accuracy@k Accuracy@k,其含义是给定 V t V_t Vt,如果预测的top k k k疾病中有 V t V_t Vt中的疾病那么得到一个1,反之为0。此外,本文还整了一个预测心衰(Heart Failure)的预测任务。
本文使用的knowledge Graph 是CCS multi-level diagnoses hierarchy

最后的结果从预测性能、表示的可视化以及注意力机制的得分等方面展开,不再赘述。

你可能感兴趣的:(文献阅读记录)