【Python机器学习】实验14 手写体卷积神经网络2

简介: 【Python机器学习】实验14 手写体卷积神经网络2

11. 读取测试集的图片预测值(神经网络的输出为10)

#导入模型
model1=torch.load('./model-cifar10.pth')
pre_result=torch.zeros(len(test_dataset),10)
for i in range(len(test_dataset)):
    pre_result[i,:]=model1(torch.reshape(test_dataset[i][0],(-1,1,28,28)))
pre_result    


tensor([[-1.8005, -0.1725,  1.4765,  ..., 13.5399, -0.6261,  3.8320],
        [ 4.3233,  7.6017, 16.5872,  ..., -0.0560,  5.2066, -7.0792],
        [-2.1821,  9.3779,  0.7749,  ...,  1.8749,  1.6951, -2.9422],
        ...,
        [-5.5755, -2.2075, -9.2250,  ..., -0.2451,  3.3703,  1.2951],
        [ 0.5491, -7.7018, -5.8288,  ..., -7.6734,  9.3946, -1.9603],
        [ 3.5516, -8.2659, -0.5965,  ..., -8.5934,  1.0133, -2.2048]],
       grad_fn=<CopySlices>)
pre_result.shape
torch.Size([10000, 10])
pre_result[:5]
tensor([[-1.8005e+00, -1.7254e-01,  1.4765e+00,  3.0824e+00, -2.5454e+00,
         -7.6911e-01, -1.2368e+01,  1.3540e+01, -6.2614e-01,  3.8320e+00],
        [ 4.3233e+00,  7.6017e+00,  1.6587e+01,  3.6477e+00, -6.6674e+00,
         -6.0579e+00, -1.5660e+00, -5.5983e-02,  5.2066e+00, -7.0792e+00],
        [-2.1821e+00,  9.3779e+00,  7.7487e-01, -3.0049e+00,  1.3374e+00,
         -1.6613e+00,  8.8544e-01,  1.8749e+00,  1.6951e+00, -2.9422e+00],
        [ 1.3456e+01, -9.9020e+00,  2.8586e+00, -2.2105e+00, -1.8515e+00,
          1.7651e-03,  4.7584e+00, -1.3772e+00, -2.2127e+00,  1.5543e+00],
        [-2.9573e+00, -2.5707e+00, -3.5142e+00, -1.7487e+00,  1.2020e+01,
         -8.8355e-01, -1.0698e+00,  6.3823e-01, -3.5542e-01,  3.6258e+00]],
       grad_fn=<SliceBackward0>)
