torch. nn.Softmax(dim=1)

最近看论文代码看到SK_Net代码时对softmax的dim这个参数不太理解,就写了个简单的代码输出看了一下,其实意思就是使得在softmax操作之后在dim这个维度相加等于1:

import torch
import torch.nn as nn
x = torch.rand(3,2,4)
print(x)
print("----------------------------")
softmax=nn.Softmax(dim = 0)
y = softmax(x)
print(y)
print(y.sum(0))
print(y.sum(1))
print(y.sum(2))
print("----------------------------")
softmax=nn.Softmax(dim = 1)
y = softmax(x)
print(y)
print(y.sum(0))
print(y.sum(1))
print(y.sum(2))
print("----------------------------")
softmax=nn.Softmax(dim = 2)
y = softmax(x)
print(y)
print(y.sum(0))
print(y.sum(1))
print(y.sum(2))

torch. nn.Softmax(dim=1)_第1张图片
torch. nn.Softmax(dim=1)_第2张图片
torch. nn.Softmax(dim=1)_第3张图片
torch. nn.Softmax(dim=1)_第4张图片
torch. nn.Softmax(dim=1)_第5张图片

你可能感兴趣的:(python,python)