Back Propagation 反向传播

简介: Back Propagation 反向传播

3、Back Propagation 反向传播

B站视频教程传送门:PyTorch深度学习实践 - 反向传播

3.1 引出算法

对于简单(单一)的权重 W ,我们可以将它封装到 Neuron(神经元)中,然后对权重W进行更新即可:

但是,往往实际情况中会涉及多个权重(例如下图:会有几十到上百个),我们几乎不可能做到写出所有的解析式:

矩阵论参考:The Matrix Cookbook - http://matrixcookbook.com

所以,引出反向传播(Back Propagation)算法:

当然,也可以将括号里面展开,如下如所示:

注意:不断地进行线性变化,不管有多少层,最后都会统一成一种形式。

所以:不能够化简(展开),因为增加的这些权重毫无意义。

3.2 非线性函数

所以我们需要引进非线性函数,每一层都需要一个非线性函数

3.3 算法步骤

在具体讲解反向传播之间,先回顾一下链式求导法则

(1)创建计算图表 Create Computational Graph (Forward)

(2)本地梯度 Local Gradient

(3)给出连续节点的梯度 Given gradient from successive node

(4)使用链式规则来计算梯度 Use chain rule to compute the gradient (Backward)

(5)Example: 𝑓 = 𝑥 ∙ 𝜔, 𝑥 = 2, 𝜔 = 3

3.3.1 例子

3.3.2 作业1

解答如下:

3.3.3 作业2

解答如下:

3.4 Tensor in PyTorch

3.5 PyTorch实现线性模型

import torch
x_data = [1.0, 2.0, 3.0, 4.0]
y_data = [2.0, 4.0, 6.0, 8.0]
w = torch.Tensor([1.0])
w.requires_grad = True  # 计算梯度
def forward(x):
    return x * w
def loss(x, y):
    y_pred = forward(x)
    return (y_pred - y) ** 2
print("predict (before training)", 4, forward(4).item())
for epoch in range(100):
    for x, y in zip(x_data, y_data):
        l = loss(x, y)
        l.backward()
        print('\tgrad:', x, y, w.grad.item())
        w.data -= 0.01 * w.grad.data
        w.grad.data.zero_()
    print("progress:", epoch, l.item())
print("predict (after training)", 4, forward(4).item())
predict (before training) 4 4.0
  grad: 1.0 2.0 -2.0
  grad: 2.0 4.0 -7.840000152587891
  grad: 3.0 6.0 -16.228801727294922
  grad: 4.0 8.0 -23.657981872558594
progress: 0 8.745314598083496
...
  grad: 1.0 2.0 -2.384185791015625e-07
  grad: 2.0 4.0 -9.5367431640625e-07
  grad: 3.0 6.0 -2.86102294921875e-06
  grad: 4.0 8.0 -3.814697265625e-06
progress: 99 2.2737367544323206e-13
predict (after training) 4 7.999999523162842

3.6 作业3

解答如下:

代码实现:

import torch
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
w1 = torch.Tensor([1.0])  # 初始权值
w1.requires_grad = True  # 计算梯度,默认是不计算的
w2 = torch.Tensor([1.0])
w2.requires_grad = True
b = torch.Tensor([1.0])
b.requires_grad = True
def forward(x):
    return w1 * x ** 2 + w2 * x + b
def loss(x, y):  # 构建计算图
    y_pred = forward(x)
    return (y_pred - y) ** 2
print('Predict (before training)', 4, forward(4))
for epoch in range(100):
    l = loss(1, 2)  # 为了在for循环之前定义l,以便之后的输出,无实际意义
    for x, y in zip(x_data, y_data):
        l = loss(x, y)
        l.backward()
        print('\tgrad:', x, y, w1.grad.item(), w2.grad.item(), b.grad.item())
        w1.data = w1.data - 0.01 * w1.grad.data  # 注意这里的grad是一个tensor,所以要取他的data
        w2.data = w2.data - 0.01 * w2.grad.data
        b.data = b.data - 0.01 * b.grad.data
        w1.grad.data.zero_()  # 释放之前计算的梯度
        w2.grad.data.zero_()
        b.grad.data.zero_()
    print('Epoch:', epoch, l.item())
print('Predict (after training)', 4, forward(4).item())
Predict (before training) 4 tensor([21.], grad_fn=<AddBackward0>)
  grad: 1.0 2.0 2.0 2.0 2.0
  grad: 2.0 4.0 22.880001068115234 11.440000534057617 5.720000267028809
  grad: 3.0 6.0 77.04720306396484 25.682401657104492 8.560800552368164
