【阿旭机器学习实战】【15】人脸自动补全(多目标回归),并比较5种不同模型的预测效果

本文涉及的产品
交互式建模 PAI-DSW,5000CU*H 3个月
简介: 【阿旭机器学习实战】【15】人脸自动补全(多目标回归),并比较5种不同模型的预测效果

机器学习实战—人脸自动补全(多目标预测)


目标


通过上半部分的人脸图案来预测下边部分人脸,进行人脸补全。


实质是一个多目标预测问题,对每一个目标点都会进行模型建模,然后通过相应模型对各个点进行预测


数据集


采用Olivetti人脸数据集包含400张灰度的64*64像素的人脸图像,每个图像被展平为大小为4096的一维向量,40个不同的人拍照十次。


from sklearn.neighbors import KNeighborsRegressor
from sklearn.linear_model import LinearRegression,Ridge,Lasso
from sklearn.ensemble import ExtraTreesRegressor
from sklearn import datasets
• 1
faces = datasets.fetch_olivetti_faces()
• 1
faces
{'data': array([[0.30991736, 0.3677686 , 0.41735536, ..., 0.15289256, 0.16115703,
         0.1570248 ],
        [0.45454547, 0.47107437, 0.5123967 , ..., 0.15289256, 0.15289256,
         0.15289256],
        [0.3181818 , 0.40082645, 0.49173555, ..., 0.14049587, 0.14876033,
         0.15289256],
        ...,
        [0.5       , 0.53305787, 0.607438  , ..., 0.17768595, 0.14876033,
         0.19008264],
        [0.21487603, 0.21900827, 0.21900827, ..., 0.57438016, 0.59090906,
         0.60330576],
        [0.5165289 , 0.46280992, 0.28099173, ..., 0.35950413, 0.3553719 ,
         0.38429752]], dtype=float32),
 'images': array([[[0.30991736, 0.3677686 , 0.41735536, ..., 0.37190083,
          0.3305785 , 0.30578512],
         [0.3429752 , 0.40495867, 0.43801653, ..., 0.37190083,
          0.338843  , 0.3140496 ],
         [0.3429752 , 0.41735536, 0.45041323, ..., 0.38016528,
          0.338843  , 0.29752067],
         ...,
         [0.21487603, 0.20661157, 0.2231405 , ..., 0.15289256,
          0.16528925, 0.17355372],
         [0.20247933, 0.2107438 , 0.2107438 , ..., 0.14876033,
          0.16115703, 0.16528925],
         [0.20247933, 0.20661157, 0.20247933, ..., 0.15289256,
          0.16115703, 0.1570248 ]],
        [[0.45454547, 0.47107437, 0.5123967 , ..., 0.19008264,
          0.18595041, 0.18595041],
         [0.446281  , 0.48347107, 0.5206612 , ..., 0.21487603,
          0.2107438 , 0.2107438 ],
         [0.49586776, 0.5165289 , 0.53305787, ..., 0.20247933,
          0.20661157, 0.20661157],
         ...,
         [0.77272725, 0.78099173, 0.7933884 , ..., 0.1446281 ,
          0.1446281 , 0.1446281 ],
         [0.77272725, 0.7768595 , 0.7892562 , ..., 0.13636364,
          0.13636364, 0.13636364],
         [0.7644628 , 0.7892562 , 0.78099173, ..., 0.15289256,
          0.15289256, 0.15289256]],
        [[0.3181818 , 0.40082645, 0.49173555, ..., 0.40082645,
          0.3553719 , 0.30991736],
         [0.30991736, 0.3966942 , 0.47933885, ..., 0.40495867,
          0.37603307, 0.30165288],
         [0.26859504, 0.34710744, 0.45454547, ..., 0.3966942 ,
          0.37190083, 0.30991736],
         ...,
         [0.1322314 , 0.09917355, 0.08264463, ..., 0.13636364,
          0.14876033, 0.15289256],
         [0.11570248, 0.09504132, 0.0785124 , ..., 0.1446281 ,
          0.1446281 , 0.1570248 ],
         [0.11157025, 0.09090909, 0.0785124 , ..., 0.14049587,
          0.14876033, 0.15289256]],
        ...,
        [[0.5       , 0.53305787, 0.607438  , ..., 0.28512397,
          0.23966943, 0.21487603],
         [0.49173555, 0.5413223 , 0.60330576, ..., 0.29752067,
          0.20247933, 0.20661157],
         [0.46694216, 0.55785125, 0.6198347 , ..., 0.29752067,
          0.17768595, 0.18595041],
         ...,
         [0.03305785, 0.46280992, 0.5289256 , ..., 0.17355372,
          0.17355372, 0.1694215 ],
         [0.1570248 , 0.5247934 , 0.53305787, ..., 0.16528925,
          0.1570248 , 0.18595041],
         [0.45454547, 0.5206612 , 0.53305787, ..., 0.17768595,
          0.14876033, 0.19008264]],
        [[0.21487603, 0.21900827, 0.21900827, ..., 0.71487606,
          0.71487606, 0.6942149 ],
         [0.20247933, 0.20661157, 0.20661157, ..., 0.7107438 ,
          0.7066116 , 0.6942149 ],
         [0.2107438 , 0.20661157, 0.20661157, ..., 0.6859504 ,
          0.69008267, 0.6942149 ],
         ...,
         [0.2644628 , 0.25619835, 0.2603306 , ..., 0.5413223 ,
          0.57438016, 0.59090906],
         [0.26859504, 0.2644628 , 0.26859504, ..., 0.56198347,
          0.58264464, 0.59504133],
         [0.27272728, 0.26859504, 0.27272728, ..., 0.57438016,
          0.59090906, 0.60330576]],
        [[0.5165289 , 0.46280992, 0.28099173, ..., 0.5785124 ,
          0.5413223 , 0.60330576],
         [0.5165289 , 0.45041323, 0.29338843, ..., 0.58264464,
          0.553719  , 0.5785124 ],
         [0.5165289 , 0.44214877, 0.29338843, ..., 0.59917355,
          0.5785124 , 0.54545456],
         ...,
         [0.39256197, 0.41322315, 0.38842976, ..., 0.33471075,
          0.37190083, 0.3966942 ],
         [0.39256197, 0.38429752, 0.40495867, ..., 0.3305785 ,
          0.35950413, 0.37603307],
         [0.3677686 , 0.40495867, 0.3966942 , ..., 0.35950413,
          0.3553719 , 0.38429752]]], dtype=float32),
 'target': array([ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  1,  1,  1,  1,  1,  1,  1,
         1,  1,  1,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  3,  3,  3,  3,
         3,  3,  3,  3,  3,  3,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  5,
         5,  5,  5,  5,  5,  5,  5,  5,  5,  6,  6,  6,  6,  6,  6,  6,  6,
         6,  6,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  8,  8,  8,  8,  8,
         8,  8,  8,  8,  8,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9, 10, 10,
        10, 10, 10, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 11, 11, 11, 11,
        11, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 13, 13, 13, 13, 13, 13,
        13, 13, 13, 13, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 15, 15, 15,
        15, 15, 15, 15, 15, 15, 15, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16,
        17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 18, 18, 18, 18, 18, 18, 18,
        18, 18, 18, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 20, 20, 20, 20,
        20, 20, 20, 20, 20, 20, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 22,
        22, 22, 22, 22, 22, 22, 22, 22, 22, 23, 23, 23, 23, 23, 23, 23, 23,
        23, 23, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 25, 25, 25, 25, 25,
        25, 25, 25, 25, 25, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 27, 27,
        27, 27, 27, 27, 27, 27, 27, 27, 28, 28, 28, 28, 28, 28, 28, 28, 28,
        28, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 30, 30, 30, 30, 30, 30,
        30, 30, 30, 30, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 32, 32, 32,
        32, 32, 32, 32, 32, 32, 32, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33,
        34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 35, 35, 35, 35, 35, 35, 35,
        35, 35, 35, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 37, 37, 37, 37,
        37, 37, 37, 37, 37, 37, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 39,
        39, 39, 39, 39, 39, 39, 39, 39, 39]),
 'DESCR': 'Modified Olivetti faces dataset.\n\nThe original database was available from\n\n    http://www.cl.cam.ac.uk/research/dtg/attarchive/facedatabase.html\n\nThe version retrieved here comes in MATLAB format from the personal\nweb page of Sam Roweis:\n\n    http://www.cs.nyu.edu/~roweis/\n\nThere are ten different images of each of 40 distinct subjects. For some\nsubjects, the images were taken at different times, varying the lighting,\nfacial expressions (open / closed eyes, smiling / not smiling) and facial\ndetails (glasses / no glasses). All the images were taken against a dark\nhomogeneous background with the subjects in an upright, frontal position (with\ntolerance for some side movement).\n\nThe original dataset consisted of 92 x 112, while the Roweis version\nconsists of 64x64 images.\n'}

