一文入门Java机器学习

简介: 现在机器学习是一个非常热门的词,在生活中很多事情都会和机器学习扯上关系。那你知道机器学习是什么吗?其实很多读者都已经接触过机器学习的一些内容,只是没有用机器学习这个词罢了。今天就带大家来了解一下机器学习相关的知识,并使用Java实现一个非常常用的机器学习算法--线性回归。

一、前言

现在机器学习是一个非常热门的词,在生活中很多事情都会和机器学习扯上关系。那你知道机器学习是什么吗?其实很多读者都已经接触过机器学习的一些内容,只是没有用机器学习这个词罢了。

今天就带大家来了解一下机器学习相关的知识,并使用Java实现一个非常常用的机器学习算法--线性回归。

二、机器学习

机器学习是人工智能的一个子集,机器学习的目的是构建一个模型,通过已有的经验不断学习优化参数。用这个优化后的模型来预测还未发生的事情。这里就牵扯到了四个东西:

  1. 模型
  2. 经验
  3. 学习
  4. 预测

模型就是我们常说的机器学习算法,通常我们都会选择已有的一些模型。比如线性回归、逻辑回归、支持向量机等。我们需要根据问题类型来选择模型,本文介绍的是线性回归模型。

经验其实就是我们常说的数据,如果是面对天气预测的问题,我们的经验就是前几个小时或者前几天的天气数据

学习是机器学习中非常重要的一步。但是机器本身不会自己学习,需要人告诉机器如何学习。因此我们需要定义一个特殊的函数(损失函数),来帮助机器学习参数。

预测则是我们使用模型的过程,也是最简单的一步。

现在我们了解了一些基础的内容,下面我们详细看看线性回归的一些细节。

三、线性回归

3.1、找规律

相信大家都做过这样的题目,给定下面的数字,让你猜下一个数字是多少:

4
7
10
13
16
复制代码

如果盲目的猜我们很难猜到结果,现在我们假设上面数字满足方程:

image.png

我把数字的序号当作x,结果数字当作y。那我们就可以得到下面这几个坐标:

(1, 4)
(2, 7)
(3, 10)
(4, 13)
(5, 16)
复制代码

然后选两个坐标带入方程,比如选(1,4)和(3,10),我们就可以得到下面方程组:

image.png

我们就可以解出

image.png

然后再用其它坐标代入,验证一下方程的准确性。接下来我们就可以猜下一个数字是多少了,下一个数字的x是6,那么下一个数字应该是19。

3.2、线性回归

其实上面的就是线性回归,我们的目标就是找到一组最优的k和b。但是但是我们通常会把线性回归方程写出下面的样子:

image.png

这个方程和之前的方程没有区别,只是换了一个字母表示。其中w意为权重(weight),b意为偏置(bias)。

现在我们求解参数的方式和前面不一样了,并不是直接带入坐标然后解方程,而是通过一种叫梯度下降的算法来调节参数。在下一节我们会详细说到。

在机器学习中,我们把用来调节(训练)参数的数据叫做训练数据。其中x叫做特征值,y叫做目标值。

三、损失函数

在学习之前,我们通常会给模型一个初始化的参数。但是这个初始的参数通常不会太好,我们用这组参数会得到一个比较差的结果。那这个结果有多差呢?这就需要定义一个损失函数来评估这组参数到底有多差了。

损失函数是用来评估参数好坏的一个函数,通常也有很多现成的函数供我们选择。这里我们选择均方误差作为我们的损失函数:

image.png

其中y和x是我们已知的数据,而w和b是这个函数的变量。我们用wx+b计算出预测结果,然后用实际的y减去预测的y就是我们预测的误差了。为了保证这个值为正数,所以加了个平方。

这里需要注意,我们可以选择平方开根号,也可以选择不开根号。当我们开根号时,则叫做均方根误差。

上面我们只是对其中一个数字求了损失值,但是我们训练数据通常有很多,因此我们需要求所有误差的平均。所以完整的损失值计算如下:

image.png

这里只是加了个求平均的操作。

