线性回归 最小二乘法的求解推导与基于Python的底层代码实现

简介: 作为最常见的方法之一,线性回归仍可视为有监督机器学习的方法之一,同时也是一种广泛应用统计学和数据分析的基本技术。它是一种用于估计两个或多个变量之间线性关系的方法,其中一个变量是自变量,另一个变量是因变量。线性回归假设这两个变量之间存在线性关系,并试图找到一条最佳拟合直线,使预测值与实际值之间的误差最小化。

作为最常见的方法之一,线性回归仍可视为有监督机器学习的方法之一,同时也是一种广泛应用统计学和数据分析的基本技术。它是一种用于估计两个或多个变量之间线性关系的方法,其中一个变量是自变量,另一个变量是因变量。线性回归假设这两个变量之间存在线性关系,并试图找到一条最佳拟合直线,使预测值与实际值之间的误差最小化。



ac26054d89f8b85638eb227770baf29c.png




1 线性回归模型

1.1 模型本体

线性回归模型假设因变量 y yy 与自变量 x 1 , x 2 , . . . , x p   之间存在线性关系,即

image.png

其中,β 0 \beta_0β 0  是截距,β 1 , β 2 , . . . , β p \beta_1, \beta_2, ..., \beta_pβ 1 ,β,...,β p  是自变量的系数,ϵ \epsilonϵ 是误差项。线性回归模型的目标是估计系数 β 0 , β 1 , β 2 , . . . , β p \beta_0, \beta_1, \beta_2, ..., \beta_pβ 0 ,β 1 ,β 2,...,β p ,使得模型对因变量 y yy 的预测误差最小。


1.2 最小二乘法求解

最小二乘法是一种用于估计线性回归模型参数的常用方法,它的原理是寻找最小化残差平方和的系数估计值。残差是预测值与实际值之间的差异,残差平方和是所有残差平方的和。

image.png


最小化残差平方和的解可以用公式表示为

image.png

其中,β ^ \hat{\beta}  是系数估计值,X XX 是自变量矩阵,y yy 是因变量向量。如果 X T X X^TXX TX 可逆,则上述公式存在唯一解。


1.3 最小二乘求解推导

假设我们有一个大小为 m×n 的数据集,其中每个样本由 n 个特征和一个输出值组成。我们用 X 来表示数据集中的所有特征,用 Y 来表示数据集中的所有输出值。我们的目标是找到一个大小为 n×1 的权重向量 θ,使得 X*θ 和 Y 之间的误差最小化。


为了最小化误差,我们定义一个代价函数 J(θ),如下所示:

image.png

我们的目标是找到一个最小化代价函数的 θ 值。为了找到最小化代价函数的 θ 值,我们需要对代价函数求导数,令导数为 0,求得 θ。


我们首先对代价函数 J(θ) 进行展开:

image.png


我们对 θ 求导数,得到:(矩阵求导规则如有遗忘,请查看后面的表格)

image.png


令导数为 0,解得:

image.png

在第二部分的代码中,将直接使用我们推导的结果进行计算。


附:矩阵求导规则

3c92d0b527b0ebac436420baa2c3ec5b.png


2 最小二乘法求解代码

代码部分使用回归案例的经典数据集:波士顿房价数据集。其中x 1 x_{1}x 1

代表房屋面积,x 2 x_{2}x 2

代表房间数量,Y代表房价。


2.1 原始数据

X_ = np.array([  
    [10, 1],  
    [15, 1],  
    [20, 1],  
    [30, 1],  
    [50, 2],  
    [60, 1],  
    [60, 2],  
    [70, 2]]).reshape((-1, 2))  
Y = np.array([0.8, 1.0, 1.8, 2.0, 3.2, 3.0, 3.1, 3.5]).reshape((-1, 1))


此处reshape进行了兜底操作,实际上数据结构已经符合要求,对于此示例数据结构不加不会报错,但可以避免日后输入不规范带来的Bug。


2.2 判断是否有截距项并进行操作

flag = True
if flag:  
    # 添加一个截距项对应的X值  
    X = np.column_stack((X_, np.ones(shape=(X_.shape[0], 1))))  
else:  
    # 不加入截距项  
    X = X_


添加截距项的方法,本质上是给X在最右侧添加了一个新的特征,但该特征恒为1。对应该特征的自变量系数即为截距。


2.3 求解θ \thetaθ

X = np.mat(X)  
Y = np.mat(Y)
theta = (X.T * X).I * X.T * Y  
print(theta)


此处使用np.mat,可以将ndarray转为矩阵。转换之后矩阵的计算敲代码会更加容易,可以直接使用矩阵乘法和求逆矩阵等操作。如果不转换也可以,须将计算theta的部分改为:

theta = np.dot(np.dot(np.linalg.inv(np.dot(X.T, X)), X.T), Y)


层层嵌套,看起来就有些复杂了。


2.4 模型的使用

if flag:  
    x = np.mat(np.array([[55.0, 2.0,1.0]]))  
else:  
    x = np.mat(np.array([[55.0, 2.0]]))  
pred_y = x * theta

print(f"当面积为55平并且房间数目为2的时候,预测价格为:{float(pred_y):.2f}")


此处直接使用计算后的theta进行预测,并设置输出保留两位小数。


2.5 绘制结果

绘制的结果应该包含两部分,实际数据的散点图和方程所表示的平面。首先绘制散点图。


from mpl_toolkits.mplot3d import Axes3D
x1 = X[:, 0]  
x2 = X[:, 1]  
fig = plt.figure()  
ax = Axes3D(fig)  
ax.scatter(x1, x2, Y, s=40, c='r')

