1)公式定义
公式理解:分母为输入向量中的所有元素按照指数方式求和,然后将输入中的每个元素按照指数方式除以分母得到计算结果。
2)调用方式
该函数输入/输出为n维向量,目的是将其输入重新缩放,使其所有元素皆属于[0,1],并且此时所有元素总和为1。
torch.nn.Softmax(dim=None) # 参数dim表示softmax进行计算的维度
3)示例
手动实现计算公式,从输出结果验证内部实现:
# 1.n维向量
softmax=nn.Softmax(dim=0)
inp=torch.tensor([10,20,5,3],dtype=torch.float32)
print("inp:",inp)
out=softmax(inp)
print("out:",out)
total_sum=torch.sum(out)
print("sum:",total_sum)
print(math.exp(inp[0])/torch.sum(torch.exp(inp)))
inp: tensor([10., 20., 5., 3.])
out: tensor([4.5398e-05, 9.9995e-01, 3.0589e-07, 4.1397e-08])
sum: tensor(1.)
tensor(4.5398e-05)
该函数使得输入中所有数值变为[0,1]范围内,下列代码随机生成形状为(2,3,4)的输入特征,分别输出dim=0,1,2时的结果,对其进行分析:
inp=torch.randn(size=(2,3,4))
print(inp)
print("*"*20)
fun1=nn.Softmax(dim=0)
fun2=nn.Softmax(dim=1)
fun3=nn.Softmax(dim=2)
out1=fun1(inp)
out2=fun2(inp)
out3=fun3(inp)
print(out1)
print("*"*20)
print(out2)
print("*"*20)
print(out3)
输入张量形状为(2,3,4),不同的dim表示,softmax操作将会在指定维度上进行运算,使通过该维度划分的切片所有元素和为1.
输入:
tensor([[[ 1.0691, -0.0940, 0.9542, -0.1487],
[ 0.8161, 0.3944, 0.5836, 0.1312],
[-1.4869, -1.8152, 0.7426, -0.7678]],
[[ 1.1018, -1.4451, -0.0344, -0.9289],
[-0.2549, 0.2586, -0.2612, -0.4508],
[-0.8926, -0.0369, 1.4809, 0.2047]]])
********************
1)dim=0表示该操作将在输入张量的第0维进行计算,而此张量第0维长度为2,故计算实际是沿着第0维方向,将整个张量划分为3×4个切片,每个切片包含两个元素,并分别对每个切片进行softmax计算,使得这两个元素和为1.
dim=0,输出:
tensor([[[0.4918, 0.7943, 0.7288, 0.6857],
[0.7448, 0.5339, 0.6995, 0.6415],
[0.3557, 0.1445, 0.3234, 0.2744]],
[[0.5082, 0.2057, 0.2712, 0.3143],
[0.2552, 0.4661, 0.3005, 0.3585],
[0.6443, 0.8555, 0.6766, 0.7256]]])
********************
2)dim=1表示该操作将在输入张量的第1维进行计算,而此张量第1维长度为3,故计算实际是沿着第1维方向,将整个张量划分为2×4个切片,每个切片包含三个元素,并分别对每个切片进行softmax计算,使得这三个元素和为1.
dim=1,输出:
tensor([[[0.5394, 0.3561, 0.4001, 0.3495],
[0.4188, 0.5802, 0.2762, 0.4624],
[0.0419, 0.0637, 0.3238, 0.1882]],
[[0.7176, 0.0945, 0.1575, 0.1748],
[0.1848, 0.5192, 0.1256, 0.2820],
[0.0977, 0.3863, 0.7169, 0.5432]]])
********************
3)dim=2表示该操作将在输入张量的第2维进行计算,而此张量第2维长度为4,故计算实际是沿着第2维方向,将整个张量划分为2×3个切片,每个切片包含四个元素,并分别对每个切片进行softmax计算,使得这四个元素和为1.
dim=2,输出:
tensor([[[0.4000, 0.1250, 0.3566, 0.1184],
[0.3387, 0.2221, 0.2684, 0.1707],
[0.0765, 0.0551, 0.7113, 0.1571]],
[[0.6533, 0.0512, 0.2098, 0.0857],
[0.2229, 0.3724, 0.2215, 0.1832],
[0.0585, 0.1377, 0.6284, 0.1754]]])