首先我们要了解softmax的计算公式:
例如一列数组[1,2,3,4,5,6,7,8,9,10],代入到softamx计算公式之中,我们可以从公式中看出元素的数值越大,softmax算出的值也就越大,对应在图像处理中也就是概率越大。
import torch
a = torch.randn(2,3,4)
print(a)
a是tensor型,维度为(2,3,4)的,输出为:
tensor([[[-0.5947, 0.9496, -1.9366, 0.0580],
[ 0.8346, 2.2958, 1.4638, -0.0390],
[ 0.1379, 1.2085, 1.2671, -1.2156]],
[[-0.1980, -0.8921, -1.1665, -1.2323],
[-0.8220, -1.9701, 1.7183, -0.2698],
[ 0.1977, -0.1474, -1.5431, -2.2217]]])
dim=0时,对应维度数值为2,因此每次只有两个元素进行Softmax计算
softmax0 = nn.Softmax(dim=0)
t0 = softmax0(a)
上述两式子合成一个: nn.Softmax(dim=0)(a)
dim=1时,对应维度数值为3,因此每次只有3个元素进行Softmax计算
softmax1 = nn.Softmax(dim=1)
t1 = softmax1(a)
上述两式子合成一个: nn.Softmax(dim=1)(a)
dim=2时,对应维度数值为4,因此每次只有4个元素进行Softmax计算
softmax2 = nn.Softmax(dim=2)
t2 = softmax2(a)
上述两式子合成一个: nn.Softmax(dim=2)(a)
我们可以把上述例子看成 [batch, h , w], dim=1是对n*m矩阵的逐列进行softmax,dim=2是对每行进行softmax操作,并且所有的操作对不同的 batch 都是独立进行的。