【阿旭机器学习实战】【7】岭回归基本原理及其λ的选取方法

简介: 【阿旭机器学习实战】【7】岭回归基本原理及其λ的选取方法

线性回归之岭回归


1、原理


如果数据的特征数比样本点还多应该怎么办?是否还可以使用普通的线性回归来做预测?

答案是否定的。因为输入数据的矩阵X不是满秩矩阵。非满秩矩阵在求逆时会出现问题

为了解决这个问题,统计学家引入了岭回归(ridge regression)的概念。


44d007709508442a9d71d4f5af0a5c44.png

缩减方法可以去掉不重要的参数,因此能更好地理解数据。此外,与简单的线性回归相比,缩减法能取得更好的预测效果。


【注意】在岭回归里面,决定回归模型性能的除了数据算法以外,还有一个缩减值lambda * I


岭回归是加了二阶正则项(lambda*I)的最小二乘,主要适用于过拟合严重或各变量之间存在多重共线性的时候,岭回归是有bias的,这里的bias是为了让variance更小。


2、岭回归主要处理的问题


岭回归主要用于处理下面两类问题:


1.数据点少于特征变量个数


2.变量间存在共线性(最小二乘回归得到的系数不稳定,方差很大)


3、归纳总结


1.岭回归可以解决特征数量比样本量多的问题


2.岭回归作为一种缩减算法可以判断哪些特征重要或者不重要,有点类似于降维的效果


3.缩减算法可以看作是对一个模型增加偏差的同时减少方差


4、岭回归实例


# 导入岭回归模型
from sklearn.linear_model import Ridge
from sklearn.linear_model import LinearRegression
import numpy as np
# 手动创建训练数据
x_train = np.array([[1,1,2,1,4],[2,3,4,1,2],[1,3,2,4,1],[2,1,3,4,5]])
y_train = np.array([1,2,3,4])
# 测试数据
x_test = np.array([[1,2,3,1,2]])


4.1普通线性回归进行预测


linear = LinearRegression()
• 1
linear.fit(x_train,y_train)
LinearRegression(copy_X=True, fit_intercept=True, n_jobs=1, normalize=False)
linear.predict(x_test) # 这个预测是不准确的,x_train不可逆
• 1
array([1.20837809])
• 1



4.2 岭回归进行预测


rigde = Ridge(alpha=1000)
# alpha就是 lambda*I中的lambda
• 1
• 2
rigde.fit(x_train,y_train)
Ridge(alpha=1000, copy_X=True, fit_intercept=True, max_iter=None,
   normalize=False, random_state=None, solver='auto', tol=0.001)
rigde.predict(x_test)
• 1
array([2.48971789])
# 回归系数W = (X^T*X + lambda*I)^(-1) * X^T*Y
# Y = W^T * X + b
rigde.coef_ # 回归系数的衰减和alpha值有关,alpha越大,alpha对回归的影响就越大

array([9.97254560e-04, 5.40720570e-06, 5.06027985e-04, 5.94723394e-03,
       9.89143751e-04])
# 普通线性回归系数
linear.coef_
• 1
• 2
array([ 0.30827068,  0.11385607,  0.36949517,  0.72824919,  0.13748657])



4.3 岭回归的核心问题是找到合适的alpha值


选取alpha值的一般原则是:各回归系数的岭估计基本稳定


