torch.softmax()

这里区别于max,argmax,这里参数始终都有dim!!!

torch.nn.functional.softmax(inputdim=None_stacklevel=3dtype=None)[SOURCE]

Applies a softmax function.

Softmax is defined as:

\text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}Softmax(xi​)=∑j​exp(xj​)exp(xi​)​

It is applied to all slices along dim, and will re-scale them so that the elements lie in the range [0, 1] and sum to 1.

See Softmax for more details.

Parameters

  • input (Tensor) – input

  • dim (int) – A dimension along which softmax will be computed.

  • dtype (torch.dtype, optional) – the desired data type of returned tensor. If specified, the input tensor is casted to before the operation is performed. This is useful for preventing data type overflows. Default: None.dtype

NOTE

This function doesn’t work directly with NLLLoss, which expects the Log to be computed between the Softmax and itself. Use log_softmax instead (it’s faster and has better numerical properties).

softmax作用与模型应用

首先说一下Softmax函数,公式如下:

1. 三维tensor(C,H,W)

一般会设置成dim=0,1,2,-1的情况(可理解为维度索引)。其中2与-1等价,相同效果。

用一张图片来更好理解这个参数dim数值变化:

torch.softmax()_第1张图片

dim=0时, 是对每一维度相同位置的数值进行 softmax运算,和为1
dim=1时, 是对某一维度的列进行 softmax运算,和为1
dim=2时, 是对某一维度的行进行 softmax运算,和为1

准备工作:先随机生成一个(2,5,4)的矩阵,即两个维度的(5,4)矩阵

import torch 
import torch.nn.functional as F 
input= torch.randn(2,2,3))
print(input)

torch.softmax()_第2张图片

随机3维矩阵

(1) dim=0

torch.softmax()_第3张图片

dim=0

(2) dim=1

torch.softmax()_第4张图片

dim=1

(3) dim=2 或dim=-1

torch.softmax()_第5张图片

dim=2

torch.softmax()_第6张图片

dim=-1

2. 四维tensor(B,C,H,W)

是三维tensor的推广,其实三维tensor也可以是batchsize=1的四维tensor,只是dim的索引需要加1.

dim取值0,1,2,3,-1

准备工作:先随机生成一个(2,2,5,4)矩阵。其实随着dim增加(从0到3),相当于一层层剥开。

torch.softmax()_第7张图片

(1) dim=0

这时的视野应该放在整个tensor,每个batch(不同B)对应位置(相同CHW)求softmax

torch.softmax()_第8张图片

(2) dim=1

这时向里剥,每小块(不同C)对应位置(相同BHW)求softmax。

torch.softmax()_第9张图片

(3) dim=2

继续向里剥,每小块(不同H)对应位置(相同BCW)求Softmax。

torch.softmax()_第10张图片

(4) dim=3 或dim=-1

继续向里剥,也是最后一次。每个小块(不同W)对应元素(相同BCH)求softmax。

torch.softmax()_第11张图片

你可能感兴趣的:(python,深度学习,pytorch,人工智能)