【绝技揭秘】Andrew Ng 机器学习课程第十周:解锁梯度下降的神秘力量,带你飞速征服数据山峰!

简介: 【8月更文挑战第16天】Andrew Ng 的机器学习课程是学习该领域的经典资源。第十周聚焦于优化梯度下降算法以提升效率。课程涵盖不同类型的梯度下降(批量、随机及小批量)及其应用场景,介绍如何选择合适的批量大小和学习率调整策略。还介绍了动量法、RMSProp 和 Adam 优化器等高级技巧,这些方法能有效加速收敛并改善模型性能。通过实践案例展示如何使用 Python 和 NumPy 实现小批量梯度下降。

Andrew Ng 的机器学习课程是学习机器学习基础知识的经典资源之一。第十周的内容主要关注于如何优化梯度下降算法,使其更加高效地收敛。本周的学习重点在于理解不同类型的梯度下降算法及其应用场景,以及如何选择合适的参数来加速学习过程。

首先,回顾一下梯度下降的基本概念。梯度下降是一种迭代优化算法,用于最小化成本函数。在每一步迭代中,参数都会沿着负梯度的方向更新,直到达到局部最小值。梯度下降有三种主要类型:批量梯度下降、随机梯度下降和小批量梯度下降。每种类型都有其特点和适用场景。

批量梯度下降

批量梯度下降是最原始的形式,它使用整个训练集来计算梯度。这种方法的优点在于每次迭代都会朝着全局最小值方向移动,但是当数据集很大时,计算全部样本的梯度非常耗时。

随机梯度下降

随机梯度下降(SGD)在每次迭代时只使用一个训练样本来估计梯度。这种方法的主要优点是计算速度快,尤其当数据集非常大时。然而,由于每次迭代都只基于一个样本,梯度的估计往往不够准确,导致参数更新路径波动较大。

小批量梯度下降

小批量梯度下降(Mini-batch Gradient Descent)是介于批量梯度下降和随机梯度下降之间的一种折衷方案。它在每次迭代时使用一小批样本(例如 50 或 100 个样本)来计算梯度。这种方法兼顾了速度和稳定性,是实践中最常用的选择。

选择合适的批量大小

批量大小的选择对梯度下降算法的性能有很大影响。较小的批量大小可以加快学习过程,但也可能导致更多的迭代次数。较大的批量大小则可以减少迭代次数,但可能需要更多的时间来完成一次迭代。实践中,通常会选择介于 50 到 256 之间的批量大小。

学习率调整策略

学习率是梯度下降算法中一个重要的超参数。初始阶段,较高的学习率可以使算法更快地收敛;随着迭代次数增加,学习率应该逐渐减小,以便算法能够更精细地逼近最小值。常见的学习率调整策略包括固定衰减、指数衰减等。

动量法

动量法是一种改进梯度下降算法的方法,它通过引入一个动量项来加速收敛过程。动量项相当于梯度下降过程中的一种惯性,可以减少振荡,使算法更快地沿梯度方向移动。动量法的更新规则如下:

v = beta * v + alpha * gradient
theta = theta - v

其中,alpha 是学习率,beta 是动量系数(通常取值为 0.9),gradient 是梯度,v 是累积的梯度,theta 是待优化的参数。

RMSProp

RMSProp 是另一种优化算法,它可以自动调整每个参数的学习率,以解决学习率衰减的问题。RMSProp 的更新规则如下:

cache = gamma * cache + (1 - gamma) * gradient ** 2
theta = theta - (alpha / (np.sqrt(cache) + epsilon)) * gradient

其中,gamma 是衰减率(通常取值为 0.9),epsilon 是一个小常数(如 1e-8),用于防止除以零的情况。

Adam 优化器

Adam 优化器综合了动量法和 RMSProp 的优点,它同时考虑了梯度的一阶矩估计(动量)和二阶矩估计(RMSProp)。Adam 优化器在实践中表现出色,通常作为默认选择。

实践案例

下面是一个使用 Python 和 NumPy 实现的小批量梯度下降示例:

import numpy as np

def compute_cost(X, y, theta):
    m = len(y)
    predictions = X.dot(theta)
    cost = (1.0 / (2 * m)) * np.sum(np.square(predictions - y))
    return cost

