torch.nn.Softmax(dim=0,1,2)

文章目录

  • 1.基础概念讲解
    • 1.1 softmax函数
    • 1.2 softmax函数计算方法
    • 1.3 softmax函数公式
    • 1.4 softmax函数项目代码展示
  • 2.测试代码
  • 3.结果
    • 3.1 y
      • 3.1.1 x和y方向:
      • 3.1.2 z方向:
    • 3.2 dim=0的结果
    • 3.3 dim=1的结果
    • 3.4 dim=2的结果

1.基础概念讲解

1.1 softmax函数

softmax函数:又称归一化指数函数,是二分类函数sigmoid在多分类上的推广,目的是将多分类的结果以概率的形式展现出来
作用:模型已经有分类预测结果以后,将预测结果输入softmax函数,进行非负性和归一化处理,最后得到0-1之内的分类概率

1.2 softmax函数计算方法

图示:
torch.nn.Softmax(dim=0,1,2)_第1张图片

1.3 softmax函数公式

softmax函数:
在这里插入图片描述

1.4 softmax函数项目代码展示

以下代码只是一个项目中的一部分,主要是看softmax函数:

			sm = torch.nn.Softmax(dim=1)
			scores = sm(model(inputs))

model(inputs)输出为:
在这里插入图片描述

torch.nn.Softmax(dim=0,1,2)_第2张图片
注:512为图片张数,2为两类各自的概率
经过softmax函数,进行归一化处理,得到scores为:
在这里插入图片描述

torch.nn.Softmax(dim=0,1,2)_第3张图片

2.测试代码

先假设y为一个[2,2,3]的张量

import torch
import torch.nn as nn

y = torch.tensor([[[1.,2.,3.],[4.,5.,6.]],[[7.,8.,9.],[10.,11.,12.]]]) #y.shape = torch.Size([2, 2, 3])

net_1 = nn.Softmax(dim=0)
net_2 = nn.Softmax(dim=1)
net_3 = nn.Softmax(dim=2)
print('dim=0的结果是:\n',net_1(y),"\n")
print('dim=1的结果是:\n',net_2(y),"\n")
print('dim=2的结果是:\n',net_3(y),"\n")

3.结果

3.1 y

torch.nn.Softmax(dim=0,1,2)_第4张图片
对于y可以分三个方向:

3.1.1 x和y方向:

torch.nn.Softmax(dim=0,1,2)_第5张图片

3.1.2 z方向:

torch.nn.Softmax(dim=0,1,2)_第6张图片

3.2 dim=0的结果

torch.nn.Softmax(dim=0,1,2)_第7张图片
解析:
dim = 0指第一个维度,在本例中第一个维度指[2,2,3]中的第一个2,即下图中红色的中括号,红框中包含两组数据(绿色和蓝色),每组数据大小为2*3的矩阵。dim=0也就是在下图中先沿着x轴方向考虑:
torch.nn.Softmax(dim=0,1,2)_第8张图片

将绿色2*3矩阵中的所有数据相加,求均值:
在这里插入图片描述

将蓝色2*3矩阵中的所有数据相加,求均值:

在这里插入图片描述
用softmax函数计算:
torch.nn.Softmax(dim=0,1,2)_第9张图片
torch.nn.Softmax(dim=0,1,2)_第10张图片

3.3 dim=1的结果

torch.nn.Softmax(dim=0,1,2)_第11张图片
解析:
dim = 1指第二个维度,在本例中第二个维度指[2,2,3]中的第二个2,即下图中红色的中括号,红框中包含两组数据(绿色和黄色)。dim=1也就是在下图中先沿着y轴方向考虑:
torch.nn.Softmax(dim=0,1,2)_第12张图片
将绿色数据相加,求均值:
在这里插入图片描述
将黄色数据相加,求均值:
在这里插入图片描述
用softmax函数计算:
torch.nn.Softmax(dim=0,1,2)_第13张图片
torch.nn.Softmax(dim=0,1,2)_第14张图片

3.4 dim=2的结果

torch.nn.Softmax(dim=0,1,2)_第15张图片
解析:
dim = 2指第三个维度,在本例中第三个维度指[2,2,3]中的第三个3,即下图中红色的中括号,每一列包含一组数据,一共三组数据(绿色,黄色,棕色):
torch.nn.Softmax(dim=0,1,2)_第16张图片
将绿色数据相加,求均值:

在这里插入图片描述
将黄色数据相加,求均值:
在这里插入图片描述
将棕色数据相加,求均值:
在这里插入图片描述
用softmax函数计算:
torch.nn.Softmax(dim=0,1,2)_第17张图片
torch.nn.Softmax(dim=0,1,2)_第18张图片
torch.nn.Softmax(dim=0,1,2)_第19张图片

你可能感兴趣的:(深度学习,深度学习,机器学习,分类)