本文主要使用opencv实现图像分类器
int main(void)
{
//int clusters=1000;
//Classfication_SVM c(clusters);
特征聚类
//c.Train_SVM();
c.Test_SVM();
将测试图片分类
//c.category_By_svm();
SVM_Classify c(1000);
//特征聚类
//c.Train_SVM("C:\\Users\\Katrinali\\Desktop\\project data\\data\\result_image\\","C:\\Users\\Katrinali\\Desktop\\project data\\");
c.Test_SVM("C:/Users/Katrinali/Desktop/beisu","C:\\Users\\Katrinali\\Desktop\\project data\\data\\train_images\\train.txt");
//将测试图片分类
cv::Mat img = cv::imread("C:/Users/Katrinali/Desktop/project data/data/test_image/2.jpg");
c.Predict_Img("C:/Users/Katrinali/Desktop/beisu",img);
return 0;
}
SVM_Classify.h文件
#pragma once
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
//boost 库
#include
class SVM_Classify
{
public:
SVM_Classify(int);
~SVM_Classify(void);
private:
cv::Ptr<cv::FeatureDetector> featureDecter;
cv::Ptr<cv::DescriptorExtractor> descriptorExtractor;
cv::Ptr<cv::BOWKMeansTrainer> bowtrainer;
cv::Ptr<cv::BOWImgDescriptorExtractor> bowDescriptorExtractor;
cv::Ptr<cv::FlannBasedMatcher> descriptorMacher;
cv::Mat vocab;
CvSVM *stor_svms;
private:
int categories_size;
std::vector<std::string> category_name;
std::multimap<std::string, cv::Mat> train_set;
std::map<std::string,cv::Mat> allsamples_bow;
private:
void make_train_set(std::string train_path);
void create_vacab(std::string train_path);
void create_bow_image();
void trainSvm(std::string modelPath);
void make_test_set(std::string file_path);
void load_bow_image(std::string model_path);
public:
void Train_SVM(std::string TrainPath, std::string ModelPath);
void Test_SVM(std::string ModelPath,std::string txtFile);
std::string Predict_Img(std::string model_path,cv::Mat input_pic);
};
SVM_Classify.cpp文件
#include "SVM_Classify.h"
SVM_Classify::SVM_Classify(int clusters)
{
//初始化指针
featureDecter=new cv::SurfFeatureDetector();
descriptorExtractor=new cv::SurfDescriptorExtractor();
bowtrainer=new cv::BOWKMeansTrainer(clusters);
descriptorMacher=new cv::FlannBasedMatcher();
bowDescriptorExtractor=new cv::BOWImgDescriptorExtractor(descriptorExtractor,descriptorMacher);
}
SVM_Classify::~SVM_Classify(void)
{
}
void SVM_Classify::Train_SVM(std::string TrainPath, std::string ModelPath)
{
//读取训练集,得到种类
make_train_set(TrainPath);
std::string modelPath = ModelPath + std::string("vocab.xml");
create_vacab(modelPath);
create_bow_image();
trainSvm(ModelPath);
}
void SVM_Classify::make_train_set(std::string train_path)
{
std::string categor;
//递归迭代rescursive 直接定义两个迭代器:i为迭代起点(有参数),end_iter迭代终点
for(boost::filesystem::recursive_directory_iterator i(train_path),end_iter;i!=end_iter;i++)
{
// level == 0即为目录,因为TRAIN__FOLDER中设置如此
if(i.level()==0)
{
// 将类目名称设置为目录的名称
categor=(i->path()).filename().string();
category_name.push_back(categor);
}else
{
// 读取文件夹下的文件。level 1表示这是一副训练图,通过multimap容器来建立由类目名称到训练图的一对多的映射
std::string filename=std::string(train_path)+ std::string("/") +categor + std::string("/")+(i->path()).filename().string();
cv::Mat temp=cv::imread(filename,CV_LOAD_IMAGE_GRAYSCALE);
std::pair<std::string,cv::Mat> p(categor,temp);
//得到训练集
train_set.insert(p);
}
}
categories_size=category_name.size();
}
void SVM_Classify::create_vacab(std::string Model_Path)
{
//创建本地文件
cv::Mat vocab_descriptors;
// 对于每一幅模板,提取SURF算子,存入到vocab_descriptors中
std::multimap<std::string,cv::Mat> ::iterator i=train_set.begin();
for(;i!=train_set.end();i++)
{
std::vector<cv::KeyPoint> kp;
cv::Mat templ=(*i).second;
cv::Mat descrip;
featureDecter->detect(templ,kp);
descriptorExtractor->compute(templ,kp,descrip);
//push_back(Mat);在原来的Mat的最后一行后再加几行,元素为Mat时, 其类型和列的数目 必须和矩阵容器是相同的
vocab_descriptors.push_back(descrip);
}
//将每一副图的Surf特征利用add函数加入到bowTraining中去,就可以进行聚类训练了
bowtrainer->add(vocab_descriptors);
// 对SURF描述子进行聚类
vocab=bowtrainer->cluster();
//以文件格式保存词典
cv::FileStorage file_stor(Model_Path,cv::FileStorage::WRITE);
file_stor<<"vocabulary"<<vocab;
file_stor.release();
}
void SVM_Classify::create_bow_image()
{
bowDescriptorExtractor->setVocabulary(vocab);
// 对于每一幅模板,提取SURF算子,存入到vocab_descriptors中
std::multimap<std::string,cv::Mat> ::iterator i=train_set.begin();
for(;i!=train_set.end();i++)
{
std::vector<cv::KeyPoint> kp;
std::string cate_nam=(*i).first;
cv::Mat tem_image=(*i).second;
cv::Mat imageDescriptor;
featureDecter->detect(tem_image,kp);
bowDescriptorExtractor->compute(tem_image,kp,imageDescriptor);
//push_back(Mat);在原来的Mat的最后一行后再加几行,元素为Mat时, 其类型和列的数目 必须和矩阵容器是相同的
allsamples_bow[cate_nam].push_back(imageDescriptor);
}
}
void SVM_Classify::trainSvm(std::string modelPath)
{
stor_svms=new CvSVM[categories_size];
//设置训练参数
cv::SVMParams svmParams;
svmParams.svm_type = CvSVM::C_SVC;
svmParams.kernel_type = CvSVM::LINEAR;
svmParams.term_crit = cvTermCriteria(CV_TERMCRIT_ITER, 100, 1e-6);
for(int i=0;i<categories_size;i++)
{
cv::Mat tem_Samples( 0, allsamples_bow.at( category_name[i] ).cols, allsamples_bow.at( category_name[i] ).type() );
cv::Mat responses( 0, 1, CV_32SC1 );
tem_Samples.push_back( allsamples_bow.at( category_name[i] ) );
cv::Mat posResponses( allsamples_bow.at( category_name[i]).rows, 1, CV_32SC1, cv::Scalar::all(1) );
responses.push_back( posResponses );
for ( auto itr = allsamples_bow.begin(); itr != allsamples_bow.end(); ++itr )
{
if ( itr -> first == category_name[i] ) {
continue;
}
tem_Samples.push_back( itr -> second );
cv::Mat response( itr -> second.rows, 1, CV_32SC1, cv::Scalar::all( -1 ) );
responses.push_back( response );
}
stor_svms[i].train( tem_Samples, responses, cv::Mat(), cv::Mat(), svmParams );
//存储svm
std::string svm_filename=modelPath + category_name[i] + std::string("SVM.xml");
stor_svms[i].save(svm_filename.c_str());
}
}
void SVM_Classify::Test_SVM(std::string ModelPath,std::string txtFile)
{
make_test_set(txtFile);
load_bow_image(ModelPath);
}
void SVM_Classify::make_test_set(std::string file_path)
{
std::string categor;
//递归迭代rescursive 直接定义两个迭代器:i为迭代起点(有参数),end_iter迭代终点
int i=0;
std::ifstream infile;
infile.open(file_path,std::ios::in);
while(!infile.eof()) // 若未到文件结束一直循环
{
std::string tempstr;
std::getline(infile, tempstr, '\n');
category_name.push_back(tempstr);
}
categories_size=category_name.size();
}
void SVM_Classify::load_bow_image(std::string model_path)
{
cv::FileStorage va_fs(model_path + "/vocab.xml",cv::FileStorage::READ);
//如果词典存在则直接读取
if(va_fs.isOpened())
{
cv::Mat temp_vacab;
va_fs["vocabulary"] >> temp_vacab;
bowDescriptorExtractor->setVocabulary(temp_vacab);
va_fs.release();
}
}
std::string SVM_Classify::Predict_Img(std::string model_path,cv::Mat input_pic)
{
cv::Mat gray_pic;
cv::Mat threshold_image;
std::string prediction_category;
float curConfidence;
imshow("输入图片:",input_pic);
cv::cvtColor(input_pic,gray_pic,CV_BGR2GRAY);
// 提取BOW描述子
std::vector<cv::KeyPoint> kp;
cv::Mat test;
featureDecter->detect(gray_pic,kp);
bowDescriptorExtractor->compute(gray_pic,kp,test);
int sign=0;
float best_score = -2.0f;
for(int i=0;i<categories_size;i++)
{
std::string cate_na=category_name[i];
std::string f_path=std::string(model_path) + "/" +cate_na + std::string("SVM.xml");
cv::FileStorage svm_fs(f_path,cv::FileStorage::READ);
//读取SVM.xml
if(svm_fs.isOpened())
{
svm_fs.release();
CvSVM st_svm;
st_svm.load(f_path.c_str());
if(sign==0)
{
float score_Value = st_svm.predict( test, true );
float class_Value = st_svm.predict( test, false );
sign = ( score_Value < 0.0f ) == ( class_Value < 0.0f )? 1 : -1;
}
curConfidence = sign * st_svm.predict( test, true );
}
else
{
if(sign==0)
{
float scoreValue = stor_svms[i].predict( test, true );
float classValue = stor_svms[i].predict( test, false );
sign = ( scoreValue < 0.0f ) == ( classValue < 0.0f )? 1 : -1;
}
curConfidence = sign * stor_svms[i].predict( test, true );
}
if(curConfidence>best_score)
{
best_score=curConfidence;
prediction_category=cate_na;
}
}
return prediction_category;
}
Classfication_SVM.h
#pragma once
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
//boost 库
#include
using namespace cv;
using namespace std;
//定义一个boost库的命名空间
namespace fs=boost::filesystem;
using namespace fs;
#define DATA_FOLDER "C:\\Users\\Katrinali\\Desktop\\project data\\data\\"
#define TEST_FOLDER "C:\\Users\\Katrinali\\Desktop\\project data\\data\\test_image"
#define RESULT_FOLDER "C:\\Users\\Katrinali\\Desktop\\project data\\data\\result_image\\"
class Classfication_SVM
{
public:
Classfication_SVM(int);//构造函数
~Classfication_SVM(void);
private:
//存放所有训练图片的BOW
map<string,Mat> allsamples_bow;
//从类目名称到训练图集的映射,关键字可以重复出现
multimap<string,Mat> train_set;
// 训练得到的SVM
CvSVM *stor_svms;
//类目名称,也就是TRAIN_FOLDER设置的目录名
vector<string> category_name;
//类目数目
int categories_size;
//用SURF特征构造视觉词库的聚类数目
int clusters;
//存放训练图片词典
Mat vocab;
//特征检测器detectors与描述子提取器extractors 泛型句柄类Ptr
Ptr<FeatureDetector> featureDecter;
Ptr<DescriptorExtractor> descriptorExtractor;
Ptr<BOWKMeansTrainer> bowtrainer;
Ptr<BOWImgDescriptorExtractor> bowDescriptorExtractor;
Ptr<FlannBasedMatcher> descriptorMacher;
//构造训练集合
void make_train_set(string train_path);
void make_test_set(string train_path);
// 移除扩展名,用来讲模板组织成类目
string remove_extention(string);
public:
//训练分类器
void trainSvm();
//将测试图片分类
void category_By_svm();
void Train_SVM();
void create_vacab();
void create_bow_image();
void Test_SVM();
void load_bow_image();
void load_vacab();
};
Classfication_SVM.cpp
#include "Classfication_SVM.h"
Classfication_SVM::Classfication_SVM(int _clusters)
{
cout<<"开始初始化..."<<endl;
clusters=_clusters;
//初始化指针
featureDecter=new SurfFeatureDetector();
descriptorExtractor=new SurfDescriptorExtractor();
bowtrainer=new BOWKMeansTrainer(clusters);
descriptorMacher=new FlannBasedMatcher();
bowDescriptorExtractor=new BOWImgDescriptorExtractor(descriptorExtractor,descriptorMacher);
cout<<"初始化完毕..."<<endl;
}
Classfication_SVM::~Classfication_SVM(void)
{
}
void Classfication_SVM::Train_SVM()
{
//读取训练集,得到种类
make_train_set("C:\\Users\\Katrinali\\Desktop\\project data\\data\\train_images\\");
create_vacab();
create_bow_image();
trainSvm();
}
void Classfication_SVM::Test_SVM()
{
make_test_set("C:\\Users\\Katrinali\\Desktop\\project data\\data\\train_images\\");
load_vacab();
load_bow_image();
}
//构造训练集合
void Classfication_SVM::make_train_set(string train_path)
{
cout<<"读取训练集..."<<endl;
string categor;
//递归迭代rescursive 直接定义两个迭代器:i为迭代起点(有参数),end_iter迭代终点
for(recursive_directory_iterator i(train_path),end_iter;i!=end_iter;i++)
{
// level == 0即为目录,因为TRAIN__FOLDER中设置如此
if(i.level()==0)
{
// 将类目名称设置为目录的名称
categor=(i->path()).filename().string();
category_name.push_back(categor);
}else
{
// 读取文件夹下的文件。level 1表示这是一副训练图,通过multimap容器来建立由类目名称到训练图的一对多的映射
string filename=string(train_path)+categor+string("/")+(i->path()).filename().string();
Mat temp=imread(filename,CV_LOAD_IMAGE_GRAYSCALE);
pair<string,Mat> p(categor,temp);
//得到训练集
train_set.insert(p);
}
}
categories_size=category_name.size();
cout<<"发现 "<<categories_size<<"种类别物体..."<<endl;
}
void Classfication_SVM::create_vacab()
{
//创建本地文件
Mat vocab_descriptors;
// 对于每一幅模板,提取SURF算子,存入到vocab_descriptors中
multimap<string,Mat> ::iterator i=train_set.begin();
for(;i!=train_set.end();i++)
{
vector<KeyPoint>kp;
Mat templ=(*i).second;
Mat descrip;
featureDecter->detect(templ,kp);
descriptorExtractor->compute(templ,kp,descrip);
//push_back(Mat);在原来的Mat的最后一行后再加几行,元素为Mat时, 其类型和列的数目 必须和矩阵容器是相同的
vocab_descriptors.push_back(descrip);
}
cout << "训练图片开始聚类..." << endl;
//将每一副图的Surf特征利用add函数加入到bowTraining中去,就可以进行聚类训练了
bowtrainer->add(vocab_descriptors);
// 对SURF描述子进行聚类
vocab=bowtrainer->cluster();
cout<<"聚类完毕,得出词典..."<<endl;
//以文件格式保存词典
FileStorage file_stor(DATA_FOLDER "vocab.xml",FileStorage::WRITE);
file_stor<<"vocabulary"<<vocab;
file_stor.release();
}
void Classfication_SVM::create_bow_image()
{
bowDescriptorExtractor->setVocabulary(vocab);
// 对于每一幅模板,提取SURF算子,存入到vocab_descriptors中
multimap<string,Mat> ::iterator i=train_set.begin();
for(;i!=train_set.end();i++)
{
vector<KeyPoint>kp;
string cate_nam=(*i).first;
Mat tem_image=(*i).second;
Mat imageDescriptor;
featureDecter->detect(tem_image,kp);
bowDescriptorExtractor->compute(tem_image,kp,imageDescriptor);
//push_back(Mat);在原来的Mat的最后一行后再加几行,元素为Mat时, 其类型和列的数目 必须和矩阵容器是相同的
allsamples_bow[cate_nam].push_back(imageDescriptor);
}
}
//训练分类器
void Classfication_SVM::trainSvm()
{
stor_svms=new CvSVM[categories_size];
//设置训练参数
SVMParams svmParams;
svmParams.svm_type = CvSVM::C_SVC;
svmParams.kernel_type = CvSVM::LINEAR;
svmParams.term_crit = cvTermCriteria(CV_TERMCRIT_ITER, 100, 1e-6);
cout<<"训练分类器..."<<endl;
for(int i=0;i<categories_size;i++)
{
Mat tem_Samples( 0, allsamples_bow.at( category_name[i] ).cols, allsamples_bow.at( category_name[i] ).type() );
Mat responses( 0, 1, CV_32SC1 );
tem_Samples.push_back( allsamples_bow.at( category_name[i] ) );
Mat posResponses( allsamples_bow.at( category_name[i]).rows, 1, CV_32SC1, Scalar::all(1) );
responses.push_back( posResponses );
for ( auto itr = allsamples_bow.begin(); itr != allsamples_bow.end(); ++itr )
{
if ( itr -> first == category_name[i] ) {
continue;
}
tem_Samples.push_back( itr -> second );
Mat response( itr -> second.rows, 1, CV_32SC1, Scalar::all( -1 ) );
responses.push_back( response );
}
stor_svms[i].train( tem_Samples, responses, Mat(), Mat(), svmParams );
//存储svm
string svm_filename=string(DATA_FOLDER) + category_name[i] + string("SVM.xml");
stor_svms[i].save(svm_filename.c_str());
}
cout<<"分类器训练完毕..."<<endl;
}
void Classfication_SVM::make_test_set(string train_path)
{
string categor;
//递归迭代rescursive 直接定义两个迭代器:i为迭代起点(有参数),end_iter迭代终点
for(recursive_directory_iterator i(train_path),end_iter;i!=end_iter;i++)
{
// level == 0即为目录,因为TRAIN__FOLDER中设置如此
if(i.level()==0)
{
// 将类目名称设置为目录的名称
categor=(i->path()).filename().string();
category_name.push_back(categor);
}
}
categories_size=category_name.size();
cout<<"发现 "<<categories_size<<"种类别物体..."<<endl;
}
void Classfication_SVM::load_vacab()
{
FileStorage vacab_fs(DATA_FOLDER "vocab.xml",FileStorage::READ);
//如果之前已经生成好,就不需要重新聚类生成词典
if(vacab_fs.isOpened())
{
cout<<"图片已经聚类,词典已经存在.."<<endl;
vacab_fs.release();
}
}
//构造bag of words
void Classfication_SVM::load_bow_image()
{
cout<<"构造bag of words..."<<endl;
FileStorage va_fs(DATA_FOLDER "vocab.xml",FileStorage::READ);
//如果词典存在则直接读取
if(va_fs.isOpened())
{
Mat temp_vacab;
va_fs["vocabulary"] >> temp_vacab;
bowDescriptorExtractor->setVocabulary(temp_vacab);
va_fs.release();
}
}
//对测试图片进行分类
void Classfication_SVM::category_By_svm()
{
cout<<"物体分类开始..."<<endl;
Mat gray_pic;
Mat threshold_image;
string prediction_category;
float curConfidence;
directory_iterator begin_train(TEST_FOLDER);
directory_iterator end_train;
for(;begin_train!=end_train;++begin_train)
{
//获取该目录下的图片名
string train_pic_name=(begin_train->path()).filename().string();
string train_pic_path=string(TEST_FOLDER)+string("/")+(begin_train->path()).filename().string();
//读取图片
cout<<train_pic_path<<endl;
Mat input_pic=imread(train_pic_path);
imshow("输入图片:",input_pic);
cvtColor(input_pic,gray_pic,CV_BGR2GRAY);
// 提取BOW描述子
vector<KeyPoint>kp;
Mat test;
featureDecter->detect(gray_pic,kp);
bowDescriptorExtractor->compute(gray_pic,kp,test);
int sign=0;
float best_score = -2.0f;
for(int i=0;i<categories_size;i++)
{
string cate_na=category_name[i];
string f_path=string(DATA_FOLDER)+cate_na + string("SVM.xml");
FileStorage svm_fs(f_path,FileStorage::READ);
//读取SVM.xml
if(svm_fs.isOpened())
{
svm_fs.release();
CvSVM st_svm;
st_svm.load(f_path.c_str());
if(sign==0)
{
float score_Value = st_svm.predict( test, true );
float class_Value = st_svm.predict( test, false );
sign = ( score_Value < 0.0f ) == ( class_Value < 0.0f )? 1 : -1;
}
curConfidence = sign * st_svm.predict( test, true );
}
else
{
if(sign==0)
{
float scoreValue = stor_svms[i].predict( test, true );
float classValue = stor_svms[i].predict( test, false );
sign = ( scoreValue < 0.0f ) == ( classValue < 0.0f )? 1 : -1;
}
curConfidence = sign * stor_svms[i].predict( test, true );
}
if(curConfidence>best_score)
{
best_score=curConfidence;
prediction_category=cate_na;
}
}
//将图片写入相应的文件夹下
directory_iterator begin_iterater(RESULT_FOLDER);
directory_iterator end_iterator;
//获取该目录下的文件名
for(;begin_iterater!=end_iterator;++begin_iterater)
{
if(begin_iterater->path().filename().string()==prediction_category)
{
string filename=string(RESULT_FOLDER)+prediction_category+string("/")+train_pic_name;
imwrite(filename,input_pic);
}
}
//显示输出
//namedWindow("Dectect Object");
cout<<"这张图属于: "<<prediction_category<<endl;
//imshow("Dectect Object",result_objects[prediction_category]);
waitKey(0);
}
}
源码已在博客中,不过也可以直接下载整个工程
源码