【基础知识】pytorch:nn.Softmax()

转载:https://blog.csdn.net/weixin_41391619/article/details/104823086

这篇讲的很清楚,转一下 

本篇主要分析softmax函数中的dim参数, 首先介绍一下softmax函数:
x=\left [ 1,2,3 \right ]
则 softmax(x)=\left [ \frac{e^1}{e^1+e^2+e^3},\frac{e^2}{e^1+e^2+e^3},\frac{e^3}{e^1+e^2+e^3}\right ]

接下来分析torch.nn里面的softmax函数

y = torch.tensor([[[1.,2.,3.],[4.,5.,6.]],[[7.,8.,9.],[10.,11.,12.]]])
#y的size是2,2,3。可以看成有两张表,每张表2行3列
net_1 = nn.Softmax(dim=0)
net_2 = nn.Softmax(dim=1)
net_3 = nn.Softmax(dim=2)
print('dim=0的结果是:\n',net_1(y),"\n")
print('dim=1的结果是:\n',net_2(y),"\n")
print('dim=2的结果是:\n',net_3(y),"\n")

                        【基础知识】pytorch:nn.Softmax()_第1张图片

dim = 0:

dim = 0指第一个维度,在本例中第一个维度的size是2,如前文所说,我们把“2”看成是两张表,那么0.0025和0.9975是怎么来的呢?
第一张表中6个数的平均值是:(1+2+3+4+5+6)/6 = 3.5
第二张表中6个数的平均值是:(7+8+9+10+11+12)/6 = 9.5

0.0025\approx \frac{e^{3.5}}{e^{3.5}+e^{9.5}}

0.9975\approx \frac{e^{9.5}}{e^{3.5}+e^{9.5}}

dim = 1:

dim = 1指第二个维度,在本例中第二个维度的size是2,我们可以看成是2行。
我们把所有表中的第一行的数据拿出来:1,2,3;7,8,9 求平均:5
我们把所有表中的第二行的数据拿出来:4,5,6;10,11,12 求平均:8

0.0474\approx \frac{e^{5}}{e^{5}+e^{8}}

0.9526\approx \frac{e^{8}}{e^{5}+e^{8}}

dim = 2:

dim = 2指第三个维度,在本例中第三个维度的size是3,我们可以看成是3列。
我们把所有表中的第1列的数据拿出来:1,4;7,10 求平均:5.5
我们把所有表中的第2列的数据拿出来:2,5;8,11 求平均:6.5
我们把所有表中的第3列的数据拿出来:3,6;9,12 求平均:7.5
0.009\approx \frac{e^{5.5}}{e^{5.5}+e^{6.5}+e^{7.5}}

0.2447\approx \frac{e^{6.5}}{e^{5.5}+e^{6.5}+e^{7.5}}

0.6652\approx \frac{e^{7.5}}{e^{5.5}+e^{6.5}+e^{7.5}}

你可能感兴趣的:(【基础知识】pytorch:nn.Softmax())