在PyTorch中,sum()函数用于对输入张量的所有元素进行求和操作。该函数的语法如下:
torch.sum(input, dim=None, keepdim=False, dtype=None)
具体而言,sum()函数会对输入张量的所有元素进行求和操作,并返回一个标量值。
如果指定了dim参数,则会沿着指定的维度对输入张量进行求和操作,并返回一个形状与输入张量除了指定维度之外的维度相同的张量。
如果指定了keepdim参数,则返回的张量会保持与输入张量相同的维度数,并且在指定维度上的大小为1。
具体算维度,采用消解法,也就是 dim=0,那就是消去第0维,dim=1,消去第1维
比如一个矩阵维度(2,3,4),dim=0,那就是消去第0维,变成了 (3,4),消去第1维,变成了 (2,4),消去第2维,变成了 (2,3),dim=-1,也就是最后一维,在这个矩阵中也就是第二维。
import torch
x = torch.arange(24).view(2,3,4)
print(x)
tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])
y = torch.sum(x)
print(y)
tensor(276)
y = torch.sum(x,dim=0)
print(y)
print(y.shape)
tensor([[12, 14, 16, 18],
[20, 22, 24, 26],
[28, 30, 32, 34]])
torch.Size([3, 4])
y = torch.sum(x, dim=1)
print(y)
print(y.shape)
tensor([[12, 15, 18, 21],
[48, 51, 54, 57]])
torch.Size([2, 4])
y = torch.sum(x, dim=2)
print(y)
print(y.shape)
tensor([[ 6, 22, 38],
[54, 70, 86]])
torch.Size([2, 3])
y = torch.sum(x, dim=-1)
print(y)
print(y.shape)
tensor([[ 6, 22, 38],
[54, 70, 86]])
torch.Size([2, 3])