一文入门Java机器学习

简介: 现在机器学习是一个非常热门的词,在生活中很多事情都会和机器学习扯上关系。那你知道机器学习是什么吗?其实很多读者都已经接触过机器学习的一些内容,只是没有用机器学习这个词罢了。今天就带大家来了解一下机器学习相关的知识,并使用Java实现一个非常常用的机器学习算法--线性回归。

一、前言

现在机器学习是一个非常热门的词,在生活中很多事情都会和机器学习扯上关系。那你知道机器学习是什么吗?其实很多读者都已经接触过机器学习的一些内容,只是没有用机器学习这个词罢了。

今天就带大家来了解一下机器学习相关的知识,并使用Java实现一个非常常用的机器学习算法--线性回归。

二、机器学习

机器学习是人工智能的一个子集,机器学习的目的是构建一个模型,通过已有的经验不断学习优化参数。用这个优化后的模型来预测还未发生的事情。这里就牵扯到了四个东西:

  1. 模型
  2. 经验
  3. 学习
  4. 预测

模型就是我们常说的机器学习算法,通常我们都会选择已有的一些模型。比如线性回归、逻辑回归、支持向量机等。我们需要根据问题类型来选择模型,本文介绍的是线性回归模型。

经验其实就是我们常说的数据,如果是面对天气预测的问题,我们的经验就是前几个小时或者前几天的天气数据

学习是机器学习中非常重要的一步。但是机器本身不会自己学习,需要人告诉机器如何学习。因此我们需要定义一个特殊的函数(损失函数),来帮助机器学习参数。

预测则是我们使用模型的过程,也是最简单的一步。

现在我们了解了一些基础的内容,下面我们详细看看线性回归的一些细节。

三、线性回归

3.1、找规律

相信大家都做过这样的题目,给定下面的数字,让你猜下一个数字是多少:

4
7
10
13
16
复制代码

如果盲目的猜我们很难猜到结果,现在我们假设上面数字满足方程:

image.png

我把数字的序号当作x,结果数字当作y。那我们就可以得到下面这几个坐标:

(1, 4)
(2, 7)
(3, 10)
(4, 13)
(5, 16)
复制代码

然后选两个坐标带入方程,比如选(1,4)和(3,10),我们就可以得到下面方程组:

image.png

我们就可以解出

image.png

然后再用其它坐标代入,验证一下方程的准确性。接下来我们就可以猜下一个数字是多少了,下一个数字的x是6,那么下一个数字应该是19。

3.2、线性回归

其实上面的就是线性回归,我们的目标就是找到一组最优的k和b。但是但是我们通常会把线性回归方程写出下面的样子:

image.png

这个方程和之前的方程没有区别,只是换了一个字母表示。其中w意为权重(weight),b意为偏置(bias)。

现在我们求解参数的方式和前面不一样了,并不是直接带入坐标然后解方程,而是通过一种叫梯度下降的算法来调节参数。在下一节我们会详细说到。

在机器学习中,我们把用来调节(训练)参数的数据叫做训练数据。其中x叫做特征值,y叫做目标值。

三、损失函数

在学习之前,我们通常会给模型一个初始化的参数。但是这个初始的参数通常不会太好,我们用这组参数会得到一个比较差的结果。那这个结果有多差呢?这就需要定义一个损失函数来评估这组参数到底有多差了。

损失函数是用来评估参数好坏的一个函数,通常也有很多现成的函数供我们选择。这里我们选择均方误差作为我们的损失函数:

image.png

其中y和x是我们已知的数据,而w和b是这个函数的变量。我们用wx+b计算出预测结果,然后用实际的y减去预测的y就是我们预测的误差了。为了保证这个值为正数,所以加了个平方。

这里需要注意,我们可以选择平方开根号,也可以选择不开根号。当我们开根号时,则叫做均方根误差。

上面我们只是对其中一个数字求了损失值,但是我们训练数据通常有很多,因此我们需要求所有误差的平均。所以完整的损失值计算如下:

image.png

这里只是加了个求平均的操作。

四、梯度下降

4.1、梯度下降

