基于PyTorch实战权重衰减——L2范数正则化方法(附代码)

简介: 基于PyTorch实战权重衰减——L2范数正则化方法(附代码)

0. 前言

按照国际惯例,首先声明:本文只是我自己学习的理解,虽然参考了他人的宝贵见解,但是内容可能存在不准确的地方。如果发现文中错误,希望批评指正,共同进步。

本文旨在通过实例验证权重衰减法(L2范数正则化方法)对深度学习神经元网络模型训练过程中出现的过拟合现象的抑制作用,加深对这个方法的理解。

1. 权重衰减方法作用

在训练神经元网络模型时,如果训练样本不足或者网络模型过于复杂,往往会导致训练误差可以快速收敛,但是在测试数据集上的泛化误差很大,即出现过拟合现象。

出现这种情况当然可以通过增多训练样本数来解决,但是如果增加额外的训练数据很困难,对应这类过拟合问题常用方法就是权重衰减法

2. 权重衰减方法原理介绍

权重衰减等价于L2范数正则化,其方法是在损失函数中增加权重的L2范数作为惩罚项。以MSE均方误差为例,原本损失函数应该是:

image.png

增加L2范数后变成:

image.png

其中 ||w||^2代表权重的二范数, λ为权重二范数的系数, λ≥0。


可以见得,如果 λ越大,权重的“惩罚力度”就越大,权重  w的绝对值就越接近0,如果 λ=0,相当于没有“惩罚力度”。


3. 验证权重衰减法实例说明

3.1 训练数据样本

本次演示实例使用的输入训练数据为x_train = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],输出训练数据为y_train = [0.52, 8.54, 6.94, 20.76, 32.17, 30.65, 40.46, 80.12, 75.12, 98.83]。


这个数据集是由  y = x^2 函数增加一个噪声数据生成得出,可以理解为 y = x^2 为该实例的真实解析解(真实规律)。

3.2 网络模型

使用torch.nn.Sequential()构建6层全连接层网络,每层神经元个数为:

InputLayer = 1,HiddenLayer1 = 3,HiddenLayer2 = 5,HiddenLayer3 = 10,HiddenLayer4 = 5,OutputLayer = 1

3.3 损失函数

选择MSE均方差损失函数,使用 torch.norm()计算权重的L2范数。

3.4 训练参数

无论是否增加L2范数惩罚项,训练参数都是一样的(控制变量):优化函数选用torch.optim.Adam(),学习速率lr=0.005,训练次数epoch=3000。

4. 结果对比

增加L2范数学习结果为:

其中红点为训练数据;黄色线为解析解,即y=x^2;蓝色线为训练后的模型在 x=[0, 10]上的预测结果。

不加L2惩罚项的学习结果为:

可以见得增加L2范数惩罚项后,测试的输出数据可以明显更贴合 y=x^2 理论曲线,尤其是在0~4范围上。


这里也可以增加一个类似损失函数的方式通过数据说明增加L2范数后学习结果更好,定义为:

image.png

其中 y ^ \widehat{y} y 为网络模型的输出结果,y=x^2

不加L2范数惩罚项 image.png

增加L2范数惩罚项后 image.png

5. 源码

import torch
import matplotlib.pyplot as plt

torch.manual_seed(25)

x_train = torch.tensor([1,2,3,4,5,6,7,8,9,10],dtype=torch.float32).unsqueeze(-1)
y_train = torch.tensor([0.52,8.54,6.94,20.76,32.17,30.65,40.46,80.12,75.12,98.83],dtype=torch.float32).unsqueeze(-1)
plt.scatter(x_train.detach().numpy(),y_train.detach().numpy(),marker='o',s=50,c='r')

