基于梯度下降算法求解线性回归

简介: 线性回归(Linear Regression)梯度下降算法在机器学习方法分类中属于监督学习。利用它可以求解线性回归问题,计算一组二维数据之间的线性关系,假设有一组数据如下下图所示 其中X轴方向表示房屋面积、Y轴表示房屋价格。

线性回归(Linear Regression)

梯度下降算法在机器学习方法分类中属于监督学习。利用它可以求解线性回归问题,计算一组二维数据之间的线性关系,假设有一组数据如下下图所示
这里写图片描述
其中X轴方向表示房屋面积、Y轴表示房屋价格。我们希望根据上述的数据点,拟合出一条直线,能跟对任意给定的房屋面积实现价格预言,这样求解得到直线方程过程就叫线性回归,得到的直线为回归直线,数学公式表示如下:
这里写图片描述

二:梯度下降 (Gradient Descent)

这里写图片描述
这里写图片描述
这里写图片描述
这里写图片描述

三:代码实现

数据读入

public List<DataItem> getData(String fileName) {
    List<DataItem> items = new ArrayList<DataItem>();
    File f = new File(fileName);
    try {
        if (f.exists()) {
            BufferedReader br = new BufferedReader(new FileReader(f));
            String line = null;
            while((line = br.readLine()) != null) {
                String[] data = line.split(",");
                if(data != null && data.length == 2) {
                    DataItem item = new DataItem();
                    item.x = Integer.parseInt(data[0]);
                    item.y = Integer.parseInt(data[1]);
                    items.add(item);
                }
            }
            br.close();
        }
    } catch (IOException ioe) {
        System.err.println(ioe);
    }
    return items;
}

归一化处理

public void normalization(List<DataItem> items) {
    float min = 100000;
    float max = 0;
    for(DataItem item : items) {
        min = Math.min(min, item.x);
        max = Math.max(max, item.x);
    }
    float delta = max - min;
    for(DataItem item : items) {
        item.x = (item.x - min) / delta;
    }
}

梯度下降

public float[] gradientDescent(List<DataItem> items) {
    int repetion = 1500;
    float learningRate = 0.1f;
    float[] theta = new float[2];
    Arrays.fill(theta, 0);
    float[] hmatrix = new float[items.size()];
    Arrays.fill(hmatrix, 0);
    int k=0;
    float s1 = 1.0f / items.size();
    float sum1=0, sum2=0;
    for(int i=0; i<repetion; i++) {
        for(k=0; k<items.size(); k++ ) {
            hmatrix[k] = ((theta[0] + theta[1]*items.get(k).x) - items.get(k).y);
        }

        for(k=0; k<items.size(); k++ ) {
            sum1 += hmatrix[k];
            sum2 += hmatrix[k]*items.get(k).x;
        }

        sum1 = learningRate*s1*sum1;
        sum2 = learningRate*s1*sum2;

        // 更新 参数theta
        theta[0] = theta[0] - sum1;
        theta[1] = theta[1] - sum2;
    }

    return theta;
}

价格预言

public float predict(float input, float[] theta) {
    float result = theta[0] + theta[1]*input;
    return result;
}

线性回归图

public void drawPlot(List<DataItem> series1, List<DataItem> series2, float[] theta) {
    int w = 500;
    int h = 500;
    BufferedImage plot = new BufferedImage(w, h, BufferedImage.TYPE_INT_ARGB);
    Graphics2D g2d = plot.createGraphics();
    g2d.setRenderingHint(RenderingHints.KEY_ANTIALIASING, RenderingHints.VALUE_ANTIALIAS_ON);
    g2d.setPaint(Color.WHITE);
    g2d.fillRect(0, 0, w, h);
    g2d.setPaint(Color.BLACK);
    int margin = 50;
    g2d.drawLine(margin, 0, margin, h);
    g2d.drawLine(0, h-margin, w, h-margin);
    float minx=Float.MAX_VALUE, maxx=Float.MIN_VALUE;
    float miny=Float.MAX_VALUE, maxy=Float.MIN_VALUE;
    for(DataItem item : series1) {
        minx = Math.min(item.x, minx);
        maxx = Math.max(maxx, item.x);
        miny = Math.min(item.y, miny);
        maxy = Math.max(item.y, maxy);
    }
    for(DataItem item : series2) {
        minx = Math.min(item.x, minx);
        maxx = Math.max(maxx, item.x);
        miny = Math.min(item.y, miny);
        maxy = Math.max(item.y, maxy);
    }
    // draw X, Y Title and Aixes
    g2d.setPaint(Color.BLACK);
    g2d.drawString("价格(万)", 0, h/2);
    g2d.drawString("面积(平方米)", w/2, h-20);

    // draw labels and legend
    g2d.setPaint(Color.BLUE);
    float xdelta = maxx - minx;
    float ydelta = maxy - miny;
    float xstep = xdelta / 10.0f;
    float ystep = ydelta / 10.0f;
    int dx = (w - 2*margin) / 11;
    int dy = (h - 2*margin) / 11;

    // draw labels
    for(int i=1; i<11; i++) {
        g2d.drawLine(margin+i*dx, h-margin, margin+i*dx, h-margin-10);
        g2d.drawLine(margin, h-margin-dy*i, margin+10, h-margin-dy*i);
        int xv = (int)(minx + (i-1)*xstep);
        float yv = (int)((miny + (i-1)*ystep)/10000.0f);
        g2d.drawString(""+xv, margin+i*dx, h-margin+15);
        g2d.drawString(""+yv, margin-25, h-margin-dy*i);
    }

    // draw point
    g2d.setPaint(Color.BLUE);
    for(DataItem item : series1) {
        float xs = (item.x - minx) / xstep + 1;
        float ys = (item.y - miny) / ystep + 1;
        g2d.fillOval((int)(xs*dx+margin-3), (int)(h-margin-ys*dy-3), 7,7);
    }
    g2d.fillRect(100, 20, 20, 10);
    g2d.drawString("训练数据", 130, 30);

    // draw regression line
    g2d.setPaint(Color.RED);
    for(int i=0; i<series2.size()-1; i++) {
        float x1 = (series2.get(i).x - minx) / xstep + 1;
        float y1 = (series2.get(i).y - miny) / ystep + 1;
        float x2 = (series2.get(i+1).x - minx) / xstep + 1;
        float y2 = (series2.get(i+1).y - miny) / ystep + 1;
        g2d.drawLine((int)(x1*dx+margin-3), (int)(h-margin-y1*dy-3), (int)(x2*dx+margin-3), (int)(h-margin-y2*dy-3));
    }
    g2d.fillRect(100, 50, 20, 10);
    g2d.drawString("线性回归", 130, 60);


    g2d.dispose();
    saveImage(plot);
}

