基于PyTorch对凸函数采用SGD算法优化实例(附源码)

简介: 基于PyTorch对凸函数采用SGD算法优化实例(附源码)

0. 前言

按照国际惯例,首先声明:本文只是我自己学习的理解,虽然参考了他人的宝贵见解及成果,但是内容可能存在不准确的地方。如果发现文中错误,希望批评指正,共同进步。

本文基于PyTorch实例说明SGD(随机梯度下降)优化方法。

随机梯度下降(Stochastic Gradient Descent, SGD)是一种在机器学习和深度学习中广泛使用的优化算法,用于最小化模型的损失函数。SGD 适用于大规模数据集和复杂的模型,尤其是在训练神经网络时。

1. SGD算法介绍

SGD 的思想源于传统的梯度下降法(GD),后者是通过计算整个数据集上的损失函数梯度来确定参数更新的方向和步长。然而,在大数据场景下,每次迭代都遍历所有样本的计算成本很高。SGD 就是在这种背景下提出的,它采用了一个更为高效的策略:在每一步迭代中,仅随机抽取一个样本来估计梯度,然后用这个估计梯度更新模型参数。

1.1 基本流程
  1. 初始化模型参数 θ。
  2. 对于每个训练迭代:
    a. 随机选择一个样本 (x, y) 或者一个包含多个样本的小批量数据(mini-batch)。
    b. 计算选定样本或小批量数据上的损失函数 L ( θ ) L(θ) L(θ)关于当前参数 θ θ θ 的梯度 ∇ L ( θ ) ∇L(θ) L(θ)
    c. 使用学习率 η 来更新模型参数: θ i + 1 = θ i − η ∇ L ( θ i ) θ_{i+1} = θ_i - η∇L(θ_i) θi+1=θiηL(θi)
  3. 重复步骤2直到达到预设的停止条件(如达到最大迭代次数、损失函数收敛等)。
1.2 SGD特点
  • 优点
  • 计算效率高,尤其对于大规模数据集,可以快速获得反馈并进行更新。
  • 具有在线学习的能力,可以适应实时流式数据输入。
  • 由于噪声的存在,SGD有助于避免局部极小点,并有可能跳到全局最优附近。
  • 缺点
  • 梯度估计具有随机性,可能导致训练过程不稳定,收敛速度也可能会受噪声影响而变慢。
  • 学习率的选择非常关键,如果过大可能错过最优解,过小则会导致收敛慢。
  • 不像批量梯度下降那样能准确反映整体数据集的趋势。
1.3 SGD进阶形式

为了改进标准 SGD 的性能,研究者们提出了一系列增强版的 SGD 算法,例如:

  • Momentum(动量):引入历史梯度信息,减少振荡并加速收敛。
  • Nesterov Accelerated Gradient (NAG):在计算梯度时提前考虑动量的影响。
  • Adagrad、RMSprop、Adadelta:自适应调整学习率,对不同参数使用不同的学习率。
  • Adam 和 AdaMax:基于梯度的一阶矩和二阶矩估计进行自适应学习率调整,是目前最常用的优化器之一。

这些算法结合了 SGD 的随机性和其它技术以提高优化效果和稳定性。

2. 实例说明

基于Pytorch,手动编写SGD(随机梯度下降)方法,求-sin2(x)-sin2(y)的最小值,x∈[-2.5 , 2.5] , y∈[-2.5 , 2.5]。

画一下要拟合的函数图像:

代码

import matplotlib.pyplot
from mpl_toolkits.mplot3d import Axes3D
import numpy
import os  #这句要加,要不画图会报错
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' #这句要加,要不画图会报错
def fix_fun(x,y):  #构建函数fix_fun = (x^2+y+10)^4+(x+y^2-5)^4
    return -numpy.sin(x)**2 - numpy.sin(y)**2
#画一下函数图像
x = numpy.arange(-2.5,2.5,0.1)
y = numpy.arange(-2.5,2.5,0.1)
x, y = numpy.meshgrid(x, y)
z = fix_fun(x,y)
fig = matplotlib.pyplot.figure()
ax = Axes3D(fig)
ax.plot_surface(x,y,z, cmap='rainbow')
matplotlib.pyplot.show()

画出来后,曲面长下面的样子↓

