Fisher线性判别

在理解Fisher线性分类的参考代码基础上(matlab代码),改用python代码完成Fisher判别的推导。重点理解“群内离散度”(样本类内离散矩阵)、“群间离散度”(总类内离散矩阵)的概念和几何意义

      • 一、Fisher算法描述
        • (1) W的确定
        • (2) 阈值的确定
        • (3)Fisher线性判别的决策规则
      • 二、Python实现
        • 1. 代码
        • 2. 结果显示:

一、Fisher算法描述

Fisher线性判别分析的基本思想:选择一个投影方向(线性变换,线性组合),将高维问题降低到一维问题来解决,同时变换后的一维数据满足每一类内部的样本尽可能聚集在一起,不同类的样本相隔尽可能地远。
Fisher线性判别分析,就是通过给定的训练数据,确定投影方向W和阈值w0, 即确定线性判别函数,然后根据这个线性判别函数,对测试数据进行测试,得到测试数据的类别。
线性判别函数的一般形式可以表示为: g ( x ) = W T X + w 0 g(x)=W^TX+w_0 g(x)=WTX+w0其中:
X = ( x 1 ⋮ x d ) , W = ( w 1 w 2 ⋮ w d ) X= \left( \begin{matrix} x_1\\ \vdots \\ x_d\\ \end{matrix} \right), W= \left( \begin{matrix} w_1\\ w_2\\ \vdots \\ w_d\\ \end{matrix} \right) X=x1xd,W=w1w2wd
Fisher选择投影方向W的原则,即使原样本向量在该方向上的投影能兼顾类间分布尽可能分开,类内样本投影尽可能密集的要求。 如下为具体步骤:

(1) W的确定

各类样本均值向量 m i m_i mi:
m i = 1 N i ∑ i ∈ X i x , i = 1 , 2 m_i = \frac {1}{N_i}\sum_{i\in X_i} x, i=1,2 mi=Ni1iXix,i=1,2
样本类内离散度矩阵 S i S_i Si和总类内离散度矩阵 S w S_w Sw:
S i = ∑ x ∈ X i ( x − m i ) ( x − m i ) T , i = 1 , 2 S_i=\sum_{x \in X_i}(x-m_i)(x-m_i)^T,i=1,2 Si=xXi(xmi)(xmi)T,i=1,2
S w = S 1 + S 2 S_w = S_1+S_2 Sw=S1+S2
样本类间离散度矩阵 S b S_b Sb:
S b = ( m 1 − m 2 ) ( m 1 − m 2 ) T S_b=(m_1-m_2)(m_1-m_2)^T Sb=(m1m2)(m1m2)T

在投影后的一维空间中,各类样本均值 m i ′ = W T m i m_i'=W^Tm_i mi=WTmi

样本类内离散度和总类内离散度 S i ′ = W T S i W S w ′ = W T S w W S_i'=W^TS_iWS_w'=W^TS_wW Si=WTSiWSw=WTSwW

样本类间离散度 S b ′ = W T S b W S_b'=W^TS_bW Sb=WTSbW

Fisher准则函数为max J F ( W ) = ( m 1 ~ − m 2 ~ ) 2 S 1 ~ 2 + S 2 ~ 2 J_F(W)=\frac{(\tilde{m_1} - \tilde{m_2})^2}{\tilde{S_1}^2+\tilde{S_2}^2} JF(W)=S1~2+S2~2(m1~m2~)2

(2) 阈值的确定

w 0 w_0 w0是个常数,称为阈值权,对于连累问题的线性分类器可以采用下属决策规则:
g ( x ) = g 1 ( x ) − g 2 ( x ) g(x)=g_1(x)-g_2(x) g(x)=g1(x)g2(x),则:
如果g(x)>0,则决策x属于 w 1 w_1 w1;若g(x)<0,则决策x属于 w 2 w_2 w2;如果g(x)=0,则可将x任意分到某一类,或拒绝。

(3)Fisher线性判别的决策规则

Fisher准则函数满足两个性质:
1.投影后,各类样本内部尽可能密集,即总类内离散度越小越好。
2.投影后,各类样本尽可能离得远,即样本类间离散度越大越好。
根据这个性质确定准则函数,根据使准则函数取得最大值,可求出W:
W = S w − 1 ( m 1 − m 2 ) W=S_w^{-1}(m_1-m_2) W=Sw1(m1m2)
这就是Fisher判别准则下的最优投影方向。

