基于PyTorch对凸函数采用SGD算法优化实例(附源码)

简介: 基于PyTorch对凸函数采用SGD算法优化实例(附源码)

0. 前言

按照国际惯例,首先声明:本文只是我自己学习的理解,虽然参考了他人的宝贵见解及成果,但是内容可能存在不准确的地方。如果发现文中错误,希望批评指正,共同进步。

本文基于PyTorch实例说明SGD(随机梯度下降)优化方法。

随机梯度下降(Stochastic Gradient Descent, SGD)是一种在机器学习和深度学习中广泛使用的优化算法,用于最小化模型的损失函数。SGD 适用于大规模数据集和复杂的模型,尤其是在训练神经网络时。

1. SGD算法介绍

SGD 的思想源于传统的梯度下降法(GD),后者是通过计算整个数据集上的损失函数梯度来确定参数更新的方向和步长。然而,在大数据场景下,每次迭代都遍历所有样本的计算成本很高。SGD 就是在这种背景下提出的,它采用了一个更为高效的策略:在每一步迭代中,仅随机抽取一个样本来估计梯度,然后用这个估计梯度更新模型参数。

1.1 基本流程
  1. 初始化模型参数 θ。
  2. 对于每个训练迭代:
    a. 随机选择一个样本 (x, y) 或者一个包含多个样本的小批量数据(mini-batch)。
    b. 计算选定样本或小批量数据上的损失函数 L ( θ ) L(θ) L(θ)关于当前参数 θ θ θ 的梯度 ∇ L ( θ ) ∇L(θ) L(θ)
    c. 使用学习率 η 来更新模型参数: θ i + 1 = θ i − η ∇ L ( θ i ) θ_{i+1} = θ_i - η∇L(θ_i) θi+1=θiηL(θi)
  3. 重复步骤2直到达到预设的停止条件(如达到最大迭代次数、损失函数收敛等)。
1.2 SGD特点
  • 优点
  • 计算效率高,尤其对于大规模数据集,可以快速获得反馈并进行更新。
  • 具有在线学习的能力,可以适应实时流式数据输入。
  • 由于噪声的存在,SGD有助于避免局部极小点,并有可能跳到全局最优附近。
  • 缺点
  • 梯度估计具有随机性,可能导致训练过程不稳定,收敛速度也可能会受噪声影响而变慢。
  • 学习率的选择非常关键,如果过大可能错过最优解,过小则会导致收敛慢。
  • 不像批量梯度下降那样能准确反映整体数据集的趋势。
1.3 SGD进阶形式

为了改进标准 SGD 的性能,研究者们提出了一系列增强版的 SGD 算法,例如:

  • Momentum(动量):引入历史梯度信息,减少振荡并加速收敛。
  • Nesterov Accelerated Gradient (NAG):在计算梯度时提前考虑动量的影响。
  • Adagrad、RMSprop、Adadelta:自适应调整学习率,对不同参数使用不同的学习率。
  • Adam 和 AdaMax:基于梯度的一阶矩和二阶矩估计进行自适应学习率调整,是目前最常用的优化器之一。

这些算法结合了 SGD 的随机性和其它技术以提高优化效果和稳定性。

2. 实例说明

基于Pytorch,手动编写SGD(随机梯度下降)方法,求-sin2(x)-sin2(y)的最小值,x∈[-2.5 , 2.5] , y∈[-2.5 , 2.5]。

画一下要拟合的函数图像:

代码

import matplotlib.pyplot
from mpl_toolkits.mplot3d import Axes3D
import numpy
import os  #这句要加,要不画图会报错
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' #这句要加,要不画图会报错
def fix_fun(x,y):  #构建函数fix_fun = (x^2+y+10)^4+(x+y^2-5)^4
    return -numpy.sin(x)**2 - numpy.sin(y)**2
