【Python机器学习】实验06 贝叶斯推理3

简介: 【Python机器学习】实验06 贝叶斯推理

2 数据读取–训练集和测试集的划分

#划分数据为训练数据和测试数据
from sklearn.model_selection import train_test_split
X_train,X_test,y_train,y_test=train_test_split(X,y,test_size=0.2,random_state=0)
X_train.shape,X_test.shape,y_train.shape,y_test.shape
((120, 4), (30, 4), (120, 1), (30, 1))

3 数据读取–准备好每个类别各自的数据

#看看哪些索引处的标签为0
np.where(y_train==0)
(array([  2,   6,  11,  13,  14,  31,  38,  39,  42,  43,  45,  48,  52,
         57,  58,  61,  63,  66,  67,  69,  70,  71,  75,  76,  77,  80,
         81,  83,  88,  90,  92,  93,  95, 104, 108, 113, 114, 115, 119],
       dtype=int64),
 array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int64))
#新建一个字典,存储每个标签对应的索引,该操作的目的是为了后面对不同类别分别计算均值和方差
dic={}
for i in [0,1,2]:
    dic[i]=np.where(y_train==i)
dic
{0: (array([  2,   6,  11,  13,  14,  31,  38,  39,  42,  43,  45,  48,  52,
          57,  58,  61,  63,  66,  67,  69,  70,  71,  75,  76,  77,  80,
          81,  83,  88,  90,  92,  93,  95, 104, 108, 113, 114, 115, 119],
        dtype=int64),
  array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int64)),
 1: (array([  1,   5,   7,   8,   9,  15,  20,  22,  23,  28,  30,  33,  34,
          35,  36,  41,  44,  47,  49,  51,  72,  78,  79,  82,  85,  87,
          97,  98,  99, 102, 103, 105, 109, 110, 111, 112, 117], dtype=int64),
  array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int64)),
 2: (array([  0,   3,   4,  10,  12,  16,  17,  18,  19,  21,  24,  25,  26,
          27,  29,  32,  37,  40,  46,  50,  53,  54,  55,  56,  59,  60,
          62,  64,  65,  68,  73,  74,  84,  86,  89,  91,  94,  96, 100,
         101, 106, 107, 116, 118], dtype=int64),
  array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        dtype=int64))}

4 定义数据的均值和方差

#计算均值和方差,对于每个特征(列这个维度)计算均值和方差,因此,有多少个特征,那么均值和方差向量中就有多少个元素
#X为数据框
def u_sigma(X):
    u=np.mean(X,axis=0)
    sigma=np.var(X,axis=0)
    return u,sigma
dic[0][0]
array([  2,   6,  11,  13,  14,  31,  38,  39,  42,  43,  45,  48,  52,
        57,  58,  61,  63,  66,  67,  69,  70,  71,  75,  76,  77,  80,
        81,  83,  88,  90,  92,  93,  95, 104, 108, 113, 114, 115, 119],
      dtype=int64)
#计算类别0(普通鸢尾花)的均值和方差
u_0,sigma_0=u_sigma(X_train[dic[0][0],:])
u_0,sigma_0
(array([5.02051282, 3.4025641 , 1.46153846, 0.24102564]),
 array([0.12932281, 0.1417883 , 0.02031558, 0.01113741]))
#计算类别1(山鸢尾花)的均值和方差
u_1,sigma_1=u_sigma(X_train[dic[1][0],:])
u_1,sigma_1
(array([5.88648649, 2.76216216, 4.21621622, 1.32432432]),
 array([0.26387144, 0.1039737 , 0.2300073 , 0.04075968]))
#计算类别2(维吉利亚尾花)的均值和方差
u_2,sigma_2=u_sigma(X_train[dic[2][0],:])
u_2,sigma_2
(array([6.63863636, 2.98863636, 5.56590909, 2.03181818]),
 array([0.38918905, 0.10782541, 0.29451963, 0.06444215]))

5 定义每个类别的先验概率

#计算每个类别对应的先验概率
lst_pri=[]
for i in [0,1,2]:
    lst_pri.append(len(dic[i][0]))
lst_pri=[item/sum(lst_pri) for item in lst_pri]
lst_pri
[0.325, 0.30833333333333335, 0.36666666666666664]

6 调用概率密度函数

