动手强化学习(十):Actor-Critic 算法

简介: 在之前的内容中,我们学习了基于值函数的方法(DQN)和基于策略的方法(REINFORCE),其中基于值函数的方法只学习一个价值函数,而基于策略的方法只学习一个策略函数。那么一个很自然的问题,有没有什么方法既学习价值函数,又学习策略函数呢?答案就是 Actor-Critic。Actor-Critic 是一系列算法,目前前沿的很多高效算法都属于 Actor-Critic 算法,今天我们将会介绍一种最简单的 Actor-Critic 算法。需要明确的是,Actor-Critic 算法本质上是基于策略的算法,因为这系列算法都是去优化一个带参数的策略,只是其中会额外学习价值函数来帮助策略函数的学习。

文章转于 伯禹学习平台-动手学强化学习 (强推)


本文所有代码均可在jupyter notebook运行


与君共勉,一起学习。


1. 简介


 在之前的内容中,我们学习了基于值函数的方法(DQN)和基于策略的方法(REINFORCE),其中基于值函数的方法只学习一个价值函数,而基于策略的方法只学习一个策略函数。那么一个很自然的问题,有没有什么方法既学习价值函数,又学习策略函数呢?答案就是 Actor-Critic。Actor-Critic 是一系列算法,目前前沿的很多高效算法都属于 Actor-Critic 算法,今天我们将会介绍一种最简单的 Actor-Critic 算法。需要明确的是,Actor-Critic 算法本质上是基于策略的算法,因为这系列算法都是去优化一个带参数的策略,只是其中会额外学习价值函数来帮助策略函数的学习。


2. Actor-Critic 算法


 我们回顾一下在 REINFORCE 算法中,目标函数的梯度中有一项轨迹回报,来指导策略的更新。而值函数的概念正是基于期望回报,我们能不能考虑拟合一个值函数来指导策略进行学习呢?这正是 Actor-Critic 算法所做的。让我们先回顾一下策略梯度的形式,在策略梯度中,我们可以把梯度写成下面这个形式:


image.png


其中 ψ t 可以有很多种形式:


image.png


 在 REINFORCE 的最后部分,我们提到了 REINFORCE通过蒙特卡洛采样的方法对梯度的估计是无偏的,但是方差非常大,我们可以用第三种形式引入基线 (baseline) b ( s t ) 来减小方差。此外我们也可以采用 Actor-Critic 算法,估计 一个动作价值函数 Q 来代替蒙特卡洛采样得到的回报,这便是第 4 种形式。这个时候,我们也可以把状态价值函数 V  作为基线,从偍牧但是用神经网络进行估计的方法可以减小方差、提高鲁棒性。除此之外,REINFORCE 算法基于蒙特卡洛采样,只能在序列结束后进行更新,而 Actor-Critic 的方法则可以在每一步之后都进行更新。


我们将 Actor-Critic 分为两个部分: 分别是 Actor (策略网络) 和 Critic (价值网络):


  • Critic 要做的是通过 Actor 与环境交互收集的数据学习一个价值函数,这个价值函数会用于帮助 Actor 进行更新策略。


  • Actor 要做的则是与环境交互,并利用 Ctitic 价值函数来用策略梯度学习一个更好的策略。


image.png


 与 DQN 中一样,我们采取类似于目标网络的方法,上式中 r + γ V ω ( s t + 1 )作为时序差分目标,不会产生梯度来更新价值函数。所以价值函数的梯度为


image.png


然后使用梯度下降方法即可。接下来让我们总体看看 Actor-Critic 算法的流程吧!


  • 初始化策略网络参数 θ  ,价值网络参数 ω


  • 不断进行如下循环 (每个循环是一条序列) :


。 用当前策略 π θ 平样轨 迹 { s 1 , a 1 , r 1 , s 2 , a 2 , r 2 … }


。 为每一步数据计算: δ t = r t + γ V ω ( s t + 1 ) − V ω ( s )


。 更新价值参数 w = w + α ω t δ tω V ω ( s )


。 更新策略参数 θ = θ + α θt δ tθ log ⁡ π θ ( a ∣ s )


 好了!这就是 Actor-Critic 算法的流程啦,让我们来用代码实现它看看效果如何吧!


3. Actor-Critic 代码实践


 我们仍然在 Cartpole 环境上进行 Actor-Critic 算法的实验。


import gym
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import rl_utils


 定义我们的策略网络 PolicyNet,与 REINFORCE 算法中一样。


class PolicyNet(torch.nn.Module):
    def __init__(self, state_dim, hidden_dim, action_dim):
        super(PolicyNet, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, action_dim)
    def forward(self, x):
        x = F.relu(self.fc1(x))
        return  F.softmax(self.fc2(x),dim=1)


 Actor-Critic 算法中额外引入一个价值网络,接下来的代码定义我们的价值网络 ValueNet,输入是状态,输出状态的价值。


class ValueNet(torch.nn.Module):
    def __init__(self, state_dim, hidden_dim):
        super(ValueNet, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, 1)
    def forward(self, x):
        x = F.relu(self.fc1(x))
        return self.fc2(x)


 再定义我们的 ActorCritic 算法。主要包含采取动作和更新网络参数两个函数。


