四、反向传播 back propagation

简介: 四、反向传播 back propagation

一、原理

1.1 计算图

在简单的线性模型中,我们可以通过数学推导求出梯度公式。

在复杂网络中,因为太复杂,无法直接数学计算梯度公式。

考虑将这样的复杂网络看成是图,我们在图上传播梯度,最后根据链式求导求出梯度(反向传播)。

 

计算图(一个简单的二层神经网络)

由于可以线性展开,这样就等于一个网络,中间计算的权重就没有意义。所以我们要加上一个非线性的东西,增加模型的复杂度。

 


 

1.2 反向传播

 

反向传播的过程:(先是蓝色箭头,然后红色)

Example:y = w * x

 

 

 

二、PyTorch实现

PyTorch :

Tenso(张量):PyTorch中存储数据的基本元素。

Tensor两个重要的成员,data和grad。(grad也是个张量)

 

#pytorch,线性模型
 
import torch
 
 
#数据样本x,y
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
 
#设置权重w 初值1.0 ,注意中括号[]
w = torch.Tensor([1.0])
w.requires_grad = True #w需要计算梯度
 
# y = x * w
def forward(x):
    return x * w #由于w是Tensor类型,这里x会自动类型转换为Tensor
 
#损失函数
def loss(x, y):
    y_pred = forward(x)
    return (y_pred - y) ** 2
 
 
print("predict (before training) x={x},y={y}".format(x=4,y=forward(4).item()))
 
for epoch in range(100):
    for x, y in zip(x_data, y_data):
        l = loss(x, y) #Forward, compute loss。
        l.backward()   #Backward, 反向传播后就得到梯度
        print('\tgrad:',x, y, w.grad.item())
        w.data = w.data  - 0.01 * w.grad.data #更新w
 
        w.grad.data.zero_()#默认w的梯度累加,但这里需要清零。所以要显式清零。
    print("progress: epoch={epoch}, loss = {l}".format(epoch=epoch, l =l.item()))
print("predict (after training) x={x},y={y}".format(x=4,y=forward(4).item()))

相关文章
|
机器学习/深度学习 算法 PyTorch
反向传播(Backpropagation)
反向传播(Backpropagation)是一种用于训练神经网络的常用算法。它通过计算神经网络中各个参数对于损失函数的梯度,从而实现参数的更新和优化。神经网络是一种模拟人脑神经元相互连接的计算模型,用于解决各种机器学习和深度学习任务。
189 1
|
4月前
|
算法 数据挖掘
文献解读-Consistency and reproducibility of large panel next-generation sequencing: Multi-laboratory assessment of somatic mutation detection on reference materials with mismatch repair and proofreading deficiency
Consistency and reproducibility of large panel next-generation sequencing: Multi-laboratory assessment of somatic mutation detection on reference materials with mismatch repair and proofreading deficiency,大panel二代测序的一致性和重复性:对具有错配修复和校对缺陷的参考物质进行体细胞突变检测的多实验室评估
37 6
文献解读-Consistency and reproducibility of large panel next-generation sequencing: Multi-laboratory assessment of somatic mutation detection on reference materials with mismatch repair and proofreading deficiency
|
7月前
|
机器学习/深度学习 算法 Python
Backpropagation
【6月更文挑战第24天】
58 7
|
7月前
|
监控 Java API
Java一分钟之-JPA事务管理:PROPAGATION_REQUIRED, PROPAGATION_REQUIRES_NEW等
【6月更文挑战第14天】Java企业开发中,事务管理确保数据一致性,Spring事务管理核心概念包括`PROPAGATION_REQUIRED`和`PROPAGATION_REQUIRES_NEW`。前者在无事务时新建,有事务时加入,常用于保证业务方法在事务中执行。后者始终创建新事务,独立于当前事务,适用于需隔离影响的场景。理解其应用场景和易错点,合理配置事务传播行为,能提升应用的健壮性和性能。通过监控和日志优化事务策略是关键。
187 1
|
8月前
|
机器学习/深度学习 人工智能 算法
神经网络算法——反向传播 Back Propagation
神经网络算法——反向传播 Back Propagation
97 0
|
机器学习/深度学习 算法 PyTorch
Back Propagation 反向传播
Back Propagation 反向传播
119 0
|
存储 JSON 数据挖掘
PlistEdit Pro
PlistEdit Pro 是一款为 macOS 设计的属性列表(Plist)编辑器,它可以帮助用户直观且方便地查看、编辑和管理 Plist 文件。Plist 文件是 macOS 和 iOS 操作系统中许多应用程序和系统功能使用的配置文件格式,通常包含键值对、数组和字典等数据结构。
311 0
|
机器学习/深度学习 算法 Python
BP神经网络(Back Propagation Neural Network)算法原理推导与Python实现详解
BP神经网络(Back Propagation Neural Network)算法原理推导与Python实现详解
|
机器学习/深度学习 资源调度 Python
一文弄懂神经网络中的反向传播法——Back Propagation
其实应用挺广的,在图像识别,文本分类等等都会用到,我会专门再写一篇Auto-Encoder的文章来说明,包括一些变种之类的。如果你的输出和原始输入不一样,那么就是很常见的人工神经网络了,相当于让原始数据通过一个映射来得到我们想要的输出数据,也就是我们今天要讲的话题。
一文弄懂神经网络中的反向传播法——Back Propagation
|
机器学习/深度学习 算法 搜索推荐
On the Unreasonable Effectiveness of Feature propagation in Learning on Graphs with Missing 论文阅读笔记
On the Unreasonable Effectiveness of Feature propagation in Learning on Graphs with Missing 论文阅读笔记
220 0
On the Unreasonable Effectiveness of Feature propagation in Learning on Graphs with Missing 论文阅读笔记