最优传输系列(一):最优传输入门及Monge-Kantorovitch Problem

目录

  • 引言
  • 背景
  • 什么是最优传输问题(Optimal Transport)
  • 概率测度
  • 如何解决最优传输问题
    • Monge
    • Kantorovitch
    • 推广到连续空间
  • 引用

引言

  • 本人刚刚接触最优传输,很多地方并不能也不敢写出自己尚不成熟的见解来误导别人,所以很多地方都是借鉴了别人的成果与经验。之所以记录这篇博客,一是为了和大家分享学习内容,也期望自己的文章能帮助到有需要的人;二是进行学习笔记式的记录,并时常回顾。本文不做商用,所引用的别人的成果也会在文章末尾贴出链接,本文只想将本人所涉猎的内容整理起来,形成更容易让人理解的更好的内容。下面进入正题。
  • 最优运输(Optimal Transport)近年来引起了广大学者的研究兴趣,并在NIPS和ICML等机器学习顶级会议频繁出现。然而,最优运输的基本理论对于初学者来说,并不友好:初看理论,感觉全是晦涩难懂的数学推理公式,让很多读者有点望而却步的感觉。此外,目前国内关于最优运输理论的研究还比较初步,相关中文资料也比较匮乏。因此,笔者对自己最近几天在网上博客、论文和视频等资料的学习过程进行了初步整理,希望对后续的初学者提供一点帮助。
  • 我的入门体验:最优运输相关理论的代码库已比较丰富(需要代码,可以去github上搜索,或者检索关于最优运输的热门顶会论文,基本都有开源代码),并且核心理论也没那么复杂,或者说只要你认真阅读完本文,我相信你应该能够较顺畅地把最优运输理论应用到你的实际应用中了。

背景

  • 最优运输问题最早是由法国数学家加斯帕德·蒙日(Gaspard Monge)在19世纪中期提出,它是一种将给定质量的泥土运输到给定洞里的最小成本解决方案。这个问题在20世纪中期重新出现在坎托罗维奇的著作中,并在近些年的研究中发现了一些令人惊讶的新进展,比如Sinkhorn算法。最优运输被广泛应用于多个领域,包括计算流体力学,多幅图像之间的颜色转移或图像处理背景下的变形,计算机图形学中的插值方案,以及经济学、通过匹配和均衡问题等。此外,最优传输最近也引起了生物医学相关学者的关注,并被广泛用于单细胞RNA发育过程中指导分化以及提高细胞观测数据的数据增强工具,从而提高各种下游分任务的准确性和稳定性。
  • 当前,许多现代统计和机器学习问题可以被重新描述为在两个概率分布之间寻找最优运输图。例如,领域适应旨在从源数据分布中学习一个训练良好的模型,并将该模型转换为采用目标数据分布。另一个例子是深度生成模型,其目标是将一个固定的分布,例如标准高斯或均匀分布,映射到真实样本的潜在总体分布。在最近几十年里,OT方法在现代数据科学应用的显著增殖中重新焕发了活力,包括机器学习、统计和计算机视觉。