四、梯度下降

4.1、梯度下降

我们可以先观察一下参数和损失值之间的函数图形,为了方便观看,我们只考虑w参数,图像如下:

网络异常,图片无法展示
|

其中x轴为w参数,y轴为损失值。从图中可以看到B点是损失值最小的点(不考虑图外的区域)。

在B点的左右有A、C两个点,我们分别来分析一下这两个点。

A点在B点的左边,为了让损失值变小,我们需要往右更新参数,即增加参数。我们从图中可以看出,A点切线的斜率是小于0的,也就是A点处的导数小于0。这一点很重要,我们分析完C点再一起详细说。

C点在B点的右边,因此我们需要往左更新参数,即减少参数。从图中可以看出,C点的斜率要大于0,也就是C点处导数大于0。

分析完A、C两个点后,我们发现。当导数大于0时,我们需要减小参数,当导数小于0是我们需要增加参数。因此我们可以按照如下公式来更新参数:

image.png

其中更新后的参数等于原本参数减去损失函数对参数w的导数。

4.2、学习率

用上面的方式更新参数其实是有问题的,因为我们求到的导数很多时候是一个比较大的数值,如果我们直接减去导数的话就会出现左图的情况:

网络异常,图片无法展示
|

可以看到参数一直在峡谷振荡,要更新到最优参数要花很长的时间,或者根本更新不到最优参数。

因此我们可以让导数乘一个很小的数,这个数就是学习率,当我们使用学习率后我们参数更新可能就会更接近右图了。

4.3、多个参数更新

我们已经知道一个参数怎么更新了,那多个参数要怎么更新呢?在我们前面线性回归方程有w和b两个参数,更新会有什么区别吗?

其实并没有太大区别,我们只需要将求导数换成求偏导就好了,然后分别对w和b进行更新。

其中损失函数对w的偏导如下:

image.png

损失函数对b的偏导如下:

image.png

我们参数更新函数如下:

image.png

如果不记得求导公式可以直接套用上面的公式更新参数即可。

五、实现线性回归

下面我们用Java来实现一个线性回归模型。

5.1、类的创建和初始化参数

首先我们创建一个名为LinearRegression的类,代码如下:

package com.zack.lr;
import java.util.ArrayList;
public class LinearRegression {
    //权重
    private double weight;
    //偏置值
    private double bias;
    //特征值
    private ArrayList<Double> features;
    //目标值
    private ArrayList<Double> targets;
    /**
     * 构造线性回归模型
     * @param features 训练数据的特征值
     * @param targets 训练数据的目标值
     */
    public LinearRegression(ArrayList<Double> features, ArrayList<Double> targets){
        this.features = features;
        this.targets = targets;
        initParameter();
    }
    /**
     * 初始化权重和偏置值
     */
    public void initParameter(){
        this.weight = Math.random();
        this.bias = Math.random();
    }
}
复制代码

在类中我们定义了四个成员变量,分别是两个模型参数和训练数据的特征值和目标值。

然后我们需要在创建模型时传入训练数据的特征值和目标值,并调用initParameter函数随机初始模型的参数。

5.2、梯度下降

下面我们编写梯度下降算法,代码如下:

package com.zack.lr;
import java.util.ArrayList;
public class LinearRegression {
  ......
    /**
     * 梯度下降更新参数
     * @param learning_rate 学习率
     * @return 损失值
     */
    public double gradientDecent(double learning_rate){
        double w_ = 0;
        double b_ = 0;
        double totalLoss = 0;
        int n = this.features.size();
        double n = this.features.size();
        for (int i = 0; i < this.features.size(); i++) {
            double yPredict = this.features.get(i) * this.weight + this.bias;
            // loss对w的偏导
            w_ += -2 * learning_rate * this.features.get(i) * (this.targets.get(i) - yPredict) / n;
            // loss对b的偏导
            b_ += -2 * learning_rate * (this.targets.get(i) - (yPredict)) / n;
            // 计算loss用于输出
            totalLoss += Math.pow(this.targets.get(i) - yPredict, 2) / n;
        }
        //更新参数
        this.weight -= w_;
        this.bias -= b_;
        return totalLoss;
    }
}
复制代码