#显示这10000张图片的标签
label_10000=[test_dataset[i][1] for i in range(10000)]
label_10000
[7,
 2,
 1,
 0,
 4,
 1,
 4,
 9,
 5,
 9,
 0,
 6,
 9,
 0,
 1,
 5,
 9,
 7,
 3,
 4,
 9,
 6,
 6,
 5,
 4,
 0,
 7,
 4,
 0,
 1,
 3,
 1,
 3,
 4,
 7,
 2,
 7,
 1,
 2,
 1,
 1,
 7,
 4,
 2,
 3,
 5,
 1,
 2,
 4,
 4,
 6,
 3,
 5,
 5,
 6,
 0,
 4,
 1,
 9,
 5,
 7,
 8,
 9,
 3,
 7,
 4,
 6,
 4,
 3,
 0,
 7,
 0,
 2,
 9,
 1,
 7,
 3,
 2,
 9,
 7,
 7,
 6,
 2,
 7,
 8,
 4,
 7,
 3,
 6,
 1,
 3,
 6,
 9,
 3,
 1,
 4,
 1,
 7,
 6,
 9,
 6,
 0,
 5,
 4,
 9,
 9,
 2,
 1,
 9,
 4,
 8,
 7,
 3,
 9,
 7,
 4,
 4,
 4,
 9,
 2,
 5,
 4,
 7,
 6,
 7,
 9,
 0,
 5,
 8,
 5,
 6,
 6,
 5,
 7,
 8,
 1,
 0,
 1,
 6,
 4,
 6,
 7,
 3,
 1,
 7,
 1,
 8,
 2,
 0,
 2,
 9,
 9,
 5,
 5,
 1,
 5,
 6,
 0,
 3,
 4,
 4,
 6,
 5,
 4,
 6,
 5,
 4,
 5,
 1,
 4,
 4,
 7,
 2,
 3,
 2,
 7,
 1,
 8,
 1,
 8,
 1,
 8,
 5,
 0,
 8,
 9,
 2,
 5,
 0,
 1,
 1,
 1,
 0,
 9,
 0,
 3,
 1,
 6,
 4,
 2,
 3,
 6,
 1,
 1,
 1,
 3,
 9,
 5,
 2,
 9,
 4,
 5,
 9,
 3,
 9,
 0,
 3,
 6,
 5,
 5,
 7,
 2,
 2,
 7,
 1,
 2,
 8,
 4,
 1,
 7,
 3,
 3,
 8,
 8,
 7,
 9,
 2,
 2,
 4,
 1,
 5,
 9,
 8,
 7,
 2,
 3,
 0,
 4,
 4,
 2,
 4,
 1,
 9,
 5,
 7,
 7,
 2,
 8,
 2,
 6,
 8,
 5,
 7,
 7,
 9,
 1,
 8,
 1,
 8,
 0,
 3,
 0,
 1,
 9,
 9,
 4,
 1,
 8,
 2,
 1,
 2,
 9,
 7,
 5,
 9,
 2,
 6,
 4,
 1,
 5,
 8,
 2,
 9,
 2,
 0,
 4,
 0,
 0,
 2,
 8,
 4,
 7,
 1,
 2,
 4,
 0,
 2,
 7,
 4,
 3,
 3,
 0,
 0,
 3,
 1,
 9,
 6,
 5,
 2,
 5,
 9,
 2,
 9,
 3,
 0,
 4,
 2,
 0,
 7,
 1,
 1,
 2,
 1,
 5,
 3,
 3,
 9,
 7,
 8,
 6,
 5,
 6,
 1,
 3,
 8,
 1,
 0,
 5,
 1,
 3,
 1,
 5,
 5,
 6,
 1,
 8,
 5,
 1,
 7,
 9,
 4,
 6,
 2,
 2,
 5,
 0,
 6,
 5,
 6,
 3,
 7,
 2,
 0,
 8,
 8,
 5,
 4,
 1,
 1,
 4,
 0,
 3,
 3,
 7,
 6,
 1,
 6,
 2,
 1,
 9,
 2,
 8,
 6,
 1,
 9,
 5,
 2,
 5,
 4,
 4,
 2,
 8,
 3,
 8,
 2,
 4,
 5,
 0,
 3,
 1,
 7,
 7,
 5,
 7,
 9,
 7,
 1,
 9,
 2,
 1,
 4,
 2,
 9,
 2,
 0,
 4,
 9,
 1,
 4,
 8,
 1,
 8,
 4,
 5,
 9,
 8,
 8,
 3,
 7,
 6,
 0,
 0,
 3,
 0,
 2,
 6,
 6,
 4,
 9,
 3,
 3,
 3,
 2,
 3,
 9,
 1,
 2,
 6,
 8,
 0,
 5,
 6,
 6,
 6,
 3,
 8,
 8,
 2,
 7,
 5,
 8,
 9,
 6,
 1,
 8,
 4,
 1,
 2,
 5,
 9,
 1,
 9,
 7,
 5,
 4,
 0,
 8,
 9,
 9,
 1,
 0,
 5,
 2,
 3,
 7,
 8,
 9,
 4,
 0,
 6,
 3,
 9,
 5,
 2,
 1,
 3,
 1,
 3,
 6,
 5,
 7,
 4,
 2,
 2,
 6,
 3,
 2,
 6,
 5,
 4,
 8,
 9,
 7,
 1,
 3,
 0,
 3,
 8,
 3,
 1,
 9,
 3,
 4,
 4,
 6,
 4,
 2,
 1,
 8,
 2,
 5,
 4,
 8,
 8,
 4,
 0,
 0,
 2,
 3,
 2,
 7,
 7,
 0,
 8,
 7,
 4,
 4,
 7,
 9,
 6,
 9,
 0,
 9,
 8,
 0,
 4,
 6,
 0,
 6,
 3,
 5,
 4,
 8,
 3,
 3,
 9,
 3,
 3,
 3,
 7,
 8,
 0,
 8,
 2,
 1,
 7,
 0,
 6,
 5,
 4,
 3,
 8,
 0,
 9,
 6,
 3,
 8,
 0,
 9,
 9,
 6,
 8,
 6,
 8,
 5,
 7,
 8,
 6,
 0,
 2,
 4,
 0,
 2,
 2,
 3,
 1,
 9,
 7,
 5,
 1,
 0,
 8,
 4,
 6,
 2,
 6,
 7,
 9,
 3,
 2,
 9,
 8,
 2,
 2,
 9,
 2,
 7,
 3,
 5,
 9,
 1,
 8,
 0,
 2,
 0,
 5,
 2,
 1,
 3,
 7,
 6,
 7,
 1,
 2,
 5,
 8,
 0,
 3,
 7,
 2,
 4,
 0,
 9,
 1,
 8,
 6,
 7,
 7,
 4,
 3,
 4,
 9,
 1,
 9,
 5,
 1,
 7,
 3,
 9,
 7,
 6,
 9,
 1,
 3,
 7,
 8,
 3,
 3,
 6,
 7,
 2,
 8,
 5,
 8,
 5,
 1,
 1,
 4,
 4,
 3,
 1,
 0,
 7,
 7,
 0,
 7,
 9,
 4,
 4,
 8,
 5,
 5,
 4,
 0,
 8,
 2,
 1,
 0,
 8,
 4,
 5,
 0,
 4,
 0,
 6,
 1,
 7,
 3,
 2,
 6,
 7,
 2,
 6,
 9,
 3,
 1,
 4,
 6,
 2,
 5,
 4,
 2,
 0,
 6,
 2,
 1,
 7,
 3,
 4,
 1,
 0,
 5,
 4,
 3,
 1,
 1,
 7,
 4,
 9,
 9,
 4,
 8,
 4,
 0,
 2,
 4,
 5,
 1,
 1,
 6,
 4,
 7,
 1,
 9,
 4,
 2,
 4,
 1,
 5,
 5,
 3,
 8,
 3,
 1,
 4,
 5,
 6,
 8,
 9,
 4,
 1,
 5,
 3,
 8,
 0,
 3,
 2,
 5,
 1,
 2,
 8,
 3,
 4,
 4,
 0,
 8,
 8,
 3,
 3,
 1,
 7,
 3,
 5,
 9,
 6,
 3,
 2,
 6,
 1,
 3,
 6,
 0,
 7,
 2,
 1,
 7,
 1,
 4,
 2,
 4,
 2,
 1,
 7,
 9,
 6,
 1,
 1,
 2,
 4,
 8,
 1,
 7,
 7,
 4,
 8,
 0,
 7,
 3,
 1,
 3,
 1,
 0,
 7,
 7,
 0,
 3,
 5,
 5,
 2,
 7,
 6,
 6,
 9,
 2,
 8,
 3,
 5,
 2,
 2,
 5,
 6,
 0,
 8,
 2,
 9,
 2,
 8,
 8,
 8,
 8,
 7,
 4,
 9,
 3,
 0,
 6,
 6,
 3,
 2,
 1,
 3,
 2,
 2,
 9,
 3,
 0,
 0,
 5,
 7,
 8,
 1,
 4,
 4,
 6,
 0,
 2,
 9,
 1,
 4,
 7,
 4,
 7,
 3,
 9,
 8,
 8,
 4,
 7,
 1,
 2,
 1,
 2,
 2,
 3,
 2,
 3,
 2,
 3,
 9,
 1,
 7,
 4,
 0,
 3,
 5,
 5,
 8,
 6,
 3,
 2,
 6,
 7,
 6,
 6,
 3,
 2,
 7,
 8,
 1,
 1,
 7,
 5,
 6,
 4,
 9,
 5,
 1,
 3,
 3,
 4,
 7,
 8,
 9,
 1,
 1,
 6,
 9,
 1,
 4,
 4,
 5,
 4,
 0,
 6,
 2,
 2,
 3,
 1,
 5,
 1,
 2,
 0,
 3,
 8,
 1,
 2,
 6,
 7,
 1,
 6,
 2,
 3,
 9,
 0,
 1,
 2,
 2,
 0,
 8,
 9,
 ...]