什么是最优传输问题(Optimal Transport)

  • 这个问题就像它的名字一样通俗易懂,简单来说,我们有m个工厂,总共需要G吨原材料,而这些原材料分布在n个产地,那么问题来了,在n个产地到m个工厂分别运输成本不同的情况下,如何用最低成本将G吨材料从n个产地分别运送到m个工厂?
  • 下面举个例子,设m=3,n=3,工厂A、B、C分别需要原材料100吨、50吨、100吨,而产地A、B、C分别拥有原材料70吨、80吨、100吨,设 c ( i , j ) c(i,j) c(i,j)为从产地 i i i到工厂 j j j运送一吨的成本, r ( i , j ) r(i,j) r(i,j)为从产地 i i i到工厂 j j j运送的吨数。而我们要做的就是
    m i n ∑ i = 1 n ∑ j = 1 m c ( i , j ) ∗ d ( i , j ) min\displaystyle\sum_{i=1}^{n} \displaystyle\sum_{j=1}^{m} c(i,j)*d(i,j) mini=1nj=1mc(i,j)d(i,j)
    且需要满足以下约束条件:
    ∑ i = 1 n r ( i , 1 ) = 100 \displaystyle\sum_{i=1}^{n} r(i,1)=100 i=1nr(i,1)=100 ∑ i = 1 n r ( i , 2 ) = 50 \displaystyle\sum_{i=1}^{n} r(i,2)=50 i=1nr(i,2)=50
    ∑ i = 1 n r ( i , 3 ) = 100 \displaystyle\sum_{i=1}^{n} r(i,3)=100 i=1nr(i,3)=100 ∑ j = 1 m r ( 1 , j ) = 70 \displaystyle\sum_{j=1}^{m} r(1,j)=70 j=1mr(1,j)=70
    ∑ j = 1 m r ( 2 , j ) = 80 \displaystyle\sum_{j=1}^{m} r(2,j)=80 j=1mr(2,j)=80 ∑ j = 1 m r ( 3 , j ) = 100 \displaystyle\sum_{j=1}^{m} r(3,j)=100 j=1mr(3,j)=100
    为了方便表示,设 ν = [ 100 , 50 , 100 ] ν=[100,50,100] ν=[100,50,100], μ = [ 70 , 80 , 100 ] μ=[70,80,100] μ=[70,80,100],则上述约束条件变成了:
    ∑ i = 1 n r ( i , j ) = ν j \displaystyle\sum_{i=1}^{n} r(i,j)=ν_j i=1nr(i,j)=νj ∑ j = 1 m r ( i , j ) = μ i \displaystyle\sum_{j=1}^{m} r(i,j)=μ_i j=1mr(i,j)=μi

概率测度

  • 上面我们定义了 μ μ μ ν ν ν,其实呢,我们还可以把它们归一化一下,即
    ν = [ 100 / G , 50 / G , 100 / G ] = [ 0.4 , 0.2 , 0.4 ] ν=[100/G,50/G,100/G]=[0.4,0.2,0.4] ν=[100/G,50/G,100/G]=[0.4,0.2,0.4], μ = [ 70 / G , 80 / G , 100 / G ] = [ 0.28 , 0.32 , 0.4 ] μ=[70/G,80/G,100/G]=[0.28,0.32,0.4] μ=[70/G,80/G,100/G]=[0.28,0.32,0.4]
  • 在数值上,已经从没有规范的数据转化为概率了,这样的好处是:
    1. 首先这些值非0
    2. 其次它们的加和为1
    3. 当然还满足一些我们一般不关注的性质(次可数可加性)
  • 其实这里的 μ μ μ ν ν ν概率向量(probability vector,也被称为直方图Histograms),它们的官方定义如下:
    在这里插入图片描述
    而测度与上面的向量稍有不同,测度其实是一个函数,它的官方定义如下:
    在这里插入图片描述
    上述公式含义:以 a i a_i ai为概率和对应位置 x i x_i xi的狄拉克δ函数值乘积的累加和。下图很好地阐述了一组不同元素点的概率向量分布
    最优传输系列(一):最优传输入门及Monge-Kantorovitch Problem_第1张图片
    关于狄拉克(Dirac)函数大家可以参考这篇文章https://spaces.ac.cn/archives/1870,但Dirac在此处出现的原因及作用我并不清楚,希望评论区大佬们不吝赐教!
  • 所以,设 x t = [ 100 , 50 , 100 ] 、 x s = [ 70 , 80 , 100 ] x^t=[100,50,100]、x^s=[70,80,100] xt=[100,50,100]xs=[70,80,100],则 α = ∑ i = 1 n μ i δ x i s α=\displaystyle\sum_{i=1}^{n} μ_iδ_{x^s_{i}} α=i=1nμiδxis β = ∑ j = 1 m ν j δ x j t β=\displaystyle\sum_{j=1}^{m} ν_jδ_{x^t_{j}} β=j=1mνjδxjt就是概率测度(probability measure),更进一步讲,由于这里n和m都是确定的整数,所以这是一个离散的问题,所以 α α α β β β属于概率测度中的离散测度(Discrete measure)
  • 另外我想提前补充一点,这里的 x s = [ 100 , 50 , 100 ] 、 x t = [ 70 , 80 , 100 ] x^s=[100,50,100]、x^t=[70,80,100] xs=[100,50,100]xt=[70,80,100]中每一个元素都是一个实数,其实在迁移学习中,上面列表里的每一个元素是一个数据样本,在计算机视觉领域,则每一个数据样本是一张图片像素矩阵。