#所有样本带入到第1个类别的高斯模型参数中得到的结果
pre_0=gaussian_density(X_train,u_0,sigma_0)*lst_pri[0]
pre_0
array([3.64205427e-225, 3.40844822e-130, 3.08530851e+000, 4.39737931e-176,
       9.32161971e-262, 1.12603195e-090, 8.19955989e-002, 5.38088810e-180,
       9.99826548e-113, 6.22294079e-089, 2.18584476e-247, 1.14681255e+000,
       2.38802541e-230, 7.48076601e-003, 1.51577355e+000, 8.84977214e-059,
       9.40380304e-226, 2.20471084e-296, 1.11546261e-168, 1.12595279e-254,
       7.13493544e-080, 0.00000000e+000, 5.43149166e-151, 7.00401162e-075,
       2.20419920e-177, 7.88959967e-176, 1.41957694e-141, 1.31858669e-191,
       4.74468428e-145, 9.39276491e-214, 2.02942932e-136, 1.40273451e+000,
       4.66850302e-197, 1.84403192e-103, 8.15997638e-072, 1.70855259e-092,
       8.50513873e-134, 1.04684523e-275, 1.95561507e+000, 5.03262010e-003,
       3.23862571e-215, 3.13715578e-099, 5.29812808e-001, 6.29658079e-003,
       1.81543604e-163, 1.32072621e+000, 1.48741944e-190, 4.61289448e-041,
       1.58979789e+000, 2.96357473e-134, 0.00000000e+000, 2.65155682e-103,
       7.05472630e-001, 1.42166693e-285, 8.68838944e-281, 4.74069911e-280,
       2.59051414e-254, 1.30709804e+000, 1.93716067e+000, 1.10437770e-205,
       2.87463392e-264, 8.77307761e-003, 6.56796757e-251, 1.82259183e+000,
       2.68966659e-196, 2.28835722e-239, 3.85005332e-001, 2.97070927e+000,
       1.54669251e-245, 2.97250230e+000, 2.51256489e-001, 7.67795136e-002,
       4.15395634e-093, 1.00997094e-298, 0.00000000e+000, 3.22193669e+000,
       2.47369004e+000, 3.01412924e+000, 5.36914976e-122, 4.87767060e-123,
       6.01262218e-001, 4.61755454e-002, 1.10260946e-111, 7.18092701e-001,
       0.00000000e+000, 4.83593087e-049, 0.00000000e+000, 1.77412583e-123,
       2.53482967e-001, 1.70832646e-168, 1.88690143e-002, 0.00000000e+000,
       1.86389396e+000, 1.35985047e+000, 8.17806813e-294, 3.28434438e+000,
       8.21098705e-277, 1.00342674e-097, 2.20897185e-083, 1.58003504e-057,
       1.61348013e-243, 3.80414054e-237, 2.15851912e-161, 1.95128444e-180,
       1.31803692e+000, 7.79858859e-067, 6.12107543e-279, 4.66850302e-197,
       3.52624721e+000, 7.63949242e-132, 3.31703393e-097, 5.37109191e-168,
       6.90508182e-119, 7.83871527e-001, 8.95165152e-001, 1.09244100e+000,
       1.04987457e-233, 1.54899418e-087, 0.00000000e+000, 1.49109871e+000])
#所有样本带入到第2个类别的高斯模型参数中得到的结果
pre_1=gaussian_density(X_train,u_1,sigma_1)*lst_pri[1]
pre_1
array([2.95633338e-04, 1.36197317e-01, 2.90318178e-16, 7.67369010e-03,
       3.75455611e-07, 1.46797523e-01, 6.95344048e-15, 3.36175041e-02,
       2.53841239e-01, 3.16199307e-01, 1.32212698e-06, 2.31912196e-17,
       6.23661197e-08, 4.43491705e-12, 9.03659728e-17, 6.06688573e-04,
       3.14945948e-04, 1.24882948e-11, 1.87288422e-02, 2.66560740e-05,
       1.30000970e-01, 2.76182931e-12, 2.07410916e-02, 7.22817433e-02,
       7.79602598e-03, 4.38522048e-02, 8.22673683e-03, 1.14220807e-03,
       1.03590806e-02, 8.19796704e-05, 2.21991209e-02, 1.91118667e-15,
       1.48027054e-03, 4.05979965e-01, 1.65444313e-01, 2.36465225e-01,
       2.30302015e-01, 4.54901890e-07, 7.37406496e-17, 2.21052310e-20,
       3.87241584e-04, 2.87187564e-01, 8.53516604e-15, 3.46342632e-18,
       9.95391379e-03, 2.43959119e-16, 4.23043625e-03, 2.34628172e-03,
       2.50262009e-16, 5.08355498e-02, 1.22369433e-14, 4.12873889e-01,
       1.33213958e-17, 2.98880456e-08, 1.95809747e-09, 6.40227550e-08,
       2.84653316e-06, 5.40191505e-17, 4.67733730e-16, 6.42382537e-05,
       1.79818302e-07, 1.09855352e-16, 2.30402853e-08, 3.51870932e-16,
       3.18554534e-04, 1.18966325e-06, 5.07486109e-18, 2.25215273e-17,
       2.37994256e-05, 9.20537370e-16, 9.71966954e-18, 1.81892177e-14,
       1.17820150e-01, 7.11741017e-10, 3.82851638e-12, 6.59703177e-17,
       8.88106613e-16, 1.68993929e-16, 3.77332955e-01, 1.22469010e-01,
       2.07501791e-17, 9.48218948e-12, 2.63666294e-01, 1.33681661e-13,
       1.13413698e-15, 1.81908946e-03, 1.46950870e-13, 6.95238806e-02,
       4.07966207e-20, 1.07543910e-02, 1.43838827e-19, 5.26740196e-12,
       2.36489470e-16, 8.55569443e-16, 4.82666780e-08, 1.63877804e-16,
       5.30883063e-10, 4.36520033e-01, 3.13721528e-01, 3.62503830e-02,
       7.75810130e-08, 1.09538068e-07, 6.27229834e-02, 4.93070200e-03,
       5.32420738e-15, 3.01096779e-02, 8.55857074e-10, 1.48027054e-03,
       4.25565651e-16, 1.22088863e-01, 3.06149212e-01, 5.75190751e-03,
       1.16325296e-01, 4.61599415e-17, 6.67684050e-15, 4.97991843e-17,
       3.11807922e-04, 1.25938919e-01, 6.63898313e-16, 5.04670598e-17])
