pytorch中optimizer为不同参数设置不同的学习率

简介: pytorch中optimizer为不同参数设置不同的学习率

在pytorch中已经实现了一些常见的优化器,例如Adam、SGD、Adagrad、RMsprop等,但是有些任务中我们需要设定不同的学习策略,例如给模型的不同参数设置不同的学习率

class Linear(nn.Module):
    def __init__(self):
        super().__init__()
        self.w1 = nn.Parameter(torch.randn(3, 4))
        self.b1 = nn.Parameter(torch.randn(1, 3))
        self.w2 = nn.Parameter(torch.randn(3, 2))
        self.b2 = nn.Parameter(torch.randn(1, 2))
    def forward(self, x):
        x = F.linear(x, self.w1, self.b1)
        return F.linear(x, self.w2, self.b2)

该网络我们定义了4个可学习参数,2个是权重矩阵w,2个是偏置矩阵b,我们假定要为w矩阵设置学习率为1e-2,而为b矩阵设置为1e-3。

实现这种需求其实很简单,只需要在定义优化器时传入一个字典,分别传入需要优化的参数列表,以及对应的学习率。

model = Linear()
w_params = [param for name, param in model.named_parameters() if 'w' in name]
b_params = [param for name, param in model.named_parameters() if 'b' in name]
optimizer = torch.optim.Adam([
    {'params': w_params, 'lr': 1e-2},
    {'params': b_params, 'lr': 1e-3}
])


目录
相关文章
|
2月前
|
机器学习/深度学习 算法 PyTorch
【PyTorch实战演练】自调整学习率实例应用(附代码)
【PyTorch实战演练】自调整学习率实例应用(附代码)
46 0
|
9月前
|
机器学习/深度学习 缓存 监控
Pytorch学习笔记(7):优化器、学习率及调整策略、动量
Pytorch学习笔记(7):优化器、学习率及调整策略、动量
362 0
Pytorch学习笔记(7):优化器、学习率及调整策略、动量
|
16天前
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch与迁移学习:利用预训练模型提升性能
【4月更文挑战第18天】PyTorch支持迁移学习,助力提升深度学习性能。预训练模型(如ResNet、VGG)在大规模数据集(如ImageNet)训练后,可在新任务中加速训练,提高准确率。通过选择模型、加载预训练权重、修改结构和微调,可适应不同任务需求。迁移学习节省资源,但也需考虑源任务与目标任务的相似度及超参数选择。实践案例显示,预训练模型能有效提升小数据集上的图像分类任务性能。未来,迁移学习将继续在深度学习领域发挥重要作用。
|
2月前
|
机器学习/深度学习 PyTorch 算法框架/工具
通过实例学习Pytorch加载权重.load_state_dict()与保存权重.save()
通过实例学习Pytorch加载权重.load_state_dict()与保存权重.save()
17 0
|
2月前
|
机器学习/深度学习 算法 PyTorch
PyTorch使用Tricks:学习率衰减 !!
PyTorch使用Tricks:学习率衰减 !!
42 0
|
2月前
|
机器学习/深度学习 算法 PyTorch
从零开始学习线性回归:理论、实践与PyTorch实现
从零开始学习线性回归:理论、实践与PyTorch实现
从零开始学习线性回归:理论、实践与PyTorch实现
|
4月前
|
机器学习/深度学习 PyTorch 语音技术
Pytorch迁移学习使用Resnet50进行模型训练预测猫狗二分类
深度学习在图像分类、目标检测、语音识别等领域取得了重大突破,但是随着网络层数的增加,梯度消失和梯度爆炸问题逐渐凸显。随着层数的增加,梯度信息在反向传播过程中逐渐变小,导致网络难以收敛。同时,梯度爆炸问题也会导致网络的参数更新过大,无法正常收敛。 为了解决这些问题,ResNet提出了一个创新的思路:引入残差块(Residual Block)。残差块的设计允许网络学习残差映射,从而减轻了梯度消失问题,使得网络更容易训练。
|
9月前
|
机器学习/深度学习 并行计算 PyTorch
迁移学习的 PyTorch 实现
迁移学习的 PyTorch 实现
|
5月前
|
机器学习/深度学习 PyTorch 调度
迁移学习的 PyTorch 实现
迁移学习的 PyTorch 实现