【解决方案 二十八】Java实现逻辑回归预测模型

简介: 【解决方案 二十八】Java实现逻辑回归预测模型

R语言实现逻辑回归预测模型可以说相当方便,因为标准的库已经有人写好了,Java似乎不擅长统计学领域,所以实现比较复杂,这里给出一个Java实现的逻辑回归预测模型实现方式以及一些常用函数:

package com.exaple.cunzai;
import java.io.*;
import java.util.ArrayList;
import java.util.List;
public class BatchGradientDescent {
    //存储数据内容
    private static List<Double[]> list=new ArrayList<Double[]>();
    //构建训练集矩阵
    private static Double[][] Matrix;
    // 步长->学习率
    private static  double alpha = 0.001;
    // 迭代次数
    private static  int steps = 500;
    //初始化权重向量
    private static Double[][] weights;
    //初始化分类标签列表
    private static Double[][] target;
    //构建训练集矩阵
    public static void geMatrix(){
            //开始构建x+b 系数矩阵:b这里默认为1
            //初始化第一列默认1
            Matrix= new Double[list.size()][list.get(0).length];
            for(int i=0;i<list.size();i++){
                Matrix[i][0]=1.0;
            }
            //初始化第二列->值为list.get(i)数组中的第一列
            for(int i=0;i<list.size();i++){
                Matrix[i][1]=list.get(i)[0];
            }
            //初始化第二列->值为list.get(i)数组中的第二列
            for(int i=0;i<list.size();i++){
                Matrix[i][2]=list.get(i)[1];
            }
            //训练集矩阵构建完成list.size()个样本,特征list.get(i).length  矩阵(list.size()维度,list.get(i).length维度)
    }
    //初始化权重向量矩阵和真实标签矩阵
    public static void initWeights(){
        weights=new Double[list.get(0).length][1];
        weights[0][0]=1.0;
        weights[1][0]=1.0;
        weights[2][0]=1.0;
        target=new Double[list.size()][1];
        for(int i=0;i<list.size();i++){
            target[i][0]=list.get(i)[2];
        }
    }
    // Logistic函数->sigmoid
    public static Double[][] sigmoid(Double[][] wx) {
        Double[][] sigmod=new Double[wx.length][wx[0].length];
        for(int i=0;i<wx.length;i++){
            double v = 1.0 / (1 + Math.exp(-wx[i][0]));
            sigmod [i][0]=v;
        }
        return sigmod;
    }
    //矩阵相乘
    public static Double[][] MatrixMutMatrix(Double a[][], Double b[][]) {
        int arow = a.length;
        int bcol = b[0].length;
        int m = b.length;
        Double[][] c = new Double[arow][bcol];
        for (int i = 0; i < arow; i++) {
            for (int j = 0; j < bcol; j++) {
                Double result = 0.0;
                for (int k = 0; k < m; k++) {
                    result += a[i][k] * b[k][j];
                }
                c[i][j] = result;
            }
        }
        return c;
    }
    //矩阵相减->计算误差
    public static Double[][] subMatrix(Double[][] A, Double[][] B){
        int line=A.length,list=A[0].length;
        Double[][] C =new Double[line][list];
        for(int i=0;i<line;i++)
        {
            for(int j=0;j<list;j++)
            {
                C[i][j]=A[i][j]-B[i][j];
            }
        }
        return C;
    }
    // 将矩阵转置
    public static Double[][] revMatrix(Double temp [][]) {
        Double[][] result =new Double[temp[0].length][temp.length];
        for (int i = 0; i < result.length; i++) {
            for (int j = 0; j < result[i].length; j++) {
                result[i][j] = temp[j][i] ;
            }
        }
        return result;
    }
    // 将矩阵乘以一个数
    public static Double[][] mutMatrix(Double temp [][],Double v) {
        for (int i = 0; i < temp.length; i++) {
            for (int j = 0; j < temp[i].length; j++) {
                temp[i][j] = temp[i][j]*v;
            }
        }
        return temp;
    }
    //矩阵相加
    public static Double[][] AddsMatrix(Double[][]A,Double[][] B){
        int line=A.length,list=A[0].length;
        Double[][]C=new Double[line][list];
        for(int i=0;i<line;i++)
        {
            for(int j=0;j<list;j++)
            {
                C[i][j]=A[i][j]+B[i][j];
            }
        }
        return C;
    }
    //回归函数
    public static Double regression_calc(Double[][] w,Double[][] x){
        Double[][] result=sigmoid(MatrixMutMatrix(w,x));
        Double value=result[0][0];
        return value;
    }
    //分类函数
    public static Double classifier(Double[][]x,Double[][] w){
        Double[][] result=sigmoid(MatrixMutMatrix(w,x));
        Double value=result[0][0];
        Double v;
        if(value>0.5){
            v=1.0;
        }else{
            v=0.0;
        }
        return v;
    }
    //解析数据
    public static void getDate(){
        try {
            File file = new File("D:\\data\\lr_data\\testSet.txt");
            InputStreamReader inputReader = new InputStreamReader(new FileInputStream(file));
            BufferedReader bf = new BufferedReader(inputReader);
            String str;
            while ((str = bf.readLine()) != null){
                Double[] arr=new Double[3];
                String[] result=str.split("\t");
                for(int i=0;i<result.length;i++){
                    arr[i]=Double.parseDouble(result[i]);
                }
                list.add(arr);
            }
            bf.close();
            inputReader.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }
    /**
     *
     * @param args
     * 1、设置初始w,计算F(w)
     * 2、计算梯度 • 下降方向
     * 3、尝试梯度更新
     * 4、如果 较小,停止; 否则 ;跳到第2步
     */
    public static void main(String[] args) {
        getDate();
        geMatrix();
        initWeights();
        for(int i=0;i<steps;i++){
            //训练集矩阵 乘  权重  w*x
            Double[][] gradient=MatrixMutMatrix(Matrix,weights);
            //sigmoid函数  1/1+exp(-wx) 返回预测值
            Double[][] output=sigmoid(gradient);
            //真实值减预测值  返回误差
            Double[][] errors = subMatrix(target,output);
            //训练集矩阵 转置
            Double[][] dataMat=revMatrix(Matrix);
            //转置后的训练集矩阵 乘 步长
            Double[][] mut=mutMatrix(dataMat,alpha);
            //所有样本乘以误差
            Double[][] err=MatrixMutMatrix(mut,errors);
            //更新权重  权重 +步长∗ 梯度(误差)
            weights = AddsMatrix(weights,err);
        }
        System.out.println(weights[0][0]);
        System.out.println(weights[1][0]);
        System.out.println(weights[2][0]);
        /*得到权重
        4.178813076565532
        0.5048987439366058
        0.6198026439379993*/
        Double[][] x=new Double[1][3];
        x[0][0]=1.0;
        x[0][1]=0.9316350;
        x[0][2]=-1.589505;
        Double[][] w=new Double[1][3];
        w[0][0]=4.178813076565532;
        w[0][1]=0.5048987439366058;
        w[0][2]=0.6198026439379993;
        //回归函数
        Double a=regression_calc(w,x);
        //分类函数
        Double b=classifier(w,x);
        System.out.println(a);
        System.out.println(b);
    }
}


相关文章
|
1月前
|
网络协议 Java Shell
java spring 项目若依框架启动失败,启动不了服务提示端口8080占用escription: Web server failed to start. Port 8080 was already in use. Action: Identify and stop the process that’s listening on port 8080 or configure this application to listen on another port-优雅草卓伊凡解决方案
java spring 项目若依框架启动失败,启动不了服务提示端口8080占用escription: Web server failed to start. Port 8080 was already in use. Action: Identify and stop the process that’s listening on port 8080 or configure this application to listen on another port-优雅草卓伊凡解决方案
84 7
|
1月前
|
缓存 Java 应用服务中间件
java语言后台管理若依框架-登录提示404-接口异常-系统接口404异常如何处理-登录验证码不显示prod-api/captchaImage 404 (Not Found) 如何处理-解决方案优雅草卓伊凡
java语言后台管理若依框架-登录提示404-接口异常-系统接口404异常如何处理-登录验证码不显示prod-api/captchaImage 404 (Not Found) 如何处理-解决方案优雅草卓伊凡
230 5
|
3月前
|
JSON 前端开发 Java
【Bug合集】——Java大小写引起传参失败,获取值为null的解决方案
类中成员变量命名问题引起传送json字符串,但是变量为null的情况做出解释,@Data注解(Spring自动生成的get和set方法)和@JsonProperty
|
2月前
|
JSON 前端开发 安全
【潜意识java】前后端跨域问题及解决方案
本文深入探讨了跨域问题及其解决方案。跨域是指浏览器出于安全考虑,限制从一个域加载的网页请求另一个域的资源。
138 0
|
4月前
|
设计模式 Java 开发者
Java多线程编程的陷阱与解决方案####
本文深入探讨了Java多线程编程中常见的问题及其解决策略。通过分析竞态条件、死锁、活锁等典型场景,并结合代码示例和实用技巧,帮助开发者有效避免这些陷阱,提升并发程序的稳定性和性能。 ####
|
4月前
|
安全 Java 开发者
Java多线程编程中的常见问题与解决方案
本文深入探讨了Java多线程编程中常见的问题,包括线程安全问题、死锁、竞态条件等,并提供了相应的解决策略。文章首先介绍了多线程的基础知识,随后详细分析了每个问题的产生原因和典型场景,最后提出了实用的解决方案,旨在帮助开发者提高多线程程序的稳定性和性能。
|
4月前
|
人工智能 监控 数据可视化
Java智慧工地信息管理平台源码 智慧工地信息化解决方案SaaS源码 支持二次开发
智慧工地系统是依托物联网、互联网、AI、可视化建立的大数据管理平台,是一种全新的管理模式,能够实现劳务管理、安全施工、绿色施工的智能化和互联网化。围绕施工现场管理的人、机、料、法、环五大维度,以及施工过程管理的进度、质量、安全三大体系为基础应用,实现全面高效的工程管理需求,满足工地多角色、多视角的有效监管,实现工程建设管理的降本增效,为监管平台提供数据支撑。
80 3
|
4月前
|
Java API Apache
|
5月前
|
安全 Java
Java多线程通信新解:本文通过生产者-消费者模型案例,深入解析wait()、notify()、notifyAll()方法的实用技巧
【10月更文挑战第20天】Java多线程通信新解:本文通过生产者-消费者模型案例,深入解析wait()、notify()、notifyAll()方法的实用技巧,包括避免在循环外调用wait()、优先使用notifyAll()、确保线程安全及处理InterruptedException等,帮助读者更好地掌握这些方法的应用。
54 1
|
5月前
|
Java
短频快task的java解决方案
本文探讨了Java自带WorkStealingPool的缺陷,特别是在任务中断方面的不足。普通线程池在处理短频快任务时存在锁竞争问题,导致性能损耗。文章提出了一种基于任务窃取机制的优化方案,通过设计合理的窃取逻辑和减少性能损耗,实现了任务的高效执行和资源的充分利用。最后总结了不同场景下应选择的线程池类型。