class ActorCritic:
    def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr, gamma, device):
        self.actor = PolicyNet(state_dim, hidden_dim, action_dim).to(device)
        self.critic = ValueNet(state_dim, hidden_dim).to(device) # 价值网络
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=critic_lr) # 价值网络优化器
        self.gamma = gamma
    def take_action(self, state):
        state = torch.tensor([state], dtype=torch.float)
        probs = self.actor(state)
        action_dist = torch.distributions.Categorical(probs)
        action = action_dist.sample()
        return action.item()
    def update(self, transition_dict):
        states = torch.tensor(transition_dict['states'], dtype=torch.float)
        actions = torch.tensor(transition_dict['actions']).view(-1, 1)
        rewards = torch.tensor(transition_dict['rewards'], dtype=torch.float).view(-1, 1)
        next_states = torch.tensor(transition_dict['next_states'], dtype=torch.float)
        dones = torch.tensor(transition_dict['dones'], dtype=torch.float).view(-1, 1)
        td_target = rewards + self.gamma * self.critic(next_states) * (1 - dones) # 时序差分目标
        td_delta = td_target - self.critic(states) # 时序差分误差
        log_probs = torch.log(self.actor(states).gather(1, actions))
        actor_loss = torch.mean(-log_probs * td_delta.detach())
        critic_loss = torch.mean(F.mse_loss(self.critic(states), td_target.detach())) # 均方误差损失函数
        self.actor_optimizer.zero_grad()
        self.critic_optimizer.zero_grad()
        actor_loss.backward() # 计算策略网络的梯度
        critic_loss.backward() # 计算价值网络的梯度
        self.actor_optimizer.step() # 更新策略网络参数
        self.critic_optimizer.step() # 更新价值网络参数


 定义好 Actor 和 Critic,我们就可以开始实验了,看看 Actor-Critic 在 Cartpole 环境上表现如何吧!


actor_lr = 1e-3
critic_lr = 1e-2
num_episodes = 1000
hidden_dim = 128
gamma = 0.98
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
env_name = 'CartPole-v0'
env = gym.make(env_name)
env.seed(0)
torch.manual_seed(0)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = ActorCritic(state_dim, hidden_dim, action_dim, actor_lr, critic_lr, gamma, device)
return_list = rl_utils.train_on_policy_agent(env, agent, num_episodes)
----------------------------------------------------------------------------------------------Iteration 0: 100%|██████████| 100/100 [00:00<00:00, 218.65it/s, episode=100, return=21.100]
Iteration 1: 100%|██████████| 100/100 [00:01<00:00, 95.81it/s, episode=200, return=72.800]
Iteration 2: 100%|██████████| 100/100 [00:02<00:00, 45.96it/s, episode=300, return=109.300]
Iteration 3: 100%|██████████| 100/100 [00:05<00:00, 12.55it/s, episode=400, return=163.000]
Iteration 4: 100%|██████████| 100/100 [00:08<00:00, 11.24it/s, episode=500, return=193.600]
Iteration 5: 100%|██████████| 100/100 [00:08<00:00, 11.11it/s, episode=600, return=195.900]
Iteration 6: 100%|██████████| 100/100 [00:08<00:00, 11.88it/s, episode=700, return=199.100]
Iteration 7: 100%|██████████| 100/100 [00:08<00:00, 11.77it/s, episode=800, return=186.900]
Iteration 8: 100%|██████████| 100/100 [00:08<00:00, 11.23it/s, episode=900, return=200.000]
Iteration 9: 100%|██████████| 100/100 [00:08<00:00, 11.22it/s, episode=1000, return=200.000]


 在 CartPole-v0 环境中,满分就是 200 分,让我们来看看每个序列得分如何吧!


episodes_list = list(range(len(return_list)))
plt.plot(episodes_list,return_list)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('Actor-Critic on {}'.format(env_name))
plt.show()
mv_return = rl_utils.moving_average(return_list, 9)
plt.plot(episodes_list, mv_return)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('Actor-Critic on {}'.format(env_name))
plt.show()


f6ea04344d744a95961fdcf60a2804a7.png


fcfd3ac0b55e47ccadcfb6f4b18bc317.png


 根据实验结果我们发现,Actor-Critic 算法很快便能收敛到最优策略,并且训练过程非常稳定,抖动情况相比 REINFORCE 算法有了明显的改进,这多亏了价值函数的引入减小了方差。


4. 总结


 我们在本章中学习了 Actor-Critic 算法,它是基于策略和基于价值的方法的叠加。Actor-Critic 算法非常实用,往后像 DDPG、TRPO、PPO、SAC 这样的算法都是在 Actor-Critic 框架下进行发展的,深入了解 Actor-Critic 算法对读懂目前深度强化学习的研究热点大有裨益。


相关资源来自:伯禹学习平台-动手学强化学习


35f0c043b96a4143bb9612b6bc0f1c4b.png

