torch 如何在迭代训练中,只计算参数梯度,不更新参数。并且将参数的梯度保存到列表,这个列表的值不随着梯度清零而跟着变为零

简介: 在使用 torch.autograd.backward() 计算梯度后,可以通过设置参数的 requires_grad 属性为 False 来阻止参数更新。同时,通过将计算出的梯度值放入一个列表中来保存,这个列表的值不会随着梯度清零而清零

在使用 torch.autograd.backward() 计算梯度后,可以通过设置参数的 requires_grad 属性为 False 来阻止参数更新。同时,通过将计算出的梯度值放入一个列表中来保存,这个列表的值不会随着梯度清零而清零。下面是一个示例代码:

import torch
# 定义网络和优化器
model = torch.nn.Linear(10, 1)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# 定义保存梯度的列表
grads_list = []
# 迭代训练过程
for i in range(10):
    # 获取数据并前向计算
    inputs, labels = get_data()
    outputs = model(inputs)
    # 计算损失函数并进行反向传播计算梯度
    loss = compute_loss(outputs, labels)
    model.zero_grad()
    loss.backward(retain_graph=True)
    # 保存梯度并更新优化器状态
    grads_list.append([p.grad.detach().clone() for p in model.parameters()])
    optimizer.step()
    # 在某些情况下需要清空梯度,但不影响保存下来的梯度值
    for p in model.parameters():
        p.grad = None

在上述代码中,我们首先定义了一个保存梯度的列表 grads_list,然后在每次迭代中使用 loss.backward() 计算梯度,并将计算出的梯度值放入 grads_list 中。由于我们在更新参数之前将所有参数的 requires_grad 属性设置为 False,所以这些参数不会被优化器更新。最后,我们清空梯度以便进行下一轮迭代,并通过 p.grad = None 来清除计算图中的梯度信息,但不影响保存下来的梯度值。

相关文章
|
Python 缓存
Python ChainMap:链式映射的妙用与实战解析
【4月更文挑战第1天】Python中的`collections`模块提供了一个名为`ChainMap`的类,它实现了多个字典的链式查找。`ChainMap`将多个字典组织成一个逻辑上的单一字典,允许你像操作单个字典一样来访问这些字典。当在`ChainMap`中查找一个键时,它会按照字典被添加的顺序从前向后依次查找,直到找到匹配的键为止。如果找不到,就会抛出`KeyError`。
|
弹性计算 关系型数据库 数据库
搭建PostgreSQL主从架构
PostgreSQL是一个关系型数据库管理系统(RDBMS),支持NoSQL数据类型(JSON/XML/hstore)。本教程介绍如何在两台CentOS 7操作系统的ECS实例上搭建PostgreSQL主从架构。
|
Linux iOS开发 MacOS
【MCP教程系列】阿里云百炼MCP全面配置指南:涵盖NPX、UVX、SSE及Streamable HTTP
本文详细介绍如何在阿里云百炼平台及Windows、Linux、MacOS系统中正确配置MCP服务的JSON文件。内容涵盖三种MCP服务配置:npx(基于Stdio)、uvx(Python工具运行)和SSE(服务器发送事件)。同时解析Streamable HTTP作为新一代传输方案的优势与应用,帮助用户掌握每个参数的具体用途及使用方法,解决配置过程中可能遇到的问题,提供完整示例和扩展信息以优化设置体验。
3366 11
|
11月前
|
搜索推荐 物联网 PyTorch
Qwen2.5-7B-Instruct Lora 微调
本教程介绍如何基于Transformers和PEFT框架对Qwen2.5-7B-Instruct模型进行LoRA微调。
11914 34
Qwen2.5-7B-Instruct Lora 微调
|
计算机视觉
【CV大模型SAM(Segment-Anything)】如何保存分割后的对象mask?并提取mask对应的图片区域?
【CV大模型SAM(Segment-Anything)】如何保存分割后的对象mask?并提取mask对应的图片区域?
【CV大模型SAM(Segment-Anything)】如何保存分割后的对象mask?并提取mask对应的图片区域?
|
消息中间件 缓存 运维
中间件数据一致性和可靠性问题
【7月更文挑战第14天】
246 1
中间件数据一致性和可靠性问题
|
机器学习/深度学习 算法 机器人
【博士每天一篇文献-算法】改进的PNN架构Lifelong learning with dynamically expandable networks
本文介绍了一种名为Dynamically Expandable Network(DEN)的深度神经网络架构,它能够在学习新任务的同时保持对旧任务的记忆,并通过动态扩展网络容量和选择性重训练机制,有效防止语义漂移,实现终身学习。
315 9
`cmd`模块是Python标准库中的一个模块,它提供了一个简单的框架来创建命令行解释器。
`cmd`模块是Python标准库中的一个模块,它提供了一个简单的框架来创建命令行解释器。
|
计算机视觉 索引
【OpenCV】直方图计算 & 均衡化直方图
【OpenCV】直方图计算 & 均衡化直方图
532 3
|
NoSQL Serverless Python
在Python的Pandas中,可以通过直接赋值或使用apply函数在DataFrame添加新列。
【5月更文挑战第2天】在Python的Pandas中,可以通过直接赋值或使用apply函数在DataFrame添加新列。方法一是直接赋值,如`df['C'] = 0`,创建新列C并初始化为0。方法二是应用函数,例如定义`add_column`函数计算A列和B列之和,然后使用`df.apply(add_column, axis=1)`,使C列存储每行A、B列的和。
762 0