阿里云天池大赛工业蒸汽预测代码学习(2)

0```python
#查看异常值的代码
def find_outliers(model,X,y,sigma=3):
#predict y
try:
y_pred=pd.Series(model.predict(X),index=y.index)
except:
model.fit(X,y)
y_pred=pd.Series(model.predict(X),index=y.index)
#用模型预测

resid=y-y_pred#计算残差
mean_resid=resid.mean()#平均残差
std_resid=resid.std()#计算标准差

z=(resid-mean_resid)/std_resid
outliers=z[abs(z)>sigma].index#异常值判定标准

print('R2=',model.score(X,y))
#计算R2系数
print("mse=",mean_squared_error(y,y_pred))#计算均方误差
print('-----------------------------')
print('mean of residuals:',mean_resid)
print('std of residuals:',std_resid)
print('-----------------------------')
print(len(outliers),'outliers:')
print(outliers.tolist())#输出异常值

plt.figure(figsize=(15,5))
ax_131=plt.subplot(1,3,1)
plt.plot(y_pred,'.')
plt.plot(y.loc[outliers],y_pred.loc[outliers],'ro')
plt.legend(['Accepted','Outlier'])
plt.xlabel('y')
plt.ylabel('y_pred');

ax_132=plt.subplot(1,3,2)
plt.plot(y,y-y_pred,'.')

ax_133=plt.subplot(1,3,3)
z.plot.hist(bins=50,ax=ax_133)
z.loc[outliers].plot.hist(color='r',bins=50,ax=ax_133)
plt.legend(['Accepted','Outlier'])
plt.xlabel('z')

plt.savefig('outliers.png')
return outliers

X_train=train_data.iloc[:,0:-1]
y_train=train_data.iloc[:,-1]
outliers=find_outliers(Ridge(),X_train,y_train)

pd.Series():Series是⼀种类似于⼀维数组的对象,它由⼀组数据(各种
NumPy数据类型)以及⼀组与之相关的数据标签(即索引)组
成。仅由⼀组数据即可产⽣最简单的Series。详细见《利用python进行数据分析》

model.score()函数:返回预测的确定系数R²。
R²:系数R^2定义为(1-u/v),其中u是平方的剩余和((y_真-y_pred)²),v是平方的总和((y_-true-y_-true.mean())²)最好的可能分数是1.0,可能是负数(因为模型可能会任意变差)。如果一个常数模型总是预测y的期望值,而不考虑输入特征,则R^2得分为0.0。
plt.subplot():调用 subplot() 函数可以创建一个子图,然后程序就可以在子图上进行绘制。subplot(nrows, ncols, index, **kwargs) 函数的 nrows 参数指定将数据图区域分成多少行;ncols 参数指定将数据图区域分成多少列;index 参数指定获取第几个区域。例如上述代码中的三次调用就是将区域分为一行,三列,分别在第一二三个区域绘图。

plt.plot()函数:plt.plot(x,y,format_string,**kwargs),参数分别为x轴数据,y轴数据,控制曲线的格式字符串(可选)
异常值函数的得到的outliers为误差数据的索引,这里:
 plt.plot(y.loc[outliers],y_pred.loc[outliers],'ro')绘制了真实值,预测值中的异常值
 plt.legend():函数主要的作用就是给图加上图例,plt.legend([x,y,z])里面的参数使用的是list的的形式将图表的的名称喂给这和函数。
 
 总结:以上代码中,判断异常值的标准是选取 z=(resid mean_resid)/std_resid的绝对值大于sigma。
 所以只需要先得到z的索引列表再进行数据可视化即可。


你可能感兴趣的:(机器学习,学习,python,机器学习)