pytorch中ndimension()用法

pytorch中ndimension()用法

Tensor.ndimension(),返回tensor的维度(整数)

import torch

a = torch.zeros([3])
b = torch.zeros([1,2,3])
print(a,'\n',b)

a_dim = a.ndimension()
b_dim = b.ndimension()
print(a_dim,'\n',b_dim)

###返回值
tensor([0., 0., 0.]) 
tensor([[[0., 0., 0.],
         [0., 0., 0.]]])
1 
3

你可能感兴趣的:(python,pytorch)