3. SGD算法构建思路

非常简单:用.backward()方法求出要优化的函数的梯度,按照SGD的定义找出最优解。

4. 运行结果

看看随机生成的几个优化的结果

5. 源码

老规矩,细节的备注仍然写在代码里面。

import torch
from torch.autograd import Variable
import matplotlib.pyplot
from mpl_toolkits.mplot3d import Axes3D
import numpy
import os  #这句要加,要不画图会报错
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' #这句要加,要不画图会报错
def fix_fun_numpy(x,y):  #构建函数fix_fun(因为tensor数据类型和numpy画图不能混用,构建两个函数)
    return -numpy.sin(x)**2 - numpy.sin(y)**2
def fix_fun_tensor(x,y):  #构建函数fix_fun
    return -torch.sin(x)**2 - torch.sin(y)**2
#画一下函数图像
x = numpy.arange(-2.5,2.5,0.1)
y = numpy.arange(-2.5,2.5,0.1)
x, y = numpy.meshgrid(x, y)
z = fix_fun_numpy(x,y)
fig = matplotlib.pyplot.figure()
ax = Axes3D(fig)
#构建SGD算法
X = (torch.rand(2,dtype=float)-0.5)*5 #0~1分布改成-2.5~2.5分布,生成随机初始点(所以叫‘随机’梯度下降)
X_dots = []
Y_dots = []
Z_dots = []
for iter in range(50):
    X = Variable(X, requires_grad = True)  #需要求导的时候要加这一句
    Z = fix_fun_tensor(X[0],X[1])  #这一句一定要写在Variable后面
    Z.backward()  #求导(梯度)
    X = X - 0.1*X.grad #learning rate=0.1
    X_dots.append(X[0].detach().numpy())  #记录下每个学习过程点(为了后面画出学习的路径)
    Y_dots.append(X[1].detach().numpy())
    Z_dots.append(Z.detach().numpy())
ax.plot_wireframe(x, y, z)  #换成wireframe,因为用surface会和曲线重合,看不出来
ax.plot(X_dots,Y_dots,Z_dots,color='red')
matplotlib.pyplot.show()

6. 后记

喜欢探索的同学还可以把上面的源码改一改,并试试Adagrad, Adam, Momentum这些方法,也对比下各自的优劣。