data = faces.data
target = faces.target
data.shape
(400, 4096)
• 1
faces.images.shape
• 1
(400, 64, 64)

import matplotlib.pyplot as plt
%matplotlib inline
• 1
• 2
# 打印一张人脸图片
plt.imshow(data[100].reshape((64,64)),cmap="gray")
• 1
• 2


9211e5872a2b42218d4fc6c559ff613b.png

数据切分


切分特征数据数据和标签数据,特征是上半边脸,标签是下半边脸


# 特征是上半边脸
faces_up = data[:,:2048]
# 需要预测的目标:标签是下半边脸
faces_down = data[:,2048:]
plt.figure(figsize=(2,2))
plt.imshow(faces_up[10].reshape((32,64)),cmap="gray")
• 1
• 2
<matplotlib.image.AxesImage at 0x25eca1c8828>

a8fa790840c54aca923b28b5143e7f91.png

plt.figure(figsize=(2,2))
plt.imshow(faces_down[10].reshape((32,64)),cmap="gray")

10da94a016c84d93bb5a463285be775c.png

划分数据集


# 数据切分
from sklearn.model_selection import train_test_split
x_train,x_test,y_train,y_test = train_test_split(faces_up,faces_down,test_size=0.02)
• 1
y_train[1]


array([0.5082645 , 0.5082645 , 0.5123967 , ..., 0.16115703, 0.17768595,
       0.1694215 ], dtype=float32)