#所有样本带入到第3个类别的高斯模型参数中得到的结果
pre_2=gaussian_density(X_train,u_2,sigma_2)*lst_pri[2]
pre_2
array([1.88926441e-01, 7.41874323e-04, 2.18905385e-26, 7.03342033e-02,
       2.07838563e-01, 6.36007282e-06, 3.75616194e-24, 2.15583340e-02,
       7.65683494e-04, 4.80086802e-06, 3.04560221e-01, 4.03768532e-28,
       1.12679216e-01, 9.72668930e-22, 1.40128825e-26, 2.07279668e-11,
       2.09922203e-01, 3.69933717e-02, 7.04823898e-04, 6.49975333e-02,
       2.90135522e-07, 2.72821894e-02, 1.19387091e-02, 7.43267743e-08,
       5.99160309e-02, 1.85609819e-02, 1.38418438e-04, 4.76244749e-02,
       2.86112072e-03, 2.53639963e-01, 3.04064364e-03, 5.04262171e-26,
       5.47700919e-02, 3.69353344e-05, 1.75987852e-06, 5.01849240e-06,
       2.09975476e-03, 8.54119142e-02, 1.00630371e-26, 1.53267285e-31,
       1.61099289e-01, 2.08157220e-05, 9.87308671e-25, 7.12483734e-27,
       1.49368318e-02, 4.76225689e-27, 7.43930795e-02, 8.62041503e-11,
       9.03427577e-27, 2.32663919e-04, 4.36377985e-03, 6.75646957e-05,
       1.81992485e-28, 1.99685684e-01, 1.36031284e-01, 2.34763950e-01,
       2.49673422e-01, 9.27207512e-27, 2.43693353e-26, 1.79134484e-01,
       1.95463733e-01, 3.06844563e-28, 6.40538684e-02, 5.34390777e-27,
       2.02012772e-02, 2.61986932e-01, 8.07090461e-29, 1.45826047e-27,
       4.70449238e-02, 5.86183174e-26, 9.92273358e-29, 9.92642821e-24,
       1.68421105e-06, 1.22514460e-01, 1.57513390e-02, 3.69159440e-27,
       2.04206384e-26, 8.30149544e-27, 2.05007234e-04, 1.47522326e-03,
       3.70249288e-28, 1.18962106e-21, 3.04482104e-04, 1.44239452e-23,
       1.07163996e-03, 5.75350754e-11, 6.13059140e-04, 1.38954915e-03,
       1.29199008e-29, 4.74148015e-02, 5.06182005e-29, 7.33590052e-03,
       3.76544259e-26, 2.67245797e-26, 7.13465644e-02, 5.26396730e-27,
       4.51771500e-02, 3.67360555e-05, 3.79694730e-06, 9.71272783e-09,
       1.26212878e-01, 1.49245747e-01, 4.92630412e-03, 8.08794435e-02,
       1.30436645e-25, 8.74375374e-09, 1.07798580e-01, 5.47700919e-02,
       2.29068907e-26, 1.01895184e-03, 3.35870705e-05, 3.23117267e-02,
       4.91416425e-05, 3.49183358e-27, 1.03729239e-24, 1.10117672e-27,
       1.80129089e-01, 6.09942673e-07, 3.30717488e-04, 1.01366241e-27])

7 计算训练集的预测结果

