pytorch 中,使用到 dim 参数的 api 都是跟集合有关的,比如 max(), min(), mean(), softmax() 等。当指定某个 dim 时,表示使用该维度的所有元素进行集合运算,一个 tensor 的 shape 为 (3, 4, 5),分别对应的 dim 如下所示
dim | shape |
---|---|
0 | 3 |
1 | 4 |
2 | 5 |
当使用 max(dim=1) 时,表示使用第二个维度中全部四个元素中的每个元素参与求最大值计算,计算后的 shape 变为 (3,5),因为只从 四个中求得最大的那个作为结果。如果 shape 的长度为 3,则 dim 的取值只能在区间 [-3, 2],否则将报错。
Example
>>> a = torch.randn(3,4,5)
# 求得第二个维度的最大值
>>> torch.max(a,1)
torch.return_types.max(
values=tensor([[0.7700, 0.1390, 0.6952, 1.9428, 0.8477],
[1.0085, 0.7961, 0.9462, 2.1287, 0.9356],
[1.1520, 2.1478, 0.8291, 1.0854, 0.7780]]),
indices=tensor([[1, 1, 2, 2, 0],
[1, 2, 2, 3, 0],
[0, 1, 3, 3, 3]]))
# 第二个维度缩减为只有一个元素,即 (3,1,5),api 将维度为 1 的去掉了
>>> torch.max(a,1).values.shape
torch.Size([3, 5])
# 第三个维度缩减为只有一个元素,即 (3,4,1),api 将维度为 1 的去掉了
>>> torch.max(a,2).values.shape
torch.Size([3, 4])
# 超出 dim 范围,报错
>>> torch.max(a,3).values.shape
Traceback (most recent call last):
File "" , line 1, in <module>
IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3)
总结:
1、dim 是一种集合运算的参数,表示将某个维度的所有元素参与集合运算
2、dim 的取值和 shape 的长度密切相关,dim 的取值为 [-len(shape), len(shape)-1]