1.朴素贝叶斯法是典型的生成学习方法。生成方法由训练数据学习联合概率分布
P ( X , Y ) P(X,Y) P(X,Y),然后求得后验概率分布 P ( Y ∣ X ) P(Y|X) P(Y∣X)。具体来说,利用训练数据学习 P ( X ∣ Y ) P(X|Y) P(X∣Y)和 P ( Y ) P(Y) P(Y)的估计,得到联合概率分布:
P ( X , Y ) = P ( Y ) P ( X ∣ Y ) P(X,Y)=P(Y)P(X|Y) P(X,Y)=P(Y)P(X∣Y)
概率估计方法可以是极大似然估计或贝叶斯估计。
2.朴素贝叶斯法的基本假设是条件独立性,
P ( X = x ∣ Y = c k ) = P ( X ( 1 ) = x ( 1 ) , ⋯ , X ( n ) = x ( n ) ∣ Y = c k ) = ∏ j = 1 n P ( X ( j ) = x ( j ) ∣ Y = c k ) \begin{aligned} P(X&=x | Y=c_{k} )=P\left(X^{(1)}=x^{(1)}, \cdots, X^{(n)}=x^{(n)} | Y=c_{k}\right) \\ &=\prod_{j=1}^{n} P\left(X^{(j)}=x^{(j)} | Y=c_{k}\right) \end{aligned} P(X=x∣Y=ck)=P(X(1)=x(1),⋯,X(n)=x(n)∣Y=ck)=j=1∏nP(X(j)=x(j)∣Y=ck)
这是一个较强的假设。由于这一假设,模型包含的条件概率的数量大为减少,朴素贝叶斯法的学习与预测大为简化。因而朴素贝叶斯法高效,且易于实现。其缺点是分类的性能不一定很高。
3.朴素贝叶斯法利用贝叶斯定理与学到的联合概率模型进行分类预测。
P ( Y ∣ X ) = P ( X , Y ) P ( X ) = P ( Y ) P ( X ∣ Y ) ∑ Y P ( Y ) P ( X ∣ Y ) P(Y | X)=\frac{P(X, Y)}{P(X)}=\frac{P(Y) P(X | Y)}{\sum_{Y} P(Y) P(X | Y)} P(Y∣X)=P(X)P(X,Y)=∑YP(Y)P(X∣Y)P(Y)P(X∣Y)
将输入 x x x分到后验概率最大的类 y y y。
y = arg max c k P ( Y = c k ) ∏ j = 1 n P ( X j = x ( j ) ∣ Y = c k ) y=\arg \max _{c_{k}} P\left(Y=c_{k}\right) \prod_{j=1}^{n} P\left(X_{j}=x^{(j)} | Y=c_{k}\right) y=argckmaxP(Y=ck)j=1∏nP(Xj=x(j)∣Y=ck)
后验概率最大等价于0-1损失函数时的期望风险最小化。
模型:
from sklearn.datasets import load_iris
import pandas as pd
import numpy as np
#导入鸢尾花数据集
iris=load_iris()
#获得特征X,和相应的标签y
X=iris["data"]
y=iris["target"]
iris
{'data': array([[5.1, 3.5, 1.4, 0.2],
[4.9, 3. , 1.4, 0.2],
[4.7, 3.2, 1.3, 0.2],
[4.6, 3.1, 1.5, 0.2],
[5. , 3.6, 1.4, 0.2],
[5.4, 3.9, 1.7, 0.4],
[4.6, 3.4, 1.4, 0.3],
[5. , 3.4, 1.5, 0.2],
[4.4, 2.9, 1.4, 0.2],
[4.9, 3.1, 1.5, 0.1],
[5.4, 3.7, 1.5, 0.2],
[4.8, 3.4, 1.6, 0.2],
[4.8, 3. , 1.4, 0.1],
[4.3, 3. , 1.1, 0.1],
[5.8, 4. , 1.2, 0.2],
[5.7, 4.4, 1.5, 0.4],
[5.4, 3.9, 1.3, 0.4],
[5.1, 3.5, 1.4, 0.3],
[5.7, 3.8, 1.7, 0.3],
[5.1, 3.8, 1.5, 0.3],
[5.4, 3.4, 1.7, 0.2],
[5.1, 3.7, 1.5, 0.4],
[4.6, 3.6, 1. , 0.2],
[5.1, 3.3, 1.7, 0.5],
[4.8, 3.4, 1.9, 0.2],
[5. , 3. , 1.6, 0.2],
[5. , 3.4, 1.6, 0.4],
[5.2, 3.5, 1.5, 0.2],
[5.2, 3.4, 1.4, 0.2],
[4.7, 3.2, 1.6, 0.2],
[4.8, 3.1, 1.6, 0.2],
[5.4, 3.4, 1.5, 0.4],
[5.2, 4.1, 1.5, 0.1],
[5.5, 4.2, 1.4, 0.2],
[4.9, 3.1, 1.5, 0.2],
[5. , 3.2, 1.2, 0.2],
[5.5, 3.5, 1.3, 0.2],
[4.9, 3.6, 1.4, 0.1],
[4.4, 3. , 1.3, 0.2],
[5.1, 3.4, 1.5, 0.2],
[5. , 3.5, 1.3, 0.3],
[4.5, 2.3, 1.3, 0.3],
[4.4, 3.2, 1.3, 0.2],
[5. , 3.5, 1.6, 0.6],
[5.1, 3.8, 1.9, 0.4],
[4.8, 3. , 1.4, 0.3],
[5.1, 3.8, 1.6, 0.2],
[4.6, 3.2, 1.4, 0.2],
[5.3, 3.7, 1.5, 0.2],
[5. , 3.3, 1.4, 0.2],
[7. , 3.2, 4.7, 1.4],
[6.4, 3.2, 4.5, 1.5],
[6.9, 3.1, 4.9, 1.5],
[5.5, 2.3, 4. , 1.3],
[6.5, 2.8, 4.6, 1.5],
[5.7, 2.8, 4.5, 1.3],
[6.3, 3.3, 4.7, 1.6],
[4.9, 2.4, 3.3, 1. ],
[6.6, 2.9, 4.6, 1.3],
[5.2, 2.7, 3.9, 1.4],
[5. , 2. , 3.5, 1. ],
[5.9, 3. , 4.2, 1.5],
[6. , 2.2, 4. , 1. ],
[6.1, 2.9, 4.7, 1.4],
[5.6, 2.9, 3.6, 1.3],
[6.7, 3.1, 4.4, 1.4],
[5.6, 3. , 4.5, 1.5],
[5.8, 2.7, 4.1, 1. ],
[6.2, 2.2, 4.5, 1.5],
[5.6, 2.5, 3.9, 1.1],
[5.9, 3.2, 4.8, 1.8],
[6.1, 2.8, 4. , 1.3],
[6.3, 2.5, 4.9, 1.5],
[6.1, 2.8, 4.7, 1.2],
[6.4, 2.9, 4.3, 1.3],
[6.6, 3. , 4.4, 1.4],
[6.8, 2.8, 4.8, 1.4],
[6.7, 3. , 5. , 1.7],
[6. , 2.9, 4.5, 1.5],
[5.7, 2.6, 3.5, 1. ],
[5.5, 2.4, 3.8, 1.1],
[5.5, 2.4, 3.7, 1. ],
[5.8, 2.7, 3.9, 1.2],
[6. , 2.7, 5.1, 1.6],
[5.4, 3. , 4.5, 1.5],
[6. , 3.4, 4.5, 1.6],
[6.7, 3.1, 4.7, 1.5],
[6.3, 2.3, 4.4, 1.3],
[5.6, 3. , 4.1, 1.3],
[5.5, 2.5, 4. , 1.3],
[5.5, 2.6, 4.4, 1.2],
[6.1, 3. , 4.6, 1.4],
[5.8, 2.6, 4. , 1.2],
[5. , 2.3, 3.3, 1. ],
[5.6, 2.7, 4.2, 1.3],
[5.7, 3. , 4.2, 1.2],
[5.7, 2.9, 4.2, 1.3],
[6.2, 2.9, 4.3, 1.3],
[5.1, 2.5, 3. , 1.1],
[5.7, 2.8, 4.1, 1.3],
[6.3, 3.3, 6. , 2.5],
[5.8, 2.7, 5.1, 1.9],
[7.1, 3. , 5.9, 2.1],
[6.3, 2.9, 5.6, 1.8],
[6.5, 3. , 5.8, 2.2],
[7.6, 3. , 6.6, 2.1],
[4.9, 2.5, 4.5, 1.7],
[7.3, 2.9, 6.3, 1.8],
[6.7, 2.5, 5.8, 1.8],
[7.2, 3.6, 6.1, 2.5],
[6.5, 3.2, 5.1, 2. ],
[6.4, 2.7, 5.3, 1.9],
[6.8, 3. , 5.5, 2.1],
[5.7, 2.5, 5. , 2. ],
[5.8, 2.8, 5.1, 2.4],
[6.4, 3.2, 5.3, 2.3],
[6.5, 3. , 5.5, 1.8],
[7.7, 3.8, 6.7, 2.2],
[7.7, 2.6, 6.9, 2.3],
[6. , 2.2, 5. , 1.5],
[6.9, 3.2, 5.7, 2.3],
[5.6, 2.8, 4.9, 2. ],
[7.7, 2.8, 6.7, 2. ],
[6.3, 2.7, 4.9, 1.8],
[6.7, 3.3, 5.7, 2.1],
[7.2, 3.2, 6. , 1.8],
[6.2, 2.8, 4.8, 1.8],
[6.1, 3. , 4.9, 1.8],
[6.4, 2.8, 5.6, 2.1],
[7.2, 3. , 5.8, 1.6],
[7.4, 2.8, 6.1, 1.9],
[7.9, 3.8, 6.4, 2. ],
[6.4, 2.8, 5.6, 2.2],
[6.3, 2.8, 5.1, 1.5],
[6.1, 2.6, 5.6, 1.4],
[7.7, 3. , 6.1, 2.3],
[6.3, 3.4, 5.6, 2.4],
[6.4, 3.1, 5.5, 1.8],
[6. , 3. , 4.8, 1.8],
[6.9, 3.1, 5.4, 2.1],
[6.7, 3.1, 5.6, 2.4],
[6.9, 3.1, 5.1, 2.3],
[5.8, 2.7, 5.1, 1.9],
[6.8, 3.2, 5.9, 2.3],
[6.7, 3.3, 5.7, 2.5],
[6.7, 3. , 5.2, 2.3],
[6.3, 2.5, 5. , 1.9],
[6.5, 3. , 5.2, 2. ],
[6.2, 3.4, 5.4, 2.3],
[5.9, 3. , 5.1, 1.8]]),
'target': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]),
'frame': None,
'target_names': array(['setosa', 'versicolor', 'virginica'], dtype='
#查看X,y的形状
X.shape,y.shape
((150, 4), (150,))
#将y转换为二维数组
y=y.reshape((150,-1))
y.shape
(150, 1)
#通过数据框可视化
df=pd.DataFrame(np.hstack([X,y]),columns=iris.feature_names+["target"])
df
sepal length (cm) | sepal width (cm) | petal length (cm) | petal width (cm) | target | |
---|---|---|---|---|---|
0 | 5.1 | 3.5 | 1.4 | 0.2 | 0.0 |
1 | 4.9 | 3.0 | 1.4 | 0.2 | 0.0 |
2 | 4.7 | 3.2 | 1.3 | 0.2 | 0.0 |
3 | 4.6 | 3.1 | 1.5 | 0.2 | 0.0 |
4 | 5.0 | 3.6 | 1.4 | 0.2 | 0.0 |
... | ... | ... | ... | ... | ... |
145 | 6.7 | 3.0 | 5.2 | 2.3 | 2.0 |
146 | 6.3 | 2.5 | 5.0 | 1.9 | 2.0 |
147 | 6.5 | 3.0 | 5.2 | 2.0 | 2.0 |
148 | 6.2 | 3.4 | 5.4 | 2.3 | 2.0 |
149 | 5.9 | 3.0 | 5.1 | 1.8 | 2.0 |
150 rows × 5 columns
#把标签列转为整型
df["target"]=df["target"].astype("int")
df
sepal length (cm) | sepal width (cm) | petal length (cm) | petal width (cm) | target | |
---|---|---|---|---|---|
0 | 5.1 | 3.5 | 1.4 | 0.2 | 0 |
1 | 4.9 | 3.0 | 1.4 | 0.2 | 0 |
2 | 4.7 | 3.2 | 1.3 | 0.2 | 0 |
3 | 4.6 | 3.1 | 1.5 | 0.2 | 0 |
4 | 5.0 | 3.6 | 1.4 | 0.2 | 0 |
... | ... | ... | ... | ... | ... |
145 | 6.7 | 3.0 | 5.2 | 2.3 | 2 |
146 | 6.3 | 2.5 | 5.0 | 1.9 | 2 |
147 | 6.5 | 3.0 | 5.2 | 2.0 | 2 |
148 | 6.2 | 3.4 | 5.4 | 2.3 | 2 |
149 | 5.9 | 3.0 | 5.1 | 1.8 | 2 |
150 rows × 5 columns
#划分数据为训练数据和测试数据
from sklearn.model_selection import train_test_split
X_train,X_test,y_train,y_test=train_test_split(X[:100],y[:100],test_size=0.25,random_state=0)
X_train.shape,X_test.shape,y_train.shape,y_test.shape
((75, 4), (25, 4), (75, 1), (25, 1))
y_train
array([[0],
[0],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[0],
[0],
[1],
[1],
[1],
[0],
[1],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[1],
[1],
[0],
[0],
[0],
[1],
[0],
[0],
[0],
[1],
[0],
[0],
[1],
[1],
[1],
[1],
[0],
[1],
[0],
[1],
[0],
[0],
[0],
[1],
[1],
[1],
[0],
[1],
[1],
[1],
[0],
[0],
[1],
[0],
[0],
[1],
[1],
[0],
[1],
[1],
[1],
[0],
[0],
[1],
[0],
[1],
[1],
[1],
[0],
[0]])
#看看哪些索引处的标签为0
np.where(y_train==0)
(array([ 0, 1, 10, 11, 15, 17, 18, 19, 20, 21, 22, 23, 24, 27, 28, 29, 31,
32, 33, 35, 36, 41, 43, 45, 46, 47, 51, 55, 56, 58, 59, 62, 66, 67,
69, 73, 74], dtype=int64),
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int64))
np.where(y_train==1)
(array([ 2, 3, 4, 5, 6, 7, 8, 9, 12, 13, 14, 16, 25, 26, 30, 34, 37,
38, 39, 40, 42, 44, 48, 49, 50, 52, 53, 54, 57, 60, 61, 63, 64, 65,
68, 70, 71, 72], dtype=int64),
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int64))
#新建一个字典,存储每个标签对应的索引(用到行索引),该操作的目的是为了后面对不同类别分别计算均值和方差
dic={}
for i in [0,1]:
dic[i]=np.where(y_train==i)
dic
{0: (array([ 0, 1, 10, 11, 15, 17, 18, 19, 20, 21, 22, 23, 24, 27, 28, 29, 31,
32, 33, 35, 36, 41, 43, 45, 46, 47, 51, 55, 56, 58, 59, 62, 66, 67,
69, 73, 74], dtype=int64),
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int64)),
1: (array([ 2, 3, 4, 5, 6, 7, 8, 9, 12, 13, 14, 16, 25, 26, 30, 34, 37,
38, 39, 40, 42, 44, 48, 49, 50, 52, 53, 54, 57, 60, 61, 63, 64, 65,
68, 70, 71, 72], dtype=int64),
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int64))}
#计算均值和方差,对于每个特征(列这个维度)计算均值和方差,因此,有多少个特征,那么均值和方差向量中就有多少个元素
#X为数据框
def u_sigma(X):
u=np.mean(X,axis=0)
sigma=np.var(X,axis=0)
return u,sigma
#包含两个元素,第一个元素为类别0对应的均值和方差,第二个元素为类别为1的元素对应的均值和方差
lst=[]
for key,value in dic.items():
lst.append(u_sigma(X_train[value[0]]))
lst
[(array([5.06486486, 3.45135135, 1.47297297, 0.24054054]),
array([0.11200877, 0.14195763, 0.02197224, 0.00889701])),
(array([5.92368421, 2.78684211, 4.26578947, 1.33947368]),
array([0.27496537, 0.09956371, 0.23646122, 0.04081025]))]
#序列解包,看看是否正确
u_0,sigma_0=lst[0]
u_1,sigma_1=lst[1]
u_0,sigma_0,u_1,sigma_1
(array([5.06486486, 3.45135135, 1.47297297, 0.24054054]),
array([0.11200877, 0.14195763, 0.02197224, 0.00889701]),
array([5.92368421, 2.78684211, 4.26578947, 1.33947368]),
array([0.27496537, 0.09956371, 0.23646122, 0.04081025]))
GaussianNB 高斯朴素贝叶斯,特征的可能性被假设为高斯
概率密度函数:
P ( x i ∣ y k ) = 1 2 π σ y k 2 e x p ( − ( x i − μ y k ) 2 2 σ y k 2 ) P(x_i | y_k)=\frac{1}{\sqrt{2\pi\sigma^2_{yk}}}exp(-\frac{(x_i-\mu_{yk})^2}{2\sigma^2_{yk}}) P(xi∣yk)=2πσyk21exp(−2σyk2(xi−μyk)2)
数学期望(mean): μ \mu μ
方差: σ 2 = ∑ ( X − μ ) 2 N \sigma^2=\frac{\sum(X-\mu)^2}{N} σ2=N∑(X−μ)2
#计算类别0(普通鸢尾花)的均值和方差
u_0,sigma_0=u_sigma(X_train[dic[0][0],:])
u_0,sigma_0
(array([5.06486486, 3.45135135, 1.47297297, 0.24054054]),
array([0.11200877, 0.14195763, 0.02197224, 0.00889701]))
#计算类别1(山鸢尾花)的均值和方差
u_1,sigma_1=u_sigma(X_train[dic[1][0],:])
u_1,sigma_1
(array([5.92368421, 2.78684211, 4.26578947, 1.33947368]),
array([0.27496537, 0.09956371, 0.23646122, 0.04081025]))
len(dic[0][0]),len(dic[1][0])
(37, 38)
dic[0][0]
array([ 0, 1, 10, 11, 15, 17, 18, 19, 20, 21, 22, 23, 24, 27, 28, 29, 31,
32, 33, 35, 36, 41, 43, 45, 46, 47, 51, 55, 56, 58, 59, 62, 66, 67,
69, 73, 74], dtype=int64)
#计算每个类别对应的先验概率
lst_pri=[]
for i in [0,1]:
lst_pri.append(len(dic[i][0]))
lst_pri=[item/sum(lst_pri) for item in lst_pri]
lst_pri
[0.49333333333333335, 0.5066666666666667]
def gaussian_density(data,u,sigma):
expo=np.exp(-np.power(data-u,2)/(2*sigma))
coef=1/(np.sqrt(2*np.pi*sigma))
return np.prod(coef*expo,axis=1)
#所有样本带入到第1个类别的高斯模型参数中得到的结果
pre_0=gaussian_density(X_train,u_0,sigma_0)*lst_pri[0]
pre_0
array([3.99415464e+000, 1.94367635e+000, 6.60889499e-097, 1.80752252e-082,
1.44507736e-148, 8.63205906e-058, 1.77086187e-073, 1.72200357e-108,
4.86671382e-134, 1.06674156e-132, 5.80979347e+000, 1.93582589e-001,
6.83123642e-151, 3.80660319e-138, 3.54858798e-110, 2.47436003e+000,
9.47627356e-114, 3.63995412e-001, 6.64092778e-003, 5.19779913e+000,
1.15891783e-002, 5.07677505e+000, 2.86260160e+000, 2.21879073e-001,
1.56640570e-001, 1.03157479e-131, 8.43689850e-092, 5.64628646e+000,
3.64465774e+000, 5.22805105e+000, 5.83954842e-143, 3.24263354e+000,
9.31529278e-001, 4.57789205e-002, 2.23448562e-161, 3.09648295e+000,
1.00212662e+000, 5.17295325e-130, 1.09814912e-048, 1.88640805e-056,
3.08491848e-137, 4.81085712e-001, 1.12504707e-129, 3.67995439e-002,
3.91991816e-092, 3.70404421e+000, 1.97791635e+000, 5.18297633e+000,
3.22002953e-109, 2.45629129e-042, 4.65684882e-078, 1.20020428e+000,
3.47644237e-102, 5.30752338e-159, 2.67525891e-180, 2.14367370e+000,
1.69559466e+000, 5.01330518e-065, 2.90136679e+000, 6.26263265e+000,
9.91822069e-123, 6.08616441e-129, 7.38230838e-001, 2.42302202e-096,
4.49573232e-170, 6.29495594e-117, 1.39322505e+000, 1.33577067e+000,
1.49050826e-177, 1.31733476e+000, 5.16176371e-102, 4.55092123e-084,
5.28027292e-073, 1.74659558e+000, 1.73554442e-002])
#所有样本带入到第2个类别的高斯模型参数中得到的结果
pre_1=gaussian_density(X_train,u_1,sigma_1)*lst_pri[1]
pre_1
array([6.88891263e-17, 2.52655671e-16, 6.66784142e-01, 4.39035170e-01,
1.02097078e-01, 5.26743134e-04, 8.41179097e-02, 3.62626644e-01,
7.91642821e-02, 1.44031642e-01, 2.76147108e-16, 6.67290518e-15,
4.75292781e-02, 4.49054758e-01, 4.79673262e-01, 3.31237947e-16,
4.53713921e-01, 5.07639533e-18, 8.97591672e-17, 2.14239456e-17,
2.89264720e-18, 9.14486465e-16, 1.93935408e-16, 9.52254108e-18,
1.72377778e-14, 4.48431308e-01, 2.11349055e-01, 6.33550524e-17,
8.36586449e-16, 1.63398769e-16, 2.61589867e-02, 4.42217308e-16,
2.04791994e-17, 9.81772333e-12, 2.65632115e-02, 8.48713904e-17,
1.37974305e-13, 3.37353331e-01, 1.87800865e-03, 4.26608396e-02,
4.58473827e-02, 3.33967704e-20, 2.47883299e-01, 1.36596674e-19,
3.18444088e-01, 2.23261970e-16, 8.08973781e-16, 1.58016713e-16,
6.30695919e-01, 2.54489986e-03, 1.61140759e-01, 8.06573695e-15,
6.10877468e-01, 1.25788818e-01, 1.36687997e-02, 4.89645218e-15,
8.15261126e-19, 3.32739495e-02, 4.87766404e-17, 4.05703434e-16,
1.48439207e-01, 2.49686080e-01, 1.21546609e-17, 4.80883386e-01,
1.36182282e-02, 1.75312606e-01, 4.57390205e-17, 6.63620680e-15,
7.51872920e-02, 4.53624816e-17, 6.57207208e-01, 1.69998516e-01,
2.35169368e-01, 4.90692552e-17, 1.93538305e-13])
#得到训练集的预测结果
pre_all=np.hstack([pre_0.reshape(len(pre_0),1),pre_1.reshape(pre_1.shape[0],1)])
pre_all
array([[3.99415464e+000, 6.88891263e-017],
[1.94367635e+000, 2.52655671e-016],
[6.60889499e-097, 6.66784142e-001],
[1.80752252e-082, 4.39035170e-001],
[1.44507736e-148, 1.02097078e-001],
[8.63205906e-058, 5.26743134e-004],
[1.77086187e-073, 8.41179097e-002],
[1.72200357e-108, 3.62626644e-001],
[4.86671382e-134, 7.91642821e-002],
[1.06674156e-132, 1.44031642e-001],
[5.80979347e+000, 2.76147108e-016],
[1.93582589e-001, 6.67290518e-015],
[6.83123642e-151, 4.75292781e-002],
[3.80660319e-138, 4.49054758e-001],
[3.54858798e-110, 4.79673262e-001],
[2.47436003e+000, 3.31237947e-016],
[9.47627356e-114, 4.53713921e-001],
[3.63995412e-001, 5.07639533e-018],
[6.64092778e-003, 8.97591672e-017],
[5.19779913e+000, 2.14239456e-017],
[1.15891783e-002, 2.89264720e-018],
[5.07677505e+000, 9.14486465e-016],
[2.86260160e+000, 1.93935408e-016],
[2.21879073e-001, 9.52254108e-018],
[1.56640570e-001, 1.72377778e-014],
[1.03157479e-131, 4.48431308e-001],
[8.43689850e-092, 2.11349055e-001],
[5.64628646e+000, 6.33550524e-017],
[3.64465774e+000, 8.36586449e-016],
[5.22805105e+000, 1.63398769e-016],
[5.83954842e-143, 2.61589867e-002],
[3.24263354e+000, 4.42217308e-016],
[9.31529278e-001, 2.04791994e-017],
[4.57789205e-002, 9.81772333e-012],
[2.23448562e-161, 2.65632115e-002],
[3.09648295e+000, 8.48713904e-017],
[1.00212662e+000, 1.37974305e-013],
[5.17295325e-130, 3.37353331e-001],
[1.09814912e-048, 1.87800865e-003],
[1.88640805e-056, 4.26608396e-002],
[3.08491848e-137, 4.58473827e-002],
[4.81085712e-001, 3.33967704e-020],
[1.12504707e-129, 2.47883299e-001],
[3.67995439e-002, 1.36596674e-019],
[3.91991816e-092, 3.18444088e-001],
[3.70404421e+000, 2.23261970e-016],
[1.97791635e+000, 8.08973781e-016],
[5.18297633e+000, 1.58016713e-016],
[3.22002953e-109, 6.30695919e-001],
[2.45629129e-042, 2.54489986e-003],
[4.65684882e-078, 1.61140759e-001],
[1.20020428e+000, 8.06573695e-015],
[3.47644237e-102, 6.10877468e-001],
[5.30752338e-159, 1.25788818e-001],
[2.67525891e-180, 1.36687997e-002],
[2.14367370e+000, 4.89645218e-015],
[1.69559466e+000, 8.15261126e-019],
[5.01330518e-065, 3.32739495e-002],
[2.90136679e+000, 4.87766404e-017],
[6.26263265e+000, 4.05703434e-016],
[9.91822069e-123, 1.48439207e-001],
[6.08616441e-129, 2.49686080e-001],
[7.38230838e-001, 1.21546609e-017],
[2.42302202e-096, 4.80883386e-001],
[4.49573232e-170, 1.36182282e-002],
[6.29495594e-117, 1.75312606e-001],
[1.39322505e+000, 4.57390205e-017],
[1.33577067e+000, 6.63620680e-015],
[1.49050826e-177, 7.51872920e-002],
[1.31733476e+000, 4.53624816e-017],
[5.16176371e-102, 6.57207208e-001],
[4.55092123e-084, 1.69998516e-001],
[5.28027292e-073, 2.35169368e-001],
[1.74659558e+000, 4.90692552e-017],
[1.73554442e-002, 1.93538305e-013]])
np.argmax(pre_all,axis=1)
array([0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0,
0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0,
1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1,
0, 0, 1, 0, 1, 1, 1, 0, 0], dtype=int64)
#真实情况为
y_train.ravel()
array([0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0,
0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0,
1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1,
0, 0, 1, 0, 1, 1, 1, 0, 0])
#判断多少预测正确了
np.argmax(pre_all,axis=1)==y_train.ravel()
array([ True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True])
#计算精确率
np.sum(np.argmax(pre_all,axis=1)==y_train.ravel())/len(y_train.ravel())
1.0
def predict(X_test,y_test,u_0,sigma_0,u_1,sigma_1,lst_pri):
pre_0=gaussian_density(X_test,u_0,sigma_0)*lst_pri[0]
pre_1=gaussian_density(X_test,u_1,sigma_1)*lst_pri[1]
pre_all=np.hstack([pre_0.reshape(len(pre_0),1),pre_1.reshape(pre_1.shape[0],1)])
return np.sum(np.argmax(pre_all,axis=1)==y_test.ravel())/len(y_test)
predict(X_test,y_test,u_0,sigma_0,u_1,sigma_1,lst_pri)
1.0
# 1 导入包
from sklearn.naive_bayes import GaussianNB, BernoulliNB,MultinomialNB
# 2建立模型
clf=GaussianNB()
# 3 拟合模型
clf.fit(X_train,y_train.ravel())
GaussianNB()
# 4 测试模型
clf.score(X_test,y_test)
1.0
# 1 导入包
from sklearn.naive_bayes import GaussianNB, BernoulliNB,MultinomialNB
# 建立模型
clf=MultinomialNB()
# 3 拟合模型
clf.fit(X_train,y_train.ravel())
MultinomialNB()
# 4 测试模型
clf.score(X_test,y_test)
1.0
from sklearn.datasets import load_iris
import pandas as pd
import numpy as np
#导入鸢尾花数据集
iris=load_iris()
#获得特征X,和相应的标签y
X=iris["data"]
y=iris["target"]
#查看X,y的形状
X.shape,y.shape
((150, 4), (150,))
#将y转换为二维数组
y=y.reshape((150,-1))
y.shape
(150, 1)
#通过数据框可视化
df=pd.DataFrame(np.hstack([X,y]),columns=iris.feature_names+["target"])
df
sepal length (cm) | sepal width (cm) | petal length (cm) | petal width (cm) | target | |
---|---|---|---|---|---|
0 | 5.1 | 3.5 | 1.4 | 0.2 | 0.0 |
1 | 4.9 | 3.0 | 1.4 | 0.2 | 0.0 |
2 | 4.7 | 3.2 | 1.3 | 0.2 | 0.0 |
3 | 4.6 | 3.1 | 1.5 | 0.2 | 0.0 |
4 | 5.0 | 3.6 | 1.4 | 0.2 | 0.0 |
... | ... | ... | ... | ... | ... |
145 | 6.7 | 3.0 | 5.2 | 2.3 | 2.0 |
146 | 6.3 | 2.5 | 5.0 | 1.9 | 2.0 |
147 | 6.5 | 3.0 | 5.2 | 2.0 | 2.0 |
148 | 6.2 | 3.4 | 5.4 | 2.3 | 2.0 |
149 | 5.9 | 3.0 | 5.1 | 1.8 | 2.0 |
150 rows × 5 columns
#把标签列转为整型
df["target"]=df["target"].astype("int")
df
sepal length (cm) | sepal width (cm) | petal length (cm) | petal width (cm) | target | |
---|---|---|---|---|---|
0 | 5.1 | 3.5 | 1.4 | 0.2 | 0 |
1 | 4.9 | 3.0 | 1.4 | 0.2 | 0 |
2 | 4.7 | 3.2 | 1.3 | 0.2 | 0 |
3 | 4.6 | 3.1 | 1.5 | 0.2 | 0 |
4 | 5.0 | 3.6 | 1.4 | 0.2 | 0 |
... | ... | ... | ... | ... | ... |
145 | 6.7 | 3.0 | 5.2 | 2.3 | 2 |
146 | 6.3 | 2.5 | 5.0 | 1.9 | 2 |
147 | 6.5 | 3.0 | 5.2 | 2.0 | 2 |
148 | 6.2 | 3.4 | 5.4 | 2.3 | 2 |
149 | 5.9 | 3.0 | 5.1 | 1.8 | 2 |
150 rows × 5 columns
#看看0,1,2类别分别是哪些列
index_0=df[df["target"]==0].index
index_0,len(index_0)
(Int64Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49],
dtype='int64'),
50)
#划分数据为训练数据和测试数据
from sklearn.model_selection import train_test_split
X_train,X_test,y_train,y_test=train_test_split(X,y,test_size=0.2,random_state=0)
X_train.shape,X_test.shape,y_train.shape,y_test.shape
((120, 4), (30, 4), (120, 1), (30, 1))
#看看哪些索引处的标签为0
np.where(y_train==0)
(array([ 2, 6, 11, 13, 14, 31, 38, 39, 42, 43, 45, 48, 52,
57, 58, 61, 63, 66, 67, 69, 70, 71, 75, 76, 77, 80,
81, 83, 88, 90, 92, 93, 95, 104, 108, 113, 114, 115, 119],
dtype=int64),
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int64))
#新建一个字典,存储每个标签对应的索引,该操作的目的是为了后面对不同类别分别计算均值和方差
dic={}
for i in [0,1,2]:
dic[i]=np.where(y_train==i)
dic
{0: (array([ 2, 6, 11, 13, 14, 31, 38, 39, 42, 43, 45, 48, 52,
57, 58, 61, 63, 66, 67, 69, 70, 71, 75, 76, 77, 80,
81, 83, 88, 90, 92, 93, 95, 104, 108, 113, 114, 115, 119],
dtype=int64),
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int64)),
1: (array([ 1, 5, 7, 8, 9, 15, 20, 22, 23, 28, 30, 33, 34,
35, 36, 41, 44, 47, 49, 51, 72, 78, 79, 82, 85, 87,
97, 98, 99, 102, 103, 105, 109, 110, 111, 112, 117], dtype=int64),
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int64)),
2: (array([ 0, 3, 4, 10, 12, 16, 17, 18, 19, 21, 24, 25, 26,
27, 29, 32, 37, 40, 46, 50, 53, 54, 55, 56, 59, 60,
62, 64, 65, 68, 73, 74, 84, 86, 89, 91, 94, 96, 100,
101, 106, 107, 116, 118], dtype=int64),
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
dtype=int64))}
#计算均值和方差,对于每个特征(列这个维度)计算均值和方差,因此,有多少个特征,那么均值和方差向量中就有多少个元素
#X为数据框
def u_sigma(X):
u=np.mean(X,axis=0)
sigma=np.var(X,axis=0)
return u,sigma
dic[0][0]
array([ 2, 6, 11, 13, 14, 31, 38, 39, 42, 43, 45, 48, 52,
57, 58, 61, 63, 66, 67, 69, 70, 71, 75, 76, 77, 80,
81, 83, 88, 90, 92, 93, 95, 104, 108, 113, 114, 115, 119],
dtype=int64)
#计算类别0(普通鸢尾花)的均值和方差
u_0,sigma_0=u_sigma(X_train[dic[0][0],:])
u_0,sigma_0
(array([5.02051282, 3.4025641 , 1.46153846, 0.24102564]),
array([0.12932281, 0.1417883 , 0.02031558, 0.01113741]))
#计算类别1(山鸢尾花)的均值和方差
u_1,sigma_1=u_sigma(X_train[dic[1][0],:])
u_1,sigma_1
(array([5.88648649, 2.76216216, 4.21621622, 1.32432432]),
array([0.26387144, 0.1039737 , 0.2300073 , 0.04075968]))
#计算类别2(维吉利亚尾花)的均值和方差
u_2,sigma_2=u_sigma(X_train[dic[2][0],:])
u_2,sigma_2
(array([6.63863636, 2.98863636, 5.56590909, 2.03181818]),
array([0.38918905, 0.10782541, 0.29451963, 0.06444215]))
#计算每个类别对应的先验概率
lst_pri=[]
for i in [0,1,2]:
lst_pri.append(len(dic[i][0]))
lst_pri=[item/sum(lst_pri) for item in lst_pri]
lst_pri
[0.325, 0.30833333333333335, 0.36666666666666664]
#所有样本带入到第1个类别的高斯模型参数中得到的结果
pre_0=gaussian_density(X_train,u_0,sigma_0)*lst_pri[0]
pre_0
array([3.64205427e-225, 3.40844822e-130, 3.08530851e+000, 4.39737931e-176,
9.32161971e-262, 1.12603195e-090, 8.19955989e-002, 5.38088810e-180,
9.99826548e-113, 6.22294079e-089, 2.18584476e-247, 1.14681255e+000,
2.38802541e-230, 7.48076601e-003, 1.51577355e+000, 8.84977214e-059,
9.40380304e-226, 2.20471084e-296, 1.11546261e-168, 1.12595279e-254,
7.13493544e-080, 0.00000000e+000, 5.43149166e-151, 7.00401162e-075,
2.20419920e-177, 7.88959967e-176, 1.41957694e-141, 1.31858669e-191,
4.74468428e-145, 9.39276491e-214, 2.02942932e-136, 1.40273451e+000,
4.66850302e-197, 1.84403192e-103, 8.15997638e-072, 1.70855259e-092,
8.50513873e-134, 1.04684523e-275, 1.95561507e+000, 5.03262010e-003,
3.23862571e-215, 3.13715578e-099, 5.29812808e-001, 6.29658079e-003,
1.81543604e-163, 1.32072621e+000, 1.48741944e-190, 4.61289448e-041,
1.58979789e+000, 2.96357473e-134, 0.00000000e+000, 2.65155682e-103,
7.05472630e-001, 1.42166693e-285, 8.68838944e-281, 4.74069911e-280,
2.59051414e-254, 1.30709804e+000, 1.93716067e+000, 1.10437770e-205,
2.87463392e-264, 8.77307761e-003, 6.56796757e-251, 1.82259183e+000,
2.68966659e-196, 2.28835722e-239, 3.85005332e-001, 2.97070927e+000,
1.54669251e-245, 2.97250230e+000, 2.51256489e-001, 7.67795136e-002,
4.15395634e-093, 1.00997094e-298, 0.00000000e+000, 3.22193669e+000,
2.47369004e+000, 3.01412924e+000, 5.36914976e-122, 4.87767060e-123,
6.01262218e-001, 4.61755454e-002, 1.10260946e-111, 7.18092701e-001,
0.00000000e+000, 4.83593087e-049, 0.00000000e+000, 1.77412583e-123,
2.53482967e-001, 1.70832646e-168, 1.88690143e-002, 0.00000000e+000,
1.86389396e+000, 1.35985047e+000, 8.17806813e-294, 3.28434438e+000,
8.21098705e-277, 1.00342674e-097, 2.20897185e-083, 1.58003504e-057,
1.61348013e-243, 3.80414054e-237, 2.15851912e-161, 1.95128444e-180,
1.31803692e+000, 7.79858859e-067, 6.12107543e-279, 4.66850302e-197,
3.52624721e+000, 7.63949242e-132, 3.31703393e-097, 5.37109191e-168,
6.90508182e-119, 7.83871527e-001, 8.95165152e-001, 1.09244100e+000,
1.04987457e-233, 1.54899418e-087, 0.00000000e+000, 1.49109871e+000])
#所有样本带入到第2个类别的高斯模型参数中得到的结果
pre_1=gaussian_density(X_train,u_1,sigma_1)*lst_pri[1]
pre_1
array([2.95633338e-04, 1.36197317e-01, 2.90318178e-16, 7.67369010e-03,
3.75455611e-07, 1.46797523e-01, 6.95344048e-15, 3.36175041e-02,
2.53841239e-01, 3.16199307e-01, 1.32212698e-06, 2.31912196e-17,
6.23661197e-08, 4.43491705e-12, 9.03659728e-17, 6.06688573e-04,
3.14945948e-04, 1.24882948e-11, 1.87288422e-02, 2.66560740e-05,
1.30000970e-01, 2.76182931e-12, 2.07410916e-02, 7.22817433e-02,
7.79602598e-03, 4.38522048e-02, 8.22673683e-03, 1.14220807e-03,
1.03590806e-02, 8.19796704e-05, 2.21991209e-02, 1.91118667e-15,
1.48027054e-03, 4.05979965e-01, 1.65444313e-01, 2.36465225e-01,
2.30302015e-01, 4.54901890e-07, 7.37406496e-17, 2.21052310e-20,
3.87241584e-04, 2.87187564e-01, 8.53516604e-15, 3.46342632e-18,
9.95391379e-03, 2.43959119e-16, 4.23043625e-03, 2.34628172e-03,
2.50262009e-16, 5.08355498e-02, 1.22369433e-14, 4.12873889e-01,
1.33213958e-17, 2.98880456e-08, 1.95809747e-09, 6.40227550e-08,
2.84653316e-06, 5.40191505e-17, 4.67733730e-16, 6.42382537e-05,
1.79818302e-07, 1.09855352e-16, 2.30402853e-08, 3.51870932e-16,
3.18554534e-04, 1.18966325e-06, 5.07486109e-18, 2.25215273e-17,
2.37994256e-05, 9.20537370e-16, 9.71966954e-18, 1.81892177e-14,
1.17820150e-01, 7.11741017e-10, 3.82851638e-12, 6.59703177e-17,
8.88106613e-16, 1.68993929e-16, 3.77332955e-01, 1.22469010e-01,
2.07501791e-17, 9.48218948e-12, 2.63666294e-01, 1.33681661e-13,
1.13413698e-15, 1.81908946e-03, 1.46950870e-13, 6.95238806e-02,
4.07966207e-20, 1.07543910e-02, 1.43838827e-19, 5.26740196e-12,
2.36489470e-16, 8.55569443e-16, 4.82666780e-08, 1.63877804e-16,
5.30883063e-10, 4.36520033e-01, 3.13721528e-01, 3.62503830e-02,
7.75810130e-08, 1.09538068e-07, 6.27229834e-02, 4.93070200e-03,
5.32420738e-15, 3.01096779e-02, 8.55857074e-10, 1.48027054e-03,
4.25565651e-16, 1.22088863e-01, 3.06149212e-01, 5.75190751e-03,
1.16325296e-01, 4.61599415e-17, 6.67684050e-15, 4.97991843e-17,
3.11807922e-04, 1.25938919e-01, 6.63898313e-16, 5.04670598e-17])
#所有样本带入到第3个类别的高斯模型参数中得到的结果
pre_2=gaussian_density(X_train,u_2,sigma_2)*lst_pri[2]
pre_2
array([1.88926441e-01, 7.41874323e-04, 2.18905385e-26, 7.03342033e-02,
2.07838563e-01, 6.36007282e-06, 3.75616194e-24, 2.15583340e-02,
7.65683494e-04, 4.80086802e-06, 3.04560221e-01, 4.03768532e-28,
1.12679216e-01, 9.72668930e-22, 1.40128825e-26, 2.07279668e-11,
2.09922203e-01, 3.69933717e-02, 7.04823898e-04, 6.49975333e-02,
2.90135522e-07, 2.72821894e-02, 1.19387091e-02, 7.43267743e-08,
5.99160309e-02, 1.85609819e-02, 1.38418438e-04, 4.76244749e-02,
2.86112072e-03, 2.53639963e-01, 3.04064364e-03, 5.04262171e-26,
5.47700919e-02, 3.69353344e-05, 1.75987852e-06, 5.01849240e-06,
2.09975476e-03, 8.54119142e-02, 1.00630371e-26, 1.53267285e-31,
1.61099289e-01, 2.08157220e-05, 9.87308671e-25, 7.12483734e-27,
1.49368318e-02, 4.76225689e-27, 7.43930795e-02, 8.62041503e-11,
9.03427577e-27, 2.32663919e-04, 4.36377985e-03, 6.75646957e-05,
1.81992485e-28, 1.99685684e-01, 1.36031284e-01, 2.34763950e-01,
2.49673422e-01, 9.27207512e-27, 2.43693353e-26, 1.79134484e-01,
1.95463733e-01, 3.06844563e-28, 6.40538684e-02, 5.34390777e-27,
2.02012772e-02, 2.61986932e-01, 8.07090461e-29, 1.45826047e-27,
4.70449238e-02, 5.86183174e-26, 9.92273358e-29, 9.92642821e-24,
1.68421105e-06, 1.22514460e-01, 1.57513390e-02, 3.69159440e-27,
2.04206384e-26, 8.30149544e-27, 2.05007234e-04, 1.47522326e-03,
3.70249288e-28, 1.18962106e-21, 3.04482104e-04, 1.44239452e-23,
1.07163996e-03, 5.75350754e-11, 6.13059140e-04, 1.38954915e-03,
1.29199008e-29, 4.74148015e-02, 5.06182005e-29, 7.33590052e-03,
3.76544259e-26, 2.67245797e-26, 7.13465644e-02, 5.26396730e-27,
4.51771500e-02, 3.67360555e-05, 3.79694730e-06, 9.71272783e-09,
1.26212878e-01, 1.49245747e-01, 4.92630412e-03, 8.08794435e-02,
1.30436645e-25, 8.74375374e-09, 1.07798580e-01, 5.47700919e-02,
2.29068907e-26, 1.01895184e-03, 3.35870705e-05, 3.23117267e-02,
4.91416425e-05, 3.49183358e-27, 1.03729239e-24, 1.10117672e-27,
1.80129089e-01, 6.09942673e-07, 3.30717488e-04, 1.01366241e-27])
#得到训练集的预测结果
pre_all=np.hstack([pre_0.reshape(len(pre_0),1),pre_1.reshape(pre_1.shape[0],1),pre_2.reshape(pre_2.shape[0],1)])
pre_all
array([[3.64205427e-225, 2.95633338e-004, 1.88926441e-001],
[3.40844822e-130, 1.36197317e-001, 7.41874323e-004],
[3.08530851e+000, 2.90318178e-016, 2.18905385e-026],
[4.39737931e-176, 7.67369010e-003, 7.03342033e-002],
[9.32161971e-262, 3.75455611e-007, 2.07838563e-001],
[1.12603195e-090, 1.46797523e-001, 6.36007282e-006],
[8.19955989e-002, 6.95344048e-015, 3.75616194e-024],
[5.38088810e-180, 3.36175041e-002, 2.15583340e-002],
[9.99826548e-113, 2.53841239e-001, 7.65683494e-004],
[6.22294079e-089, 3.16199307e-001, 4.80086802e-006],
[2.18584476e-247, 1.32212698e-006, 3.04560221e-001],
[1.14681255e+000, 2.31912196e-017, 4.03768532e-028],
[2.38802541e-230, 6.23661197e-008, 1.12679216e-001],
[7.48076601e-003, 4.43491705e-012, 9.72668930e-022],
[1.51577355e+000, 9.03659728e-017, 1.40128825e-026],
[8.84977214e-059, 6.06688573e-004, 2.07279668e-011],
[9.40380304e-226, 3.14945948e-004, 2.09922203e-001],
[2.20471084e-296, 1.24882948e-011, 3.69933717e-002],
[1.11546261e-168, 1.87288422e-002, 7.04823898e-004],
[1.12595279e-254, 2.66560740e-005, 6.49975333e-002],
[7.13493544e-080, 1.30000970e-001, 2.90135522e-007],
[0.00000000e+000, 2.76182931e-012, 2.72821894e-002],
[5.43149166e-151, 2.07410916e-002, 1.19387091e-002],
[7.00401162e-075, 7.22817433e-002, 7.43267743e-008],
[2.20419920e-177, 7.79602598e-003, 5.99160309e-002],
[7.88959967e-176, 4.38522048e-002, 1.85609819e-002],
[1.41957694e-141, 8.22673683e-003, 1.38418438e-004],
[1.31858669e-191, 1.14220807e-003, 4.76244749e-002],
[4.74468428e-145, 1.03590806e-002, 2.86112072e-003],
[9.39276491e-214, 8.19796704e-005, 2.53639963e-001],
[2.02942932e-136, 2.21991209e-002, 3.04064364e-003],
[1.40273451e+000, 1.91118667e-015, 5.04262171e-026],
[4.66850302e-197, 1.48027054e-003, 5.47700919e-002],
[1.84403192e-103, 4.05979965e-001, 3.69353344e-005],
[8.15997638e-072, 1.65444313e-001, 1.75987852e-006],
[1.70855259e-092, 2.36465225e-001, 5.01849240e-006],
[8.50513873e-134, 2.30302015e-001, 2.09975476e-003],
[1.04684523e-275, 4.54901890e-007, 8.54119142e-002],
[1.95561507e+000, 7.37406496e-017, 1.00630371e-026],
[5.03262010e-003, 2.21052310e-020, 1.53267285e-031],
[3.23862571e-215, 3.87241584e-004, 1.61099289e-001],
[3.13715578e-099, 2.87187564e-001, 2.08157220e-005],
[5.29812808e-001, 8.53516604e-015, 9.87308671e-025],
[6.29658079e-003, 3.46342632e-018, 7.12483734e-027],
[1.81543604e-163, 9.95391379e-003, 1.49368318e-002],
[1.32072621e+000, 2.43959119e-016, 4.76225689e-027],
[1.48741944e-190, 4.23043625e-003, 7.43930795e-002],
[4.61289448e-041, 2.34628172e-003, 8.62041503e-011],
[1.58979789e+000, 2.50262009e-016, 9.03427577e-027],
[2.96357473e-134, 5.08355498e-002, 2.32663919e-004],
[0.00000000e+000, 1.22369433e-014, 4.36377985e-003],
[2.65155682e-103, 4.12873889e-001, 6.75646957e-005],
[7.05472630e-001, 1.33213958e-017, 1.81992485e-028],
[1.42166693e-285, 2.98880456e-008, 1.99685684e-001],
[8.68838944e-281, 1.95809747e-009, 1.36031284e-001],
[4.74069911e-280, 6.40227550e-008, 2.34763950e-001],
[2.59051414e-254, 2.84653316e-006, 2.49673422e-001],
[1.30709804e+000, 5.40191505e-017, 9.27207512e-027],
[1.93716067e+000, 4.67733730e-016, 2.43693353e-026],
[1.10437770e-205, 6.42382537e-005, 1.79134484e-001],
[2.87463392e-264, 1.79818302e-007, 1.95463733e-001],
[8.77307761e-003, 1.09855352e-016, 3.06844563e-028],
[6.56796757e-251, 2.30402853e-008, 6.40538684e-002],
[1.82259183e+000, 3.51870932e-016, 5.34390777e-027],
[2.68966659e-196, 3.18554534e-004, 2.02012772e-002],
[2.28835722e-239, 1.18966325e-006, 2.61986932e-001],
[3.85005332e-001, 5.07486109e-018, 8.07090461e-029],
[2.97070927e+000, 2.25215273e-017, 1.45826047e-027],
[1.54669251e-245, 2.37994256e-005, 4.70449238e-002],
[2.97250230e+000, 9.20537370e-016, 5.86183174e-026],
[2.51256489e-001, 9.71966954e-018, 9.92273358e-029],
[7.67795136e-002, 1.81892177e-014, 9.92642821e-024],
[4.15395634e-093, 1.17820150e-001, 1.68421105e-006],
[1.00997094e-298, 7.11741017e-010, 1.22514460e-001],
[0.00000000e+000, 3.82851638e-012, 1.57513390e-002],
[3.22193669e+000, 6.59703177e-017, 3.69159440e-027],
[2.47369004e+000, 8.88106613e-016, 2.04206384e-026],
[3.01412924e+000, 1.68993929e-016, 8.30149544e-027],
[5.36914976e-122, 3.77332955e-001, 2.05007234e-004],
[4.87767060e-123, 1.22469010e-001, 1.47522326e-003],
[6.01262218e-001, 2.07501791e-017, 3.70249288e-028],
[4.61755454e-002, 9.48218948e-012, 1.18962106e-021],
[1.10260946e-111, 2.63666294e-001, 3.04482104e-004],
[7.18092701e-001, 1.33681661e-013, 1.44239452e-023],
[0.00000000e+000, 1.13413698e-015, 1.07163996e-003],
[4.83593087e-049, 1.81908946e-003, 5.75350754e-011],
[0.00000000e+000, 1.46950870e-013, 6.13059140e-004],
[1.77412583e-123, 6.95238806e-002, 1.38954915e-003],
[2.53482967e-001, 4.07966207e-020, 1.29199008e-029],
[1.70832646e-168, 1.07543910e-002, 4.74148015e-002],
[1.88690143e-002, 1.43838827e-019, 5.06182005e-029],
[0.00000000e+000, 5.26740196e-012, 7.33590052e-003],
[1.86389396e+000, 2.36489470e-016, 3.76544259e-026],
[1.35985047e+000, 8.55569443e-016, 2.67245797e-026],
[8.17806813e-294, 4.82666780e-008, 7.13465644e-002],
[3.28434438e+000, 1.63877804e-016, 5.26396730e-027],
[8.21098705e-277, 5.30883063e-010, 4.51771500e-002],
[1.00342674e-097, 4.36520033e-001, 3.67360555e-005],
[2.20897185e-083, 3.13721528e-001, 3.79694730e-006],
[1.58003504e-057, 3.62503830e-002, 9.71272783e-009],
[1.61348013e-243, 7.75810130e-008, 1.26212878e-001],
[3.80414054e-237, 1.09538068e-007, 1.49245747e-001],
[2.15851912e-161, 6.27229834e-002, 4.92630412e-003],
[1.95128444e-180, 4.93070200e-003, 8.08794435e-002],
[1.31803692e+000, 5.32420738e-015, 1.30436645e-025],
[7.79858859e-067, 3.01096779e-002, 8.74375374e-009],
[6.12107543e-279, 8.55857074e-010, 1.07798580e-001],
[4.66850302e-197, 1.48027054e-003, 5.47700919e-002],
[3.52624721e+000, 4.25565651e-016, 2.29068907e-026],
[7.63949242e-132, 1.22088863e-001, 1.01895184e-003],
[3.31703393e-097, 3.06149212e-001, 3.35870705e-005],
[5.37109191e-168, 5.75190751e-003, 3.23117267e-002],
[6.90508182e-119, 1.16325296e-001, 4.91416425e-005],
[7.83871527e-001, 4.61599415e-017, 3.49183358e-027],
[8.95165152e-001, 6.67684050e-015, 1.03729239e-024],
[1.09244100e+000, 4.97991843e-017, 1.10117672e-027],
[1.04987457e-233, 3.11807922e-004, 1.80129089e-001],
[1.54899418e-087, 1.25938919e-001, 6.09942673e-007],
[0.00000000e+000, 6.63898313e-016, 3.30717488e-004],
[1.49109871e+000, 5.04670598e-017, 1.01366241e-027]])
#判断多少预测正确了
np.argmax(pre_all,axis=1)==y_train.ravel()
array([ True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
False, True, True, True, True, True, True, False, False,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, False,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, False, True, True, True, True,
True, True, True, False, True, True, True, True, True,
True, True, True])
#计算精确率
np.sum(np.argmax(pre_all,axis=1)==y_train.ravel())/len(y_train.ravel())
0.95
def predict(X_test,y_test,u_0,sigma_0,u_1,sigma_1,u_2,sigma_2,lst_pri):
pre_0=gaussian_density(X_test,u_0,sigma_0)*lst_pri[0]
pre_1=gaussian_density(X_test,u_1,sigma_1)*lst_pri[1]
pre_2=gaussian_density(X_test,u_2,sigma_2)*lst_pri[2]
pre_all=np.hstack([pre_0.reshape(len(pre_0),1),pre_1.reshape(pre_1.shape[0],1),pre_2.reshape(pre_2.shape[0],1)])
return np.sum(np.argmax(pre_all,axis=1)==y_test.ravel())/len(y_test)
predict(X_test,y_test,u_0,sigma_0,u_1,sigma_1,u_2,sigma_2,lst_pri)
0.9666666666666667
from sklearn.naive_bayes import GaussianNB
clf=GaussianNB()
help(GaussianNB)
Help on class GaussianNB in module sklearn.naive_bayes:
class GaussianNB(_BaseNB)
| GaussianNB(*, priors=None, var_smoothing=1e-09)
|
| Gaussian Naive Bayes (GaussianNB).
|
| Can perform online updates to model parameters via :meth:`partial_fit`.
| For details on algorithm used to update feature means and variance online,
| see Stanford CS tech report STAN-CS-79-773 by Chan, Golub, and LeVeque:
|
| http://i.stanford.edu/pub/cstr/reports/cs/tr/79/773/CS-TR-79-773.pdf
|
| Read more in the :ref:`User Guide `.
|
| Parameters
| ----------
| priors : array-like of shape (n_classes,)
| Prior probabilities of the classes. If specified the priors are not
| adjusted according to the data.
|
| var_smoothing : float, default=1e-9
| Portion of the largest variance of all features that is added to
| variances for calculation stability.
|
| .. versionadded:: 0.20
|
| Attributes
| ----------
| class_count_ : ndarray of shape (n_classes,)
| number of training samples observed in each class.
|
| class_prior_ : ndarray of shape (n_classes,)
| probability of each class.
|
| classes_ : ndarray of shape (n_classes,)
| class labels known to the classifier.
|
| epsilon_ : float
| absolute additive value to variances.
|
| n_features_in_ : int
| Number of features seen during :term:`fit`.
|
| .. versionadded:: 0.24
|
| feature_names_in_ : ndarray of shape (`n_features_in_`,)
| Names of features seen during :term:`fit`. Defined only when `X`
| has feature names that are all strings.
|
| .. versionadded:: 1.0
|
| sigma_ : ndarray of shape (n_classes, n_features)
| Variance of each feature per class.
|
| .. deprecated:: 1.0
| `sigma_` is deprecated in 1.0 and will be removed in 1.2.
| Use `var_` instead.
|
| var_ : ndarray of shape (n_classes, n_features)
| Variance of each feature per class.
|
| .. versionadded:: 1.0
|
| theta_ : ndarray of shape (n_classes, n_features)
| mean of each feature per class.
|
| See Also
| --------
| BernoulliNB : Naive Bayes classifier for multivariate Bernoulli models.
| CategoricalNB : Naive Bayes classifier for categorical features.
| ComplementNB : Complement Naive Bayes classifier.
| MultinomialNB : Naive Bayes classifier for multinomial models.
|
| Examples
| --------
| >>> import numpy as np
| >>> X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]])
| >>> Y = np.array([1, 1, 1, 2, 2, 2])
| >>> from sklearn.naive_bayes import GaussianNB
| >>> clf = GaussianNB()
| >>> clf.fit(X, Y)
| GaussianNB()
| >>> print(clf.predict([[-0.8, -1]]))
| [1]
| >>> clf_pf = GaussianNB()
| >>> clf_pf.partial_fit(X, Y, np.unique(Y))
| GaussianNB()
| >>> print(clf_pf.predict([[-0.8, -1]]))
| [1]
|
| Method resolution order:
| GaussianNB
| _BaseNB
| sklearn.base.ClassifierMixin
| sklearn.base.BaseEstimator
| builtins.object
|
| Methods defined here:
|
| __init__(self, *, priors=None, var_smoothing=1e-09)
| Initialize self. See help(type(self)) for accurate signature.
|
| fit(self, X, y, sample_weight=None)
| Fit Gaussian Naive Bayes according to X, y.
|
| Parameters
| ----------
| X : array-like of shape (n_samples, n_features)
| Training vectors, where `n_samples` is the number of samples
| and `n_features` is the number of features.
|
| y : array-like of shape (n_samples,)
| Target values.
|
| sample_weight : array-like of shape (n_samples,), default=None
| Weights applied to individual samples (1. for unweighted).
|
| .. versionadded:: 0.17
| Gaussian Naive Bayes supports fitting with *sample_weight*.
|
| Returns
| -------
| self : object
| Returns the instance itself.
|
| partial_fit(self, X, y, classes=None, sample_weight=None)
| Incremental fit on a batch of samples.
|
| This method is expected to be called several times consecutively
| on different chunks of a dataset so as to implement out-of-core
| or online learning.
|
| This is especially useful when the whole dataset is too big to fit in
| memory at once.
|
| This method has some performance and numerical stability overhead,
| hence it is better to call partial_fit on chunks of data that are
| as large as possible (as long as fitting in the memory budget) to
| hide the overhead.
|
| Parameters
| ----------
| X : array-like of shape (n_samples, n_features)
| Training vectors, where `n_samples` is the number of samples and
| `n_features` is the number of features.
|
| y : array-like of shape (n_samples,)
| Target values.
|
| classes : array-like of shape (n_classes,), default=None
| List of all the classes that can possibly appear in the y vector.
|
| Must be provided at the first call to partial_fit, can be omitted
| in subsequent calls.
|
| sample_weight : array-like of shape (n_samples,), default=None
| Weights applied to individual samples (1. for unweighted).
|
| .. versionadded:: 0.17
|
| Returns
| -------
| self : object
| Returns the instance itself.
|
| ----------------------------------------------------------------------
| Readonly properties defined here:
|
| sigma_
| DEPRECATED: Attribute `sigma_` was deprecated in 1.0 and will be removed in1.2. Use `var_` instead.
|
| ----------------------------------------------------------------------
| Data and other attributes defined here:
|
| __abstractmethods__ = frozenset()
|
| ----------------------------------------------------------------------
| Methods inherited from _BaseNB:
|
| predict(self, X)
| Perform classification on an array of test vectors X.
|
| Parameters
| ----------
| X : array-like of shape (n_samples, n_features)
| The input samples.
|
| Returns
| -------
| C : ndarray of shape (n_samples,)
| Predicted target values for X.
|
| predict_log_proba(self, X)
| Return log-probability estimates for the test vector X.
|
| Parameters
| ----------
| X : array-like of shape (n_samples, n_features)
| The input samples.
|
| Returns
| -------
| C : array-like of shape (n_samples, n_classes)
| Returns the log-probability of the samples for each class in
| the model. The columns correspond to the classes in sorted
| order, as they appear in the attribute :term:`classes_`.
|
| predict_proba(self, X)
| Return probability estimates for the test vector X.
|
| Parameters
| ----------
| X : array-like of shape (n_samples, n_features)
| The input samples.
|
| Returns
| -------
| C : array-like of shape (n_samples, n_classes)
| Returns the probability of the samples for each class in
| the model. The columns correspond to the classes in sorted
| order, as they appear in the attribute :term:`classes_`.
|
| ----------------------------------------------------------------------
| Methods inherited from sklearn.base.ClassifierMixin:
|
| score(self, X, y, sample_weight=None)
| Return the mean accuracy on the given test data and labels.
|
| In multi-label classification, this is the subset accuracy
| which is a harsh metric since you require for each sample that
| each label set be correctly predicted.
|
| Parameters
| ----------
| X : array-like of shape (n_samples, n_features)
| Test samples.
|
| y : array-like of shape (n_samples,) or (n_samples, n_outputs)
| True labels for `X`.
|
| sample_weight : array-like of shape (n_samples,), default=None
| Sample weights.
|
| Returns
| -------
| score : float
| Mean accuracy of ``self.predict(X)`` wrt. `y`.
|
| ----------------------------------------------------------------------
| Data descriptors inherited from sklearn.base.ClassifierMixin:
|
| __dict__
| dictionary for instance variables (if defined)
|
| __weakref__
| list of weak references to the object (if defined)
|
| ----------------------------------------------------------------------
| Methods inherited from sklearn.base.BaseEstimator:
|
| __getstate__(self)
|
| __repr__(self, N_CHAR_MAX=700)
| Return repr(self).
|
| __setstate__(self, state)
|
| get_params(self, deep=True)
| Get parameters for this estimator.
|
| Parameters
| ----------
| deep : bool, default=True
| If True, will return the parameters for this estimator and
| contained subobjects that are estimators.
|
| Returns
| -------
| params : dict
| Parameter names mapped to their values.
|
| set_params(self, **params)
| Set the parameters of this estimator.
|
| The method works on simple estimators as well as on nested objects
| (such as :class:`~sklearn.pipeline.Pipeline`). The latter have
| parameters of the form ``__`` so that it's
| possible to update each component of a nested object.
|
| Parameters
| ----------
| **params : dict
| Estimator parameters.
|
| Returns
| -------
| self : estimator instance
| Estimator instance.
X.shape,y.shape
((150, 4), (150, 1))
clf.fit(X_train,y_train.ravel())
GaussianNB()
clf.score(X_test,y_test.ravel())
0.9666666666666667