【阿旭机器学习实战】【7】岭回归基本原理及其λ的选取方法

简介: 【阿旭机器学习实战】【7】岭回归基本原理及其λ的选取方法

线性回归之岭回归


1、原理


如果数据的特征数比样本点还多应该怎么办?是否还可以使用普通的线性回归来做预测?

答案是否定的。因为输入数据的矩阵X不是满秩矩阵。非满秩矩阵在求逆时会出现问题

为了解决这个问题,统计学家引入了岭回归(ridge regression)的概念。


44d007709508442a9d71d4f5af0a5c44.png

缩减方法可以去掉不重要的参数,因此能更好地理解数据。此外,与简单的线性回归相比,缩减法能取得更好的预测效果。


【注意】在岭回归里面,决定回归模型性能的除了数据算法以外,还有一个缩减值lambda * I


岭回归是加了二阶正则项(lambda*I)的最小二乘,主要适用于过拟合严重或各变量之间存在多重共线性的时候,岭回归是有bias的,这里的bias是为了让variance更小。


2、岭回归主要处理的问题


岭回归主要用于处理下面两类问题:


1.数据点少于特征变量个数


2.变量间存在共线性(最小二乘回归得到的系数不稳定,方差很大)


3、归纳总结


1.岭回归可以解决特征数量比样本量多的问题


2.岭回归作为一种缩减算法可以判断哪些特征重要或者不重要,有点类似于降维的效果


3.缩减算法可以看作是对一个模型增加偏差的同时减少方差


4、岭回归实例


# 导入岭回归模型
from sklearn.linear_model import Ridge
from sklearn.linear_model import LinearRegression
import numpy as np
# 手动创建训练数据
x_train = np.array([[1,1,2,1,4],[2,3,4,1,2],[1,3,2,4,1],[2,1,3,4,5]])
y_train = np.array([1,2,3,4])
# 测试数据
x_test = np.array([[1,2,3,1,2]])


4.1普通线性回归进行预测


linear = LinearRegression()
• 1
linear.fit(x_train,y_train)
LinearRegression(copy_X=True, fit_intercept=True, n_jobs=1, normalize=False)
linear.predict(x_test) # 这个预测是不准确的,x_train不可逆
• 1
array([1.20837809])
• 1



4.2 岭回归进行预测


rigde = Ridge(alpha=1000)
# alpha就是 lambda*I中的lambda
• 1
• 2
rigde.fit(x_train,y_train)
Ridge(alpha=1000, copy_X=True, fit_intercept=True, max_iter=None,
   normalize=False, random_state=None, solver='auto', tol=0.001)
rigde.predict(x_test)
• 1
array([2.48971789])
# 回归系数W = (X^T*X + lambda*I)^(-1) * X^T*Y
# Y = W^T * X + b
rigde.coef_ # 回归系数的衰减和alpha值有关,alpha越大,alpha对回归的影响就越大

array([9.97254560e-04, 5.40720570e-06, 5.06027985e-04, 5.94723394e-03,
       9.89143751e-04])
# 普通线性回归系数
linear.coef_
• 1
• 2
array([ 0.30827068,  0.11385607,  0.36949517,  0.72824919,  0.13748657])



4.3 岭回归的核心问题是找到合适的alpha值


选取alpha值的一般原则是:各回归系数的岭估计基本稳定


# 建立训练数据集
x_train = 1/(np.arange(1,11)+ np.arange(0,10).reshape((10,1)))
x_train
array([[1.        , 0.5       , 0.33333333, 0.25      , 0.2       ,
        0.16666667, 0.14285714, 0.125     , 0.11111111, 0.1       ],
       [0.5       , 0.33333333, 0.25      , 0.2       , 0.16666667,
        0.14285714, 0.125     , 0.11111111, 0.1       , 0.09090909],
       [0.33333333, 0.25      , 0.2       , 0.16666667, 0.14285714,
        0.125     , 0.11111111, 0.1       , 0.09090909, 0.08333333],
       [0.25      , 0.2       , 0.16666667, 0.14285714, 0.125     ,
        0.11111111, 0.1       , 0.09090909, 0.08333333, 0.07692308],
       [0.2       , 0.16666667, 0.14285714, 0.125     , 0.11111111,
        0.1       , 0.09090909, 0.08333333, 0.07692308, 0.07142857],
       [0.16666667, 0.14285714, 0.125     , 0.11111111, 0.1       ,
        0.09090909, 0.08333333, 0.07692308, 0.07142857, 0.06666667],
       [0.14285714, 0.125     , 0.11111111, 0.1       , 0.09090909,
        0.08333333, 0.07692308, 0.07142857, 0.06666667, 0.0625    ],
       [0.125     , 0.11111111, 0.1       , 0.09090909, 0.08333333,
        0.07692308, 0.07142857, 0.06666667, 0.0625    , 0.05882353],
       [0.11111111, 0.1       , 0.09090909, 0.08333333, 0.07692308,
        0.07142857, 0.06666667, 0.0625    , 0.05882353, 0.05555556],
       [0.1       , 0.09090909, 0.08333333, 0.07692308, 0.07142857,
        0.06666667, 0.0625    , 0.05882353, 0.05555556, 0.05263158]])
