RuntimeError: derivative for floor_divide is not implemented

RuntimeError: derivative for floor_divide is not implemented

指的是:pytorch还没有实现floor(取下运算),backward报错

例子:

logits = (dist // dist.median())  # 将值转换为01

# 由:tensor([ 46.8038,  47.6287,  64.6071,  93.4675, 113.4134, 119.4478,  34.6805,
#         74.9144,  92.9773, 104.6233], device='cuda:0', grad_fn=)
# 转:tensor([0., 0., 0., 1., 1., 1., 0., 1., 1., 1.], device='cuda:0',
#      grad_fn=)

# loss
loss = criterion(logits, label)
optimizer.zero_grad()
loss.backward() 
optimizer.step()

此时,可以寻找其他近似运算替代

import torch
 
a = torch.tensor(3.14)
print(a.floor(), a.ceil(), a.trunc(), a.frac())  # 取下,取上,取整数,取小数
b = torch.tensor(3.49)
c = torch.tensor(3.5)
print(b.round(), c.round())  # 四舍五入

这里选用取整数运算.trunc(),backward不再报错

logits = (dist / dist.median()).trunc()  # 将值转换为01

# loss
loss = criterion(logits, label)
optimizer.zero_grad()
loss.backward() 
optimizer.step()

 

你可能感兴趣的:(pytorch)