pytorch-retinanet

参考代码

# 编译 nms 
cd pytorch-retinanet/lib
bash build.sh

出现如下问题,是因为,参考的代码中 pytorch==0.4, 我使用 pytorch=1.0
ImportError: torch.utils.ffi is deprecated. Please use cpp extensions instead
解决方法:

  • replace nms lib: https://github.com/huaifeng1993/NMS
  • this code has been compiled.if you need compile:
cd nms
rm -rf /build
rm *.so
cd ..
python setup3.py build_ext --inplace
#at last,you need modify code in model.py:
#from lib.nms.pth_nms import pth_nms
from lib.nms.gpu_nms import gpu_nms
and
#return pth_nms(dets, thresh)
return gpu_nms(dets, thresh)
raise error:
TypeError: Argument 'dets' has incorrect type (expected numpy.ndarray, got Tensor)

我的解决方式是:

# anchors_nms_idx = nms(torch.cat([transformed_anchors, scores], dim=2)[0, :, :], 0.5) 
anchors_nms_idx = nms(torch.cat([transformed_anchors, scores], dim=2)[0, :, :].cpu().numpy(), 0.5) 
# 再 visualize.py 中添加
scores = scores.cpu().numpy()

上述issue中解决方式是(这种方式可能需要重新编译):

# you need change the dets to numpy:
add in gpu_nms() :
dets = dets.numpy()

cv2 退出窗口

k = cv2.waitKey(0)  # waitkey代表读取键盘的输入,括号里的数字代表等待多长时间,单位ms。 0代表一直等待
if k == 27:  # 键盘上Esc键的键值
	cv2.destroyAllWindows()
	break  # 终止循环

你可能感兴趣的:(pytorch,目标检测)