import numpy as np
pre_10000=pre_result.detach()
pre_10000
pre_10000=np.array(pre_10000)
pre_10000
array([[-1.8004757 , -0.17253768,  1.4764961 , ..., 13.539932  ,
        -0.6261405 ,  3.832048  ],
       [ 4.323273  ,  7.601658  , 16.587166  , ..., -0.05598306,
         5.20656   , -7.0792093 ],
       [-2.1820781 ,  9.377863  ,  0.7748679 , ...,  1.8749483 ,
         1.6950815 , -2.9421623 ],
       ...,
       [-5.575542  , -2.2075167 , -9.225033  , ..., -0.24509335,
         3.3702612 ,  1.2950805 ],
       [ 0.5491407 , -7.7017508 , -5.8287773 , ..., -7.6733685 ,
         9.39456   , -1.9602803 ],
       [ 3.5516088 , -8.265893  , -0.59651583, ..., -8.593432  ,
         1.0132635 , -2.2048213 ]], dtype=float32)

12. 采用pandas可视化数据

import pandas as pd 
table=pd.DataFrame(zip(pre_10000,label_10000))
table

0
1
0 [-1.8004757, -0.17253768, 1.4764961, 3.0824265... 7
1 [4.323273, 7.601658, 16.587166, 3.6476722, -6.... 2
2 [-2.1820781, 9.377863, 0.7748679, -3.0049446, ... 1
3 [13.455704, -9.902006, 2.8586285, -2.2104588, ... 0
4 [-2.9572597, -2.5707455, -3.5142026, -1.748683... 4
... ... ...
9995 [-2.5784128, 10.5256405, 23.895123, 8.827512, ... 2
9996 [-2.773907, 0.56169015, 1.6811254, 15.230703, ... 3
9997 [-5.575542, -2.2075167, -9.225033, -5.60418, 1... 4
9998 [0.5491407, -7.7017508, -5.8287773, 2.2394006,... 5
9999 [3.5516088, -8.265893, -0.59651583, -4.034732,... 6

10000 rows × 2 columns

table[0].values
array([array([ -1.8004757 ,  -0.17253768,   1.4764961 ,   3.0824265 ,
               -2.545419  ,  -0.76911056, -12.368087  ,  13.539932  ,
               -0.6261405 ,   3.832048  ], dtype=float32)            ,
       array([ 4.323273  ,  7.601658  , 16.587166  ,  3.6476722 , -6.6673512 ,
              -6.05786   , -1.5660243 , -0.05598306,  5.20656   , -7.0792093 ],
             dtype=float32)                                                    ,
       array([-2.1820781,  9.377863 ,  0.7748679, -3.0049446,  1.3374403,
              -1.6612737,  0.8854448,  1.8749483,  1.6950815, -2.9421623],
             dtype=float32)                                               ,
       ...,
       array([-5.575542  , -2.2075167 , -9.225033  , -5.60418   , 17.216341  ,
               2.8671436 ,  1.0113716 , -0.24509335,  3.3702612 ,  1.2950805 ],
             dtype=float32)                                                    ,
       array([ 0.5491407, -7.7017508, -5.8287773,  2.2394006, -7.533697 ,
              13.003905 ,  6.1807218, -7.6733685,  9.39456  , -1.9602803],
             dtype=float32)                                               ,
       array([ 3.5516088 , -8.265893  , -0.59651583, -4.034732  ,  1.3853229 ,
               6.1974382 , 16.321545  , -8.593432  ,  1.0132635 , -2.2048213 ],
             dtype=float32)                                                    ],
      dtype=object)
table["pred"]=[np.argmax(table[0][i]) for i in range(table.shape[0])]
table

0
1 pred
0 [-1.8004757, -0.17253768, 1.4764961, 3.0824265... 7 7
1 [4.323273, 7.601658, 16.587166, 3.6476722, -6.... 2 2
2 [-2.1820781, 9.377863, 0.7748679, -3.0049446, ... 1 1
3 [13.455704, -9.902006, 2.8586285, -2.2104588, ... 0 0
4 [-2.9572597, -2.5707455, -3.5142026, -1.748683... 4 4
... ... ... ...
9995 [-2.5784128, 10.5256405, 23.895123, 8.827512, ... 2 2
9996 [-2.773907, 0.56169015, 1.6811254, 15.230703, ... 3 3
9997 [-5.575542, -2.2075167, -9.225033, -5.60418, 1... 4 4
9998 [0.5491407, -7.7017508, -5.8287773, 2.2394006,... 5 5
9999 [3.5516088, -8.265893, -0.59651583, -4.034732,... 6 6

10000 rows × 3 columns

13. 对预测错误的样本点进行可视化

mismatch=table[table[1]!=table["pred"]]
mismatch


0 1 pred
247 [-0.28747877, 1.9184055, 8.627771, -3.1354206,... 4 2
340 [-5.550468, 1.6552217, -0.96347404, 9.110174, ... 5 3
449 [-6.0154114, -3.7659, -2.7571707, 14.220249, -... 3 5
582 [-1.4626387, 1.3258317, 10.138913, 5.996572, -... 8 2
659 [-3.1300178, 8.830592, 8.781635, 5.6512327, -3... 2 1
... ... ... ...
9768 [2.6190603, -5.539648, 3.0145228, 4.8416886, -... 2 3
9770 [7.0385275, -9.72994, 0.03886398, -0.3356622, ... 5 6
9792 [-0.84618676, -0.038114145, -4.388391, 0.12577... 4 9
9904 [1.6193992, -7.525599, 2.833153, 3.7744582, -2... 2 8
9982 [0.8662107, -7.932593, -0.3750058, 1.9749051, ... 5 6

158 rows × 3 columns

from matplotlib import pyplot as plt
plt.scatter(mismatch[1],mismatch["pred"])


<matplotlib.collections.PathCollection at 0x217dc403490>

14. 看看错误样本被预测为哪些数据

mismatch[mismatch[1]==8].sort_values("pred").index
Int64Index([4807, 2896,  582, 6625, 7220, 3871, 4123, 1878, 1319, 2179, 4601,
            4956, 3023, 9280, 8408, 6765, 4497, 1530,  947],
           dtype='int64')
table.iloc[4500,:]
0       [-4.9380565, 6.2523484, -1.2272537, 0.32682633...
1                                                       9
pred                                                    1
Name: 4500, dtype: object
idx_lst=mismatch[mismatch[1]==8].sort_values("pred").index.values
idx_lst,len(idx_lst)
(array([4807, 2896,  582, 6625, 7220, 3871, 4123, 1878, 1319, 2179, 4601,
        4956, 3023, 9280, 8408, 6765, 4497, 1530,  947], dtype=int64),
 19)
mismatch[mismatch[1]==8].sort_values("pred")
0 1 pred
4807 [5.3192024, -4.2546616, 3.6083155, 3.8956034, ... 8 0
2896 [7.4840407, -8.972937, 0.9461607, 1.6278361, -... 8 0
582 [-1.4626387, 1.3258317, 10.138913, 5.996572, -... 8 2
6625 [-5.413072, 2.7984824, 6.0430045, 2.3938487, 0... 8 2
7220 [-3.1443837, -3.4629154, 4.8560658, 12.752452,... 8 3
3871 [0.1749076, -5.8143945, 3.083826, 8.113558, -5... 8 3
4123 [-3.8682778, -2.290763, 6.1067047, 10.920237, ... 8 3
1878 [-2.8437655, -2.4290323, 3.1861248, 9.739316, ... 8 3
1319 [3.583813, -6.279593, -0.21310738, 7.2746606, ... 8 3
2179 [-0.57300043, -3.8434098, 8.02766, 12.139142, ... 8 3
4601 [-9.5640745, -2.1305811, -5.2161045, 2.3105593... 8 4
4956 [-7.5286517, -4.080871, -6.850239, -2.9094412,... 8 4
3023 [-2.6319933, -11.065216, -1.3231966, 0.0415189... 8 5
9280 [-1.9706918, -11.544259, -0.51283014, 3.955923... 8 5
8408 [1.0573181, -3.7079592, 0.34973174, -0.3489528... 8 6
6765 [2.8831, -2.6855779, 0.39529848, -1.855415, -2... 8 6
4497 [-4.830113, -0.28656, 4.911254, 4.4041815, -2.... 8 7
1530 [-4.4495664, -2.5381584, 5.4418654, 9.994939, ... 8 7
947 [-2.8835857, -8.3713045, -1.5150836, 3.1263702... 8 9
import numpy as np
img=np.stack(list(test_dataset[idx_lst[i]][0][0] for i in range(5)),axis=1).reshape(28,28*5)
plt.imshow(img)
<matplotlib.image.AxesImage at 0x217dc28e9d0>


#显示3行
import numpy as np
img30=np.stack(
    tuple(np.stack(
            tuple(test_dataset[idx_lst[i+j*5]][0][0] for i in range(5)),
        axis=1).reshape(28,28*5) for j in range(3)),axis=0).reshape(28*3,28*5)
plt.imshow(img30)
plt.axis('off')
(-0.5, 139.5, 83.5, -0.5)

arr2=table.iloc[idx_lst[:30],2].values
arr2
array([0, 0, 2, 2, 3, 3, 3, 3, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 9],
      dtype=int64)


目录
相关文章
|
11天前
|
机器学习/深度学习 人工智能 算法
海洋生物识别系统+图像识别+Python+人工智能课设+深度学习+卷积神经网络算法+TensorFlow
海洋生物识别系统。以Python作为主要编程语言,通过TensorFlow搭建ResNet50卷积神经网络算法,通过对22种常见的海洋生物('蛤蜊', '珊瑚', '螃蟹', '海豚', '鳗鱼', '水母', '龙虾', '海蛞蝓', '章鱼', '水獭', '企鹅', '河豚', '魔鬼鱼', '海胆', '海马', '海豹', '鲨鱼', '虾', '鱿鱼', '海星', '海龟', '鲸鱼')数据集进行训练,得到一个识别精度较高的模型文件,然后使用Django开发一个Web网页平台操作界面,实现用户上传一张海洋生物图片识别其名称。
94 7
海洋生物识别系统+图像识别+Python+人工智能课设+深度学习+卷积神经网络算法+TensorFlow
|
4天前
|
机器学习/深度学习 人工智能 算法
【乐器识别系统】图像识别+人工智能+深度学习+Python+TensorFlow+卷积神经网络+模型训练
乐器识别系统。使用Python为主要编程语言,基于人工智能框架库TensorFlow搭建ResNet50卷积神经网络算法,通过对30种乐器('迪吉里杜管', '铃鼓', '木琴', '手风琴', '阿尔卑斯号角', '风笛', '班卓琴', '邦戈鼓', '卡萨巴', '响板', '单簧管', '古钢琴', '手风琴(六角形)', '鼓', '扬琴', '长笛', '刮瓜', '吉他', '口琴', '竖琴', '沙槌', '陶笛', '钢琴', '萨克斯管', '锡塔尔琴', '钢鼓', '长号', '小号', '大号', '小提琴')的图像数据集进行训练,得到一个训练精度较高的模型,并将其
17 0
【乐器识别系统】图像识别+人工智能+深度学习+Python+TensorFlow+卷积神经网络+模型训练
|
11天前
|
机器学习/深度学习 人工智能 算法
【昆虫识别系统】图像识别Python+卷积神经网络算法+人工智能+深度学习+机器学习+TensorFlow+ResNet50
昆虫识别系统,使用Python作为主要开发语言。通过TensorFlow搭建ResNet50卷积神经网络算法(CNN)模型。通过对10种常见的昆虫图片数据集('蜜蜂', '甲虫', '蝴蝶', '蝉', '蜻蜓', '蚱蜢', '蛾', '蝎子', '蜗牛', '蜘蛛')进行训练,得到一个识别精度较高的H5格式模型文件,然后使用Django搭建Web网页端可视化操作界面,实现用户上传一张昆虫图片识别其名称。
144 7
【昆虫识别系统】图像识别Python+卷积神经网络算法+人工智能+深度学习+机器学习+TensorFlow+ResNet50
|
5天前
|
机器学习/深度学习 计算机视觉 网络架构
【YOLOv8改进-卷积Conv】DualConv( Dual Convolutional):用于轻量级深度神经网络的双卷积核
**摘要:** 我们提出DualConv,一种融合$3\times3$和$1\times1$卷积的轻量级DNN技术,适用于资源有限的系统。它通过组卷积结合两种卷积核,减少计算和参数量,同时增强准确性。在MobileNetV2上,参数减少54%,CIFAR-100精度仅降0.68%。在YOLOv3中,DualConv提升检测速度并增4.4%的PASCAL VOC准确性。论文及代码已开源。
|
4天前
|
机器学习/深度学习 PyTorch 算法框架/工具
【YOLOv8改进 - 注意力机制】SimAM:轻量级注意力机制,解锁卷积神经网络新潜力
YOLO目标检测专栏介绍了SimAM,一种无参数的CNN注意力模块,基于神经科学理论优化能量函数,提升模型表现。SimAM通过计算3D注意力权重增强特征表示,无需额外参数。文章提供论文链接、Pytorch实现代码及详细配置,展示了如何在目标检测任务中应用该模块。
|
8天前
|
机器学习/深度学习 人工智能 自然语言处理
机器学习算法入门:从K-means到神经网络
【6月更文挑战第26天】机器学习入门:从K-means到神经网络。文章涵盖了K-means聚类、逻辑回归、决策树和神经网络的基础原理及应用场景。K-means用于数据分组,逻辑回归适用于二分类,决策树通过特征划分做决策,神经网络则在复杂任务如图像和语言处理中大显身手。是初学者的算法导览。
|
9天前
|
机器学习/深度学习 算法 数据挖掘
Python机器学习10大经典算法的讲解和示例
为了展示10个经典的机器学习算法的最简例子,我将为每个算法编写一个小的示例代码。这些算法将包括线性回归、逻辑回归、K-最近邻(KNN)、支持向量机(SVM)、决策树、随机森林、朴素贝叶斯、K-均值聚类、主成分分析(PCA)、和梯度提升(Gradient Boosting)。我将使用常见的机器学习库,如 scikit-learn,numpy 和 pandas 来实现这些算法。
|
21天前
|
机器学习/深度学习 PyTorch 算法框架/工具
【从零开始学习深度学习】26.卷积神经网络之AlexNet模型介绍及其Pytorch实现【含完整代码】
【从零开始学习深度学习】26.卷积神经网络之AlexNet模型介绍及其Pytorch实现【含完整代码】
|
21天前
|
机器学习/深度学习 PyTorch 算法框架/工具
【从零开始学习深度学习】28.卷积神经网络之NiN模型介绍及其Pytorch实现【含完整代码】
【从零开始学习深度学习】28.卷积神经网络之NiN模型介绍及其Pytorch实现【含完整代码】
|
2天前
|
机器学习/深度学习 编解码 数据可视化
图神经网络版本的Kolmogorov Arnold(KAN)代码实现和效果对比
目前我们看到有很多使用KAN替代MLP的实验,但是目前来说对于图神经网络来说还没有类似的实验,今天我们就来使用KAN创建一个图神经网络Graph Kolmogorov Arnold(GKAN),来测试下KAN是否可以在图神经网络方面有所作为。
15 0