pytorch学习经验(六)torch.where():根据条件修改张量值

今天写代码的时候遇到一个问题,网络前向过程中有一个张量A,我想把张量A中的大于0的值变成张量B中对应的值,最初的实现是:

A[A>0]=B[A>0]

然后运行起来就报错了,原因是这个操作属于in-place操作,而pytorch在涉及到求梯度的tensor时,是不允许对这些tensor做原地操作的,否则在反向传播的时候,这些张量计算出来的梯度发生变化。
所以我后来采用了torch.where()方法:

torch.where(condition, x, y) → Tensor
# 使用where方法
C = torch.where(A > 0, B, A)

condition为y的条件表达式,where方法检查y中的所有元素,对于y中满足condition的元素,用x中对应元素替换;否则,还保留y中的元素。where返回一个新的张量。

你可能感兴趣的:(pytorch学习经验(六)torch.where():根据条件修改张量值)