动手强化学习(八):DQN 改进算法——Dueling DQN

本文涉及的产品
函数计算FC,每月15万CU 3个月
简介: DQN 算法敲开了深度强化学习的大门,但是作为先驱性的工作,其本身存在着一些问题以及一些可以改进的地方。于是,在 DQN 之后,学术界涌现出了非常多的改进算法。本章将介绍其中两个非常著名的算法:Double DQN 和 Dueling DQN,这两个算法的实现非常简单,只需要在 DQN 的基础上稍加修改,它们能在一定程度上改善 DQN 的效果。

文章转于 伯禹学习平台-动手学强化学习 (强推)


本文所有代码均可在jupyter notebook运行


与君共勉,一起学习。


1. 简介


 DQN 算法敲开了深度强化学习的大门,但是作为先驱性的工作,其本身存在着一些问题以及一些可以改进的地方。于是,在 DQN 之后,学术界涌现出了非常多的改进算法。本章将介绍其中两个非常著名的算法:Double DQNDueling DQN,这两个算法的实现非常简单,只需要在 DQN 的基础上稍加修改,它们能在一定程度上改善 DQN 的效果。


2. Dueling DQN


Dueling DQN 是 DQN 另一种的改进算法,它在传统 DQN 的基础上只进行了微小的改动,但却能大幅提升 DQN 的表现。在强化学习中,我们将状态动作价值函数 Q减去状态价值函数 V 的结果定 义为优势函数 A ,即 A ( s , a ) = Q ( s , a ) − V ( s )  。在同一个状态下,所有动作的优势值之和为 0 ,因为所有动作的动作价值的期望就是这个状态的状态价值。据此,在 Dueling DQN 中,Q网络被 建模为:


Q η , α , β ( s , a ) = V η , α ( s ) + A η , β ( s , a )


其中, V η , α ( s )  为状态价值函数,而 A η , β ( s , a ) 则为该状态下采取不同动作的优势函数,表示采取不同动作的差异性; η  是状态价值函数和优势函数共享的网络参数,一般用在神经网络中,用来提 取特征的前几层; 而 α和 β  分别为状态价值函数和优势函数的参数。在这样的模型下,我们不再让神经网络直接输出 Q 值,而是训练神经网络的最后几层的两个分支,分别输出状态价值函数和优势 函数,再求和得到 Q值。Dueling DQN 的网络结构如图所示。


88db8ff9b1f34aebb2e7f538e41105e0.png


将状态价值函数和优势函数分别建模的好处在于:某些情境下智能体只会关注状态的价值,而并不关心不同动作导致的差异,此时将二者分开建模能够使智能体更好地处理与动作关联较小的状态。在下图所示的驾驶车辆游戏中,智能体注意力集中的部位被显示为橙色,当智能体前面没有车时,车辆自身动作并没有太大差异,此时智能体更关注状态价值,而当智能体前面有车时(智能体需要超车),智能体开始关注不同动作优势值的差异。


6369a9b185f64365aded37c78f7acd5b.png


对于 Dueling DQN 中的公式 Q η , α , β ( s , a ) = V η , α ( s ) + A η , β ( s , a ) ,它存在对于 V  值和 A 值建模不唯一性的问题。例如,对于同样的 Q值,如果将 V  值加上任意大小的常数 C  ,再将所有 A值减去 C  ,则得到的 Q 值依然不变,这就导致了训练的不稳定性。为了解决这一问题,Dueling DQN 强制最优动作的优势函数的实际输出为 0 ,即:


image.png


此时 V ( s ) = max ⁡ a Q ( s , a ) ,可以确保 V 值建模的唯一性。在实现过程中,我们还可以用平均代替最大化操作,即:


image.png


此时 image.png。在下面的代码实现中,我们将采取此种方式,虽然它不再满足贝尔曼最优方程,但实际应用时更加稳定。


有人可能会问:“为什么 Dueling DQN 会比 DQN 好? "部分原因在于 Dueling DQN 能更高效学习状态价值函数。每一次更新时,函数 V都会被更新,这也会影响到其他动作的 Q  值。而传统的 DQN 只会更新某个动作的 Q  值,其他动作的 Q 值就不会更新。因此,Dueling DQN 能够更加频繁、准确地学习状态价值函数。


3. Dueling DQN 代码实践


Dueling DQN 与 DQN 相比的差异只是在网络结构上,大部分代码依然可以继续沿用。我们定义状态价值函数和优势函数的复合神经网络VAnet。


