基于PyTorch对凸函数采用SGD算法优化实例(附源码)

简介: 基于PyTorch对凸函数采用SGD算法优化实例(附源码)

0. 前言

按照国际惯例,首先声明:本文只是我自己学习的理解,虽然参考了他人的宝贵见解及成果,但是内容可能存在不准确的地方。如果发现文中错误,希望批评指正,共同进步。

本文基于PyTorch实例说明SGD(随机梯度下降)优化方法。

随机梯度下降(Stochastic Gradient Descent, SGD)是一种在机器学习和深度学习中广泛使用的优化算法,用于最小化模型的损失函数。SGD 适用于大规模数据集和复杂的模型,尤其是在训练神经网络时。

1. SGD算法介绍

SGD 的思想源于传统的梯度下降法(GD),后者是通过计算整个数据集上的损失函数梯度来确定参数更新的方向和步长。然而,在大数据场景下,每次迭代都遍历所有样本的计算成本很高。SGD 就是在这种背景下提出的,它采用了一个更为高效的策略:在每一步迭代中,仅随机抽取一个样本来估计梯度,然后用这个估计梯度更新模型参数。

1.1 基本流程
  1. 初始化模型参数 θ。
  2. 对于每个训练迭代:
    a. 随机选择一个样本 (x, y) 或者一个包含多个样本的小批量数据(mini-batch)。
    b. 计算选定样本或小批量数据上的损失函数 L ( θ ) L(θ) L(θ)关于当前参数 θ θ θ 的梯度 ∇ L ( θ ) ∇L(θ) L(θ)
    c. 使用学习率 η 来更新模型参数: θ i + 1 = θ i − η ∇ L ( θ i ) θ_{i+1} = θ_i - η∇L(θ_i) θi+1=θiηL(θi)
  3. 重复步骤2直到达到预设的停止条件(如达到最大迭代次数、损失函数收敛等)。
1.2 SGD特点
  • 优点
  • 计算效率高,尤其对于大规模数据集,可以快速获得反馈并进行更新。
  • 具有在线学习的能力,可以适应实时流式数据输入。
  • 由于噪声的存在,SGD有助于避免局部极小点,并有可能跳到全局最优附近。
  • 缺点
  • 梯度估计具有随机性,可能导致训练过程不稳定,收敛速度也可能会受噪声影响而变慢。
  • 学习率的选择非常关键,如果过大可能错过最优解,过小则会导致收敛慢。
  • 不像批量梯度下降那样能准确反映整体数据集的趋势。
1.3 SGD进阶形式

为了改进标准 SGD 的性能,研究者们提出了一系列增强版的 SGD 算法,例如:

  • Momentum(动量):引入历史梯度信息,减少振荡并加速收敛。
  • Nesterov Accelerated Gradient (NAG):在计算梯度时提前考虑动量的影响。
  • Adagrad、RMSprop、Adadelta:自适应调整学习率,对不同参数使用不同的学习率。
  • Adam 和 AdaMax:基于梯度的一阶矩和二阶矩估计进行自适应学习率调整,是目前最常用的优化器之一。

这些算法结合了 SGD 的随机性和其它技术以提高优化效果和稳定性。

2. 实例说明

基于Pytorch,手动编写SGD(随机梯度下降)方法,求-sin2(x)-sin2(y)的最小值,x∈[-2.5 , 2.5] , y∈[-2.5 , 2.5]。

画一下要拟合的函数图像:

代码

import matplotlib.pyplot
from mpl_toolkits.mplot3d import Axes3D
import numpy
import os  #这句要加,要不画图会报错
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' #这句要加,要不画图会报错
def fix_fun(x,y):  #构建函数fix_fun = (x^2+y+10)^4+(x+y^2-5)^4
    return -numpy.sin(x)**2 - numpy.sin(y)**2
#画一下函数图像
x = numpy.arange(-2.5,2.5,0.1)
y = numpy.arange(-2.5,2.5,0.1)
x, y = numpy.meshgrid(x, y)
z = fix_fun(x,y)
fig = matplotlib.pyplot.figure()
ax = Axes3D(fig)
ax.plot_surface(x,y,z, cmap='rainbow')
matplotlib.pyplot.show()

画出来后,曲面长下面的样子↓

3. SGD算法构建思路

非常简单:用.backward()方法求出要优化的函数的梯度,按照SGD的定义找出最优解。

4. 运行结果

看看随机生成的几个优化的结果

5. 源码

老规矩,细节的备注仍然写在代码里面。

import torch
from torch.autograd import Variable
import matplotlib.pyplot
from mpl_toolkits.mplot3d import Axes3D
import numpy
import os  #这句要加,要不画图会报错
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' #这句要加,要不画图会报错
def fix_fun_numpy(x,y):  #构建函数fix_fun(因为tensor数据类型和numpy画图不能混用,构建两个函数)
    return -numpy.sin(x)**2 - numpy.sin(y)**2
