【机器学习】numpy实现NAG(Nesterov accelerated gradient)优化器

简介: 【机器学习】numpy实现NAG(Nesterov accelerated gradient)优化器

NAG(Nesterov accelerated gradient)优化原理

Momentum是基于动量原理的,就是每次更新参数时,梯度的方向都会和上一次迭代的方向有关,当一个球向山下滚的时候,它会越滚越快,能够加快收敛,但是这样也会存在一个问题,每次梯度都是本次和上次之和,如果是同向,那么将导致梯度很大,当到达谷底的时候很容易动量过大导致小球冲过谷底,跳过当前局部最优位置。

我们希望有一个更智能的球,一个知道它要去哪里的球,这样它知道在山坡再次向上倾斜之前减速。

Nesterov accelerated gradient是一种使动量项具有这种预见性的方法。我们知道我们将使用动量项来移动参数。因此,为我们提供了参数下一个位置的近似值(完全更新时缺少梯度),大致了解了参数的位置。我们现在可以通过计算梯度(不是我们当前参数的,而是我们参数的近似未来位置的)有效地向前看。

image.png

第一个公式分为两个部分看,第一项是动量部分,保持上次的梯度方向,第二项就是当前梯度,但是这个不太一样,梯度参数里面是 θ − γ ∗ v t − 1 ,由于我们希望小球可以知道自己何时停下,所以希望小球可以预测未来梯度的趋势,一旦发现前方的坡度上升,那么就应该减小步伐,以免冲出最低点,从公式角度理解,更新当前梯度时,我先按照上次梯度方向更新,计算一个大概未来的一个梯度,如果为正,那么说明本次更新后仍和我之前更行的方向一致,说明本次不会冲出去,保持更新即可,但是如果为负,说明本次更新后梯度方向变化了,即冲过了最优点,那么正好和上次的动量方向抵消一部分,因为两者异号,这样小球就知道自己此次更新会冲过去,所以两者抵消一部分导致本次更新步伐没有那么大。

我们再次将动量项的值设置为0.9左右。动量首先计算当前梯度(图3中的蓝色小矢量),然后在更新的累积梯度(蓝色大矢量)的方向上进行大跳跃,而NAG首先在先前累积梯度(棕色矢量)的方向上进行大跳跃,测量梯度,然后进行校正(绿色矢量)。此预期更新可防止我们进行得太快,从而提高响应能力,从而显著提高RNN在许多任务上的性能。

现在我们能够根据误差函数的斜率调整更新,并反过来加快SGD,我们还希望根据每个参数的重要性调整更新,以执行更大或更小的更新。

迭代过程

代码实践

import numpy as np
import matplotlib.pyplot as plt
class Optimizer:
    def __init__(self,
                 epsilon = 1e-10, # 误差
                 iters = 100000,  # 最大迭代次数
                 lamb = 0.01, # 学习率
                 gamma = 0.0, # 动量项系数
                ):
        self.epsilon = epsilon
        self.iters = iters
        self.lamb = lamb
        self.gamma = gamma
    def nag(self, x_0 = 0.5, y_0 = 0.5):
        f1, f2 = self.fn(x_0, y_0), 0
        w = np.array([x_0, y_0]) # 每次迭代后的函数值,用于绘制梯度曲线
        k = 0 # 当前迭代次数
        v_t = 0.0
        while True:
            if abs(f1 - f2) <= self.epsilon or k > self.iters:
                break
            f1 = self.fn(x_0, y_0)
            if k == 0:
                g = np.array([self.dx(x_0, y_0), self.dy(x_0, y_0)])
            else:
                g = np.array([self.dx(x_0 - v_t[0], y_0 - v_t[1]), self.dy(x_0 - v_t[0], y_0 - v_t[1])])
            v_t = self.gamma * v_t + self.lamb * g
            x_0, y_0 = np.array([x_0, y_0]) - v_t
            f2 = self.fn(x_0, y_0)
            w = np.vstack((w, (x_0, y_0)))
            k += 1
        self.print_info(k, x_0, y_0, f2)
        self.draw_process(w)
    def print_info(self, k, x_0, y_0, f2):
        print('迭代次数:{}'.format(k))
        print('极值点:【x_0】:{} 【y_0】:{}'.format(x_0, y_0))
        print('函数的极值:{}'.format(f2))
    def draw_process(self, w):
        X = np.arange(0, 1.5, 0.01)
        Y = np.arange(-1, 1, 0.01)
        [x, y] = np.meshgrid(X, Y)
        f = x**3 - y**3 + 3 * x**2 + 3 * y**2 - 9 * x
        plt.contour(x, y, f, 20)
        plt.plot(w[:, 0],w[:, 1], 'g*', w[:, 0], w[:, 1])
        plt.show()
    def fn(self, x, y):
        return x**3 - y**3 + 3 * x**2 + 3 * y**2 - 9 * x
    def dx(self, x, y):
        return 3 * x**2 + 6 * x - 9
    def dy(self, x, y):
        return - 3 * y**2 + 6 * y
