pytorch 笔记 torch.roll

1 基本使用方法

向某个方向滑动

torch.roll(input, shifts, dims=None) → Tensor

2 参数说明

input (Tensor) 输入张量
shifts (int 或 tuple of int)
  • 张量元素移位的位数。
  • 如果该参数是一个元组(例如shifts=(x,y)),dims必须是一个相同大小的元组(例如dims=(a,b)),相当于在第a维度移x位,在b维度移y位
dims  维度

3 举例说明

a=torch.arange(16).reshape(4,4)
print(a)
torch.roll(a,shifts=(1,1),dims=(0,1))
'''
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])
tensor([[15, 12, 13, 14],
        [ 3,  0,  1,  2],
        [ 7,  4,  5,  6],
        [11,  8,  9, 10]])
'''

pytorch 笔记 torch.roll_第1张图片

 

a=torch.arange(16).reshape(4,4)
print(a)
torch.roll(a,shifts=(2,2),dims=(0,1))
'''
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])
tensor([[10, 11,  8,  9],
        [14, 15, 12, 13],
        [ 2,  3,  0,  1],
        [ 6,  7,  4,  5]])
'''

你可能感兴趣的:(python库整理)