def fix_fun_tensor(x,y):  #构建函数fix_fun
    return -torch.sin(x)**2 - torch.sin(y)**2
#画一下函数图像
x = numpy.arange(-2.5,2.5,0.1)
y = numpy.arange(-2.5,2.5,0.1)
x, y = numpy.meshgrid(x, y)
z = fix_fun_numpy(x,y)
fig = matplotlib.pyplot.figure()
ax = Axes3D(fig)
#构建SGD算法
X = (torch.rand(2,dtype=float)-0.5)*5 #0~1分布改成-2.5~2.5分布,生成随机初始点(所以叫‘随机’梯度下降)
X_dots = []
Y_dots = []
Z_dots = []
for iter in range(50):
    X = Variable(X, requires_grad = True)  #需要求导的时候要加这一句
    Z = fix_fun_tensor(X[0],X[1])  #这一句一定要写在Variable后面
    Z.backward()  #求导(梯度)
    X = X - 0.1*X.grad #learning rate=0.1
    X_dots.append(X[0].detach().numpy())  #记录下每个学习过程点(为了后面画出学习的路径)
    Y_dots.append(X[1].detach().numpy())
    Z_dots.append(Z.detach().numpy())
ax.plot_wireframe(x, y, z)  #换成wireframe,因为用surface会和曲线重合,看不出来
ax.plot(X_dots,Y_dots,Z_dots,color='red')
matplotlib.pyplot.show()

6. 后记

喜欢探索的同学还可以把上面的源码改一改,并试试Adagrad, Adam, Momentum这些方法,也对比下各自的优劣。


相关文章
|
9天前
|
人工智能 算法 数据安全/隐私保护
基于遗传优化的SVD水印嵌入提取算法matlab仿真
该算法基于遗传优化的SVD水印嵌入与提取技术,通过遗传算法优化水印嵌入参数,提高水印的鲁棒性和隐蔽性。在MATLAB2022a环境下测试,展示了优化前后的性能对比及不同干扰下的水印提取效果。核心程序实现了SVD分解、遗传算法流程及其参数优化,有效提升了水印技术的应用价值。
|
8天前
|
存储 缓存 算法
优化轮询算法以提高资源分配的效率
【10月更文挑战第13天】通过以上这些优化措施,可以在一定程度上提高轮询算法的资源分配效率,使其更好地适应不同的应用场景和需求。但需要注意的是,优化策略的选择和实施需要根据具体情况进行详细的分析和评估,以确保优化效果的最大化。
|
9天前
|
并行计算 算法 IDE
【灵码助力Cuda算法分析】分析共享内存的矩阵乘法优化
本文介绍了如何利用通义灵码在Visual Studio 2022中对基于CUDA的共享内存矩阵乘法优化代码进行深入分析。文章从整体程序结构入手,逐步深入到线程调度、矩阵分块、循环展开等关键细节,最后通过带入具体值的方式进一步解析复杂循环逻辑,展示了通义灵码在辅助理解和优化CUDA编程中的强大功能。
|
9天前
|
存储 缓存 算法
前端算法:优化与实战技巧的深度探索
【10月更文挑战第21天】前端算法:优化与实战技巧的深度探索
10 1
|
10天前
|
机器学习/深度学习 算法 数据安全/隐私保护
基于贝叶斯优化CNN-LSTM网络的数据分类识别算法matlab仿真
本项目展示了基于贝叶斯优化(BO)的CNN-LSTM网络在数据分类中的应用。通过MATLAB 2022a实现,优化前后效果对比明显。核心代码附带中文注释和操作视频,涵盖BO、CNN、LSTM理论,特别是BO优化CNN-LSTM网络的batchsize和学习率,显著提升模型性能。
|
17天前
|
机器学习/深度学习 人工智能 算法
[大语言模型-算法优化] 微调技术-LoRA算法原理及优化应用详解
[大语言模型-算法优化] 微调技术-LoRA算法原理及优化应用详解
52 0
[大语言模型-算法优化] 微调技术-LoRA算法原理及优化应用详解
|
10天前
|
数据采集 缓存 算法
算法优化的常见策略有哪些
【10月更文挑战第20天】算法优化的常见策略有哪些
|
10天前
|
缓存 分布式计算 监控
算法优化:提升程序性能的艺术
【10月更文挑战第20天】算法优化:提升程序性能的艺术
|
10天前
|
缓存 分布式计算 监控
优化算法和代码需要注意什么
【10月更文挑战第20天】优化算法和代码需要注意什么
14 0
|
15天前
|
机器学习/深度学习 算法 数据安全/隐私保护
基于贝叶斯优化卷积神经网络(Bayes-CNN)的多因子数据分类识别算法matlab仿真
本项目展示了贝叶斯优化在CNN中的应用,包括优化过程、训练与识别效果对比,以及标准CNN的识别结果。使用Matlab2022a开发,提供完整代码及视频教程。贝叶斯优化通过构建代理模型指导超参数优化,显著提升模型性能,适用于复杂数据分类任务。