本文地址: http://blog.csdn.net/caroline_wendy/article/details/26261173
需要加载(load)SVM的模型, 然后将结点转换为SVM的格式, 即索引(index)+数据(value)的形式;
释放SVM的model有专用的函数: svm_free_and_destroy_model, 否则容易内存泄露;
可以预测数据的概率, 则需要模型是概率模型, 返回的是一个类别数组(2分类, 则为2个值的数组), 即各个标签的概率值;
注意: 标签即概率值较大的部分, 所以在训练时, 应注意正负样本的顺序,
正样本在前, 下标0, 为正样本的概率, 下标1, 为负样本的概率; 反之亦然.
代码:
/*! @file ******************************************************************************** <PRE> 模块名 : 分类器 文件名 : SvmClassifier.cpp 相关文件 : SvmClassifier.h 文件实现功能 : SVM分类器类实现 作者 : C.L.Wang Email: [email protected] 版本 : 1.0 -------------------------------------------------------------------------------- 多线程安全性 : 是 异常时安全性 : 是 -------------------------------------------------------------------------------- 备注 : 无 -------------------------------------------------------------------------------- 修改记录 : 日 期 版本 修改人 修改内容 2014/03/27 1.0 C.L.Wang Create </PRE> ******************************************************************************** * 版权所有(c) C.L.Wang, 保留所有权利 *******************************************************************************/ #include "stdafx.h" #include "SvmClassifier.h" #include <opencv.hpp> using namespace std; using namespace cv; using namespace vd; const std::string SvmClassifier::NORM_NAME = "normalization.xml"; //归一化模型 const std::string SvmClassifier::SVM_MODEL_NAME = "hvd.model"; //Svm模型 bool SvmClassifier::m_mutex = true; //互斥锁 /*! @function ******************************************************************************** <PRE> 函数名 : SvmClassifier 功能 : 参数构造函数 参数 : const Mat& _videoFeature, 视频特征; const string& _modelPath, 模型路径; 返回值 : 无 抛出异常 : 无 -------------------------------------------------------------------------------- 复杂度 : 无 备注 : 无 典型用法 : SvmClassifier iSF(_videoFeature, _modelPath); -------------------------------------------------------------------------------- 作者 : C.L.Wang </PRE> *******************************************************************************/ SvmClassifier::SvmClassifier ( const cv::Mat& _videoFeature, /*特征*/ const std::string& _modelPath /*模型路径*/ ) : Classifier(_videoFeature, _modelPath), m_model(nullptr), m_node(nullptr) { return; } /*! @function ******************************************************************************** <PRE> 函数名 : ~SvmClassifier 功能 : 析构函数 参数 : void 返回值 : 无 抛出异常 : 无 -------------------------------------------------------------------------------- 复杂度 : 无 备注 : 无 典型用法 : iSC.~SvmClassifier(); -------------------------------------------------------------------------------- 作者 : C.L.Wang </PRE> *******************************************************************************/ SvmClassifier::~SvmClassifier (void) { if (m_model != nullptr) { svm_free_and_destroy_model(&m_model); } if (m_node != nullptr) { delete[] m_node; m_node = nullptr; } return; } /*! @function ******************************************************************************** <PRE> 函数名 : calculateResult 功能 : 计算分类结果 参数 : void 返回值 : const double, 分类结果 抛出异常 : 无 -------------------------------------------------------------------------------- 复杂度 : 无 备注 : 无 典型用法 : result = iSC.calculateResult(); -------------------------------------------------------------------------------- 作者 : C.L.Wang </PRE> *******************************************************************************/ const double SvmClassifier::calculateResult (void) { double result(0.0); while(1) { if (m_mutex == true) { m_mutex = false; _initModel(); result = _predictValue(); if (m_model != nullptr) { svm_free_and_destroy_model(&m_model); } if (m_node != nullptr) { delete[] m_node; m_node = nullptr; } m_mutex = true; break; } } return result; } /*! @function ******************************************************************************** <PRE> 函数名 : _predictValue 功能 : 预测值 参数 : void 返回值 : const double, 预测值; 抛出异常 : 无 -------------------------------------------------------------------------------- 复杂度 : 无 备注 : 无 典型用法 : result = _predictValue(); -------------------------------------------------------------------------------- 作者 : C.L.Wang </PRE> *******************************************************************************/ const double SvmClassifier::_predictValue (void) const { double label (0.0); double prop (0.0); const int nr_class (2); double* prob_estimates = (double *) malloc(nr_class*sizeof(double)); label = svm_predict_probability(m_model, m_node, prob_estimates); prop = prob_estimates[0]; //返回预测概率值 delete[] prob_estimates; return prop; } /*! @function ******************************************************************************** <PRE> 函数名 : _initModel 功能 : 初始化模型 参数 : void 返回值 : void 抛出异常 : 无 -------------------------------------------------------------------------------- 复杂度 : 无 备注 : 无 典型用法 : _initModel(); -------------------------------------------------------------------------------- 作者 : C.L.Wang </PRE> *******************************************************************************/ void SvmClassifier::_initModel (void) { /*完整路径*/ std::string modelName (m_modelPath); //模型名称 std::string normName (m_modelPath); //归一化名称 const std::string slash("/"); modelName.append(slash); modelName.append(SVM_MODEL_NAME); normName.append(slash); normName.append(NORM_NAME); std::ifstream ifs; ifs.open(modelName, ios::in); if (ifs.fail()) { __printLog(std::cerr, "Failed to open the model file!"); } ifs.close(); ifs.open(normName, ios::in); if (ifs.fail()) { __printLog(std::cerr, "Failed to open the model file!"); } ifs.close(); if (m_model != nullptr) { svm_free_and_destroy_model(&m_model); } m_model = svm_load_model(modelName.c_str()); __transSvmNode(normName); return; } /*! @function ******************************************************************************** <PRE> 函数名 : __transSvmNode 功能 : 转换Svm结点 参数 : const string& normName, 归一化模型路径 返回值 : void 抛出异常 : 无 -------------------------------------------------------------------------------- 复杂度 : 无 备注 : 无 典型用法 : __transSvmNode(normName); -------------------------------------------------------------------------------- 作者 : C.L.Wang </PRE> *******************************************************************************/ void SvmClassifier::__transSvmNode (const std::string& _normName) { cv::FileStorage fs(_normName, FileStorage::READ); cv::Mat maxNorm; fs["normalization"] >> maxNorm; fs.release(); /*归一化视频特征*/ cv::Mat normFeature = cv::Mat::zeros(1, maxNorm.cols-2, CV_64FC1); for (int j=2; j<m_videoFeature.cols; ++j) { for(int i=0; i<m_videoFeature.rows; ++i) { normFeature.at<double>(0, j-2) += m_videoFeature.at<double>(i, j); } } for (int j=0; j<normFeature.cols; ++j) { normFeature.at<double>(0, j) /= m_videoFeature.rows; if (maxNorm.at<double>(0, j+2) > 0.0001) normFeature.at<double>(0, j) /= maxNorm.at<double>(0, j+2); } normFeature.at<double>(0,0) = 0.0; if (m_node != nullptr) { delete[] m_node; m_node = nullptr; } m_node = new svm_node[normFeature.cols]; for (int j=1; j < normFeature.cols; ++j) { m_node[j-1].index = j; m_node[j-1].value = normFeature.at<double>(0, j); } m_node[normFeature.cols-1].index = -1; m_node[normFeature.cols-1].value = 0; return; }