最后得到决策规则:
若 g ( x ) = w T ( x − 1 2 ( m 1 + m 2 ) ) 大 于 或 小 于 l o g P ( w 2 ) P ( w 1 ) , 则 x ∈ { w 1 w 2 若g(x)=w^T(x-\frac 1 2(m_1+m_2))大于或小于log\frac {P_{(w_2)}}{P_{(w_1)}},则x\in \left\{ \begin{array}{lr} w_1 \\ \\ w_2 \end{array} \right. g(x)=wT(x21(m1+m2))logP(w1)P(w2)xw1w2
对于某一个未知类别的样本向量x,如果 y = W T , x > y 0 y=W^T,x>y_0 y=WT,x>y0,则 x ∈ w 1 x\in w_1 xw1;否则 x ∈ w 2 x\in w_2 xw2

二、Python实现

1. 代码

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt 
import seaborn as sns
path=r'media/Iris.csv'
df = pd.read_csv(path, header=0)
Iris1=df.values[0:50,0:4]
Iris2=df.values[50:100,0:4]
Iris3=df.values[100:150,0:4]
m1=np.mean(Iris1,axis=0)
m2=np.mean(Iris2,axis=0)
m3=np.mean(Iris3,axis=0)
s1=np.zeros((4,4))
s2=np.zeros((4,4))
s3=np.zeros((4,4))
for i in range(0,30,1):
    a=Iris1[i,:]-m1
    a=np.array([a])
    b=a.T
    s1=s1+np.dot(b,a)    
for i in range(0,30,1):
    c=Iris2[i,:]-m2
    c=np.array([c])
    d=c.T
    s2=s2+np.dot(d,c) 
    #s2=s2+np.dot((Iris2[i,:]-m2).T,(Iris2[i,:]-m2))
for i in range(0,30,1):
    a=Iris3[i,:]-m3
    a=np.array([a])
    b=a.T
    s3=s3+np.dot(b,a) 
sw12=s1+s2
sw13=s1+s3
sw23=s2+s3
#投影方向
a=np.array([m1-m2])
sw12=np.array(sw12,dtype='float')
sw13=np.array(sw13,dtype='float')
sw23=np.array(sw23,dtype='float')
#判别函数以及T
#需要先将m1-m2转化成矩阵才能进行求其转置矩阵
a=m1-m2
a=np.array([a])
a=a.T
b=m1-m3
b=np.array([b])
b=b.T
c=m2-m3
c=np.array([c])
c=c.T
w12=(np.dot(np.linalg.inv(sw12),a)).T
w13=(np.dot(np.linalg.inv(sw13),b)).T
w23=(np.dot(np.linalg.inv(sw23),c)).T
#print(m1+m2) #1x4维度  invsw12 4x4维度  m1-m2 4x1维度
T12=-0.5*(np.dot(np.dot((m1+m2),np.linalg.inv(sw12)),a))
T13=-0.5*(np.dot(np.dot((m1+m3),np.linalg.inv(sw13)),b))
T23=-0.5*(np.dot(np.dot((m2+m3),np.linalg.inv(sw23)),c))
kind1=0
kind2=0
kind3=0
newiris1=[]
newiris2=[]
newiris3=[]
for i in range(30,49):
    x=Iris1[i,:]
    x=np.array([x])
    g12=np.dot(w12,x.T)+T12
    g13=np.dot(w13,x.T)+T13
    g23=np.dot(w23,x.T)+T23
    if g12>0 and g13>0:
        newiris1.extend(x)
        kind1=kind1+1
    elif g12<0 and g23>0:
        newiris2.extend(x)
    elif g13<0 and g23<0 :
        newiris3.extend(x)
#print(newiris1)
for i in range(30,49):
    x=Iris2[i,:]
    x=np.array([x])
    g12=np.dot(w12,x.T)+T12
    g13=np.dot(w13,x.T)+T13
    g23=np.dot(w23,x.T)+T23
    if g12>0 and g13>0:
        newiris1.extend(x)
    elif g12<0 and g23>0:
 
        newiris2.extend(x)
        kind2=kind2+1
    elif g13<0 and g23<0 :
        newiris3.extend(x)
for i in range(30,50):
    x=Iris3[i,:]
    x=np.array([x])
    g12=np.dot(w12,x.T)+T12
    g13=np.dot(w13,x.T)+T13
    g23=np.dot(w23,x.T)+T23
    if g12>0 and g13>0:
        newiris1.extend(x)
    elif g12<0 and g23>0:     
        newiris2.extend(x)
    elif g13<0 and g23<0 :
        newiris3.extend(x)
        kind3=kind3+1
correct=(kind1+kind2+kind3)/60
print("样本类内离散度矩阵S1:",s1,'\n')
print("样本类内离散度矩阵S2:",s2,'\n')
print("样本类内离散度矩阵S3:",s3,'\n')
print('-----------------------------------------------------------------------------------------------')
print("总体类内离散度矩阵Sw12:",sw12,'\n')
print("总体类内离散度矩阵Sw13:",sw13,'\n')
print("总体类内离散度矩阵Sw23:",sw23,'\n')
print('-----------------------------------------------------------------------------------------------')
print('判断出来的综合正确率:',correct*100,'%')

2. 结果显示:

Fisher线性判别_第1张图片
判别的方法还有许多,比如贝叶斯算法,BP神经网络算法,K-means算法等不同的算法,所得到的结果也不一样。fisher判别的算法,感觉还是比较准确了,其他算法的话可以去搜索相关的文章看看。

本次fisher线性判别算法,就结束了。

你可能感兴趣的:(人工智能作业,机器学习,Fisher线性判决)