RANSAC的简化版代码

Figure_1.png
import numpy as np
import math
import random
import matplotlib.pyplot as plt

def compute_all_choices(x,y):
    iterations = 2000
    p_count = x.shape[0]
    inlier_dist_thresh = 0.25
    sample_count = 0
    P= 0.99
    preinliers = 0
    bestk = 0
    bestb = 0
    inlier_num_thresh = int(p_count*0.4)
    while(iterations>sample_count):
        sample_idx = [random.randint(0,p_count-1),random.randint(0,p_count-1)]
        if sample_idx[0]==sample_idx[1]:
            continue
        x_1 = x[sample_idx[0]]
        x_2 = x[sample_idx[1]]
        y_1 = y[sample_idx[0]]
        y_2 = y[sample_idx[1]]
        k = (y_1-y_2)/(x_1-x_2)
        b = y_1 - k*x_1
        total_inlier = 0
        for i in range (p_count):
            y_hat = k * x[i] + b
            if abs(y_hat - y[i]) < inlier_dist_thresh:
                total_inlier += 1
        if total_inlier > preinliers:
            preinliers = total_inlier
            iterations = math.log(1-P)/math.log(1-math.pow(total_inlier/float(p_count),2))
            bestk = k
            bestb = b
        if total_inlier > inlier_num_thresh:
            break
        
    return bestk, bestb


def main():
    # y = 2x+5
    X = np.array([random.uniform(1,10) for i in range(60)])
    y = 2*X+5
    randomness = np.array([random.uniform(-0.3,0.3) for i in range(60)])
    y += randomness
    for i in range(18):
        y[2*i+2] = y[2*i+2] + random.uniform(-15,15)
    
    k,b = compute_all_choices(X,y)
    vizy = X*k+b
    plt.title("demo")
    plt.xlabel("x")
    plt.ylabel("y")
    plt.scatter(X,y)
    plt.plot(X,vizy)
    plt.show()

if __name__ == '__main__':
    main()

原理

https://blog.csdn.net/zhoucoolqi/article/details/105497572
n – 用于拟合的最小数据组数.
k – 算法规定的最大遍历次数.
t – 数据和模型匹配程度的阈值,在t范围内即inliers,在范围外即outliers.
d – 表示模型合适的最小数据组数.

关于最大遍历次数的更新

你可能感兴趣的:(RANSAC的简化版代码)