深度学习经典算法PPO的通俗理解

简介: #1 前置知识点 基本概念 [https://www.yuque.com/docs/share/04b60c4c-90ec-49c7-8a47-0dae7d3c78c7?#](https://www.yuque.com/docs/share/04b60c4c-90ec-49c7-8a47-0dae7d3c78c7?#) (部分符合的定义在这里) 要理解PPO,就必须先理解Actor

1 前置知识点

基本概念
https://www.yuque.com/docs/share/04b60c4c-90ec-49c7-8a47-0dae7d3c78c7?#
(部分符合的定义在这里)

要理解PPO,就必须先理解Actor-Critic.
Actor负责输出policy,也就是在某个状态下执行各种action的概率分布
Critic负责输出Vaue of state。
Actor和Critic的默契:Actor相信Critic给的状态的value就是真的; Critic也相信Actor选送过来的(s,a)中的a就是最优的action。通过不断的迭代,强化这个信任关系的正确性。
(这体现了我们的价值观 [因为信任,所以简单],哈哈哈~)
image.png

所以这样就不难理解Critic的Loss是怎么来的了,Critic的输出就是state的Value,那就让Critic模型的输出使得以下公式成立:
Vs=rs,a+γVs

 
其中,rs,a,s,a,s是训练Critic需要的数据,s是在状态s下执行动作a得到新状态, rs,a是reward, γ 是discount factor。
跟基础概念的区别是,这里的系统假定是执行动作a只能到s, 没有体现执行a可以得到不同的状态; (但是其实这种概率可以体现在训练数据中,因为(s,a,rs,a)s 不一定是一一对应,其概率可以通过sampling得到的数据分布体现)
所以Critic的Loss就是|rs,a+γVsVs|,也就是所谓的TD(Time Difference)-Error的L1,或者L2也可以.
那么Actor的Loss怎么计算呢?
这里就先来明白Advantage的概念,其实也就是TD-Error
 Adv=rs,a+γVsVs
 
之所以称之为Advantage,是因为假如Advantage>0, 就是说实际执行a之后,发现当前的状态Value实际上比当前Critic估计出来的要大,所以这是个好的Action,它能够让Vs 变大,Actor应该增大这个action的概率;反之,Advantage<0,这个action就是不好的,应该减小执行它概率。
所以Actor的Loss就是log(π(a|s))Adv
, 因为要最小化Loss,所以前面加个负号;Adv的符号代表了应该增大这个action的输出概率,还是减小这个action的输出概率;Adv的大小代表了增加或减小的幅度应该有多大。

2 Proximal Policy Optimization(PPO)

2.1 PPO主要是来解决什么问题?

它是为了让我们在训练Agent的时候能够take the biggest possible improvement step on a policy using the data we currently have, without stepping so far that we accidentally cause performance collapse。就是更新policy的时候尽可能步子迈大点,但也要防止扯着蛋,即模型参数崩了。

2.2 PPO怎么解决这个问题的?

简单来说,相同的网络结构,维护两组模型参数Old和New,在每次迭代的时候去更新New的参数,但是要控制New的模型输出Policy和Old的Policy不要差距太大,本轮迭代结束后,把New的参数覆盖掉Old的参数。

怎么去控制差距不要太大呢?作者给了两种方式: PPO-Penalty, PPO-Clip

2.2.1 PPO-Clip

先说PPO-Clip, 它通过下面的公式来更新策略:
θk+1=argmaxθEs,aπθk[L(s,a,θk,θ)]

 
就是最大化L(s,a,θk,θ),
L(s,a,θk,θ)=min(πθ(a|s)πθk(a|s)Aπθk(s,a),clip(πθ(a|s)πθk(a|s),1ϵ,1+ϵ)Aπθk(s,a))