我们可以先观察一下参数和损失值之间的函数图形,为了方便观看,我们只考虑w参数,图像如下:

网络异常,图片无法展示
|

其中x轴为w参数,y轴为损失值。从图中可以看到B点是损失值最小的点(不考虑图外的区域)。

在B点的左右有A、C两个点,我们分别来分析一下这两个点。

A点在B点的左边,为了让损失值变小,我们需要往右更新参数,即增加参数。我们从图中可以看出,A点切线的斜率是小于0的,也就是A点处的导数小于0。这一点很重要,我们分析完C点再一起详细说。

C点在B点的右边,因此我们需要往左更新参数,即减少参数。从图中可以看出,C点的斜率要大于0,也就是C点处导数大于0。

分析完A、C两个点后,我们发现。当导数大于0时,我们需要减小参数,当导数小于0是我们需要增加参数。因此我们可以按照如下公式来更新参数:

image.png

其中更新后的参数等于原本参数减去损失函数对参数w的导数。

4.2、学习率

用上面的方式更新参数其实是有问题的,因为我们求到的导数很多时候是一个比较大的数值,如果我们直接减去导数的话就会出现左图的情况:

网络异常,图片无法展示
|

可以看到参数一直在峡谷振荡,要更新到最优参数要花很长的时间,或者根本更新不到最优参数。

因此我们可以让导数乘一个很小的数,这个数就是学习率,当我们使用学习率后我们参数更新可能就会更接近右图了。

4.3、多个参数更新

我们已经知道一个参数怎么更新了,那多个参数要怎么更新呢?在我们前面线性回归方程有w和b两个参数,更新会有什么区别吗?

其实并没有太大区别,我们只需要将求导数换成求偏导就好了,然后分别对w和b进行更新。

其中损失函数对w的偏导如下:

image.png

损失函数对b的偏导如下:

image.png

我们参数更新函数如下:

image.png

如果不记得求导公式可以直接套用上面的公式更新参数即可。

五、实现线性回归

下面我们用Java来实现一个线性回归模型。

5.1、类的创建和初始化参数

首先我们创建一个名为LinearRegression的类,代码如下:

package com.zack.lr;
import java.util.ArrayList;
public class LinearRegression {
    //权重
    private double weight;
    //偏置值
    private double bias;
    //特征值
    private ArrayList<Double> features;
    //目标值
    private ArrayList<Double> targets;
    /**
     * 构造线性回归模型
     * @param features 训练数据的特征值
     * @param targets 训练数据的目标值
     */
    public LinearRegression(ArrayList<Double> features, ArrayList<Double> targets){
        this.features = features;
        this.targets = targets;
        initParameter();
    }
    /**
     * 初始化权重和偏置值
     */
    public void initParameter(){
        this.weight = Math.random();
        this.bias = Math.random();
    }
}
复制代码

在类中我们定义了四个成员变量,分别是两个模型参数和训练数据的特征值和目标值。

然后我们需要在创建模型时传入训练数据的特征值和目标值,并调用initParameter函数随机初始模型的参数。

5.2、梯度下降

下面我们编写梯度下降算法,代码如下:

package com.zack.lr;
import java.util.ArrayList;
public class LinearRegression {
  ......
    /**
     * 梯度下降更新参数
     * @param learning_rate 学习率
     * @return 损失值
     */
    public double gradientDecent(double learning_rate){
        double w_ = 0;
        double b_ = 0;
        double totalLoss = 0;
        int n = this.features.size();
        double n = this.features.size();
        for (int i = 0; i < this.features.size(); i++) {
            double yPredict = this.features.get(i) * this.weight + this.bias;
            // loss对w的偏导
            w_ += -2 * learning_rate * this.features.get(i) * (this.targets.get(i) - yPredict) / n;
            // loss对b的偏导
            b_ += -2 * learning_rate * (this.targets.get(i) - (yPredict)) / n;
            // 计算loss用于输出
            totalLoss += Math.pow(this.targets.get(i) - yPredict, 2) / n;
        }
        //更新参数
        this.weight -= w_;
        this.bias -= b_;
        return totalLoss;
    }
}
复制代码

