【Python机器学习】实验06 贝叶斯推理 2

简介: 【Python机器学习】实验06 贝叶斯推理

8 定义概率密度函数

def gaussian_density(data,u,sigma):
    expo=np.exp(-np.power(data-u,2)/(2*sigma))
    coef=1/(np.sqrt(2*np.pi*sigma))
    return np.prod(coef*expo,axis=1)
#所有样本带入到第1个类别的高斯模型参数中得到的结果
pre_0=gaussian_density(X_train,u_0,sigma_0)*lst_pri[0]
pre_0
array([3.99415464e+000, 1.94367635e+000, 6.60889499e-097, 1.80752252e-082,
       1.44507736e-148, 8.63205906e-058, 1.77086187e-073, 1.72200357e-108,
       4.86671382e-134, 1.06674156e-132, 5.80979347e+000, 1.93582589e-001,
       6.83123642e-151, 3.80660319e-138, 3.54858798e-110, 2.47436003e+000,
       9.47627356e-114, 3.63995412e-001, 6.64092778e-003, 5.19779913e+000,
       1.15891783e-002, 5.07677505e+000, 2.86260160e+000, 2.21879073e-001,
       1.56640570e-001, 1.03157479e-131, 8.43689850e-092, 5.64628646e+000,
       3.64465774e+000, 5.22805105e+000, 5.83954842e-143, 3.24263354e+000,
       9.31529278e-001, 4.57789205e-002, 2.23448562e-161, 3.09648295e+000,
       1.00212662e+000, 5.17295325e-130, 1.09814912e-048, 1.88640805e-056,
       3.08491848e-137, 4.81085712e-001, 1.12504707e-129, 3.67995439e-002,
       3.91991816e-092, 3.70404421e+000, 1.97791635e+000, 5.18297633e+000,
       3.22002953e-109, 2.45629129e-042, 4.65684882e-078, 1.20020428e+000,
       3.47644237e-102, 5.30752338e-159, 2.67525891e-180, 2.14367370e+000,
       1.69559466e+000, 5.01330518e-065, 2.90136679e+000, 6.26263265e+000,
       9.91822069e-123, 6.08616441e-129, 7.38230838e-001, 2.42302202e-096,
       4.49573232e-170, 6.29495594e-117, 1.39322505e+000, 1.33577067e+000,
       1.49050826e-177, 1.31733476e+000, 5.16176371e-102, 4.55092123e-084,
       5.28027292e-073, 1.74659558e+000, 1.73554442e-002])
#所有样本带入到第2个类别的高斯模型参数中得到的结果
pre_1=gaussian_density(X_train,u_1,sigma_1)*lst_pri[1]
pre_1
array([6.88891263e-17, 2.52655671e-16, 6.66784142e-01, 4.39035170e-01,
       1.02097078e-01, 5.26743134e-04, 8.41179097e-02, 3.62626644e-01,
       7.91642821e-02, 1.44031642e-01, 2.76147108e-16, 6.67290518e-15,
       4.75292781e-02, 4.49054758e-01, 4.79673262e-01, 3.31237947e-16,
       4.53713921e-01, 5.07639533e-18, 8.97591672e-17, 2.14239456e-17,
       2.89264720e-18, 9.14486465e-16, 1.93935408e-16, 9.52254108e-18,
       1.72377778e-14, 4.48431308e-01, 2.11349055e-01, 6.33550524e-17,
       8.36586449e-16, 1.63398769e-16, 2.61589867e-02, 4.42217308e-16,
       2.04791994e-17, 9.81772333e-12, 2.65632115e-02, 8.48713904e-17,
       1.37974305e-13, 3.37353331e-01, 1.87800865e-03, 4.26608396e-02,
       4.58473827e-02, 3.33967704e-20, 2.47883299e-01, 1.36596674e-19,
       3.18444088e-01, 2.23261970e-16, 8.08973781e-16, 1.58016713e-16,
       6.30695919e-01, 2.54489986e-03, 1.61140759e-01, 8.06573695e-15,
       6.10877468e-01, 1.25788818e-01, 1.36687997e-02, 4.89645218e-15,
       8.15261126e-19, 3.32739495e-02, 4.87766404e-17, 4.05703434e-16,
       1.48439207e-01, 2.49686080e-01, 1.21546609e-17, 4.80883386e-01,
       1.36182282e-02, 1.75312606e-01, 4.57390205e-17, 6.63620680e-15,
       7.51872920e-02, 4.53624816e-17, 6.57207208e-01, 1.69998516e-01,
       2.35169368e-01, 4.90692552e-17, 1.93538305e-13])