这个函数我们传入一个学习率,然后对参数进行更新。

首先w_、b_在这里表示的是两个参数的变化量,totalLoss表示参数的损失值。我们来关注下面这段代码:

double yPredict = this.features.get(i) * this.weight + this.bias;
// loss对w的偏导
w_ += -2 * learning_rate * this.features.get(i) * (this.targets.get(i) - yPredict) / n;
// loss对b的偏导
b_ += -2 * learning_rate * (this.targets.get(i) - (yPredict)) / n;
// 计算loss用于输出
totalLoss += Math.pow(this.targets.get(i) - yPredict, 2) / n;
复制代码

可以看到我们只不过是在代入求偏导的公式。因为我们是循环求了n(训练数据的数量)次,所以我们需要除一个n。

然后我们更新参数即可。

5.3、预测结果

我们训练模型的目的就是预测结果,这一步非常简单,我们只需要带入x然后求y的值即可。代码如下:

package com.zack.lr;
import java.util.ArrayList;
public class LinearRegression {
  ......
    /**
     * 预测结果
     * @param features 特征值
     * @return 预测结果
     */
    public ArrayList<Double> predict(ArrayList<Double> features){
        //用于装预测结果
        ArrayList<Double> yPredict = new ArrayList<>();
        for (Double feature : features) {
            //对每个x进行预测
            yPredict.add(feature * this.weight + this.bias);
        }
        return yPredict;
    }
}
复制代码

这部分代码很简单,我们直接使用训练好的参数计算y的值,然后加入到预测集合,把结果返回就好了。

5.4、测试程序

下面我们编写一个测试程序来看看这个模型的效果如何。代码如下:

package com.zack.lr;
import java.util.ArrayList;
public class LrDemo {
    public static void main(String[] args) {
        ArrayList<Double> features = new ArrayList<>();
        ArrayList<Double> targets = new ArrayList<>();
        //准备特征值
        for (int i = 0; i < 200; i++) {
            features.add((double)i);
        }
        //用y = 3x + 1生成目标值
        for (Double feature : features) {
            //生成目标值,并加上一个随机数
            double target = feature * 3 + 1 + Math.random() * 3;
            targets.add(target);
        }
        //创建线性回归模型
        LinearRegression linearRegression = new LinearRegression(features, targets);
        for (long i = 1; i <= 300; i++){
            double loss = linearRegression.gradientDecent(1e-6);
            if(i % 100 == 0){
                System.out.println("第" + i + "次更新后");
                System.out.println("weight = " + linearRegression.getWeight());
                System.out.println("bias = " + linearRegression.getBias());
                System.out.println("loss = " + loss);
            }
        }
        //准备数据用于测试
        ArrayList<Double> testList = new ArrayList<>();
        testList.add(100.0);
        testList.add(27.0);
        ArrayList<Double> testPredict = linearRegression.predict(testList);
        System.out.println("真实结果");
        for (Double testX : testList) {
            System.out.println(testX * 3 + 1);
        }
        System.out.println("预测结果");
        for (Double predict : testPredict) {
            System.out.println(predict);
        }
    }
}
复制代码

我们生成了一些数据用来训练。然后设置一个比较小的学习率进行梯度下降,这也是我们通常的做法。

我们更新了300次参数就得到了一个比较好的结果,下面是输出结果:

第100次更新后
weight = 2.824277848777937
bias = 0.5474555396904982
loss = 511.60209782851484
第200次更新后
weight = 3.0023165757117525
bias = 0.5488865404002248
loss = 3.9653518084868575
第300次更新后
weight = 3.0144922328217816
bias = 0.5490704219725514
loss = 1.5908606096719522
真实结果
301.0
82.0
预测结果
301.9982937041507
81.94036070816065
复制代码

可以看到预测结果和真实结果已经非常接近了。到此我们就完成了一个简单的线性回归算法了。

