LDA线性判别分析Python

原理:

     线性判别分析(Linear Discriminant Analysis,简称LDA)是一种经典的监督学习的数据降维方法,也叫做Fisher线性判别(Fisher Linear Discriminant,FLD),是模式识别的经典算法 ,它是在1996年由Belhumeur引入模式识别和人工智能领域的。LDA的主要思想是将一个高维空间中的数据投影到一个较低维的空间中,且投影后要保证各个类别的类内方差小而类间均值差别大,这意味着同一类的高维数据投影到低维空间后相同类别的聚在一起,而不同类别之间相距较远。如下图将二维数据投影到一维直线上

 LDA降维的目标:将带有标签的数据降维,投影到低维空间同时满足三个条件:

尽可能多的保留数据样本的信息(即选择最大的特征是对应的特征向量所代表的方向)。
寻找使样本尽可能好分的最佳投影方向。
投影后使得同类样本尽可能近,不同类样本尽可能远。

案例

论证过程

sklearn实现(鸢尾花为例)
         在scikit-learn中, LDA类是sklearn.discriminant_analysis.LinearDiscriminantAnalysis。既可以用于分类又可以用于降维。当然,应用场景最多的还是降维。和PCA类似,LDA降维基本也不用调参,只需要指定降维到的维数即可 

import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis

iris = datasets.load_iris()
X = iris.data
Y = iris.target
target_names = iris.target_names
lda = LinearDiscriminantAnalysis(n_components=2)
X_r2 = lda.fit(X, Y).transform(X)
colors = ['navy', 'turquoise', 'darkorange']

for color, i, target_name in zip(colors, [0, 1, 2], target_names):
    plt.scatter(X_r2[Y == i, 0], X_r2[Y == i, 1], alpha=.8, color=color,
                label=target_name)
plt.legend(loc='best', shadow=False, scatterpoints=1)
plt.title('LDA of IRIS dataset by sklearn')

plt.show()

你可能感兴趣的:(多元统计,大数据)