【Python机器学习】实验15 将Lenet5应用于Cifar10数据集3

简介: 【Python机器学习】实验15 将Lenet5应用于Cifar10数据集

12. 采用pandas可视化数据

import pandas as pd 
table=pd.DataFrame(zip(pre_10000,label_10000))
table
0 1
0 [-0.49338394, -1.098238, 0.40724754, 1.7330961... 3
1 [4.0153656, 4.4736323, -0.29209492, -3.2882178... 8
2 [1.3858219, 3.2021556, -0.70040375, -1.0123051... 8
3 [2.11508, 0.82618773, 0.007076204, -1.1409527,... 0
4 [-2.352432, -2.7906854, 1.9833877, 2.1087575, ... 6
... ... ...
9995 [-0.55809855, -4.3891077, -0.3040389, 3.001731... 8
9996 [-2.7151718, -4.1596007, 1.2393914, 2.8491826,... 3
9997 [-1.9550545, -3.808494, 1.7917161, 2.6365147, ... 5
9998 [0.66809845, -0.5327946, 0.30590305, -0.182045... 1
9999 [-0.51935434, -2.6184506, 1.1929085, 0.1288419... 7

10000 rows × 2 columns

table[0].values
array([array([-0.49338394, -1.098238  ,  0.40724754,  1.7330961 , -0.4455951 ,
               1.6433077 ,  0.1720748 , -0.40375623, -1.165497  , -0.820113  ],
             dtype=float32)                                                    ,
       array([ 4.0153656 ,  4.4736323 , -0.29209492, -3.2882178 , -1.6234205 ,
              -4.481386  , -3.1240807 , -2.392501  ,  4.317573  ,  4.190993  ],
             dtype=float32)                                                    ,
       array([ 1.3858219 ,  3.2021556 , -0.70040375, -1.0123051 , -1.7393746 ,
              -1.6656632 , -3.2578242 , -2.2767155 ,  3.092283  ,  2.373978  ],
             dtype=float32)                                                    ,
       ...,
       array([-1.9550545 , -3.808494  ,  1.7917161 ,  2.6365147 ,  0.37311587,
               3.545672  , -0.43889195,  2.110389  , -2.9572597 , -1.7386926 ],
             dtype=float32)                                                    ,
       array([ 0.66809845, -0.5327946 ,  0.30590305, -0.18204585,  2.0045712 ,
               0.47369143, -0.3122899 ,  0.11701592, -2.5236375 , -0.5746133 ],
             dtype=float32)                                                    ,
       array([-0.51935434, -2.6184506 ,  1.1929085 ,  0.1288419 ,  1.8770852 ,
               0.4296908 , -0.22015049,  3.7748828 , -2.3134274 , -1.5123445 ],
             dtype=float32)                                                    ],
      dtype=object)
table["pred"]=[np.argmax(table[0][i]) for i in range(table.shape[0])]
table
0 1 pred
0 [-0.49338394, -1.098238, 0.40724754, 1.7330961... 3 3
1 [4.0153656, 4.4736323, -0.29209492, -3.2882178... 8 1
2 [1.3858219, 3.2021556, -0.70040375, -1.0123051... 8 1
3 [2.11508, 0.82618773, 0.007076204, -1.1409527,... 0 8
4 [-2.352432, -2.7906854, 1.9833877, 2.1087575, ... 6 6
... ... ... ...
9995 [-0.55809855, -4.3891077, -0.3040389, 3.001731... 8 5
9996 [-2.7151718, -4.1596007, 1.2393914, 2.8491826,... 3 3
9997 [-1.9550545, -3.808494, 1.7917161, 2.6365147, ... 5 5
9998 [0.66809845, -0.5327946, 0.30590305, -0.182045... 1 4
9999 [-0.51935434, -2.6184506, 1.1929085, 0.1288419... 7 7

10000 rows × 3 columns

13. 对预测错误的样本点进行可视化

mismatch=table[table[1]!=table["pred"]]
mismatch
0 1 pred
1 [4.0153656, 4.4736323, -0.29209492, -3.2882178... 8 1
2 [1.3858219, 3.2021556, -0.70040375, -1.0123051... 8 1
3 [2.11508, 0.82618773, 0.007076204, -1.1409527,... 0 8
8 [0.02641207, -3.6653092, 2.294829, 2.2884543, ... 3 5
12 [-1.4556388, -1.7955011, -0.6100754, 1.169481,... 5 6
... ... ... ...
9989 [-0.2553262, -2.8777533, 3.4579017, 0.3079242,... 2 4
9993 [-0.077826336, -3.14616, 0.8994149, 3.5604722,... 5 3
9994 [-1.2543154, -2.4472265, 0.6754027, 2.0582433,... 3 6
9995 [-0.55809855, -4.3891077, -0.3040389, 3.001731... 8 5
9998 [0.66809845, -0.5327946, 0.30590305, -0.182045... 1 4

4657 rows × 3 columns

from matplotlib import pyplot as plt
plt.scatter(mismatch[1],mismatch["pred"])
<matplotlib.collections.PathCollection at 0x1b3a92ef910>

14. 看看错误样本被预测为哪些数据?

mismatch[mismatch[1]==9].sort_values("pred").index


Int64Index([2129, 1465, 2907,  787, 2902, 2307, 4588, 5737, 8276, 8225,
            ...
            7635, 7553, 7526, 3999, 1626, 1639, 4193, 7198, 3957, 3344],
           dtype='int64', length=396)
idx_lst=mismatch[mismatch[1]==9].sort_values("pred").index.values
idx_lst,len(idx_lst)
(array([2129, 1465, 2907,  787, 2902, 2307, 4588, 5737, 8276, 8225, 8148,
        4836, 1155, 7218, 8034, 7412, 5069, 1629, 5094, 5109, 7685, 5397,
        1427, 5308, 8727, 2960, 2491, 6795, 1997, 6686, 9449, 6545, 8985,
        9401, 3564, 6034,  383, 9583, 9673,  507, 3288, 6868, 9133, 9085,
         577, 4261, 6974,  411, 6290, 5416, 5350, 5950, 5455, 5498, 6143,
        5964, 5864, 5877, 6188, 5939,   14, 5300, 3501, 3676, 3770, 3800,
        3850, 3893, 3902, 4233, 4252, 4253, 4276, 5335, 4297, 4418, 4445,
        4536, 4681, 6381, 4929, 4945, 5067, 5087, 5166, 5192, 4364, 4928,
        7024, 6542, 8144, 8312, 8385, 8406, 8453, 8465, 8521, 8585, 8673,
        8763, 8946, 9067, 9069, 9199, 9209, 9217, 9280, 9403, 9463, 9518,
        9692, 9743, 9871, 9875, 9881, 8066, 6509, 8057, 7826, 6741, 6811,
        6814, 6840, 6983, 7007, 3492, 7028, 7075, 7121, 7232, 7270, 7424,
        7431, 7444, 7492, 7499, 7501, 7578, 7639, 7729, 7767, 7792, 7818,
        7824, 7942, 3459, 4872, 1834, 1487, 1668, 1727, 1732, 1734, 1808,
        1814, 1815, 1831, 1927, 2111, 2126, 2190, 2246, 2290, 2433, 2596,
        2700, 2714, 1439, 1424, 1376, 1359,   28,  151,  172,  253,  259,
         335,  350,  591,  625, 2754,  734,  940,  951,  970, 1066, 1136,
        1177, 1199, 1222, 1231,  853, 2789, 9958, 2946, 3314, 3307, 2876,
        3208, 3166, 2944, 2817, 2305, 7522, 7155, 7220, 4590, 2899, 2446,
        2186, 7799, 9492, 3163, 4449, 2027, 2387, 1064, 3557, 2177,  654,
        9791, 2670, 2514, 2495, 3450, 8972, 3210, 3755, 2756, 7967, 3970,
        4550, 6017,  938,  744, 6951, 3397, 4852, 3133, 7931,  707, 3312,
        7470, 6871, 8292, 7100, 9529, 9100, 3853, 9060, 9732, 2521, 3789,
        2974, 5311, 3218, 5736, 3055, 7076, 1220, 9147, 1344,  532, 8218,
        3569, 1008, 8475, 8877, 1582, 8936, 4758, 1837, 9517,  252, 5832,
        1916, 6369, 4979, 9324, 6218, 9777, 7923, 4521, 2868,  213, 8083,
        5952, 5579, 4508, 5488, 2460, 5332, 5180, 8323, 8345, 3776, 2568,
        5151, 4570, 2854, 8488, 4874,  680, 2810, 1285, 6136, 3339, 9143,
        6852, 1906, 7067, 7073, 2975, 1924, 6804, 6755, 9299, 2019, 9445,
        9560,  360, 1601, 7297, 9122, 6377, 9214, 6167, 3980,  394, 7491,
        7581, 9349, 8953,  222,  139,  530, 3577, 9868,  247, 9099, 9026,
         209,  538, 3229, 9258,  585, 9204, 9643, 1492, 3609, 6570, 6561,
        6469, 6435, 6419, 2155, 6275, 4481, 2202, 1987, 2271, 2355, 2366,
        2432, 5400, 2497, 2727, 4931, 4619, 9884, 5902, 8796, 6848, 6960,
        8575, 8413,  981, 8272, 8145, 3172, 1221, 3168, 1256, 1889, 1291,
        3964, 7635, 7553, 7526, 3999, 1626, 1639, 4193, 7198, 3957, 3344],
       dtype=int64),
 396)
import numpy as np
img=np.stack(list(test_dataset[idx_lst[i]][0][0] for i in range(5)),axis=1).reshape(32,32*5)
plt.imshow(img)
plt.axis('off')
(-0.5, 159.5, 31.5, -0.5)

51b620f1d16b42cd90faf82a9e30d3bb.png

#显示4行
import numpy as np
img20=np.stack(
    tuple(np.stack(
            tuple(test_dataset[idx_lst[i+j*5]][0][0] for i in range(5)),
        axis=1).reshape(32,32*5) for j in range(4)),axis=0).reshape(32*4,32*5)
plt.imshow(img20)
plt.axis('off')
(-0.5, 159.5, 127.5, -0.5)

0c783bf29137476c835285acc890ac80.png

15.输出错误的模型类别

idx_lst=mismatch[mismatch[1]==9].index.values
table.iloc[idx_lst[:], 2].values
array([1, 1, 8, 1, 1, 8, 7, 8, 8, 6, 1, 1, 1, 1, 7, 0, 7, 0, 0, 8, 6, 8,
       0, 8, 1, 1, 3, 7, 5, 1, 4, 0, 1, 4, 1, 1, 1, 8, 6, 3, 1, 1, 0, 1,
       1, 6, 8, 1, 1, 8, 7, 8, 6, 1, 1, 1, 0, 1, 0, 1, 8, 6, 7, 8, 0, 8,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 6, 8, 7, 6, 7, 1, 8, 0, 7, 3, 1, 1, 0,
       8, 3, 3, 1, 8, 1, 8, 1, 2, 0, 8, 8, 3, 8, 1, 3, 7, 0, 3, 8, 3, 5,
       7, 1, 3, 1, 1, 8, 1, 3, 1, 7, 1, 7, 7, 1, 3, 0, 0, 1, 1, 0, 5, 7,
       6, 4, 3, 1, 8, 8, 1, 3, 5, 8, 0, 1, 5, 1, 7, 8, 4, 3, 1, 1, 1, 3,
       0, 6, 8, 8, 1, 3, 1, 7, 5, 1, 1, 5, 1, 1, 8, 8, 4, 7, 8, 8, 1, 1,
       1, 0, 1, 1, 1, 1, 1, 3, 8, 7, 7, 1, 4, 7, 0, 2, 8, 1, 6, 0, 4, 1,
       7, 1, 1, 8, 1, 6, 1, 0, 1, 0, 0, 7, 1, 7, 1, 1, 0, 5, 7, 1, 1, 0,
       8, 1, 1, 7, 1, 7, 5, 0, 6, 1, 1, 8, 1, 1, 7, 1, 4, 0, 7, 1, 7, 1,
       6, 8, 1, 6, 7, 1, 8, 8, 8, 1, 1, 0, 8, 8, 0, 1, 7, 0, 7, 1, 1, 1,
       8, 7, 0, 5, 4, 8, 0, 1, 1, 1, 1, 7, 7, 1, 6, 5, 1, 2, 8, 0, 2, 1,
       1, 7, 0, 1, 1, 1, 5, 7, 1, 1, 1, 2, 8, 8, 1, 7, 8, 1, 0, 1, 1, 1,
       3, 1, 1, 1, 7, 4, 1, 4, 0, 1, 1, 7, 1, 8, 0, 6, 0, 8, 0, 5, 1, 7,
       7, 1, 1, 8, 1, 1, 6, 7, 1, 8, 1, 1, 0, 1, 8, 6, 6, 1, 8, 3, 0, 8,
       5, 1, 1, 0, 8, 5, 7, 0, 7, 6, 1, 8, 1, 7, 1, 8, 1, 7, 6, 8, 0, 1,
       7, 0, 1, 3, 6, 1, 5, 7, 0, 8, 0, 1, 5, 1, 6, 3, 8, 1, 1, 1, 8, 1],
      dtype=int64)
arr2=table.iloc[idx_lst[:], 2].values
print('错误模型共' + str(len(arr2)) + '个')
for i in range(33):
    for j in range(12):
        print(classes[arr2[j+i*12]],end=" ")
    print()
错误模型共396个
car car ship car car ship horse ship ship frog car car 
car car horse plane horse plane plane ship frog ship plane ship 
car car cat horse dog car deer plane car deer car car 
car ship frog cat car car plane car car frog ship car 
car ship horse ship frog car car car plane car plane car 
ship frog horse ship plane ship car car car car car car 
car car car frog ship horse frog horse car ship plane horse 
cat car car plane ship cat cat car ship car ship car 
bird plane ship ship cat ship car cat horse plane cat ship 
cat dog horse car cat car car ship car cat car horse 
car horse horse car cat plane plane car car plane dog horse 
frog deer cat car ship ship car cat dog ship plane car 
dog car horse ship deer cat car car car cat plane frog 
ship ship car cat car horse dog car car dog car car 
ship ship deer horse ship ship car car car plane car car 
car car car cat ship horse horse car deer horse plane bird 
ship car frog plane deer car horse car car ship car frog 
car plane car plane plane horse car horse car car plane dog 
horse car car plane ship car car horse car horse dog plane 
frog car car ship car car horse car deer plane horse car 
horse car frog ship car frog horse car ship ship ship car 
car plane ship ship plane car horse plane horse car car car 
ship horse plane dog deer ship plane car car car car horse 
horse car frog dog car bird ship plane bird car car horse 
plane car car car dog horse car car car bird ship ship 
car horse ship car plane car car car cat car car car 
horse deer car deer plane car car horse car ship plane frog 
plane ship plane dog car horse horse car car ship car car 
frog horse car ship car car plane car ship frog frog car 
ship cat plane ship dog car car plane ship dog horse plane 
horse frog car ship car horse car ship car horse frog ship 
plane car horse plane car cat frog car dog horse plane ship 
plane car dog car frog cat ship car car car ship car 


目录
相关文章
|
9天前
|
监控 数据可视化 数据挖掘
Python Rich库使用指南:打造更美观的命令行应用
Rich库是Python的终端美化利器,支持彩色文本、智能表格、动态进度条和语法高亮,大幅提升命令行应用的可视化效果与用户体验。
51 0
|
1月前
|
数据采集 监控 Java
Python 函数式编程的执行效率:实际应用中的权衡
Python 函数式编程的执行效率:实际应用中的权衡
201 102
|
17天前
|
机器学习/深度学习 数据采集 算法
量子机器学习入门:三种数据编码方法对比与应用
在量子机器学习中,数据编码方式决定了量子模型如何理解和处理信息。本文详解角度编码、振幅编码与基础编码三种方法,分析其原理、实现及适用场景,帮助读者选择最适合的编码策略,提升量子模型性能。
110 8
|
10天前
|
机器学习/深度学习 算法 安全
【强化学习应用(八)】基于Q-learning的无人机物流路径规划研究(Python代码实现)
【强化学习应用(八)】基于Q-learning的无人机物流路径规划研究(Python代码实现)
|
25天前
|
设计模式 缓存 运维
Python装饰器实战场景解析:从原理到应用的10个经典案例
Python装饰器是函数式编程的精华,通过10个实战场景,从日志记录、权限验证到插件系统,全面解析其应用。掌握装饰器,让代码更优雅、灵活,提升开发效率。
86 0
|
29天前
|
数据采集 存储 数据可视化
Python网络爬虫在环境保护中的应用:污染源监测数据抓取与分析
在环保领域,数据是决策基础,但分散在多个平台,获取困难。Python网络爬虫技术灵活高效,可自动化抓取空气质量、水质、污染源等数据,实现多平台整合、实时更新、结构化存储与异常预警。本文详解爬虫实战应用,涵盖技术选型、代码实现、反爬策略与数据分析,助力环保数据高效利用。
102 0
|
1月前
|
存储 程序员 数据处理
Python列表基础操作全解析:从创建到灵活应用
本文深入浅出地讲解了Python列表的各类操作,从创建、增删改查到遍历与性能优化,内容详实且贴近实战,适合初学者快速掌握这一核心数据结构。
166 0
|
1月前
|
机器学习/深度学习 人工智能 自然语言处理
Java 大视界 -- Java 大数据机器学习模型在自然语言生成中的可控性研究与应用(229)
本文深入探讨Java大数据与机器学习在自然语言生成(NLG)中的可控性研究,分析当前生成模型面临的“失控”挑战,如数据噪声、标注偏差及黑盒模型信任问题,提出Java技术在数据清洗、异构框架融合与生态工具链中的关键作用。通过条件注入、强化学习与模型融合等策略,实现文本生成的精准控制,并结合网易新闻与蚂蚁集团的实战案例,展示Java在提升生成效率与合规性方面的卓越能力,为金融、法律等强监管领域提供技术参考。
|
1月前
|
机器学习/深度学习 算法 Java
Java 大视界 -- Java 大数据机器学习模型在生物信息学基因功能预测中的优化与应用(223)
本文探讨了Java大数据与机器学习模型在生物信息学中基因功能预测的优化与应用。通过高效的数据处理能力和智能算法,提升基因功能预测的准确性与效率,助力医学与农业发展。
|
1月前
|
机器学习/深度学习 搜索推荐 数据可视化
Java 大视界 -- Java 大数据机器学习模型在电商用户流失预测与留存策略制定中的应用(217)
本文探讨 Java 大数据与机器学习在电商用户流失预测与留存策略中的应用。通过构建高精度预测模型与动态分层策略,助力企业提前识别流失用户、精准触达,实现用户留存率与商业价值双提升,为电商应对用户流失提供技术新思路。

推荐镜像

更多