Python-线性判别分析(Fisher判别分析)使用鸢尾花数据集 Iris

代码实现

例如鸢尾花数据集,将数据集分为三类样本,然后得到三个总体类离散度矩阵,三个总体类离散度矩阵根据上述公式计算即可。
IRIS数据集以鸢尾花的特征作为数据来源,数据集包含150个数据集,有4维,分为3 类,每类50个数据,每个数据包含4个属性,是在数据挖掘、数据分类中非常常用的测试集、训练集。

解决数据读出来只有149行

pandas因为版本的问题(我猜的),现在会默认的不读取第一行数据,也就是表头,但iris数据中,没有表头,全是数据.

所以,我们需要使用如下的方式,读取数据集

df = pd.read_csv(r'http://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data',header = None)

实话实话,我还是喜欢用http的方式直接引用iris数据集,下载的话,实在是太麻烦了.

完整代码

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt 
import seaborn as sns

df = pd.read_csv(r'http://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data',header = None)
Iris=df.values
display(len(list(df.values)))


Iris1=df.values[0:50,0:4]
Iris2=df.values[50:100,0:4]
Iris3=df.values[100:150,0:4]

m1=np.mean(Iris1,axis=0)
m2=np.mean(Iris2,axis=0)
m3=np.mean(Iris3,axis=0)

s1=np.zeros((4,4))
s2=np.zeros((4,4))
s3=np.zeros((4,4))
for i in range(0,31):
    a=Iris1[i,:]-m1
    a=np.array([a])
    b=a.T
    s1=s1+np.dot(b,a)    
for i in range(0,31):
    c=Iris2[i,:]-m2
    c=np.array([c])
    d=c.T
    s2=s2+np.dot(d,c) 
for i in range(0,30,1):
    a=Iris3[i,:]-m3
    a=np.array([a])
    b=a.T
    s3=s3+np.dot(b,a) 
sw12=s1+s2
sw13=s1+s3
sw23=s2+s3


# 投影方向
a=np.array([m1-m2])
sw12=np.array(sw12,dtype='float')
sw13=np.array(sw13,dtype='float')
sw23=np.array(sw23,dtype='float')
#  判别函数以及T
#  需要先将m1-m2转化成矩阵才能进行求其转置矩阵
a=m1-m2
a=np.array([a])
a=a.T
b=m1-m3
b=np.array([b])
b=b.T
c=m2-m3
c=np.array([c])
c=c.T
w12=(np.dot(np.linalg.inv(sw12),a)).T
w13=(np.dot(np.linalg.inv(sw13),b)).T
w23=(np.dot(np.linalg.inv(sw23),c)).T
#  print(m1+m2) #1x4维度  invsw12 4x4维度  m1-m2 4x1维度
T12=-0.5*(np.dot(np.dot((m1+m2),np.linalg.inv(sw12)),a))
T13=-0.5*(np.dot(np.dot((m1+m3),np.linalg.inv(sw13)),b))
T23=-0.5*(np.dot(np.dot((m2+m3),np.linalg.inv(sw23)),c))
kind1=0
kind2=0
kind3=0
newiris1=[]
newiris2=[]
newiris3=[]
for i in range(30,50):
    x=Iris1[i,:]
    x=np.array([x])
    g12=np.dot(w12,x.T)+T12
    g13=np.dot(w13,x.T)+T13
    g23=np.dot(w23,x.T)+T23
    if g12>0 and g13>0:
        newiris1.extend(x)
        kind1=kind1+1
    elif g12<0 and g23>0:
        newiris2.extend(x)
    elif g13<0 and g23<0 :
        newiris3.extend(x)
#print(newiris1)
for i in range(30,50):
    x=Iris2[i,:]
    x=np.array([x])
    g12=np.dot(w12,x.T)+T12
    g13=np.dot(w13,x.T)+T13
    g23=np.dot(w23,x.T)+T23
    if g12>0 and g13>0:
        newiris1.extend(x)
    elif g12<0 and g23>0:
 
        newiris2.extend(x)
        kind2=kind2+1
    elif g13<0 and g23<0 :
        newiris3.extend(x)
for i in range(30,50):
    x=Iris3[i,:]
    x=np.array([x])
    g12=np.dot(w12,x.T)+T12
    g13=np.dot(w13,x.T)+T13
    g23=np.dot(w23,x.T)+T23
    if g12>0 and g13>0:
        newiris1.extend(x)
    elif g12<0 and g23>0:     
        newiris2.extend(x)
    elif g13<0 and g23<0 :
        newiris3.extend(x)
        kind3=kind3+1
correct=(kind1+kind2+kind3)/60
display("样本类内离散度矩阵S1:",s1,'\n')
display("样本类内离散度矩阵S2:",s2,'\n')
display("样本类内离散度矩阵S3:",s3,'\n')
display('------------------------------------------------------------------------------')
display("总体类内离散度矩阵Sw12:",sw12,'\n')
display("总体类内离散度矩阵Sw13:",sw13,'\n')
display("总体类内离散度矩阵Sw23:",sw23,'\n')
display('------------------------------------------------------------------------------')
display('判断出来的综合正确率:',correct*100,'%')

代码的执行效果
Python-线性判别分析(Fisher判别分析)使用鸢尾花数据集 Iris_第1张图片

Python-线性判别分析(Fisher判别分析)使用鸢尾花数据集 Iris_第2张图片

数据可视化

散点图

from sklearn import datasets
import matplotlib.pyplot as plt
import numpy as np
 
iris = datasets.load_iris()
irisFeatures = iris["data"]
irisFeaturesName = iris["feature_names"]
irisLabels = iris["target"]
 
def scatter_plot(dim1, dim2):
    for t,marker,color in zip(range(3),">ox","rgb"):
           # zip()接受任意多个序列参数,返回一个元组tuple列表
        # 用不同的标记和颜色画出每种品种iris花朵的前两维数据
        # We plot each class on its own to get different colored markers
        plt.scatter(irisFeatures[irisLabels == t,dim1],
                    irisFeatures[irisLabels == t,dim2],marker=marker,c=color)
    dim_meaning = {0:'setal length',1:'setal width',2:'petal length',3:'petal width'}
    plt.xlabel(dim_meaning.get(dim1))
    plt.ylabel(dim_meaning.get(dim2))
 
plt.subplot(231)
scatter_plot(0,1)
plt.subplot(232)
scatter_plot(0,2)
plt.subplot(233)
scatter_plot(0,3)
plt.subplot(234)
scatter_plot(1,2)
plt.subplot(235)
scatter_plot(1,3)
plt.subplot(236)
scatter_plot(2,3)
 
plt.show()

代码执行的效果
Python-线性判别分析(Fisher判别分析)使用鸢尾花数据集 Iris_第3张图片

你可能感兴趣的:(Python-线性判别分析(Fisher判别分析)使用鸢尾花数据集 Iris)