这个形式主要是为了让我们理解为啥叫PPO-Clip(我感觉直接用后面那个Clip项其实就够了,这个表达有点冗余),θk 就是当前Old的参数,θ 是New的参数。πθ(a|s) 是New Actor输出的Policy上状态s时执行a的概率,πθk(a|s) 表示的Old Actor输出的Policy上状态s时执行a的概率。Aπθk(s,a)是基于Old Critic得到的Advantage.
对这个公式进行改写,更容易看出它的真实目的,
L(s,a,θk,θ)=min(πθ(a|s)πθk(a|s)Aπθk(s,a),g(ϵ,Aπθk(s,a) ))

其中,

g(ϵ,A)={(1+ϵ)AA0(1ϵ)AA<0

  当Advantage>=0的时候, L(s,a,θk,θ)=min(πθ(a|s)πθk(a|s),1+ϵ)Aπθk(s,a)
这就清楚的说明,这时候应该增大πθ(a|s),也就是认为这个action是好的,增加选择a的概率。但是通过1+ϵ 限制增大的幅度。 同理,当Advantage<0的时候 L(s,a,θk,θ)=min(πθ(a|s)πθk(a|s),1ϵ)Aπθk(s,a)
缩小πθ(a|s),但是幅度不能小于1ϵ 另外,根据我的理解,πθk(a|s)应该截断梯度,也就是反向传到的时候用不着去更新Old Actor的参数。在OpenAI Spinningup的代码([https://github.com/openai/spinningup/blob/master/spinup/algos/pytorch/ppo/ppo.py](https://github.com/openai/spinningup/blob/master/spinup/algos/pytorch/ppo/ppo.py))确实是这样处理的,但是在Tianshou的代码里([https://github.com/thu-ml/tianshou/blob/master/tianshou/policy/ppo.py](https://github.com/thu-ml/tianshou/blob/master/tianshou/policy/ppo.py))没有做截断,结果也OK,想来对于πθ(a|s)来说,πθk(a|s)就是一个scalar factor, 这个factor是变量还是静态值,也许影响不那么大,而且本轮迭代结束后θk也会被覆盖掉,反向传导更新了也白搭。 到这里,其实说的都是如何更新Actor。 怎么更新Critic的参数呢? Lcs,a,rs,a,s=|rs,a+VπθksVπθ|
唯一的不同是target value是用Old Critic计算的,这也是DRL领域的常规操作了. 小结一下,PPO-Clip就是通过Clip操作来防止步子迈太大的。作者实验证明Clip的效果比Penalty好。 ### 2.2.2 PPO-Penalty LKLPEN(θ)=πθ(a|s)πθk(a|s)Aπθk(s,a)βKLD(πθ(|s),πθk(|s))
理解上上面的,这个理解起来也就容易了,就是增加一个新旧Policy差异的惩罚项,差异通过KL divergence来衡量 (PS: 如理解有误支持,欢迎批评指正~)

烁凡
+关注
目录
打赏
0
0
0
0
49
分享
相关文章
猫狗宠物识别系统Python+TensorFlow+人工智能+深度学习+卷积网络算法
宠物识别系统使用Python和TensorFlow搭建卷积神经网络,基于37种常见猫狗数据集训练高精度模型,并保存为h5格式。通过Django框架搭建Web平台,用户上传宠物图片即可识别其名称,提供便捷的宠物识别服务。
460 55
基于Python深度学习的眼疾识别系统实现~人工智能+卷积网络算法
眼疾识别系统,本系统使用Python作为主要开发语言,基于TensorFlow搭建卷积神经网络算法,并收集了4种常见的眼疾图像数据集(白内障、糖尿病性视网膜病变、青光眼和正常眼睛) 再使用通过搭建的算法模型对数据集进行训练得到一个识别精度较高的模型,然后保存为为本地h5格式文件。最后使用Django框架搭建了一个Web网页平台可视化操作界面,实现用户上传一张眼疾图片识别其名称。
248 5
基于Python深度学习的眼疾识别系统实现~人工智能+卷积网络算法
基于MobileNet深度学习网络的活体人脸识别检测算法matlab仿真
本内容主要介绍一种基于MobileNet深度学习网络的活体人脸识别检测技术及MQAM调制类型识别方法。完整程序运行效果无水印,需使用Matlab2022a版本。核心代码包含详细中文注释与操作视频。理论概述中提到,传统人脸识别易受非活体攻击影响,而MobileNet通过轻量化的深度可分离卷积结构,在保证准确性的同时提升检测效率。活体人脸与非活体在纹理和光照上存在显著差异,MobileNet可有效提取人脸高级特征,为无线通信领域提供先进的调制类型识别方案。
基于深度学习的路面裂缝检测算法matlab仿真
本项目基于YOLOv2算法实现高效的路面裂缝检测,使用Matlab 2022a开发。完整程序运行效果无水印,核心代码配有详细中文注释及操作视频。通过深度学习技术,将目标检测转化为回归问题,直接预测裂缝位置和类别,大幅提升检测效率与准确性。适用于实时检测任务,确保道路安全维护。 简介涵盖了算法理论、数据集准备、网络训练及检测过程,采用Darknet-19卷积神经网络结构,结合随机梯度下降算法进行训练。
【宠物识别系统】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+图像识别
宠物识别系统,本系统使用Python作为主要开发语言,基于TensorFlow搭建卷积神经网络算法,并收集了37种常见的猫狗宠物种类数据集【'阿比西尼亚猫(Abyssinian)', '孟加拉猫(Bengal)', '暹罗猫(Birman)', '孟买猫(Bombay)', '英国短毛猫(British Shorthair)', '埃及猫(Egyptian Mau)', '缅因猫(Maine Coon)', '波斯猫(Persian)', '布偶猫(Ragdoll)', '俄罗斯蓝猫(Russian Blue)', '暹罗猫(Siamese)', '斯芬克斯猫(Sphynx)', '美国斗牛犬
252 29
【宠物识别系统】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+图像识别
深度学习中的优化算法及其应用
【10月更文挑战第8天】 本文将探讨深度学习中常用的优化算法,包括梯度下降法、Adam和RMSProp等,介绍这些算法的基本原理与应用场景。通过实例分析,帮助读者更好地理解和应用这些优化算法,提高深度学习模型的训练效率与性能。
394 63
基于GoogleNet深度学习网络的手语识别算法matlab仿真
本项目展示了基于GoogleNet的深度学习手语识别算法,使用Matlab2022a实现。通过卷积神经网络(CNN)识别手语手势,如&quot;How are you&quot;、&quot;I am fine&quot;、&quot;I love you&quot;等。核心在于Inception模块,通过多尺度处理和1x1卷积减少计算量,提高效率。项目附带完整代码及操作视频。
基于深度学习网络的宝石类型识别算法matlab仿真
本项目利用GoogLeNet深度学习网络进行宝石类型识别,实验包括收集多类宝石图像数据集并按7:1:2比例划分。使用Matlab2022a实现算法,提供含中文注释的完整代码及操作视频。GoogLeNet通过其独特的Inception模块,结合数据增强、学习率调整和正则化等优化手段,有效提升了宝石识别的准确性和效率。
基于Python深度学习的【垃圾识别系统】实现~TensorFlow+人工智能+算法网络
垃圾识别分类系统。本系统采用Python作为主要编程语言,通过收集了5种常见的垃圾数据集('塑料', '玻璃', '纸张', '纸板', '金属'),然后基于TensorFlow搭建卷积神经网络算法模型,通过对图像数据集进行多轮迭代训练,最后得到一个识别精度较高的模型文件。然后使用Django搭建Web网页端可视化操作界面,实现用户在网页端上传一张垃圾图片识别其名称。
214 0
基于Python深度学习的【垃圾识别系统】实现~TensorFlow+人工智能+算法网络
【手写数字识别】Python+深度学习+机器学习+人工智能+TensorFlow+算法模型
手写数字识别系统,使用Python作为主要开发语言,基于深度学习TensorFlow框架,搭建卷积神经网络算法。并通过对数据集进行训练,最后得到一个识别精度较高的模型。并基于Flask框架,开发网页端操作平台,实现用户上传一张图片识别其名称。
217 0
【手写数字识别】Python+深度学习+机器学习+人工智能+TensorFlow+算法模型