建立不同的回归模型并训练


此处分别用KNN回归模型,线性回归,岭回归,lasso回归,极端随机森林回归这几种不同的模型来进行建模


estimators = {
    "knn":KNeighborsRegressor(),
    "linear":LinearRegression(),
    "ridge":Ridge(),
    "lasso":Lasso(),
    "extra":ExtraTreesRegressor()  #极端随机森林回归
}
# 定义一个字典,用于保存每个算法预测结果
faces_pre = dict()
for key,estimator in estimators.items():
    # 对算法进行模型训练
    estimator.fit(x_train,y_train)
    # 预测
    y_ = estimator.predict(x_test)
    # 把预测的结果保存
    faces_pre[key] = y_
    # 得分
    score = estimator.score(x_test, y_test)
    print(key, score)
knn 0.4880642098170732
linear 0.18894319531680143
ridge 0.5157197923145055
lasso -0.2100687498661858
extra 0.35087195680524175
faces_pre
{'knn': array([[0.4471074 , 0.41652894, 0.42066115, ..., 0.54793394, 0.5355372 ,
         0.546281  ],
        [0.34876034, 0.34214878, 0.346281  , ..., 0.42727274, 0.42809922,
         0.43057853],
        [0.5355372 , 0.546281  , 0.58016527, ..., 0.56611574, 0.56280994,
         0.5644628 ],
        ...,
        [0.64793384, 0.67685956, 0.7049587 , ..., 0.41487604, 0.3586777 ,
         0.36776862],
        [0.3942149 , 0.41322312, 0.43553716, ..., 0.45785123, 0.43471074,
         0.39173552],
        [0.47520667, 0.47024792, 0.51404965, ..., 0.631405  , 0.6256199 ,
         0.59173554]], dtype=float32),
 'linear': array([[0.42212042, 0.35969752, 0.39748642, ..., 0.63096315, 0.5628751 ,
         0.5159277 ],
        [0.4241521 , 0.26758337, 0.16570012, ..., 0.09656662, 0.13010818,
         0.19814485],
        [0.62213266, 0.441006  , 0.48480797, ..., 0.5819658 , 0.69699645,
         0.44033697],
        ...,
        [0.71544605, 0.6732123 , 0.7088314 , ..., 0.37067276, 0.39097485,
         0.45659465],
        [0.2940399 , 0.3306437 , 0.32395566, ..., 0.19252078, 0.21714431,
         0.24263924],
        [0.4138433 , 0.47978985, 0.5166639 , ..., 0.5562554 , 0.4086836 ,
         0.42044348]], dtype=float32),
 'ridge': array([[0.4290133 , 0.37331253, 0.4017402 , ..., 0.5793132 , 0.53899723,
         0.4968022 ],
        [0.3253019 , 0.2301054 , 0.17614344, ..., 0.33642793, 0.3497425 ,
         0.3560007 ],
        [0.5519007 , 0.46847916, 0.5257808 , ..., 0.6301012 , 0.69831306,
         0.5881569 ],
        ...,
        [0.6989316 , 0.6826698 , 0.7077453 , ..., 0.29566136, 0.32281214,
         0.3521443 ],
        [0.31752783, 0.33159164, 0.33879474, ..., 0.24723864, 0.23903543,
         0.23862499],
        [0.39791593, 0.4184358 , 0.52279156, ..., 0.58981174, 0.50477254,
         0.5145724 ]], dtype=float32),
 'lasso': array([[0.5130819 , 0.5360938 , 0.56652683, ..., 0.31880376, 0.31096098,
         0.307535  ],
        [0.5130819 , 0.5360938 , 0.56652683, ..., 0.31880376, 0.31096098,
         0.307535  ],
        [0.5130819 , 0.5360938 , 0.56652683, ..., 0.31880376, 0.31096098,
         0.307535  ],
        ...,
        [0.5130819 , 0.5360938 , 0.56652683, ..., 0.31880376, 0.31096098,
         0.307535  ],
        [0.5130819 , 0.5360938 , 0.56652683, ..., 0.31880376, 0.31096098,
         0.307535  ],
        [0.5130819 , 0.5360938 , 0.56652683, ..., 0.31880376, 0.31096098,
         0.307535  ]], dtype=float32),
 'extra': array([[0.42644627, 0.39462809, 0.40661157, ..., 0.5409091 , 0.53388429,
         0.53966941],
        [0.30619835, 0.33347108, 0.35661157, ..., 0.43057852, 0.42066116,
         0.40909091],
        [0.43842976, 0.47768595, 0.58347108, ..., 0.45867768, 0.40041323,
         0.39380165],
        ...,
        [0.64049588, 0.65702479, 0.6731405 , ..., 0.36157025, 0.37272727,
         0.38429752],
        [0.3161157 , 0.3144628 , 0.37066115, ..., 0.41239669, 0.40206612,
         0.37685951],
        [0.43471075, 0.47272727, 0.51818182, ..., 0.54090908, 0.503719  ,
         0.50041322]])}
