traincascade源码解读

阅读源码位于opencv\sources\apps\traincascade中,下面主要是利用traincascade训练框架的大致的理解
【1】traincascade.cpp

int main(int argc, char* argv[])
{
    CvCascadeClassifier classifier;
    string cascadeDirName, vecName, bgName;
    int numPos = 2000;
    int numNeg = 1000;
    int numStages = 20;
    int numThreads = getNumThreads();
    int precalcValBufSize = 1024,
        precalcIdxBufSize = 1024;
    bool baseFormatSave = false;
    double acceptanceRatioBreakValue = -1.0;

    CvCascadeParams cascadeParams;
    CvCascadeBoostParams stageParams;
    Ptr featureParams[] = { makePtr(),
        makePtr(),
        makePtr()
    };
    //各项参数,具体含义下面有介绍
    ......
    classifier.train(cascadeDirName,
        vecName,
        bgName,
        numPos, numNeg,
        precalcValBufSize, precalcIdxBufSize,
        numStages,
        cascadeParams,
        *featureParams[cascadeParams.featureType],
        stageParams,
        baseFormatSave,
        acceptanceRatioBreakValue);
    return 0;
}

train的各参数:
cascadeDirName, 表示训练结果输出目录
vecName, 正样本的vec文件,由 opencv_createsamples 生成。正样本可以由包含待检测物体的一张图片生成,也可由一系列标记好的图像生成。
bgName, 背景图像的描述文件,文件中包含一系列的图像文件名,这些图像将被随机选作物体的背景
numPos, numNeg, 正负样本的个数
precalcValBufSize, 缓存大小,用于存储预先计算的特征值(feature values),单位为MB。
precalcIdxBufSize 缓存大小,用于存储预先计算的特征索引(feature indices),单位为MB。内存越大,训练时间越短。
numStages, 训练的分类器的级数
cascadeParams, 级联参数,除了默认值外,还可以通过参数指定. 其中stageType智能取值BOOST, featureType则支持haar,LBP,LOG
*featureParams[cascadeParams.featureType], 根据fratureType确定具体使用的FeatureParams
stageParams, boost分类器的参数,
-bt指定boosttype,取值
DAB=Discrete AdaBoost
RAB = Real AdaBoost,
LB = LogitBoost,
GAB = Gentle AdaBoost,默认为GENTLE AdaBoost
-minHitRate
分类器的每一级最小检测率, 默认0.995。总的检测率大约为 min_hit_rate^number_of_stages。
-maxFalseAlarmRate
分类器的每一级允许最大FPR,默认0.5。总的为 max_false_alarm_rate^number_of_stages.
-weightTrimRate
样本权重按大小序累计超过此致的样本保留进入下一轮训练. 默认0.95。 见CvBoost::trim_weights
-maxDepth
弱分类器树最大的深度。默认是1,是二叉树(stumps),只使用一个特征。
-maxWeakCount
每一级中的弱分类器的最大数目。默认100
baseFormatSave 这个参数仅在使用Haar特征时有效。如果指定这个参数,那么级联分类器将以老的格式存储。

进入分类器的训练函数train中,该函数概述了整个Cascade的执行过程。包括训练前的初始化,各Stage的强分类器间的样本集更新及强分类器训练都可看到其踪影,最显眼的还是其中的Stage训练的for大循环。