y_train = np.ones(10)
y_train
• 1
• 2
array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])


# 创建一系列alpha值作为岭回归的缩减系数
alphas = np.logspace(-2,5,200)
alphas
array([1.00000000e-02, 1.08436597e-02, 1.17584955e-02, 1.27505124e-02,
       1.38262217e-02, 1.49926843e-02, 1.62575567e-02, 1.76291412e-02,
       1.91164408e-02, 2.07292178e-02, 2.24780583e-02, 2.43744415e-02,
       2.64308149e-02, 2.86606762e-02, 3.10786619e-02, 3.37006433e-02,
       3.65438307e-02, 3.96268864e-02, 4.29700470e-02, 4.65952567e-02,
       5.05263107e-02, 5.47890118e-02, 5.94113398e-02, 6.44236351e-02,
       6.98587975e-02, 7.57525026e-02, 8.21434358e-02, 8.90735464e-02,
       9.65883224e-02, 1.04737090e-01, 1.13573336e-01, 1.23155060e-01,
       1.33545156e-01, 1.44811823e-01, 1.57029012e-01, 1.70276917e-01,
       1.84642494e-01, 2.00220037e-01, 2.17111795e-01, 2.35428641e-01,
       2.55290807e-01, 2.76828663e-01, 3.00183581e-01, 3.25508860e-01,
       3.52970730e-01, 3.82749448e-01, 4.15040476e-01, 4.50055768e-01,
       4.88025158e-01, 5.29197874e-01, 5.73844165e-01, 6.22257084e-01,
       6.74754405e-01, 7.31680714e-01, 7.93409667e-01, 8.60346442e-01,
       9.32930403e-01, 1.01163798e+00, 1.09698580e+00, 1.18953407e+00,
       1.28989026e+00, 1.39871310e+00, 1.51671689e+00, 1.64467618e+00,
       1.78343088e+00, 1.93389175e+00, 2.09704640e+00, 2.27396575e+00,
       2.46581108e+00, 2.67384162e+00, 2.89942285e+00, 3.14403547e+00,
       3.40928507e+00, 3.69691271e+00, 4.00880633e+00, 4.34701316e+00,
       4.71375313e+00, 5.11143348e+00, 5.54266452e+00, 6.01027678e+00,
       6.51733960e+00, 7.06718127e+00, 7.66341087e+00, 8.30994195e+00,
       9.01101825e+00, 9.77124154e+00, 1.05956018e+01, 1.14895100e+01,
       1.24588336e+01, 1.35099352e+01, 1.46497140e+01, 1.58856513e+01,
       1.72258597e+01, 1.86791360e+01, 2.02550194e+01, 2.19638537e+01,
       2.38168555e+01, 2.58261876e+01, 2.80050389e+01, 3.03677112e+01,
       3.29297126e+01, 3.57078596e+01, 3.87203878e+01, 4.19870708e+01,
       4.55293507e+01, 4.93704785e+01, 5.35356668e+01, 5.80522552e+01,
       6.29498899e+01, 6.82607183e+01, 7.40196000e+01, 8.02643352e+01,
       8.70359136e+01, 9.43787828e+01, 1.02341140e+02, 1.10975250e+02,
       1.20337784e+02, 1.30490198e+02, 1.41499130e+02, 1.53436841e+02,
       1.66381689e+02, 1.80418641e+02, 1.95639834e+02, 2.12145178e+02,
       2.30043012e+02, 2.49450814e+02, 2.70495973e+02, 2.93316628e+02,
       3.18062569e+02, 3.44896226e+02, 3.73993730e+02, 4.05546074e+02,
       4.39760361e+02, 4.76861170e+02, 5.17092024e+02, 5.60716994e+02,
       6.08022426e+02, 6.59318827e+02, 7.14942899e+02, 7.75259749e+02,
       8.40665289e+02, 9.11588830e+02, 9.88495905e+02, 1.07189132e+03,
       1.16232247e+03, 1.26038293e+03, 1.36671636e+03, 1.48202071e+03,
       1.60705282e+03, 1.74263339e+03, 1.88965234e+03, 2.04907469e+03,
       2.22194686e+03, 2.40940356e+03, 2.61267523e+03, 2.83309610e+03,
       3.07211300e+03, 3.33129479e+03, 3.61234270e+03, 3.91710149e+03,
       4.24757155e+03, 4.60592204e+03, 4.99450512e+03, 5.41587138e+03,
       5.87278661e+03, 6.36824994e+03, 6.90551352e+03, 7.48810386e+03,
       8.11984499e+03, 8.80488358e+03, 9.54771611e+03, 1.03532184e+04,
       1.12266777e+04, 1.21738273e+04, 1.32008840e+04, 1.43145894e+04,
       1.55222536e+04, 1.68318035e+04, 1.82518349e+04, 1.97916687e+04,
       2.14614120e+04, 2.32720248e+04, 2.52353917e+04, 2.73644000e+04,
       2.96730241e+04, 3.21764175e+04, 3.48910121e+04, 3.78346262e+04,
       4.10265811e+04, 4.44878283e+04, 4.82410870e+04, 5.23109931e+04,
       5.67242607e+04, 6.15098579e+04, 6.66991966e+04, 7.23263390e+04,
       7.84282206e+04, 8.50448934e+04, 9.22197882e+04, 1.00000000e+05])