# 建立训练数据集
x_train = 1/(np.arange(1,11)+ np.arange(0,10).reshape((10,1)))
x_train
array([[1.        , 0.5       , 0.33333333, 0.25      , 0.2       ,
        0.16666667, 0.14285714, 0.125     , 0.11111111, 0.1       ],
       [0.5       , 0.33333333, 0.25      , 0.2       , 0.16666667,
        0.14285714, 0.125     , 0.11111111, 0.1       , 0.09090909],
       [0.33333333, 0.25      , 0.2       , 0.16666667, 0.14285714,
        0.125     , 0.11111111, 0.1       , 0.09090909, 0.08333333],
       [0.25      , 0.2       , 0.16666667, 0.14285714, 0.125     ,
        0.11111111, 0.1       , 0.09090909, 0.08333333, 0.07692308],
       [0.2       , 0.16666667, 0.14285714, 0.125     , 0.11111111,
        0.1       , 0.09090909, 0.08333333, 0.07692308, 0.07142857],
       [0.16666667, 0.14285714, 0.125     , 0.11111111, 0.1       ,
        0.09090909, 0.08333333, 0.07692308, 0.07142857, 0.06666667],
       [0.14285714, 0.125     , 0.11111111, 0.1       , 0.09090909,
        0.08333333, 0.07692308, 0.07142857, 0.06666667, 0.0625    ],
       [0.125     , 0.11111111, 0.1       , 0.09090909, 0.08333333,
        0.07692308, 0.07142857, 0.06666667, 0.0625    , 0.05882353],
       [0.11111111, 0.1       , 0.09090909, 0.08333333, 0.07692308,
        0.07142857, 0.06666667, 0.0625    , 0.05882353, 0.05555556],
       [0.1       , 0.09090909, 0.08333333, 0.07692308, 0.07142857,
        0.06666667, 0.0625    , 0.05882353, 0.05555556, 0.05263158]])
y_train = np.ones(10)
y_train
• 1
• 2
array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])


# 创建一系列alpha值作为岭回归的缩减系数
alphas = np.logspace(-2,5,200)
alphas
array([1.00000000e-02, 1.08436597e-02, 1.17584955e-02, 1.27505124e-02,
       1.38262217e-02, 1.49926843e-02, 1.62575567e-02, 1.76291412e-02,
       1.91164408e-02, 2.07292178e-02, 2.24780583e-02, 2.43744415e-02,
       2.64308149e-02, 2.86606762e-02, 3.10786619e-02, 3.37006433e-02,
       3.65438307e-02, 3.96268864e-02, 4.29700470e-02, 4.65952567e-02,
       5.05263107e-02, 5.47890118e-02, 5.94113398e-02, 6.44236351e-02,
       6.98587975e-02, 7.57525026e-02, 8.21434358e-02, 8.90735464e-02,
       9.65883224e-02, 1.04737090e-01, 1.13573336e-01, 1.23155060e-01,
       1.33545156e-01, 1.44811823e-01, 1.57029012e-01, 1.70276917e-01,
       1.84642494e-01, 2.00220037e-01, 2.17111795e-01, 2.35428641e-01,
       2.55290807e-01, 2.76828663e-01, 3.00183581e-01, 3.25508860e-01,
       3.52970730e-01, 3.82749448e-01, 4.15040476e-01, 4.50055768e-01,
       4.88025158e-01, 5.29197874e-01, 5.73844165e-01, 6.22257084e-01,
       6.74754405e-01, 7.31680714e-01, 7.93409667e-01, 8.60346442e-01,
       9.32930403e-01, 1.01163798e+00, 1.09698580e+00, 1.18953407e+00,
       1.28989026e+00, 1.39871310e+00, 1.51671689e+00, 1.64467618e+00,
       1.78343088e+00, 1.93389175e+00, 2.09704640e+00, 2.27396575e+00,
       2.46581108e+00, 2.67384162e+00, 2.89942285e+00, 3.14403547e+00,
       3.40928507e+00, 3.69691271e+00, 4.00880633e+00, 4.34701316e+00,
       4.71375313e+00, 5.11143348e+00, 5.54266452e+00, 6.01027678e+00,
       6.51733960e+00, 7.06718127e+00, 7.66341087e+00, 8.30994195e+00,
       9.01101825e+00, 9.77124154e+00, 1.05956018e+01, 1.14895100e+01,
       1.24588336e+01, 1.35099352e+01, 1.46497140e+01, 1.58856513e+01,
       1.72258597e+01, 1.86791360e+01, 2.02550194e+01, 2.19638537e+01,
       2.38168555e+01, 2.58261876e+01, 2.80050389e+01, 3.03677112e+01,
       3.29297126e+01, 3.57078596e+01, 3.87203878e+01, 4.19870708e+01,
       4.55293507e+01, 4.93704785e+01, 5.35356668e+01, 5.80522552e+01,
       6.29498899e+01, 6.82607183e+01, 7.40196000e+01, 8.02643352e+01,
       8.70359136e+01, 9.43787828e+01, 1.02341140e+02, 1.10975250e+02,
       1.20337784e+02, 1.30490198e+02, 1.41499130e+02, 1.53436841e+02,
       1.66381689e+02, 1.80418641e+02, 1.95639834e+02, 2.12145178e+02,
       2.30043012e+02, 2.49450814e+02, 2.70495973e+02, 2.93316628e+02,
       3.18062569e+02, 3.44896226e+02, 3.73993730e+02, 4.05546074e+02,
       4.39760361e+02, 4.76861170e+02, 5.17092024e+02, 5.60716994e+02,
       6.08022426e+02, 6.59318827e+02, 7.14942899e+02, 7.75259749e+02,
       8.40665289e+02, 9.11588830e+02, 9.88495905e+02, 1.07189132e+03,
       1.16232247e+03, 1.26038293e+03, 1.36671636e+03, 1.48202071e+03,
       1.60705282e+03, 1.74263339e+03, 1.88965234e+03, 2.04907469e+03,
       2.22194686e+03, 2.40940356e+03, 2.61267523e+03, 2.83309610e+03,
       3.07211300e+03, 3.33129479e+03, 3.61234270e+03, 3.91710149e+03,
       4.24757155e+03, 4.60592204e+03, 4.99450512e+03, 5.41587138e+03,
       5.87278661e+03, 6.36824994e+03, 6.90551352e+03, 7.48810386e+03,
       8.11984499e+03, 8.80488358e+03, 9.54771611e+03, 1.03532184e+04,
       1.12266777e+04, 1.21738273e+04, 1.32008840e+04, 1.43145894e+04,
       1.55222536e+04, 1.68318035e+04, 1.82518349e+04, 1.97916687e+04,
       2.14614120e+04, 2.32720248e+04, 2.52353917e+04, 2.73644000e+04,
       2.96730241e+04, 3.21764175e+04, 3.48910121e+04, 3.78346262e+04,
       4.10265811e+04, 4.44878283e+04, 4.82410870e+04, 5.23109931e+04,
       5.67242607e+04, 6.15098579e+04, 6.66991966e+04, 7.23263390e+04,
       7.84282206e+04, 8.50448934e+04, 9.22197882e+04, 1.00000000e+05])
