1.4 ME-TRPO
Model-free RL算法往往会有很高的sample-complexity,无法应用在真实世界的环境中,而经典的Model-based RL算法会有很高的model bias,所以ME-TRPO使用了emsemble的方式来保持模型的不确定性,通过改变初始化权重和训练输入序列来区分神经网络。
成功的model-free强化学习方法反复收集数据,估计策略梯度,改进策略,然后丢弃数据。相反,model-based的强化学习更广泛地使用数据;它使用所有收集的数据来训练环境的动态模型。经典的the vanilla model-based reinforcement learning algorithm的流程如Algorithm 1所示:
dynamic model使用神经网络进行建模。即输入一个状态和一个动作,预测状态的变化,则下一个状态可以表示为神经网络输出+输入之和。但是,vanilla算法有诸多弊端,例如梯度爆炸、梯度消失、策略在数据稀疏分布处更新等。本文提出了ME-TRPO算法,与vanilla算法的三种不同在于:拟合了一组dynamic model {f_ϕ1,…,f_ϕK} 。差异在于初始化权重不同,以及训练的mini-batch的顺序不同;使用TRPO算法代替BPTT优化策略;使用model emsemble在验证集监督策略性能,策略性能不再提升时立即停止迭代。
ME-TRPO的伪代码如下:
策略优化。为了克服BPTT的问题,考虑可以使用model-free RL文献中的似然比方法,例如Vanilla Policy Gradient (VPG)、Proximal Policy Optimization (PPO)、Trust Region Policy Optimization (TRPO)等。在作者文献中最好的结果是由TRPO实现的。为了估计梯度,使用学到的模型来模拟轨迹,具体如下:在每一步长,随机选择一个模型来预测当前状态和行动的下一个状态。这避免了策略在一个episode中对任何单一模型的过度拟合,从而保证更稳定的学习。
策略验证。使用学到的K个模型监测策略的性能。具体来说,计算策略改善的模型的比率:
只要这个比率超过某个阈值,当前的迭代就会继续。在实践中,在每5次梯度更新后对策略进行验证,使用70%作为阈值。如果该比率低于阈值,可以容忍少量的更新,如果性能提高,则终止当前的迭代。然后,重复整个过程,用策略收集更多的真实世界数据,优化模型组合,并使用模型组合来改进策略。这个过程一直持续到达到真实环境中的理想性能。
当前 SOTA!平台收录ME-TRPO共2个模型实现资源。
项目 | SOTA!平台项目详情页 |
ME-TRPO | 前往SOTA!模型平台获取实现资源:https://sota.jiqizhixin.com/project/me-trpo |
1.5 DMVE
基于模型的动态地平线价值扩展( Dynamic-horizon Model-based Value Expansion,DMVE)在不同的 rollout horizons 下,调整world model的使用。受可用于视觉数据新颖性检测的基于重建技术的启发,引入带有重建模块的world model进行图像特征提取,以获得更精确的价值估计。原始图像和重建图像都被用来确定适应性价值扩展的适当horizon。
图4. DMVE概述。为了估计状态值,DMVE首先采用了一个重建网络来重建原始观测。然后,将原始图像和重建图像送入过渡模块,进行 H-step latent imagination,并以不同的 rollout horizons h=1,2,...,H 计算两者的 value expansion 。之后,选择对应于原始图像和重建图像之间前K个最小值扩展误差的horizon。最后,从与所选horizon相对应的扩展值中取平均值进行数值估计
DMVE算法伪代码如下:
DMVE框架的核心是一种类似MVE的算法,利用世界模型(world model)、行动模型(action model)和价值模型(value model)来估计状态值。
World model。许多model-based RL方法首先建立一个世界模型,并进一步用它来推导行为。在学习世界模型的基础上,模型学习和策略学习的过程可以交替并行。通常,世界模型提供了一个系统的动态,从当前的状态和行动映射到下一个状态,并对这种transition给予奖励。DMVE使用的world model通过重建原始图像来学习动态的规划。同时,这种基于重构的架构适合于动态水平线的选择。世界模型w_θ是由以下部分组成的:
视觉控制的任务被表述为部分可观察的马尔可夫决策过程( Partially Observable Markov Decision Process, POMDP),因为agent不能直接观察这种任务的基本状态。它可以被描述为一个7元组(S, A, O, T, R, P, γ),其中S表示状态集合,A表示行动集合,O表示观察集合。agent在一连串的离散时间步长中的每一步与环境互动。T(s_t+1 | s_t, a_t)是状态s_t∈S中行动a_t∈A导致状态s_t+1的条件转换概率。R(s_t, a_t)是在状态s_t下执行行动a_t的实值奖励,P表示观察概率P(o_t|s_t+1, a_t),其中,o_t代表agent的观察,执行行动a_t,世界移动到状态s_t+1。γ∈(0,1)是一个折扣系数。策略π(s)从环境状态映射到行动。在每个时间步长t,环境处于某个状态s_t∈S,agent在一个状态下选择一个可行的行动a_t∈A,表征环境以概率T(s_t+1 | s_t, a_t)过渡到状态s_t+1∈S。在环境中进行的行动之后,agent收到一个概率为P(o_t | s_t+1, a_t)的观察o_t∈O,以及一个数字奖励R(s_t, a_t)。然后,上述互动过程重复进行。RL的目标是通过使累积奖励最大化来学习一个最佳策略:
在POMDP设置中,状态不能直接获得,因此应用表征模块(representation model)将观察到的行动映射到低维连续向量,这些向量被视为马尔科夫转换( Markovian transitions)状态。重建模块从状态中估计原始观测值,并通过最小化重建误差确保状态能够代表原始输入数据中的有效信息。奖励模块根据环境反馈的实值奖励,预测想象轨迹中的奖励。过渡模块(transition module)根据当前的状态和行动来预测下一个状态,而没有看到原始的观察结果。过渡模块被实现为一个循环的状态空间模型( Recurrent State Space Model,RSSM)。表征模块是RSSM和卷积神经网络的结合。重建模块是一个转置的CNN,而奖励模块是一个密集网络。所有四个模型组件都通过随机反向传播进行联合优化。
Policy Learning。采用actor-critic方法进行策略学习。除了估计的值函数之外,actor-critic法还有一个独立的结构来表示策略。策略结构即actor,因为是用来推导行为的。同时,值函数为critic,因为它批评由actor决定的行为。对于状态s_t,actor和critic模型被定义为:
actor和critic模型均为密集网络(dense network)。状态值需要为actor和critic模型优化进行估计。MVE可以通过假设一个近似的动力学模型和一个奖励函数来改进值的估计。由于上述世界模型包含价值扩展所需的元素,可以用它来估计状态值。选择一个h∈{1,2,...,H},可以用想象的轨迹计算s_t的价值估计:
本文使用的是 MVE-LI (MVE的变体),状态s_t的估计值可以表示为V_h(s_t),在估计了状态值之后,可以优化actor和critic神经网络,而世界模型是固定的。MVE-LI中actor和critic模型的学习目标被设定为:
MVE-LI以一个固定的最大rollout horizon H想象潜在空间中的未来状态。然而,对于不同的imagination horizons h∈{1,2,...,H},可以为状态s_t估计不同的值:
给定最大rollout horizon H,可以用世界模型得到不同h∈{1,2,...,H}的数据序列的价值估计V_h,其中,L表示序列长度。首先,时间步长t的状态可以通过st∼w_θ(s_t | s_t-1, a_t-1, o_t)得出。随后,状态s_t的值可以通过以下方式估计出来:
同样,对于重建的图像oˆt∼wOθ(ˆo_t|s_t),基于重建的状态可以通过sˆt∼wSθ(ˆs_t|s_t-1, a_t-1, oˆt)得出,相应的值可以通过如上相同的过程来估计。给定所选 horizons 的数量K,rollout horizons 的集合H可以通过以下方式确定:
DMVE的最终价值估计为:
在上述自适应horizon的选择过程中,原始图像和重建图像被送入不同的世界模型组件,以预测未来的序列并进行价值估计。根据我们的假设,这些图像之间的世界模型的输出误差反映了它对不同输入的概括能力。因此,基于两者的估计状态值之间的误差被用来确定最终值估计的horizon。
项目 | SOTA!平台项目详情页 |
DMVE | 前往SOTA!模型平台获取实现资源:https://sota.jiqizhixin.com/project/dmve |
前往 SOTA!模型资源站(sota.jiqizhixin.com)即可获取本文中包含的模型实现代码、预训练模型及API等资源。
网页端访问:在浏览器地址栏输入新版站点地址 sota.jiqizhixin.com ,即可前往「SOTA!模型」平台,查看关注的模型是否有新资源收录。
移动端访问:在微信移动端中搜索服务号名称「机器之心SOTA模型」或 ID 「sotaai」,关注 SOTA!模型服务号,即可通过服务号底部菜单栏使用平台功能,更有最新AI技术、开发资源及社区动态定期推送。