假使需要从一个含有较多杂乱数据的数据集中提取到理想的模型(比如有70%的数据都不符合模型)时,最小二乘法就难以拟合出符合30%的数据误差最小的模型。
使用ransac算法可以很好的处理异常数据:包括一维数据剔除异常值,二维数据点剔除离群点拟合出一条直线,三维空间点剔除异常点拟合出一个平面,下面介绍通过Ransac算法拟合直线将一维数据进行划分为2类,实现点拟合和异常值剔除功能,并且可视化到图像上。
K K K为需要采样的次数; z z z为获取一个好样本的概率,一般设为99%; w w w为点集中内点的比例,一般可以在初始时设置一个较小值如0.1,然后迭代更新; n n n为模型参数估计需要的最小点个数,直线拟合最少需要2个点。
1. RANSAC参数计算的迭代次数没有上限;如果设置迭代次数太少,得到的结果可能不是最优的结果,甚至可能得到错误的结果,得到最优结果的概率与迭代次数成正比。
2. RANSAC可以对局外点进行剔除,这一点是比较好的,但它也并不是完美的,因为ransac拟合的直线必须要经过数据中的2点,当它对于数据分布不太随机时结果是不理想的,比如拟合两条平近似平行直线分布的点时,拟合出来的直线(下图红色)会不平行于数据分布直线,而理想的结果(下图蓝色)是和直线分布平行,如下图所示,绿色是内点,黑色是外点;
考虑将RANSAC和最小二乘结合进行,这样可以结合二者的优势,得到较为理想的结果。最小二乘也有其局限性,一是它没有剔除局外点,拟合的是全局最优的,当局外点误差过大时,会极大的影响结果得到不正确的结果;
#include "opencv2/opencv.hpp"
#include
using namespace std;
using namespace cv;
void calcLinePara(vector<Point2f> pts, double &a, double &b, double &c, double &res)
{
res = 0;
Vec4f line;
vector<Point2f> ptsF;
for (unsigned int i = 0; i < pts.size(); i++)
ptsF.push_back(pts[i]);
fitLine(ptsF, line, CV_DIST_L2, 0, 1e-2, 1e-2);
a = line[1];
b = -line[0];
c = line[0] * line[3] - line[1] * line[2];
for (unsigned int i = 0; i < pts.size(); i++)
{
double resid_ = fabs(pts[i].x * a + pts[i].y * b + c);
res += resid_;
}
res /= pts.size();
}
bool getSample(vector<int> set, vector<int> &sset)
{
int i[2];
if (set.size() > 2)
{
do
{
for (int n = 0; n < 2; n++) {
double x = rand() / (double)RAND_MAX;
i[n] = int(x * (set.size() - 1));
}
} while (!(i[1] != i[0]));
for (int n = 0; n < 2; n++)
{
sset.push_back(i[n]);
}
}
else
{
return false;
}
return true;
}
bool verifyComposition(const vector<Point2f> pts)
{
cv::Point2f pt1 = pts[0];
cv::Point2f pt2 = pts[1];
if (abs(pt1.x - pt2.x) < 5 && abs(pt1.y - pt2.y) < 5)
return false;
return true;
}
void fitLine(vector<Point2f> ptSet, double &a, double &b, double &c, vector<bool> &inlierFlag)
{
double residual_error = 3;
int sample_count = 0;
int N = 300;
double res = 0;
bool stop_loop = false;
int maximum = 0;
inlierFlag = vector<bool>(ptSet.size(), false);
vector<double> resids_(ptSet.size(), 3);
srand((unsigned int)time(NULL));
vector<int> ptsID;
for (unsigned int i = 0; i < ptSet.size(); i++)
ptsID.push_back(i);
while (N > sample_count && !stop_loop)
{
vector<bool> inlierstemp;
vector<double> residualstemp;
vector<int> ptss;
int inlier_count = 0;
if (!getSample(ptsID, ptss))
{
stop_loop = true;
continue;
}
vector<Point2f> pt_sam;
pt_sam.push_back(ptSet[ptss[0]]);
pt_sam.push_back(ptSet[ptss[1]]);
if (!verifyComposition(pt_sam))
{
++sample_count;
continue;
}
calcLinePara(pt_sam, a, b, c, res);
for (unsigned int i = 0; i < ptSet.size(); i++)
{
Point2f pt = ptSet[i];
double resid_ = fabs(pt.x * a + pt.y * b + c);
residualstemp.push_back(resid_);
inlierstemp.push_back(false);
if (resid_ < residual_error)
{
++inlier_count;
inlierstemp[i] = true;
}
}
if (inlier_count >= maximum)
{
maximum = inlier_count;
resids_ = residualstemp;
inlierFlag = inlierstemp;
}
if (inlier_count == 0)
{
N = 500;
}
else
{
double epsilon = 1.0 - double(inlier_count) / (double)ptSet.size();
double p = 0.99;
double s = 2.0;
N = int(log(1.0 - p) / log(1.0 - pow((1.0 - epsilon), s)));
}
++sample_count;
}
vector<Point2f> pset;
for (unsigned int i = 0; i < ptSet.size(); i++)
{
if (inlierFlag[i])
pset.push_back(ptSet[i]);
}
calcLinePara(pset, a, b, c, res);
}
void Ransac(vector<float>data, vector<bool> &isLinePoint)
{
int width = 180;
int height = 120;
vector<Point2f> dataToPoint;
srand((unsigned int)time(NULL));
for (size_t i = 0; i < data.size(); i++)
{
double x = rand() / (double)RAND_MAX;
//data[i] = data[i] * height ; //注意数据尽量在0-height范围,否则,先归一化数据再*height
Point2f pt(x*width, data[i]);
dataToPoint.push_back(pt);
}
double A, B, C;
fitLine(dataToPoint, A, B, C, isLinePoint);
float k = -A/B;
float b = -C/B;
cout << "k,b= " << k <<","<<b<< endl;
Mat img(height, width, CV_8UC3, Scalar(255, 255, 255));
for (unsigned int i = 0; i < dataToPoint.size(); i++) {
if (isLinePoint[i])
circle(img, dataToPoint[i], 1, Scalar(0, 255, 0), 2, 16);
else
circle(img, dataToPoint[i], 1, Scalar(0, 0, 255), 2, 16);
}
//show line
Point2f P1;
P1.x = 0;
P1.y = b;
Point2f P2;
P2.x = width;
P2.y = k*width + b;
line(img, P1, P2, Scalar(0, 0, 0), 1, 16);
rotate(img, img, ROTATE_180);
imshow("fit_line", img);
}
void main(int argc, char* argv[]) {
vector<float> data = { 29, 39, 115, 109, 23, 24, 64, 34, 23, 24, 23, 25, 51, 89, 21,
24, 24, 26, 21, 20, 21, 22, 22, 22, 24, 50, 12, 11, 55, 79,
56, 87, 23, 26, 25, 99 ,22, 112, 22, 33, 25, 18, 64, 48};
vector<bool> lables;
Ransac(data, lables);
waitKey();
}