MNN classficationTopkEval.cpp 输入模型和配置文件,测试模型在 ImageNet 数据集上的分类精度。
程序分为3部分:
Interpreter::resizeTensor 调整输入张量,进而 Interpreter::resizeSession 调整后续所涉及的张量。
MNN_PRINT 兼容 Android JNI log。
输入模型及配置文件,runEvaluation 评估模型。
if (argc < 3) {
MNN_PRINT("Usage: ./classficationTopkEval.out model.mnn preTreatConfig.json\n");
}
const auto modelPath = argv[1];
const auto preTreatConfigFile = argv[2];
runEvaluation(modelPath, preTreatConfigFile);
return 0;
网络的输入及预处理参数存储在 json 文件中。
RapidJSON 是腾讯开源的 C++ JSON 解析器和生成器。
std::ifstream::rdbuf 获取流缓冲区。返回指向内部 filebuf 对象的指针。
GenericDocument::Parse 从只读字符串解析 json 文本(带编码转换)。
int height, width;
std::string imagePath;
std::string groundTruthIdFile;
rapidjson::Document document;
{
std::ifstream fileNames(preTreatConfig);
std::ostringstream output;
output << fileNames.rdbuf();
auto outputStr = output.str();
document.Parse(outputStr.c_str());
if (document.HasParseError()) {
MNN_ERROR("Invalid json\n");
return 0;
}
}
GenericDocument::GetObject 返回 Object 对象。
ImageProcess::Config 结构体记录图像格式及变换时的参数。
GenericObject::HasMember 查询字段。
GenericObject 对象重载操作符 operator[] 。
ImageFormat 为枚举类型。
GenericValue::GetArray
GenericValue::GetFloat
auto picObj = document.GetObject();
ImageProcess::Config config;
config.filterType = BILINEAR;
// defalut input image format
config.destFormat = BGR;
{
if (picObj.HasMember("format")) {
auto format = picObj["format"].GetString();
static std::map<std::string, ImageFormat> formatMap{
{
"BGR", BGR}, {
"RGB", RGB}, {
"GRAY", GRAY}};
if (formatMap.find(format) != formatMap.end()) {
config.destFormat = formatMap.find(format)->second;
}
}
}
config.sourceFormat = RGBA;
{
if (picObj.HasMember("mean")) {
auto mean = picObj["mean"].GetArray();
int cur = 0;
for (auto iter = mean.begin(); iter != mean.end(); iter++) {
config.mean[cur++] = iter->GetFloat();
}
}
if (picObj.HasMember("normal")) {
auto normal = picObj["normal"].GetArray();
int cur = 0;
for (auto iter = normal.begin(); iter != normal.end(); iter++) {
config.normal[cur++] = iter->GetFloat();
}
}
if (picObj.HasMember("width")) {
width = picObj["width"].GetInt();
}
if (picObj.HasMember("height")) {
height = picObj["height"].GetInt();
}
if (picObj.HasMember("imagePath")) {
imagePath = picObj["imagePath"].GetString();
}
if (picObj.HasMember("groundTruthId")) {
groundTruthIdFile = picObj["groundTruthId"].GetString();
}
}
ImageProcess::create 根据 ImageProcess::Config 创建 ImageProcess 对象。
Interpreter::createFromFile 由模型文件创建一个 Interpreter 对象。
Interpreter 持有网络数据,多个会话可以共享同一个网络。
ScheduleConfig 结构体进行会话调度表配置。
Interpreter::createSession 使用 ScheduleConfig 的配置创建会话。创建的会话将在网络中进行管理。
Interpreter::getSessionInput 获取给定名称的输入张量。
Interpreter::resizeTensor 调整输入张量。
Interpreter::resizeSession 调用此函数以准备张量。调整任何输入张量的大小后,应恢复输出张量缓冲区(host 或 deviceId)。
Interpreter::getSessionOutput 获取给定名称的输出张量。
lstat 获取有关指定文件的状态信息,并将其放置在buf
参数指向的内存区域中。
std::shared_ptr<ImageProcess> pretreat(ImageProcess::create(config));
std::shared_ptr<Interpreter> classficationInterpreter(Interpreter::createFromFile(modelPath));
ScheduleConfig classficationEvalConfig;
classficationEvalConfig.type = MNN_FORWARD_CPU;
classficationEvalConfig.numThread = 4;
auto classficationSession = classficationInterpreter->createSession(classficationEvalConfig);
auto inputTensor = classficationInterpreter->getSessionInput(classficationSession, nullptr);
auto shape = inputTensor->shape();
// the model has not input dimension
if(shape.size() == 0){
shape.resize(4);
shape[0] = 1;
shape[1] = 3;
shape[2] = height;
shape[3] = width;
}
// set batch to be 1
shape[0] = 1;
classficationInterpreter->resizeTensor(inputTensor, shape);
classficationInterpreter->resizeSession(classficationSession);
auto outputTensor = classficationInterpreter->getSessionOutput(classficationSession, nullptr);
opendir 函数打开与目录名称相对应的目录流,并返回指向该目录流的指针。流位于目录中的第一项。
readdir 返回一个指向dirent
结构的指针,该结构表示dirp
指向的目录流中的下一个目录条目。在到达目录流的末尾或发生错误时,它返回NULL
。
// read ground truth label id
std::vector<int> groundTruthId;
{
std::ifstream inputOs(groundTruthIdFile);
std::string line;
while (std::getline(inputOs, line)) {
groundTruthId.emplace_back(std::atoi(line.c_str()));
}
}
// read images file path
int count = 0;
std::vector<std::string> files;
{
struct stat s;
lstat(imagePath.c_str(), &s);
struct dirent* filename;
DIR* dir;
dir = opendir(imagePath.c_str());
while ((filename = readdir(dir)) != nullptr) {
if (strcmp(filename->d_name, ".") == 0 || strcmp(filename->d_name, "..") == 0) {
continue;
}
files.push_back(filename->d_name);
count++;
}
std::cout << "total: " << count << std::endl;
std::sort(files.begin(), files.end());
}
if (count != groundTruthId.size()) {
MNN_ERROR("The number of input images is not same with ground truth id\n");
return 0;
}
stbi_load 源于 nothings/stb。
Matrix 包含一个用于转换坐标的3x3矩阵。这允许使用平移、缩放、倾斜、旋转和透视来映射点和向量。Matrix::setTranslate 设置行列的平移量。
ImageProcess::setMatrix 设置仿射变换矩阵。
ImageProcess::convert 将源数据转换为给定张量。
stbi_image_free 释放加载的图像。
Interpreter::runSession 运行网络。
输出结果排序,computeTopkAcc 计算前 k 位准确率。
int test = 0;
int top1 = 0;
int topk = 0;
const int outputTensorSize = outputTensor->elementSize();
if (outputTensorSize != TOTAL_CLASS_NUM) {
MNN_ERROR("Change the total class number, such as the result number of tensorflow mobilenetv1/v2 is 1001\n");
return 0;
}
std::vector<std::pair<int, float>> sortedResult(outputTensorSize);
for (const auto& file : files) {
const auto img = imagePath + file;
int h, w, channel;
auto inputImage = stbi_load(img.c_str(), &w, &h, &channel, 4);
if (!inputImage) {
MNN_ERROR("Can't open %s\n", img.c_str());
return 0;
}
// input image transform
Matrix trans;
// choose resize or crop
// resize method
// trans.setScale((float)(w-1) / (width-1), (float)(h-1) / (height-1));
// crop method
trans.setTranslate(16.0f, 16.0f);
pretreat->setMatrix(trans);
pretreat->convert((uint8_t*)inputImage, h, w, 0, inputTensor);
stbi_image_free(inputImage);
classficationInterpreter->runSession(classficationSession);
{
// default float value
auto outputDataPtr = outputTensor->host<float>();
for (int i = 0; i < outputTensorSize; ++i) {
sortedResult[i] = std::make_pair(i, outputDataPtr[i]);
}
std::sort(sortedResult.begin(), sortedResult.end(),
[](std::pair<int, float> a, std::pair<int, float> b) {
return a.second > b.second; });
}
computeTopkAcc(groundTruthId, sortedResult, test, &top1, &topk);
test++;
MNN_PRINT("==> tested: %f, Top1: %f, Topk: %f\n", (float)test / (float)count * 100.0,
(float)top1 / (float)test * 100.0, (float)topk / (float)test * 100.0);
}
return 0;
const int label = groundTruthId[index];
if (sortedResult[0].first == label) {
(*top1)++;
}
for (int i = 0; i < TOPK; ++i) {
if (label == sortedResult[i].first) {
(*topk)++;
break;
}
}