TypeError: where() missing 2 required positional argument: “input“, “other“

1.问题描述

程序报错的源代码为:

	for _ in range(attack_iters):
	   output = model(X + delta)
	   index = torch.where(output.max(1)[1] == y)
	   if len(index[0]) == 0:
	       break
	   loss = F.cross_entropy(output, y)
	   if opt is not None:
	       with amp.scale_loss(loss, opt) as scaled_loss:
	           scaled_loss.backward()
	   else:
	       loss.backward()

报错的信息如下所示:
TypeError: where() missing 2 required positional argument: “input“, “other“_第1张图片

2.解决方法

torch.where()在1.0.1版本之前只能接受3个参数(以下介绍用法1),在1.0.1版本之后即可以接受3个参数,也可以接受1个参数(以下用法1和用法2)。解决办法就是,把pytorch的版本从1.0.1升级到1.1.0版本以上即可。

用法1:torch.where(condition,x,y)

包含三个参数的用法和程序实例代码如下所示:
TypeError: where() missing 2 required positional argument: “input“, “other“_第2张图片

>>> x = torch.randn(3, 2)
>>> y = torch.ones(3, 2)
>>> x
tensor([[-0.4620,  0.3139],
        [ 0.3898, -0.7197],
        [ 0.0478, -0.1657]])
>>> torch.where(x > 0, x, y)
tensor([[ 1.0000,  0.3139],
        [ 0.3898,  1.0000],
        [ 0.0478,  1.0000]])
>>> x = torch.randn(2, 2, dtype=torch.double)
>>> x
tensor([[ 1.0779,  0.0383],
        [-0.8785, -1.1089]], dtype=torch.float64)
>>> torch.where(x > 0, x, 0.)
tensor([[1.0779, 0.0383],
        [0.0000, 0.0000]], dtype=torch.float64)

用法2:torch.where(condition)

包含1个参数的用法和程序实例代码如下所示,该用法跟
TypeError: where() missing 2 required positional argument: “input“, “other“_第3张图片

>>> torch.where(torch.tensor([1,1,1,0,1]))
(tensor([0, 1, 2, 4]),)
>>> torch.where(torch.tensor([[0.6, 0.0, 0.0, 0.0],
...                             [0.0, 0.4, 0.0, 0.0],
...                             [0.0, 0.0, 1.2, 0.0],
...                             [0.0, 0.0, 0.0, -0.4]]))
(tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3]))
>>> torch.nonzero(torch.tensor([1, 1, 1, 0, 1]), as_tuple=True)
(tensor([0, 1, 2, 4]),)
>>> torch.nonzero(torch.tensor([[0.6, 0.0, 0.0, 0.0],
...                             [0.0, 0.4, 0.0, 0.0],
...                             [0.0, 0.0, 1.2, 0.0],
...                             [0.0, 0.0, 0.0,-0.4]]), as_tuple=True)
(tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3]))

你可能感兴趣的:(python)