torch.unbind()

torch.unbind(input, dim=0)

函数的作用:

  1. 删除一个 tensor 的维度。可以理解为降维。
  2. 返回一个沿给定维度的所有切片的元组。

参数:

torch.unbind(input, dim=0):
input 为输入的 tensor。
dim 为需要移除的维度。

举例:

例子1: dim = 0

import torch
a = torch.Tensor([[1,2,3],[4,5,6],[7,8,9]])
print(torch.unbind(a,0))

输出:

(tensor([1., 2., 3.]), tensor([4., 5., 6.]), tensor([7., 8., 9.]))

例子2: dim = 1

import torch
a = torch.Tensor([[1,2,3],[4,5,6],[7,8,9]])
print(torch.unbind(a,1))

输出:

(tensor([1., 4., 7.]), tensor([2., 5., 8.]), tensor([3., 6., 9.]))

也可使用如下形式,与上面等价:

a.unbind(0)

import torch
a = torch.Tensor([[1,2,3],[4,5,6],[7,8,9]])
print(a.unbind(0))

输出:

(tensor([1., 2., 3.]), tensor([4., 5., 6.]), tensor([7., 8., 9.]))

参考: pytorch 官方文档
https://pytorch.org/docs/stable/generated/torch.unbind.html?highlight=unbind#torch.unbind

你可能感兴趣的:(深度学习,python,pytorch,深度学习)