class Linear(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = torch.nn.Sequential(
            torch.nn.Linear(in_features=1, out_features=3),
            torch.nn.Sigmoid(),
            torch.nn.Linear(in_features=3,out_features=5),
            torch.nn.Sigmoid(),
            torch.nn.Linear(in_features=5, out_features=10),
            torch.nn.Sigmoid(),
            torch.nn.Linear(in_features=10,out_features=5),
            torch.nn.Sigmoid(),
            torch.nn.Linear(in_features=5, out_features=1),
            torch.nn.ReLU(),
        )

    def forward(self,x):
        return self.layers(x)

linear = Linear()

opt = torch.optim.Adam(linear.parameters(),lr= 0.005)
loss = torch.nn.MSELoss()


for epoch in range(3000):
    l = 0
    L2 = 0
    for iter in range(10):

        for w in linear.parameters():
            L2 = torch.norm(w, p=2)*1e8  #计算权重的L2范数,如果要取消L2正则化惩罚只要把这项*0就可以了
        opt.zero_grad()
        output = linear(x_train[iter])
        loss_L2 = loss(output, y_train[iter]) + L2
        loss_L2.backward()
        l = loss_L2.detach() + l
        opt.step()
    print(epoch,L2,loss_L2)

#     plt.scatter(epoch, l, s=5,c='g')
#
# plt.show()


if __name__ == '__main__':
    predict_loss = 0
    for i in range(1000):
        x = torch.tensor([i/100], dtype=torch.float32)
        y_predict = linear(x)
        plt.scatter(x.detach().numpy(),y_predict.detach().numpy(),s=2,c='b')
        plt.scatter(i/100,i*i/10000,s=2,c='y')
        predict_loss = (i*i/10000 - y_predict)**2/(y_predict)**2 + predict_loss   #计算神经元网络模型输出对解析解的loss
# plt.show()

# print(linear.state_dict())
print(predict_loss)

本文的主要参考文献:

[1]Aston Zhang, Mu Li. Dive into deep learning.北京:人民邮电出版社.


相关文章
|
7月前
|
机器学习/深度学习 PyTorch TensorFlow
TensorFlow与PyTorch深度对比分析:从基础原理到实战选择的完整指南
蒋星熠Jaxonic,深度学习探索者。本文深度对比TensorFlow与PyTorch架构、性能、生态及应用场景,剖析技术选型关键,助力开发者在二进制星河中驾驭AI未来。
921 13
|
9月前
|
PyTorch 算法框架/工具 异构计算
PyTorch 2.0性能优化实战:4种常见代码错误严重拖慢模型
我们将深入探讨图中断(graph breaks)和多图问题对性能的负面影响,并分析PyTorch模型开发中应当避免的常见错误模式。
530 9
|
11月前
|
机器学习/深度学习 存储 PyTorch
PyTorch + MLFlow 实战:从零构建可追踪的深度学习模型训练系统
本文通过使用 Kaggle 数据集训练情感分析模型的实例,详细演示了如何将 PyTorch 与 MLFlow 进行深度集成,实现完整的实验跟踪、模型记录和结果可复现性管理。文章将系统性地介绍训练代码的核心组件,展示指标和工件的记录方法,并提供 MLFlow UI 的详细界面截图。
514 2
PyTorch + MLFlow 实战:从零构建可追踪的深度学习模型训练系统
|
11月前
|
机器学习/深度学习 PyTorch 算法框架/工具
提升模型泛化能力:PyTorch的L1、L2、ElasticNet正则化技术深度解析与代码实现
本文将深入探讨L1、L2和ElasticNet正则化技术,重点关注其在PyTorch框架中的具体实现。关于这些技术的理论基础,建议读者参考相关理论文献以获得更深入的理解。
381 4
提升模型泛化能力:PyTorch的L1、L2、ElasticNet正则化技术深度解析与代码实现
|
机器学习/深度学习 人工智能 PyTorch
使用PyTorch实现GPT-2直接偏好优化训练:DPO方法改进及其与监督微调的效果对比
本文将系统阐述DPO的工作原理、实现机制,以及其与传统RLHF和SFT方法的本质区别。
1512 22
使用PyTorch实现GPT-2直接偏好优化训练:DPO方法改进及其与监督微调的效果对比
|
机器学习/深度学习 人工智能 PyTorch
【深度学习】使用PyTorch构建神经网络:深度学习实战指南
PyTorch是一个开源的Python机器学习库,特别专注于深度学习领域。它由Facebook的AI研究团队开发并维护,因其灵活的架构、动态计算图以及在科研和工业界的广泛支持而受到青睐。PyTorch提供了强大的GPU加速能力,使得在处理大规模数据集和复杂模型时效率极高。
545 59
|
存储 并行计算 PyTorch
探索PyTorch:模型的定义和保存方法
探索PyTorch:模型的定义和保存方法
|
机器学习/深度学习 PyTorch TensorFlow
【PyTorch】PyTorch深度学习框架实战(一):实现你的第一个DNN网络
【PyTorch】PyTorch深度学习框架实战(一):实现你的第一个DNN网络
758 2
|
机器学习/深度学习 数据挖掘 TensorFlow
解锁Python数据分析新技能,TensorFlow&PyTorch双引擎驱动深度学习实战盛宴
【7月更文挑战第31天】在数据驱动时代,Python凭借其简洁性与强大的库支持,成为数据分析与机器学习的首选语言。**数据分析基础**从Pandas和NumPy开始,Pandas简化了数据处理和清洗,NumPy支持高效的数学运算。例如,加载并清洗CSV数据、计算总销售额等。
306 2

推荐镜像

更多