Mol2Image: 连接药物分子与细胞显微图像的条件流模型【2021CVPR】
Mol2Image: Improved Conditional Flow Models for Molecule to Image Synthesis
paper:Improved Conditional Flow Models for Molecule to Image Synthesis | Papers With Code
code:GitHub - uhlerlab/mol2image
2021年,来自麻省理工和多伦多大学的团队在2021 CVPR发表文章,利用流模型(Flow Models)合成药物分子作用于细胞后的显微图像,以此模拟药物分子的加入对细胞形态的影响。
以下是全文主要内容。
输入:分子
输出:分子在细胞上的影响的图像
本文的目标是生成不同药物分子作用下的细胞显微镜图像。基于最近图神经网络在学习分子嵌入和基于流的图像生成模型方面取得的成功,我们提出了Mol2Image:一个连接药物分子和细胞图像的流模型。为了生成不同分辨率的细胞特征并扩展到高分辨率的图像,我们开发了一个基于Haar小波图像金字塔的新型多尺度流架构。为了最大化生成的图像和药物分子干预之间的相互信息,我们设计了一个基于对比学习的训练策略。为了评估我们的模型,我们提出了一套新的生物图像生成指标,这些指标是稳健的、可解释的,并且与从业者相关。我们的方法学习了药物分子作用下的有意义的嵌入,它被转化为反映生物效应的图像表示。
近年来,细胞显微镜检测越来越受到关注,与传统的靶向筛选相比,图像中丰富的形态学数据为药物发现提供了更多信息。在这些发展的推动下,我们的目标是建立第一个细胞显微图像(不同药物作用下的)的生成模型,将药物分子作用的信息转化为内容丰富和可解释的图像表示。这样的系统在药物开发中具有许多实际应用——例如,它可以使从业者能够根据化合物对细胞的预测形态学效应来虚拟筛选化合物,从而更有效地探索广阔的化学空间并减少进行大量实验所需的资源。小分子可以进入细胞并改变其生物学功能和途径,导致细胞形状、结构、组织等发生变化,并且这些改变是在显微镜图像中可见的。与预测特定化学性质的传统模型相比,分子图像合成模型有可能对药物的形态学效应提供不一样的视角,从而获得广泛的特性,例如作用机制和基因靶点的发现。
在生物的需求之外,还存在一些技术上的问题。已有的生成式流模型要求输入图像与隐变量同维度,如果输入一张高清图片,由于维度过高(这里指像素数)导致其超出显存而无法进行模型训练。以最近常用的生成式流模型Glow为代表,由于显存的限制,最大只能生成256 x 256像素的图像。此外,现有的条件生成式流模型,未能将条件与生成结果很好地结合,即生成结果与条件的相关性不大。
因此,本文利用Haar小波图像金字塔变换,构建了多尺度的流模型,使得模型能够生成512 x 512像素的图像;还利用对比学习的方法,增强输入条件(即药物分子作用)对生成图像的影响;最后提出了几个关于细胞显微图像形态学的指标,对Mol2Image这类任务提供评价。
本文选用的数据集是Bray等人于2016年发表在Nature Protocol的Cell Painting数据集(数据的获取如下图)。这个数据集包括284K张分别经过10.5K种药物分子作用的细胞图像。涉及五种细胞/信息类型:核DNA、内质网、核/质RNA、线粒体及细胞骨架。每张图像对应五个彩色通道。在数据集分配方面,本文选取与8.5K药物分子作用的219K张细胞显微图像作为训练集;剩下的2K个及对应图像作为测试集。
图2. Cell Painting数据集的获取
模型的整体框架如下:
图中的x代表512 x 512像素的细胞显微图像,其经过一次Haar小波变换后能够分成一张256 x 256像素的均值图像x1和三张256 x 256像素差值图像x0~。对于四张256 x 256的图像,可利用Haar小波逆变换将它们无损变换为512 x 512的高清图像。一张256 x 256像素的均值图像x1和三张256 x 256像素差值图像x0~可分别由高清图像经过如下卷积核卷积而成:
图4. 用于Haar小波变换的卷积核
有了Haar小波变换后,就能够将一张512 x 512像素的图像转成四张256 x 256像素的图像了。这样就能够利用流模型(最大能够256 x 256像素图像)去间接生成512 x 512像素的图片了。
以下阐述模型的训练过程:已知有一张512 x 512像素的细胞显微图像x及对应的分子输入。对512 x 512像素的细胞显微图像x做一次Haar小波变换,得到一张256 x 256像素的均值图像x1和三张256 x 256像素差值图像x0~;对于每一张256 x 256像素差值图像x0~,经过一系列可逆流变换(流变换具体操作同Glow模型[4])得到对应维度的隐变量z;对得到的x1,做一次Haar小波变换,得到得到一张128 x 128像素的均值图像x2和三张128 x 128像素差值图像x1~;对于每一张128 x 128像素差值图像x1~,经过一系列可逆流变换得到对应维度的隐变量z1;如此循环,直至做到16 x 16像素的均值图像x6;对于16 x 16像素的均值图像x6,直接对其进行一系列可逆流变换得到对应维度的隐变量z6。此时,流模型对应的似然函数为:
此外,为了使得扰动分子和生成图像的相关性足够大,即需要图像的隐变量与做出的分子编码(图3中绿色部分)尽可能相关,本文通过对比学习来实现。对比学习的损失函数为:
其中的h(x,y)指的是图像的隐变量z与扰动分子的编码g(y)的余弦相似度。
训练阶段的最终损失函数包括了负的似然函数和上述对比损失:
以下阐述细胞显微图像的生成过程:已知给定的扰动分子作为输入。对于给定的扰动分子,利用图神经网络做出分子的编码g(y);将分子编码作为条件,做出16 x 16像素图象x6对应隐变量z6的均值和方差:μ6=MLP(g(y));Σ6=MLP(g(y)),其中MLP()为简单的神经网络;得到均值方差后,z6可从此正态分布中进行采样;将z6经过逆的流变换可得到16 x 16像素图象x6;对于32 x 32像素图象x5对应的插值图像x5~(16 x 16像素),对应的隐变量的均值和方差μ5=MLP(x6,g(y));Σ5=MLP(x5,g(y))可以求得;z5从此正态分布采样得到,经过逆的流变换可得到16 x 16像素的差值图像x5~;由差值图像x5~和16 x 16像素图像x6即可通过Haar小波变换的逆变换得到32 x 32像素图象x5;如此循环,最终得到512 x 512像素图像x。
将Mol2Image模型训练好之后,与其他及基准模型进行比较:
图5. Mol2Image模型与基准模型、真实图像的比较
对于此细胞显微图像生成任务,本文提出了几个细胞特征形态学指标:1、覆盖度(Coverage):被细胞占据的总的图像面积;2、细胞/核数量(Cell/Nuclei Count):图像中总的细胞/核的数量;3、细胞尺寸(Cell Size):图像中平均的细胞尺寸;4、Zernike Shape:用Zernike多项式描述细胞形状的30个特征;5、表达水平(Expression Level):一组5个特征,用于测量图像中不同细胞隔室的信号水平。以上指标都是越大越优,在下图表中,由于空间限制,对Zernike Shape和表达水平都取了均值。
此外,本文还针对分子和细胞图像的匹配程度,预训练了分类器。如果生成图像和扰动分子输入分类器的准确率与真实图像和扰动分子输入分类器的准确率相近时,表明在同一个扰动分子输入时,生成的图像和真实的细胞显微图像足够相似。此指标在下图标中为Correspondence Accuracy。为了衡量真实图像和生成图像的数据相似度,本文还引入了Sliced Wasserstein Distance (SWD)。对比数据如下图表:
图6. 各模型与真值的比较结果
本文为从药物分子到图像合成开发了一种新的基于多尺度流的架构和训练策略,并展示了此方法在为生物细胞图像生成量身定制的新评估指标上的好处。我们的工作代表了基于图像的化学物质虚拟筛选的第一步,并为研究分子结构和细胞形态中的信息奠定了基础。未来工作的一个有希望的途径是整合辅助信息(例如,已知的化学性质或药物剂量),对分子嵌入空间施加限制并提高对以前看不见的分子的泛化。
CVPR|Mol2Image:连接药物分子与细胞显微图像的条件流模型_腾讯新闻