MMD最大均值差异代码解析

import torch
#采用的高斯核函数
def guassian_kernel(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
    '''
    source: 源域
    target:
    kernel_mul:核的倍数
    kernel_num:多少个核心
   
    n_samples = int(source.size()[0])+int(target.size()[0])
    total = torch.cat([source, target], dim=0)
    total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
    total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
    L2_distance = ((total0-total1)**2).sum(2)
    #求出高斯核函数的分母||u-v||**2
    if fix_sigma:
        bandwidth = fix_sigma
    else:
        bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples)
        #上面初始化bandwidth
    bandwidth /= kernel_mul ** (kernel_num // 2)#// 表示两数相除取整 **表示幂运算
    bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)]
    '''
    一系列操作求出分母的列表bandwidth_list,一共设置有五个核,所以求出的列表i从0、1、2、3、4共5个值的列表
    
       '''
    kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]
# kernel_val也求出5个值作为一个列表,求和值
    return sum(kernel_val))

'''
当使用MMD方法的时候我们使用它,输入源域和目标域数据,这里是每次迭代都使用MMD方法。
print(batch_size)可以帮助你每次查看batch的大小,防止由于drop_last的原因导致的batch源域与目标域数据不匹配的现象出现。
'''
def DAN(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
    batch_size = int(source.size()[0])
    #print(batch_size)
    kernels = guassian_kernel(source, target,kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma)
#kernels返回sum(kernel_val)
    '''
    实现高斯核函数kernels = sum(kernel_val)
    '''
    XX = kernels[:batch_size, :batch_size]#都是取源域的数据拼接前半部分
    YY = kernels[batch_size:, batch_size:]#都是取目标域数据拼接后半部分
    XY = kernels[:batch_size, batch_size:]
    YX = kernels[batch_size:, :batch_size]
    loss = torch.mean(XX + YY - XY - YX)
    return loss

'''
为什么使用高斯核函数,某种意义上它可以实现无限维度的映射。 

'''

你可能感兴趣的:(pytorch,深度学习,python)