Pytorch中的KL散度

 

import torch.nn as nn
import torch
import torch.nn.functional as F

if __name__ == '__main__':
    x_o = torch.Tensor([[1, 2], [3, 4]])
    y_o = torch.Tensor([[0.1, 0.2], [0.3, 0.4]])

    x = F.log_softmax(x_o, dim=-1)

    y = F.softmax(y_o, dim=-1)
    criterion = nn.KLDivLoss()
    klloss = criterion(x, y)

    print('klloss', klloss)

    kl = F.kl_div(x, y, reduction='sum')

    print('kl', kl)

    kl2 = F.kl_div(x, y, reduction='mean')

    print('kl2', kl2)

klloss tensor(0.0482)
kl tensor(0.1928)
kl2 tensor(0.0482)

你可能感兴趣的:(Python,Pytorch,大数据)