#画一下函数图像
x = numpy.arange(-2.5,2.5,0.1)
y = numpy.arange(-2.5,2.5,0.1)
x, y = numpy.meshgrid(x, y)
z = fix_fun(x,y)
fig = matplotlib.pyplot.figure()
ax = Axes3D(fig)
ax.plot_surface(x,y,z, cmap='rainbow')
matplotlib.pyplot.show()

画出来后,曲面长下面的样子↓

3. SGD算法构建思路

非常简单:用.backward()方法求出要优化的函数的梯度,按照SGD的定义找出最优解。

4. 运行结果

看看随机生成的几个优化的结果

5. 源码

老规矩,细节的备注仍然写在代码里面。

import torch
from torch.autograd import Variable
import matplotlib.pyplot
from mpl_toolkits.mplot3d import Axes3D
import numpy
import os  #这句要加,要不画图会报错
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' #这句要加,要不画图会报错
def fix_fun_numpy(x,y):  #构建函数fix_fun(因为tensor数据类型和numpy画图不能混用,构建两个函数)
    return -numpy.sin(x)**2 - numpy.sin(y)**2
def fix_fun_tensor(x,y):  #构建函数fix_fun
    return -torch.sin(x)**2 - torch.sin(y)**2
#画一下函数图像
x = numpy.arange(-2.5,2.5,0.1)
y = numpy.arange(-2.5,2.5,0.1)
x, y = numpy.meshgrid(x, y)
z = fix_fun_numpy(x,y)
fig = matplotlib.pyplot.figure()
ax = Axes3D(fig)
#构建SGD算法
X = (torch.rand(2,dtype=float)-0.5)*5 #0~1分布改成-2.5~2.5分布,生成随机初始点(所以叫‘随机’梯度下降)
X_dots = []
Y_dots = []
Z_dots = []
for iter in range(50):
    X = Variable(X, requires_grad = True)  #需要求导的时候要加这一句
    Z = fix_fun_tensor(X[0],X[1])  #这一句一定要写在Variable后面
    Z.backward()  #求导(梯度)
    X = X - 0.1*X.grad #learning rate=0.1
    X_dots.append(X[0].detach().numpy())  #记录下每个学习过程点(为了后面画出学习的路径)
    Y_dots.append(X[1].detach().numpy())
    Z_dots.append(Z.detach().numpy())
ax.plot_wireframe(x, y, z)  #换成wireframe,因为用surface会和曲线重合,看不出来
ax.plot(X_dots,Y_dots,Z_dots,color='red')
matplotlib.pyplot.show()

6. 后记

喜欢探索的同学还可以把上面的源码改一改,并试试Adagrad, Adam, Momentum这些方法,也对比下各自的优劣。