本次案例是仅考虑一元的一元线性回归,感兴趣的读者可以考虑扩展到多元线性回归。感谢阅读~

目录
相关文章
|
7月前
|
存储 Oracle Java
java零基础学习者入门课程
本课程为Java零基础入门教程,涵盖环境搭建、变量、运算符、条件循环、数组及面向对象基础,每讲配示例代码与实践建议,助你循序渐进掌握核心知识,轻松迈入Java编程世界。
581 0
|
9月前
|
安全 Java 数据库连接
2025 年最新 Java 学习路线图含实操指南助你高效入门 Java 编程掌握核心技能
2025年最新Java学习路线图,涵盖基础环境搭建、核心特性(如密封类、虚拟线程)、模块化开发、响应式编程、主流框架(Spring Boot 3、Spring Security 6)、数据库操作(JPA + Hibernate 6)及微服务实战,助你掌握企业级开发技能。
1098 3
|
8月前
|
Java
java入门代码示例
本文介绍Java入门基础,包含Hello World、变量类型、条件判断、循环及方法定义等核心语法示例,帮助初学者快速掌握Java编程基本结构与逻辑。
609 0
|
8月前
|
机器学习/深度学习 数据采集 算法
量子机器学习入门:三种数据编码方法对比与应用
在量子机器学习中,数据编码方式决定了量子模型如何理解和处理信息。本文详解角度编码、振幅编码与基础编码三种方法,分析其原理、实现及适用场景,帮助读者选择最适合的编码策略,提升量子模型性能。
626 8
|
8月前
|
前端开发 Java 数据库连接
帮助新手快速上手的 JAVA 学习路线最详细版涵盖从入门到进阶的 JAVA 学习路线
本Java学习路线涵盖从基础语法、面向对象、异常处理到高级框架、微服务、JVM调优等内容,适合新手入门到进阶,助力掌握企业级开发技能,快速成为合格Java开发者。
1170 3
|
9月前
|
NoSQL Java 关系型数据库
Java 从入门到进阶完整学习路线图规划与实战开发最佳实践指南
本文为Java开发者提供从入门到进阶的完整学习路线图,涵盖基础语法、面向对象、数据结构与算法、并发编程、JVM调优、主流框架(如Spring Boot)、数据库操作(MySQL、Redis)、微服务架构及云原生开发等内容,并结合实战案例与最佳实践,助力高效掌握Java核心技术。
934 0
|
9月前
|
Java 测试技术 API
Java IO流(二):文件操作与NIO入门
本文详解Java NIO与传统IO的区别与优势,涵盖Path、Files类、Channel、Buffer、Selector等核心概念,深入讲解文件操作、目录遍历、NIO实战及性能优化技巧,适合处理大文件与高并发场景,助力高效IO编程与面试准备。
|
9月前
|
Java 编译器 API
Java Lambda表达式与函数式编程入门
Lambda表达式是Java 8引入的重要特性,简化了函数式编程的实现方式。它通过简洁的语法替代传统的匿名内部类,使代码更清晰、易读。本文深入讲解Lambda表达式的基本语法、函数式接口、方法引用等核心概念,并结合集合操作、线程处理、事件回调等实战案例,帮助开发者掌握现代Java编程技巧。同时,还解析了面试中高频出现的相关问题,助你深入理解其原理与应用场景。
|
8月前
|
Java API 数据库
2025 年最新 Java 实操学习路线,从入门到高级应用详细指南
2025年Java最新实操学习路线,涵盖从环境搭建到微服务、容器化部署的全流程实战内容,助你掌握Java 21核心特性、Spring Boot 3.2开发、云原生与微服务架构,提升企业级项目开发能力,适合从入门到高级应用的学习需求。
2486 0
|
8月前
|
监控 Java API
2025 年全新出炉的 Java 学习路线:从入门起步到实操精通的详细指南
2025年Java学习路线与实操指南,涵盖Java 21核心特性、虚拟线程、Spring Boot 3、微服务、Spring Security、容器化部署等前沿技术,助你从入门到企业级开发进阶。
1561 0