权重衰减的简单示例代码,采用L2正则项

简介: 权重衰减的简单示例代码,采用L2正则项

创造数据

x的数据维度为(200,100)

w的数据维度为(100,1)

利用data_iter获得批次数据

import torch
from torch.utils import data
import torch.nn as nn
n_examples=200
n_features=100
true_w=torch.randn(100,1)
true_b=torch.tensor(0.54)
x_=torch.randn(200,100)
y_=x_@true_w+true_b
y_+=torch.normal(0,0.01,y_.shape)
def data_iter(x,y,batch_size):
    n_example=len(x)
    indices=torch.randperm(n_example)
    for i in range(0,n_example,batch_size):
        batch_indices=indices[i:min(i+batch_size,n_example)]
        yield x[batch_indices],y[batch_indices]

只对参数w进行权重衰减,b不需要

方式一

在优化器的参数中,利用字典的方式指名对待不同的参数实行不同的执行原则

wd=3
net=nn.Linear(100,1)
loss_function=nn.MSELoss()
optimizer=torch.optim.SGD([{'params':net.weight,
                           'weight_decay':wd},
                          {'params':net.bias}],lr=0.03)
epochs=3
for epoch in range(epochs):
    net.train()
    losses=0.0
    for x,y in data_iter(x_,y_,batch_size=20):
        y_hat=net(x)
        loss=loss_function(y_hat,y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses+=loss.item()
    print(losses)

方式二

方式二用了两个优化器,第一个掌管参数w的优化,第二个负责偏置b的优化,但是这样较为麻烦,需要两次梯度清0,且进行两次梯度更新

wd=3
net=nn.Linear(100,1)
loss_function=nn.MSELoss()
optimizer_w=torch.optim.SGD([net.weight],lr=0.03,weight_decay=wd)
optimizer_b=torch.optim.SGD([net.bias],lr=0.03)
epochs=3
for epoch in range(epochs):
    net.train()
    losses=0.0
    for x,y in data_iter(x_,y_,batch_size=20):
        y_hat=net(x)
        loss=loss_function(y_hat,y)
        optimizer_w.zero_grad()
        optimizer_b.zero_grad()
        loss.backward()
        optimizer_w.step()
        optimizer_b.step()
        losses+=loss.item()
    print(losses)


目录
相关文章
|
算法 C语言
C语言的伪代码结构
C语言的伪代码结构
263 1
|
Java Maven
SpringBoot项目如何打包、部署
SpringBoot项目如何打包、部署
256 0
|
供应链 算法 Java
使用Java构建区块链应用
使用Java构建区块链应用
|
前端开发 UED Python
Wagtail-基于Python Django的内容管理系统CMS实现公网访问
Wagtail-基于Python Django的内容管理系统CMS实现公网访问
283 0
|
编解码 算法 计算机视觉
OpenCV(十七):拉普拉斯图像金字塔
OpenCV(十七):拉普拉斯图像金字塔
460 0
|
存储 安全 Java
利用POI多线程导出数据错位解决
通过反射替换解决
934 0
|
存储 SQL 分布式计算
当流计算邂逅数据湖:Paimon 的前生今世
希望通过笔者以下的经历,回顾流计算一步一步扩大场景的过程,并引出 Apache Paimon 的前生今世。
1816 0
当流计算邂逅数据湖:Paimon 的前生今世
|
Java 开发者
JavaHTTP心跳:服务器与客户端实时连接的实现方式
JavaHTTP心跳:服务器与客户端实时连接的实现方式 在网络通信中,实时连接是一种至关重要的功能。它允许服务器与客户端之间保持持久的通信信道,实现快速、高效的数据传输。对于Java开发者来说,实现服务器与客户端之间的实时连接可以通过JavaHTTP心跳技术来实现。本文将介绍如何利用JavaHTTP心跳来实现服务器与客户端之间的实时连接。
514 0