如何解决最优传输问题

Monge

  • 最优传输的问题就是由Monge(蒙日)启发的,他也提出了对应的解决方案。
  • 蒙日的方法非常直观,就是寻找一个合适的映射关系,得到合理的安排方案使得成本最低,其官方定义如下:
    在这里插入图片描述
    T T T就是解决方案, α α α β β β是离散测度,即:
    α = ∑ i = 1 n μ i δ x i s α=\displaystyle\sum_{i=1}^{n} μ_iδ_{x^s_{i}} α=i=1nμiδxis β = ∑ j = 1 m ν j δ x j t β=\displaystyle\sum_{j=1}^{m} ν_jδ_{x^t_{j}} β=j=1mνjδxjt
    蒙日的意思就是找到一个映射 T : [ x i s ] − > [ x j t ] T:[x^s_i]->[x^t_j] T:[xis]>[xjt]使得成本最低,并且满足假设:
    任 意 j ∈ m , ν j = ∑ i : T ( x i s ) = x j t μ i 任意j∈m,ν_j=\displaystyle\sum_{i:T(x^s_i)=x^t_j} μ_i jmνj=i:T(xis)=xjtμi
    这个假设的意思就是,工厂 j j j需要的原材料需要从所有产地中获取,以上面的例子为例,工厂1需要100吨,那么就是:
    100 / 250 = ∑ i : T ( x i s ) = 100 μ i 100/250=\displaystyle\sum_{i:T(x^s_i)=100} μ_i 100/250=i:T(xis)=100μi T T T就是一个安排策略,它安排几个工厂提供原材料。
    示意图如下:
    最优传输系列(一):最优传输入门及Monge-Kantorovitch Problem_第2张图片
  • 但是蒙日的方法有个弊端,那就是只要 T T T决定了产地 i i i将材料运送至哪个工厂,就会将全部的原材料转运过去,而不能转运一部分。这一点从上图的示意图也能看出来,可以一个红色点接受多个蓝色点(即一个工厂接受多个产地,在迁移学习中就是一个目标域样本接受多个源域样本换句话说多个源域样本可以向同一个目标域样本迁移),但一个蓝色点不能向多个红色点运输(即一个产地的原材料不能向多个工厂运输,一个源域样本不能向多个目标域迁移)。
  • 我们再看这个式子
    100 / 250 = ∑ i : T ( x i s ) = 100 μ i 100/250=\displaystyle\sum_{i:T(x^s_i)=100} μ_i 100/250=i:T(xis)=100μi
    由于蒙日的弊端,对于 x i s x^s_i xis T T T的安排要么是用,要么就是不用,而不能使用 0.5 x i s 0.5x^s_i 0.5xis来参与构成100吨这个任务。
  • 再进一步看 r ( i , j ) r(i,j) r(i,j),由于蒙日的约束条件, r ( i , j ) r(i,j) r(i,j)的每一行都只有一个元素非零,且值为1。最后每一列元素相加等于对应工厂需要的原材料数。
  • 这个约束条件是非线性的,有的时候蒙日的方法是无解的。

