一、前言
现在机器学习是一个非常热门的词,在生活中很多事情都会和机器学习扯上关系。那你知道机器学习是什么吗?其实很多读者都已经接触过机器学习的一些内容,只是没有用机器学习这个词罢了。
今天就带大家来了解一下机器学习相关的知识,并使用Java实现一个非常常用的机器学习算法--线性回归。
二、机器学习
机器学习是人工智能的一个子集,机器学习的目的是构建一个模型,通过已有的经验不断学习优化参数。用这个优化后的模型来预测还未发生的事情。这里就牵扯到了四个东西:
- 模型
- 经验
- 学习
- 预测
模型就是我们常说的机器学习算法,通常我们都会选择已有的一些模型。比如线性回归、逻辑回归、支持向量机等。我们需要根据问题类型来选择模型,本文介绍的是线性回归模型。
经验其实就是我们常说的数据,如果是面对天气预测的问题,我们的经验就是前几个小时或者前几天的天气数据
学习是机器学习中非常重要的一步。但是机器本身不会自己学习,需要人告诉机器如何学习。因此我们需要定义一个特殊的函数(损失函数),来帮助机器学习参数。
预测则是我们使用模型的过程,也是最简单的一步。
现在我们了解了一些基础的内容,下面我们详细看看线性回归的一些细节。
三、线性回归
3.1、找规律
相信大家都做过这样的题目,给定下面的数字,让你猜下一个数字是多少:
4 7 10 13 16 复制代码
如果盲目的猜我们很难猜到结果,现在我们假设上面数字满足方程:
我把数字的序号当作x,结果数字当作y。那我们就可以得到下面这几个坐标:
(1, 4) (2, 7) (3, 10) (4, 13) (5, 16) 复制代码
然后选两个坐标带入方程,比如选(1,4)和(3,10),我们就可以得到下面方程组:
我们就可以解出
然后再用其它坐标代入,验证一下方程的准确性。接下来我们就可以猜下一个数字是多少了,下一个数字的x是6,那么下一个数字应该是19。
3.2、线性回归
其实上面的就是线性回归,我们的目标就是找到一组最优的k和b。但是但是我们通常会把线性回归方程写出下面的样子:
这个方程和之前的方程没有区别,只是换了一个字母表示。其中w意为权重(weight),b意为偏置(bias)。
现在我们求解参数的方式和前面不一样了,并不是直接带入坐标然后解方程,而是通过一种叫梯度下降的算法来调节参数。在下一节我们会详细说到。
在机器学习中,我们把用来调节(训练)参数的数据叫做训练数据。其中x叫做特征值,y叫做目标值。
三、损失函数
在学习之前,我们通常会给模型一个初始化的参数。但是这个初始的参数通常不会太好,我们用这组参数会得到一个比较差的结果。那这个结果有多差呢?这就需要定义一个损失函数来评估这组参数到底有多差了。
损失函数是用来评估参数好坏的一个函数,通常也有很多现成的函数供我们选择。这里我们选择均方误差作为我们的损失函数:
其中y和x是我们已知的数据,而w和b是这个函数的变量。我们用wx+b
计算出预测结果,然后用实际的y减去预测的y就是我们预测的误差了。为了保证这个值为正数,所以加了个平方。
这里需要注意,我们可以选择平方开根号,也可以选择不开根号。当我们开根号时,则叫做均方根误差。
上面我们只是对其中一个数字求了损失值,但是我们训练数据通常有很多,因此我们需要求所有误差的平均。所以完整的损失值计算如下:
这里只是加了个求平均的操作。
四、梯度下降
4.1、梯度下降
我们可以先观察一下参数和损失值之间的函数图形,为了方便观看,我们只考虑w参数,图像如下:
其中x轴为w参数,y轴为损失值。从图中可以看到B点是损失值最小的点(不考虑图外的区域)。
在B点的左右有A、C两个点,我们分别来分析一下这两个点。
A点在B点的左边,为了让损失值变小,我们需要往右更新参数,即增加参数。我们从图中可以看出,A点切线的斜率是小于0的,也就是A点处的导数小于0。这一点很重要,我们分析完C点再一起详细说。
C点在B点的右边,因此我们需要往左更新参数,即减少参数。从图中可以看出,C点的斜率要大于0,也就是C点处导数大于0。
分析完A、C两个点后,我们发现。当导数大于0时,我们需要减小参数,当导数小于0是我们需要增加参数。因此我们可以按照如下公式来更新参数:
其中更新后的参数等于原本参数减去损失函数对参数w的导数。
4.2、学习率
用上面的方式更新参数其实是有问题的,因为我们求到的导数很多时候是一个比较大的数值,如果我们直接减去导数的话就会出现左图的情况:
可以看到参数一直在峡谷振荡,要更新到最优参数要花很长的时间,或者根本更新不到最优参数。
因此我们可以让导数乘一个很小的数,这个数就是学习率,当我们使用学习率后我们参数更新可能就会更接近右图了。
4.3、多个参数更新
我们已经知道一个参数怎么更新了,那多个参数要怎么更新呢?在我们前面线性回归方程有w和b两个参数,更新会有什么区别吗?
其实并没有太大区别,我们只需要将求导数换成求偏导就好了,然后分别对w和b进行更新。
其中损失函数对w的偏导如下:
损失函数对b的偏导如下:
我们参数更新函数如下:
如果不记得求导公式可以直接套用上面的公式更新参数即可。
五、实现线性回归
下面我们用Java来实现一个线性回归模型。
5.1、类的创建和初始化参数
首先我们创建一个名为LinearRegression的类,代码如下:
package com.zack.lr; import java.util.ArrayList; public class LinearRegression { //权重 private double weight; //偏置值 private double bias; //特征值 private ArrayList<Double> features; //目标值 private ArrayList<Double> targets; /** * 构造线性回归模型 * @param features 训练数据的特征值 * @param targets 训练数据的目标值 */ public LinearRegression(ArrayList<Double> features, ArrayList<Double> targets){ this.features = features; this.targets = targets; initParameter(); } /** * 初始化权重和偏置值 */ public void initParameter(){ this.weight = Math.random(); this.bias = Math.random(); } } 复制代码
在类中我们定义了四个成员变量,分别是两个模型参数和训练数据的特征值和目标值。
然后我们需要在创建模型时传入训练数据的特征值和目标值,并调用initParameter函数随机初始模型的参数。
5.2、梯度下降
下面我们编写梯度下降算法,代码如下:
package com.zack.lr; import java.util.ArrayList; public class LinearRegression { ...... /** * 梯度下降更新参数 * @param learning_rate 学习率 * @return 损失值 */ public double gradientDecent(double learning_rate){ double w_ = 0; double b_ = 0; double totalLoss = 0; int n = this.features.size(); double n = this.features.size(); for (int i = 0; i < this.features.size(); i++) { double yPredict = this.features.get(i) * this.weight + this.bias; // loss对w的偏导 w_ += -2 * learning_rate * this.features.get(i) * (this.targets.get(i) - yPredict) / n; // loss对b的偏导 b_ += -2 * learning_rate * (this.targets.get(i) - (yPredict)) / n; // 计算loss用于输出 totalLoss += Math.pow(this.targets.get(i) - yPredict, 2) / n; } //更新参数 this.weight -= w_; this.bias -= b_; return totalLoss; } } 复制代码
这个函数我们传入一个学习率,然后对参数进行更新。
首先w_、b_在这里表示的是两个参数的变化量,totalLoss表示参数的损失值。我们来关注下面这段代码:
double yPredict = this.features.get(i) * this.weight + this.bias; // loss对w的偏导 w_ += -2 * learning_rate * this.features.get(i) * (this.targets.get(i) - yPredict) / n; // loss对b的偏导 b_ += -2 * learning_rate * (this.targets.get(i) - (yPredict)) / n; // 计算loss用于输出 totalLoss += Math.pow(this.targets.get(i) - yPredict, 2) / n; 复制代码
可以看到我们只不过是在代入求偏导的公式。因为我们是循环求了n(训练数据的数量)次,所以我们需要除一个n。
然后我们更新参数即可。
5.3、预测结果
我们训练模型的目的就是预测结果,这一步非常简单,我们只需要带入x然后求y的值即可。代码如下:
package com.zack.lr; import java.util.ArrayList; public class LinearRegression { ...... /** * 预测结果 * @param features 特征值 * @return 预测结果 */ public ArrayList<Double> predict(ArrayList<Double> features){ //用于装预测结果 ArrayList<Double> yPredict = new ArrayList<>(); for (Double feature : features) { //对每个x进行预测 yPredict.add(feature * this.weight + this.bias); } return yPredict; } } 复制代码
这部分代码很简单,我们直接使用训练好的参数计算y的值,然后加入到预测集合,把结果返回就好了。
5.4、测试程序
下面我们编写一个测试程序来看看这个模型的效果如何。代码如下:
package com.zack.lr; import java.util.ArrayList; public class LrDemo { public static void main(String[] args) { ArrayList<Double> features = new ArrayList<>(); ArrayList<Double> targets = new ArrayList<>(); //准备特征值 for (int i = 0; i < 200; i++) { features.add((double)i); } //用y = 3x + 1生成目标值 for (Double feature : features) { //生成目标值,并加上一个随机数 double target = feature * 3 + 1 + Math.random() * 3; targets.add(target); } //创建线性回归模型 LinearRegression linearRegression = new LinearRegression(features, targets); for (long i = 1; i <= 300; i++){ double loss = linearRegression.gradientDecent(1e-6); if(i % 100 == 0){ System.out.println("第" + i + "次更新后"); System.out.println("weight = " + linearRegression.getWeight()); System.out.println("bias = " + linearRegression.getBias()); System.out.println("loss = " + loss); } } //准备数据用于测试 ArrayList<Double> testList = new ArrayList<>(); testList.add(100.0); testList.add(27.0); ArrayList<Double> testPredict = linearRegression.predict(testList); System.out.println("真实结果"); for (Double testX : testList) { System.out.println(testX * 3 + 1); } System.out.println("预测结果"); for (Double predict : testPredict) { System.out.println(predict); } } } 复制代码
我们生成了一些数据用来训练。然后设置一个比较小的学习率进行梯度下降,这也是我们通常的做法。
我们更新了300次参数就得到了一个比较好的结果,下面是输出结果:
第100次更新后 weight = 2.824277848777937 bias = 0.5474555396904982 loss = 511.60209782851484 第200次更新后 weight = 3.0023165757117525 bias = 0.5488865404002248 loss = 3.9653518084868575 第300次更新后 weight = 3.0144922328217816 bias = 0.5490704219725514 loss = 1.5908606096719522 真实结果 301.0 82.0 预测结果 301.9982937041507 81.94036070816065 复制代码
可以看到预测结果和真实结果已经非常接近了。到此我们就完成了一个简单的线性回归算法了。
本次案例是仅考虑一元的一元线性回归,感兴趣的读者可以考虑扩展到多元线性回归。感谢阅读~