先上代码
import torch.nn as nn
import torch
m = nn.Softmax(dim=0)
n = nn.Softmax(dim=1)
k = nn.Softmax(dim=2)
v = nn.Softmax(dim=3)
input = torch.randn(2, 2, 2, 3)
print(input)
print('dim = 0: ', m(input))
print('dim = 1: ', n(input))
print('dim = 2: ', k(input))
print('dim = 3: ', v(input))
输出
tensor([[[[ 0.4846, 1.1182, 0.8734],
[-0.2782, 0.3953, -0.8374]],
[[-0.1427, -1.0001, -1.3453],
[-2.2306, -2.5269, -1.2058]]],
[[[ 0.6038, -0.1763, -0.8722],
[ 0.1515, -0.5093, 0.3797]],
[[-0.7225, -0.3317, 0.1577],
[-1.5082, 2.1467, -1.1849]]]])
dim = 0: tensor([[[[0.4702, 0.7849, 0.8514],
[0.3942, 0.7119, 0.2284]],
[[0.6410, 0.3389, 0.1820],
[0.3268, 0.0093, 0.4948]]],
[[[0.5298, 0.2151, 0.1486],
[0.6058, 0.2881, 0.7716]],
[[0.3590, 0.6611, 0.8180],
[0.6732, 0.9907, 0.5052]]]])
dim = 1: tensor([[[[0.6519, 0.8927, 0.9019],
[0.8757, 0.9489, 0.5911]],
[[0.3481, 0.1073, 0.0981],
[0.1243, 0.0511, 0.4089]]],
[[[0.7902, 0.5388, 0.2631],
[0.8402, 0.0656, 0.8270]],
[[0.2098, 0.4612, 0.7369],
[0.1598, 0.9344, 0.1730]]]])
dim = 2: tensor([[[[0.6820, 0.6732, 0.8469],
[0.3180, 0.3268, 0.1531]],
[[0.8897, 0.8215, 0.4652],
[0.1103, 0.1785, 0.5348]]],
[[[0.6112, 0.5825, 0.2224],
[0.3888, 0.4175, 0.7776]],
[[0.6869, 0.0774, 0.7929],
[0.3131, 0.9226, 0.2071]]]])
dim = 3: tensor([[[[0.2294, 0.4322, 0.3384],
[0.2831, 0.5551, 0.1618]],
[[0.5798, 0.2460, 0.1742],
[0.2207, 0.1641, 0.6151]]],
[[[0.5928, 0.2717, 0.1355],
[0.3606, 0.1862, 0.4531]],
[[0.2045, 0.3023, 0.4932],
[0.0244, 0.9420, 0.0337]]]])
三维参考:
https://blog.csdn.net/sunyueqinghit/article/details/101113251