bool CvCascadeClassifier::train(const string _cascadeDirName,
    const string _posFilename,
    const string _negFilename,
    int _numPos, int _numNeg,
    int _precalcValBufSize, int _precalcIdxBufSize,
    int _numStages,
    const CvCascadeParams& _cascadeParams,
    const CvFeatureParams& _featureParams,
    const CvCascadeBoostParams& _stageParams,
    bool baseFormatSave,
    double acceptanceRatioBreakValue)
{
    // Start recording clock ticks for training time output
    const clock_t begin_time = clock();

    if (_cascadeDirName.empty() || _posFilename.empty() || _negFilename.empty())
        CV_Error(CV_StsBadArg, "_cascadeDirName or _bgfileName or _vecFileName is NULL");

    string dirName;
    if (_cascadeDirName.find_last_of("/\\") == (_cascadeDirName.length() - 1))
        dirName = _cascadeDirName;
    else
        dirName = _cascadeDirName + '/';

    numPos = _numPos;
    numNeg = _numNeg;
    numStages = _numStages;
    //进行输入检查
    if (!imgReader.create(_posFilename, _negFilename, _cascadeParams.winSize))
    {
        cout << "Image reader can not be created from -vec " << _posFilename
            << " and -bg " << _negFilename << "." << endl;
        return false;
    }
    if (!load(dirName))
    {
        cascadeParams = _cascadeParams;
        featureParams = CvFeatureParams::create(cascadeParams.featureType);
        featureParams->init(_featureParams);
        stageParams = makePtr();
        *stageParams = _stageParams;
        featureEvaluator = CvFeatureEvaluator::create(cascadeParams.featureType);
        featureEvaluator->init(featureParams, numPos + numNeg, cascadeParams.winSize);
        stageClassifiers.reserve(numStages);
    }
    else{
        // Make sure that if model parameters are preloaded, that people are aware of this,
        // even when passing other parameters to the training command
        cout << "---------------------------------------------------------------------------------" << endl;
        cout << "Training parameters are pre-loaded from the parameter file in data folder!" << endl;
        cout << "Please empty this folder if you want to use a NEW set of training parameters." << endl;
        cout << "---------------------------------------------------------------------------------" << endl;
    }
    // 打印输出
    cout << "PARAMETERS:" << endl;
    cout << "cascadeDirName: " << _cascadeDirName << endl;
    cout << "vecFileName: " << _posFilename << endl;
    cout << "bgFileName: " << _negFilename << endl;
    cout << "numPos: " << _numPos << endl;
    cout << "numNeg: " << _numNeg << endl;
    cout << "numStages: " << numStages << endl;
    cout << "precalcValBufSize[Mb] : " << _precalcValBufSize << endl;
    cout << "precalcIdxBufSize[Mb] : " << _precalcIdxBufSize << endl;
    cout << "acceptanceRatioBreakValue : " << acceptanceRatioBreakValue << endl;
    cascadeParams.printAttrs();
    stageParams->printAttrs();
    featureParams->printAttrs();

    int startNumStages = (int)stageClassifiers.size();
    if (startNumStages > 1)
        cout << endl << "Stages 0-" << startNumStages - 1 << " are loaded" << endl;
    else if (startNumStages == 1)
        cout << endl << "Stage 0 is loaded" << endl;

    double requiredLeafFARate = pow((double)stageParams->maxFalseAlarm, (double)numStages) /
        (double)stageParams->max_depth;
    double tempLeafFARate;
    // 进入numstages级分类器的每一级的训练大循环
    for (int i = startNumStages; i < numStages; i++)
    {
        cout << endl << "===== TRAINING " << i << "-stage =====" << endl;
        cout << " << endl;

        if (!updateTrainingSet(tempLeafFARate))// 更新训练数据集
        {
            cout << "Train dataset for temp stage can not be filled. "
                "Branch training terminated." << endl;
            break;
        }
        if (tempLeafFARate <= requiredLeafFARate)//满足要求的虚警率
        {
            cout << "Required leaf false alarm rate achieved. "
                "Branch training terminated." << endl;
            break;
        }
        if ((tempLeafFARate <= acceptanceRatioBreakValue) && (acceptanceRatioBreakValue >= 0))//满足接受率
        {
            cout << "The required acceptanceRatio for the model has been reached to avoid overfitting of trainingdata. "
                "Branch training terminated." << endl;
            break;
        }

        Ptr tempStage = makePtr();
        //进入每一级stage的训练,重要
        bool isStageTrained = tempStage->train(featureEvaluator,
            curNumSamples, _precalcValBufSize, _precalcIdxBufSize,
            *stageParams);
        cout << "END>" << endl;

        if (!isStageTrained)
            break;

        stageClassifiers.push_back(tempStage);// 保存每一级的分类器

        // save params
        if (i == 0)
        {
            std::string paramsFilename = dirName + CC_PARAMS_FILENAME;
            FileStorage fs(paramsFilename, FileStorage::WRITE);
            if (!fs.isOpened())
            {
                cout << "Parameters can not be written, because file " << paramsFilename
                    << " can not be opened." << endl;
                return false;
            }
            fs << FileStorage::getDefaultObjectName(paramsFilename) << "{";
            writeParams(fs);
            fs << "}";
        }
        // save current stage
        char buf[10];
        sprintf(buf, "%s%d", "stage", i);
        string stageFilename = dirName + buf + ".xml";
        FileStorage fs(stageFilename, FileStorage::WRITE);
        if (!fs.isOpened())
        {
            cout << "Current stage can not be written, because file " << stageFilename
                << " can not be opened." << endl;
            return false;
        }
        fs << FileStorage::getDefaultObjectName(stageFilename) << "{";
        tempStage->write(fs, Mat());
        fs << "}";

        // Output training time up till now
        float seconds = float(clock() - begin_time) / CLOCKS_PER_SEC;
        int days = int(seconds) / 60 / 60 / 24;
        int hours = (int(seconds) / 60 / 60) % 24;
        int minutes = (int(seconds) / 60) % 60;
        int seconds_left = int(seconds) % 60;
        cout << "Training until now has taken " << days << " days " << hours << " hours " << minutes << " minutes " << seconds_left << " seconds." << endl;
    }// 一个stage训练完成

    if (stageClassifiers.size() == 0)
    {
        cout << "Cascade classifier can't be trained. Check the used training parameters." << endl;
        return false;
    }
    // 保存这时生成的训练器.xml文件
    save(dirName + CC_CASCADE_FILENAME, baseFormatSave);

    return true;
}

