MNN classficationTopkEval

MNN classficationTopkEval.cpp 输入模型和配置文件,测试模型在 ImageNet 数据集上的分类精度。
程序分为3部分:

  • ImageProcess 将图像转换为适当格式的 Tensor;
  • Interpreter 由模型文件创建 Net 和 Session 并执行会话;
  • computeTopkAcc 计算分类准确率。

Interpreter::resizeTensor 调整输入张量,进而 Interpreter::resizeSession 调整后续所涉及的张量。

main

main
runEvaluation

MNN_PRINT 兼容 Android JNI log。
输入模型及配置文件,runEvaluation 评估模型。

    if (argc < 3) {
     
        MNN_PRINT("Usage: ./classficationTopkEval.out model.mnn preTreatConfig.json\n");
    }

    const auto modelPath          = argv[1];
    const auto preTreatConfigFile = argv[2];

    runEvaluation(modelPath, preTreatConfigFile);

    return 0;

runEvaluation

Created with Raphaël 2.2.0 runEvaluation modelPath, preTreatConfig Document Document::Parse Document::GetObject ImageProcess::Config ImageProcess::create ScheduleConfig Interpreter::createSession Interpreter::getSessionInput Interpreter::resizeTensor Interpreter::resizeSession Interpreter::getSessionOutput stbi_load Matrix Matrix::setTranslate ImageProcess::setMatrix ImageProcess::convert stbi_image_free Interpreter::runSession computeTopkAcc End

网络的输入及预处理参数存储在 json 文件中。
RapidJSON 是腾讯开源的 C++ JSON 解析器和生成器。
std::ifstream::rdbuf 获取流缓冲区。返回指向内部 filebuf 对象的指针。
GenericDocument::Parse 从只读字符串解析 json 文本(带编码转换)。

    int height, width;
    std::string imagePath;
    std::string groundTruthIdFile;
    rapidjson::Document document;
    {
     
        std::ifstream fileNames(preTreatConfig);
        std::ostringstream output;
        output << fileNames.rdbuf();
        auto outputStr = output.str();
        document.Parse(outputStr.c_str());
        if (document.HasParseError()) {
     
            MNN_ERROR("Invalid json\n");
            return 0;
        }
    }

GenericDocument::GetObject 返回 Object 对象。
ImageProcess::Config 结构体记录图像格式及变换时的参数。
GenericObject::HasMember 查询字段。
GenericObject 对象重载操作符 operator[] 。
ImageFormat 为枚举类型。
GenericValue::GetArray
GenericValue::GetFloat

    auto picObj = document.GetObject();
    ImageProcess::Config config;
    config.filterType = BILINEAR;
    // defalut input image format
    config.destFormat = BGR;
    {
     
        if (picObj.HasMember("format")) {
     
            auto format = picObj["format"].GetString();
            static std::map<std::string, ImageFormat> formatMap{
     {
     "BGR", BGR}, {
     "RGB", RGB}, {
     "GRAY", GRAY}};
            if (formatMap.find(format) != formatMap.end()) {
     
                config.destFormat = formatMap.find(format)->second;
            }
        }
    }
    config.sourceFormat = RGBA;
    {
     
        if (picObj.HasMember("mean")) {
     
            auto mean = picObj["mean"].GetArray();
            int cur   = 0;
            for (auto iter = mean.begin(); iter != mean.end(); iter++) {
     
                config.mean[cur++] = iter->GetFloat();
            }
        }
        if (picObj.HasMember("normal")) {
     
            auto normal = picObj["normal"].GetArray();
            int cur     = 0;
            for (auto iter = normal.begin(); iter != normal.end(); iter++) {
     
                config.normal[cur++] = iter->GetFloat();
            }
        }
        if (picObj.HasMember("width")) {
     
            width = picObj["width"].GetInt();
        }
        if (picObj.HasMember("height")) {
     
            height = picObj["height"].GetInt();
        }
        if (picObj.HasMember("imagePath")) {
     
            imagePath = picObj["imagePath"].GetString();
        }
        if (picObj.HasMember("groundTruthId")) {
     
            groundTruthIdFile = picObj["groundTruthId"].GetString();
        }
    }

ImageProcess::create 根据 ImageProcess::Config 创建 ImageProcess 对象。
Interpreter::createFromFile 由模型文件创建一个 Interpreter 对象。
Interpreter 持有网络数据,多个会话可以共享同一个网络。
ScheduleConfig 结构体进行会话调度表配置。
Interpreter::createSession 使用 ScheduleConfig 的配置创建会话。创建的会话将在网络中进行管理。
Interpreter::getSessionInput 获取给定名称的输入张量。
Interpreter::resizeTensor 调整输入张量。
Interpreter::resizeSession 调用此函数以准备张量。调整任何输入张量的大小后,应恢复输出张量缓冲区(host 或 deviceId)。
Interpreter::getSessionOutput 获取给定名称的输出张量。
lstat 获取有关指定文件的状态信息,并将其放置在buf参数指向的内存区域中。

    std::shared_ptr<ImageProcess> pretreat(ImageProcess::create(config));

    std::shared_ptr<Interpreter> classficationInterpreter(Interpreter::createFromFile(modelPath));
    ScheduleConfig classficationEvalConfig;
    classficationEvalConfig.type      = MNN_FORWARD_CPU;
    classficationEvalConfig.numThread = 4;
    auto classficationSession         = classficationInterpreter->createSession(classficationEvalConfig);
    auto inputTensor                  = classficationInterpreter->getSessionInput(classficationSession, nullptr);
    auto shape                        = inputTensor->shape();
    // the model has not input dimension
    if(shape.size() == 0){
     
        shape.resize(4);
        shape[0] = 1;
        shape[1] = 3;
        shape[2] = height;
        shape[3] = width;
    }
    // set batch to be 1
    shape[0] = 1;
    classficationInterpreter->resizeTensor(inputTensor, shape);
    classficationInterpreter->resizeSession(classficationSession);

    auto outputTensor = classficationInterpreter->getSessionOutput(classficationSession, nullptr);

