Deep Coral loss

import torch


def CORAL(source, target):
    d = source.data.shape[1] #coral公式中的分母部分
    ns, nt = source.data.shape[0], target.data.shape[0]
    # source covariance
    xm = torch.mean(source, 0, keepdim=True) - source #对应着Cs的分子部分
    xc = xm.t() @ xm/(ns-1)  #对应着Cs的分子部分

    # target covariance
    xmt = torch.mean(target, 0, keepdim=True) - target#对应着Ct的分子部分
    xct = xmt.t() @ xmt/(nt-1)#对应着Ct的分子部分

    # frobenius norm between source and target
    loss = torch.mean(torch.mul((xc - xct), (xc - xct))) #Cs-Ct的点乘
    loss = loss/(4*d*d)

    return loss

Coral公式:

Deep Coral loss_第1张图片

只做学习使用,作者也是看了别人进行了学习总结,希望能对你有所帮助。

 故障诊断与python学习

Deep CORAL: Correlation Alignment for Deep Domain Adaptation_gdtop818的博客-CSDN博客

你可能感兴趣的:(迁移学习)