pytorch中函数参数dim的理解

对于刚入门的新手来说,Torch API中的维度真的很迷惑人.例如 torch.sum(x, dim=0) 是按着行相加呢,还是列相加?

还有TopK,softmax等函数

在我潜意识的认为dim=0就是按着行操作,sum就是按着行相加,然而我的潜意识骗了我,或者说设计API的人不按常规套路出牌、

我们来看一个例子:

import numpy as np
import torch
x = torch.tensor([
        [1,2,3],
        [4,5,6]
    ])
# 我们可以看到"行"是dim=0, "列"是dim=1
print(x.shape)
'''输出
torch.Size([2, 3])
'''
# 但是我们按照dim=0求和, 是按照列相加
torch.sum(x, dim=0)
'''输出
tensor([5, 7, 9])
'''
# 但是我们按照dim=1求和, 是按照行相加1
torch.sum(x, dim=1)
'''输出
tensor([ 6, 15])
'''

pytorch中函数参数dim的理解_第1张图片

 对于size是1的除外,我们可以将其想象为dim=0的时候, 对行进行挤压, 于是变为了一行, 也就是每列元素相加。dim=1的时候就是对列挤压,于是就变成了一列,也就是对每行元素相加。

# 看一下三维的
x = torch.tensor([
        [
         [1,2,3],
         [4,5,6]
        ],
        [
         [1,2,3],
         [4,5,6]
        ],
        [
         [1,2,3],
         [4,5,6]
        ]
    ])
# 我们可以看到第三维是dim=0, "行"是dim=1, 列是dim=2
print(x.shape)
'''输出
torch.Size([3, 2, 3])
'''

pytorch中函数参数dim的理解_第2张图片

根据以上的结论,我们看看TopK的用法

# 看一下2维的
x = torch.tensor([
    [0.1, 0.2 ,0.5, 0.2],
    [0.4, 0.3, 0.2, 0.1],
    [0.1, 0.2, 0.5, 0.1],
])
# 我们可以看到"行"是dim=0, 列是dim=1
print(x.shape)
'''
torch.Size([3, 4])
'''
a, b = x.topk(1, dim=1)
print(a)
print('-'*10)
print(b)
print(b.shape)
'''输出
tensor([[0.5000],
        [0.4000],
        [0.5000]])
----------
tensor([[2],
        [0],
        [2]])
torch.Size([3, 1])
'''
a, b = x.topk(1, dim=0)
print(a)
print('-'*10)
print(b)
print(b.shape)
'''输出
tensor([[0.4000, 0.3000, 0.5000, 0.2000]])
----------
tensor([[1, 1, 0, 0]])
torch.Size([1, 4])
'''

我们再来看softmax

import torch
import numpy as np
import torch.nn.functional as F

data = np.array([[0.1, 0.3, 0.6], [1.5,2.1 ,0.55]])
t_data = torch.from_numpy(data)
print(t_data)
print(t_data.shape)
#print(t_data.type())
print("**************************************************************************")
prob = F.softmax(t_data,dim=0) # dim = 0,在列上进行Softmax;dim=1,在行上进行Softmax
print(prob)
print(prob.shape)
#print(prob.type())
print("**************************************************************************")
prob = F.softmax(t_data,dim=1) # dim = 0,在列上进行Softmax;dim=1,在行上进行Softmax
print(prob)
print(prob.shape)
#print(prob.type())

pytorch中函数参数dim的理解_第3张图片

 当dim=0时候按着列进行softmax(0.1978+0.8022 = 1),当dim=1的时候按着行精选softmax(0.2584+0.3165+0.4260 = 1)

现在是不是有所感悟,你细品

总结:size不等于1,dim 指定沿着某一维度挤压,例如dim=0,就是行被压缩,按着列操作。如dim=1,就是列被压缩,按着行操作。

你可能感兴趣的:(pytorch,pytorch,softmax,sum,深度学习)