在训练Pytorch模型的时候,报错
Traceback (most recent call last):
File "F:\project\modelscope-master\modelscope-master\modelscope.py", line 17, in
result = inpainting(input)
File "F:\project\modelscope-master\modelscope-master\modelscope\pipelines\base.py", line 219, in __call__
output = self._process_single(input, *args, **kwargs)
File "F:\project\modelscope-master\modelscope-master\modelscope\pipelines\base.py", line 247, in _process_single
out = self.preprocess(input, **preprocess_params)
File "F:\project\modelscope-master\modelscope-master\modelscope\pipelines\cv\image_inpainting_pipeline.py", line 107, in preprocess
result = self.perform_inference(result)
File "F:\project\modelscope-master\modelscope-master\modelscope\pipelines\cv\image_inpainting_pipeline.py", line 119, in perform_inference
cur_res = refine_predict(
File "F:\project\modelscope-master\modelscope-master\modelscope\models\cv\image_inpainting\refinement.py", line 385, in refine_predict
image_inpainted = _infer(image, mask, forward_front, forward_rears,
File "F:\project\modelscope-master\modelscope-master\modelscope\models\cv\image_inpainting\refinement.py", line 214, in _infer
mask_downscaled = _erode_mask(mask_downscaled, ekernel=ekernel)
File "F:\project\modelscope-master\modelscope-master\modelscope\models\cv\image_inpainting\refinement.py", line 109, in _erode_mask
mask = erosion(mask, ekernel)
File "I:\conda\envs\lama\lib\site-packages\skimage\morphology\misc.py", line 39, in func_out
return func(image, selem=selem, *args, **kwargs)
File "I:\conda\envs\lama\lib\site-packages\skimage\morphology\grey.py", line 181, in erosion
selem = np.array(selem)
File "I:\conda\envs\lama\lib\site-packages\torch\_tensor.py", line 757, in __array__
return self.numpy()
TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.
如果想把CUDA tensor格式的数据改成numpy时,需要先将其转换成cpu float-tensor随后再转到numpy格式。 numpy不能读取CUDA tensor 需要将它转化为 CPU tensor
将报错代码self.numpy()改为self.cpu().numpy()即可
之前报错的代码是在101服务器上,创建的虚拟环境中Python=3.7,在跑实验的时候出现报错;今天在100服务器上跑同样的实验没有报错,经查在100服务器上创建的虚拟环境中Python=3.8;目测是Python版本的缘故,看来还是python3.8好用一些