机器学习--RANSAC算法,排除异常数据干扰的回归算法

RANSAC算法

在机器学习中,训练数据可能会出现异常数据,部分异常数据在线性回归模型中,将会影响线性回归的拟合效果,误导模型的预测。而RANSAC算法是能够排除异常数据干扰的一种回归算法。英文名称:Random Sample Consensus,“随机采样一致”,这是一个基于线性回归思想的算法。

代码实现(演示)

# RANSAC算法是能够排除异常数据干扰的一个回归算法
import numpy as np
import random
import matplotlib.pyplot as plt

class LinearRegression():
    def fit(self,X,y):
        self.w = np.linalg.inv(X.T.dot(X)).dot(X.T).dot(y)
        pass
    def predict(self,X):
        return X.dot(self.w)

    # RANSAC算法
        def RANSAC(self,X,y,N,d,k):
        m,n = X.shape
        w_list = []
        r_list = []
        t = 0
        while t<=N:
            # 随机取子集
            a,b = random.sample(range(m),2)
            if a>b:
                a,b = b,a
            # 子集数量大于k,即要求的取点数量
            if b-a>=k:
                # 根据子集建模
                self.fit(X[a:b,:],y[a:b])
                y_true = y
                y_pred = self.predict(X)
                # 计算训练数据,回归数据的误差
                B = abs(y_true-y_pred)
                Bt = []
                By = []

                # 选择出误差小于d的数据
                for i in range(len(B)):
                    if B[i]<d:
                        Bt.append(X[i,:])
                        By.append(y[i])
                Bt = np.array(Bt)
                By = np.array(By)

                # 若误差小于d的数据数量大于k,则使用这些数据建立新的墨香
                if len(Bt)>k:
                    self.fit(Bt,By)
                    y_pred =self.predict(Bt)
                    y_true = By
                # 计算均方误差,并保存本次模型与均方误差
                r =mean_squared_error(y_true,y_pred)
                w_list.append(self.w)
                r_list.append(r)
                t = t+1
                pass
            pass
        index_min = np.argmin(r_list)
        self.w = w_list[index_min]
        pass

# 均方误差
# np.average() 用于求平均
def mean_squared_error(y_true,y_pred):
    return np.average((y_true-y_pred)**2,axis=0)


# 生成异常数据
def generate_samples(m,k):
    X_normal = 2 * (np.random.rand(m,1)-0.5)
    y_normal = X_normal+np.random.normal(0,0.1,(m,1))
    X_outlier = 2 * (np.random.rand(k,1)-0.5)
    y_outlier = X_outlier+np.random.normal(3,0.1,(k,1))
    X = np.concatenate((X_normal,X_outlier),axis=0)
    y = np.concatenate((y_normal,y_outlier),axis= 0)
    return X, y

def process_features(X):
    m,n = X.shape
    X = np.c_[np.ones((m,1)),X]
    return X


np.random.seed(0)
X_original,y = generate_samples(100,5)
X = process_features(X_original)
model =LinearRegression()
model.fit(X,y)
y_pred = model.predict(X)
model.RANSAC(X,y,1000,3,30)
y_RANSAC = model.predict(X)
print(mean_squared_error(y,y_RANSAC))
plt.scatter(X_original,y)
plt.plot(X_original,y_pred)
plt.plot(X_original,y_RANSAC,color = "red")
plt.show()

该算法在线性回归的基础上需要引入3个新参数–N,d,k。算法的实现主要是通过N轮循环生成N个模型h,并从N个模型中寻找均方误差最小的模型,作为算法的模型输出。
在第t次循环中,从训练数据S中随机抽取一个子集St,通过线性回归算法计算出线性模型ht,来拟合St中的数据,并分析在该线性模型ht下,训练数据中所有与模型下预测值的误差不超过d的点的个数,如果这些点的个数超过k,则线性回归算法计算出一个新的模型ht来拟合这些点,并将本次循环得到的模型更新。
机器学习--RANSAC算法,排除异常数据干扰的回归算法_第1张图片
图像为算法运行结果,其中的蓝线代表不采用RANSAC算法建立的线性回归模型,该模型因顶部5个严重偏离的异常数据而导致模型严重上移,影响了线性回归的整体拟合效果。而红线采用了RANSAC算法,在一定程度上减弱了异常数据的影响,从而使数据拟合效果得到优化。
该算法主要通过误差d去排除异常数据,合适的d可以使拟合效果更优秀。

自我思考:

RANSAC算法应用了随机子集,通过循环后建立多个模型,选其中最优模型,那么如果直接通过误差排除掉异常数据点呢?

    # 根据误差,排除数据差异大的算法
    def RANSAC_selfthought(self,X,y,d):
        self.fit(X,y)
        y_pred = self.predict(X)
        Bt = []
        By = []
        B = abs(y - y_pred)
        for i in range(len(B)):
            if B[i]<d:
                Bt.append(X[i,:])
                By.append(y[i])
        Bt = np.array(Bt)
        By = np.array(By)
        self.fit(Bt,By)
        pass

机器学习--RANSAC算法,排除异常数据干扰的回归算法_第2张图片

你可能感兴趣的:(算法,编程实践,编程思想,机器学习)