【Python机器学习】实验15 将Lenet5应用于Cifar10数据集3

简介: 【Python机器学习】实验15 将Lenet5应用于Cifar10数据集

12. 采用pandas可视化数据

import pandas as pd 
table=pd.DataFrame(zip(pre_10000,label_10000))
table
0 1
0 [-0.49338394, -1.098238, 0.40724754, 1.7330961... 3
1 [4.0153656, 4.4736323, -0.29209492, -3.2882178... 8
2 [1.3858219, 3.2021556, -0.70040375, -1.0123051... 8
3 [2.11508, 0.82618773, 0.007076204, -1.1409527,... 0
4 [-2.352432, -2.7906854, 1.9833877, 2.1087575, ... 6
... ... ...
9995 [-0.55809855, -4.3891077, -0.3040389, 3.001731... 8
9996 [-2.7151718, -4.1596007, 1.2393914, 2.8491826,... 3
9997 [-1.9550545, -3.808494, 1.7917161, 2.6365147, ... 5
9998 [0.66809845, -0.5327946, 0.30590305, -0.182045... 1
9999 [-0.51935434, -2.6184506, 1.1929085, 0.1288419... 7

10000 rows × 2 columns

table[0].values
array([array([-0.49338394, -1.098238  ,  0.40724754,  1.7330961 , -0.4455951 ,
               1.6433077 ,  0.1720748 , -0.40375623, -1.165497  , -0.820113  ],
             dtype=float32)                                                    ,
       array([ 4.0153656 ,  4.4736323 , -0.29209492, -3.2882178 , -1.6234205 ,
              -4.481386  , -3.1240807 , -2.392501  ,  4.317573  ,  4.190993  ],
             dtype=float32)                                                    ,
       array([ 1.3858219 ,  3.2021556 , -0.70040375, -1.0123051 , -1.7393746 ,
              -1.6656632 , -3.2578242 , -2.2767155 ,  3.092283  ,  2.373978  ],
             dtype=float32)                                                    ,
       ...,
       array([-1.9550545 , -3.808494  ,  1.7917161 ,  2.6365147 ,  0.37311587,
               3.545672  , -0.43889195,  2.110389  , -2.9572597 , -1.7386926 ],
             dtype=float32)                                                    ,
       array([ 0.66809845, -0.5327946 ,  0.30590305, -0.18204585,  2.0045712 ,
               0.47369143, -0.3122899 ,  0.11701592, -2.5236375 , -0.5746133 ],
             dtype=float32)                                                    ,
       array([-0.51935434, -2.6184506 ,  1.1929085 ,  0.1288419 ,  1.8770852 ,
               0.4296908 , -0.22015049,  3.7748828 , -2.3134274 , -1.5123445 ],
             dtype=float32)                                                    ],
      dtype=object)
table["pred"]=[np.argmax(table[0][i]) for i in range(table.shape[0])]
table
0 1 pred
0 [-0.49338394, -1.098238, 0.40724754, 1.7330961... 3 3
1 [4.0153656, 4.4736323, -0.29209492, -3.2882178... 8 1
2 [1.3858219, 3.2021556, -0.70040375, -1.0123051... 8 1
3 [2.11508, 0.82618773, 0.007076204, -1.1409527,... 0 8
4 [-2.352432, -2.7906854, 1.9833877, 2.1087575, ... 6 6
... ... ... ...
9995 [-0.55809855, -4.3891077, -0.3040389, 3.001731... 8 5
9996 [-2.7151718, -4.1596007, 1.2393914, 2.8491826,... 3 3
9997 [-1.9550545, -3.808494, 1.7917161, 2.6365147, ... 5 5
9998 [0.66809845, -0.5327946, 0.30590305, -0.182045... 1 4
9999 [-0.51935434, -2.6184506, 1.1929085, 0.1288419... 7 7

10000 rows × 3 columns

13. 对预测错误的样本点进行可视化

mismatch=table[table[1]!=table["pred"]]
mismatch
0 1 pred
1 [4.0153656, 4.4736323, -0.29209492, -3.2882178... 8 1
2 [1.3858219, 3.2021556, -0.70040375, -1.0123051... 8 1
3 [2.11508, 0.82618773, 0.007076204, -1.1409527,... 0 8
8 [0.02641207, -3.6653092, 2.294829, 2.2884543, ... 3 5
12 [-1.4556388, -1.7955011, -0.6100754, 1.169481,... 5 6
... ... ... ...
9989 [-0.2553262, -2.8777533, 3.4579017, 0.3079242,... 2 4
9993 [-0.077826336, -3.14616, 0.8994149, 3.5604722,... 5 3
9994 [-1.2543154, -2.4472265, 0.6754027, 2.0582433,... 3 6
9995 [-0.55809855, -4.3891077, -0.3040389, 3.001731... 8 5
9998 [0.66809845, -0.5327946, 0.30590305, -0.182045... 1 4

4657 rows × 3 columns

from matplotlib import pyplot as plt
plt.scatter(mismatch[1],mismatch["pred"])
<matplotlib.collections.PathCollection at 0x1b3a92ef910>

14. 看看错误样本被预测为哪些数据?

mismatch[mismatch[1]==9].sort_values("pred").index


Int64Index([2129, 1465, 2907,  787, 2902, 2307, 4588, 5737, 8276, 8225,
            ...
            7635, 7553, 7526, 3999, 1626, 1639, 4193, 7198, 3957, 3344],
           dtype='int64', length=396)
idx_lst=mismatch[mismatch[1]==9].sort_values("pred").index.values
idx_lst,len(idx_lst)
(array([2129, 1465, 2907,  787, 2902, 2307, 4588, 5737, 8276, 8225, 8148,
        4836, 1155, 7218, 8034, 7412, 5069, 1629, 5094, 5109, 7685, 5397,
        1427, 5308, 8727, 2960, 2491, 6795, 1997, 6686, 9449, 6545, 8985,
        9401, 3564, 6034,  383, 9583, 9673,  507, 3288, 6868, 9133, 9085,
         577, 4261, 6974,  411, 6290, 5416, 5350, 5950, 5455, 5498, 6143,
        5964, 5864, 5877, 6188, 5939,   14, 5300, 3501, 3676, 3770, 3800,
        3850, 3893, 3902, 4233, 4252, 4253, 4276, 5335, 4297, 4418, 4445,
        4536, 4681, 6381, 4929, 4945, 5067, 5087, 5166, 5192, 4364, 4928,
        7024, 6542, 8144, 8312, 8385, 8406, 8453, 8465, 8521, 8585, 8673,
        8763, 8946, 9067, 9069, 9199, 9209, 9217, 9280, 9403, 9463, 9518,
        9692, 9743, 9871, 9875, 9881, 8066, 6509, 8057, 7826, 6741, 6811,
        6814, 6840, 6983, 7007, 3492, 7028, 7075, 7121, 7232, 7270, 7424,
        7431, 7444, 7492, 7499, 7501, 7578, 7639, 7729, 7767, 7792, 7818,
        7824, 7942, 3459, 4872, 1834, 1487, 1668, 1727, 1732, 1734, 1808,
        1814, 1815, 1831, 1927, 2111, 2126, 2190, 2246, 2290, 2433, 2596,
        2700, 2714, 1439, 1424, 1376, 1359,   28,  151,  172,  253,  259,
         335,  350,  591,  625, 2754,  734,  940,  951,  970, 1066, 1136,
        1177, 1199, 1222, 1231,  853, 2789, 9958, 2946, 3314, 3307, 2876,
        3208, 3166, 2944, 2817, 2305, 7522, 7155, 7220, 4590, 2899, 2446,
        2186, 7799, 9492, 3163, 4449, 2027, 2387, 1064, 3557, 2177,  654,
        9791, 2670, 2514, 2495, 3450, 8972, 3210, 3755, 2756, 7967, 3970,
        4550, 6017,  938,  744, 6951, 3397, 4852, 3133, 7931,  707, 3312,
        7470, 6871, 8292, 7100, 9529, 9100, 3853, 9060, 9732, 2521, 3789,
        2974, 5311, 3218, 5736, 3055, 7076, 1220, 9147, 1344,  532, 8218,
        3569, 1008, 8475, 8877, 1582, 8936, 4758, 1837, 9517,  252, 5832,
        1916, 6369, 4979, 9324, 6218, 9777, 7923, 4521, 2868,  213, 8083,
        5952, 5579, 4508, 5488, 2460, 5332, 5180, 8323, 8345, 3776, 2568,
        5151, 4570, 2854, 8488, 4874,  680, 2810, 1285, 6136, 3339, 9143,
        6852, 1906, 7067, 7073, 2975, 1924, 6804, 6755, 9299, 2019, 9445,
        9560,  360, 1601, 7297, 9122, 6377, 9214, 6167, 3980,  394, 7491,
        7581, 9349, 8953,  222,  139,  530, 3577, 9868,  247, 9099, 9026,
         209,  538, 3229, 9258,  585, 9204, 9643, 1492, 3609, 6570, 6561,
        6469, 6435, 6419, 2155, 6275, 4481, 2202, 1987, 2271, 2355, 2366,
        2432, 5400, 2497, 2727, 4931, 4619, 9884, 5902, 8796, 6848, 6960,
        8575, 8413,  981, 8272, 8145, 3172, 1221, 3168, 1256, 1889, 1291,
        3964, 7635, 7553, 7526, 3999, 1626, 1639, 4193, 7198, 3957, 3344],
       dtype=int64),
 396)
import numpy as np
img=np.stack(list(test_dataset[idx_lst[i]][0][0] for i in range(5)),axis=1).reshape(32,32*5)
plt.imshow(img)
plt.axis('off')
(-0.5, 159.5, 31.5, -0.5)

51b620f1d16b42cd90faf82a9e30d3bb.png

#显示4行
import numpy as np
img20=np.stack(
    tuple(np.stack(
            tuple(test_dataset[idx_lst[i+j*5]][0][0] for i in range(5)),
        axis=1).reshape(32,32*5) for j in range(4)),axis=0).reshape(32*4,32*5)
plt.imshow(img20)
plt.axis('off')
(-0.5, 159.5, 127.5, -0.5)

0c783bf29137476c835285acc890ac80.png

15.输出错误的模型类别

idx_lst=mismatch[mismatch[1]==9].index.values
table.iloc[idx_lst[:], 2].values
array([1, 1, 8, 1, 1, 8, 7, 8, 8, 6, 1, 1, 1, 1, 7, 0, 7, 0, 0, 8, 6, 8,
       0, 8, 1, 1, 3, 7, 5, 1, 4, 0, 1, 4, 1, 1, 1, 8, 6, 3, 1, 1, 0, 1,
       1, 6, 8, 1, 1, 8, 7, 8, 6, 1, 1, 1, 0, 1, 0, 1, 8, 6, 7, 8, 0, 8,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 6, 8, 7, 6, 7, 1, 8, 0, 7, 3, 1, 1, 0,
       8, 3, 3, 1, 8, 1, 8, 1, 2, 0, 8, 8, 3, 8, 1, 3, 7, 0, 3, 8, 3, 5,
       7, 1, 3, 1, 1, 8, 1, 3, 1, 7, 1, 7, 7, 1, 3, 0, 0, 1, 1, 0, 5, 7,
       6, 4, 3, 1, 8, 8, 1, 3, 5, 8, 0, 1, 5, 1, 7, 8, 4, 3, 1, 1, 1, 3,
       0, 6, 8, 8, 1, 3, 1, 7, 5, 1, 1, 5, 1, 1, 8, 8, 4, 7, 8, 8, 1, 1,
       1, 0, 1, 1, 1, 1, 1, 3, 8, 7, 7, 1, 4, 7, 0, 2, 8, 1, 6, 0, 4, 1,
       7, 1, 1, 8, 1, 6, 1, 0, 1, 0, 0, 7, 1, 7, 1, 1, 0, 5, 7, 1, 1, 0,
       8, 1, 1, 7, 1, 7, 5, 0, 6, 1, 1, 8, 1, 1, 7, 1, 4, 0, 7, 1, 7, 1,
       6, 8, 1, 6, 7, 1, 8, 8, 8, 1, 1, 0, 8, 8, 0, 1, 7, 0, 7, 1, 1, 1,
       8, 7, 0, 5, 4, 8, 0, 1, 1, 1, 1, 7, 7, 1, 6, 5, 1, 2, 8, 0, 2, 1,
       1, 7, 0, 1, 1, 1, 5, 7, 1, 1, 1, 2, 8, 8, 1, 7, 8, 1, 0, 1, 1, 1,
       3, 1, 1, 1, 7, 4, 1, 4, 0, 1, 1, 7, 1, 8, 0, 6, 0, 8, 0, 5, 1, 7,
       7, 1, 1, 8, 1, 1, 6, 7, 1, 8, 1, 1, 0, 1, 8, 6, 6, 1, 8, 3, 0, 8,
       5, 1, 1, 0, 8, 5, 7, 0, 7, 6, 1, 8, 1, 7, 1, 8, 1, 7, 6, 8, 0, 1,
       7, 0, 1, 3, 6, 1, 5, 7, 0, 8, 0, 1, 5, 1, 6, 3, 8, 1, 1, 1, 8, 1],
      dtype=int64)
arr2=table.iloc[idx_lst[:], 2].values
print('错误模型共' + str(len(arr2)) + '个')
for i in range(33):
    for j in range(12):
        print(classes[arr2[j+i*12]],end=" ")
    print()
错误模型共396个
car car ship car car ship horse ship ship frog car car 
car car horse plane horse plane plane ship frog ship plane ship 
car car cat horse dog car deer plane car deer car car 
car ship frog cat car car plane car car frog ship car 
car ship horse ship frog car car car plane car plane car 
ship frog horse ship plane ship car car car car car car 
car car car frog ship horse frog horse car ship plane horse 
cat car car plane ship cat cat car ship car ship car 
bird plane ship ship cat ship car cat horse plane cat ship 
cat dog horse car cat car car ship car cat car horse 
car horse horse car cat plane plane car car plane dog horse 
frog deer cat car ship ship car cat dog ship plane car 
dog car horse ship deer cat car car car cat plane frog 
ship ship car cat car horse dog car car dog car car 
ship ship deer horse ship ship car car car plane car car 
car car car cat ship horse horse car deer horse plane bird 
ship car frog plane deer car horse car car ship car frog 
car plane car plane plane horse car horse car car plane dog 
horse car car plane ship car car horse car horse dog plane 
frog car car ship car car horse car deer plane horse car 
horse car frog ship car frog horse car ship ship ship car 
car plane ship ship plane car horse plane horse car car car 
ship horse plane dog deer ship plane car car car car horse 
horse car frog dog car bird ship plane bird car car horse 
plane car car car dog horse car car car bird ship ship 
car horse ship car plane car car car cat car car car 
horse deer car deer plane car car horse car ship plane frog 
plane ship plane dog car horse horse car car ship car car 
frog horse car ship car car plane car ship frog frog car 
ship cat plane ship dog car car plane ship dog horse plane 
horse frog car ship car horse car ship car horse frog ship 
plane car horse plane car cat frog car dog horse plane ship 
plane car dog car frog cat ship car car car ship car 


目录
相关文章
|
11天前
|
机器学习/深度学习 人工智能 算法
【昆虫识别系统】图像识别Python+卷积神经网络算法+人工智能+深度学习+机器学习+TensorFlow+ResNet50
昆虫识别系统,使用Python作为主要开发语言。通过TensorFlow搭建ResNet50卷积神经网络算法(CNN)模型。通过对10种常见的昆虫图片数据集('蜜蜂', '甲虫', '蝴蝶', '蝉', '蜻蜓', '蚱蜢', '蛾', '蝎子', '蜗牛', '蜘蛛')进行训练,得到一个识别精度较高的H5格式模型文件,然后使用Django搭建Web网页端可视化操作界面,实现用户上传一张昆虫图片识别其名称。
144 7
【昆虫识别系统】图像识别Python+卷积神经网络算法+人工智能+深度学习+机器学习+TensorFlow+ResNet50
|
9天前
|
机器学习/深度学习 算法 数据挖掘
Python机器学习10大经典算法的讲解和示例
为了展示10个经典的机器学习算法的最简例子,我将为每个算法编写一个小的示例代码。这些算法将包括线性回归、逻辑回归、K-最近邻(KNN)、支持向量机(SVM)、决策树、随机森林、朴素贝叶斯、K-均值聚类、主成分分析(PCA)、和梯度提升(Gradient Boosting)。我将使用常见的机器学习库,如 scikit-learn,numpy 和 pandas 来实现这些算法。
|
15天前
|
机器学习/深度学习 数据采集 算法
【机器学习】Scikit-Learn:Python机器学习的瑞士军刀
【机器学习】Scikit-Learn:Python机器学习的瑞士军刀
34 3
|
15天前
|
机器学习/深度学习 机器人 Python
实践指南,终于有大佬把Python和机器学习讲明白了!
机器学习正在迅速成为数据驱动型世界的一个必备模块。许多不同的领域,如机器人、医学、零售和出版等,都需要依赖这门技术。 机器学习是近年来渐趋热门的一个领域,同时 Python 语言经过一段时间的发展也已逐渐成为主流的编程语言之一。今天给小伙伴们分享的这份手册结合了机器学习和 Python 语言两个热门的领域,通过易于理解的项目详细讲述了如何构建真实的机器学习应用程序。
|
17天前
|
机器学习/深度学习 人工智能 监控
【机器学习】Python与深度学习的完美结合——深度学习在医学影像诊断中的惊人表现
【机器学习】Python与深度学习的完美结合——深度学习在医学影像诊断中的惊人表现
35 3
|
7天前
|
数据采集 机器学习/深度学习 算法
机器学习方法之决策树算法
决策树算法是一种常用的机器学习方法,可以应用于分类和回归任务。通过递归地将数据集划分为更小的子集,从而形成一棵树状的结构模型。每个内部节点代表一个特征的判断,每个分支代表这个特征的某个取值或范围,每个叶节点则表示预测结果。
25 1
|
11天前
|
机器学习/深度学习 人工智能 算法
算法金 | 统计学的回归和机器学习中的回归有什么差别?
**摘要:** 统计学回归重在解释,使用线性模型分析小数据集,强调假设检验与解释性。机器学习回归目标预测,处理大数据集,模型复杂多样,关注泛化能力和预测误差。两者在假设、模型、数据量和评估标准上有显著差异,分别适用于解释性研究和预测任务。
38 8
算法金 | 统计学的回归和机器学习中的回归有什么差别?
|
1天前
|
机器学习/深度学习 数据采集 人工智能
|
8天前
|
机器学习/深度学习 人工智能 自然语言处理
机器学习算法入门:从K-means到神经网络
【6月更文挑战第26天】机器学习入门:从K-means到神经网络。文章涵盖了K-means聚类、逻辑回归、决策树和神经网络的基础原理及应用场景。K-means用于数据分组,逻辑回归适用于二分类,决策树通过特征划分做决策,神经网络则在复杂任务如图像和语言处理中大显身手。是初学者的算法导览。
|
12天前
|
机器学习/深度学习 人工智能 Dart
AI - 机器学习GBDT算法
梯度提升决策树(Gradient Boosting Decision Tree),是一种集成学习的算法,它通过构建多个决策树来逐步修正之前模型的错误,从而提升模型整体的预测性能。