相关文章
|
2月前
|
PyTorch 算法框架/工具
Pytorch学习笔记(五):nn.AdaptiveAvgPool2d()函数详解
PyTorch中的`nn.AdaptiveAvgPool2d()`函数用于实现自适应平均池化,能够将输入特征图调整到指定的输出尺寸,而不需要手动计算池化核大小和步长。
195 1
Pytorch学习笔记(五):nn.AdaptiveAvgPool2d()函数详解
|
2月前
|
PyTorch 算法框架/工具
Pytorch学习笔记(六):view()和nn.Linear()函数详解
这篇博客文章详细介绍了PyTorch中的`view()`和`nn.Linear()`函数,包括它们的语法格式、参数解释和具体代码示例。`view()`函数用于调整张量的形状,而`nn.Linear()`则作为全连接层,用于固定输出通道数。
118 0
Pytorch学习笔记(六):view()和nn.Linear()函数详解
|
2月前
|
PyTorch 算法框架/工具
Pytorch学习笔记(四):nn.MaxPool2d()函数详解
这篇博客文章详细介绍了PyTorch中的nn.MaxPool2d()函数,包括其语法格式、参数解释和具体代码示例,旨在指导读者理解和使用这个二维最大池化函数。
185 0
Pytorch学习笔记(四):nn.MaxPool2d()函数详解
|
2月前
|
PyTorch 算法框架/工具
Pytorch学习笔记(三):nn.BatchNorm2d()函数详解
本文介绍了PyTorch中的BatchNorm2d模块,它用于卷积层后的数据归一化处理,以稳定网络性能,并讨论了其参数如num_features、eps和momentum,以及affine参数对权重和偏置的影响。
250 0
Pytorch学习笔记(三):nn.BatchNorm2d()函数详解
|
3天前
|
机器学习/深度学习 人工智能 PyTorch
使用PyTorch实现GPT-2直接偏好优化训练:DPO方法改进及其与监督微调的效果对比
本文将系统阐述DPO的工作原理、实现机制,以及其与传统RLHF和SFT方法的本质区别。
41 22
使用PyTorch实现GPT-2直接偏好优化训练:DPO方法改进及其与监督微调的效果对比
|
25天前
|
搜索推荐 Python
利用Python内置函数实现的冒泡排序算法
在上述代码中,`bubble_sort` 函数接受一个列表 `arr` 作为输入。通过两层循环,外层循环控制排序的轮数,内层循环用于比较相邻的元素并进行交换。如果前一个元素大于后一个元素,就将它们交换位置。
125 67
|
22天前
|
机器学习/深度学习 人工智能 PyTorch
Transformer模型变长序列优化:解析PyTorch上的FlashAttention2与xFormers
本文探讨了Transformer模型中变长输入序列的优化策略,旨在解决深度学习中常见的计算效率问题。文章首先介绍了批处理变长输入的技术挑战,特别是填充方法导致的资源浪费。随后,提出了多种优化技术,包括动态填充、PyTorch NestedTensors、FlashAttention2和XFormers的memory_efficient_attention。这些技术通过减少冗余计算、优化内存管理和改进计算模式,显著提升了模型的性能。实验结果显示,使用FlashAttention2和无填充策略的组合可以将步骤时间减少至323毫秒,相比未优化版本提升了约2.5倍。
41 3
Transformer模型变长序列优化:解析PyTorch上的FlashAttention2与xFormers
|
7天前
|
机器学习/深度学习 前端开发 算法
婚恋交友系统平台 相亲交友平台系统 婚恋交友系统APP 婚恋系统源码 婚恋交友平台开发流程 婚恋交友系统架构设计 婚恋交友系统前端/后端开发 婚恋交友系统匹配推荐算法优化
婚恋交友系统平台通过线上互动帮助单身男女找到合适伴侣,提供用户注册、个人资料填写、匹配推荐、实时聊天、社区互动等功能。开发流程包括需求分析、技术选型、系统架构设计、功能实现、测试优化和上线运维。匹配推荐算法优化是核心,通过用户行为数据分析和机器学习提高匹配准确性。
32 3
|
1月前
|
监控 PyTorch 数据处理
通过pin_memory 优化 PyTorch 数据加载和传输:工作原理、使用场景与性能分析
在 PyTorch 中,`pin_memory` 是一个重要的设置,可以显著提高 CPU 与 GPU 之间的数据传输速度。当 `pin_memory=True` 时,数据会被固定在 CPU 的 RAM 中,从而加快传输到 GPU 的速度。这对于处理大规模数据集、实时推理和多 GPU 训练等任务尤为重要。本文详细探讨了 `pin_memory` 的作用、工作原理及最佳实践,帮助你优化数据加载和传输,提升模型性能。
85 4
通过pin_memory 优化 PyTorch 数据加载和传输:工作原理、使用场景与性能分析
|
1月前
|
搜索推荐 算法 C语言
【排序算法】八大排序(上)(c语言实现)(附源码)
本文介绍了四种常见的排序算法:冒泡排序、选择排序、插入排序和希尔排序。通过具体的代码实现和测试数据,详细解释了每种算法的工作原理和性能特点。冒泡排序通过不断交换相邻元素来排序,选择排序通过选择最小元素进行交换,插入排序通过逐步插入元素到已排序部分,而希尔排序则是插入排序的改进版,通过预排序使数据更接近有序,从而提高效率。文章最后总结了这四种算法的空间和时间复杂度,以及它们的稳定性。
102 8