Pytorch中函数参数dim的理解

一.dim的定义

dim的不同值代表不同的维度,例如在二维张量中dim=0代表的是行,dim=1代表的是列。广泛的说,在多维张量(d_{0},d_{1},d_{2},...,d_{n})中,dim=0就是指d_{0},dim=n是指d_{n}

二.例子

torch.sum()

Pytorch中函数参数dim的理解_第1张图片

 input:输入的张量

dim:需要消减的维度

keepdim:输出张量中是否保存指定dim维的张量

eg1:

b = torch.arange(3 * 2 * 2).view(3, 2, 2)
print(b)
print(torch.sum(b, (1, 2)))

输出结果为:

Pytorch中函数参数dim的理解_第2张图片

这里的输出结果是按照第0维进行相加的,原因是因为dim=(1, 2)将这两维进行消减,从而根据剩下的一维进行求和计算。

eg2:

b = torch.arange(3 * 2 * 2).view(3, 2, 2)
print(b)
print(torch.sum(b, 1))

 Pytorch中函数参数dim的理解_第3张图片

 这里将第一维消减后还有两维,所以最终的输出结果将按照第0维以及第2维进行计算,所以最终张量的维度为3*2.

你可能感兴趣的:(pytorch,python,深度学习)