faces_pre["knn"]
array([[0.4471074 , 0.41652894, 0.42066115, ..., 0.54793394, 0.5355372 ,
        0.546281  ],
       [0.34876034, 0.34214878, 0.346281  , ..., 0.42727274, 0.42809922,
        0.43057853],
       [0.5355372 , 0.546281  , 0.58016527, ..., 0.56611574, 0.56280994,
        0.5644628 ],
       ...,
       [0.64793384, 0.67685956, 0.7049587 , ..., 0.41487604, 0.3586777 ,
        0.36776862],
       [0.3942149 , 0.41322312, 0.43553716, ..., 0.45785123, 0.43471074,
        0.39173552],
       [0.47520667, 0.47024792, 0.51404965, ..., 0.631405  , 0.6256199 ,
        0.59173554]], dtype=float32)


不同模型预测的人脸结果与实际的对比


import numpy as np
plt.figure(figsize=(6*3,8*3))
for i in range(8):
    axes = plt.subplot(8,6,i*6+1)
    axes.axis("off")
    face_up = x_test[i]
    face_down = y_test[i]
    face = np.concatenate([face_up,face_down])
    axes.imshow(face.reshape((64,64)),cmap="gray")
    if i==0:
        axes.set_title("True")
    # 把机器学习预测出来的下半边脸和上半边脸拼接
    for j,key in enumerate(faces_pre):
        axes = plt.subplot(8,6,i*6+2+j)
        axes.axis("off")
        if i==0:
            axes.set_title(key)
        face_up = x_test[i]
        y_pre = faces_pre[key]
        face_down_pre = y_pre[i]
        face =np.concatenate([face_up,face_down_pre])
        axes.imshow(face.reshape((64,64)),cmap="gray")

2abf7f312ef94f1283fceed0d172bc4b.png

通过对比发现,上述案例中通过KNN预测的结果的脸型要好一些只是有明显的分界线,需要进一步处理,线性回归与岭回归预测的结果没有明显分界线,但是实际预测效果没有那么好;lasso回归,极端随机森林预测出的人脸结果不理想。


