【机器学习】numpy实现Momentum动量优化器

简介: 【机器学习】numpy实现Momentum动量优化器

Momentum优化原理

SGD难以在沟壑中航行,即曲面在一个维度上的曲线比在另一个维度上的曲线陡峭得多的区域,这在局部最优点附近很常见。在这些情况下,SGD在沟谷的斜坡上振荡,同时沿着底部向局部最优方向缓慢前进,如图所示。

动量是一种有助于在相关方向上加速SGD并抑制振荡的方法,如图b所示。它通过将过去时间步长的更新向量的一小部分添加到当前更新向量来实现这一点:

image.png

动量项通常设置为0.9或类似值。

基本上,当使用动量时,我们将球推下山。球在下坡时积累动量,在路上越来越快(如果有空气阻力,直到达到其极限速度,即1)。同样的情况也发生在参数更新上:对于渐变指向相同方向的维度,动量项增加;对于渐变改变方向的维度,动量项减少更新。因此,我们获得了更快的收敛速度和更少的振荡。

Momentum说白了就是同方向加速,反方向减速,如果上次迭代的梯度方向和本次梯度方向一致,那么本次迭代过程将大幅度更新参数,而如果上次和本次梯度相反,那么就会在一定程度上抵消一部分,导致本次梯度更新只是小幅度更新。

在训练初期,初始化的参数距离最优位置很远,所以此时应该大幅度进行更新参数,类似于小球滚下山,越滚越快,这样就会使我们能够快速到达谷底,减少收敛次数,当到了谷底后,我们不希望还向之前大步更新,应该进行微调,如果仍然大步就会导致不断在最优位置进行振荡,所以此时采用动量的话,当快到了谷底的时候,此时的梯度方向和原来的方向相反,这样就会抵消一部分,导致更新的步幅减小,更容易进行微调找到最优位置。

迭代过程

代码实践

import numpy as np
import matplotlib.pyplot as plt
class Optimizer:
    def __init__(self,
                 epsilon = 1e-10, # 误差
                 iters = 100000,  # 最大迭代次数
                 lamb = 0.01, # 学习率
                 gamma = 0.0, # 动量项系数
                ):
        self.epsilon = epsilon
        self.iters = iters
        self.lamb = lamb
        self.gamma = gamma
    def momentum(self, x_0 = 0.5, y_0 = 0.5):
        f1, f2 = self.fn(x_0, y_0), 0
        w = np.array([x_0, y_0]) # 每次迭代后的函数值,用于绘制梯度曲线
        k = 0 # 当前迭代次数
        v_t = 0.0
        while True:
            if abs(f1 - f2) <= self.epsilon or k > self.iters:
                break
            f1 = self.fn(x_0, y_0)
            g = np.array([self.dx(x_0, y_0), self.dy(x_0, y_0)])
            v_t = self.gamma * v_t + self.lamb * g
            x_0, y_0 = np.array([x_0, y_0]) - v_t
            f2 = self.fn(x_0, y_0)
            w = np.vstack((w, (x_0, y_0)))
            k += 1
        self.print_info(k, x_0, y_0, f2)
        self.draw_process(w)
    def print_info(self, k, x_0, y_0, f2):
        print('迭代次数:{}'.format(k))
        print('极值点:【x_0】:{} 【y_0】:{}'.format(x_0, y_0))
        print('函数的极值:{}'.format(f2))
    def draw_process(self, w):
        X = np.arange(0, 1.5, 0.01)
        Y = np.arange(-1, 1, 0.01)
        [x, y] = np.meshgrid(X, Y)
        f = x**3 - y**3 + 3 * x**2 + 3 * y**2 - 9 * x
        plt.contour(x, y, f, 20)
        plt.plot(w[:, 0],w[:, 1], 'g*', w[:, 0], w[:, 1])
        plt.show()
    def fn(self, x, y):
        return x**3 - y**3 + 3 * x**2 + 3 * y**2 - 9 * x
    def dx(self, x, y):
        return 3 * x**2 + 6 * x - 9
    def dy(self, x, y):
        return - 3 * y**2 + 6 * y
"""
    函数: f(x) = x**3 - y**3 + 3 * x**2 + 3 * y**2 - 9 * x
    最优解: x = 1, y = 0
    极小值: f(x,y) = -5
"""
optimizer = Optimizer()
optimizer.momentum()