9 计算训练集的预测结果

#得到训练集的预测结果
pre_all=np.hstack([pre_0.reshape(len(pre_0),1),pre_1.reshape(pre_1.shape[0],1)])
pre_all
array([[3.99415464e+000, 6.88891263e-017],
       [1.94367635e+000, 2.52655671e-016],
       [6.60889499e-097, 6.66784142e-001],
       [1.80752252e-082, 4.39035170e-001],
       [1.44507736e-148, 1.02097078e-001],
       [8.63205906e-058, 5.26743134e-004],
       [1.77086187e-073, 8.41179097e-002],
       [1.72200357e-108, 3.62626644e-001],
       [4.86671382e-134, 7.91642821e-002],
       [1.06674156e-132, 1.44031642e-001],
       [5.80979347e+000, 2.76147108e-016],
       [1.93582589e-001, 6.67290518e-015],
       [6.83123642e-151, 4.75292781e-002],
       [3.80660319e-138, 4.49054758e-001],
       [3.54858798e-110, 4.79673262e-001],
       [2.47436003e+000, 3.31237947e-016],
       [9.47627356e-114, 4.53713921e-001],
       [3.63995412e-001, 5.07639533e-018],
       [6.64092778e-003, 8.97591672e-017],
       [5.19779913e+000, 2.14239456e-017],
       [1.15891783e-002, 2.89264720e-018],
       [5.07677505e+000, 9.14486465e-016],
       [2.86260160e+000, 1.93935408e-016],
       [2.21879073e-001, 9.52254108e-018],
       [1.56640570e-001, 1.72377778e-014],
       [1.03157479e-131, 4.48431308e-001],
       [8.43689850e-092, 2.11349055e-001],
       [5.64628646e+000, 6.33550524e-017],
       [3.64465774e+000, 8.36586449e-016],
       [5.22805105e+000, 1.63398769e-016],
       [5.83954842e-143, 2.61589867e-002],
       [3.24263354e+000, 4.42217308e-016],
       [9.31529278e-001, 2.04791994e-017],
       [4.57789205e-002, 9.81772333e-012],
       [2.23448562e-161, 2.65632115e-002],
       [3.09648295e+000, 8.48713904e-017],
       [1.00212662e+000, 1.37974305e-013],
       [5.17295325e-130, 3.37353331e-001],
       [1.09814912e-048, 1.87800865e-003],
       [1.88640805e-056, 4.26608396e-002],
       [3.08491848e-137, 4.58473827e-002],
       [4.81085712e-001, 3.33967704e-020],
       [1.12504707e-129, 2.47883299e-001],
       [3.67995439e-002, 1.36596674e-019],
       [3.91991816e-092, 3.18444088e-001],
       [3.70404421e+000, 2.23261970e-016],
       [1.97791635e+000, 8.08973781e-016],
       [5.18297633e+000, 1.58016713e-016],
       [3.22002953e-109, 6.30695919e-001],
       [2.45629129e-042, 2.54489986e-003],
       [4.65684882e-078, 1.61140759e-001],
       [1.20020428e+000, 8.06573695e-015],
       [3.47644237e-102, 6.10877468e-001],
       [5.30752338e-159, 1.25788818e-001],
       [2.67525891e-180, 1.36687997e-002],
       [2.14367370e+000, 4.89645218e-015],
       [1.69559466e+000, 8.15261126e-019],
       [5.01330518e-065, 3.32739495e-002],
       [2.90136679e+000, 4.87766404e-017],
       [6.26263265e+000, 4.05703434e-016],
       [9.91822069e-123, 1.48439207e-001],
       [6.08616441e-129, 2.49686080e-001],
       [7.38230838e-001, 1.21546609e-017],
       [2.42302202e-096, 4.80883386e-001],
       [4.49573232e-170, 1.36182282e-002],
       [6.29495594e-117, 1.75312606e-001],
       [1.39322505e+000, 4.57390205e-017],
       [1.33577067e+000, 6.63620680e-015],
       [1.49050826e-177, 7.51872920e-002],
       [1.31733476e+000, 4.53624816e-017],
       [5.16176371e-102, 6.57207208e-001],
       [4.55092123e-084, 1.69998516e-001],
       [5.28027292e-073, 2.35169368e-001],
       [1.74659558e+000, 4.90692552e-017],
       [1.73554442e-002, 1.93538305e-013]])