四:总结

本文通过最简单的示例,演示了利用梯度下降算法实现线性回归分析,使用更新收敛的算法常被称为LMS(Least Mean Square)又叫Widrow-Hoff学习规则,此外梯度下降算法还可以进一步区分为增量梯度下降算法与批量梯度下降算法,这两种梯度下降方法在基于神经网络的机器学习中经常会被提及,对此感兴趣的可以自己进一步探索与研究。

只分享干货,不止于代码

目录
相关文章
|
29天前
|
机器学习/深度学习 自然语言处理 算法
深入理解机器学习算法:从线性回归到神经网络
深入理解机器学习算法:从线性回归到神经网络
|
2月前
|
机器学习/深度学习 算法 大数据
机器学习入门:梯度下降算法(下)
机器学习入门:梯度下降算法(下)
|
6月前
|
机器学习/深度学习 算法 PyTorch
【从零开始学习深度学习】38. Pytorch实战案例:梯度下降、随机梯度下降、小批量随机梯度下降3种优化算法对比【含数据集与源码】
【从零开始学习深度学习】38. Pytorch实战案例:梯度下降、随机梯度下降、小批量随机梯度下降3种优化算法对比【含数据集与源码】
|
3月前
|
机器学习/深度学习 算法
深度学习中的优化算法:从梯度下降到Adam
本文深入探讨了深度学习中的核心——优化算法,重点分析了梯度下降及其多种变体。通过比较梯度下降、动量方法、AdaGrad、RMSProp以及Adam等算法,揭示了它们如何更高效地找到损失函数的最小值。此外,文章还讨论了不同优化算法在实际模型训练中的表现和选择依据,为深度学习实践提供了宝贵的指导。
112 7
|
2月前
|
机器学习/深度学习 算法
机器学习入门:梯度下降算法(上)
机器学习入门:梯度下降算法(上)
|
3月前
|
存储 算法 测试技术
预见未来?Python线性回归算法:数据中的秘密预言家
【9月更文挑战第11天】在数据的海洋中,线性回归算法犹如智慧的预言家,助我们揭示未知。本案例通过收集房屋面积、距市中心距离等数据,利用Python的pandas和scikit-learn库构建房价预测模型。经过训练与测试,模型展现出较好的预测能力,均方根误差(RMSE)低,帮助房地产投资者做出更明智决策。尽管现实关系复杂多变,线性回归仍提供了有效工具,引领我们在数据世界中自信前行。
55 5
|
4月前
|
机器学习/深度学习 人工智能 算法
【人工智能】线性回归模型:数据结构、算法详解与人工智能应用,附代码实现
线性回归是一种预测性建模技术,它研究的是因变量(目标)和自变量(特征)之间的关系。这种关系可以表示为一个线性方程,其中因变量是自变量的线性组合。
83 2
|
4月前
|
机器学习/深度学习 算法 Python
探索机器学习中的梯度下降优化算法
【8月更文挑战第1天】在机器学习的广阔天地里,梯度下降法如同一位勇敢的探险家,指引我们穿越复杂的数学丛林,寻找模型参数的最优解。本文将深入探讨梯度下降法的核心原理,并通过Python代码示例,展示其在解决实际问题中的应用。
90 3
|
4月前
|
存储 算法 定位技术
预见未来?Python线性回归算法:数据中的秘密预言家
【8月更文挑战第3天】站在数据的海洋边,线性回归算法犹如智慧的预言家,揭示着房价的秘密。作为房地产投资者,面对复杂的市场,我们可通过收集房屋面积、位置等数据并利用Python的pandas及scikit-learn库,建立线性回归模型预测房价。通过评估模型的均方根误差(RMSE),我们可以更精准地判断投资时机,让数据引领我们走向成功的彼岸。
27 1
|
4月前
|
机器学习/深度学习 算法 数据可视化
Python数据分析高手修炼手册:线性回归算法,让你的数据说话更有力
【8月更文挑战第1天】在数据驱动时代,掌握数据分析技能至关重要。线性回归是最基础且强大的工具之一,能从复杂数据中提炼简单有效的模型。本文探索Python中线性回归的应用并通过实战示例加深理解。线性回归建立变量间线性关系模型:Y = β0 + β1*X + ε。使用scikit-learn库进行实战:首先安装必要库,然后加载数据、训练模型并评估性能。示例展示了如何使用`LinearRegression`模型进行房价预测,包括数据可视化。掌握线性回归,让数据“说话”更有力。
48 2
下一篇
DataWorks