#得到训练集的预测结果
pre_all=np.hstack([pre_0.reshape(len(pre_0),1),pre_1.reshape(pre_1.shape[0],1),pre_2.reshape(pre_2.shape[0],1)])
pre_all
array([[3.64205427e-225, 2.95633338e-004, 1.88926441e-001],
       [3.40844822e-130, 1.36197317e-001, 7.41874323e-004],
       [3.08530851e+000, 2.90318178e-016, 2.18905385e-026],
       [4.39737931e-176, 7.67369010e-003, 7.03342033e-002],
       [9.32161971e-262, 3.75455611e-007, 2.07838563e-001],
       [1.12603195e-090, 1.46797523e-001, 6.36007282e-006],
       [8.19955989e-002, 6.95344048e-015, 3.75616194e-024],
       [5.38088810e-180, 3.36175041e-002, 2.15583340e-002],
       [9.99826548e-113, 2.53841239e-001, 7.65683494e-004],
       [6.22294079e-089, 3.16199307e-001, 4.80086802e-006],
       [2.18584476e-247, 1.32212698e-006, 3.04560221e-001],
       [1.14681255e+000, 2.31912196e-017, 4.03768532e-028],
       [2.38802541e-230, 6.23661197e-008, 1.12679216e-001],
       [7.48076601e-003, 4.43491705e-012, 9.72668930e-022],
       [1.51577355e+000, 9.03659728e-017, 1.40128825e-026],
       [8.84977214e-059, 6.06688573e-004, 2.07279668e-011],
       [9.40380304e-226, 3.14945948e-004, 2.09922203e-001],
       [2.20471084e-296, 1.24882948e-011, 3.69933717e-002],
       [1.11546261e-168, 1.87288422e-002, 7.04823898e-004],
       [1.12595279e-254, 2.66560740e-005, 6.49975333e-002],
       [7.13493544e-080, 1.30000970e-001, 2.90135522e-007],
       [0.00000000e+000, 2.76182931e-012, 2.72821894e-002],
       [5.43149166e-151, 2.07410916e-002, 1.19387091e-002],
       [7.00401162e-075, 7.22817433e-002, 7.43267743e-008],
       [2.20419920e-177, 7.79602598e-003, 5.99160309e-002],
       [7.88959967e-176, 4.38522048e-002, 1.85609819e-002],
       [1.41957694e-141, 8.22673683e-003, 1.38418438e-004],
       [1.31858669e-191, 1.14220807e-003, 4.76244749e-002],
       [4.74468428e-145, 1.03590806e-002, 2.86112072e-003],
       [9.39276491e-214, 8.19796704e-005, 2.53639963e-001],
       [2.02942932e-136, 2.21991209e-002, 3.04064364e-003],
       [1.40273451e+000, 1.91118667e-015, 5.04262171e-026],
       [4.66850302e-197, 1.48027054e-003, 5.47700919e-002],
       [1.84403192e-103, 4.05979965e-001, 3.69353344e-005],
       [8.15997638e-072, 1.65444313e-001, 1.75987852e-006],
       [1.70855259e-092, 2.36465225e-001, 5.01849240e-006],
       [8.50513873e-134, 2.30302015e-001, 2.09975476e-003],
       [1.04684523e-275, 4.54901890e-007, 8.54119142e-002],
       [1.95561507e+000, 7.37406496e-017, 1.00630371e-026],
       [5.03262010e-003, 2.21052310e-020, 1.53267285e-031],
       [3.23862571e-215, 3.87241584e-004, 1.61099289e-001],
       [3.13715578e-099, 2.87187564e-001, 2.08157220e-005],
       [5.29812808e-001, 8.53516604e-015, 9.87308671e-025],
       [6.29658079e-003, 3.46342632e-018, 7.12483734e-027],
       [1.81543604e-163, 9.95391379e-003, 1.49368318e-002],
       [1.32072621e+000, 2.43959119e-016, 4.76225689e-027],
       [1.48741944e-190, 4.23043625e-003, 7.43930795e-002],
       [4.61289448e-041, 2.34628172e-003, 8.62041503e-011],
       [1.58979789e+000, 2.50262009e-016, 9.03427577e-027],
       [2.96357473e-134, 5.08355498e-002, 2.32663919e-004],
       [0.00000000e+000, 1.22369433e-014, 4.36377985e-003],
       [2.65155682e-103, 4.12873889e-001, 6.75646957e-005],
       [7.05472630e-001, 1.33213958e-017, 1.81992485e-028],
       [1.42166693e-285, 2.98880456e-008, 1.99685684e-001],
       [8.68838944e-281, 1.95809747e-009, 1.36031284e-001],
       [4.74069911e-280, 6.40227550e-008, 2.34763950e-001],
       [2.59051414e-254, 2.84653316e-006, 2.49673422e-001],
       [1.30709804e+000, 5.40191505e-017, 9.27207512e-027],
       [1.93716067e+000, 4.67733730e-016, 2.43693353e-026],
       [1.10437770e-205, 6.42382537e-005, 1.79134484e-001],
       [2.87463392e-264, 1.79818302e-007, 1.95463733e-001],
       [8.77307761e-003, 1.09855352e-016, 3.06844563e-028],
       [6.56796757e-251, 2.30402853e-008, 6.40538684e-002],
       [1.82259183e+000, 3.51870932e-016, 5.34390777e-027],
       [2.68966659e-196, 3.18554534e-004, 2.02012772e-002],
       [2.28835722e-239, 1.18966325e-006, 2.61986932e-001],
       [3.85005332e-001, 5.07486109e-018, 8.07090461e-029],
       [2.97070927e+000, 2.25215273e-017, 1.45826047e-027],
       [1.54669251e-245, 2.37994256e-005, 4.70449238e-002],
       [2.97250230e+000, 9.20537370e-016, 5.86183174e-026],
       [2.51256489e-001, 9.71966954e-018, 9.92273358e-029],
       [7.67795136e-002, 1.81892177e-014, 9.92642821e-024],
       [4.15395634e-093, 1.17820150e-001, 1.68421105e-006],
       [1.00997094e-298, 7.11741017e-010, 1.22514460e-001],
       [0.00000000e+000, 3.82851638e-012, 1.57513390e-002],
       [3.22193669e+000, 6.59703177e-017, 3.69159440e-027],
       [2.47369004e+000, 8.88106613e-016, 2.04206384e-026],
       [3.01412924e+000, 1.68993929e-016, 8.30149544e-027],
       [5.36914976e-122, 3.77332955e-001, 2.05007234e-004],
       [4.87767060e-123, 1.22469010e-001, 1.47522326e-003],
       [6.01262218e-001, 2.07501791e-017, 3.70249288e-028],
       [4.61755454e-002, 9.48218948e-012, 1.18962106e-021],
       [1.10260946e-111, 2.63666294e-001, 3.04482104e-004],
       [7.18092701e-001, 1.33681661e-013, 1.44239452e-023],
       [0.00000000e+000, 1.13413698e-015, 1.07163996e-003],
       [4.83593087e-049, 1.81908946e-003, 5.75350754e-011],
       [0.00000000e+000, 1.46950870e-013, 6.13059140e-004],
       [1.77412583e-123, 6.95238806e-002, 1.38954915e-003],
       [2.53482967e-001, 4.07966207e-020, 1.29199008e-029],
       [1.70832646e-168, 1.07543910e-002, 4.74148015e-002],
       [1.88690143e-002, 1.43838827e-019, 5.06182005e-029],
       [0.00000000e+000, 5.26740196e-012, 7.33590052e-003],
       [1.86389396e+000, 2.36489470e-016, 3.76544259e-026],
       [1.35985047e+000, 8.55569443e-016, 2.67245797e-026],
       [8.17806813e-294, 4.82666780e-008, 7.13465644e-002],
       [3.28434438e+000, 1.63877804e-016, 5.26396730e-027],
       [8.21098705e-277, 5.30883063e-010, 4.51771500e-002],
       [1.00342674e-097, 4.36520033e-001, 3.67360555e-005],
       [2.20897185e-083, 3.13721528e-001, 3.79694730e-006],
       [1.58003504e-057, 3.62503830e-002, 9.71272783e-009],
       [1.61348013e-243, 7.75810130e-008, 1.26212878e-001],
       [3.80414054e-237, 1.09538068e-007, 1.49245747e-001],
       [2.15851912e-161, 6.27229834e-002, 4.92630412e-003],
       [1.95128444e-180, 4.93070200e-003, 8.08794435e-002],
       [1.31803692e+000, 5.32420738e-015, 1.30436645e-025],
       [7.79858859e-067, 3.01096779e-002, 8.74375374e-009],
       [6.12107543e-279, 8.55857074e-010, 1.07798580e-001],
       [4.66850302e-197, 1.48027054e-003, 5.47700919e-002],
       [3.52624721e+000, 4.25565651e-016, 2.29068907e-026],
       [7.63949242e-132, 1.22088863e-001, 1.01895184e-003],
       [3.31703393e-097, 3.06149212e-001, 3.35870705e-005],
       [5.37109191e-168, 5.75190751e-003, 3.23117267e-002],
       [6.90508182e-119, 1.16325296e-001, 4.91416425e-005],
       [7.83871527e-001, 4.61599415e-017, 3.49183358e-027],
       [8.95165152e-001, 6.67684050e-015, 1.03729239e-024],
       [1.09244100e+000, 4.97991843e-017, 1.10117672e-027],
       [1.04987457e-233, 3.11807922e-004, 1.80129089e-001],
       [1.54899418e-087, 1.25938919e-001, 6.09942673e-007],
       [0.00000000e+000, 6.63898313e-016, 3.30717488e-004],
       [1.49109871e+000, 5.04670598e-017, 1.01366241e-027]])