目录
相关文章
|
2月前
|
机器学习/深度学习 算法 机器人
多代理强化学习综述:原理、算法与挑战
多代理强化学习是强化学习的一个子领域,专注于研究在共享环境中共存的多个学习代理的行为。每个代理都受其个体奖励驱动,采取行动以推进自身利益;在某些环境中,这些利益可能与其他代理的利益相冲突,从而产生复杂的群体动态。
259 5
|
5天前
|
机器学习/深度学习 算法
强化学习之父Richard Sutton给出一个简单思路,大幅增强所有RL算法
Richard Sutton领导的团队提出了一种称为“奖励中心化”的方法,通过从观察到的奖励中减去其经验平均值,使奖励更加集中,显著提高了强化学习算法的性能。该方法在解决持续性问题时表现出色,尤其是在折扣因子接近1的情况下。论文地址:https://arxiv.org/pdf/2405.09999
30 15
|
23天前
|
机器学习/深度学习 人工智能 算法
探索人工智能中的强化学习:原理、算法与应用
探索人工智能中的强化学习:原理、算法与应用
|
23天前
|
机器学习/深度学习 人工智能 算法
探索人工智能中的强化学习:原理、算法及应用
探索人工智能中的强化学习:原理、算法及应用
|
4月前
|
机器学习/深度学习 算法 TensorFlow
深入探索强化学习与深度学习的融合:使用TensorFlow框架实现深度Q网络算法及高效调试技巧
【8月更文挑战第31天】强化学习是机器学习的重要分支,尤其在深度学习的推动下,能够解决更为复杂的问题。深度Q网络(DQN)结合了深度学习与强化学习的优势,通过神经网络逼近动作价值函数,在多种任务中表现出色。本文探讨了使用TensorFlow实现DQN算法的方法及其调试技巧。DQN通过神经网络学习不同状态下采取动作的预期回报Q(s,a),处理高维状态空间。
70 1
|
4月前
|
机器学习/深度学习 存储 算法
强化学习实战:基于 PyTorch 的环境搭建与算法实现
【8月更文第29天】强化学习是机器学习的一个重要分支,它让智能体通过与环境交互来学习策略,以最大化长期奖励。本文将介绍如何使用PyTorch实现两种经典的强化学习算法——Deep Q-Network (DQN) 和 Actor-Critic Algorithm with Asynchronous Advantage (A3C)。我们将从环境搭建开始,逐步实现算法的核心部分,并给出完整的代码示例。
335 1
|
4月前
|
测试技术 数据库
探索JSF单元测试秘籍!如何让您的应用更稳固、更高效?揭秘成功背后的测试之道!
【8月更文挑战第31天】在 JavaServer Faces(JSF)应用开发中,确保代码质量和可维护性至关重要。本文详细介绍了如何通过单元测试实现这一目标。首先,阐述了单元测试的重要性及其对应用稳定性的影响;其次,提出了提高 JSF 应用可测试性的设计建议,如避免直接访问外部资源和使用依赖注入;最后,通过一个具体的 `UserBean` 示例,展示了如何利用 JUnit 和 Mockito 框架编写有效的单元测试。通过这些方法,不仅能够确保代码质量,还能提高开发效率和降低维护成本。
57 0
|
5月前
|
机器学习/深度学习 存储 数据采集
强化学习系列:A3C算法解析
【7月更文挑战第13天】A3C算法作为一种高效且广泛应用的强化学习算法,通过结合Actor-Critic结构和异步训练的思想,实现了在复杂环境下的高效学习和优化策略的能力。其并行化的训练方式和优势函数的引入,使得A3C算法在解决大规模连续动作空间和高维状态空间的问题上表现优异。未来,随着技术的不断发展,A3C算法有望在更多领域发挥重要作用,推动强化学习技术的进一步发展。
|
6月前
|
机器学习/深度学习 分布式计算 算法
在机器学习项目中,选择算法涉及问题类型识别(如回归、分类、聚类、强化学习)
【6月更文挑战第28天】在机器学习项目中,选择算法涉及问题类型识别(如回归、分类、聚类、强化学习)、数据规模与特性(大数据可能适合分布式算法或深度学习)、性能需求(准确性、速度、可解释性)、资源限制(计算与内存)、领域知识应用以及实验验证(交叉验证、模型比较)。迭代过程包括数据探索、模型构建、评估和优化,结合业务需求进行决策。
63 0
|
7月前
|
机器学习/深度学习 算法
算法人生(2):从“强化学习”看如何“活在当下”
本文探讨了强化学习的原理及其在个人生活中的启示。强化学习强调智能体在动态环境中通过与环境交互学习最优策略,不断迭代优化。这种思想类似于“活在当下”的哲学,要求人们专注于当前状态和决策,不过分依赖历史经验或担忧未来。活在当下意味着全情投入每一刻,不被过去或未来牵绊。通过减少执着,提高觉察力和静心练习,我们可以更好地活在当下,同时兼顾历史经验和未来规划。文章建议实践静心、时间管理和接纳每个瞬间,以实现更低焦虑、更高生活质量的生活艺术。
下一篇
DataWorks