【解决方案 二十八】Java实现逻辑回归预测模型

简介: 【解决方案 二十八】Java实现逻辑回归预测模型

R语言实现逻辑回归预测模型可以说相当方便,因为标准的库已经有人写好了,Java似乎不擅长统计学领域,所以实现比较复杂,这里给出一个Java实现的逻辑回归预测模型实现方式以及一些常用函数:

package com.exaple.cunzai;
import java.io.*;
import java.util.ArrayList;
import java.util.List;
public class BatchGradientDescent {
    //存储数据内容
    private static List<Double[]> list=new ArrayList<Double[]>();
    //构建训练集矩阵
    private static Double[][] Matrix;
    // 步长->学习率
    private static  double alpha = 0.001;
    // 迭代次数
    private static  int steps = 500;
    //初始化权重向量
    private static Double[][] weights;
    //初始化分类标签列表
    private static Double[][] target;
    //构建训练集矩阵
    public static void geMatrix(){
            //开始构建x+b 系数矩阵:b这里默认为1
            //初始化第一列默认1
            Matrix= new Double[list.size()][list.get(0).length];
            for(int i=0;i<list.size();i++){
                Matrix[i][0]=1.0;
            }
            //初始化第二列->值为list.get(i)数组中的第一列
            for(int i=0;i<list.size();i++){
                Matrix[i][1]=list.get(i)[0];
            }
            //初始化第二列->值为list.get(i)数组中的第二列
            for(int i=0;i<list.size();i++){
                Matrix[i][2]=list.get(i)[1];
            }
            //训练集矩阵构建完成list.size()个样本,特征list.get(i).length  矩阵(list.size()维度,list.get(i).length维度)
    }
    //初始化权重向量矩阵和真实标签矩阵
    public static void initWeights(){
        weights=new Double[list.get(0).length][1];
        weights[0][0]=1.0;
        weights[1][0]=1.0;
        weights[2][0]=1.0;
        target=new Double[list.size()][1];
        for(int i=0;i<list.size();i++){
            target[i][0]=list.get(i)[2];
        }
    }
    // Logistic函数->sigmoid
    public static Double[][] sigmoid(Double[][] wx) {
        Double[][] sigmod=new Double[wx.length][wx[0].length];
        for(int i=0;i<wx.length;i++){
            double v = 1.0 / (1 + Math.exp(-wx[i][0]));
            sigmod [i][0]=v;
        }
        return sigmod;
    }
    //矩阵相乘
    public static Double[][] MatrixMutMatrix(Double a[][], Double b[][]) {
        int arow = a.length;
        int bcol = b[0].length;
        int m = b.length;
        Double[][] c = new Double[arow][bcol];
        for (int i = 0; i < arow; i++) {
            for (int j = 0; j < bcol; j++) {
                Double result = 0.0;
                for (int k = 0; k < m; k++) {
                    result += a[i][k] * b[k][j];
                }
                c[i][j] = result;
            }
        }
        return c;
    }
    //矩阵相减->计算误差
    public static Double[][] subMatrix(Double[][] A, Double[][] B){
        int line=A.length,list=A[0].length;
        Double[][] C =new Double[line][list];
        for(int i=0;i<line;i++)
        {
            for(int j=0;j<list;j++)
            {
                C[i][j]=A[i][j]-B[i][j];
            }
        }
        return C;
    }
    // 将矩阵转置
    public static Double[][] revMatrix(Double temp [][]) {
        Double[][] result =new Double[temp[0].length][temp.length];
        for (int i = 0; i < result.length; i++) {
            for (int j = 0; j < result[i].length; j++) {
                result[i][j] = temp[j][i] ;
            }
        }
        return result;
    }
    // 将矩阵乘以一个数
    public static Double[][] mutMatrix(Double temp [][],Double v) {
        for (int i = 0; i < temp.length; i++) {
            for (int j = 0; j < temp[i].length; j++) {
                temp[i][j] = temp[i][j]*v;
            }
        }
        return temp;
    }
    //矩阵相加
    public static Double[][] AddsMatrix(Double[][]A,Double[][] B){
        int line=A.length,list=A[0].length;
        Double[][]C=new Double[line][list];
        for(int i=0;i<line;i++)
        {
            for(int j=0;j<list;j++)
            {
                C[i][j]=A[i][j]+B[i][j];
            }
        }
        return C;
    }
    //回归函数
    public static Double regression_calc(Double[][] w,Double[][] x){
        Double[][] result=sigmoid(MatrixMutMatrix(w,x));
        Double value=result[0][0];
        return value;
    }
    //分类函数
    public static Double classifier(Double[][]x,Double[][] w){
        Double[][] result=sigmoid(MatrixMutMatrix(w,x));
        Double value=result[0][0];
        Double v;
        if(value>0.5){
            v=1.0;
        }else{
            v=0.0;
        }
        return v;
    }
    //解析数据
    public static void getDate(){
        try {
            File file = new File("D:\\data\\lr_data\\testSet.txt");
            InputStreamReader inputReader = new InputStreamReader(new FileInputStream(file));
            BufferedReader bf = new BufferedReader(inputReader);
            String str;
            while ((str = bf.readLine()) != null){
                Double[] arr=new Double[3];
                String[] result=str.split("\t");
                for(int i=0;i<result.length;i++){
                    arr[i]=Double.parseDouble(result[i]);
                }
                list.add(arr);
            }
            bf.close();
            inputReader.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }
    /**
     *
     * @param args
     * 1、设置初始w,计算F(w)
     * 2、计算梯度 • 下降方向
     * 3、尝试梯度更新
     * 4、如果 较小,停止; 否则 ;跳到第2步
     */
    public static void main(String[] args) {
        getDate();
        geMatrix();
        initWeights();
        for(int i=0;i<steps;i++){
            //训练集矩阵 乘  权重  w*x
            Double[][] gradient=MatrixMutMatrix(Matrix,weights);
            //sigmoid函数  1/1+exp(-wx) 返回预测值
            Double[][] output=sigmoid(gradient);
            //真实值减预测值  返回误差
            Double[][] errors = subMatrix(target,output);
            //训练集矩阵 转置
            Double[][] dataMat=revMatrix(Matrix);
            //转置后的训练集矩阵 乘 步长
            Double[][] mut=mutMatrix(dataMat,alpha);
            //所有样本乘以误差
            Double[][] err=MatrixMutMatrix(mut,errors);
            //更新权重  权重 +步长∗ 梯度(误差)
            weights = AddsMatrix(weights,err);
        }
        System.out.println(weights[0][0]);
        System.out.println(weights[1][0]);
        System.out.println(weights[2][0]);
        /*得到权重
        4.178813076565532
        0.5048987439366058
        0.6198026439379993*/
        Double[][] x=new Double[1][3];
        x[0][0]=1.0;
        x[0][1]=0.9316350;
        x[0][2]=-1.589505;
        Double[][] w=new Double[1][3];
        w[0][0]=4.178813076565532;
        w[0][1]=0.5048987439366058;
        w[0][2]=0.6198026439379993;
        //回归函数
        Double a=regression_calc(w,x);
        //分类函数
        Double b=classifier(w,x);
        System.out.println(a);
        System.out.println(b);
    }
}


相关文章
|
2月前
|
关系型数据库 MySQL Java
【IDEA】java后台操作mysql数据库驱动常见错误解决方案
【IDEA】java后台操作mysql数据库驱动常见错误解决方案
82 0
|
21天前
|
安全 Java 开发者
Java多线程编程中的常见问题与解决方案
本文深入探讨了Java多线程编程中常见的问题,包括线程安全问题、死锁、竞态条件等,并提供了相应的解决策略。文章首先介绍了多线程的基础知识,随后详细分析了每个问题的产生原因和典型场景,最后提出了实用的解决方案,旨在帮助开发者提高多线程程序的稳定性和性能。
|
27天前
|
人工智能 监控 数据可视化
Java智慧工地信息管理平台源码 智慧工地信息化解决方案SaaS源码 支持二次开发
智慧工地系统是依托物联网、互联网、AI、可视化建立的大数据管理平台,是一种全新的管理模式,能够实现劳务管理、安全施工、绿色施工的智能化和互联网化。围绕施工现场管理的人、机、料、法、环五大维度,以及施工过程管理的进度、质量、安全三大体系为基础应用,实现全面高效的工程管理需求,满足工地多角色、多视角的有效监管,实现工程建设管理的降本增效,为监管平台提供数据支撑。
35 3
|
1月前
|
Java API Apache
|
2月前
|
安全 Java
Java多线程通信新解:本文通过生产者-消费者模型案例,深入解析wait()、notify()、notifyAll()方法的实用技巧
【10月更文挑战第20天】Java多线程通信新解:本文通过生产者-消费者模型案例,深入解析wait()、notify()、notifyAll()方法的实用技巧,包括避免在循环外调用wait()、优先使用notifyAll()、确保线程安全及处理InterruptedException等,帮助读者更好地掌握这些方法的应用。
22 1
|
2月前
|
Java
短频快task的java解决方案
本文探讨了Java自带WorkStealingPool的缺陷,特别是在任务中断方面的不足。普通线程池在处理短频快任务时存在锁竞争问题,导致性能损耗。文章提出了一种基于任务窃取机制的优化方案,通过设计合理的窃取逻辑和减少性能损耗,实现了任务的高效执行和资源的充分利用。最后总结了不同场景下应选择的线程池类型。
|
2月前
|
SQL 分布式计算 Java
Hadoop-11-MapReduce JOIN 操作的Java实现 Driver Mapper Reducer具体实现逻辑 模拟SQL进行联表操作
Hadoop-11-MapReduce JOIN 操作的Java实现 Driver Mapper Reducer具体实现逻辑 模拟SQL进行联表操作
37 3
|
2月前
|
小程序 Java
小程序访问java后台失败解决方案
小程序访问java后台失败解决方案
47 2
|
2月前
|
存储 前端开发 Java
浅谈Java中文乱码浅析及解决方案
浅谈Java中文乱码浅析及解决方案
80 0
|
2月前
|
Java
Error:java: 无效的目标发行版: 11解决方案
Error:java: 无效的目标发行版: 11解决方案
82 0