Python调用训练好的MNN模型时,使用PIL模块读取图片
如下列代码所示,通过PIL模块的Image.open()函数由指定路径打开对应的图片img。
def read_img(img_path): # read image & data pre-process
data = torch.randn(1, 3, 112, 112)
transform = T.Compose([
transforms.Resize(112,112),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
img = Image.open(img_path).convert("RGB")
data[0, :, :, :] = transform(img)
return data
if __name__ == "__main__":
#加载模型#
#设置输入#
image = read_img('your_img_path')
tmp_input = MNN.Tensor((1, 3, 112, 112), MNN.Halide_Type_Float, image, MNN.Tensor_DimensionType_Caffe)
Exception: PyMNNTensor_init: data is not tuple/numpy
若使用MNN.Tensor()函数转换mnn模型的输入时,其输入的数据格式是tuple/numpy,而由PIL读出的img经过torchvision.transforms转换之后data数据格式转为tensor,不能直接放入MNN.Tensor()函数中
一种可行的简单的解决方法为,将tensor直接转成numpy
if __name__ == "__main__":
interpreter = MNN.Interpreter("your_mnn_path")
session = interpreter.createSession()
input_tensor = interpreter.getSessionInput(session)
image = read_img('your_img_path')
input_numpy = image.cpu().numpy().squeeze()
tmp_input = MNN.Tensor((1, 3, 112, 112), MNN.Halide_Type_Float, image, MNN.Tensor_DimensionType_Caffe)
input_tensor.copyFrom(tmp_input)