Python调用MNN模型bug记录和解决方法

项目场景:

Python调用训练好的MNN模型时,使用PIL模块读取图片


问题描述:

如下列代码所示,通过PIL模块的Image.open()函数由指定路径打开对应的图片img。

def read_img(img_path):     # read image & data pre-process
    data = torch.randn(1, 3, 112, 112)
    transform = T.Compose([
    	transforms.Resize(112112),
    	transforms.ToTensor(),
    	transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    img = Image.open(img_path).convert("RGB")
    data[0, :, :, :] = transform(img)
    return data

if __name__ == "__main__":
	#加载模型#
	#设置输入#
	image = read_img('your_img_path')
	tmp_input = MNN.Tensor((1, 3, 112, 112), MNN.Halide_Type_Float, image, MNN.Tensor_DimensionType_Caffe)
Exception: PyMNNTensor_init: data is not tuple/numpy

原因分析:

若使用MNN.Tensor()函数转换mnn模型的输入时,其输入的数据格式是tuple/numpy,而由PIL读出的img经过torchvision.transforms转换之后data数据格式转为tensor,不能直接放入MNN.Tensor()函数中


解决方案:

一种可行的简单的解决方法为,将tensor直接转成numpy

if __name__ == "__main__":
    interpreter = MNN.Interpreter("your_mnn_path")
    session = interpreter.createSession()
    input_tensor = interpreter.getSessionInput(session)
	image = read_img('your_img_path')
	input_numpy = image.cpu().numpy().squeeze()
	tmp_input = MNN.Tensor((1, 3, 112, 112), MNN.Halide_Type_Float, image, MNN.Tensor_DimensionType_Caffe)
	input_tensor.copyFrom(tmp_input)

你可能感兴趣的:(python)