torch.unbind(input, dim=0):
input 为输入的 tensor。
dim 为需要移除的维度。
import torch
a = torch.Tensor([[1,2,3],[4,5,6],[7,8,9]])
print(torch.unbind(a,0))
输出:
(tensor([1., 2., 3.]), tensor([4., 5., 6.]), tensor([7., 8., 9.]))
import torch
a = torch.Tensor([[1,2,3],[4,5,6],[7,8,9]])
print(torch.unbind(a,1))
输出:
(tensor([1., 4., 7.]), tensor([2., 5., 8.]), tensor([3., 6., 9.]))
a.unbind(0)
import torch
a = torch.Tensor([[1,2,3],[4,5,6],[7,8,9]])
print(a.unbind(0))
输出:
(tensor([1., 2., 3.]), tensor([4., 5., 6.]), tensor([7., 8., 9.]))
参考: pytorch 官方文档
https://pytorch.org/docs/stable/generated/torch.unbind.html?highlight=unbind#torch.unbind