SGD、Adam

简介: 【9月更文挑战第23天】

在深度学习中,optimizer.zero_grad() 是一个非常重要的步骤,它用于在每次迭代开始前清除(重置)模型参数的梯度。这是因为大多数优化算法(如SGD、Adam等)都是基于梯度的,它们需要在每次参数更新后清零梯度,以避免梯度累积。

为什么需要 optimizer.zero_grad()

在训练神经网络时,我们使用反向传播算法来计算每个参数的梯度。这些梯度表示损失函数相对于每个参数的敏感度,告诉我们应该如何调整参数以最小化损失。在每次迭代中,我们根据这些梯度来更新参数。

如果不在每次迭代前清零梯度,梯度将会累积,即新的梯度会加到旧的梯度上。这会导致错误的更新,因为每次更新应该只基于当前迭代计算的梯度。

如何使用 optimizer.zero_grad()

在使用PyTorch等深度学习框架时,optimizer.zero_grad() 通常在每次训练迭代的开始被调用。以下是一个简单的训练循环示例,展示了如何在PyTorch中使用它:

import torch
import torch.nn as nn
import torch.optim as optim

# 定义一个简单的模型
model = nn.Sequential(
    nn.Linear(10, 5),
    nn.ReLU(),
    nn.Linear(5, 1)
)

# 定义损失函数
loss_fn = nn.MSELoss()

# 定义优化器
optimizer = optim.SGD(model.parameters(), lr=0.01)

# 假设有一些训练数据
inputs = torch.randn(5, 10)
targets = torch.randn(5, 1)

# 训练循环
for epoch in range(5):
    # 清除梯度
    optimizer.zero_grad()

    # 前向传播
    outputs = model(inputs)

    # 计算损失
    loss = loss_fn(outputs, targets)

    # 反向传播
    loss.backward()

    # 参数更新
    optimizer.step()

    # 打印损失
    print(f'Epoch {epoch+1}, Loss: {loss.item()}')

在这个例子中,我们首先定义了一个简单的线性模型、损失函数和优化器。在训练循环中,我们使用 optimizer.zero_grad() 来清除梯度,然后进行前向传播、计算损失、反向传播和参数更新。

代码解释:

  • optimizer.zero_grad(): 清除模型参数的梯度。
  • loss.backward(): 计算梯度,这是通过反向传播自动完成的。
  • optimizer.step(): 根据计算得到的梯度更新模型参数。
目录
相关文章
|
机器学习/深度学习 PyTorch 算法框架/工具
训练误差与泛化误差的说明
训练误差与泛化误差的说明
548 0
|
机器学习/深度学习 并行计算 PyTorch
【已解决】RuntimeError: CUDA error: device-side assert triggeredCUDA kernel errors might be asynchronous
【已解决】RuntimeError: CUDA error: device-side assert triggeredCUDA kernel errors might be asynchronous
|
8月前
|
JSON Shell 数据格式
使用 pipx 安装并执行 Python 应用程序 (1)
使用 pipx 安装并执行 Python 应用程序 (1)
722 17
|
安全 算法 C++
专题九Simulink仿真基础-2
专题九Simulink仿真基础
343 1
|
自动驾驶 物联网 5G
什么是 5G 以及它如何工作?
【8月更文挑战第23天】
2453 0
|
存储 编解码 数据可视化
Transformers 4.37 中文文档(九十一)(4)
Transformers 4.37 中文文档(九十一)
140 0
Transformers 4.37 中文文档(九十一)(4)
|
弹性计算 网络协议 Linux
为什么我的幻兽帕鲁服务器搭建好了之后连不上,提示超时?
幻兽帕鲁服务器刚刚搭建完成,你一定迫不及待的的想要连上去玩耍了,但是连接等待半天后,不是进入到游戏而是提示超时,令人崩溃。
9311 2
|
搜索推荐 数据可视化 Python
Matplotlib图表中的数据标签与图例设置
【4月更文挑战第17天】这篇文章介绍了如何在Python的Matplotlib库中设置数据标签和图例,以增强图表的可读性和解释性。主要内容包括:使用`text`函数添加基本和自定义数据标签,以及自动和手动创建图例。图例的位置和样式可通过`loc`和相关参数调整。文章强调了数据标签和图例结合使用的重要性,提供了一个综合示例来展示实践方法。良好的图表设计旨在清晰有效地传达信息。
|
监控 物联网 智能硬件
MQTT 持久会话与 Clean Session 详解
【2月更文挑战第17天】
895 5
|
测试技术 开发工具 git
【git 实用指南】Git提交指南:如何制定团队友好的提交规则
【git 实用指南】Git提交指南:如何制定团队友好的提交规则
604 0