#判断多少预测正确了
np.argmax(pre_all,axis=1)==y_train.ravel()
array([ True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
       False,  True,  True,  True,  True,  True,  True, False, False,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True, False,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True, False,  True,  True,  True,  True,
        True,  True,  True, False,  True,  True,  True,  True,  True,
        True,  True,  True])
#计算精确率
np.sum(np.argmax(pre_all,axis=1)==y_train.ravel())/len(y_train.ravel())
0.95

8 计算测试集的预测结果

def predict(X_test,y_test,u_0,sigma_0,u_1,sigma_1,u_2,sigma_2,lst_pri):
    pre_0=gaussian_density(X_test,u_0,sigma_0)*lst_pri[0]
    pre_1=gaussian_density(X_test,u_1,sigma_1)*lst_pri[1]
    pre_2=gaussian_density(X_test,u_2,sigma_2)*lst_pri[2]
    pre_all=np.hstack([pre_0.reshape(len(pre_0),1),pre_1.reshape(pre_1.shape[0],1),pre_2.reshape(pre_2.shape[0],1)])
    return np.sum(np.argmax(pre_all,axis=1)==y_test.ravel())/len(y_test)
predict(X_test,y_test,u_0,sigma_0,u_1,sigma_1,u_2,sigma_2,lst_pri)
0.9666666666666667