Kantorovitch

  • Kantorovitch方法就是用来解决这个问题的,它是蒙日问题的松弛版本,不要求精确求解蒙日问题,而是对原来的要求进行松弛,允许每个元素拆分运送到目的地。
  • 如下图所示:
    最优传输系列(一):最优传输入门及Monge-Kantorovitch Problem_第3张图片
  • 其实我在上面那个例子讲解的时候提到,约束如下:
    ∑ i = 1 n r ( i , j ) = ν j \displaystyle\sum_{i=1}^{n} r(i,j)=ν_j i=1nr(i,j)=νj ∑ j = 1 m r ( i , j ) = μ i \displaystyle\sum_{j=1}^{m} r(i,j)=μ_i j=1mr(i,j)=μi
    其实这里对应的就是Kantorovitch方法的约束,我觉得如果只列出蒙日方法的约束条件 ∑ i = 1 n r ( i , j ) = ν j \displaystyle\sum_{i=1}^{n} r(i,j)=ν_j i=1nr(i,j)=νj会让人觉得缺少了另一半,会有些奇怪,对于这个例子来说一个产地只能运送材料到一个工厂确实是一件奇怪的事。
  • 对于约束条件的官方定义如下:
    最优传输系列(一):最优传输入门及Monge-Kantorovitch Problem_第4张图片
    官方定义的数学符号比较多,我上面其实就是对它定义的一种解读,方便大家理解,然后也方便大家进行对照。
    然后我上面提出的 r ( i , j ) r(i,j) r(i,j)其实就是对应这里的 U ( a , b ) U(a,b) U(a,b),每一行相加得源域的概率测度 α α α,每一列相加得到目标域的概率测度 β β β
    图示如下:
    最优传输系列(一):最优传输入门及Monge-Kantorovitch Problem_第5张图片
  • 官方定义下,最终对应的优化目标为:
    在这里插入图片描述
    上面的 P P P其实就对应我说的 r r r,符号 r r r是比较规范的符号, P P P的话一般用在与概率相关的方面,为了怕混淆了,大家可以记住符号 r r r,称为耦合测度,也比较好理解,就是耦合在一起的 α 和 β α和β αβ测度。

推广到连续空间

  • 上面对于两个问题的描述都是离散空间的,推广到连续空间的话,前面提到的 r ( i , j ) r(i,j) r(i,j) μ μ μ ν ν ν都变成了连续的,目标函数就变成了下面这样
    蒙日问题:
    我的定义: T 0 = a r g T m i n ∫ Ω 1 c ( x s , T ( x s ) ) d μ ( x s ) T_{0}=arg_{T}min\int_{Ω_1}c(x^s,T(x^s))dμ(x^s) T0=argTminΩ1c(xs,T(xs))dμ(xs) s.t. T# μ μ μ= ν ν ν
    意思就是,在源域的空间即 Ω 1 Ω_1 Ω1中,对针对所有点的(意思是连续的)概率测度函数与对应的成本进行积分,这就对应损失 T T T,而我们要选择成本最低的 μ μ μ,即最优的部署 T 0 T_0 T0,并且要满足部署 T T T能够使按照得选出来的 μ μ μ对源域进行分配后能够达到目标域的分布。
    官方定义如下:
    在这里插入图片描述
    Kantorovitch问题:
    我的定义: T 0 = a r g T m i n ∫ Ω 1 × Ω 2 c ( x s , x t ) d r ( x s , x t ) T_{0}=arg_{T}min\int_{Ω_1×Ω_2}c(x^s,x^t)dr(x^s,x^t) T0=argTminΩ1×Ω2c(xs,xt)dr(xs,xt) s.t. T# r r r= μ μ μ、T# r r r= ν ν ν
    官方定义:
    在这里插入图片描述
    其实有名的推土机距离就是这个成本,即Wasserstein metric
    在这里插入图片描述

引用

  • https://www.cnblogs.com/liuzhen1995/p/14524932.html#a2
  • https://zhuanlan.zhihu.com/p/26988777
  • Computational Optimal Transport Handbook

你可能感兴趣的:(最优传输,python,深度学习,最优传输)