torch.cumsum() 和 torch.cumprod()

 

import torch as t
a = t.arange(0, 6).view(2,3)
print(a)
a.cumsum(dim=0)

 

a = t.arange(0, 6).view(2,3)
print(a)
a.cumsum(dim=1)

 

对于二维输入a,dim=0(第1行不动,将第1行累加到其他行);dim=1(进入最内层,转化成列处理。第1列不动,将第1列累加到其他列;从第一列开始后面的每一列都是前面对应行元素的累加和),运行结果如下:

torch.cumsum() 和 torch.cumprod()_第1张图片

torch.cumsum() 和 torch.cumprod()_第2张图片

 

参数dim,用来指定这些操作是在哪个维度上执行的。关于dim(对应于Numpy中的axis)有提供一个简单的记忆方式:

假设输入的形状是(m, n, k)

  • 如果指定dim=0,输出的形状就是(1, n, k)或者(n, k)
  • 如果指定dim=1,输出的形状就是(m, 1, k)或者(m, k)
  • 如果指定dim=2,输出的形状就是(m, n, 1)或者(m, n)

size中是否有"1",取决于参数keepdimkeepdim=True会保留维度1。注意,以上只是经验总结,并非所有函数都符合这种形状变化方式,如cumsum

 

同理,torch.cumprod()

dim=1时,第一列不变,后面的每列将前面列的元素乘起来,如 12 = 3 * 4 ,60 = 3 * 4 * 5 。

dim=0时,第一行不变,后面每行将前面行对应元素乘起来,如 10 = 2 * 5 。

torch.cumsum() 和 torch.cumprod()_第3张图片

你可能感兴趣的:(pytorch,pytorch)