# 用上面一系列的alpha值,创建对应的算法
# 定义一个列表,用于存放每次训练的回归系数
w = []
# 创建岭回归模型
r = Ridge(fit_intercept=False)
for alpha in alphas:
    # 给算法设置不同的alpha值
    r.set_params(alpha=alpha)
    # 对不同的alpha值的模型进行训练
    r.fit(x_train,y_train)
    # 取出对应的回归系数
    w.append(r.coef_)


画岭迹线(回归系数和alpha值之间的关系)


import matplotlib.pyplot as plt
%matplotlib inline
plt.figure(figsize=(12,9))
axes = plt.subplot(111)
axes.plot(alphas,w)
axes.set_xscale("log")

fae3704da42645f29de0bc63959ff98e.png


通过对岭迹线的观察发现超100以后岭迹线基本趋于稳定,alpha合适值在100以后

如果内容对你有帮助,感谢记得点赞+关注哦!

相关文章
|
3天前
|
机器学习/深度学习 数据采集 搜索推荐
机器学习多场景实战(一)
机器学习已广泛应用,从个性化推荐到金融风控,数据指标是评估其效果的关键。数据指标包括活跃用户(DAU, MAU, WAU)衡量用户粘性,新增用户量和注册转化率评估营销效果,留存率(次日、7日、30日)反映用户吸引力,行为指标如PV(页面浏览量)、UV(独立访客)和转化率分析用户行为。产品数据指标如GMV、ARPU、ARPPU和付费率关注业务变现,推广付费指标(CPM, CPC, CPA等)则关乎广告效率。找到北极星指标,如月销售额或用户留存,可指导业务发展。案例中涉及电商销售数据,计算月销售金额、环比、销量、新用户占比、激活率和留存率以评估业务表现。
|
6天前
|
机器学习/深度学习 人工智能 算法
【机器学习】RLHF:在线方法与离线算法在大模型语言模型校准中的博弈
【机器学习】RLHF:在线方法与离线算法在大模型语言模型校准中的博弈
210 6
|
3天前
|
机器学习/深度学习 搜索推荐 数据挖掘
机器学习多场景实战(二 )
这是一个关于机器学习应用于电商平台用户行为分析的概要,包括以下几个关键点: 1. **月活跃用户分析**:通过购买记录确定活跃用户,计算每月活跃用户数。 2. **月客单价**:定义为月度总销售额除以月活跃用户数,衡量平均每位活跃用户的消费金额。 3. **新用户占比**:基于用户首次购买和最近购买时间判断新老用户,计算每月新用户的购买比例。 4. **激活率计算**:定义为当月与上月都有购买行为的用户数占上月购买用户数的比例,反映用户留存情况。 5. **Pandas数据操作**:使用Pandas库进行数据集合并(concat和merge),以及计算不同维度的组合。
|
5天前
|
机器学习/深度学习 算法 BI
机器学习笔记(一) 感知机算法 之 原理篇
机器学习笔记(一) 感知机算法 之 原理篇
|
6天前
|
机器学习/深度学习 人工智能 Java
【Sping Boot与机器学习融合:构建赋能AI的微服务应用实战】
【Sping Boot与机器学习融合:构建赋能AI的微服务应用实战】
10 1
|
7天前
|
机器学习/深度学习 算法 搜索推荐
机器学习方法之强化学习
强化学习是一种机器学习方法,旨在通过与环境的交互来学习如何做出决策,以最大化累积的奖励。
26 2
|
7天前
|
机器学习/深度学习 搜索推荐
解决冷启动问题的机器学习方法和一个简化的代码示例
解决冷启动问题的机器学习方法和一个简化的代码示例
|
8天前
|
机器学习/深度学习 搜索推荐 PyTorch
【机器学习】图神经网络:深度解析图神经网络的基本构成和原理以及关键技术
【机器学习】图神经网络:深度解析图神经网络的基本构成和原理以及关键技术
39 2
|
8天前
|
机器学习/深度学习 数据采集 运维
无监督学习是机器学习的一种重要方法
无监督学习是机器学习的一种重要方法
|
8天前
|
机器学习/深度学习 算法 TensorFlow
强化学习是一种通过与环境交互来学习最优行为策略的机器学习方法。
强化学习是一种通过与环境交互来学习最优行为策略的机器学习方法。

热门文章

最新文章