# 用上面一系列的alpha值,创建对应的算法
# 定义一个列表,用于存放每次训练的回归系数
w = []
# 创建岭回归模型
r = Ridge(fit_intercept=False)
for alpha in alphas:
    # 给算法设置不同的alpha值
    r.set_params(alpha=alpha)
    # 对不同的alpha值的模型进行训练
    r.fit(x_train,y_train)
    # 取出对应的回归系数
    w.append(r.coef_)


画岭迹线(回归系数和alpha值之间的关系)


import matplotlib.pyplot as plt
%matplotlib inline
plt.figure(figsize=(12,9))
axes = plt.subplot(111)
axes.plot(alphas,w)
axes.set_xscale("log")

fae3704da42645f29de0bc63959ff98e.png


通过对岭迹线的观察发现超100以后岭迹线基本趋于稳定,alpha合适值在100以后

如果内容对你有帮助,感谢记得点赞+关注哦!

相关文章
|
2月前
|
机器学习/深度学习 算法 数据挖掘
K-means聚类算法是机器学习中常用的一种聚类方法,通过将数据集划分为K个簇来简化数据结构
K-means聚类算法是机器学习中常用的一种聚类方法,通过将数据集划分为K个簇来简化数据结构。本文介绍了K-means算法的基本原理,包括初始化、数据点分配与簇中心更新等步骤,以及如何在Python中实现该算法,最后讨论了其优缺点及应用场景。
140 4
|
4天前
|
机器学习/深度学习 人工智能 算法
机器学习算法的优化与改进:提升模型性能的策略与方法
机器学习算法的优化与改进:提升模型性能的策略与方法
45 13
机器学习算法的优化与改进:提升模型性能的策略与方法
|
24天前
|
机器学习/深度学习 传感器 运维
使用机器学习技术进行时间序列缺失数据填充:基础方法与入门案例
本文探讨了时间序列分析中数据缺失的问题,并通过实际案例展示了如何利用机器学习技术进行缺失值补充。文章构建了一个模拟的能源生产数据集,采用线性回归和决策树回归两种方法进行缺失值补充,并从统计特征、自相关性、趋势和季节性等多个维度进行了详细评估。结果显示,决策树方法在处理复杂非线性模式和保持数据局部特征方面表现更佳,而线性回归方法则适用于简单的线性趋势数据。文章最后总结了两种方法的优劣,并给出了实际应用建议。
60 7
使用机器学习技术进行时间序列缺失数据填充:基础方法与入门案例
|
1月前
|
机器学习/深度学习 存储 运维
分布式机器学习系统:设计原理、优化策略与实践经验
本文详细探讨了分布式机器学习系统的发展现状与挑战,重点分析了数据并行、模型并行等核心训练范式,以及参数服务器、优化器等关键组件的设计与实现。文章还深入讨论了混合精度训练、梯度累积、ZeRO优化器等高级特性,旨在提供一套全面的技术解决方案,以应对超大规模模型训练中的计算、存储及通信挑战。
72 4
|
2月前
|
机器学习/深度学习 算法 UED
在数据驱动时代,A/B 测试成为评估机器学习项目不同方案效果的重要方法
在数据驱动时代,A/B 测试成为评估机器学习项目不同方案效果的重要方法。本文介绍 A/B 测试的基本概念、步骤及其在模型评估、算法改进、特征选择和用户体验优化中的应用,同时提供 Python 实现示例,强调其在确保项目性能和用户体验方面的关键作用。
40 6
|
2月前
|
机器学习/深度学习 数据采集 算法
机器学习在医疗诊断中的前沿应用,包括神经网络、决策树和支持向量机等方法,及其在医学影像、疾病预测和基因数据分析中的具体应用
医疗诊断是医学的核心,其准确性和效率至关重要。本文探讨了机器学习在医疗诊断中的前沿应用,包括神经网络、决策树和支持向量机等方法,及其在医学影像、疾病预测和基因数据分析中的具体应用。文章还讨论了Python在构建机器学习模型中的作用,面临的挑战及应对策略,并展望了未来的发展趋势。
155 1
|
2月前
|
机器学习/深度学习 数据采集 数据可视化
Python数据科学实战:从Pandas到机器学习
Python数据科学实战:从Pandas到机器学习
|
2月前
|
机器学习/深度学习 数据采集 数据处理
谷歌提出视觉记忆方法,让大模型训练数据更灵活
谷歌研究人员提出了一种名为“视觉记忆”的方法,结合了深度神经网络的表示能力和数据库的灵活性。该方法将图像分类任务分为图像相似性和搜索两部分,支持灵活添加和删除数据、可解释的决策机制以及大规模数据处理能力。实验结果显示,该方法在多个数据集上取得了优异的性能,如在ImageNet上实现88.5%的top-1准确率。尽管有依赖预训练模型等限制,但视觉记忆为深度学习提供了新的思路。
39 2
|
2月前
|
机器学习/深度学习 TensorFlow API
机器学习实战:TensorFlow在图像识别中的应用探索
【10月更文挑战第28天】随着深度学习技术的发展,图像识别取得了显著进步。TensorFlow作为Google开源的机器学习框架,凭借其强大的功能和灵活的API,在图像识别任务中广泛应用。本文通过实战案例,探讨TensorFlow在图像识别中的优势与挑战,展示如何使用TensorFlow构建和训练卷积神经网络(CNN),并评估模型的性能。尽管面临学习曲线和资源消耗等挑战,TensorFlow仍展现出广阔的应用前景。
80 5
|
2月前
|
机器学习/深度学习 人工智能 TensorFlow
基于TensorFlow的深度学习模型训练与优化实战
基于TensorFlow的深度学习模型训练与优化实战
107 0