torch.nn.Softmax

  • 很多博客对于这个函数讲的并不清晰,我在这里重新梳理一下。
  • 话不多说,直接上例子。
  • 下面是代码
import torch
import math
import torch.nn as nn
input=torch.Tensor([[ 0.5450, -0.6264,  1.0446],
                    [ 0.6324,  1.9069,  0.7158],
                    [ 0.3224,  0.5342,  -0.4561]])
softmax_input=nn.Softmax(dim=0)(input)
logsoftmax_input=nn.LogSoftmax(dim=0)(input)
print(softmax_input)
print(logsoftmax_input)
#输出
#tensor([[0.3458, 0.0596, 0.5147],
#        [0.3774, 0.7503, 0.3705],
#        [0.2768, 0.1901, 0.1148]])
#tensor([[-1.0619, -2.8206, -0.6641],
#        [-0.9745, -0.2873, -0.9929],
#        [-1.2845, -1.6600, -2.1648]])

我们可以看到输出的softmax_input的所有行对应下标之和是1,即0.3458+0.3774+0.2768=1、0.0596+0.7503+0.1901=1,这是因为我们规定dim=0,dim=0表示对于第一个维度的对应下标之和是1,那么我们再看一下dim=1的情况:

import torch
import math
import torch.nn as nn
input=torch.Tensor([[ 0.5450, -0.6264,  1.0446],
                    [ 0.6324,  1.9069,  0.7158],
                    [ 0.3224,  0.5342,  -0.4561]])
softmax_input=nn.Softmax(dim=1)(input)
logsoftmax_input=nn.LogSoftmax(dim=1)(input)
print(softmax_input)
print(logsoftmax_input)
#tensor([[0.3381, 0.1048, 0.5572],
#        [0.1766, 0.6315, 0.1919],
#        [0.3711, 0.4586, 0.1704]])
#tensor([[-1.0845, -2.2559, -0.5849],
#        [-1.7341, -0.4596, -1.6507],
#        [-0.9914, -0.7796, -1.7699]])

我们可以看到,dim=1表示对于第二维度而言,对应下标之和为1,0.3381+0.1048+0.5572=1,0.1766+0.6315+0.1919=1,即所有列的对应下标之和为1。

  • 下一个问题是softmax_input的输出是怎么计算的,我们以dim=0为例,输出是:
#输出
#tensor([[0.3458, 0.0596, 0.5147],
#        [0.3774, 0.7503, 0.3705],
#        [0.2768, 0.1901, 0.1148]])
#tensor([[-1.0619, -2.8206, -0.6641],
#        [-0.9745, -0.2873, -0.9929],
#        [-1.2845, -1.6600, -2.1648]])

以0.3458、0.0596、0.5147为例,计算公式就是softmax的公式

print(math.exp(0.5450)/(math.exp(0.5450)+math.exp(0.6324)+math.exp(0.3224)))
print(math.exp(-0.6264)/(math.exp(-0.6264)+math.exp(1.9069)+math.exp(0.5342)))
print(math.exp(1.0446)/(math.exp(1.0446)+math.exp(0.7158)+math.exp(-0.4561)))
#输出
#0.34580919722189674
#0.05957044131350709
#0.5147313583743323

你可能感兴趣的:(算法,softmax,算法)