目前的深度学习模型在医疗领域的应用面临着两个问题:
本文首先介绍了一下EHR在医疗领域的应用,不再赘述。
然后提到关于训练深度学习模型时可能面临的数据量不足的问题。举例说明了cerebral degenerations (e.g. Leukodystrophy, Cerebral lipidoses) or developmental disorders (e.g. autistic disorder, Heller’s syndrome)的预测是很困难的,因为训练数据集中对于这些罕见病的学习机会是很少的。
解决上述问题的方法之一是引入医学本体,从而将医学概念的结构与关系都进行编码,从而在缩小search space的同时也不损失信息。
这里提到了几个应该可以用的本体:
这些本体结构中比较相近的nodes可能会和相似的患者有更紧密的联系,可以为模型进行knowledge transfer提供可能。
这篇文章中提出的方法叫做GRAM(GRaph-based Attention Model),这个模型可以通过注意力机制将医学本体的知识融合到深度学习模型中。
结果比较由以下几个方面展开:
GRAM的目的是通过本体 G \mathcal{G} G中的 p a r e n t − c h i l d parent-child parent−child的关系在数据量不足的情况下学习到稳健的表示。
GRAM的设计中面临着使用本体的信息和数据中的信息的一个权衡,这会影响所学医学概念的特异性,这就是说如果用本体结构的信息更多,那么学到的概念应该是更加范化的,用数据的信息更多(数据量足够的话)那么学到的概念就会比较聚焦。
在上图中,下面几个实线的圈代表叶子结点,上面的虚线代表不同层次的 p a r e n t parent parent,其中叶子结点 c i c_i ci的表示可以由它几个 p a r e n t parent parent的表示组合得到,然后可以使用几个 p a r e n t parent parent的向量排出来的 G G G矩阵进行嵌入学习,再结合注意力机制,可以将患者的 v t v_t vt表示为一个向量 x t \bm{x}_t xt
在knowledge DAG中,每一个结点 c i c_i ci都有一个基础的嵌入向量 e i ∈ R m \bm{e}_i \isin \mathbb{R}^m ei∈Rm。对应就有 e 1 , e 1 , . . . , e ∣ C ∣ \bm{e}_1,\bm{e}_1,...,\bm{e}_{|\mathcal{C}|} e1,e1,...,e∣C∣一直到 e ∣ C ∣ + 1 , e ∣ C ∣ + 2 , . . . , e ∣ C ∣ + ∣ C ′ ∣ \bm{e}_{|\mathcal{C}|+1},\bm{e}_{|\mathcal{C}|+2},...,\bm{e}_{|\mathcal{C}|+|\mathcal{C}'|} e∣C∣+1,e∣C∣+2,...,e∣C∣+∣C′∣分别为叶子结点和非叶子结点的嵌入向量。而一个叶子结点的最终表示由它自己的嵌入向量和它的ancestors的嵌入向量的凸组合来表示:
g i = ∑ j ∈ A ( i ) α i j e j , ∑ j ∈ A ( i ) α i j = 1 , α i j ≥ 0 f o r j ∈ A ( i ) \bm{g}_i=\sum_{j \isin \mathcal{A}(i)}\alpha_{ij} \bm{e}_j,\sum_{j \isin \bm{\mathcal{A}(i)}}\alpha_{ij}=1,\alpha_{ij} \geq0 \ for j \isin \mathcal{A}(i) gi=j∈A(i)∑αijej,j∈A(i)∑αij=1,αij≥0 forj∈A(i)
其中 g i \bm{g}_i gi表示叶子结点的最终表示, A ( i ) \mathcal{A}(i) A(i)表示 c i c_i ci的所有祖先和自己的集合。 α i j \alpha_{ij} αij是注意力机制的权重值,用来权衡使用第 j j j个祖先的嵌入来表示 i i i的信息量权重,可以由下式计算得到:
α i j = exp ( f ( e i , e j ) ) ∑ k ∈ A ( i ) exp f ( e i , e j ) \alpha_{ij}=\frac {\exp({f(\bm{e_i},\bm{e_j})})} {\sum_{k\isin \mathcal{A}(i)}\exp{f(\bm{e_i},\bm{e_j})}} αij=∑k∈A(i)expf(ei,ej)exp(f(ei,ej))
其中 f ( e i , e j ) f(\bm{e_i},\bm{e_j}) f(ei,ej)度量的是 e i \bm{e}_i ei与 e k \bm{e}_k ek的匹配度,本文使用了一个单层的MLP来计算:
f ( e i , e j ) = u a T tanh ( W a [ e i e j ] + b a ) f(\bm{e}_i,\bm{e}_j)=\bm{u}_a^T\tanh({\bm{W}_a} \begin{bmatrix} \bm{e}_i \\ \bm{e}_j \end{bmatrix} + \bm{b}_a) f(ei,ej)=uaTtanh(Wa[eiej]+ba)
其中的 W a ∈ R l × 2 m \bm{W}_a \isin \mathbb{R}^{l\times 2m} Wa∈Rl×2m用于对 e i \bm{e_i} ei和 e j \bm{e_j} ej拼接后的向量进行线性组合, b ∈ R l \bm{b} \isin \mathbb{R}^l b∈Rl是偏置向量, u a T \bm{u}_a^T uaT是一个用于生成标量值的向量。式子中的 l l l其实就是MLP的中间神经元的个数。
文章在这里特意说明了一下,虽然在这个研究中本体结构基本上是single path形式的(应该是指一个子结点只有一个父母),但是也可以拓展到多path的形式。
当得到了每一个code的表示 g 1 , g 2 , . . . , g ∣ C ∣ \bm{g}_1,\bm{g}_2,...,\bm{g}_{|\mathcal{C}|} g1,g2,...,g∣C∣后,可以对它们进行拼接然后得到嵌入矩阵 G ∈ R m × ∣ C ∣ \bm{G} \isin \mathcal{R}^{m \times |\mathcal{C}|} G∈Rm×∣C∣, g i \bm{g}_i gi对应着这个矩阵的第 i i i列。
对于每个患者的一个visit V t V_t Vt,可以通过一个多热向量 x t \bm{x}_t xt与 G \bm{G} G作点乘再通过非线性转换获得一个向量表示 v t \bm{v}_t vt,然后可以预测最终的label y t \bm{y}_t yt。本文使用RNN来预测疾病序列。也就是说,给定了前几次的visits V 1 , V 2 , . . . , V t V_1,V_2,...,V_t V1,V2,...,Vt以预测下一次的 V t + 1 V_{t+1} Vt+1。可以表示为如下形式:
v 1 , v 2 , . . . , v t = tanh G [ x 1 , x 2 , . . . , x t ] , h 1 , h 2 , . . . , h t = R N N ( G [ v 1 , v 2 , . . . , v t , θ r ] ) , y ^ t = x ^ t + 1 = S o f t m a x ( W h t + b ) \bm{v_1},\bm{v_2},...,\bm{v_t}=\tanh{\bm{G}[\bm{x}_1,\bm{x}_2,...,\bm{x}_t]},\\ \bm{h_1},\bm{h_2},...,\bm{h_t}=RNN( {\bm{G}[\bm{v}_1,\bm{v}_2,...,\bm{v}_t,\theta_r]}),\\ \hat{y}_t=\hat{x}_{t+1} =Softmax(\bm{Wh}_t+\bm{b}) v1,v2,...,vt=tanhG[x1,x2,...,xt],h1,h2,...,ht=RNN(G[v1,v2,...,vt,θr]),y^t=x^t+1=Softmax(Wht+b)
其中 x t \bm{x}_t xt是第 t t t次visit对应的多热向量,然后可以对应得到表示 v t \bm{v}_t vt, h t \bm{h}_t ht为对应的隐藏层向量, θ r , W , b \theta_r,\bm{W},\bm{b} θr,W,b为参数。作者说这里最后预测使用的是Softmax而非dimension-wise sigmoid是因为前者有更好的预测效果。
模型使用的损失函数为交叉熵函数:
L ( x 1 , x 2 , . . . , x T ) = − 1 T − 1 ∑ t = 1 T − 1 ( y t T log ( y ^ t ) ) + ( 1 − y t T ) log ( 1 − y ^ t ) ) \mathcal{L}(\bm{x}_1,\bm{x}_2,...,\bm{x}_T)=-\frac{1}{T-1}\sum ^{T-1}_{t=1}(\bm{y}_t^T\log(\hat\bm{y}_t))+(1-\bm{y}_t^T)\log(1-\hat\bm{y}_t)) L(x1,x2,...,xT)=−T−11t=1∑T−1(ytTlog(y^t))+(1−ytT)log(1−y^t))
这里的损失是针对一个患者的,并且是对所有的时刻的损失进行了加和。在实际的操作中,还需要对所有的患者进行取平均的操作。
下面是整个算法的伪代码:
本文使用的是medical codes的共现关系来学习它们的基础表示,具体的方法用的是GloVe.
在这篇文章中,medical codes及它们的共现矩阵是根据患者的每次visit进行计数的。
这里举了一个例子:
其中 c o u n t ( c i , V t ′ ) count(c_i,V_t') count(ci,Vt′)是 c i c_i ci在 V t ′ V_t' Vt′中的个数。在这个例子中, c a c_a ca与 c c c_c cc的共线值为 3 × 2 3\times 2 3×2, c i c_i ci与 c a c_a ca的共线值为 1 × 3 1\times 3 1×3。最终可以得到一个共现矩阵 M ∈ R ∣ D ∣ × ∣ D ∣ \bm{M}\isin{\mathbb{R}^{|\mathcal{D}|\times |\mathcal{D}|}} M∈R∣D∣×∣D∣
然后可以使用下面文献的方法来学习嵌入表示。
Jeffrey Pennington, Richard Socher, and Christopher D Manning. 2014. Glove: Global Vectors for Word Representation. In EMNLP.
Prediction tasks and source of data
本文的预测任务是序列疾预测任务,即使用这次以前所有的患者记录以预测下一次患者可能诊断出的所有病。
使用了俩数据集,Sutter Palo Alto Medical Foundation (PAMF) 和MIMIC-III。前者每个患者有比较多的visits,后者只有比较少的visits。
疾病ICD通过CCS降低了类别数量以降低任务难度和加快训练速度,这里疾病编码出现的频率反映着数据的不充足程度。使用的指标是 A c c u r a c y @ k Accuracy@k Accuracy@k,其含义是给定 V t V_t Vt,如果预测的top k k k疾病中有 V t V_t Vt中的疾病那么得到一个1,反之为0。此外,本文还整了一个预测心衰(Heart Failure)的预测任务。
本文使用的knowledge Graph 是CCS multi-level diagnoses hierarchy。
最后的结果从预测性能、表示的可视化以及注意力机制的得分等方面展开,不再赘述。