The Elements of Statistical Learning 第五章figure5-8python绘图 (修改版)

这里的数据不是来自5-8的数据,需要大家根据书中的进行调整,所用的编程方法只针对问题。
#导入数据和数据预处理
#生成0-1均匀分布的随机数
import numpy as np#调用numpy库,目的是使用矩阵,对数据矩阵进行操作
np.random.seed(12345)
import matplotlib.pyplot as plt#调用matplotlib中的函数pyplot,为画图做准备
uni_data=np.random.uniform(0,1,size=(100,1))
input_x=np.sort(uni_data,axis=0)
#构造一个取正值函数
def pv(x):
    if x<0:
        y=0
    else:
        y=x
    return y
#计算样本矩阵
#因为i的取值最大为100-2,所以循环的次数在98
def getSmoothSpline(x):
    n=[]
    for i in np.arange(98):
        y1=(pv(x-input_x[i])**3-pv(x-input_x[99])**3)/(input_x[99]-input_x[i])
        y2=(pv(x-input_x[98])**3-pv(x-input_x[99])**3)/(input_x[99]-input_x[98])
        a=(y1-y2)[0]
        n.append(a)
    return(n)
N1=np.zeros((100,98))
for j in np.arange(100):
    N1[j,:]=np.array(getSmoothSpline(input_x[j])).reshape(1,98)
N2=np.hstack((np.ones((100,1)),input_x))
N=np.hstack((N2,N1))
###########################求惩罚项的矩阵
##求积分的过程中需要用到两两点横坐标之间的比较,这里先比较大小。
W1=np.zeros((100,100))
for i in np.arange(100):
    for j in np.arange(100):
        W1[i,j]=max(input_x[i],input_x[j])
"""
由于前两个基的二阶导等于0,其余基的二阶导乘以这两项的积分始终等于零,先考虑剩余的基
,最后利用矩阵的方法,合并成最后的惩罚矩阵。下面给出的是,二阶导之后两两两个基分段积分
下面的积分计算过于冗长。
"""
W2=np.zeros((98,98))
for i in np.arange(98):
    for j in np.arange(98):
        h1=36/((input_x[99]-input_x[j])*(input_x[99]-input_x[i]))
        h2=(1/3)*(input_x[98])**3
        h3=(1/2)*(input_x[i]+input_x[j])*input_x[98]**2
        h4=input_x[i]*input_x[j]*(input_x[98])
        h21=(1/3)*(W1[i,j])**3
        h31=(1/2)*(input_x[i]+input_x[j])*(W1[i,j]**2)
        h41=input_x[i]*input_x[j]*(W1[i,j])
        h=h1*(h2-h3+h4-h21+h31-h41)
        H11=36*(input_x[98]-input_x[i])*(input_x[98]-input_x[j])
        H12=(input_x[99]-input_x[98])**2*(input_x[99]-input_x[i])*(input_x[99]-input_x[j])
        H1=H11/H12
        H2=(1/3)*(input_x[99]-input_x[98])**3
        H=H1*H2
        W2[i,j]=H+h
W3=np.zeros((100,100))
W3[2:,2:]=W2
W=W3
#矩阵的运算,求出光滑矩阵,其中S表示的为光滑矩阵,lambda的值取得是0.2
th=N.T.dot(N)+0.2*W
pinvMatrix=np.linalg.pinv(th)
S=(N.dot(pinvMatrix)).dot(N.T)
##画出热图
import seaborn as sns
sns.heatmap(S, cmap='ocean')
plt.show()
The Elements of Statistical Learning 第五章figure5-8python绘图 (修改版)_第1张图片

你可能感兴趣的:(The Elements of Statistical Learning 第五章figure5-8python绘图 (修改版))