sum()去掉0_(pytorch)torch.sum的用法及dim参数的使用

torch.sum()对输入的tensor数据的某一维度求和,一共两种用法

1.torch.sum(input, dtype=None)

2.torch.sum(input, list: dim, bool: keepdim=False, dtype=None) → Tensor

input:输入一个tensor

dim:要求和的维度,可以是一个列表

keepdim:求和之后这个dim的元素个数为1,所以要被去掉,如果要保留这个维度,则应当keepdim=True


dim参数的使用(用图来表示)

sum()去掉0_(pytorch)torch.sum的用法及dim参数的使用_第1张图片
import torch

a = torch.ones((2, 3))
a1 =  torch.sum(a)
a2 =  torch.sum(a, dim=0)
a3 =  torch.sum(a, dim=1)

print(a)
print(a1)
print(a2)
print(a3)

sum()去掉0_(pytorch)torch.sum的用法及dim参数的使用_第2张图片

此处也可将dim=0理解为纵向压缩,如:

2=1+1

2=1+1

2=1+1

故a被压缩成了 [2,2,2] size也从2*3压缩成了1*3

将dim=0理解为横向压缩,同纵向压缩,此处不再赘述


keepdim参数的使用

求和之后这个dim的元素个数为1,所以要被去掉,如果要保留这个维度,则应当keepdim=True

此处以一个例子来说明

import torch

a = torch.ones((2, 3))
print(a)
print('a.size=',a.size())

a1 =  torch.sum(a, dim=0)
print(a1)
print('a1.size=',a1.size())

a2 =  torch.sum(a, dim=0,keepdim=True)
print(a2)
print('a2.size=',a2.size())

sum()去掉0_(pytorch)torch.sum的用法及dim参数的使用_第3张图片

a1 经过纵向压缩后,被压缩成维度是 1 的张量,但令keepdim=true后,维度仍然是2维

PS:张量只能想象到2维,3维的还理解不了

理解PyTorch中维度的概念​mathpretty.com
sum()去掉0_(pytorch)torch.sum的用法及dim参数的使用_第4张图片

3维可以参考此文章

你可能感兴趣的:(sum()去掉0)