2 数据读取–训练集和测试集的划分
#划分数据为训练数据和测试数据 from sklearn.model_selection import train_test_split X_train,X_test,y_train,y_test=train_test_split(X,y,test_size=0.2,random_state=0)
X_train.shape,X_test.shape,y_train.shape,y_test.shape
((120, 4), (30, 4), (120, 1), (30, 1))
3 数据读取–准备好每个类别各自的数据
#看看哪些索引处的标签为0 np.where(y_train==0)
(array([ 2, 6, 11, 13, 14, 31, 38, 39, 42, 43, 45, 48, 52, 57, 58, 61, 63, 66, 67, 69, 70, 71, 75, 76, 77, 80, 81, 83, 88, 90, 92, 93, 95, 104, 108, 113, 114, 115, 119], dtype=int64), array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int64))
#新建一个字典,存储每个标签对应的索引,该操作的目的是为了后面对不同类别分别计算均值和方差 dic={} for i in [0,1,2]: dic[i]=np.where(y_train==i)
dic
{0: (array([ 2, 6, 11, 13, 14, 31, 38, 39, 42, 43, 45, 48, 52, 57, 58, 61, 63, 66, 67, 69, 70, 71, 75, 76, 77, 80, 81, 83, 88, 90, 92, 93, 95, 104, 108, 113, 114, 115, 119], dtype=int64), array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int64)), 1: (array([ 1, 5, 7, 8, 9, 15, 20, 22, 23, 28, 30, 33, 34, 35, 36, 41, 44, 47, 49, 51, 72, 78, 79, 82, 85, 87, 97, 98, 99, 102, 103, 105, 109, 110, 111, 112, 117], dtype=int64), array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int64)), 2: (array([ 0, 3, 4, 10, 12, 16, 17, 18, 19, 21, 24, 25, 26, 27, 29, 32, 37, 40, 46, 50, 53, 54, 55, 56, 59, 60, 62, 64, 65, 68, 73, 74, 84, 86, 89, 91, 94, 96, 100, 101, 106, 107, 116, 118], dtype=int64), array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int64))}
4 定义数据的均值和方差
#计算均值和方差,对于每个特征(列这个维度)计算均值和方差,因此,有多少个特征,那么均值和方差向量中就有多少个元素 #X为数据框 def u_sigma(X): u=np.mean(X,axis=0) sigma=np.var(X,axis=0) return u,sigma
dic[0][0]
array([ 2, 6, 11, 13, 14, 31, 38, 39, 42, 43, 45, 48, 52, 57, 58, 61, 63, 66, 67, 69, 70, 71, 75, 76, 77, 80, 81, 83, 88, 90, 92, 93, 95, 104, 108, 113, 114, 115, 119], dtype=int64)
#计算类别0(普通鸢尾花)的均值和方差 u_0,sigma_0=u_sigma(X_train[dic[0][0],:]) u_0,sigma_0
(array([5.02051282, 3.4025641 , 1.46153846, 0.24102564]), array([0.12932281, 0.1417883 , 0.02031558, 0.01113741]))
#计算类别1(山鸢尾花)的均值和方差 u_1,sigma_1=u_sigma(X_train[dic[1][0],:]) u_1,sigma_1
(array([5.88648649, 2.76216216, 4.21621622, 1.32432432]), array([0.26387144, 0.1039737 , 0.2300073 , 0.04075968]))
#计算类别2(维吉利亚尾花)的均值和方差 u_2,sigma_2=u_sigma(X_train[dic[2][0],:]) u_2,sigma_2
(array([6.63863636, 2.98863636, 5.56590909, 2.03181818]), array([0.38918905, 0.10782541, 0.29451963, 0.06444215]))
5 定义每个类别的先验概率
#计算每个类别对应的先验概率 lst_pri=[] for i in [0,1,2]: lst_pri.append(len(dic[i][0])) lst_pri=[item/sum(lst_pri) for item in lst_pri] lst_pri
[0.325, 0.30833333333333335, 0.36666666666666664]
6 调用概率密度函数
#所有样本带入到第1个类别的高斯模型参数中得到的结果 pre_0=gaussian_density(X_train,u_0,sigma_0)*lst_pri[0] pre_0
array([3.64205427e-225, 3.40844822e-130, 3.08530851e+000, 4.39737931e-176, 9.32161971e-262, 1.12603195e-090, 8.19955989e-002, 5.38088810e-180, 9.99826548e-113, 6.22294079e-089, 2.18584476e-247, 1.14681255e+000, 2.38802541e-230, 7.48076601e-003, 1.51577355e+000, 8.84977214e-059, 9.40380304e-226, 2.20471084e-296, 1.11546261e-168, 1.12595279e-254, 7.13493544e-080, 0.00000000e+000, 5.43149166e-151, 7.00401162e-075, 2.20419920e-177, 7.88959967e-176, 1.41957694e-141, 1.31858669e-191, 4.74468428e-145, 9.39276491e-214, 2.02942932e-136, 1.40273451e+000, 4.66850302e-197, 1.84403192e-103, 8.15997638e-072, 1.70855259e-092, 8.50513873e-134, 1.04684523e-275, 1.95561507e+000, 5.03262010e-003, 3.23862571e-215, 3.13715578e-099, 5.29812808e-001, 6.29658079e-003, 1.81543604e-163, 1.32072621e+000, 1.48741944e-190, 4.61289448e-041, 1.58979789e+000, 2.96357473e-134, 0.00000000e+000, 2.65155682e-103, 7.05472630e-001, 1.42166693e-285, 8.68838944e-281, 4.74069911e-280, 2.59051414e-254, 1.30709804e+000, 1.93716067e+000, 1.10437770e-205, 2.87463392e-264, 8.77307761e-003, 6.56796757e-251, 1.82259183e+000, 2.68966659e-196, 2.28835722e-239, 3.85005332e-001, 2.97070927e+000, 1.54669251e-245, 2.97250230e+000, 2.51256489e-001, 7.67795136e-002, 4.15395634e-093, 1.00997094e-298, 0.00000000e+000, 3.22193669e+000, 2.47369004e+000, 3.01412924e+000, 5.36914976e-122, 4.87767060e-123, 6.01262218e-001, 4.61755454e-002, 1.10260946e-111, 7.18092701e-001, 0.00000000e+000, 4.83593087e-049, 0.00000000e+000, 1.77412583e-123, 2.53482967e-001, 1.70832646e-168, 1.88690143e-002, 0.00000000e+000, 1.86389396e+000, 1.35985047e+000, 8.17806813e-294, 3.28434438e+000, 8.21098705e-277, 1.00342674e-097, 2.20897185e-083, 1.58003504e-057, 1.61348013e-243, 3.80414054e-237, 2.15851912e-161, 1.95128444e-180, 1.31803692e+000, 7.79858859e-067, 6.12107543e-279, 4.66850302e-197, 3.52624721e+000, 7.63949242e-132, 3.31703393e-097, 5.37109191e-168, 6.90508182e-119, 7.83871527e-001, 8.95165152e-001, 1.09244100e+000, 1.04987457e-233, 1.54899418e-087, 0.00000000e+000, 1.49109871e+000])
#所有样本带入到第2个类别的高斯模型参数中得到的结果 pre_1=gaussian_density(X_train,u_1,sigma_1)*lst_pri[1] pre_1
array([2.95633338e-04, 1.36197317e-01, 2.90318178e-16, 7.67369010e-03, 3.75455611e-07, 1.46797523e-01, 6.95344048e-15, 3.36175041e-02, 2.53841239e-01, 3.16199307e-01, 1.32212698e-06, 2.31912196e-17, 6.23661197e-08, 4.43491705e-12, 9.03659728e-17, 6.06688573e-04, 3.14945948e-04, 1.24882948e-11, 1.87288422e-02, 2.66560740e-05, 1.30000970e-01, 2.76182931e-12, 2.07410916e-02, 7.22817433e-02, 7.79602598e-03, 4.38522048e-02, 8.22673683e-03, 1.14220807e-03, 1.03590806e-02, 8.19796704e-05, 2.21991209e-02, 1.91118667e-15, 1.48027054e-03, 4.05979965e-01, 1.65444313e-01, 2.36465225e-01, 2.30302015e-01, 4.54901890e-07, 7.37406496e-17, 2.21052310e-20, 3.87241584e-04, 2.87187564e-01, 8.53516604e-15, 3.46342632e-18, 9.95391379e-03, 2.43959119e-16, 4.23043625e-03, 2.34628172e-03, 2.50262009e-16, 5.08355498e-02, 1.22369433e-14, 4.12873889e-01, 1.33213958e-17, 2.98880456e-08, 1.95809747e-09, 6.40227550e-08, 2.84653316e-06, 5.40191505e-17, 4.67733730e-16, 6.42382537e-05, 1.79818302e-07, 1.09855352e-16, 2.30402853e-08, 3.51870932e-16, 3.18554534e-04, 1.18966325e-06, 5.07486109e-18, 2.25215273e-17, 2.37994256e-05, 9.20537370e-16, 9.71966954e-18, 1.81892177e-14, 1.17820150e-01, 7.11741017e-10, 3.82851638e-12, 6.59703177e-17, 8.88106613e-16, 1.68993929e-16, 3.77332955e-01, 1.22469010e-01, 2.07501791e-17, 9.48218948e-12, 2.63666294e-01, 1.33681661e-13, 1.13413698e-15, 1.81908946e-03, 1.46950870e-13, 6.95238806e-02, 4.07966207e-20, 1.07543910e-02, 1.43838827e-19, 5.26740196e-12, 2.36489470e-16, 8.55569443e-16, 4.82666780e-08, 1.63877804e-16, 5.30883063e-10, 4.36520033e-01, 3.13721528e-01, 3.62503830e-02, 7.75810130e-08, 1.09538068e-07, 6.27229834e-02, 4.93070200e-03, 5.32420738e-15, 3.01096779e-02, 8.55857074e-10, 1.48027054e-03, 4.25565651e-16, 1.22088863e-01, 3.06149212e-01, 5.75190751e-03, 1.16325296e-01, 4.61599415e-17, 6.67684050e-15, 4.97991843e-17, 3.11807922e-04, 1.25938919e-01, 6.63898313e-16, 5.04670598e-17])
#所有样本带入到第3个类别的高斯模型参数中得到的结果 pre_2=gaussian_density(X_train,u_2,sigma_2)*lst_pri[2] pre_2
array([1.88926441e-01, 7.41874323e-04, 2.18905385e-26, 7.03342033e-02, 2.07838563e-01, 6.36007282e-06, 3.75616194e-24, 2.15583340e-02, 7.65683494e-04, 4.80086802e-06, 3.04560221e-01, 4.03768532e-28, 1.12679216e-01, 9.72668930e-22, 1.40128825e-26, 2.07279668e-11, 2.09922203e-01, 3.69933717e-02, 7.04823898e-04, 6.49975333e-02, 2.90135522e-07, 2.72821894e-02, 1.19387091e-02, 7.43267743e-08, 5.99160309e-02, 1.85609819e-02, 1.38418438e-04, 4.76244749e-02, 2.86112072e-03, 2.53639963e-01, 3.04064364e-03, 5.04262171e-26, 5.47700919e-02, 3.69353344e-05, 1.75987852e-06, 5.01849240e-06, 2.09975476e-03, 8.54119142e-02, 1.00630371e-26, 1.53267285e-31, 1.61099289e-01, 2.08157220e-05, 9.87308671e-25, 7.12483734e-27, 1.49368318e-02, 4.76225689e-27, 7.43930795e-02, 8.62041503e-11, 9.03427577e-27, 2.32663919e-04, 4.36377985e-03, 6.75646957e-05, 1.81992485e-28, 1.99685684e-01, 1.36031284e-01, 2.34763950e-01, 2.49673422e-01, 9.27207512e-27, 2.43693353e-26, 1.79134484e-01, 1.95463733e-01, 3.06844563e-28, 6.40538684e-02, 5.34390777e-27, 2.02012772e-02, 2.61986932e-01, 8.07090461e-29, 1.45826047e-27, 4.70449238e-02, 5.86183174e-26, 9.92273358e-29, 9.92642821e-24, 1.68421105e-06, 1.22514460e-01, 1.57513390e-02, 3.69159440e-27, 2.04206384e-26, 8.30149544e-27, 2.05007234e-04, 1.47522326e-03, 3.70249288e-28, 1.18962106e-21, 3.04482104e-04, 1.44239452e-23, 1.07163996e-03, 5.75350754e-11, 6.13059140e-04, 1.38954915e-03, 1.29199008e-29, 4.74148015e-02, 5.06182005e-29, 7.33590052e-03, 3.76544259e-26, 2.67245797e-26, 7.13465644e-02, 5.26396730e-27, 4.51771500e-02, 3.67360555e-05, 3.79694730e-06, 9.71272783e-09, 1.26212878e-01, 1.49245747e-01, 4.92630412e-03, 8.08794435e-02, 1.30436645e-25, 8.74375374e-09, 1.07798580e-01, 5.47700919e-02, 2.29068907e-26, 1.01895184e-03, 3.35870705e-05, 3.23117267e-02, 4.91416425e-05, 3.49183358e-27, 1.03729239e-24, 1.10117672e-27, 1.80129089e-01, 6.09942673e-07, 3.30717488e-04, 1.01366241e-27])
7 计算训练集的预测结果
#得到训练集的预测结果 pre_all=np.hstack([pre_0.reshape(len(pre_0),1),pre_1.reshape(pre_1.shape[0],1),pre_2.reshape(pre_2.shape[0],1)]) pre_all
array([[3.64205427e-225, 2.95633338e-004, 1.88926441e-001], [3.40844822e-130, 1.36197317e-001, 7.41874323e-004], [3.08530851e+000, 2.90318178e-016, 2.18905385e-026], [4.39737931e-176, 7.67369010e-003, 7.03342033e-002], [9.32161971e-262, 3.75455611e-007, 2.07838563e-001], [1.12603195e-090, 1.46797523e-001, 6.36007282e-006], [8.19955989e-002, 6.95344048e-015, 3.75616194e-024], [5.38088810e-180, 3.36175041e-002, 2.15583340e-002], [9.99826548e-113, 2.53841239e-001, 7.65683494e-004], [6.22294079e-089, 3.16199307e-001, 4.80086802e-006], [2.18584476e-247, 1.32212698e-006, 3.04560221e-001], [1.14681255e+000, 2.31912196e-017, 4.03768532e-028], [2.38802541e-230, 6.23661197e-008, 1.12679216e-001], [7.48076601e-003, 4.43491705e-012, 9.72668930e-022], [1.51577355e+000, 9.03659728e-017, 1.40128825e-026], [8.84977214e-059, 6.06688573e-004, 2.07279668e-011], [9.40380304e-226, 3.14945948e-004, 2.09922203e-001], [2.20471084e-296, 1.24882948e-011, 3.69933717e-002], [1.11546261e-168, 1.87288422e-002, 7.04823898e-004], [1.12595279e-254, 2.66560740e-005, 6.49975333e-002], [7.13493544e-080, 1.30000970e-001, 2.90135522e-007], [0.00000000e+000, 2.76182931e-012, 2.72821894e-002], [5.43149166e-151, 2.07410916e-002, 1.19387091e-002], [7.00401162e-075, 7.22817433e-002, 7.43267743e-008], [2.20419920e-177, 7.79602598e-003, 5.99160309e-002], [7.88959967e-176, 4.38522048e-002, 1.85609819e-002], [1.41957694e-141, 8.22673683e-003, 1.38418438e-004], [1.31858669e-191, 1.14220807e-003, 4.76244749e-002], [4.74468428e-145, 1.03590806e-002, 2.86112072e-003], [9.39276491e-214, 8.19796704e-005, 2.53639963e-001], [2.02942932e-136, 2.21991209e-002, 3.04064364e-003], [1.40273451e+000, 1.91118667e-015, 5.04262171e-026], [4.66850302e-197, 1.48027054e-003, 5.47700919e-002], [1.84403192e-103, 4.05979965e-001, 3.69353344e-005], [8.15997638e-072, 1.65444313e-001, 1.75987852e-006], [1.70855259e-092, 2.36465225e-001, 5.01849240e-006], [8.50513873e-134, 2.30302015e-001, 2.09975476e-003], [1.04684523e-275, 4.54901890e-007, 8.54119142e-002], [1.95561507e+000, 7.37406496e-017, 1.00630371e-026], [5.03262010e-003, 2.21052310e-020, 1.53267285e-031], [3.23862571e-215, 3.87241584e-004, 1.61099289e-001], [3.13715578e-099, 2.87187564e-001, 2.08157220e-005], [5.29812808e-001, 8.53516604e-015, 9.87308671e-025], [6.29658079e-003, 3.46342632e-018, 7.12483734e-027], [1.81543604e-163, 9.95391379e-003, 1.49368318e-002], [1.32072621e+000, 2.43959119e-016, 4.76225689e-027], [1.48741944e-190, 4.23043625e-003, 7.43930795e-002], [4.61289448e-041, 2.34628172e-003, 8.62041503e-011], [1.58979789e+000, 2.50262009e-016, 9.03427577e-027], [2.96357473e-134, 5.08355498e-002, 2.32663919e-004], [0.00000000e+000, 1.22369433e-014, 4.36377985e-003], [2.65155682e-103, 4.12873889e-001, 6.75646957e-005], [7.05472630e-001, 1.33213958e-017, 1.81992485e-028], [1.42166693e-285, 2.98880456e-008, 1.99685684e-001], [8.68838944e-281, 1.95809747e-009, 1.36031284e-001], [4.74069911e-280, 6.40227550e-008, 2.34763950e-001], [2.59051414e-254, 2.84653316e-006, 2.49673422e-001], [1.30709804e+000, 5.40191505e-017, 9.27207512e-027], [1.93716067e+000, 4.67733730e-016, 2.43693353e-026], [1.10437770e-205, 6.42382537e-005, 1.79134484e-001], [2.87463392e-264, 1.79818302e-007, 1.95463733e-001], [8.77307761e-003, 1.09855352e-016, 3.06844563e-028], [6.56796757e-251, 2.30402853e-008, 6.40538684e-002], [1.82259183e+000, 3.51870932e-016, 5.34390777e-027], [2.68966659e-196, 3.18554534e-004, 2.02012772e-002], [2.28835722e-239, 1.18966325e-006, 2.61986932e-001], [3.85005332e-001, 5.07486109e-018, 8.07090461e-029], [2.97070927e+000, 2.25215273e-017, 1.45826047e-027], [1.54669251e-245, 2.37994256e-005, 4.70449238e-002], [2.97250230e+000, 9.20537370e-016, 5.86183174e-026], [2.51256489e-001, 9.71966954e-018, 9.92273358e-029], [7.67795136e-002, 1.81892177e-014, 9.92642821e-024], [4.15395634e-093, 1.17820150e-001, 1.68421105e-006], [1.00997094e-298, 7.11741017e-010, 1.22514460e-001], [0.00000000e+000, 3.82851638e-012, 1.57513390e-002], [3.22193669e+000, 6.59703177e-017, 3.69159440e-027], [2.47369004e+000, 8.88106613e-016, 2.04206384e-026], [3.01412924e+000, 1.68993929e-016, 8.30149544e-027], [5.36914976e-122, 3.77332955e-001, 2.05007234e-004], [4.87767060e-123, 1.22469010e-001, 1.47522326e-003], [6.01262218e-001, 2.07501791e-017, 3.70249288e-028], [4.61755454e-002, 9.48218948e-012, 1.18962106e-021], [1.10260946e-111, 2.63666294e-001, 3.04482104e-004], [7.18092701e-001, 1.33681661e-013, 1.44239452e-023], [0.00000000e+000, 1.13413698e-015, 1.07163996e-003], [4.83593087e-049, 1.81908946e-003, 5.75350754e-011], [0.00000000e+000, 1.46950870e-013, 6.13059140e-004], [1.77412583e-123, 6.95238806e-002, 1.38954915e-003], [2.53482967e-001, 4.07966207e-020, 1.29199008e-029], [1.70832646e-168, 1.07543910e-002, 4.74148015e-002], [1.88690143e-002, 1.43838827e-019, 5.06182005e-029], [0.00000000e+000, 5.26740196e-012, 7.33590052e-003], [1.86389396e+000, 2.36489470e-016, 3.76544259e-026], [1.35985047e+000, 8.55569443e-016, 2.67245797e-026], [8.17806813e-294, 4.82666780e-008, 7.13465644e-002], [3.28434438e+000, 1.63877804e-016, 5.26396730e-027], [8.21098705e-277, 5.30883063e-010, 4.51771500e-002], [1.00342674e-097, 4.36520033e-001, 3.67360555e-005], [2.20897185e-083, 3.13721528e-001, 3.79694730e-006], [1.58003504e-057, 3.62503830e-002, 9.71272783e-009], [1.61348013e-243, 7.75810130e-008, 1.26212878e-001], [3.80414054e-237, 1.09538068e-007, 1.49245747e-001], [2.15851912e-161, 6.27229834e-002, 4.92630412e-003], [1.95128444e-180, 4.93070200e-003, 8.08794435e-002], [1.31803692e+000, 5.32420738e-015, 1.30436645e-025], [7.79858859e-067, 3.01096779e-002, 8.74375374e-009], [6.12107543e-279, 8.55857074e-010, 1.07798580e-001], [4.66850302e-197, 1.48027054e-003, 5.47700919e-002], [3.52624721e+000, 4.25565651e-016, 2.29068907e-026], [7.63949242e-132, 1.22088863e-001, 1.01895184e-003], [3.31703393e-097, 3.06149212e-001, 3.35870705e-005], [5.37109191e-168, 5.75190751e-003, 3.23117267e-002], [6.90508182e-119, 1.16325296e-001, 4.91416425e-005], [7.83871527e-001, 4.61599415e-017, 3.49183358e-027], [8.95165152e-001, 6.67684050e-015, 1.03729239e-024], [1.09244100e+000, 4.97991843e-017, 1.10117672e-027], [1.04987457e-233, 3.11807922e-004, 1.80129089e-001], [1.54899418e-087, 1.25938919e-001, 6.09942673e-007], [0.00000000e+000, 6.63898313e-016, 3.30717488e-004], [1.49109871e+000, 5.04670598e-017, 1.01366241e-027]])
#判断多少预测正确了 np.argmax(pre_all,axis=1)==y_train.ravel()
array([ True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, False, True, True, True, True, True, True, False, False, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, False, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, False, True, True, True, True, True, True, True, False, True, True, True, True, True, True, True, True])
#计算精确率 np.sum(np.argmax(pre_all,axis=1)==y_train.ravel())/len(y_train.ravel())
0.95
8 计算测试集的预测结果
def predict(X_test,y_test,u_0,sigma_0,u_1,sigma_1,u_2,sigma_2,lst_pri): pre_0=gaussian_density(X_test,u_0,sigma_0)*lst_pri[0] pre_1=gaussian_density(X_test,u_1,sigma_1)*lst_pri[1] pre_2=gaussian_density(X_test,u_2,sigma_2)*lst_pri[2] pre_all=np.hstack([pre_0.reshape(len(pre_0),1),pre_1.reshape(pre_1.shape[0],1),pre_2.reshape(pre_2.shape[0],1)]) return np.sum(np.argmax(pre_all,axis=1)==y_test.ravel())/len(y_test)
predict(X_test,y_test,u_0,sigma_0,u_1,sigma_1,u_2,sigma_2,lst_pri)
0.9666666666666667
9 scikit-learn实例
from sklearn.naive_bayes import GaussianNB clf=GaussianNB()
help(GaussianNB)
Help on class GaussianNB in module sklearn.naive_bayes: class GaussianNB(_BaseNB) | GaussianNB(*, priors=None, var_smoothing=1e-09) | | Gaussian Naive Bayes (GaussianNB). | | Can perform online updates to model parameters via :meth:`partial_fit`. | For details on algorithm used to update feature means and variance online, | see Stanford CS tech report STAN-CS-79-773 by Chan, Golub, and LeVeque: | | http://i.stanford.edu/pub/cstr/reports/cs/tr/79/773/CS-TR-79-773.pdf | | Read more in the :ref:`User Guide <gaussian_naive_bayes>`. | | Parameters | ---------- | priors : array-like of shape (n_classes,) | Prior probabilities of the classes. If specified the priors are not | adjusted according to the data. | | var_smoothing : float, default=1e-9 | Portion of the largest variance of all features that is added to | variances for calculation stability. | | .. versionadded:: 0.20 | | Attributes | ---------- | class_count_ : ndarray of shape (n_classes,) | number of training samples observed in each class. | | class_prior_ : ndarray of shape (n_classes,) | probability of each class. | | classes_ : ndarray of shape (n_classes,) | class labels known to the classifier. | | epsilon_ : float | absolute additive value to variances. | | n_features_in_ : int | Number of features seen during :term:`fit`. | | .. versionadded:: 0.24 | | feature_names_in_ : ndarray of shape (`n_features_in_`,) | Names of features seen during :term:`fit`. Defined only when `X` | has feature names that are all strings. | | .. versionadded:: 1.0 | | sigma_ : ndarray of shape (n_classes, n_features) | Variance of each feature per class. | | .. deprecated:: 1.0 | `sigma_` is deprecated in 1.0 and will be removed in 1.2. | Use `var_` instead. | | var_ : ndarray of shape (n_classes, n_features) | Variance of each feature per class. | | .. versionadded:: 1.0 | | theta_ : ndarray of shape (n_classes, n_features) | mean of each feature per class. | | See Also | -------- | BernoulliNB : Naive Bayes classifier for multivariate Bernoulli models. | CategoricalNB : Naive Bayes classifier for categorical features. | ComplementNB : Complement Naive Bayes classifier. | MultinomialNB : Naive Bayes classifier for multinomial models. | | Examples | -------- | >>> import numpy as np | >>> X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]]) | >>> Y = np.array([1, 1, 1, 2, 2, 2]) | >>> from sklearn.naive_bayes import GaussianNB | >>> clf = GaussianNB() | >>> clf.fit(X, Y) | GaussianNB() | >>> print(clf.predict([[-0.8, -1]])) | [1] | >>> clf_pf = GaussianNB() | >>> clf_pf.partial_fit(X, Y, np.unique(Y)) | GaussianNB() | >>> print(clf_pf.predict([[-0.8, -1]])) | [1] | | Method resolution order: | GaussianNB | _BaseNB | sklearn.base.ClassifierMixin | sklearn.base.BaseEstimator | builtins.object | | Methods defined here: | | __init__(self, *, priors=None, var_smoothing=1e-09) | Initialize self. See help(type(self)) for accurate signature. | | fit(self, X, y, sample_weight=None) | Fit Gaussian Naive Bayes according to X, y. | | Parameters | ---------- | X : array-like of shape (n_samples, n_features) | Training vectors, where `n_samples` is the number of samples | and `n_features` is the number of features. | | y : array-like of shape (n_samples,) | Target values. | | sample_weight : array-like of shape (n_samples,), default=None | Weights applied to individual samples (1. for unweighted). | | .. versionadded:: 0.17 | Gaussian Naive Bayes supports fitting with *sample_weight*. | | Returns | ------- | self : object | Returns the instance itself. | | partial_fit(self, X, y, classes=None, sample_weight=None) | Incremental fit on a batch of samples. | | This method is expected to be called several times consecutively | on different chunks of a dataset so as to implement out-of-core | or online learning. | | This is especially useful when the whole dataset is too big to fit in | memory at once. | | This method has some performance and numerical stability overhead, | hence it is better to call partial_fit on chunks of data that are | as large as possible (as long as fitting in the memory budget) to | hide the overhead. | | Parameters | ---------- | X : array-like of shape (n_samples, n_features) | Training vectors, where `n_samples` is the number of samples and | `n_features` is the number of features. | | y : array-like of shape (n_samples,) | Target values. | | classes : array-like of shape (n_classes,), default=None | List of all the classes that can possibly appear in the y vector. | | Must be provided at the first call to partial_fit, can be omitted | in subsequent calls. | | sample_weight : array-like of shape (n_samples,), default=None | Weights applied to individual samples (1. for unweighted). | | .. versionadded:: 0.17 | | Returns | ------- | self : object | Returns the instance itself. | | ---------------------------------------------------------------------- | Readonly properties defined here: | | sigma_ | DEPRECATED: Attribute `sigma_` was deprecated in 1.0 and will be removed in1.2. Use `var_` instead. | | ---------------------------------------------------------------------- | Data and other attributes defined here: | | __abstractmethods__ = frozenset() | | ---------------------------------------------------------------------- | Methods inherited from _BaseNB: | | predict(self, X) | Perform classification on an array of test vectors X. | | Parameters | ---------- | X : array-like of shape (n_samples, n_features) | The input samples. | | Returns | ------- | C : ndarray of shape (n_samples,) | Predicted target values for X. | | predict_log_proba(self, X) | Return log-probability estimates for the test vector X. | | Parameters | ---------- | X : array-like of shape (n_samples, n_features) | The input samples. | | Returns | ------- | C : array-like of shape (n_samples, n_classes) | Returns the log-probability of the samples for each class in | the model. The columns correspond to the classes in sorted | order, as they appear in the attribute :term:`classes_`. | | predict_proba(self, X) | Return probability estimates for the test vector X. | | Parameters | ---------- | X : array-like of shape (n_samples, n_features) | The input samples. | | Returns | ------- | C : array-like of shape (n_samples, n_classes) | Returns the probability of the samples for each class in | the model. The columns correspond to the classes in sorted | order, as they appear in the attribute :term:`classes_`. | | ---------------------------------------------------------------------- | Methods inherited from sklearn.base.ClassifierMixin: | | score(self, X, y, sample_weight=None) | Return the mean accuracy on the given test data and labels. | | In multi-label classification, this is the subset accuracy | which is a harsh metric since you require for each sample that | each label set be correctly predicted. | | Parameters | ---------- | X : array-like of shape (n_samples, n_features) | Test samples. | | y : array-like of shape (n_samples,) or (n_samples, n_outputs) | True labels for `X`. | | sample_weight : array-like of shape (n_samples,), default=None | Sample weights. | | Returns | ------- | score : float | Mean accuracy of ``self.predict(X)`` wrt. `y`. | | ---------------------------------------------------------------------- | Data descriptors inherited from sklearn.base.ClassifierMixin: | | __dict__ | dictionary for instance variables (if defined) | | __weakref__ | list of weak references to the object (if defined) | | ---------------------------------------------------------------------- | Methods inherited from sklearn.base.BaseEstimator: | | __getstate__(self) | | __repr__(self, N_CHAR_MAX=700) | Return repr(self). | | __setstate__(self, state) | | get_params(self, deep=True) | Get parameters for this estimator. | | Parameters | ---------- | deep : bool, default=True | If True, will return the parameters for this estimator and | contained subobjects that are estimators. | | Returns | ------- | params : dict | Parameter names mapped to their values. | | set_params(self, **params) | Set the parameters of this estimator. | | The method works on simple estimators as well as on nested objects | (such as :class:`~sklearn.pipeline.Pipeline`). The latter have | parameters of the form ``<component>__<parameter>`` so that it's | possible to update each component of a nested object. | | Parameters | ---------- | **params : dict | Estimator parameters. | | Returns | ------- | self : estimator instance | Estimator instance.
X.shape,y.shape
((150, 4), (150, 1))
clf.fit(X_train,y_train.ravel())
GaussianNB()
clf.score(X_test,y_test.ravel())
0.9666666666666667