PaddleDectection代码解析之尝试简化infer.py

2021SC@SDUSC  

PaddleDetection库是在PaddlePaddle基础上,将常见的目标检测模型集中到一起,方便训练和预测的代码库。便于快速上手使用。

经过阅读PaddleDetection的整个项目结构以及大部分代码, 我发现infer.py这个文件是可以进行优化的, 于是在阅读了源码之后我希望不止于理解这么简单,而是着手对infer.py文件进行优化。

但是,一旦涉及到对代码的优化,或者想要摘出代码某模块使用时,对代码内部行为的掌握就十分重要。而且该库的封装使用了python装饰器和反射,模型结构搭建使用了递归,就连简单的学习器定义和加载模型都封装的很深。用户无法在调试中掌握代码的行为,就无法实现优化和模块拿来使用的目的。因此,这次的优化还是花费了我不少心思。

在本次优化中,所优化的代码遵循线性、少封装、方便调试、代码行为直白等原则,真正实现“拆走即用”的目的,下面对我都优化思路进行介绍。

优化过程:

 1.准备输入数据的部分的优化:

    image = cv2.imread('demo/orange_71.jpg')
    image = cv2.resize(image, dsize=(512, 512)).astype(np.float32)
    image = image/255
    mean = np.array([0.485, 0.456, 0.406]).astype(np.float32)
    image = image - mean[np.newaxis, np.newaxis, :]
    image = image.transpose((2,0,1))
    image = image[np.newaxis, :, :, :]
    image_size = np.array([512, 512])
    image_size = image_size[np.newaxis, :]

2.定义设备和program部分的优化:

    place = fluid.core.CPUPlace()
    exe = fluid.Executor(place)
    startup_program = fluid.Program()
    inference_program = fluid.Program()

3.搭建网络结构部分的优化:

    backbone = MobileNet()
    head = YOLOv3Head(num_classes=3)
    model = YOLOv3(backbone = backbone, yolo_head=head)
    with fluid.program_guard(inference_program, startup_program):
        data = fluid.layers.data(name = 'image', shape=[-1, 3, 512, 512], dtype='float32')
        im_size = fluid.layers.data(name = 'im_size', shape=[-1, 2], dtype='float32')
        test_fetches = model.test({'image':data, 'im_size':im_size})
    inference_program = inference_program.clone(for_test=True)

4. 加载预训练模型,需要自己提前下载好,下载地址为:

https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mobilenet_v1_fruit.tar 

 5. 前向预测并输出结果的优化:

    output = exe.run(inference_program, fetch_list = [test_fetches['bbox']], feed={'image': image, 'im_size':image_size}, return_numpy=False)
    output = np.array(output[0])
    print(output)
    #cv2.rectangle(...)  #注意,要把原始的orange_71.jpg resize到512*512再可视化
    #cv2.putText(...)
    #cv2.imwrite(...)

 

完整优化代码:

from ppdet.modeling.backbones.mobilenet import MobileNet
from ppdet.modeling.anchor_heads.yolo_head import YOLOv3Head
from ppdet.modeling.architectures.yolo import YOLOv3
import paddle.fluid as fluid
import numpy as np
import cv2

if __name__=="__main__":
    #准备输入数据
    image = cv2.imread('demo/orange_71.jpg')
    image = cv2.resize(image, dsize=(512, 512)).astype(np.float32)
    image = image/255
    mean = np.array([0.485, 0.456, 0.406]).astype(np.float32)
    image = image - mean[np.newaxis, np.newaxis, :]
    image = image.transpose((2,0,1))
    image = image[np.newaxis, :, :, :]
    image_size = np.array([512, 512])
    image_size = image_size[np.newaxis, :]
    
    #定义设备和program
    place = fluid.core.CPUPlace()
    exe = fluid.Executor(place)
    startup_program = fluid.Program()
    inference_program = fluid.Program()

    #搭建网络结构
    backbone = MobileNet()
    head = YOLOv3Head(num_classes=3)
    model = YOLOv3(backbone = backbone, yolo_head=head)
    with fluid.program_guard(inference_program, startup_program):
        data = fluid.layers.data(name = 'image', shape=[-1, 3, 512, 512], dtype='float32')
        im_size = fluid.layers.data(name = 'im_size', shape=[-1, 2], dtype='float32')
        test_fetches = model.test({'image':data, 'im_size':im_size})
    inference_program = inference_program.clone(for_test=True)

    #加载预训练模型,需要自己提前下载好,下载地址为:https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mobilenet_v1_fruit.tar (其实可以从原始的infer.py中找到)
    fluid.io.load_persistables(executor=exe, dirname='C:\\Users\\Administrator\\.cache\\paddle\\weights\\yolov3_mobilenet_v1_fruit', main_program = inference_program)

    #前项预测并输出结果
    output = exe.run(inference_program, fetch_list = [test_fetches['bbox']], feed={'image': image, 'im_size':image_size}, return_numpy=False)
    output = np.array(output[0])
    print(output)
    #cv2.rectangle(...)  #注意,要把原始的orange_71.jpg resize到512*512再可视化
    #cv2.putText(...)
    #cv2.imwrite(...)
#output解释:每一行为一个结果,共6个数,第1个为类别,第2个为得分,接下来4个为坐标框
#我的output结果是:
#output
#array([[  0.        ,   0.02032653,  28.338654  ,  91.62413   ,        339.2822    , 375.45032   ],
#       [  1.        ,   0.34385982,  28.338654  ,  91.62413   ,        339.2822    , 375.45032   ],
#       [  1.        ,   0.06452157, 298.2945    , 108.12228   ,        437.34155   , 232.94336   ],
#       [  1.        ,   0.04201436, 263.78772   , 244.61458   ,        480.8169    , 432.98602   ],
#       [  1.        ,   0.0129063 , 312.90588   , 133.76407   ,        505.23004   , 352.65436   ],
#       [  2.        ,   0.8128051 ,  28.338654  ,  91.62413   ,        339.2822    , 375.45032   ]], dtype=float32)

 

你可能感兴趣的:(深度学习,人工智能)