这个函数我们传入一个学习率,然后对参数进行更新。

首先w_、b_在这里表示的是两个参数的变化量,totalLoss表示参数的损失值。我们来关注下面这段代码:

double yPredict = this.features.get(i) * this.weight + this.bias;
// loss对w的偏导
w_ += -2 * learning_rate * this.features.get(i) * (this.targets.get(i) - yPredict) / n;
// loss对b的偏导
b_ += -2 * learning_rate * (this.targets.get(i) - (yPredict)) / n;
// 计算loss用于输出
totalLoss += Math.pow(this.targets.get(i) - yPredict, 2) / n;
复制代码

可以看到我们只不过是在代入求偏导的公式。因为我们是循环求了n(训练数据的数量)次,所以我们需要除一个n。

然后我们更新参数即可。

5.3、预测结果

我们训练模型的目的就是预测结果,这一步非常简单,我们只需要带入x然后求y的值即可。代码如下:

package com.zack.lr;
import java.util.ArrayList;
public class LinearRegression {
  ......
    /**
     * 预测结果
     * @param features 特征值
     * @return 预测结果
     */
    public ArrayList<Double> predict(ArrayList<Double> features){
        //用于装预测结果
        ArrayList<Double> yPredict = new ArrayList<>();
        for (Double feature : features) {
            //对每个x进行预测
            yPredict.add(feature * this.weight + this.bias);
        }
        return yPredict;
    }
}
复制代码

这部分代码很简单,我们直接使用训练好的参数计算y的值,然后加入到预测集合,把结果返回就好了。

5.4、测试程序

下面我们编写一个测试程序来看看这个模型的效果如何。代码如下:

package com.zack.lr;
import java.util.ArrayList;
public class LrDemo {
    public static void main(String[] args) {
        ArrayList<Double> features = new ArrayList<>();
        ArrayList<Double> targets = new ArrayList<>();
        //准备特征值
        for (int i = 0; i < 200; i++) {
            features.add((double)i);
        }
        //用y = 3x + 1生成目标值
        for (Double feature : features) {
            //生成目标值,并加上一个随机数
            double target = feature * 3 + 1 + Math.random() * 3;
            targets.add(target);
        }
        //创建线性回归模型
        LinearRegression linearRegression = new LinearRegression(features, targets);
        for (long i = 1; i <= 300; i++){
            double loss = linearRegression.gradientDecent(1e-6);
            if(i % 100 == 0){
                System.out.println("第" + i + "次更新后");
                System.out.println("weight = " + linearRegression.getWeight());
                System.out.println("bias = " + linearRegression.getBias());
                System.out.println("loss = " + loss);
            }
        }
        //准备数据用于测试
        ArrayList<Double> testList = new ArrayList<>();
        testList.add(100.0);
        testList.add(27.0);
        ArrayList<Double> testPredict = linearRegression.predict(testList);
        System.out.println("真实结果");
        for (Double testX : testList) {
            System.out.println(testX * 3 + 1);
        }
        System.out.println("预测结果");
        for (Double predict : testPredict) {
            System.out.println(predict);
        }
    }
}
复制代码

我们生成了一些数据用来训练。然后设置一个比较小的学习率进行梯度下降,这也是我们通常的做法。

我们更新了300次参数就得到了一个比较好的结果,下面是输出结果:

第100次更新后
weight = 2.824277848777937
bias = 0.5474555396904982
loss = 511.60209782851484
第200次更新后
weight = 3.0023165757117525
bias = 0.5488865404002248
loss = 3.9653518084868575
第300次更新后
weight = 3.0144922328217816
bias = 0.5490704219725514
loss = 1.5908606096719522
真实结果
301.0
82.0
预测结果
301.9982937041507
81.94036070816065
复制代码

可以看到预测结果和真实结果已经非常接近了。到此我们就完成了一个简单的线性回归算法了。

本次案例是仅考虑一元的一元线性回归,感兴趣的读者可以考虑扩展到多元线性回归。感谢阅读~