Epoch: 0 18.321826934814453
...
  grad: 1.0 2.0 0.31661415100097656 0.31661415100097656 0.31661415100097656
  grad: 2.0 4.0 -1.7297439575195312 -0.8648719787597656 -0.4324359893798828
  grad: 3.0 6.0 1.4307546615600586 0.47691822052001953 0.15897274017333984
Epoch: 99 0.00631808303296566
Predict (after training) 4 8.544171333312988


目录
相关文章
|
Linux 网络虚拟化 虚拟化
Linux虚拟网络设备深度解析:使用场景、分类与开发者指南
Linux虚拟网络设备支撑着各种复杂的网络需求和配置,从基础的网络桥接到高级的网络隔离和加密🔐。以下是对主要Linux虚拟网络设备的介绍、它们的作用以及适用场景的概览,同时提出了一种合理的分类,并指出应用开发人员应该着重掌握的设备。
Linux虚拟网络设备深度解析:使用场景、分类与开发者指南
|
机器学习/深度学习 数据采集 人工智能
AI大模型知识点大梳理1
AI大模型是什么 AI大模型发展历程
1211 0
|
Java Nacos 开发工具
【Nacos】心跳断了怎么办?!8步排查法+实战代码,手把手教你解决Nacos客户端不发送心跳检测问题,让服务瞬间恢复活力!
【8月更文挑战第15天】Nacos是一款广受好评的微服务注册与配置中心。然而,“客户端不发送心跳检测”的问题时有发生,可能导致服务实例被视为离线。本文介绍如何排查此类问题:确认Nacos服务器地址配置正确;检查网络连通性;查看客户端日志;确保Nacos SDK版本兼容;调整心跳检测策略;验证服务实例注册状态;必要时重启应用;检查影响行为的环境变量。通过这些步骤,通常可定位并解决问题,保障服务稳定运行。
1026 0
|
机器学习/深度学习 算法 数据挖掘
反向传播算法
反向传播算法
|
监控 网络协议 Unix
容器化部署实践之Django应用部署(二)
容器化部署实践之Django应用部署(二)
|
Go C++ 测试技术
VS代码生成工具ReSharper使用手册:配置快捷键(转)
原文:http://blog.csdn.net/fhzh520/article/details/46364603 VS代码生成工具ReSharper提供了丰富的快捷键,可以极大地提高你的开发效率。
2084 0
|
8天前
|
人工智能 自然语言处理 Shell
🦞 如何在 OpenClaw (Clawdbot/Moltbot) 配置阿里云百炼 API
本教程指导用户在开源AI助手Clawdbot中集成阿里云百炼API,涵盖安装Clawdbot、获取百炼API Key、配置环境变量与模型参数、验证调用等完整流程,支持Qwen3-max thinking (Qwen3-Max-2026-01-23)/Qwen - Plus等主流模型,助力本地化智能自动化。
🦞 如何在 OpenClaw (Clawdbot/Moltbot) 配置阿里云百炼 API
|
6天前
|
人工智能 JavaScript 应用服务中间件
零门槛部署本地AI助手:Windows系统Moltbot(Clawdbot)保姆级教程
Moltbot(原Clawdbot)是一款功能全面的智能体AI助手,不仅能通过聊天互动响应需求,还具备“动手”和“跑腿”能力——“手”可读写本地文件、执行代码、操控命令行,“脚”能联网搜索、访问网页并分析内容,“大脑”则可接入Qwen、OpenAI等云端API,或利用本地GPU运行模型。本教程专为Windows系统用户打造,从环境搭建到问题排查,详细拆解全流程,即使无技术基础也能顺利部署本地AI助理。
6496 13
|
4天前
|
人工智能 机器人 Linux
保姆级 OpenClaw (原 Clawdbot)飞书对接教程 手把手教你搭建 AI 助手
OpenClaw(原Clawdbot)是一款开源本地AI智能体,支持飞书等多平台对接。本教程手把手教你Linux下部署,实现数据私有、系统控制、网页浏览与代码编写,全程保姆级操作,240字内搞定专属AI助手搭建!
3745 11
保姆级 OpenClaw (原 Clawdbot)飞书对接教程 手把手教你搭建 AI 助手