pytorch查看tensor中包含nan的方法

a = torch.randn(2,3,4,4)
a[1][2][3][1] = np.nan
a[1][2][3][2] = np.nan
result = torch.nonzero(torch.isnan(a)==True)
print(result)
# tensor([[1, 2, 3, 1],
#        [1, 2, 3, 2]])

那么就可以查看result.shape[0] 是否 > 0,如果 > 0,就代表a这个tensor里肯定有nan了

你可能感兴趣的:(pytorch,深度学习,python)