class VAnet(torch.nn.Module):
    ''' 只有一层隐藏层的A网络和V网络 '''
    def __init__(self, state_dim, hidden_dim, action_dim):
        super(VAnet, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim, hidden_dim)  # 共享网络部分
        self.fc_A = torch.nn.Linear(hidden_dim, action_dim)
        self.fc_V = torch.nn.Linear(hidden_dim, 1)
    def forward(self, x):
        A = self.fc_A(F.relu(self.fc1(x)))
        V = self.fc_V(F.relu(self.fc1(x)))
        Q = V + A - A.mean(1).view(-1, 1)  # Q值由V值和A值计算得到
        return Q
class DQN:
    ''' DQN算法,包括Double DQN和Dueling DQN '''
    def __init__(self,
                 state_dim,
                 hidden_dim,
                 action_dim,
                 learning_rate,
                 gamma,
                 epsilon,
                 target_update,
                 device,
                 dqn_type='VanillaDQN'):
        self.action_dim = action_dim
        if dqn_type == 'DuelingDQN':  # Dueling DQN采取不一样的网络框架
            self.q_net = VAnet(state_dim, hidden_dim,
                               self.action_dim).to(device)
            self.target_q_net = VAnet(state_dim, hidden_dim,
                                      self.action_dim).to(device)
        else:
            self.q_net = Qnet(state_dim, hidden_dim,
                              self.action_dim).to(device)
            self.target_q_net = Qnet(state_dim, hidden_dim,
                                     self.action_dim).to(device)
        self.optimizer = torch.optim.Adam(self.q_net.parameters(),
                                          lr=learning_rate)
        self.gamma = gamma
        self.epsilon = epsilon
        self.target_update = target_update
        self.count = 0
        self.dqn_type = dqn_type
        self.device = device
    def take_action(self, state):
        if np.random.random() < self.epsilon:
            action = np.random.randint(self.action_dim)
        else:
            state = torch.tensor([state], dtype=torch.float).to(self.device)
            action = self.q_net(state).argmax().item()
        return action
    def max_q_value(self, state):
        state = torch.tensor([state], dtype=torch.float).to(self.device)
        return self.q_net(state).max().item()
    def update(self, transition_dict):
        states = torch.tensor(transition_dict['states'],
                              dtype=torch.float).to(self.device)
        actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(
            self.device)
        rewards = torch.tensor(transition_dict['rewards'],
                               dtype=torch.float).view(-1, 1).to(self.device)
        next_states = torch.tensor(transition_dict['next_states'],
                                   dtype=torch.float).to(self.device)
        dones = torch.tensor(transition_dict['dones'],
                             dtype=torch.float).view(-1, 1).to(self.device)
        q_values = self.q_net(states).gather(1, actions)
        if self.dqn_type == 'DoubleDQN':
            max_action = self.q_net(next_states).max(1)[1].view(-1, 1)
            max_next_q_values = self.target_q_net(next_states).gather(
                1, max_action)
        else:
            max_next_q_values = self.target_q_net(next_states).max(1)[0].view(
                -1, 1)
        q_targets = rewards + self.gamma * max_next_q_values * (1 - dones)
        dqn_loss = torch.mean(F.mse_loss(q_values, q_targets))
        self.optimizer.zero_grad()
        dqn_loss.backward()
        self.optimizer.step()
        if self.count % self.target_update == 0:
            self.target_q_net.load_state_dict(self.q_net.state_dict())
        self.count += 1
random.seed(0)
np.random.seed(0)
env.seed(0)
torch.manual_seed(0)
replay_buffer = rl_utils.ReplayBuffer(buffer_size)
agent = DQN(state_dim, hidden_dim, action_dim, lr, gamma, epsilon,
            target_update, device, 'DuelingDQN')
return_list, max_q_value_list = train_DQN(agent, env, num_episodes,
                                          replay_buffer, minimal_size,
                                          batch_size)
