转载:https://blog.csdn.net/weixin_41391619/article/details/104823086
这篇讲的很清楚,转一下
本篇主要分析softmax函数中的dim参数, 首先介绍一下softmax函数:
设
则
接下来分析torch.nn里面的softmax函数
y = torch.tensor([[[1.,2.,3.],[4.,5.,6.]],[[7.,8.,9.],[10.,11.,12.]]])
#y的size是2,2,3。可以看成有两张表,每张表2行3列
net_1 = nn.Softmax(dim=0)
net_2 = nn.Softmax(dim=1)
net_3 = nn.Softmax(dim=2)
print('dim=0的结果是:\n',net_1(y),"\n")
print('dim=1的结果是:\n',net_2(y),"\n")
print('dim=2的结果是:\n',net_3(y),"\n")
dim = 0指第一个维度,在本例中第一个维度的size是2,如前文所说,我们把“2”看成是两张表,那么0.0025和0.9975是怎么来的呢?
第一张表中6个数的平均值是:(1+2+3+4+5+6)/6 = 3.5
第二张表中6个数的平均值是:(7+8+9+10+11+12)/6 = 9.5
dim = 1指第二个维度,在本例中第二个维度的size是2,我们可以看成是2行。
我们把所有表中的第一行的数据拿出来:1,2,3;7,8,9 求平均:5
我们把所有表中的第二行的数据拿出来:4,5,6;10,11,12 求平均:8
dim = 2指第三个维度,在本例中第三个维度的size是3,我们可以看成是3列。
我们把所有表中的第1列的数据拿出来:1,4;7,10 求平均:5.5
我们把所有表中的第2列的数据拿出来:2,5;8,11 求平均:6.5
我们把所有表中的第3列的数据拿出来:3,6;9,12 求平均:7.5