"""
    函数: f(x) = x**3 - y**3 + 3 * x**2 + 3 * y**2 - 9 * x
    最优解: x = 1, y = 0
    极小值: f(x,y) = -5
"""
optimizer = Optimizer()
optimizer.nag()


目录
相关文章
|
4月前
|
机器学习/深度学习 数据可视化 搜索推荐
Python在社交媒体分析中扮演关键角色,借助Pandas、NumPy、Matplotlib等工具处理、可视化数据及进行机器学习。
【7月更文挑战第5天】Python在社交媒体分析中扮演关键角色,借助Pandas、NumPy、Matplotlib等工具处理、可视化数据及进行机器学习。流程包括数据获取、预处理、探索、模型选择、评估与优化,以及结果可视化。示例展示了用户行为、话题趋势和用户画像分析。Python的丰富生态使得社交媒体洞察变得高效。通过学习和实践,可以提升社交媒体分析能力。
84 1
|
1月前
|
机器学习/深度学习 算法 数据挖掘
【Python篇】深度探索NumPy(下篇):从科学计算到机器学习的高效实战技巧1
【Python篇】深度探索NumPy(下篇):从科学计算到机器学习的高效实战技巧
51 5
|
1月前
|
机器学习/深度学习 算法 数据可视化
【Python篇】深度探索NumPy(下篇):从科学计算到机器学习的高效实战技巧2
【Python篇】深度探索NumPy(下篇):从科学计算到机器学习的高效实战技巧
40 1
|
3月前
|
机器学习/深度学习 PyTorch TensorFlow
NumPy 与机器学习框架的集成
【8月更文第30天】NumPy 是 Python 中用于科学计算的核心库之一,它提供了高效的多维数组对象,以及用于操作数组的大量函数。NumPy 的高效性和灵活性使其成为许多机器学习框架的基础。本文将探讨 NumPy 如何与 TensorFlow 和 PyTorch 等流行机器学习框架协同工作,并通过具体的代码示例来展示它们之间的交互。
58 0
|
4月前
|
机器学习/深度学习 数据采集 大数据
驾驭大数据洪流:Pandas与NumPy在高效数据处理与机器学习中的核心作用
【7月更文挑战第13天】在大数据时代,Pandas与NumPy是Python数据分析的核心,用于处理复杂数据集。在一个电商销售数据案例中,首先使用Pandas的`read_csv`加载CSV数据,通过`head`和`describe`进行初步探索。接着,数据清洗涉及填充缺失值和删除异常数据。然后,利用`groupby`和`aggregate`分析销售趋势,并用Matplotlib可视化结果。在机器学习预处理阶段,借助NumPy进行数组操作,如特征缩放。Pandas的数据操作便捷性与NumPy的数值计算效率,共同助力高效的数据分析和建模。
90 3
|
4月前
|
机器学习/深度学习 数据采集 数据处理
重构数据处理流程:Pandas与NumPy高级特性在机器学习前的优化
【7月更文挑战第14天】在数据科学中,Pandas和NumPy是数据处理的关键,用于清洗、转换和计算。用`pip install pandas numpy`安装后,Pandas的`read_csv`读取数据,`fillna`处理缺失值,`drop`删除列。Pandas的`apply`、`groupby`和`merge`执行复杂转换。NumPy加速数值计算,如`square`进行向量化操作,`dot`做矩阵乘法。结合两者优化数据预处理,提升模型训练效率和效果。
67 1
|
5月前
|
机器学习/深度学习 人工智能 资源调度
机器学习之numpy基础——线性代数,不要太简单哦
机器学习之numpy基础——线性代数,不要太简单哦
81 6
|
6月前
|
机器学习/深度学习 存储 搜索推荐
利用机器学习算法改善电商推荐系统的效率
电商行业日益竞争激烈,提升用户体验成为关键。本文将探讨如何利用机器学习算法优化电商推荐系统,通过分析用户行为数据和商品信息,实现个性化推荐,从而提高推荐效率和准确性。
239 14
|
6月前
|
机器学习/深度学习 算法 数据可视化
实现机器学习算法时,特征选择是非常重要的一步,你有哪些推荐的方法?
实现机器学习算法时,特征选择是非常重要的一步,你有哪些推荐的方法?
118 1
|
6月前
|
机器学习/深度学习 算法 搜索推荐
Machine Learning机器学习之决策树算法 Decision Tree(附Python代码)
Machine Learning机器学习之决策树算法 Decision Tree(附Python代码)
下一篇
无影云桌面