episodes_list = list(range(len(return_list)))
mv_return = rl_utils.moving_average(return_list, 5)
plt.plot(episodes_list, mv_return)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('Dueling DQN on {}'.format(env_name))
plt.show()
frames_list = list(range(len(max_q_value_list)))
plt.plot(frames_list, max_q_value_list)
plt.axhline(0, c='orange', ls='--')
plt.axhline(10, c='red', ls='--')
plt.xlabel('Frames')
plt.ylabel('Q value')
plt.title('Dueling DQN on {}'.format(env_name))
plt.show()
-------------------------------------------------------------------------------------------
Iteration 0: 100%|████████████████████████████████████████| 20/20 [00:10<00:00,  1.87it/s, episode=20, return=-708.652]
Iteration 1: 100%|████████████████████████████████████████| 20/20 [00:15<00:00,  1.28it/s, episode=40, return=-229.557]
Iteration 2: 100%|████████████████████████████████████████| 20/20 [00:15<00:00,  1.32it/s, episode=60, return=-184.607]
Iteration 3: 100%|████████████████████████████████████████| 20/20 [00:13<00:00,  1.50it/s, episode=80, return=-200.323]
Iteration 4: 100%|███████████████████████████████████████| 20/20 [00:13<00:00,  1.51it/s, episode=100, return=-213.811]
Iteration 5: 100%|███████████████████████████████████████| 20/20 [00:13<00:00,  1.53it/s, episode=120, return=-181.165]
Iteration 6: 100%|███████████████████████████████████████| 20/20 [00:14<00:00,  1.35it/s, episode=140, return=-222.040]
Iteration 7: 100%|███████████████████████████████████████| 20/20 [00:14<00:00,  1.35it/s, episode=160, return=-173.313]
Iteration 8: 100%|███████████████████████████████████████| 20/20 [00:12<00:00,  1.62it/s, episode=180, return=-236.372]
Iteration 9: 100%|███████████████████████████████████████| 20/20 [00:12<00:00,  1.57it/s, episode=200, return=-230.058]


6c7a1432e349486c8a8430ed57de2b72.png


da50de09890544f2ae341efeaf3997cb.png


根据代码运行结果我们可以发现,相比于传统的 DQN,Dueling DQN 在多个动作选择下的学习更加稳定,得到的回报最大值也更大。由 Dueling DQN 的原理可知,随着动作空间的增大,Dueling DQN 相比于 DQN 的优势更为明显。之前我们在环境中设置的离散动作数为 11,我们可以增加离散动作数(例如 15、25 等),继续进行对比实验。


4. 对 Q 值过高估计的定量分析


对 Q 值过高估计的定量分析 Q ω − ( s , a ) − V 服从 [ − 1 , 1 ]之间的均匀独立同分布;假设动作空间大小为 m ∘ 那么,对于任意状态 s ,有:


image.png


即状态空间 m  越大时, Q 值过高,估计越严重。


证明: 将估算误差记为 ϵ a = Q ω − ( s , a ) − max ⁡ a ′ Q ( s , a ′ )  ,由于估算误差对于不同的动作是独立的,因此有:


image.png


P ( ϵ a ≤ x )  是 ϵ a 的累积分布函数 (cumulative distribution function,即 CDF),它可以具体被写为:


image.png


因此,我们得到关于 max ⁡ a ϵ a的累积分布函数:


image.png


最后我们可以得到:


image.png


虽然这一分析简化了实际环境,但它仍然正确刻画了Q值过高估计的一些性质,比如Q 值的过高估计随动作空间大小m 的增加而增加,换言之,在动作选择数更多的环境中,Q值的过高估计会更严重。


总结


在传统的 DQN 基础上,有两种非常容易实现的变式——Double DQN 和 Dueling DQN,Double DQN 解决了 DQN 中对Q 值的过高估计,而 Dueling DQN 能够很好地学习到不同动作的差异性,在动作空间较大的环境下非常有效。从 Double DQN 和 Dueling DQN 的方法原理中,我们也能感受到深度强化学习的研究是在关注深度学习和强化学习有效结合:一是在深度学习的模块的基础上,强化学习方法如何更加有效地工作,并避免深度模型学习行为带来的一些问题,例如使用 Double DQN 解决Q 值过高估计的问题;二是在强化学习的场景下,深度学习模型如何有效学习到有用的模式,例如设计 Dueling DQN 网络架构来高效地学习状态价值函数以及动作优势函数。


相关资源来自:伯禹学习平台-动手学强化学习


35f0c043b96a4143bb9612b6bc0f1c4b.png