上述的进入每一级的stage的训练过程如下:


bool CvCascadeBoost::train(const CvFeatureEvaluator* _featureEvaluator,
    int _numSamples,
    int _precalcValBufSize, int _precalcIdxBufSize,
    const CvCascadeBoostParams& _params)
{
    bool isTrained = false;
    CV_Assert(!data);
    clear();
    data = new CvCascadeBoostTrainData(_featureEvaluator, _numSamples,
        _precalcValBufSize, _precalcIdxBufSize, _params);
    CvMemStorage *storage = cvCreateMemStorage();
    weak = cvCreateSeq(0, sizeof(CvSeq), sizeof(CvBoostTree*), storage);// 多个弱分类器的序列
    storage = 0;

    set_params(_params);
    if ((_params.boost_type == LOGIT) || (_params.boost_type == GENTLE))
        data->do_responses_copy();

    update_weights(0);

    cout << "+----+---------+---------+" << endl;
    cout << "|  N |    HR   |    FA   |" << endl;
    cout << "+----+---------+---------+" << endl;
    //每一个弱分类器的训练过程
    do
    {
        CvCascadeBoostTree* tree = new CvCascadeBoostTree;
        if (!tree->train(data, subsample_mask, this))
        //进入一个决策树的训练
        {
            delete tree;
            break;
        }
        cvSeqPush(weak, &tree);//将该弱分类器加入序列
        update_weights(tree);//更新样本权重
        trim_weights();
        if (cvCountNonZero(subsample_mask) == 0)
            break;
    } while (!isErrDesired() && (weak->total < params.weak_count));

    if (weak->total > 0)
    {
        data->is_classifier = true;
        data->free_train_data();
        isTrained = true;
    }
    else
        clear();

    return isTrained;
}

对于细节的理解还不甚清楚,早使用时只需要进行菜单命令行输入对应的参数,应用中可能也会遇到各种问题,后面将会有实际操作过程

你可能感兴趣的:(物体检测)