def gradient_descent(X, y, theta, alpha, num_iters, batch_size):
    m = len(y)
    J_history = np.zeros(num_iters)

    for iter in range(num_iters):
        shuffled_indices = np.random.permutation(m)
        X_shuffled = X[shuffled_indices]
        y_shuffled = y[shuffled_indices]

        for start in range(0, m, batch_size):
            end = min(start + batch_size, m)
            X_batch = X_shuffled[start:end]
            y_batch = y_shuffled[start:end]

            gradients = (1.0 / batch_size) * X_batch.T.dot(X_batch.dot(theta) - y_batch)
            theta = theta - alpha * gradients

        J_history[iter] = compute_cost(X, y, theta)

    return theta, J_history

# 示例数据
X = np.array([[1, 2], [1, 3], [1, 4]])
y = np.array([3, 5, 7])
theta = np.zeros(2)
alpha = 0.01
num_iters = 1500
batch_size = 2

theta, J_history = gradient_descent(X, y, theta, alpha, num_iters, batch_size)

print("Final theta:", theta)

总之,优化梯度下降算法对于提高机器学习模型的训练效率至关重要。通过合理选择批量大小、学习率调整策略以及应用先进的优化技巧(如动量法、RMSProp 和 Adam 优化器),可以显著加速模型的训练过程,并提高模型的整体性能。

相关文章
|
5月前
|
机器学习/深度学习 人工智能 自然语言处理
梯度下降求极值,机器学习&深度学习
梯度下降求极值,机器学习&深度学习
42 0
|
2月前
|
机器学习/深度学习 算法 Python
探索机器学习中的梯度下降优化算法
【8月更文挑战第1天】在机器学习的广阔天地里,梯度下降法如同一位勇敢的探险家,指引我们穿越复杂的数学丛林,寻找模型参数的最优解。本文将深入探讨梯度下降法的核心原理,并通过Python代码示例,展示其在解决实际问题中的应用。
58 3
|
4月前
|
机器学习/深度学习 人工智能 算法
【机器学习】深度探索:从基础概念到深度学习关键技术的全面解析——梯度下降、激活函数、正则化与批量归一化
【机器学习】深度探索:从基础概念到深度学习关键技术的全面解析——梯度下降、激活函数、正则化与批量归一化
50 3
|
5月前
|
机器学习/深度学习 监控 算法
LabVIEW使用机器学习分类模型探索基于技能课程的学习
LabVIEW使用机器学习分类模型探索基于技能课程的学习
43 1
|
5月前
|
机器学习/深度学习
Coursera 吴恩达Machine Learning(机器学习)课程 |第五周测验答案(仅供参考)
Coursera 吴恩达Machine Learning(机器学习)课程 |第五周测验答案(仅供参考)
|
5月前
|
机器学习/深度学习 人工智能 算法
【人工智能】<吴恩达-机器学习>批量梯度下降&矩阵和向量运算概述
【1月更文挑战第26天】【人工智能】<吴恩达-机器学习>批量梯度下降&矩阵和向量运算概述
|
5月前
|
机器学习/深度学习 人工智能
【人工智能】<吴恩达-机器学习>单变量的线性回归&认识梯度下降
【1月更文挑战第26天】【人工智能】<吴恩达-机器学习>单变量的线性回归&认识梯度下降
|
5月前
|
机器学习/深度学习 算法
【机器学习】三种梯度下降对比
【1月更文挑战第24天】【机器学习】三种梯度下降对比
|
18天前
|
机器学习/深度学习 算法 TensorFlow
交通标志识别系统Python+卷积神经网络算法+深度学习人工智能+TensorFlow模型训练+计算机课设项目+Django网页界面
交通标志识别系统。本系统使用Python作为主要编程语言,在交通标志图像识别功能实现中,基于TensorFlow搭建卷积神经网络算法模型,通过对收集到的58种常见的交通标志图像作为数据集,进行迭代训练最后得到一个识别精度较高的模型文件,然后保存为本地的h5格式文件。再使用Django开发Web网页端操作界面,实现用户上传一张交通标志图片,识别其名称。
45 6
交通标志识别系统Python+卷积神经网络算法+深度学习人工智能+TensorFlow模型训练+计算机课设项目+Django网页界面
|
2月前
|
机器学习/深度学习 算法 数据挖掘
8个常见的机器学习算法的计算复杂度总结
8个常见的机器学习算法的计算复杂度总结
8个常见的机器学习算法的计算复杂度总结