aaf494d8b3e700c9d416f2e6ab37c2b4.png


之后绘制方程所表示的平面,这里稍微麻烦一点。


x = np.arange(0, 100)  
y = np.arange(0, 4)  
x, y = np.meshgrid(x, y)


首先指定了x和y的范围,然后使用meshgrid计算由x和y范围内的点形成的网格的坐标。meshgrid前xydeshape分别为(100,)(4,),meshgrid后xy的顺序均为(4,100)。In other words,我们得到了这400个点对应的xy坐标。随后


如果希望使用矩阵与数值直接相乘,(4,100)无法和(1,1)相乘,需要把theat转换成浮点数,即

def predict(x1, x2, theta, base=False):
    if base:
        y_ = x1 * float(theta[0]) + x2 * float(theta[1]) + float(theta[2])
    else:
        y_ = x1 * theta[0] + x2 * theta[1]
    return y_
z = predict(x1, x2, theta, base=True)
z.shape = x1.shape
1


如果希望使用矩阵与矩阵相乘,将x1,x2组合成一个矩阵X,然后:

def predict(x1, x2, theta, base=False):  
    if base:  
        y_ = X * self.theta[0:2] + self.theta[2]
    else:  
        y_ = X * self.theta
    return y_  
z = predict(X, theta, base=True)
z.shape = x1.shape


两种方法获得z后,都要将z的shape同样转换为(4,100)。

ax.plot_surface(x1, x2, z, rstride=1, cstride=1, cmap=plt.cm.jet)  ##画超平面   cmap=plt.cm.jet彩图
ax.set_title(u'房屋租赁价格预测')
plt.show()


最后,使用plot_surface画图。rstride 和 cstride:分别表示行跨度和列跨度,用于控制曲面的网格线密度。默认值均为 1,表示每行和每列都画线。如果将其设置为 2,则表示每隔一行或一列画一条线。cmap:表示曲面的颜色映射,用于表示高度值与颜色的对应关系。可以使用 plt.cm 模块提供的内置颜色映射,如 plt.cm.jet 表示使用蓝-绿-黄-红的颜色映射。也可以自定义颜色映射。

最终的绘制结果为:



代码比较简单,就不发合并后的代码了,有任何问题直接文末留言即可。如需下载合并后的代码可点此下载30d7f34994f0c941adab7f854145f051.png

另外实际使用过程中建议封装成类。


在实际项目中,单纯的一次线性回归可能效果并不理想,此时可能会考虑才用特征扩张来优化模型,具体内容可查看

线性回归 特征扩展的原理与python代码的实现


相关文章
|
2天前
|
机器学习/深度学习 人工智能 数据挖掘
Numba是一个Python库,用于对Python代码进行即时(JIT)编译,以便在硬件上高效执行。
Numba是一个Python库,用于对Python代码进行即时(JIT)编译,以便在硬件上高效执行。
20 9
|
2天前
|
机器人 Shell 开发者
`roslibpy`是一个Python库,它允许非ROS(Robot Operating System)环境(如Web浏览器、移动应用等)与ROS环境进行交互。通过使用`roslibpy`,开发者可以编写Python代码来远程控制ROS节点,发布和订阅话题,以及调用服务。
`roslibpy`是一个Python库,它允许非ROS(Robot Operating System)环境(如Web浏览器、移动应用等)与ROS环境进行交互。通过使用`roslibpy`,开发者可以编写Python代码来远程控制ROS节点,发布和订阅话题,以及调用服务。
18 8
|
1天前
|
存储 缓存 算法
如何优化Python代码?
【7月更文挑战第14天】如何优化Python代码?
13 6
|
2天前
|
机器学习/深度学习 TensorFlow API
Keras是一个高层神经网络API,由Python编写,并能够在TensorFlow、Theano或CNTK之上运行。Keras的设计初衷是支持快速实验,能够用最少的代码实现想法,并且能够方便地在CPU和GPU上运行。
Keras是一个高层神经网络API,由Python编写,并能够在TensorFlow、Theano或CNTK之上运行。Keras的设计初衷是支持快速实验,能够用最少的代码实现想法,并且能够方便地在CPU和GPU上运行。
9 0
|
2天前
|
Unix Linux Python
`subprocess`模块是Python中用于生成新进程、连接到它们的输入/输出/错误管道,并获取它们的返回(退出)代码的模块。
`subprocess`模块是Python中用于生成新进程、连接到它们的输入/输出/错误管道,并获取它们的返回(退出)代码的模块。
6 0
|
2天前
|
Unix Shell Python
Python代码示例标准输出与标准错误输出
Python代码示例标准输出与标准错误输出
5 0
|
2天前
|
SQL Java C++
Python代码示例简单的print()函数使用
Python代码示例简单的print()函数使用
4 0
|
2天前
|
Shell 开发者 C++
`mypy` 是一个Python的静态类型检查器,它可以在不运行代码的情况下发现潜在的类型错误。
`mypy` 是一个Python的静态类型检查器,它可以在不运行代码的情况下发现潜在的类型错误。
5 0
|
2天前
|
监控 程序员 持续交付
`pylint`是一个高度可配置的Python代码分析工具,它可以帮助程序员查找代码中的错误、样式问题、可能的bug以及不符合编码标准的部分。
`pylint`是一个高度可配置的Python代码分析工具,它可以帮助程序员查找代码中的错误、样式问题、可能的bug以及不符合编码标准的部分。
7 0
|
文字识别 算法 前端开发
100行Python代码实现一款高精度免费OCR工具
近期Github开源了一款基于Python开发、名为Textshot的截图工具,刚开源不到半个月已经500+Star。
100行Python代码实现一款高精度免费OCR工具