相关实践学习
【文生图】一键部署Stable Diffusion基于函数计算
本实验教你如何在函数计算FC上从零开始部署Stable Diffusion来进行AI绘画创作,开启AIGC盲盒。函数计算提供一定的免费额度供用户使用。本实验答疑钉钉群:29290019867
建立 Serverless 思维
本课程包括: Serverless 应用引擎的概念, 为开发者带来的实际价值, 以及让您了解常见的 Serverless 架构模式
目录
相关文章
|
2月前
|
机器学习/深度学习 算法 机器人
多代理强化学习综述:原理、算法与挑战
多代理强化学习是强化学习的一个子领域,专注于研究在共享环境中共存的多个学习代理的行为。每个代理都受其个体奖励驱动,采取行动以推进自身利益;在某些环境中,这些利益可能与其他代理的利益相冲突,从而产生复杂的群体动态。
241 5
|
1天前
|
机器学习/深度学习 算法
强化学习之父Richard Sutton给出一个简单思路,大幅增强所有RL算法
Richard Sutton领导的团队提出了一种称为“奖励中心化”的方法,通过从观察到的奖励中减去其经验平均值,使奖励更加集中,显著提高了强化学习算法的性能。该方法在解决持续性问题时表现出色,尤其是在折扣因子接近1的情况下。论文地址:https://arxiv.org/pdf/2405.09999
25 15
|
19天前
|
机器学习/深度学习 人工智能 算法
探索人工智能中的强化学习:原理、算法与应用
探索人工智能中的强化学习:原理、算法与应用
|
19天前
|
机器学习/深度学习 人工智能 算法
探索人工智能中的强化学习:原理、算法及应用
探索人工智能中的强化学习:原理、算法及应用
|
4月前
|
机器学习/深度学习 算法 TensorFlow
深入探索强化学习与深度学习的融合:使用TensorFlow框架实现深度Q网络算法及高效调试技巧
【8月更文挑战第31天】强化学习是机器学习的重要分支,尤其在深度学习的推动下,能够解决更为复杂的问题。深度Q网络(DQN)结合了深度学习与强化学习的优势,通过神经网络逼近动作价值函数,在多种任务中表现出色。本文探讨了使用TensorFlow实现DQN算法的方法及其调试技巧。DQN通过神经网络学习不同状态下采取动作的预期回报Q(s,a),处理高维状态空间。
68 1
|
4月前
|
机器学习/深度学习 存储 算法
强化学习实战:基于 PyTorch 的环境搭建与算法实现
【8月更文第29天】强化学习是机器学习的一个重要分支,它让智能体通过与环境交互来学习策略,以最大化长期奖励。本文将介绍如何使用PyTorch实现两种经典的强化学习算法——Deep Q-Network (DQN) 和 Actor-Critic Algorithm with Asynchronous Advantage (A3C)。我们将从环境搭建开始,逐步实现算法的核心部分,并给出完整的代码示例。
321 1
|
4月前
|
测试技术 数据库
探索JSF单元测试秘籍!如何让您的应用更稳固、更高效?揭秘成功背后的测试之道!
【8月更文挑战第31天】在 JavaServer Faces(JSF)应用开发中,确保代码质量和可维护性至关重要。本文详细介绍了如何通过单元测试实现这一目标。首先,阐述了单元测试的重要性及其对应用稳定性的影响;其次,提出了提高 JSF 应用可测试性的设计建议,如避免直接访问外部资源和使用依赖注入;最后,通过一个具体的 `UserBean` 示例,展示了如何利用 JUnit 和 Mockito 框架编写有效的单元测试。通过这些方法,不仅能够确保代码质量,还能提高开发效率和降低维护成本。
57 0
|
13天前
|
算法
基于WOA算法的SVDD参数寻优matlab仿真
该程序利用鲸鱼优化算法(WOA)对支持向量数据描述(SVDD)模型的参数进行优化,以提高数据分类的准确性。通过MATLAB2022A实现,展示了不同信噪比(SNR)下模型的分类误差。WOA通过模拟鲸鱼捕食行为,动态调整SVDD参数,如惩罚因子C和核函数参数γ,以寻找最优参数组合,增强模型的鲁棒性和泛化能力。
|
19天前
|
机器学习/深度学习 算法 Serverless
基于WOA-SVM的乳腺癌数据分类识别算法matlab仿真,对比BP神经网络和SVM
本项目利用鲸鱼优化算法(WOA)优化支持向量机(SVM)参数,针对乳腺癌早期诊断问题,通过MATLAB 2022a实现。核心代码包括参数初始化、目标函数计算、位置更新等步骤,并附有详细中文注释及操作视频。实验结果显示,WOA-SVM在提高分类精度和泛化能力方面表现出色,为乳腺癌的早期诊断提供了有效的技术支持。
|
6天前
|
存储 算法
基于HMM隐马尔可夫模型的金融数据预测算法matlab仿真
本项目基于HMM模型实现金融数据预测,包括模型训练与预测两部分。在MATLAB2022A上运行,通过计算状态转移和观测概率预测未来值,并绘制了预测值、真实值及预测误差的对比图。HMM模型适用于金融市场的时间序列分析,能够有效捕捉隐藏状态及其转换规律,为金融预测提供有力工具。