9 scikit-learn实例

from sklearn.naive_bayes import GaussianNB
clf=GaussianNB()
help(GaussianNB)
Help on class GaussianNB in module sklearn.naive_bayes:
class GaussianNB(_BaseNB)
 |  GaussianNB(*, priors=None, var_smoothing=1e-09)
 |  
 |  Gaussian Naive Bayes (GaussianNB).
 |  
 |  Can perform online updates to model parameters via :meth:`partial_fit`.
 |  For details on algorithm used to update feature means and variance online,
 |  see Stanford CS tech report STAN-CS-79-773 by Chan, Golub, and LeVeque:
 |  
 |      http://i.stanford.edu/pub/cstr/reports/cs/tr/79/773/CS-TR-79-773.pdf
 |  
 |  Read more in the :ref:`User Guide <gaussian_naive_bayes>`.
 |  
 |  Parameters
 |  ----------
 |  priors : array-like of shape (n_classes,)
 |      Prior probabilities of the classes. If specified the priors are not
 |      adjusted according to the data.
 |  
 |  var_smoothing : float, default=1e-9
 |      Portion of the largest variance of all features that is added to
 |      variances for calculation stability.
 |  
 |      .. versionadded:: 0.20
 |  
 |  Attributes
 |  ----------
 |  class_count_ : ndarray of shape (n_classes,)
 |      number of training samples observed in each class.
 |  
 |  class_prior_ : ndarray of shape (n_classes,)
 |      probability of each class.
 |  
 |  classes_ : ndarray of shape (n_classes,)
 |      class labels known to the classifier.
 |  
 |  epsilon_ : float
 |      absolute additive value to variances.
 |  
 |  n_features_in_ : int
 |      Number of features seen during :term:`fit`.
 |  
 |      .. versionadded:: 0.24
 |  
 |  feature_names_in_ : ndarray of shape (`n_features_in_`,)
 |      Names of features seen during :term:`fit`. Defined only when `X`
 |      has feature names that are all strings.
 |  
 |      .. versionadded:: 1.0
 |  
 |  sigma_ : ndarray of shape (n_classes, n_features)
 |      Variance of each feature per class.
 |  
 |      .. deprecated:: 1.0
 |         `sigma_` is deprecated in 1.0 and will be removed in 1.2.
 |         Use `var_` instead.
 |  
 |  var_ : ndarray of shape (n_classes, n_features)
 |      Variance of each feature per class.
 |  
 |      .. versionadded:: 1.0
 |  
 |  theta_ : ndarray of shape (n_classes, n_features)
 |      mean of each feature per class.
 |  
 |  See Also
 |  --------
 |  BernoulliNB : Naive Bayes classifier for multivariate Bernoulli models.
 |  CategoricalNB : Naive Bayes classifier for categorical features.
 |  ComplementNB : Complement Naive Bayes classifier.
 |  MultinomialNB : Naive Bayes classifier for multinomial models.
 |  
 |  Examples
 |  --------
 |  >>> import numpy as np
 |  >>> X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]])
 |  >>> Y = np.array([1, 1, 1, 2, 2, 2])
 |  >>> from sklearn.naive_bayes import GaussianNB
 |  >>> clf = GaussianNB()
 |  >>> clf.fit(X, Y)
 |  GaussianNB()
 |  >>> print(clf.predict([[-0.8, -1]]))
 |  [1]
 |  >>> clf_pf = GaussianNB()
 |  >>> clf_pf.partial_fit(X, Y, np.unique(Y))
 |  GaussianNB()
 |  >>> print(clf_pf.predict([[-0.8, -1]]))
 |  [1]
 |  
 |  Method resolution order:
 |      GaussianNB
 |      _BaseNB
 |      sklearn.base.ClassifierMixin
 |      sklearn.base.BaseEstimator
 |      builtins.object
 |  
 |  Methods defined here:
 |  
 |  __init__(self, *, priors=None, var_smoothing=1e-09)
 |      Initialize self.  See help(type(self)) for accurate signature.
 |  
 |  fit(self, X, y, sample_weight=None)
 |      Fit Gaussian Naive Bayes according to X, y.
 |      
 |      Parameters
 |      ----------
 |      X : array-like of shape (n_samples, n_features)
 |          Training vectors, where `n_samples` is the number of samples
 |          and `n_features` is the number of features.
 |      
 |      y : array-like of shape (n_samples,)
 |          Target values.
 |      
 |      sample_weight : array-like of shape (n_samples,), default=None
 |          Weights applied to individual samples (1. for unweighted).
 |      
 |          .. versionadded:: 0.17
 |             Gaussian Naive Bayes supports fitting with *sample_weight*.
 |      
 |      Returns
 |      -------
 |      self : object
 |          Returns the instance itself.
 |  
 |  partial_fit(self, X, y, classes=None, sample_weight=None)
 |      Incremental fit on a batch of samples.
 |      
 |      This method is expected to be called several times consecutively
 |      on different chunks of a dataset so as to implement out-of-core
 |      or online learning.
 |      
 |      This is especially useful when the whole dataset is too big to fit in
 |      memory at once.
 |      
 |      This method has some performance and numerical stability overhead,
 |      hence it is better to call partial_fit on chunks of data that are
 |      as large as possible (as long as fitting in the memory budget) to
 |      hide the overhead.
 |      
 |      Parameters
 |      ----------
 |      X : array-like of shape (n_samples, n_features)
 |          Training vectors, where `n_samples` is the number of samples and
 |          `n_features` is the number of features.
 |      
 |      y : array-like of shape (n_samples,)
 |          Target values.
 |      
 |      classes : array-like of shape (n_classes,), default=None
 |          List of all the classes that can possibly appear in the y vector.
 |      
 |          Must be provided at the first call to partial_fit, can be omitted
 |          in subsequent calls.
 |      
 |      sample_weight : array-like of shape (n_samples,), default=None
 |          Weights applied to individual samples (1. for unweighted).
 |      
 |          .. versionadded:: 0.17
 |      
 |      Returns
 |      -------
 |      self : object
 |          Returns the instance itself.
 |  
 |  ----------------------------------------------------------------------
 |  Readonly properties defined here:
 |  
 |  sigma_
 |      DEPRECATED: Attribute `sigma_` was deprecated in 1.0 and will be removed in1.2. Use `var_` instead.
 |  
 |  ----------------------------------------------------------------------
 |  Data and other attributes defined here:
 |  
 |  __abstractmethods__ = frozenset()
 |  
 |  ----------------------------------------------------------------------
 |  Methods inherited from _BaseNB:
 |  
 |  predict(self, X)
 |      Perform classification on an array of test vectors X.
 |      
 |      Parameters
 |      ----------
 |      X : array-like of shape (n_samples, n_features)
 |          The input samples.
 |      
 |      Returns
 |      -------
 |      C : ndarray of shape (n_samples,)
 |          Predicted target values for X.
 |  
 |  predict_log_proba(self, X)
 |      Return log-probability estimates for the test vector X.
 |      
 |      Parameters
 |      ----------
 |      X : array-like of shape (n_samples, n_features)
 |          The input samples.
 |      
 |      Returns
 |      -------
 |      C : array-like of shape (n_samples, n_classes)
 |          Returns the log-probability of the samples for each class in
 |          the model. The columns correspond to the classes in sorted
 |          order, as they appear in the attribute :term:`classes_`.
 |  
 |  predict_proba(self, X)
 |      Return probability estimates for the test vector X.
 |      
 |      Parameters
 |      ----------
 |      X : array-like of shape (n_samples, n_features)
 |          The input samples.
 |      
 |      Returns
 |      -------
 |      C : array-like of shape (n_samples, n_classes)
 |          Returns the probability of the samples for each class in
 |          the model. The columns correspond to the classes in sorted
 |          order, as they appear in the attribute :term:`classes_`.
 |  
 |  ----------------------------------------------------------------------
 |  Methods inherited from sklearn.base.ClassifierMixin:
 |  
 |  score(self, X, y, sample_weight=None)
 |      Return the mean accuracy on the given test data and labels.
 |      
 |      In multi-label classification, this is the subset accuracy
 |      which is a harsh metric since you require for each sample that
 |      each label set be correctly predicted.
 |      
 |      Parameters
 |      ----------
 |      X : array-like of shape (n_samples, n_features)
 |          Test samples.
 |      
 |      y : array-like of shape (n_samples,) or (n_samples, n_outputs)
 |          True labels for `X`.
 |      
 |      sample_weight : array-like of shape (n_samples,), default=None
 |          Sample weights.
 |      
 |      Returns
 |      -------
 |      score : float
 |          Mean accuracy of ``self.predict(X)`` wrt. `y`.
 |  
 |  ----------------------------------------------------------------------
 |  Data descriptors inherited from sklearn.base.ClassifierMixin:
 |  
 |  __dict__
 |      dictionary for instance variables (if defined)
 |  
 |  __weakref__
 |      list of weak references to the object (if defined)
 |  
 |  ----------------------------------------------------------------------
 |  Methods inherited from sklearn.base.BaseEstimator:
 |  
 |  __getstate__(self)
 |  
 |  __repr__(self, N_CHAR_MAX=700)
 |      Return repr(self).
 |  
 |  __setstate__(self, state)
 |  
 |  get_params(self, deep=True)
 |      Get parameters for this estimator.
 |      
 |      Parameters
 |      ----------
 |      deep : bool, default=True
 |          If True, will return the parameters for this estimator and
 |          contained subobjects that are estimators.
 |      
 |      Returns
 |      -------
 |      params : dict
 |          Parameter names mapped to their values.
 |  
 |  set_params(self, **params)
 |      Set the parameters of this estimator.
 |      
 |      The method works on simple estimators as well as on nested objects
 |      (such as :class:`~sklearn.pipeline.Pipeline`). The latter have
 |      parameters of the form ``<component>__<parameter>`` so that it's
 |      possible to update each component of a nested object.
 |      
 |      Parameters
 |      ----------
 |      **params : dict
 |          Estimator parameters.
 |      
 |      Returns
 |      -------
 |      self : estimator instance
 |          Estimator instance.
