【代码阅读】最大均值差异(Maximum Mean Discrepancy, MMD)损失函数代码解读(Pytroch版)

##代码及参考资料来源
Source code: easezyc/deep-transfer-learning [Github]
参考资料:迁移学习简明手册

MMD介绍

MMD(最大均值差异)是迁移学习,尤其是Domain adaptation (域适应)中使用最广泛(目前)的一种损失函数,主要用来度量两个不同但相关的分布的距离。两个分布的距离定义为:
(1) M M D ( X , Y ) = ∣ ∣ 1 n ∑ i = 1 n ϕ ( x i ) − 1 m ∑ j = 1 m ϕ ( y j ) ∣ ∣ H 2 MMD(X,Y) = ||\frac{1}{n}\sum_{i=1}^n\phi(x_i)-\frac{1}{m}\sum_{j=1}^m\phi(y_j)||_H^2\tag{1} MMD(X,Y)=n1i=1nϕ(xi)m1j=1mϕ(yj)H2(1)
其中 H H H 表示这个距离是由 ϕ ( ) \phi() ϕ() 将数据映射到再生希尔伯特空间(RKHS)中进行度量的。

为什么要用MMD?

Domain adaptation的目的是将源域(Source domain)中学到的知识可以应用到不同但相关的目标域(Target domain)。本质上是要找到一个变换函数,使得变换后的源域数据和目标域数据的距离是最小的。所以这其中就要涉及如何度量两个域中数据分布差异的问题,因此也就用到了MMD。至于Domain adaptation的前生今世可以参考王晋东大佬的知乎专栏

MMD的理论推导

MMD的关键在于如何找到一个合适的 ϕ ( ) \phi() ϕ() 来作为一个映射函数。但是这个映射函数可能在不同的任务中都不是固定的,并且这个映射可能高维空间中的映射,所以是很难去选取或者定义的。那如果不能知道 ϕ \phi ϕ,那MMD该如何求呢?我们先展开把MMD展开:
(2) M M D ( X , Y ) = ∣ ∣ 1 n 2 ∑ i n ∑ i ′ n ϕ ( x i ) ϕ ( x i ′ ) − 2 n m ∑ i n ∑ j m ϕ ( x i ) ϕ ( y j ) + 1 m 2 ∑ j m ∑ j ′ m ϕ ( y j ) ϕ ( y j ′ ) ∣ ∣ H 2 MMD(X,Y) =||\frac{1}{n^2}\sum_{i}^n\sum_{i'}^n\phi(x_i)\phi(x_i')-\frac{2}{nm}\sum_{i}^n\sum_{j}^m\phi(x_i)\phi(y_j)+\frac{1}{m^2}\sum_{j}^m\sum_{j'}^m\phi(y_j)\phi(y_j')||_H^2\tag{2} MMD(X,Y)=n21ininϕ(xi)ϕ(xi)nm2injmϕ(xi)ϕ(yj)+m21jmjmϕ(yj)ϕ(yj)H2(2)
展开后就出现了 ϕ ( x i ) ϕ ( x i ′ ) \phi(x_i)\phi(x_i') ϕ(xi)ϕ(xi)的形式,这样联系SVM中的核函数 k ( ∗ ) k(*) k(),就可以跳过计算 ϕ \phi ϕ的部分,直接求 k ( x i ) k ( x i ′ ) k(x_i)k(x_i') k(xi)k(xi)。所以MMD又可以表示为:
(3) M M D ( X , Y ) = ∣ ∣ 1 n 2 ∑ i n ∑ i ′ n k ( x i , x i ′ ) − 2 n m ∑ i n ∑ j m k ( x i , y j ) + 1 m 2 ∑ j m ∑ j ′ m k ( y j , y j ′ ) ∣ ∣ H MMD(X,Y) =||\frac{1}{n^2}\sum_{i}^n\sum_{i'}^nk(x_i, x_i')-\frac{2}{nm}\sum_{i}^n\sum_{j}^mk(x_i, y_j)+\frac{1}{m^2}\sum_{j}^m\sum_{j'}^mk(y_j, y_j')||_H\tag{3} MMD(X,Y)=n21inink(xi,xi)nm2injmk(xi,yj)+m21jmjmk(yj,yj)H(3)
在大多数论文中(比如DDC, DAN),都是用高斯核函数 k ( u , v ) = e − ∣ ∣ u − v ∣ ∣ 2 σ k(u,v) = e^{\frac{-||u-v||^2}{\sigma}} k(u,v)=eσuv2来作为核函数,至于为什么选用高斯核,最主要的应该是高斯核可以映射无穷维空间(具体的之后再分析)

理论到这里就差不多了,那如何进行实现呢?

