在机器学习中,训练数据可能会出现异常数据,部分异常数据在线性回归模型中,将会影响线性回归的拟合效果,误导模型的预测。而RANSAC算法是能够排除异常数据干扰的一种回归算法。英文名称:Random Sample Consensus,“随机采样一致”,这是一个基于线性回归思想的算法。
# RANSAC算法是能够排除异常数据干扰的一个回归算法
import numpy as np
import random
import matplotlib.pyplot as plt
class LinearRegression():
def fit(self,X,y):
self.w = np.linalg.inv(X.T.dot(X)).dot(X.T).dot(y)
pass
def predict(self,X):
return X.dot(self.w)
# RANSAC算法
def RANSAC(self,X,y,N,d,k):
m,n = X.shape
w_list = []
r_list = []
t = 0
while t<=N:
# 随机取子集
a,b = random.sample(range(m),2)
if a>b:
a,b = b,a
# 子集数量大于k,即要求的取点数量
if b-a>=k:
# 根据子集建模
self.fit(X[a:b,:],y[a:b])
y_true = y
y_pred = self.predict(X)
# 计算训练数据,回归数据的误差
B = abs(y_true-y_pred)
Bt = []
By = []
# 选择出误差小于d的数据
for i in range(len(B)):
if B[i]<d:
Bt.append(X[i,:])
By.append(y[i])
Bt = np.array(Bt)
By = np.array(By)
# 若误差小于d的数据数量大于k,则使用这些数据建立新的墨香
if len(Bt)>k:
self.fit(Bt,By)
y_pred =self.predict(Bt)
y_true = By
# 计算均方误差,并保存本次模型与均方误差
r =mean_squared_error(y_true,y_pred)
w_list.append(self.w)
r_list.append(r)
t = t+1
pass
pass
index_min = np.argmin(r_list)
self.w = w_list[index_min]
pass
# 均方误差
# np.average() 用于求平均
def mean_squared_error(y_true,y_pred):
return np.average((y_true-y_pred)**2,axis=0)
# 生成异常数据
def generate_samples(m,k):
X_normal = 2 * (np.random.rand(m,1)-0.5)
y_normal = X_normal+np.random.normal(0,0.1,(m,1))
X_outlier = 2 * (np.random.rand(k,1)-0.5)
y_outlier = X_outlier+np.random.normal(3,0.1,(k,1))
X = np.concatenate((X_normal,X_outlier),axis=0)
y = np.concatenate((y_normal,y_outlier),axis= 0)
return X, y
def process_features(X):
m,n = X.shape
X = np.c_[np.ones((m,1)),X]
return X
np.random.seed(0)
X_original,y = generate_samples(100,5)
X = process_features(X_original)
model =LinearRegression()
model.fit(X,y)
y_pred = model.predict(X)
model.RANSAC(X,y,1000,3,30)
y_RANSAC = model.predict(X)
print(mean_squared_error(y,y_RANSAC))
plt.scatter(X_original,y)
plt.plot(X_original,y_pred)
plt.plot(X_original,y_RANSAC,color = "red")
plt.show()
该算法在线性回归的基础上需要引入3个新参数–N,d,k。算法的实现主要是通过N轮循环生成N个模型h,并从N个模型中寻找均方误差最小的模型,作为算法的模型输出。
在第t次循环中,从训练数据S中随机抽取一个子集St,通过线性回归算法计算出线性模型ht,来拟合St中的数据,并分析在该线性模型ht下,训练数据中所有与模型下预测值的误差不超过d的点的个数,如果这些点的个数超过k,则线性回归算法计算出一个新的模型ht来拟合这些点,并将本次循环得到的模型更新。
图像为算法运行结果,其中的蓝线代表不采用RANSAC算法建立的线性回归模型,该模型因顶部5个严重偏离的异常数据而导致模型严重上移,影响了线性回归的整体拟合效果。而红线采用了RANSAC算法,在一定程度上减弱了异常数据的影响,从而使数据拟合效果得到优化。
该算法主要通过误差d去排除异常数据,合适的d可以使拟合效果更优秀。
RANSAC算法应用了随机子集,通过循环后建立多个模型,选其中最优模型,那么如果直接通过误差排除掉异常数据点呢?
# 根据误差,排除数据差异大的算法
def RANSAC_selfthought(self,X,y,d):
self.fit(X,y)
y_pred = self.predict(X)
Bt = []
By = []
B = abs(y - y_pred)
for i in range(len(B)):
if B[i]<d:
Bt.append(X[i,:])
By.append(y[i])
Bt = np.array(Bt)
By = np.array(By)
self.fit(Bt,By)
pass