目录
打赏
0
4
3
1
19
分享
相关文章
基于WOA鲸鱼优化的BiLSTM双向长短期记忆网络序列预测算法matlab仿真,对比BiLSTM和LSTM
本项目基于MATLAB 2022a/2024b实现,采用WOA优化的BiLSTM算法进行序列预测。核心代码包含完整中文注释与操作视频,展示从参数优化到模型训练、预测的全流程。BiLSTM通过前向与后向LSTM结合,有效捕捉序列前后文信息,解决传统RNN梯度消失问题。WOA优化超参数(如学习率、隐藏层神经元数),提升模型性能,避免局部最优解。附有运行效果图预览,最终输出预测值与实际值对比,RMSE评估精度。适合研究时序数据分析与深度学习优化的开发者参考。
基于GA遗传优化的BiLSTM双向长短期记忆网络序列预测算法matlab仿真,对比BiLSTM和LSTM
本内容包含基于BiLSTM与遗传算法(GA)的算法介绍及实现。算法通过MATLAB2022a/2024b运行,核心为优化BiLSTM超参数(如学习率、神经元数量),提升预测性能。LSTM解决传统RNN梯度问题,捕捉长期依赖;BiLSTM双向处理序列,融合前文后文信息,适合全局信息任务。附完整代码(含注释)、操作视频及无水印运行效果预览,适用于股票预测等场景,精度优于单向LSTM。
基于PSO粒子群优化的BiLSTM双向长短期记忆网络序列预测算法matlab仿真,对比BiLSTM和LSTM
本项目基于MATLAB2022a/2024b开发,结合粒子群优化(PSO)算法与双向长短期记忆网络(BiLSTM),用于优化序列预测任务中的模型参数。核心代码包含详细中文注释及操作视频,涵盖遗传算法优化过程、BiLSTM网络构建、训练及预测分析。通过PSO优化BiLSTM的超参数(如学习率、隐藏层神经元数等),显著提升模型捕捉长期依赖关系和上下文信息的能力,适用于气象、交通流量等场景。附有运行效果图预览,展示适应度值、RMSE变化及预测结果对比,验证方法有效性。
基于遗传优化ELM网络的时间序列预测算法matlab仿真
本项目实现了一种基于遗传算法优化的极限学习机(GA-ELM)网络时间序列预测方法。通过对比传统ELM与GA-ELM,验证了参数优化对非线性时间序列预测精度的提升效果。核心程序利用MATLAB 2022A完成,采用遗传算法全局搜索最优权重与偏置,结合ELM快速训练特性,显著提高模型稳定性与准确性。实验结果展示了GA-ELM在复杂数据中的优越表现,误差明显降低。此方法适用于金融、气象等领域的时间序列预测任务。
|
21天前
|
基于遗传优化算法的带时间窗多车辆路线规划matlab仿真
本程序基于遗传优化算法,实现带时间窗的多车辆路线规划,并通过MATLAB2022A仿真展示结果。输入节点坐标与时间窗信息后,算法输出最优路径规划方案。示例结果包含4条路线,覆盖所有节点并满足时间窗约束。核心代码包括初始化、适应度计算、交叉变异及局部搜索等环节,确保解的质量与可行性。遗传算法通过模拟自然进化过程,逐步优化种群个体,有效解决复杂约束条件下的路径规划问题。
基于差分进化灰狼混合优化的SVM(DE-GWO-SVM)数据预测算法matlab仿真
本项目实现基于差分进化灰狼混合优化的SVM(DE-GWO-SVM)数据预测算法的MATLAB仿真,对比SVM和GWO-SVM性能。算法结合差分进化(DE)与灰狼优化(GWO),优化SVM参数以提升复杂高维数据预测能力。核心流程包括DE生成新种群、GWO更新位置,迭代直至满足终止条件,选出最优参数组合。适用于分类、回归等任务,显著提高模型效率与准确性,运行环境为MATLAB 2022A。
基于PSO粒子群优化的多无人机路径规划matlab仿真,对比WOA优化算法
本程序基于粒子群优化(PSO)算法实现多无人机路径规划,并与鲸鱼优化算法(WOA)进行对比。使用MATLAB2022A运行,通过四个无人机的仿真,评估两种算法在能耗、复杂度、路径规划效果及收敛曲线等指标上的表现。算法原理源于1995年提出的群体智能优化,模拟鸟群觅食行为,在搜索空间中寻找最优解。环境建模采用栅格或几何法,考虑避障、速度限制等因素,将约束条件融入适应度函数。程序包含初始化粒子群、更新速度与位置、计算适应度值、迭代优化等步骤,最终输出最优路径。
基于GWO灰狼优化的BiLSTM双向长短期记忆网络序列预测算法matlab仿真,对比BiLSTM和LSTM
本项目基于Matlab 2022a/2024b实现,结合灰狼优化(GWO)算法与双向长短期记忆网络(BiLSTM),用于序列预测任务。核心代码包含数据预处理、种群初始化、适应度计算及参数优化等步骤,完整版附带中文注释与操作视频。BiLSTM通过前向与后向处理捕捉序列上下文信息,GWO优化其参数以提升预测性能。效果图展示训练过程与预测结果,适用于气象、交通等领域。LSTM结构含输入门、遗忘门与输出门,解决传统RNN梯度问题,而BiLSTM进一步增强上下文理解能力。
基于WOA鲸鱼优化的TCN-GRU时间卷积神经网络时间序列预测算法matlab仿真
本内容包含时间序列预测算法的相关资料,涵盖以下几个方面:1. 算法运行效果预览(无水印);2. 运行环境为Matlab 2022a/2024b;3. 提供部分核心程序,完整版含中文注释及操作视频;4. 理论概述:结合时间卷积神经网络(TCN)与鲸鱼优化算法(WOA),优化TCN超参数以提升非线性时间序列预测性能。通过因果卷积层与残差连接构建TCN模型,并用WOA调整卷积核大小、层数等参数,实现精准预测。适用于金融、气象等领域决策支持。

热门文章

最新文章

推荐镜像

更多
AI助理

你好,我是AI助理

可以解答问题、推荐解决方案等