目录
相关文章
|
3月前
|
机器学习/深度学习 数据可视化 搜索推荐
Python在社交媒体分析中扮演关键角色,借助Pandas、NumPy、Matplotlib等工具处理、可视化数据及进行机器学习。
【7月更文挑战第5天】Python在社交媒体分析中扮演关键角色,借助Pandas、NumPy、Matplotlib等工具处理、可视化数据及进行机器学习。流程包括数据获取、预处理、探索、模型选择、评估与优化,以及结果可视化。示例展示了用户行为、话题趋势和用户画像分析。Python的丰富生态使得社交媒体洞察变得高效。通过学习和实践,可以提升社交媒体分析能力。
64 1
|
2月前
|
机器学习/深度学习 PyTorch TensorFlow
NumPy 与机器学习框架的集成
【8月更文第30天】NumPy 是 Python 中用于科学计算的核心库之一,它提供了高效的多维数组对象,以及用于操作数组的大量函数。NumPy 的高效性和灵活性使其成为许多机器学习框架的基础。本文将探讨 NumPy 如何与 TensorFlow 和 PyTorch 等流行机器学习框架协同工作,并通过具体的代码示例来展示它们之间的交互。
22 0
|
3月前
|
机器学习/深度学习 数据采集 大数据
驾驭大数据洪流:Pandas与NumPy在高效数据处理与机器学习中的核心作用
【7月更文挑战第13天】在大数据时代,Pandas与NumPy是Python数据分析的核心,用于处理复杂数据集。在一个电商销售数据案例中,首先使用Pandas的`read_csv`加载CSV数据,通过`head`和`describe`进行初步探索。接着,数据清洗涉及填充缺失值和删除异常数据。然后,利用`groupby`和`aggregate`分析销售趋势,并用Matplotlib可视化结果。在机器学习预处理阶段,借助NumPy进行数组操作,如特征缩放。Pandas的数据操作便捷性与NumPy的数值计算效率,共同助力高效的数据分析和建模。
58 3
|
3月前
|
机器学习/深度学习 数据采集 数据处理
重构数据处理流程:Pandas与NumPy高级特性在机器学习前的优化
【7月更文挑战第14天】在数据科学中,Pandas和NumPy是数据处理的关键,用于清洗、转换和计算。用`pip install pandas numpy`安装后,Pandas的`read_csv`读取数据,`fillna`处理缺失值,`drop`删除列。Pandas的`apply`、`groupby`和`merge`执行复杂转换。NumPy加速数值计算,如`square`进行向量化操作,`dot`做矩阵乘法。结合两者优化数据预处理,提升模型训练效率和效果。
48 1
|
4月前
|
机器学习/深度学习 人工智能 资源调度
机器学习之numpy基础——线性代数,不要太简单哦
机器学习之numpy基础——线性代数,不要太简单哦
63 6
|
4月前
|
机器学习/深度学习 人工智能 IDE
人工智能平台PAI操作报错合集之交互式建模(DSW)环境中,numpy模块如何正确安装
阿里云人工智能平台PAI (Platform for Artificial Intelligence) 是阿里云推出的一套全面、易用的机器学习和深度学习平台,旨在帮助企业、开发者和数据科学家快速构建、训练、部署和管理人工智能模型。在使用阿里云人工智能平台PAI进行操作时,可能会遇到各种类型的错误。以下列举了一些常见的报错情况及其可能的原因和解决方法。
|
5月前
|
机器学习/深度学习 数据采集 算法
探索NumPy与机器学习库的集成之路
【4月更文挑战第17天】本文探讨了NumPy在机器学习中的核心作用,它为各类机器学习库提供基础数据处理和数值计算能力。NumPy的线性代数、优化算法和随机数生成等功能,对实现高效模型训练至关重要。scikit-learn等库广泛依赖NumPy进行数据预处理。未来,尽管面临大数据和复杂模型的性能挑战,NumPy与机器学习库的集成将继续深化,推动技术创新。
|
5月前
|
机器学习/深度学习 数据采集 PyTorch
《Numpy 简易速速上手小册》第9章:Numpy 在机器学习中的应用(2024 最新版)
《Numpy 简易速速上手小册》第9章:Numpy 在机器学习中的应用(2024 最新版)
44 0
|
5月前
|
达摩院 开发者 容器
「达摩院MindOpt」优化形状切割问题(MILP)
在制造业,高效地利用材料不仅是节约成本的重要环节,也是可持续发展的关键因素。无论是在金属加工、家具制造还是纺织品生产中,原材料的有效利用都直接影响了整体效率和环境影响。
「达摩院MindOpt」优化形状切割问题(MILP)
|
5月前
|
人工智能 自然语言处理 达摩院
MindOpt 云上建模求解平台:多求解器协同优化
数学规划是一种数学优化方法,主要是寻找变量的取值在特定的约束情况下,使我们的决策目标得到一个最大或者最小值的决策。
下一篇
无影云桌面