手动构造一个params_dict列表来初始化Optimizer

简介: q

Optimizer基本属性

所有Optimizer公有的一些基本属性:

lr: learning rate,学习率

eps: 学习率最小值,在动态更新学习率时,学习率最小不会小于该值。

weight_decay: 权值衰减。相当于对参数进行L2正则化(使模型复杂度尽可能低,防止过拟合),该值可以理解为正则化项的系数。

betas: (待研究)

amsgrad: (bool)(待研究)

每个Optimizer都维护一个param_groups的list。该list中维护需要优化的参数以及对应的属性设置。

optimizer基本方法
add_param_group(param_group): 为optimizer的param_groups增加一个参数组。这在微调预先训练的网络时非常有用,因为冻结层可以训练并随着训练的进行添加到优化器中。

load_state_dict(state_dict): 加载optimizer state。参数必须是optimizer.state_dict()返回的对象。

state_dict(): 返回一个dict,包含optimizer的状态:state和param_groups。

step(closure): 执行一次参数更新过程。

zero_grad(): 清除所有已经更新的参数的梯度。

我们在构造优化器时,最简单的方法通常如下:

model.parameters()返回网络model的全部参数。

将model的全部参数传入Adam中构造出一个Adam优化器,并设置 learning rate=0.1。因此该 Adam 优化器的 param_groups 维护的就是模型 model 的全部参数,并且学习率为0.1。这样在调用optimizer_Adam.step()时,就会对model的全部参数进行更新。

Optimizer的param_groups是一个list,其中的每个元素都是一组独立的参数,以dict的方式存储。结构如下:

这样可以实现很多灵活的操作,比如以下。

只训练模型的一部分参数
例如,只想训练上面的model中的layer参数,而保持layer2的参数不动。可以如下设置Optimizer:

不同部分的参数设置不同的学习率(以及其他属性)
例如,要想使model的layer参数学习率为0.1,layer2的参数学习率为0.2,可以如下设置Optimizer:

这种方法更为灵活,手动构造一个params_dict列表来初始化Optimizer。注意,字典中的参数部分的 key 必须为 ‘params’。

相关文章
|
存储 JSON 数据格式
数据集加载时报错'dict' object has no attribute 'requests‘
数据集加载时报错'dict' object has no attribute 'requests‘
289 5
|
3月前
|
Rust 编译器 C++
使用 def、cdef、cpdef 创建函数
使用 def、cdef、cpdef 创建函数
61 0
|
4月前
|
搜索推荐 API UED
路由的query参数和params参数
理解并正确使用Query参数和Params参数,是构建清晰、高效Web应用的关键之一。开发者应根据实际应用场景灵活选择参数类型,从而优化用户体验和应用性能。
196 6
|
6月前
|
容器
C++11 列表初始化(initializer_list),pair
C++11 列表初始化(initializer_list),pair
|
Python
python之列表中常用的函数:append,extend,insert,pop,remove,del函数的定义与使用方法,元素是否在列表中的判断
python之列表中常用的函数:append,extend,insert,pop,remove,del函数的定义与使用方法,元素是否在列表中的判断
153 0
|
7月前
|
XML SQL JSON
query 与 params:选择正确的参数传递方式
query 与 params:选择正确的参数传递方式
dict中所有方法的使用
提示:以下是本篇文章正文内容,下面案例可供参考
59 0
|
C++ Python
python类中初始化形式:def __init__(self)和def __init__(self, 参数1,参数2,,,参数n)区别
python类中初始化形式:def __init__(self)和def __init__(self, 参数1,参数2,,,参数n)区别
173 0
|
机器学习/深度学习 存储 PyTorch
params.data.clone()是什么意思?params是模型的参数
在深度学习中,模型的参数通常是由多个张量组成的。这些张量存储了模型在训练过程中学到的权重和偏置等参数。 params.data 是一个张量,其中包含了模型的参数数据。clone() 是 PyTorch 中的一个方法,它用于创建一个与当前张量具有相同数据但不同内存地址的新张量。 因此,params.data.clone() 的意思是创建一个与 params.data 张量具有相同数据但不同内存地址的新张量。通常,这个方法被用来复制模型参数,以便在优化器中使用。
258 0
|
缓存 Python