在TCA中,引入了一个核矩阵方便计算
(4) [ K s , s K s , s K s , t K t , t ] \begin{bmatrix} K_{s,s} & K_{s,s} \\ K_{s,t} & K_{t,t} \\ \end{bmatrix} \tag{4} [Ks,sKs,tKs,sKt,t](4)
以及L矩阵:
(5) l i , j = { 1 / n 2 , x i , x j ∈ D s 1 / m 2 , x i , x j ∈ D s − 1 / n m , otherwise l_{i,j} = \begin{cases} 1/{n^2}, & \text{$x_i, x_j\in D_s$} \\ 1/{m^2}, & \text{$x_i, x_j\in D_s$} \\ -1/{nm},& \text{otherwise} \end{cases} \tag{5} li,j=1/n2,1/m2,1/nm,xi,xjDsxi,xjDsotherwise(5)
在实际应用中,高斯核的 σ \sigma σ会取多个值,分别求核函数然后取和,作为最后的核函数。
##代码解读

import torch

def guassian_kernel(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
    '''
    将源域数据和目标域数据转化为核矩阵,即上文中的K
    Params: 
	    source: 源域数据(n * len(x))
	    target: 目标域数据(m * len(y))
	    kernel_mul: 
	    kernel_num: 取不同高斯核的数量
	    fix_sigma: 不同高斯核的sigma值
	Return:
		sum(kernel_val): 多个核矩阵之和
    '''
    n_samples = int(source.size()[0])+int(target.size()[0])# 求矩阵的行数,一般source和target的尺度是一样的,这样便于计算
    total = torch.cat([source, target], dim=0)#将source,target按列方向合并
    #将total复制(n+m)份
    total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
    #将total的每一行都复制成(n+m)行,即每个数据都扩展成(n+m)份
    total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
    #求任意两个数据之间的和,得到的矩阵中坐标(i,j)代表total中第i行数据和第j行数据之间的l2 distance(i==j时为0)
    L2_distance = ((total0-total1)**2).sum(2) 
    #调整高斯核函数的sigma值
    if fix_sigma:
        bandwidth = fix_sigma
    else:
        bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples)
    #以fix_sigma为中值,以kernel_mul为倍数取kernel_num个bandwidth值(比如fix_sigma为1时,得到[0.25,0.5,1,2,4]
    bandwidth /= kernel_mul ** (kernel_num // 2)
    bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)]
    #高斯核函数的数学表达式
    kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]
    #得到最终的核矩阵
    return sum(kernel_val)#/len(kernel_val)

def mmd_rbf(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
    '''
    计算源域数据和目标域数据的MMD距离
    Params: 
	    source: 源域数据(n * len(x))
	    target: 目标域数据(m * len(y))
	    kernel_mul: 
	    kernel_num: 取不同高斯核的数量
	    fix_sigma: 不同高斯核的sigma值
	Return:
		loss: MMD loss
    '''
    batch_size = int(source.size()[0])#一般默认为源域和目标域的batchsize相同
    kernels = guassian_kernel(source, target,
        kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma)
    #根据式(3)将核矩阵分成4部分
    XX = kernels[:batch_size, :batch_size]
    YY = kernels[batch_size:, batch_size:]
    XY = kernels[:batch_size, batch_size:]
    YX = kernels[batch_size:, :batch_size]
    loss = torch.mean(XX + YY - XY -YX)
    return loss#因为一般都是n==m,所以L矩阵一般不加入计算

代码示例

为了体现以上代码的有效性,我们参考链接生成了两组不同分布的数据。

import random
import matplotlib
import matplotlib.pyplot as plt

SAMPLE_SIZE = 500
buckets = 50

#第一种分布:对数正态分布,得到一个中值为mu,标准差为sigma的正态分布。mu可以取任何值,sigma必须大于零。
plt.subplot(1,2,1)
plt.xlabel("random.lognormalvariate")
mu = -0.6
sigma = 0.15#将输出数据限制到0-1之间
res1 = [random.lognormvariate(mu, sigma) for _ in xrange(1, SAMPLE_SIZE)]
plt.hist(res1, buckets)

#第二种分布:beta分布。参数的条件是alpha 和 beta 都要大于0, 返回值在0~1之间。
plt.subplot(1,2,2)
plt.xlabel("random.betavariate")
alpha = 1
beta = 10
res2 = [random.betavariate(alpha, beta) for _ in xrange(1, SAMPLE_SIZE)]
plt.hist(res2, buckets)

plt.savefig('data.jpg)
plt.show()

两种数据分布如下图
【代码阅读】最大均值差异(Maximum Mean Discrepancy, MMD)损失函数代码解读(Pytroch版)_第1张图片
两种分布有明显的差异,下面从两个方面用MMD来量化这种差异:
1. 分别从不同分布取两组数据(每组为10*500)

from torch.autograd import Variable

#参数值见上段代码
#分别从对数正态分布和beta分布取两组数据
diff_1 = []
for i in range(10):
    diff_1.append([random.lognormvariate(mu, sigma) for _ in xrange(1, SAMPLE_SIZE)])

diff_2 = []
for i in range(10):
    diff_2.append([random.betavariate(alpha, beta) for _ in xrange(1, SAMPLE_SIZE)])

X = torch.Tensor(diff_1)
Y = torch.Tensor(diff_2)
X,Y = Variable(X), Variable(Y)
print mmd_rbf(X,Y)

输出结果为

Variable containing:
 6.1926
[torch.FloatTensor of size 1]

2. 分别从相同分布取两组数据(每组为10*500)

from torch.autograd import Variable

#参数值见以上代码
#从对数正态分布取两组数据
same_1 = []
for i in range(10):
    same_1.append([random.lognormvariate(mu, sigma) for _ in xrange(1, SAMPLE_SIZE)])

same_2 = []
for i in range(10):
    same_2.append([random.lognormvariate(mu, sigma) for _ in xrange(1, SAMPLE_SIZE)])

X = torch.Tensor(same_1)
Y = torch.Tensor(same_2)
X,Y = Variable(X), Variable(Y)
print mmd_rbf(X,Y)

输出结果为

Variable containing:
 0.6014
[torch.FloatTensor of size 1]

可以明显看出同分布数据和不同分布数据之间的差距被量化了出来,且符合之前理论所说:不同分布MMD的值大于相同分布MMD的值。
PS,在实验中发现一个问题,就是取数据时要在0-1的范围内取,不然MMD就失效了。
如果错误之处,请指正,感谢阅读

你可能感兴趣的:(Transfer,learning)