目录
相关文章
|
8天前
|
自然语言处理 Java
Java中的字符集编码入门-增补字符(转载)
本文探讨Java对Unicode的支持及其发展历程。文章详细解析了Unicode字符集的结构,包括基本多语言面(BMP)和增补字符的表示方法,以及UTF-16编码中surrogate pair的使用。同时介绍了代码点和代码单元的概念,并解释了UTF-8的编码规则及其兼容性。
81 60
|
26天前
|
机器学习/深度学习 传感器 运维
使用机器学习技术进行时间序列缺失数据填充:基础方法与入门案例
本文探讨了时间序列分析中数据缺失的问题,并通过实际案例展示了如何利用机器学习技术进行缺失值补充。文章构建了一个模拟的能源生产数据集,采用线性回归和决策树回归两种方法进行缺失值补充,并从统计特征、自相关性、趋势和季节性等多个维度进行了详细评估。结果显示,决策树方法在处理复杂非线性模式和保持数据局部特征方面表现更佳,而线性回归方法则适用于简单的线性趋势数据。文章最后总结了两种方法的优劣,并给出了实际应用建议。
65 7
使用机器学习技术进行时间序列缺失数据填充:基础方法与入门案例
|
1月前
|
Java 开发者 微服务
Spring Boot 入门:简化 Java Web 开发的强大工具
Spring Boot 是一个开源的 Java 基础框架,用于创建独立、生产级别的基于Spring框架的应用程序。它旨在简化Spring应用的初始搭建以及开发过程。
67 6
Spring Boot 入门:简化 Java Web 开发的强大工具
|
1月前
|
监控 架构师 Java
Java虚拟机调优的艺术:从入门到精通####
本文作为一篇深入浅出的技术指南,旨在为Java开发者揭示JVM调优的神秘面纱,通过剖析其背后的原理、分享实战经验与最佳实践,引领读者踏上从调优新手到高手的进阶之路。不同于传统的摘要概述,本文将以一场虚拟的对话形式,模拟一位经验丰富的架构师向初学者传授JVM调优的心法,激发学习兴趣,同时概括性地介绍文章将探讨的核心议题——性能监控、垃圾回收优化、内存管理及常见问题解决策略。 ####
|
2月前
|
机器学习/深度学习 数据采集
机器学习入门——使用Scikit-Learn构建分类器
机器学习入门——使用Scikit-Learn构建分类器
|
2月前
|
监控 安全 Java
Java中的多线程编程:从入门到实践####
本文将深入浅出地探讨Java多线程编程的核心概念、应用场景及实践技巧。不同于传统的摘要形式,本文将以一个简短的代码示例作为开篇,直接展示多线程的魅力,随后再详细解析其背后的原理与实现方式,旨在帮助读者快速理解并掌握Java多线程编程的基本技能。 ```java // 简单的多线程示例:创建两个线程,分别打印不同的消息 public class SimpleMultithreading { public static void main(String[] args) { Thread thread1 = new Thread(() -> System.out.prin
|
2月前
|
Java 大数据 API
14天Java基础学习——第1天:Java入门和环境搭建
本文介绍了Java的基础知识,包括Java的简介、历史和应用领域。详细讲解了如何安装JDK并配置环境变量,以及如何使用IntelliJ IDEA创建和运行Java项目。通过示例代码“HelloWorld.java”,展示了从编写到运行的全过程。适合初学者快速入门Java编程。
|
2月前
|
Java 程序员 数据库连接
Java中的异常处理:从入门到精通
在Java编程的海洋中,异常处理是一艘不可或缺的救生艇。它不仅保护你的代码免受错误数据的侵袭,还能确保用户体验的平稳航行。本文将带你领略异常处理的风浪,让你学会如何在Java中捕捉、处理和预防异常,从而成为一名真正的Java航海家。
|
2月前
|
机器学习/深度学习 数据采集 人工智能
机器学习入门:Python与scikit-learn实战
机器学习入门:Python与scikit-learn实战
73 0
|
2月前
|
机器学习/深度学习 算法 Python
机器学习入门:理解并实现K-近邻算法
机器学习入门:理解并实现K-近邻算法
42 0
下一篇
开通oss服务