相关文章
|
9天前
|
机器学习/深度学习 数据采集 人工智能
构建高效机器学习模型的五大技巧
【4月更文挑战第7天】 在数据科学迅猛发展的今天,机器学习已成为解决复杂问题的重要工具。然而,构建一个既精确又高效的机器学习模型并非易事。本文将分享五种提升机器学习模型性能的有效技巧,包括数据预处理、特征工程、模型选择、超参数调优以及交叉验证。这些方法不仅能帮助初学者快速提高模型准确度,也为经验丰富的数据科学家提供了进一步提升模型性能的思路。
|
22天前
|
机器学习/深度学习 数据采集 监控
大模型开发:描述一个典型的机器学习项目流程。
机器学习项目涉及问题定义、数据收集、预处理、特征工程、模型选择、训练、评估、优化、部署和监控。每个阶段都是确保模型有效可靠的关键,需要细致操作。
14 0
|
22天前
|
机器学习/深度学习
大模型开发:解释正则化及其在机器学习中的作用。
正则化是防止机器学习过拟合的技术,通过限制模型参数和控制复杂度避免过拟合。它包含L1和L2正则化,前者产生稀疏解,后者适度缩小参数。选择合适的正则化方法和强度对模型性能关键,常用交叉验证评估。
|
24天前
|
机器学习/深度学习 数据采集 算法
构建高效机器学习模型:从数据预处理到模型优化
在机器学习的实践中,构建一个高效的模型并非一蹴而就。本文将深入探讨如何通过精确的数据预处理、合理的特征选择、适当的模型构建以及细致的参数调优来提升模型的性能。我们将讨论数据清洗的重要性,探索特征工程的策略,分析不同算法的适用场景,并分享模型调参的实用技巧。目标是为读者提供一套系统的方法论,以指导他们在构建机器学习模型时能够更加高效和目标明确。
24 3
|
13天前
|
机器学习/深度学习 数据采集 算法
构建高效机器学习模型的最佳实践
【4月更文挑战第3天】在数据驱动的时代,构建高效的机器学习模型已成为解决复杂问题的关键。本文将探讨一系列实用的技术策略,旨在提高模型的性能和泛化能力。我们将从数据预处理、特征工程、模型选择、超参数调优到集成学习等方面进行详细讨论,并通过实例分析展示如何在实践中应用这些策略。
15 1
|
9天前
|
机器学习/深度学习 数据采集 算法
机器学习实战第3天:手写数字识别
机器学习实战第3天:手写数字识别
19 0
|
11天前
|
机器学习/深度学习 数据采集 算法
构建高效机器学习模型:从数据预处理到模型优化
【4月更文挑战第5天】 在机器学习领域,构建一个高效的模型并非易事。它涉及多个阶段,包括数据预处理、特征工程、模型选择、训练以及最终的评估和优化。本文深入探讨了如何通过精确的数据预处理技巧和细致的特征工程来提升模型性能,同时介绍了几种常见的模型优化策略。我们的目标是为读者提供一套实用的指导方案,帮助他们在面对复杂数据集时能够有效地构建和调整机器学习模型。
|
12天前
|
机器学习/深度学习 算法 数据挖掘
构建高效机器学习模型:从特征工程到模型调优
【4月更文挑战第4天】在数据驱动的时代,构建一个高效的机器学习模型是解决复杂问题的关键。本文将深入探讨特征工程的重要性,并分享如何通过自动化技术进行特征选择与构造。接着,我们将讨论不同的机器学习算法及其适用场景,并提供模型训练、验证和测试的最佳实践。最后,文章将展示如何使用网格搜索和交叉验证来微调模型参数,以达到最优性能。读者将获得一套完整的指南,用以提升机器学习项目的预测准确率和泛化能力。
|
21天前
|
机器学习/深度学习 人工智能 自然语言处理
大模型落地实战指南:从选择到训练,深度解析显卡选型、模型训练技、模型选择巧及AI未来展望---打造AI应用新篇章
大模型落地实战指南:从选择到训练,深度解析显卡选型、模型训练技、模型选择巧及AI未来展望---打造AI应用新篇章
大模型落地实战指南:从选择到训练,深度解析显卡选型、模型训练技、模型选择巧及AI未来展望---打造AI应用新篇章
|
21天前
|
机器学习/深度学习 分布式计算 监控
大模型开发:你如何使用大数据进行模型训练?
在大数据模型训练中,关键步骤包括数据准备(收集、清洗、特征工程、划分),硬件准备(分布式计算、并行训练),模型选择与配置,训练与优化,监控评估,以及模型的持久化与部署。过程中要关注数据隐私、安全及法规遵循,利用技术进步提升效率和性能。
28 2

相关产品

  • 人工智能平台 PAI