X.shape,y.shape
((150, 4), (150, 1))
clf.fit(X_train,y_train.ravel())
GaussianNB()
clf.score(X_test,y_test.ravel())
0.9666666666666667



目录
相关文章
|
24天前
|
机器学习/深度学习 人工智能 算法
【手写数字识别】Python+深度学习+机器学习+人工智能+TensorFlow+算法模型
手写数字识别系统,使用Python作为主要开发语言,基于深度学习TensorFlow框架,搭建卷积神经网络算法。并通过对数据集进行训练,最后得到一个识别精度较高的模型。并基于Flask框架,开发网页端操作平台,实现用户上传一张图片识别其名称。
68 0
【手写数字识别】Python+深度学习+机器学习+人工智能+TensorFlow+算法模型
|
26天前
|
机器学习/深度学习 数据采集 人工智能
探索机器学习:从理论到Python代码实践
【10月更文挑战第36天】本文将深入浅出地介绍机器学习的基本概念、主要算法及其在Python中的实现。我们将通过实际案例,展示如何使用scikit-learn库进行数据预处理、模型选择和参数调优。无论你是初学者还是有一定基础的开发者,都能从中获得启发和实践指导。
41 2
|
28天前
|
机器学习/深度学习 数据采集 搜索推荐
利用Python和机器学习构建电影推荐系统
利用Python和机器学习构建电影推荐系统
52 1
|
28天前
|
机器学习/深度学习 算法 PyTorch
用Python实现简单机器学习模型:以鸢尾花数据集为例
用Python实现简单机器学习模型:以鸢尾花数据集为例
74 1
|
1月前
|
机器学习/深度学习 数据采集 算法
Python机器学习:Scikit-learn库的高效使用技巧
【10月更文挑战第28天】Scikit-learn 是 Python 中最受欢迎的机器学习库之一,以其简洁的 API、丰富的算法和良好的文档支持而受到开发者喜爱。本文介绍了 Scikit-learn 的高效使用技巧,包括数据预处理(如使用 Pipeline 和 ColumnTransformer)、模型选择与评估(如交叉验证和 GridSearchCV)以及模型持久化(如使用 joblib)。通过这些技巧,你可以在机器学习项目中事半功倍。
40 3
|
1月前
|
机器学习/深度学习 人工智能 算法
机器学习基础:使用Python和Scikit-learn入门
机器学习基础:使用Python和Scikit-learn入门
33 1
|
2月前
|
机器学习/深度学习 算法 Java
机器学习、基础算法、python常见面试题必知必答系列大全:(面试问题持续更新)
机器学习、基础算法、python常见面试题必知必答系列大全:(面试问题持续更新)
|
2月前
|
机器学习/深度学习 人工智能 算法
机器学习基础:使用Python和Scikit-learn入门
【10月更文挑战第12天】本文介绍了如何使用Python和Scikit-learn进行机器学习的基础知识和入门实践。首先概述了机器学习的基本概念,包括监督学习、无监督学习和强化学习。接着详细讲解了Python和Scikit-learn的安装、数据处理、模型训练和评估等步骤,并提供了代码示例。通过本文,读者可以掌握机器学习的基本流程,并为深入学习打下坚实基础。
24 1
|
2月前
|
机器学习/深度学习 API 计算机视觉
基于Python_opencv人脸录入、识别系统(应用dlib机器学习库)(下)
基于Python_opencv人脸录入、识别系统(应用dlib机器学习库)(下)
31 2
|
2月前
|
机器学习/深度学习 存储 算法
基于Python_opencv人脸录入、识别系统(应用dlib机器学习库)(上)
基于Python_opencv人脸录入、识别系统(应用dlib机器学习库)(上)
40 1