8 定义概率密度函数
def gaussian_density(data,u,sigma): expo=np.exp(-np.power(data-u,2)/(2*sigma)) coef=1/(np.sqrt(2*np.pi*sigma)) return np.prod(coef*expo,axis=1)
#所有样本带入到第1个类别的高斯模型参数中得到的结果 pre_0=gaussian_density(X_train,u_0,sigma_0)*lst_pri[0] pre_0
array([3.99415464e+000, 1.94367635e+000, 6.60889499e-097, 1.80752252e-082, 1.44507736e-148, 8.63205906e-058, 1.77086187e-073, 1.72200357e-108, 4.86671382e-134, 1.06674156e-132, 5.80979347e+000, 1.93582589e-001, 6.83123642e-151, 3.80660319e-138, 3.54858798e-110, 2.47436003e+000, 9.47627356e-114, 3.63995412e-001, 6.64092778e-003, 5.19779913e+000, 1.15891783e-002, 5.07677505e+000, 2.86260160e+000, 2.21879073e-001, 1.56640570e-001, 1.03157479e-131, 8.43689850e-092, 5.64628646e+000, 3.64465774e+000, 5.22805105e+000, 5.83954842e-143, 3.24263354e+000, 9.31529278e-001, 4.57789205e-002, 2.23448562e-161, 3.09648295e+000, 1.00212662e+000, 5.17295325e-130, 1.09814912e-048, 1.88640805e-056, 3.08491848e-137, 4.81085712e-001, 1.12504707e-129, 3.67995439e-002, 3.91991816e-092, 3.70404421e+000, 1.97791635e+000, 5.18297633e+000, 3.22002953e-109, 2.45629129e-042, 4.65684882e-078, 1.20020428e+000, 3.47644237e-102, 5.30752338e-159, 2.67525891e-180, 2.14367370e+000, 1.69559466e+000, 5.01330518e-065, 2.90136679e+000, 6.26263265e+000, 9.91822069e-123, 6.08616441e-129, 7.38230838e-001, 2.42302202e-096, 4.49573232e-170, 6.29495594e-117, 1.39322505e+000, 1.33577067e+000, 1.49050826e-177, 1.31733476e+000, 5.16176371e-102, 4.55092123e-084, 5.28027292e-073, 1.74659558e+000, 1.73554442e-002])
#所有样本带入到第2个类别的高斯模型参数中得到的结果 pre_1=gaussian_density(X_train,u_1,sigma_1)*lst_pri[1] pre_1
array([6.88891263e-17, 2.52655671e-16, 6.66784142e-01, 4.39035170e-01, 1.02097078e-01, 5.26743134e-04, 8.41179097e-02, 3.62626644e-01, 7.91642821e-02, 1.44031642e-01, 2.76147108e-16, 6.67290518e-15, 4.75292781e-02, 4.49054758e-01, 4.79673262e-01, 3.31237947e-16, 4.53713921e-01, 5.07639533e-18, 8.97591672e-17, 2.14239456e-17, 2.89264720e-18, 9.14486465e-16, 1.93935408e-16, 9.52254108e-18, 1.72377778e-14, 4.48431308e-01, 2.11349055e-01, 6.33550524e-17, 8.36586449e-16, 1.63398769e-16, 2.61589867e-02, 4.42217308e-16, 2.04791994e-17, 9.81772333e-12, 2.65632115e-02, 8.48713904e-17, 1.37974305e-13, 3.37353331e-01, 1.87800865e-03, 4.26608396e-02, 4.58473827e-02, 3.33967704e-20, 2.47883299e-01, 1.36596674e-19, 3.18444088e-01, 2.23261970e-16, 8.08973781e-16, 1.58016713e-16, 6.30695919e-01, 2.54489986e-03, 1.61140759e-01, 8.06573695e-15, 6.10877468e-01, 1.25788818e-01, 1.36687997e-02, 4.89645218e-15, 8.15261126e-19, 3.32739495e-02, 4.87766404e-17, 4.05703434e-16, 1.48439207e-01, 2.49686080e-01, 1.21546609e-17, 4.80883386e-01, 1.36182282e-02, 1.75312606e-01, 4.57390205e-17, 6.63620680e-15, 7.51872920e-02, 4.53624816e-17, 6.57207208e-01, 1.69998516e-01, 2.35169368e-01, 4.90692552e-17, 1.93538305e-13])
9 计算训练集的预测结果
#得到训练集的预测结果 pre_all=np.hstack([pre_0.reshape(len(pre_0),1),pre_1.reshape(pre_1.shape[0],1)]) pre_all
array([[3.99415464e+000, 6.88891263e-017], [1.94367635e+000, 2.52655671e-016], [6.60889499e-097, 6.66784142e-001], [1.80752252e-082, 4.39035170e-001], [1.44507736e-148, 1.02097078e-001], [8.63205906e-058, 5.26743134e-004], [1.77086187e-073, 8.41179097e-002], [1.72200357e-108, 3.62626644e-001], [4.86671382e-134, 7.91642821e-002], [1.06674156e-132, 1.44031642e-001], [5.80979347e+000, 2.76147108e-016], [1.93582589e-001, 6.67290518e-015], [6.83123642e-151, 4.75292781e-002], [3.80660319e-138, 4.49054758e-001], [3.54858798e-110, 4.79673262e-001], [2.47436003e+000, 3.31237947e-016], [9.47627356e-114, 4.53713921e-001], [3.63995412e-001, 5.07639533e-018], [6.64092778e-003, 8.97591672e-017], [5.19779913e+000, 2.14239456e-017], [1.15891783e-002, 2.89264720e-018], [5.07677505e+000, 9.14486465e-016], [2.86260160e+000, 1.93935408e-016], [2.21879073e-001, 9.52254108e-018], [1.56640570e-001, 1.72377778e-014], [1.03157479e-131, 4.48431308e-001], [8.43689850e-092, 2.11349055e-001], [5.64628646e+000, 6.33550524e-017], [3.64465774e+000, 8.36586449e-016], [5.22805105e+000, 1.63398769e-016], [5.83954842e-143, 2.61589867e-002], [3.24263354e+000, 4.42217308e-016], [9.31529278e-001, 2.04791994e-017], [4.57789205e-002, 9.81772333e-012], [2.23448562e-161, 2.65632115e-002], [3.09648295e+000, 8.48713904e-017], [1.00212662e+000, 1.37974305e-013], [5.17295325e-130, 3.37353331e-001], [1.09814912e-048, 1.87800865e-003], [1.88640805e-056, 4.26608396e-002], [3.08491848e-137, 4.58473827e-002], [4.81085712e-001, 3.33967704e-020], [1.12504707e-129, 2.47883299e-001], [3.67995439e-002, 1.36596674e-019], [3.91991816e-092, 3.18444088e-001], [3.70404421e+000, 2.23261970e-016], [1.97791635e+000, 8.08973781e-016], [5.18297633e+000, 1.58016713e-016], [3.22002953e-109, 6.30695919e-001], [2.45629129e-042, 2.54489986e-003], [4.65684882e-078, 1.61140759e-001], [1.20020428e+000, 8.06573695e-015], [3.47644237e-102, 6.10877468e-001], [5.30752338e-159, 1.25788818e-001], [2.67525891e-180, 1.36687997e-002], [2.14367370e+000, 4.89645218e-015], [1.69559466e+000, 8.15261126e-019], [5.01330518e-065, 3.32739495e-002], [2.90136679e+000, 4.87766404e-017], [6.26263265e+000, 4.05703434e-016], [9.91822069e-123, 1.48439207e-001], [6.08616441e-129, 2.49686080e-001], [7.38230838e-001, 1.21546609e-017], [2.42302202e-096, 4.80883386e-001], [4.49573232e-170, 1.36182282e-002], [6.29495594e-117, 1.75312606e-001], [1.39322505e+000, 4.57390205e-017], [1.33577067e+000, 6.63620680e-015], [1.49050826e-177, 7.51872920e-002], [1.31733476e+000, 4.53624816e-017], [5.16176371e-102, 6.57207208e-001], [4.55092123e-084, 1.69998516e-001], [5.28027292e-073, 2.35169368e-001], [1.74659558e+000, 4.90692552e-017], [1.73554442e-002, 1.93538305e-013]])
array([0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0], dtype=int64)
#真实情况为 y_train.ravel()
array([0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0])
#判断多少预测正确了 np.argmax(pre_all,axis=1)==y_train.ravel()
array([ True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True])
#计算精确率 np.sum(np.argmax(pre_all,axis=1)==y_train.ravel())/len(y_train.ravel())
1.0
10 计算测试集的预测结果
def predict(X_test,y_test,u_0,sigma_0,u_1,sigma_1,lst_pri): pre_0=gaussian_density(X_test,u_0,sigma_0)*lst_pri[0] pre_1=gaussian_density(X_test,u_1,sigma_1)*lst_pri[1] pre_all=np.hstack([pre_0.reshape(len(pre_0),1),pre_1.reshape(pre_1.shape[0],1)]) return np.sum(np.argmax(pre_all,axis=1)==y_test.ravel())/len(y_test)
predict(X_test,y_test,u_0,sigma_0,u_1,sigma_1,lst_pri)
1.0
试试sklearn-1 高斯分布
# 1 导入包 from sklearn.naive_bayes import GaussianNB, BernoulliNB,MultinomialNB
# 2建立模型 clf=GaussianNB()
# 3 拟合模型 clf.fit(X_train,y_train.ravel())
GaussianNB()
# 4 测试模型 clf.score(X_test,y_test)
1.0
试试sklearn-3 多项式分布
# 1 导入包 from sklearn.naive_bayes import GaussianNB, BernoulliNB,MultinomialNB
# 建立模型 clf=MultinomialNB()
# 3 拟合模型 clf.fit(X_train,y_train.ravel())
MultinomialNB()
# 4 测试模型 clf.score(X_test,y_test)
1.0
实验:使用完整的鸢尾花数据集来进行朴素贝叶斯分类
1 数据准备
from sklearn.datasets import load_iris import pandas as pd import numpy as np #导入鸢尾花数据集 iris=load_iris() #获得特征X,和相应的标签y X=iris["data"] y=iris["target"]
#查看X,y的形状 X.shape,y.shape
((150, 4), (150,))
#将y转换为二维数组 y=y.reshape((150,-1)) y.shape
(150, 1)
#通过数据框可视化 df=pd.DataFrame(np.hstack([X,y]),columns=iris.feature_names+["target"]) df
sepal length (cm) | sepal width (cm) | petal length (cm) | petal width (cm) | target | |
0 | 5.1 | 3.5 | 1.4 | 0.2 | 0.0 |
1 | 4.9 | 3.0 | 1.4 | 0.2 | 0.0 |
2 | 4.7 | 3.2 | 1.3 | 0.2 | 0.0 |
3 | 4.6 | 3.1 | 1.5 | 0.2 | 0.0 |
4 | 5.0 | 3.6 | 1.4 | 0.2 | 0.0 |
... | ... | ... | ... | ... | ... |
145 | 6.7 | 3.0 | 5.2 | 2.3 | 2.0 |
146 | 6.3 | 2.5 | 5.0 | 1.9 | 2.0 |
147 | 6.5 | 3.0 | 5.2 | 2.0 | 2.0 |
148 | 6.2 | 3.4 | 5.4 | 2.3 | 2.0 |
149 | 5.9 | 3.0 | 5.1 | 1.8 | 2.0 |
150 rows × 5 columns
#把标签列转为整型 df["target"]=df["target"].astype("int") df
sepal length (cm) | sepal width (cm) | petal length (cm) | petal width (cm) | target | |
0 | 5.1 | 3.5 | 1.4 | 0.2 | 0 |
1 | 4.9 | 3.0 | 1.4 | 0.2 | 0 |
2 | 4.7 | 3.2 | 1.3 | 0.2 | 0 |
3 | 4.6 | 3.1 | 1.5 | 0.2 | 0 |
4 | 5.0 | 3.6 | 1.4 | 0.2 | 0 |
... | ... | ... | ... | ... | ... |
145 | 6.7 | 3.0 | 5.2 | 2.3 | 2 |
146 | 6.3 | 2.5 | 5.0 | 1.9 | 2 |
147 | 6.5 | 3.0 | 5.2 | 2.0 | 2 |
148 | 6.2 | 3.4 | 5.4 | 2.3 | 2 |
149 | 5.9 | 3.0 | 5.1 | 1.8 | 2 |
150 rows × 5 columns
#看看0,1,2类别分别是哪些列 index_0=df[df["target"]==0].index index_0,len(index_0)
(Int64Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49], dtype='int64'), 50)