本篇主要分析softmax函数中的dim参数
首先介绍一下softmax函数:
设 x = [1,2,3]
则softmax(x)= [ e 1 e 1 + e 2 + e 3 \frac{e^1}{e^1+e^2+e^3} e1+e2+e3e1 , e 2 e 1 + e 2 + e 3 \frac{e^2}{e^1+e^2+e^3} e1+e2+e3e2 , e 3 e 1 + e 2 + e 3 \frac{e^3}{e^1+e^2+e^3} e1+e2+e3e3]
接下来分析torch.nn里面的softmax函数
y = torch.rand(size=[2,2,3])
print(y)
#y的size是2,2,3。可以看成有两张表,每张表2行3列
tensor([[[0.7536, 0.2712, 0.4176],
[0.2008, 0.9512, 0.4904]],
[[0.2232, 0.7459, 0.9858],
[0.9141, 0.0604, 0.4307]]])
net_1 = nn.Softmax(dim=0)
net_2 = nn.Softmax(dim=1)
net_3 = nn.Softmax(dim=2)
print('dim=0的结果是:\n',net_1(y),"\n")
print('dim=1的结果是:\n',net_2(y),"\n")
print('dim=2的结果是:\n',net_3(y),"\n")
dim=0的结果是:
tensor([[[0.6296, 0.3835, 0.3617],
[0.3289, 0.7091, 0.5149]],
[[0.3704, 0.6165, 0.6383],
[0.6711, 0.2909, 0.4851]]])
dim=1的结果是:
tensor([[[0.6348, 0.3363, 0.4818],
[0.3652, 0.6637, 0.5182]],
[[0.3338, 0.6650, 0.6353],
[0.6662, 0.3350, 0.3647]]])
dim=2的结果是:
tensor([[[0.4288, 0.2647, 0.3065],
[0.2245, 0.4755, 0.3000]],
[[0.2070, 0.3492, 0.4438],
[0.4896, 0.2085, 0.3019]]])
dim=0的结果是:
tensor([[[0.6296, 0.3835, 0.3617],
[0.3289, 0.7091, 0.5149]],
[[0.3704, 0.6165, 0.6383],
[0.6711, 0.2909, 0.4851]]])
dim = 0指第一个维度,在本例中第一个维度的size是2,如前文所说,我们把“2”看成是两张表,那么0.6296、0.3704是怎么来的呢?
这两张表中的数分别是:
for i in range(2):
print(y[i,:,:].reshape(-1))
tensor([0.7536, 0.2712, 0.4176, 0.2008, 0.9512, 0.4904])
tensor([0.2232, 0.7459, 0.9858, 0.9141, 0.0604, 0.4307])
0.6296 ≈ e 0.7536 e 0.7536 + e 0.2232 0.6296≈\frac{e^{0.7536}}{e^{0.7536}+e^{0.2232}} 0.6296≈e0.7536+e0.2232e0.7536
0.3704 ≈ e 0.2232 e 0.7536 + e 0.2232 0.3704≈\frac{e^{0.2232}}{e^{0.7536}+e^{0.2232}} 0.3704≈e0.7536+e0.2232e0.2232
0.6296 + 0.3704 = 1 0.6296+0.3704=1 0.6296+0.3704=1
likewise:
0.3289 ≈ e 0.2008 e 0.2008 + e 0.9141 0.3289≈\frac{e^{0.2008}}{e^{0.2008}+e^{0.9141}} 0.3289≈e0.2008+e0.9141e0.2008
0.6711 ≈ e 0.9141 e 0.2008 + e 0.9141 0.6711≈\frac{e^{0.9141}}{e^{0.2008}+e^{0.9141}} 0.6711≈e0.2008+e0.9141e0.9141
0.3289 + 0.6711 = 1 0.3289+0.6711=1 0.3289+0.6711=1
dim=1的结果是:
tensor([[[0.6348, 0.3363, 0.4818],
[0.3652, 0.6637, 0.5182]],
[[0.3338, 0.6650, 0.6353],
[0.6662, 0.3350, 0.3647]]])
dim = 1指第二个维度,在本例中第二个维度的size是2,即我们可以把原始数据中的所有数看成是2行数组成的。
这两行的数分别是:
for i in range(2):
print(y[:,i,:].reshape(-1))
tensor([0.7536, 0.2712, 0.4176, 0.2232, 0.7459, 0.9858])
tensor([0.2008, 0.9512, 0.4904, 0.9141, 0.0604, 0.4307])
0.6348 ≈ e 0.7536 e 0.7536 + e 0.2008 0.6348≈\frac{e^{0.7536}}{e^{0.7536}+e^{0.2008}} 0.6348≈e0.7536+e0.2008e0.7536
0.3652 ≈ e 0.2008 e 0.7536 + e 0.2008 0.3652≈\frac{e^{0.2008}}{e^{0.7536}+e^{0.2008}} 0.3652≈e0.7536+e0.2008e0.2008
0.6348 + 0.3652 = 1 0.6348+0.3652=1 0.6348+0.3652=1
dim=2的结果是:
tensor([[[0.4288, 0.2647, 0.3065],
[0.2245, 0.4755, 0.3000]],
[[0.2070, 0.3492, 0.4438],
[0.4896, 0.2085, 0.3019]]])
dim = 2指第三个维度,在本例中第三个维度的size是3,即我们可以把原始数据中的所有数看成是3列数组成的。
这3列数分别是:
for i in range(3):
print(y[:,:,i].reshape(-1))
tensor([0.7536, 0.2008, 0.2232, 0.9141])
tensor([0.2712, 0.9512, 0.7459, 0.0604])
tensor([0.4176, 0.4904, 0.9858, 0.4307])
0.4288 ≈ e 0.7536 e 0.7536 + e 0.2712 + e 0.4176 0.4288≈\frac{e^{0.7536}}{e^{0.7536}+e^{0.2712}+e^{0.4176}} 0.4288≈e0.7536+e0.2712+e0.4176e0.7536
0.2647 ≈ e 0.2712 e 0.7536 + e 0.2712 + e 0.4176 0.2647≈\frac{e^{0.2712}}{e^{0.7536}+e^{0.2712}+e^{0.4176}} 0.2647≈e0.7536+e0.2712+e0.4176e0.2712
0.3065 ≈ e 0.4176 e 0.7536 + e 0.2712 + e 0.4176 0.3065≈\frac{e^{0.4176}}{e^{0.7536}+e^{0.2712}+e^{0.4176}} 0.3065≈e0.7536+e0.2712+e0.4176e0.4176
0.4288 + 0.2647 + 0.3065 = 1 0.4288+0.2647+0.3065=1 0.4288+0.2647+0.3065=1