Pytorch 中 Variable 和 Tensor 的 bool 特性

Variable 的 bool

>>> vx
Variable containing:
 0.3535  0.5137  0.4131  0.5732
 0.8076  0.4160  1.0000  0.6436
 0.3682  0.8086  0.4863  0.9268
[torch.cuda.FloatTensor of size 3x4 (GPU 0)]
>>> vx > 0.5
Variable containing:
 0  1  0  1
 1  0  1  1
 0  1  0  1
[torch.cuda.ByteTensor of size 3x4 (GPU 0)]
# 不能使用 np.where !!!!!!!!!
>>> np.where(vx>0.5)
Traceback (most recent call last):
  File "", line 1, in 
  File "/anaconda/envs/th36/lib/python3.6/site-packages/torch/autograd/variable.py", line 125, in __bool__
    torch.typename(self.data) + " is ambiguous")
RuntimeError: bool value of Variable objects containing non-empty torch.cuda.ByteTensor is ambiguous
# 不能使用 np.bool !!!!!!!!!
>>> np.bool(vx > 0.5)
Traceback (most recent call last):
  File "", line 1, in 
  File "/anaconda/envs/th36/lib/python3.6/site-packages/torch/autograd/variable.py", line 125, in __bool__
    torch.typename(self.data) + " is ambiguous")
RuntimeError: bool value of Variable objects containing non-empty torch.cuda.ByteTensor is ambiguous

Tensor 的 bool

>>> vx.data > 0.5
 0  1  0  1
 1  0  1  1
 0  1  0  1
[torch.cuda.ByteTensor of size 3x4 (GPU 0)]
# 可以使用 np.where !!!!!!!!
>>> np.where(vx.data > 0.5)
(array([0, 0, 1, 1, 1, 2, 2]), array([1, 3, 0, 2, 3, 1, 3]))
# 不能使用 np.bool !!!!!!!!!
>>> np.bool(vx.data > 0.5)
Traceback (most recent call last):
  File "", line 1, in 
  File "/anaconda/envs/th36/lib/python3.6/site-packages/torch/tensor.py", line 163, in __bool__
    " objects is ambiguous")
RuntimeError: bool value of non-empty torch.cuda.ByteTensor objects is ambiguous

两者的值的比较

import torch
a = torch.tensor([[1,2,3],[4,5,6],[7,8,9],[10,11,12]],dtype=torch.float)
for i in a[0,:]:
	if i == 1:
		print('Yes')
	else:
		print('No')



Yes
No
No

前部分参考http://www.studyai.com/article/ab883768

你可能感兴趣的:(Pytorch 中 Variable 和 Tensor 的 bool 特性)