迁移学习 - 域适应Coral损失函数 - 代码实现

Coral损失函数介绍

首先定义Coral损失函数

import torch
def CORAL(source, target, **kwargs):
    d = source.data.shape[1]
    ns, nt = source.data.shape[0], target.data.shape[0]
    # source covariance
    xm = torch.mean(source, 0, keepdim=True) - source
    xc = xm.t() @ xm / (ns - 1)

    # target covariance
    xmt = torch.mean(target, 0, keepdim=True) - target
    xct = xmt.t() @ xmt / (nt - 1)

    # frobenius norm between source and target
    loss = torch.mul((xc - xct), (xc - xct))
    loss = torch.sum(loss) / (4*d*d)
    return loss

用模拟数据进行验证

# 通过随机数模拟产生经过模型输出的结果source和target
# batch可以不一样,但分类类数要一样
source = torch.rand(64,4)  # 源域输出结果为batch=64, 4分类
target = torch.rand(64,4)  # 目标域域输出结果为batch=64, 4分类
Coral_loss = CORAL(source=source, target=target)
print(Coral_loss)
>>>output
tensor(3.3486e-05)

参考资料
链接: https://zhuanlan.zhihu.com/p/108778552.

你可能感兴趣的:(基于深度学习的故障诊断,迁移学习,pytorch,深度学习)