【解决方案 二十八】Java实现逻辑回归预测模型

简介: 【解决方案 二十八】Java实现逻辑回归预测模型

R语言实现逻辑回归预测模型可以说相当方便,因为标准的库已经有人写好了,Java似乎不擅长统计学领域,所以实现比较复杂,这里给出一个Java实现的逻辑回归预测模型实现方式以及一些常用函数:

package com.exaple.cunzai;
import java.io.*;
import java.util.ArrayList;
import java.util.List;
public class BatchGradientDescent {
    //存储数据内容
    private static List<Double[]> list=new ArrayList<Double[]>();
    //构建训练集矩阵
    private static Double[][] Matrix;
    // 步长->学习率
    private static  double alpha = 0.001;
    // 迭代次数
    private static  int steps = 500;
    //初始化权重向量
    private static Double[][] weights;
    //初始化分类标签列表
    private static Double[][] target;
    //构建训练集矩阵
    public static void geMatrix(){
            //开始构建x+b 系数矩阵:b这里默认为1
            //初始化第一列默认1
            Matrix= new Double[list.size()][list.get(0).length];
            for(int i=0;i<list.size();i++){
                Matrix[i][0]=1.0;
            }
            //初始化第二列->值为list.get(i)数组中的第一列
            for(int i=0;i<list.size();i++){
                Matrix[i][1]=list.get(i)[0];
            }
            //初始化第二列->值为list.get(i)数组中的第二列
            for(int i=0;i<list.size();i++){
                Matrix[i][2]=list.get(i)[1];
            }
            //训练集矩阵构建完成list.size()个样本,特征list.get(i).length  矩阵(list.size()维度,list.get(i).length维度)
    }
    //初始化权重向量矩阵和真实标签矩阵
    public static void initWeights(){
        weights=new Double[list.get(0).length][1];
        weights[0][0]=1.0;
        weights[1][0]=1.0;
        weights[2][0]=1.0;
        target=new Double[list.size()][1];
        for(int i=0;i<list.size();i++){
            target[i][0]=list.get(i)[2];
        }
    }
    // Logistic函数->sigmoid
    public static Double[][] sigmoid(Double[][] wx) {
        Double[][] sigmod=new Double[wx.length][wx[0].length];
        for(int i=0;i<wx.length;i++){
            double v = 1.0 / (1 + Math.exp(-wx[i][0]));
            sigmod [i][0]=v;
        }
        return sigmod;
    }
    //矩阵相乘
    public static Double[][] MatrixMutMatrix(Double a[][], Double b[][]) {
        int arow = a.length;
        int bcol = b[0].length;
        int m = b.length;
        Double[][] c = new Double[arow][bcol];
        for (int i = 0; i < arow; i++) {
            for (int j = 0; j < bcol; j++) {
                Double result = 0.0;
                for (int k = 0; k < m; k++) {
                    result += a[i][k] * b[k][j];
                }
                c[i][j] = result;
            }
        }
        return c;
    }
    //矩阵相减->计算误差
    public static Double[][] subMatrix(Double[][] A, Double[][] B){
        int line=A.length,list=A[0].length;
        Double[][] C =new Double[line][list];
        for(int i=0;i<line;i++)
        {
            for(int j=0;j<list;j++)
            {
                C[i][j]=A[i][j]-B[i][j];
            }
        }
        return C;
    }
    // 将矩阵转置
    public static Double[][] revMatrix(Double temp [][]) {
        Double[][] result =new Double[temp[0].length][temp.length];
        for (int i = 0; i < result.length; i++) {
            for (int j = 0; j < result[i].length; j++) {
                result[i][j] = temp[j][i] ;
            }
        }
        return result;
    }
    // 将矩阵乘以一个数
    public static Double[][] mutMatrix(Double temp [][],Double v) {
        for (int i = 0; i < temp.length; i++) {
            for (int j = 0; j < temp[i].length; j++) {
                temp[i][j] = temp[i][j]*v;
            }
        }
        return temp;
    }
    //矩阵相加
    public static Double[][] AddsMatrix(Double[][]A,Double[][] B){
        int line=A.length,list=A[0].length;
        Double[][]C=new Double[line][list];
        for(int i=0;i<line;i++)
        {
            for(int j=0;j<list;j++)
            {
                C[i][j]=A[i][j]+B[i][j];
            }
        }
        return C;
    }
    //回归函数
    public static Double regression_calc(Double[][] w,Double[][] x){
        Double[][] result=sigmoid(MatrixMutMatrix(w,x));
        Double value=result[0][0];
        return value;
    }
    //分类函数
    public static Double classifier(Double[][]x,Double[][] w){
        Double[][] result=sigmoid(MatrixMutMatrix(w,x));
        Double value=result[0][0];
        Double v;
        if(value>0.5){
            v=1.0;
        }else{
            v=0.0;
        }
        return v;
    }
    //解析数据
    public static void getDate(){
        try {
            File file = new File("D:\\data\\lr_data\\testSet.txt");
            InputStreamReader inputReader = new InputStreamReader(new FileInputStream(file));
            BufferedReader bf = new BufferedReader(inputReader);
            String str;
            while ((str = bf.readLine()) != null){
                Double[] arr=new Double[3];
                String[] result=str.split("\t");
                for(int i=0;i<result.length;i++){
                    arr[i]=Double.parseDouble(result[i]);
                }
                list.add(arr);
            }
            bf.close();
            inputReader.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }
    /**
     *
     * @param args
     * 1、设置初始w,计算F(w)
     * 2、计算梯度 • 下降方向
     * 3、尝试梯度更新
     * 4、如果 较小,停止; 否则 ;跳到第2步
     */
    public static void main(String[] args) {
        getDate();
        geMatrix();
        initWeights();
        for(int i=0;i<steps;i++){
            //训练集矩阵 乘  权重  w*x
            Double[][] gradient=MatrixMutMatrix(Matrix,weights);
            //sigmoid函数  1/1+exp(-wx) 返回预测值
            Double[][] output=sigmoid(gradient);
            //真实值减预测值  返回误差
            Double[][] errors = subMatrix(target,output);
            //训练集矩阵 转置
            Double[][] dataMat=revMatrix(Matrix);
            //转置后的训练集矩阵 乘 步长
            Double[][] mut=mutMatrix(dataMat,alpha);
            //所有样本乘以误差
            Double[][] err=MatrixMutMatrix(mut,errors);
            //更新权重  权重 +步长∗ 梯度(误差)
            weights = AddsMatrix(weights,err);
        }
        System.out.println(weights[0][0]);
        System.out.println(weights[1][0]);
        System.out.println(weights[2][0]);
        /*得到权重
        4.178813076565532
        0.5048987439366058
        0.6198026439379993*/
        Double[][] x=new Double[1][3];
        x[0][0]=1.0;
        x[0][1]=0.9316350;
        x[0][2]=-1.589505;
        Double[][] w=new Double[1][3];
        w[0][0]=4.178813076565532;
        w[0][1]=0.5048987439366058;
        w[0][2]=0.6198026439379993;
        //回归函数
        Double a=regression_calc(w,x);
        //分类函数
        Double b=classifier(w,x);
        System.out.println(a);
        System.out.println(b);
    }
}


相关文章
|
1月前
|
Java 大数据 Go
从混沌到秩序:Java共享内存模型如何通过显式约束驯服并发?
并发编程旨在混乱中建立秩序。本文对比Java共享内存模型与Golang消息传递模型,剖析显式同步与隐式因果的哲学差异,揭示happens-before等机制如何保障内存可见性与数据一致性,展现两大范式的深层分野。(238字)
59 4
|
3月前
|
缓存 前端开发 Java
Java类加载机制与双亲委派模型
本文深入解析Java类加载机制,涵盖类加载过程、类加载器、双亲委派模型、自定义类加载器及实战应用,帮助开发者理解JVM核心原理与实际运用。
|
3月前
|
机器学习/深度学习 人工智能 自然语言处理
Java 大视界 -- Java 大数据机器学习模型在自然语言生成中的可控性研究与应用(229)
本文深入探讨Java大数据与机器学习在自然语言生成(NLG)中的可控性研究,分析当前生成模型面临的“失控”挑战,如数据噪声、标注偏差及黑盒模型信任问题,提出Java技术在数据清洗、异构框架融合与生态工具链中的关键作用。通过条件注入、强化学习与模型融合等策略,实现文本生成的精准控制,并结合网易新闻与蚂蚁集团的实战案例,展示Java在提升生成效率与合规性方面的卓越能力,为金融、法律等强监管领域提供技术参考。
|
3月前
|
机器学习/深度学习 算法 Java
Java 大视界 -- Java 大数据机器学习模型在生物信息学基因功能预测中的优化与应用(223)
本文探讨了Java大数据与机器学习模型在生物信息学中基因功能预测的优化与应用。通过高效的数据处理能力和智能算法,提升基因功能预测的准确性与效率,助力医学与农业发展。
|
3月前
|
机器学习/深度学习 搜索推荐 数据可视化
Java 大视界 -- Java 大数据机器学习模型在电商用户流失预测与留存策略制定中的应用(217)
本文探讨 Java 大数据与机器学习在电商用户流失预测与留存策略中的应用。通过构建高精度预测模型与动态分层策略,助力企业提前识别流失用户、精准触达,实现用户留存率与商业价值双提升,为电商应对用户流失提供技术新思路。
|
3月前
|
机器学习/深度学习 存储 分布式计算
Java 大视界 --Java 大数据机器学习模型在金融风险压力测试中的应用与验证(211)
本文探讨了Java大数据与机器学习模型在金融风险压力测试中的创新应用。通过多源数据采集、模型构建与优化,结合随机森林、LSTM等算法,实现信用风险动态评估、市场极端场景模拟与操作风险预警。案例分析展示了花旗银行与蚂蚁集团的智能风控实践,验证了技术在提升风险识别效率与降低金融风险损失方面的显著成效。
|
3月前
|
机器学习/深度学习 自然语言处理 算法
Java 大视界 -- Java 大数据机器学习模型在自然语言处理中的对抗训练与鲁棒性提升(205)
本文探讨Java大数据与机器学习在自然语言处理中的对抗训练与鲁棒性提升,分析对抗攻击原理,结合Java技术构建对抗样本、优化训练策略,并通过智能客服等案例展示实际应用效果。
|
4月前
|
机器学习/深度学习 分布式计算 Java
Java 大视界 -- Java 大数据机器学习模型在遥感图像土地利用分类中的优化与应用(199)
本文探讨了Java大数据与机器学习模型在遥感图像土地利用分类中的优化与应用。面对传统方法效率低、精度差的问题,结合Hadoop、Spark与深度学习框架,实现了高效、精准的分类。通过实际案例展示了Java在数据处理、模型融合与参数调优中的强大能力,推动遥感图像分类迈向新高度。
|
4月前
|
机器学习/深度学习 存储 Java
Java 大视界 -- Java 大数据机器学习模型在游戏用户行为分析与游戏平衡优化中的应用(190)
本文探讨了Java大数据与机器学习模型在游戏用户行为分析及游戏平衡优化中的应用。通过数据采集、预处理与聚类分析,开发者可深入洞察玩家行为特征,构建个性化运营策略。同时,利用回归模型优化游戏数值与付费机制,提升游戏公平性与用户体验。
|
4月前
|
机器学习/深度学习 算法 Java
Java 大视界 -- Java 大数据机器学习模型在舆情分析中的情感倾向判断与话题追踪(185)
本篇文章深入探讨了Java大数据与机器学习在舆情分析中的应用,重点介绍了情感倾向判断与话题追踪的技术实现。通过实际案例,展示了如何利用Java生态工具如Hadoop、Hive、Weka和Deeplearning4j进行舆情数据处理、情感分类与趋势预测,揭示了其在企业品牌管理与政府决策中的重要价值。文章还展望了多模态融合、实时性提升及个性化服务等未来发展方向。