PyTorch 函数解释:torch.narrow()、torch.unbind()

torch.narrow()

PyTorch 中的narrow()函数起到了筛选一定维度上的数据作用。个人感觉与x[begin:end] 相同!

参考官网:torch.narrow()

用法:torch.narrow(input, dim, start, length) → Tensor

返回输入张量的切片操作结果。 输入tensor和返回的tensor共享内存。

参数说明:

  • input (Tensor) – 需切片的张量
  • dim (int) – 切片维度
  • start (int) – 开始的索引
  • length (int) – 切片长度

示例代码:

In [1]: import torch

In [2]: x = torch.randn(3,3)

In [3]: x
Out[3]:
tensor([[ 1.2474,  0.1820, -0.0179],
        [ 0.1388, -1.7373,  0.5934],
        [ 0.2288,  1.1102,  0.6743]])

In [4]: x.narrow(0, 1, 2) # 行切片
Out[4]:
tensor([[ 0.1388, -1.7373,  0.5934],
        [ 0.2288,  1.1102,  0.6743]])

In [5]: x.narrow(1, 1, 2) # 列切片
Out[5]:
tensor([[ 0.1820, -0.0179],
        [-1.7373,  0.5934],
        [ 1.1102,  0.6743]])


torch.unbind()

torch.unbind()移除指定维后,返回一个元组,包含了沿着指定维切片后的各个切片。

参考官网:torch.unbind()

用法:torch.unbind(input, dim=0) → seq

返回指定维度切片后的元组。

代码示例:

In [6]: x
Out[6]:
tensor([[ 1.2474,  0.1820, -0.0179],
        [ 0.1388, -1.7373,  0.5934],
        [ 0.2288,  1.1102,  0.6743]])

In [7]: torch.unbind(x, 0)
Out[7]:
(tensor([ 1.2474,  0.1820, -0.0179]),
 tensor([ 0.1388, -1.7373,  0.5934]),
 tensor([0.2288, 1.1102, 0.6743]))

In [8]: torch.unbind(x, 1)
Out[8]:
(tensor([1.2474, 0.1388, 0.2288]),
 tensor([ 0.1820, -1.7373,  1.1102]),
 tensor([-0.0179,  0.5934,  0.6743]))

你可能感兴趣的:(PyTorch)