opendir 函数打开与目录名称相对应的目录流,并返回指向该目录流的指针。流位于目录中的第一项。
readdir 返回一个指向dirent结构的指针,该结构表示dirp指向的目录流中的下一个目录条目。在到达目录流的末尾或发生错误时,它返回NULL

    // read ground truth label id
    std::vector<int> groundTruthId;
    {
     
        std::ifstream inputOs(groundTruthIdFile);
        std::string line;
        while (std::getline(inputOs, line)) {
     
            groundTruthId.emplace_back(std::atoi(line.c_str()));
        }
    }

    // read images file path
    int count = 0;
    std::vector<std::string> files;
    {
     
        struct stat s;
        lstat(imagePath.c_str(), &s);
        struct dirent* filename;
        DIR* dir;
        dir = opendir(imagePath.c_str());
        while ((filename = readdir(dir)) != nullptr) {
     
            if (strcmp(filename->d_name, ".") == 0 || strcmp(filename->d_name, "..") == 0) {
     
                continue;
            }
            files.push_back(filename->d_name);
            count++;
        }
        std::cout << "total: " << count << std::endl;
        std::sort(files.begin(), files.end());
    }

    if (count != groundTruthId.size()) {
     
        MNN_ERROR("The number of input images is not same with ground truth id\n");
        return 0;
    }

stbi_load 源于 nothings/stb。
Matrix 包含一个用于转换坐标的3x3矩阵。这允许使用平移、缩放、倾斜、旋转和透视来映射点和向量。Matrix::setTranslate 设置行列的平移量。
ImageProcess::setMatrix 设置仿射变换矩阵。
ImageProcess::convert 将源数据转换为给定张量。
stbi_image_free 释放加载的图像。
Interpreter::runSession 运行网络。

输出结果排序,computeTopkAcc 计算前 k 位准确率。

    int test = 0;
    int top1 = 0;
    int topk = 0;

    const int outputTensorSize = outputTensor->elementSize();
    if (outputTensorSize != TOTAL_CLASS_NUM) {
     
        MNN_ERROR("Change the total class number, such as the result number of tensorflow mobilenetv1/v2 is 1001\n");
        return 0;
    }

    std::vector<std::pair<int, float>> sortedResult(outputTensorSize);
    for (const auto& file : files) {
     
        const auto img = imagePath + file;
        int h, w, channel;
        auto inputImage = stbi_load(img.c_str(), &w, &h, &channel, 4);
        if (!inputImage) {
     
            MNN_ERROR("Can't open %s\n", img.c_str());
            return 0;
        }

        // input image transform
        Matrix trans;
        // choose resize or crop
        // resize method
        // trans.setScale((float)(w-1) / (width-1), (float)(h-1) / (height-1));
        // crop method
        trans.setTranslate(16.0f, 16.0f);
        pretreat->setMatrix(trans);
        pretreat->convert((uint8_t*)inputImage, h, w, 0, inputTensor);
        stbi_image_free(inputImage);
        classficationInterpreter->runSession(classficationSession);

        {
     
            // default float value
            auto outputDataPtr = outputTensor->host<float>();
            for (int i = 0; i < outputTensorSize; ++i) {
     
                sortedResult[i] = std::make_pair(i, outputDataPtr[i]);
            }
            std::sort(sortedResult.begin(), sortedResult.end(),
                      [](std::pair<int, float> a, std::pair<int, float> b) {
      return a.second > b.second; });
        }
        computeTopkAcc(groundTruthId, sortedResult, test, &top1, &topk);
        test++;
        MNN_PRINT("==> tested: %f, Top1: %f, Topk: %f\n", (float)test / (float)count * 100.0,
                  (float)top1 / (float)test * 100.0, (float)topk / (float)test * 100.0);
    }

    return 0;

computeTopkAcc

    const int label = groundTruthId[index];
    if (sortedResult[0].first == label) {
     
        (*top1)++;
    }
    for (int i = 0; i < TOPK; ++i) {
     
        if (label == sortedResult[i].first) {
     
            (*topk)++;
            break;
        }
    }

参考资料:

  • MNN推理过程源码分析笔记(一)主流程
  • 详解MNN的tflite-MobilenetSSD-c++部署流程
  • Android Stuido Ndk-Jni 开发(二):Jni中打印log信息
  • Constraints and concepts (since C++20)
  • lstat(), lstat64() — Get status of file or symbolic link
  • 深入浅出 FlatBuffers 之 Schema
  • 深入浅出 FlatBuffers 之 Encode

你可能感兴趣的:(DeepLearning,MNN,深度学习,计算机视觉)