torch.unbind()

作用:对某一个维度进行长度为1的切片,并将所有切片结果返回。

举个例子就知道了:

import torch
x = torch.tensor([[1, 2],[ 3,4]])
torch.unbind(x,0)#第0个维度上进行长度为1的切片。

结果:

(tensor([1, 2]), tensor([3, 4]))

应用减少代码量:

m,n=torch.unbind(x,0)
print(m)
print(n)
m,n=x[0],x[1]#虽然这个也可以,但是如果x的第一个维度很大,这个就很繁琐。
print(m)
print(n)

结果:

tensor([1, 2])
tensor([3, 4])
tensor([1, 2])
tensor([3, 4])

补充另外一种torch.unbind()的等价形式:

x.unbind(0)

结果:

(tensor([1, 2]), tensor([3, 4]))

你可能感兴趣的:(Pytorch深入理解与实战,pytorch,python)