array([0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0,
       0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0,
       1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1,
       0, 0, 1, 0, 1, 1, 1, 0, 0], dtype=int64)
#真实情况为
y_train.ravel()
array([0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0,
       0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0,
       1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1,
       0, 0, 1, 0, 1, 1, 1, 0, 0])
#判断多少预测正确了
np.argmax(pre_all,axis=1)==y_train.ravel()
array([ True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True])
#计算精确率
np.sum(np.argmax(pre_all,axis=1)==y_train.ravel())/len(y_train.ravel())
1.0

10 计算测试集的预测结果

def predict(X_test,y_test,u_0,sigma_0,u_1,sigma_1,lst_pri):
    pre_0=gaussian_density(X_test,u_0,sigma_0)*lst_pri[0]
    pre_1=gaussian_density(X_test,u_1,sigma_1)*lst_pri[1]
    pre_all=np.hstack([pre_0.reshape(len(pre_0),1),pre_1.reshape(pre_1.shape[0],1)])
    return np.sum(np.argmax(pre_all,axis=1)==y_test.ravel())/len(y_test)
predict(X_test,y_test,u_0,sigma_0,u_1,sigma_1,lst_pri)
1.0

试试sklearn-1 高斯分布

# 1 导入包
from sklearn.naive_bayes import GaussianNB, BernoulliNB,MultinomialNB
# 2建立模型
clf=GaussianNB()
# 3 拟合模型
clf.fit(X_train,y_train.ravel())
GaussianNB()
# 4 测试模型
clf.score(X_test,y_test)
1.0

试试sklearn-3 多项式分布

# 1 导入包
from sklearn.naive_bayes import GaussianNB, BernoulliNB,MultinomialNB
# 建立模型
clf=MultinomialNB()
# 3 拟合模型
clf.fit(X_train,y_train.ravel())
MultinomialNB()
# 4 测试模型
clf.score(X_test,y_test)
1.0

实验:使用完整的鸢尾花数据集来进行朴素贝叶斯分类

1 数据准备

from sklearn.datasets import load_iris
import pandas as pd
import numpy as np
#导入鸢尾花数据集
iris=load_iris()
#获得特征X,和相应的标签y
X=iris["data"]
y=iris["target"]
#查看X,y的形状
X.shape,y.shape
((150, 4), (150,))
#将y转换为二维数组
y=y.reshape((150,-1))
y.shape
(150, 1)
#通过数据框可视化
df=pd.DataFrame(np.hstack([X,y]),columns=iris.feature_names+["target"])
df
sepal length (cm) sepal width (cm) petal length (cm) petal width (cm) target
0 5.1 3.5 1.4 0.2 0.0
1 4.9 3.0 1.4 0.2 0.0
2 4.7 3.2 1.3 0.2 0.0
3 4.6 3.1 1.5 0.2 0.0
4 5.0 3.6 1.4 0.2 0.0
... ... ... ... ... ...
145 6.7 3.0 5.2 2.3 2.0
146 6.3 2.5 5.0 1.9 2.0
147 6.5 3.0 5.2 2.0 2.0
148 6.2 3.4 5.4 2.3 2.0
149 5.9 3.0 5.1 1.8 2.0

150 rows × 5 columns

