12. 采用pandas可视化数据
import pandas as pd table=pd.DataFrame(zip(pre_10000,label_10000)) table
0 | 1 | |
0 | [-0.49338394, -1.098238, 0.40724754, 1.7330961... | 3 |
1 | [4.0153656, 4.4736323, -0.29209492, -3.2882178... | 8 |
2 | [1.3858219, 3.2021556, -0.70040375, -1.0123051... | 8 |
3 | [2.11508, 0.82618773, 0.007076204, -1.1409527,... | 0 |
4 | [-2.352432, -2.7906854, 1.9833877, 2.1087575, ... | 6 |
... | ... | ... |
9995 | [-0.55809855, -4.3891077, -0.3040389, 3.001731... | 8 |
9996 | [-2.7151718, -4.1596007, 1.2393914, 2.8491826,... | 3 |
9997 | [-1.9550545, -3.808494, 1.7917161, 2.6365147, ... | 5 |
9998 | [0.66809845, -0.5327946, 0.30590305, -0.182045... | 1 |
9999 | [-0.51935434, -2.6184506, 1.1929085, 0.1288419... | 7 |
10000 rows × 2 columns
table[0].values
array([array([-0.49338394, -1.098238 , 0.40724754, 1.7330961 , -0.4455951 , 1.6433077 , 0.1720748 , -0.40375623, -1.165497 , -0.820113 ], dtype=float32) , array([ 4.0153656 , 4.4736323 , -0.29209492, -3.2882178 , -1.6234205 , -4.481386 , -3.1240807 , -2.392501 , 4.317573 , 4.190993 ], dtype=float32) , array([ 1.3858219 , 3.2021556 , -0.70040375, -1.0123051 , -1.7393746 , -1.6656632 , -3.2578242 , -2.2767155 , 3.092283 , 2.373978 ], dtype=float32) , ..., array([-1.9550545 , -3.808494 , 1.7917161 , 2.6365147 , 0.37311587, 3.545672 , -0.43889195, 2.110389 , -2.9572597 , -1.7386926 ], dtype=float32) , array([ 0.66809845, -0.5327946 , 0.30590305, -0.18204585, 2.0045712 , 0.47369143, -0.3122899 , 0.11701592, -2.5236375 , -0.5746133 ], dtype=float32) , array([-0.51935434, -2.6184506 , 1.1929085 , 0.1288419 , 1.8770852 , 0.4296908 , -0.22015049, 3.7748828 , -2.3134274 , -1.5123445 ], dtype=float32) ], dtype=object)
table["pred"]=[np.argmax(table[0][i]) for i in range(table.shape[0])] table
0 | 1 | pred | |
0 | [-0.49338394, -1.098238, 0.40724754, 1.7330961... | 3 | 3 |
1 | [4.0153656, 4.4736323, -0.29209492, -3.2882178... | 8 | 1 |
2 | [1.3858219, 3.2021556, -0.70040375, -1.0123051... | 8 | 1 |
3 | [2.11508, 0.82618773, 0.007076204, -1.1409527,... | 0 | 8 |
4 | [-2.352432, -2.7906854, 1.9833877, 2.1087575, ... | 6 | 6 |
... | ... | ... | ... |
9995 | [-0.55809855, -4.3891077, -0.3040389, 3.001731... | 8 | 5 |
9996 | [-2.7151718, -4.1596007, 1.2393914, 2.8491826,... | 3 | 3 |
9997 | [-1.9550545, -3.808494, 1.7917161, 2.6365147, ... | 5 | 5 |
9998 | [0.66809845, -0.5327946, 0.30590305, -0.182045... | 1 | 4 |
9999 | [-0.51935434, -2.6184506, 1.1929085, 0.1288419... | 7 | 7 |
10000 rows × 3 columns
13. 对预测错误的样本点进行可视化
mismatch=table[table[1]!=table["pred"]]
mismatch
0 | 1 | pred | |
1 | [4.0153656, 4.4736323, -0.29209492, -3.2882178... | 8 | 1 |
2 | [1.3858219, 3.2021556, -0.70040375, -1.0123051... | 8 | 1 |
3 | [2.11508, 0.82618773, 0.007076204, -1.1409527,... | 0 | 8 |
8 | [0.02641207, -3.6653092, 2.294829, 2.2884543, ... | 3 | 5 |
12 | [-1.4556388, -1.7955011, -0.6100754, 1.169481,... | 5 | 6 |
... | ... | ... | ... |
9989 | [-0.2553262, -2.8777533, 3.4579017, 0.3079242,... | 2 | 4 |
9993 | [-0.077826336, -3.14616, 0.8994149, 3.5604722,... | 5 | 3 |
9994 | [-1.2543154, -2.4472265, 0.6754027, 2.0582433,... | 3 | 6 |
9995 | [-0.55809855, -4.3891077, -0.3040389, 3.001731... | 8 | 5 |
9998 | [0.66809845, -0.5327946, 0.30590305, -0.182045... | 1 | 4 |
4657 rows × 3 columns
from matplotlib import pyplot as plt plt.scatter(mismatch[1],mismatch["pred"])
<matplotlib.collections.PathCollection at 0x1b3a92ef910>
14. 看看错误样本被预测为哪些数据?
mismatch[mismatch[1]==9].sort_values("pred").index
Int64Index([2129, 1465, 2907, 787, 2902, 2307, 4588, 5737, 8276, 8225, ... 7635, 7553, 7526, 3999, 1626, 1639, 4193, 7198, 3957, 3344], dtype='int64', length=396)
idx_lst=mismatch[mismatch[1]==9].sort_values("pred").index.values idx_lst,len(idx_lst)
(array([2129, 1465, 2907, 787, 2902, 2307, 4588, 5737, 8276, 8225, 8148, 4836, 1155, 7218, 8034, 7412, 5069, 1629, 5094, 5109, 7685, 5397, 1427, 5308, 8727, 2960, 2491, 6795, 1997, 6686, 9449, 6545, 8985, 9401, 3564, 6034, 383, 9583, 9673, 507, 3288, 6868, 9133, 9085, 577, 4261, 6974, 411, 6290, 5416, 5350, 5950, 5455, 5498, 6143, 5964, 5864, 5877, 6188, 5939, 14, 5300, 3501, 3676, 3770, 3800, 3850, 3893, 3902, 4233, 4252, 4253, 4276, 5335, 4297, 4418, 4445, 4536, 4681, 6381, 4929, 4945, 5067, 5087, 5166, 5192, 4364, 4928, 7024, 6542, 8144, 8312, 8385, 8406, 8453, 8465, 8521, 8585, 8673, 8763, 8946, 9067, 9069, 9199, 9209, 9217, 9280, 9403, 9463, 9518, 9692, 9743, 9871, 9875, 9881, 8066, 6509, 8057, 7826, 6741, 6811, 6814, 6840, 6983, 7007, 3492, 7028, 7075, 7121, 7232, 7270, 7424, 7431, 7444, 7492, 7499, 7501, 7578, 7639, 7729, 7767, 7792, 7818, 7824, 7942, 3459, 4872, 1834, 1487, 1668, 1727, 1732, 1734, 1808, 1814, 1815, 1831, 1927, 2111, 2126, 2190, 2246, 2290, 2433, 2596, 2700, 2714, 1439, 1424, 1376, 1359, 28, 151, 172, 253, 259, 335, 350, 591, 625, 2754, 734, 940, 951, 970, 1066, 1136, 1177, 1199, 1222, 1231, 853, 2789, 9958, 2946, 3314, 3307, 2876, 3208, 3166, 2944, 2817, 2305, 7522, 7155, 7220, 4590, 2899, 2446, 2186, 7799, 9492, 3163, 4449, 2027, 2387, 1064, 3557, 2177, 654, 9791, 2670, 2514, 2495, 3450, 8972, 3210, 3755, 2756, 7967, 3970, 4550, 6017, 938, 744, 6951, 3397, 4852, 3133, 7931, 707, 3312, 7470, 6871, 8292, 7100, 9529, 9100, 3853, 9060, 9732, 2521, 3789, 2974, 5311, 3218, 5736, 3055, 7076, 1220, 9147, 1344, 532, 8218, 3569, 1008, 8475, 8877, 1582, 8936, 4758, 1837, 9517, 252, 5832, 1916, 6369, 4979, 9324, 6218, 9777, 7923, 4521, 2868, 213, 8083, 5952, 5579, 4508, 5488, 2460, 5332, 5180, 8323, 8345, 3776, 2568, 5151, 4570, 2854, 8488, 4874, 680, 2810, 1285, 6136, 3339, 9143, 6852, 1906, 7067, 7073, 2975, 1924, 6804, 6755, 9299, 2019, 9445, 9560, 360, 1601, 7297, 9122, 6377, 9214, 6167, 3980, 394, 7491, 7581, 9349, 8953, 222, 139, 530, 3577, 9868, 247, 9099, 9026, 209, 538, 3229, 9258, 585, 9204, 9643, 1492, 3609, 6570, 6561, 6469, 6435, 6419, 2155, 6275, 4481, 2202, 1987, 2271, 2355, 2366, 2432, 5400, 2497, 2727, 4931, 4619, 9884, 5902, 8796, 6848, 6960, 8575, 8413, 981, 8272, 8145, 3172, 1221, 3168, 1256, 1889, 1291, 3964, 7635, 7553, 7526, 3999, 1626, 1639, 4193, 7198, 3957, 3344], dtype=int64), 396)
import numpy as np img=np.stack(list(test_dataset[idx_lst[i]][0][0] for i in range(5)),axis=1).reshape(32,32*5) plt.imshow(img) plt.axis('off')
(-0.5, 159.5, 31.5, -0.5)
#显示4行 import numpy as np img20=np.stack( tuple(np.stack( tuple(test_dataset[idx_lst[i+j*5]][0][0] for i in range(5)), axis=1).reshape(32,32*5) for j in range(4)),axis=0).reshape(32*4,32*5) plt.imshow(img20) plt.axis('off')
(-0.5, 159.5, 127.5, -0.5)
15.输出错误的模型类别
idx_lst=mismatch[mismatch[1]==9].index.values table.iloc[idx_lst[:], 2].values
array([1, 1, 8, 1, 1, 8, 7, 8, 8, 6, 1, 1, 1, 1, 7, 0, 7, 0, 0, 8, 6, 8, 0, 8, 1, 1, 3, 7, 5, 1, 4, 0, 1, 4, 1, 1, 1, 8, 6, 3, 1, 1, 0, 1, 1, 6, 8, 1, 1, 8, 7, 8, 6, 1, 1, 1, 0, 1, 0, 1, 8, 6, 7, 8, 0, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1, 6, 8, 7, 6, 7, 1, 8, 0, 7, 3, 1, 1, 0, 8, 3, 3, 1, 8, 1, 8, 1, 2, 0, 8, 8, 3, 8, 1, 3, 7, 0, 3, 8, 3, 5, 7, 1, 3, 1, 1, 8, 1, 3, 1, 7, 1, 7, 7, 1, 3, 0, 0, 1, 1, 0, 5, 7, 6, 4, 3, 1, 8, 8, 1, 3, 5, 8, 0, 1, 5, 1, 7, 8, 4, 3, 1, 1, 1, 3, 0, 6, 8, 8, 1, 3, 1, 7, 5, 1, 1, 5, 1, 1, 8, 8, 4, 7, 8, 8, 1, 1, 1, 0, 1, 1, 1, 1, 1, 3, 8, 7, 7, 1, 4, 7, 0, 2, 8, 1, 6, 0, 4, 1, 7, 1, 1, 8, 1, 6, 1, 0, 1, 0, 0, 7, 1, 7, 1, 1, 0, 5, 7, 1, 1, 0, 8, 1, 1, 7, 1, 7, 5, 0, 6, 1, 1, 8, 1, 1, 7, 1, 4, 0, 7, 1, 7, 1, 6, 8, 1, 6, 7, 1, 8, 8, 8, 1, 1, 0, 8, 8, 0, 1, 7, 0, 7, 1, 1, 1, 8, 7, 0, 5, 4, 8, 0, 1, 1, 1, 1, 7, 7, 1, 6, 5, 1, 2, 8, 0, 2, 1, 1, 7, 0, 1, 1, 1, 5, 7, 1, 1, 1, 2, 8, 8, 1, 7, 8, 1, 0, 1, 1, 1, 3, 1, 1, 1, 7, 4, 1, 4, 0, 1, 1, 7, 1, 8, 0, 6, 0, 8, 0, 5, 1, 7, 7, 1, 1, 8, 1, 1, 6, 7, 1, 8, 1, 1, 0, 1, 8, 6, 6, 1, 8, 3, 0, 8, 5, 1, 1, 0, 8, 5, 7, 0, 7, 6, 1, 8, 1, 7, 1, 8, 1, 7, 6, 8, 0, 1, 7, 0, 1, 3, 6, 1, 5, 7, 0, 8, 0, 1, 5, 1, 6, 3, 8, 1, 1, 1, 8, 1], dtype=int64)
arr2=table.iloc[idx_lst[:], 2].values print('错误模型共' + str(len(arr2)) + '个') for i in range(33): for j in range(12): print(classes[arr2[j+i*12]],end=" ") print()
错误模型共396个 car car ship car car ship horse ship ship frog car car car car horse plane horse plane plane ship frog ship plane ship car car cat horse dog car deer plane car deer car car car ship frog cat car car plane car car frog ship car car ship horse ship frog car car car plane car plane car ship frog horse ship plane ship car car car car car car car car car frog ship horse frog horse car ship plane horse cat car car plane ship cat cat car ship car ship car bird plane ship ship cat ship car cat horse plane cat ship cat dog horse car cat car car ship car cat car horse car horse horse car cat plane plane car car plane dog horse frog deer cat car ship ship car cat dog ship plane car dog car horse ship deer cat car car car cat plane frog ship ship car cat car horse dog car car dog car car ship ship deer horse ship ship car car car plane car car car car car cat ship horse horse car deer horse plane bird ship car frog plane deer car horse car car ship car frog car plane car plane plane horse car horse car car plane dog horse car car plane ship car car horse car horse dog plane frog car car ship car car horse car deer plane horse car horse car frog ship car frog horse car ship ship ship car car plane ship ship plane car horse plane horse car car car ship horse plane dog deer ship plane car car car car horse horse car frog dog car bird ship plane bird car car horse plane car car car dog horse car car car bird ship ship car horse ship car plane car car car cat car car car horse deer car deer plane car car horse car ship plane frog plane ship plane dog car horse horse car car ship car car frog horse car ship car car plane car ship frog frog car ship cat plane ship dog car car plane ship dog horse plane horse frog car ship car horse car ship car horse frog ship plane car horse plane car cat frog car dog horse plane ship plane car dog car frog cat ship car car car ship car