最近看论文代码看到SK_Net代码时对softmax的dim这个参数不太理解,就写了个简单的代码输出看了一下,其实意思就是使得在softmax操作之后在dim这个维度相加等于1:
import torch
import torch.nn as nn
x = torch.rand(3,2,4)
print(x)
print("----------------------------")
softmax=nn.Softmax(dim = 0)
y = softmax(x)
print(y)
print(y.sum(0))
print(y.sum(1))
print(y.sum(2))
print("----------------------------")
softmax=nn.Softmax(dim = 1)
y = softmax(x)
print(y)
print(y.sum(0))
print(y.sum(1))
print(y.sum(2))
print("----------------------------")
softmax=nn.Softmax(dim = 2)
y = softmax(x)
print(y)
print(y.sum(0))
print(y.sum(1))
print(y.sum(2))