#把标签列转为整型
df["target"]=df["target"].astype("int")
df
sepal length (cm) sepal width (cm) petal length (cm) petal width (cm) target
0 5.1 3.5 1.4 0.2 0
1 4.9 3.0 1.4 0.2 0
2 4.7 3.2 1.3 0.2 0
3 4.6 3.1 1.5 0.2 0
4 5.0 3.6 1.4 0.2 0
... ... ... ... ... ...
145 6.7 3.0 5.2 2.3 2
146 6.3 2.5 5.0 1.9 2
147 6.5 3.0 5.2 2.0 2
148 6.2 3.4 5.4 2.3 2
149 5.9 3.0 5.1 1.8 2

150 rows × 5 columns

#看看0,1,2类别分别是哪些列
index_0=df[df["target"]==0].index
index_0,len(index_0)
(Int64Index([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
             17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
             34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49],
            dtype='int64'),
 50)


目录
相关文章
|
2天前
|
机器学习/深度学习 人工智能 算法
【手写数字识别】Python+深度学习+机器学习+人工智能+TensorFlow+算法模型
手写数字识别系统,使用Python作为主要开发语言,基于深度学习TensorFlow框架,搭建卷积神经网络算法。并通过对数据集进行训练,最后得到一个识别精度较高的模型。并基于Flask框架,开发网页端操作平台,实现用户上传一张图片识别其名称。
11 0
【手写数字识别】Python+深度学习+机器学习+人工智能+TensorFlow+算法模型
|
5天前
|
机器学习/深度学习 数据采集 人工智能
探索机器学习:从理论到Python代码实践
【10月更文挑战第36天】本文将深入浅出地介绍机器学习的基本概念、主要算法及其在Python中的实现。我们将通过实际案例,展示如何使用scikit-learn库进行数据预处理、模型选择和参数调优。无论你是初学者还是有一定基础的开发者,都能从中获得启发和实践指导。
11 2
|
6天前
|
机器学习/深度学习 数据采集 搜索推荐
利用Python和机器学习构建电影推荐系统
利用Python和机器学习构建电影推荐系统
20 1
|
6天前
|
机器学习/深度学习 算法 PyTorch
用Python实现简单机器学习模型:以鸢尾花数据集为例
用Python实现简单机器学习模型:以鸢尾花数据集为例
19 1
|
6月前
|
机器学习/深度学习 人工智能 自然语言处理
【Python机器学习】文本特征提取及文本向量化讲解和实战(图文解释 附源码)
【Python机器学习】文本特征提取及文本向量化讲解和实战(图文解释 附源码)
405 0
|
6月前
|
机器学习/深度学习 算法 数据挖掘
【Python机器学习】K-Means对文本聚类和半环形数据聚类实战(附源码和数据集)
【Python机器学习】K-Means对文本聚类和半环形数据聚类实战(附源码和数据集)
182 0
|
1月前
|
机器学习/深度学习 算法 数据挖掘
【Python篇】深度探索NumPy(下篇):从科学计算到机器学习的高效实战技巧1
【Python篇】深度探索NumPy(下篇):从科学计算到机器学习的高效实战技巧
50 5
|
1月前
|
机器学习/深度学习 数据采集 分布式计算
【Python篇】深入机器学习核心:XGBoost 从入门到实战
【Python篇】深入机器学习核心:XGBoost 从入门到实战
85 3
|
1月前
|
机器学习/深度学习 算法 数据可视化
【Python篇】深度探索NumPy(下篇):从科学计算到机器学习的高效实战技巧2
【Python篇】深度探索NumPy(下篇):从科学计算到机器学习的高效实战技巧
38 1
|
2月前
|
机器学习/深度学习 算法 Python
决策树下的智慧果实:Python机器学习实战,轻松摘取数据洞察的果实
【9月更文挑战第7天】当我们身处数据海洋,如何提炼出有价值的洞察?决策树作为一种直观且强大的机器学习算法,宛如智慧之树,引领我们在繁复的数据中找到答案。通过Python的scikit-learn库,我们可以轻松实现决策树模型,对数据进行分类或回归分析。本教程将带领大家从零开始,通过实际案例掌握决策树的原理与应用,探索数据中的秘密。
47 1