Pytorch——torch.roll()函数使用方法

文章目录

  • 一、torch.roll()函数参数
  • 二、使用示例
  • 总结
    • 参考


一、torch.roll()函数参数

torch.roll(input, shifts, dims=None)
功能:按照指定的维度滚动tensor, 如果元素超出了维度,则回归到最初的位置。
input:输入的tensor
shifts:可以为int,也可以是int型的元组。张量的元素移位的位数。如果移位是一个元组,则dim必须是相同大小的元组,并且每个维度将按相应的值滚动。
dims:roll的维度,沿着那个维度滚动。

二、使用示例

代码如下(示例):

>>>import torch
>>>x = torch.arange(1, 17).view(4, 4)
>>>x
tensor([[ 1,  2,  3,  4],
        [ 5,  6,  7,  8],
        [ 9, 10, 11, 12],
        [13, 14, 15, 16]])
>>>y=torch.roll(x,shifts=1,dims=0)#按第0维,移位为1(即顺序),滚动
>>>y
tensor([[13, 14, 15, 16],
        [ 1,  2,  3,  4],
        [ 5,  6,  7,  8],
        [ 9, 10, 11, 12]])
>>>y1=torch.roll(x,shifts=-1,dims=0)#按第0维,移位为-1(即逆序),滚动
>>>y1
tensor([[ 5,  6,  7,  8],
        [ 9, 10, 11, 12],
        [13, 14, 15, 16],
        [ 1,  2,  3,  4]])
>>>y2=torch.roll(x,shifts=2,dims=0)#按第0维,移位为2,滚动
>>>y2
tensor([[ 9, 10, 11, 12],
        [13, 14, 15, 16],
        [ 1,  2,  3,  4],
        [ 5,  6,  7,  8]])
>>>z=torch.roll(x,shifts=1,dims=1)#按第1维,移位为1,滚动
>>>z
tensor([[ 4,  1,  2,  3],
        [ 8,  5,  6,  7],
        [12,  9, 10, 11],
        [16, 13, 14, 15]])
>>>z1=torch.roll(x,shifts=-1,dims=1)#按第1维,移位为-1,滚动
>>>z1
tensor([[ 2,  3,  4,  1],
        [ 6,  7,  8,  5],
        [10, 11, 12,  9],
        [14, 15, 16, 13]])
>>>z2=torch.roll(x,shifts=2,dims=1)#按第1维,移位为2,滚动
>>>z2
tensor([[ 3,  4,  1,  2],
        [ 7,  8,  5,  6],
        [11, 12,  9, 10],
        [15, 16, 13, 14]])
>>>o=torch.roll(x,shifts=(2,1),dims=(0,1))#第0维,移位为2滚动;第1维,移位为1滚动
>>>o
tensor([[12,  9, 10, 11],
        [16, 13, 14, 15],
        [ 4,  1,  2,  3],
        [ 8,  5,  6,  7]])

总结

总结了torch.roll()函数的用法。

参考

torch.roll方法官方解释
PyTorch基础(14)-- torch.roll()方法

你可能感兴趣的:(pytorch基础,pytorch)