KNN的类结构在ml.h头文件中定义,代码如下:
KNN类的实现部分在mlknearest.cpp中,代码如下:
/****************************************************************************************\
* K-Nearest Neighbour Classifier *
\****************************************************************************************/
// k Nearest Neighbors
class CV_EXPORTS CvKNearest : public CvStatModel
{
public:
CvKNearest();
virtual ~CvKNearest();
CvKNearest( const CvMat* _train_data, const CvMat* _responses,
const CvMat* _sample_idx=0, bool _is_regression=false, int max_k=32 );
virtual bool train( const CvMat* _train_data, const CvMat* _responses,
const CvMat* _sample_idx=0, bool is_regression=false,
int _max_k=32, bool _update_base=false );
virtual float find_nearest( const CvMat* _samples, int k, CvMat* results=0,
const float** neighbors=0, CvMat* neighbor_responses=0, CvMat* dist=0 ) const;
virtual void clear();
int get_max_k() const;
int get_var_count() const;
int get_sample_count() const;
bool is_regression() const;
protected:
virtual float write_results( int k, int k1, int start, int end,
const float* neighbor_responses, const float* dist, CvMat* _results,
CvMat* _neighbor_responses, CvMat* _dist, Cv32suf* sort_buf ) const;
virtual void find_neighbors_direct( const CvMat* _samples, int k, int start, int end,
float* neighbor_responses, const float** neighbors, float* dist ) const;
int max_k, var_count;
int total;
bool regression;
CvVectors* samples;
};
/****************************************************************************************\
* K-Nearest Neighbour Classifier *
\****************************************************************************************/
// k Nearest Neighbors
class CV_EXPORTS CvKNearest : public CvStatModel
{
public:
CvKNearest();
virtual ~CvKNearest();
CvKNearest( const CvMat* _train_data, const CvMat* _responses,
const CvMat* _sample_idx=0, bool _is_regression=false, int max_k=32 );
virtual bool train( const CvMat* _train_data, const CvMat* _responses,
const CvMat* _sample_idx=0, bool is_regression=false,
int _max_k=32, bool _update_base=false );
virtual float find_nearest( const CvMat* _samples, int k, CvMat* results=0,
const float** neighbors=0, CvMat* neighbor_responses=0, CvMat* dist=0 ) const;
virtual void clear();
int get_max_k() const;
int get_var_count() const;
int get_sample_count() const;
bool is_regression() const;
protected:
virtual float write_results( int k, int k1, int start, int end,
const float* neighbor_responses, const float* dist, CvMat* _results,
CvMat* _neighbor_responses, CvMat* _dist, Cv32suf* sort_buf ) const;
virtual void find_neighbors_direct( const CvMat* _samples, int k, int start, int end,
float* neighbor_responses, const float** neighbors, float* dist ) const;
int max_k, var_count;
int total;
bool regression;
CvVectors* samples;
};
#include "_ml.h"
/****************************************************************************************\
* K-Nearest Neighbors Classifier *
\****************************************************************************************/
// k Nearest Neighbors
CvKNearest::CvKNearest()
{
samples = 0;
clear();
}
CvKNearest::~CvKNearest()
{
clear();
}
CvKNearest::CvKNearest( const CvMat* _train_data, const CvMat* _responses,
const CvMat* _sample_idx, bool _is_regression, int _max_k )
{
samples = 0;
train( _train_data, _responses, _sample_idx, _is_regression, _max_k, false );
}
void CvKNearest::clear()
{
while( samples )
{
CvVectors* next_samples = samples->next;
cvFree( &samples->data.fl );
cvFree( &samples );
samples = next_samples;
}
var_count = 0;
total = 0;
max_k = 0;
}
int CvKNearest::get_max_k() const { return max_k; }
int CvKNearest::get_var_count() const { return var_count; }
bool CvKNearest::is_regression() const { return regression; }
int CvKNearest::get_sample_count() const { return total; }
bool CvKNearest::train( const CvMat* _train_data, const CvMat* _responses,
const CvMat* _sample_idx, bool _is_regression,
int _max_k, bool _update_base )
{
bool ok = false;
CvMat* responses = 0;
CV_FUNCNAME( "CvKNearest::train" );
__BEGIN__;
CvVectors* _samples;
float** _data;
int _count, _dims, _dims_all, _rsize;
if( !_update_base )
clear();
// Prepare training data and related parameters.
// Treat categorical responses as ordered - to prevent class label compression and
// to enable entering new classes in the updates
CV_CALL( cvPrepareTrainData( "CvKNearest::train", _train_data, CV_ROW_SAMPLE,
_responses, CV_VAR_ORDERED, 0, _sample_idx, true, (const float***)&_data,
&_count, &_dims, &_dims_all, &responses, 0, 0 ));
if( _update_base && _dims != var_count )
CV_ERROR( CV_StsBadArg, "The newly added data have different dimensionality" );
if( !_update_base )
{
if( _max_k < 1 )
CV_ERROR( CV_StsOutOfRange, "max_k must be a positive number" );
regression = _is_regression;
var_count = _dims;
max_k = _max_k;
}
_rsize = _count*sizeof(float);
CV_CALL( _samples = (CvVectors*)cvAlloc( sizeof(*_samples) + _rsize ));
_samples->next = samples;
_samples->type = CV_32F;
_samples->data.fl = _data;
_samples->count = _count;
total += _count;
samples = _samples;
memcpy( _samples + 1, responses->data.fl, _rsize );
ok = true;
__END__;
return ok;
}
void CvKNearest::find_neighbors_direct( const CvMat* _samples, int k, int start, int end,
float* neighbor_responses, const float** neighbors, float* dist ) const
{
int i, j, count = end - start, k1 = 0, k2 = 0, d = var_count;
CvVectors* s = samples;
for( ; s != 0; s = s->next )
{
int n = s->count;
for( j = 0; j < n; j++ )
{
for( i = 0; i < count; i++ )
{
double sum = 0;
Cv32suf si;
const float* v = s->data.fl[j];
const float* u = (float*)(_samples->data.ptr + _samples->step*(start + i));
Cv32suf* dd = (Cv32suf*)(dist + i*k);
float* nr;
const float** nn;
int t, ii, ii1;
for( t = 0; t <= d - 4; t += 4 )
{
double t0 = u[t] - v[t], t1 = u[t+1] - v[t+1];
double t2 = u[t+2] - v[t+2], t3 = u[t+3] - v[t+3];
sum += t0*t0 + t1*t1 + t2*t2 + t3*t3;
}
for( ; t < d; t++ )
{
double t0 = u[t] - v[t];
sum += t0*t0;
}
si.f = (float)sum;
for( ii = k1-1; ii >= 0; ii-- )
if( si.i > dd[ii].i )
break;
if( ii >= k-1 )
continue;
nr = neighbor_responses + i*k;
nn = neighbors ? neighbors + (start + i)*k : 0;
for( ii1 = k2 - 1; ii1 > ii; ii1-- )
{
dd[ii1+1].i = dd[ii1].i;
nr[ii1+1] = nr[ii1];
if( nn ) nn[ii1+1] = nn[ii1];
}
dd[ii+1].i = si.i;
nr[ii+1] = ((float*)(s + 1))[j];
if( nn )
nn[ii+1] = v;
}
k1 = MIN( k1+1, k );
k2 = MIN( k1, k-1 );
}
}
}
float CvKNearest::write_results( int k, int k1, int start, int end,
const float* neighbor_responses, const float* dist,
CvMat* _results, CvMat* _neighbor_responses,
CvMat* _dist, Cv32suf* sort_buf ) const
{
float result = 0.f;
int i, j, j1, count = end - start;
double inv_scale = 1./k1;
int rstep = _results && !CV_IS_MAT_CONT(_results->type) ? _results->step/sizeof(result) : 1;
for( i = 0; i < count; i++ )
{
const Cv32suf* nr = (const Cv32suf*)(neighbor_responses + i*k);
float* dst;
float r;
if( _results || start+i == 0 )
{
if( regression )
{
double s = 0;
for( j = 0; j < k1; j++ )
s += nr[j].f;
r = (float)(s*inv_scale);
}
else
{
int prev_start = 0, best_count = 0, cur_count;
Cv32suf best_val;
for( j = 0; j < k1; j++ )
sort_buf[j].i = nr[j].i;
for( j = k1-1; j > 0; j-- )
{
bool swap_fl = false;
for( j1 = 0; j1 < j; j1++ )
if( sort_buf[j1].i > sort_buf[j1+1].i )
{
int t;
CV_SWAP( sort_buf[j1].i, sort_buf[j1+1].i, t );
swap_fl = true;
}
if( !swap_fl )
break;
}
best_val.i = 0;
for( j = 1; j <= k1; j++ )
if( j == k1 || sort_buf[j].i != sort_buf[j-1].i )
{
cur_count = j - prev_start;
if( best_count < cur_count )
{
best_count = cur_count;
best_val.i = sort_buf[j-1].i;
}
prev_start = j;
}
r = best_val.f;
}
if( start+i == 0 )
result = r;
if( _results )
_results->data.fl[(start + i)*rstep] = r;
}
if( _neighbor_responses )
{
dst = (float*)(_neighbor_responses->data.ptr +
(start + i)*_neighbor_responses->step);
for( j = 0; j < k1; j++ )
dst[j] = nr[j].f;
for( ; j < k; j++ )
dst[j] = 0.f;
}
if( _dist )
{
dst = (float*)(_dist->data.ptr + (start + i)*_dist->step);
for( j = 0; j < k1; j++ )
dst[j] = dist[j + i*k];
for( ; j < k; j++ )
dst[j] = 0.f;
}
}
return result;
}
float CvKNearest::find_nearest( const CvMat* _samples, int k, CvMat* _results,
const float** _neighbors, CvMat* _neighbor_responses, CvMat* _dist ) const
{
float result = 0.f;
bool local_alloc = false;
float* buf = 0;
const int max_blk_count = 128, max_buf_sz = 1 << 12;
CV_FUNCNAME( "CvKNearest::find_nearest" );
__BEGIN__;
int i, count, count_scale, blk_count0, blk_count = 0, buf_sz, k1;
if( !samples )
CV_ERROR( CV_StsError, "The search tree must be constructed first using train method" );
if( !CV_IS_MAT(_samples) ||
CV_MAT_TYPE(_samples->type) != CV_32FC1 ||
_samples->cols != var_count )
CV_ERROR( CV_StsBadArg, "Input samples must be floating-point matrix (x)" );
if( _results && (!CV_IS_MAT(_results) ||
_results->cols != 1 && _results->rows != 1 ||
_results->cols + _results->rows - 1 != _samples->rows) )
CV_ERROR( CV_StsBadArg,
"The results must be 1d vector containing as much elements as the number of samples" );
if( _results && CV_MAT_TYPE(_results->type) != CV_32FC1 &&
(CV_MAT_TYPE(_results->type) != CV_32SC1 || regression))
CV_ERROR( CV_StsUnsupportedFormat,
"The results must be floating-point or integer (in case of classification) vector" );
if( k < 1 || k > max_k )
CV_ERROR( CV_StsOutOfRange, "k must be within 1..max_k range" );
if( _neighbor_responses )
{
if( !CV_IS_MAT(_neighbor_responses) || CV_MAT_TYPE(_neighbor_responses->type) != CV_32FC1 ||
_neighbor_responses->rows != _samples->rows || _neighbor_responses->cols != k )
CV_ERROR( CV_StsBadArg,
"The neighbor responses (if present) must be floating-point matrix of x size" );
}
if( _dist )
{
if( !CV_IS_MAT(_dist) || CV_MAT_TYPE(_dist->type) != CV_32FC1 ||
_dist->rows != _samples->rows || _dist->cols != k )
CV_ERROR( CV_StsBadArg,
"The distances from the neighbors (if present) must be floating-point matrix of x size" );
}
count = _samples->rows;
count_scale = k*2*sizeof(float);
blk_count0 = MIN( count, max_blk_count );
buf_sz = MIN( blk_count0 * count_scale, max_buf_sz );
blk_count0 = MAX( buf_sz/count_scale, 1 );
blk_count0 += blk_count0 % 2;
blk_count0 = MIN( blk_count0, count );
buf_sz = blk_count0 * count_scale + k*sizeof(float);
k1 = get_sample_count();
k1 = MIN( k1, k );
if( buf_sz <= CV_MAX_LOCAL_SIZE )
{
buf = (float*)cvStackAlloc( buf_sz );
local_alloc = true;
}
else
CV_CALL( buf = (float*)cvAlloc( buf_sz ));
for( i = 0; i < count; i += blk_count )
{
blk_count = MIN( count - i, blk_count0 );
float* neighbor_responses = buf;
float* dist = buf + blk_count*k;
Cv32suf* sort_buf = (Cv32suf*)(dist + blk_count*k);
find_neighbors_direct( _samples, k, i, i + blk_count,
neighbor_responses, _neighbors, dist );
float r = write_results( k, k1, i, i + blk_count, neighbor_responses, dist,
_results, _neighbor_responses, _dist, sort_buf );
if( i == 0 )
result = r;
}
__END__;
if( !local_alloc )
cvFree( &buf );
return result;
}
/* End of file */