动手强化学习(八):DQN 改进算法——Dueling DQN

简介: DQN 算法敲开了深度强化学习的大门,但是作为先驱性的工作,其本身存在着一些问题以及一些可以改进的地方。于是,在 DQN 之后,学术界涌现出了非常多的改进算法。本章将介绍其中两个非常著名的算法:Double DQN 和 Dueling DQN,这两个算法的实现非常简单,只需要在 DQN 的基础上稍加修改,它们能在一定程度上改善 DQN 的效果。

文章转于 伯禹学习平台-动手学强化学习 (强推)


本文所有代码均可在jupyter notebook运行


与君共勉,一起学习。


1. 简介


 DQN 算法敲开了深度强化学习的大门,但是作为先驱性的工作,其本身存在着一些问题以及一些可以改进的地方。于是,在 DQN 之后,学术界涌现出了非常多的改进算法。本章将介绍其中两个非常著名的算法:Double DQNDueling DQN,这两个算法的实现非常简单,只需要在 DQN 的基础上稍加修改,它们能在一定程度上改善 DQN 的效果。


2. Dueling DQN


Dueling DQN 是 DQN 另一种的改进算法,它在传统 DQN 的基础上只进行了微小的改动,但却能大幅提升 DQN 的表现。在强化学习中,我们将状态动作价值函数 Q减去状态价值函数 V 的结果定 义为优势函数 A ,即 A ( s , a ) = Q ( s , a ) − V ( s )  。在同一个状态下,所有动作的优势值之和为 0 ,因为所有动作的动作价值的期望就是这个状态的状态价值。据此,在 Dueling DQN 中,Q网络被 建模为:


Q η , α , β ( s , a ) = V η , α ( s ) + A η , β ( s , a )


其中, V η , α ( s )  为状态价值函数,而 A η , β ( s , a ) 则为该状态下采取不同动作的优势函数,表示采取不同动作的差异性; η  是状态价值函数和优势函数共享的网络参数,一般用在神经网络中,用来提 取特征的前几层; 而 α和 β  分别为状态价值函数和优势函数的参数。在这样的模型下,我们不再让神经网络直接输出 Q 值,而是训练神经网络的最后几层的两个分支,分别输出状态价值函数和优势 函数,再求和得到 Q值。Dueling DQN 的网络结构如图所示。


88db8ff9b1f34aebb2e7f538e41105e0.png


将状态价值函数和优势函数分别建模的好处在于:某些情境下智能体只会关注状态的价值,而并不关心不同动作导致的差异,此时将二者分开建模能够使智能体更好地处理与动作关联较小的状态。在下图所示的驾驶车辆游戏中,智能体注意力集中的部位被显示为橙色,当智能体前面没有车时,车辆自身动作并没有太大差异,此时智能体更关注状态价值,而当智能体前面有车时(智能体需要超车),智能体开始关注不同动作优势值的差异。


6369a9b185f64365aded37c78f7acd5b.png


对于 Dueling DQN 中的公式 Q η , α , β ( s , a ) = V η , α ( s ) + A η , β ( s , a ) ,它存在对于 V  值和 A 值建模不唯一性的问题。例如,对于同样的 Q值,如果将 V  值加上任意大小的常数 C  ,再将所有 A值减去 C  ,则得到的 Q 值依然不变,这就导致了训练的不稳定性。为了解决这一问题,Dueling DQN 强制最优动作的优势函数的实际输出为 0 ,即:


image.png


此时 V ( s ) = max ⁡ a Q ( s , a ) ,可以确保 V 值建模的唯一性。在实现过程中,我们还可以用平均代替最大化操作,即:


image.png


此时 image.png。在下面的代码实现中,我们将采取此种方式,虽然它不再满足贝尔曼最优方程,但实际应用时更加稳定。


有人可能会问:“为什么 Dueling DQN 会比 DQN 好? "部分原因在于 Dueling DQN 能更高效学习状态价值函数。每一次更新时,函数 V都会被更新,这也会影响到其他动作的 Q  值。而传统的 DQN 只会更新某个动作的 Q  值,其他动作的 Q 值就不会更新。因此,Dueling DQN 能够更加频繁、准确地学习状态价值函数。


3. Dueling DQN 代码实践


Dueling DQN 与 DQN 相比的差异只是在网络结构上,大部分代码依然可以继续沿用。我们定义状态价值函数和优势函数的复合神经网络VAnet。


class VAnet(torch.nn.Module):
    ''' 只有一层隐藏层的A网络和V网络 '''
    def __init__(self, state_dim, hidden_dim, action_dim):
        super(VAnet, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim, hidden_dim)  # 共享网络部分
        self.fc_A = torch.nn.Linear(hidden_dim, action_dim)
        self.fc_V = torch.nn.Linear(hidden_dim, 1)
    def forward(self, x):
        A = self.fc_A(F.relu(self.fc1(x)))
        V = self.fc_V(F.relu(self.fc1(x)))
        Q = V + A - A.mean(1).view(-1, 1)  # Q值由V值和A值计算得到
        return Q
class DQN:
    ''' DQN算法,包括Double DQN和Dueling DQN '''
    def __init__(self,
                 state_dim,
                 hidden_dim,
                 action_dim,
                 learning_rate,
                 gamma,
                 epsilon,
                 target_update,
                 device,
                 dqn_type='VanillaDQN'):
        self.action_dim = action_dim
        if dqn_type == 'DuelingDQN':  # Dueling DQN采取不一样的网络框架
            self.q_net = VAnet(state_dim, hidden_dim,
                               self.action_dim).to(device)
            self.target_q_net = VAnet(state_dim, hidden_dim,
                                      self.action_dim).to(device)
        else:
            self.q_net = Qnet(state_dim, hidden_dim,
                              self.action_dim).to(device)
            self.target_q_net = Qnet(state_dim, hidden_dim,
                                     self.action_dim).to(device)
        self.optimizer = torch.optim.Adam(self.q_net.parameters(),
                                          lr=learning_rate)
        self.gamma = gamma
        self.epsilon = epsilon
        self.target_update = target_update
        self.count = 0
        self.dqn_type = dqn_type
        self.device = device
    def take_action(self, state):
        if np.random.random() < self.epsilon:
            action = np.random.randint(self.action_dim)
        else:
            state = torch.tensor([state], dtype=torch.float).to(self.device)
            action = self.q_net(state).argmax().item()
        return action
    def max_q_value(self, state):
        state = torch.tensor([state], dtype=torch.float).to(self.device)
        return self.q_net(state).max().item()
    def update(self, transition_dict):
        states = torch.tensor(transition_dict['states'],
                              dtype=torch.float).to(self.device)
        actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(
            self.device)
        rewards = torch.tensor(transition_dict['rewards'],
                               dtype=torch.float).view(-1, 1).to(self.device)
        next_states = torch.tensor(transition_dict['next_states'],
                                   dtype=torch.float).to(self.device)
        dones = torch.tensor(transition_dict['dones'],
                             dtype=torch.float).view(-1, 1).to(self.device)
        q_values = self.q_net(states).gather(1, actions)
        if self.dqn_type == 'DoubleDQN':
            max_action = self.q_net(next_states).max(1)[1].view(-1, 1)
            max_next_q_values = self.target_q_net(next_states).gather(
                1, max_action)
        else:
            max_next_q_values = self.target_q_net(next_states).max(1)[0].view(
                -1, 1)
        q_targets = rewards + self.gamma * max_next_q_values * (1 - dones)
        dqn_loss = torch.mean(F.mse_loss(q_values, q_targets))
        self.optimizer.zero_grad()
        dqn_loss.backward()
        self.optimizer.step()
        if self.count % self.target_update == 0:
            self.target_q_net.load_state_dict(self.q_net.state_dict())
        self.count += 1
random.seed(0)
np.random.seed(0)
env.seed(0)
torch.manual_seed(0)
replay_buffer = rl_utils.ReplayBuffer(buffer_size)
agent = DQN(state_dim, hidden_dim, action_dim, lr, gamma, epsilon,
            target_update, device, 'DuelingDQN')
return_list, max_q_value_list = train_DQN(agent, env, num_episodes,
                                          replay_buffer, minimal_size,
                                          batch_size)
episodes_list = list(range(len(return_list)))
mv_return = rl_utils.moving_average(return_list, 5)
plt.plot(episodes_list, mv_return)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('Dueling DQN on {}'.format(env_name))
plt.show()
frames_list = list(range(len(max_q_value_list)))
plt.plot(frames_list, max_q_value_list)
plt.axhline(0, c='orange', ls='--')
plt.axhline(10, c='red', ls='--')
plt.xlabel('Frames')
plt.ylabel('Q value')
plt.title('Dueling DQN on {}'.format(env_name))
plt.show()
-------------------------------------------------------------------------------------------
Iteration 0: 100%|████████████████████████████████████████| 20/20 [00:10<00:00,  1.87it/s, episode=20, return=-708.652]
Iteration 1: 100%|████████████████████████████████████████| 20/20 [00:15<00:00,  1.28it/s, episode=40, return=-229.557]
Iteration 2: 100%|████████████████████████████████████████| 20/20 [00:15<00:00,  1.32it/s, episode=60, return=-184.607]
Iteration 3: 100%|████████████████████████████████████████| 20/20 [00:13<00:00,  1.50it/s, episode=80, return=-200.323]
Iteration 4: 100%|███████████████████████████████████████| 20/20 [00:13<00:00,  1.51it/s, episode=100, return=-213.811]
Iteration 5: 100%|███████████████████████████████████████| 20/20 [00:13<00:00,  1.53it/s, episode=120, return=-181.165]
Iteration 6: 100%|███████████████████████████████████████| 20/20 [00:14<00:00,  1.35it/s, episode=140, return=-222.040]
Iteration 7: 100%|███████████████████████████████████████| 20/20 [00:14<00:00,  1.35it/s, episode=160, return=-173.313]
Iteration 8: 100%|███████████████████████████████████████| 20/20 [00:12<00:00,  1.62it/s, episode=180, return=-236.372]
Iteration 9: 100%|███████████████████████████████████████| 20/20 [00:12<00:00,  1.57it/s, episode=200, return=-230.058]


6c7a1432e349486c8a8430ed57de2b72.png


da50de09890544f2ae341efeaf3997cb.png


根据代码运行结果我们可以发现,相比于传统的 DQN,Dueling DQN 在多个动作选择下的学习更加稳定,得到的回报最大值也更大。由 Dueling DQN 的原理可知,随着动作空间的增大,Dueling DQN 相比于 DQN 的优势更为明显。之前我们在环境中设置的离散动作数为 11,我们可以增加离散动作数(例如 15、25 等),继续进行对比实验。


4. 对 Q 值过高估计的定量分析


对 Q 值过高估计的定量分析 Q ω − ( s , a ) − V 服从 [ − 1 , 1 ]之间的均匀独立同分布;假设动作空间大小为 m ∘ 那么,对于任意状态 s ,有:


image.png


即状态空间 m  越大时, Q 值过高,估计越严重。


证明: 将估算误差记为 ϵ a = Q ω − ( s , a ) − max ⁡ a ′ Q ( s , a ′ )  ,由于估算误差对于不同的动作是独立的,因此有:


image.png


P ( ϵ a ≤ x )  是 ϵ a 的累积分布函数 (cumulative distribution function,即 CDF),它可以具体被写为:


image.png


因此,我们得到关于 max ⁡ a ϵ a的累积分布函数:


image.png


最后我们可以得到:


image.png


虽然这一分析简化了实际环境,但它仍然正确刻画了Q值过高估计的一些性质,比如Q 值的过高估计随动作空间大小m 的增加而增加,换言之,在动作选择数更多的环境中,Q值的过高估计会更严重。


总结


在传统的 DQN 基础上,有两种非常容易实现的变式——Double DQN 和 Dueling DQN,Double DQN 解决了 DQN 中对Q 值的过高估计,而 Dueling DQN 能够很好地学习到不同动作的差异性,在动作空间较大的环境下非常有效。从 Double DQN 和 Dueling DQN 的方法原理中,我们也能感受到深度强化学习的研究是在关注深度学习和强化学习有效结合:一是在深度学习的模块的基础上,强化学习方法如何更加有效地工作,并避免深度模型学习行为带来的一些问题,例如使用 Double DQN 解决Q 值过高估计的问题;二是在强化学习的场景下,深度学习模型如何有效学习到有用的模式,例如设计 Dueling DQN 网络架构来高效地学习状态价值函数以及动作优势函数。


相关资源来自:伯禹学习平台-动手学强化学习


35f0c043b96a4143bb9612b6bc0f1c4b.png

相关实践学习
【玩转ComfyUI】基于函数计算一键部署AI生图平台ComfyUI
本次实验将带大家通过使用阿里云产品函数计算FC,快速使用ComfyUI实现更高质量的图像生成。
从 0 入门函数计算
在函数计算的架构中,开发者只需要编写业务代码,并监控业务运行情况就可以了。这将开发者从繁重的运维工作中解放出来,将精力投入到更有意义的开发任务上。
目录
相关文章
|
机器学习/深度学习 数据采集 算法
智能限速算法:基于强化学习的动态请求间隔控制
本文分享了通过强化学习解决抖音爬虫限速问题的技术实践。针对固定速率请求易被封禁的问题,引入基于DQN的动态请求间隔控制算法,智能调整请求间隔以平衡效率与稳定性。文中详细描述了真实经历、问题分析、技术突破及代码实现,包括代理配置、状态设计与奖励机制,并反思成长,提出未来优化方向。此方法具通用性,适用于多种动态节奏控制场景。
794 6
智能限速算法:基于强化学习的动态请求间隔控制
|
10月前
|
机器学习/深度学习 算法 PyTorch
【Pytorch框架搭建神经网络】基于DQN算法、优先级采样的DQN算法、DQN + 人工势场的避障控制研究(Python代码实现)
【Pytorch框架搭建神经网络】基于DQN算法、优先级采样的DQN算法、DQN + 人工势场的避障控制研究(Python代码实现)
264 1
|
10月前
|
机器学习/深度学习 算法 PyTorch
【DQN实现避障控制】使用Pytorch框架搭建神经网络,基于DQN算法、优先级采样的DQN算法、DQN + 人工势场实现避障控制研究(Matlab、Python实现)
【DQN实现避障控制】使用Pytorch框架搭建神经网络,基于DQN算法、优先级采样的DQN算法、DQN + 人工势场实现避障控制研究(Matlab、Python实现)
413 0
|
12月前
|
机器学习/深度学习 存储 算法
强化学习算法基准测试:6种算法在多智能体环境中的表现实测
本文系统研究了多智能体强化学习的算法性能与评估框架,选用井字棋和连珠四子作为基准环境,对比分析Q-learning、蒙特卡洛、Sarsa等表格方法在对抗场景中的表现。实验表明,表格方法在小规模状态空间(如井字棋)中可有效学习策略,但在大规模状态空间(如连珠四子)中因泛化能力不足而失效,揭示了向函数逼近技术演进的必要性。研究构建了标准化评估流程,明确了不同算法的适用边界,为理解强化学习的可扩展性问题提供了实证支持与理论参考。
563 0
强化学习算法基准测试:6种算法在多智能体环境中的表现实测
|
机器学习/深度学习 算法 数据可视化
基于Qlearning强化学习的机器人迷宫路线搜索算法matlab仿真
本内容展示了基于Q-learning算法的机器人迷宫路径搜索仿真及其实现过程。通过Matlab2022a进行仿真,结果以图形形式呈现,无水印(附图1-4)。算法理论部分介绍了Q-learning的核心概念,包括智能体、环境、状态、动作和奖励,以及Q表的构建与更新方法。具体实现中,将迷宫抽象为二维网格世界,定义起点和终点,利用Q-learning训练机器人找到最优路径。核心程序代码实现了多轮训练、累计奖励值与Q值的可视化,并展示了机器人从起点到终点的路径规划过程。
682 0
|
机器学习/深度学习 算法 机器人
强化学习:时间差分(TD)(SARSA算法和Q-Learning算法)(看不懂算我输专栏)——手把手教你入门强化学习(六)
本文介绍了时间差分法(TD)中的两种经典算法:SARSA和Q-Learning。二者均为无模型强化学习方法,通过与环境交互估算动作价值函数。SARSA是On-Policy算法,采用ε-greedy策略进行动作选择和评估;而Q-Learning为Off-Policy算法,评估时选取下一状态中估值最大的动作。相比动态规划和蒙特卡洛方法,TD算法结合了自举更新与样本更新的优势,实现边行动边学习。文章通过生动的例子解释了两者的差异,并提供了伪代码帮助理解。
1151 2
|
机器学习/深度学习 算法 PyTorch
深度强化学习中SAC算法:数学原理、网络架构及其PyTorch实现
软演员-评论家算法(Soft Actor-Critic, SAC)是深度强化学习领域的重要进展,基于最大熵框架优化策略,在探索与利用之间实现动态平衡。SAC通过双Q网络设计和自适应温度参数,提升了训练稳定性和样本效率。本文详细解析了SAC的数学原理、网络架构及PyTorch实现,涵盖演员网络的动作采样与对数概率计算、评论家网络的Q值估计及其损失函数,并介绍了完整的SAC智能体实现流程。SAC在连续动作空间中表现出色,具有高样本效率和稳定的训练过程,适合实际应用场景。
5963 7
深度强化学习中SAC算法:数学原理、网络架构及其PyTorch实现
|
9月前
|
机器学习/深度学习 算法 机器人
【水下图像增强融合算法】基于融合的水下图像与视频增强研究(Matlab代码实现)
【水下图像增强融合算法】基于融合的水下图像与视频增强研究(Matlab代码实现)
754 0
|
9月前
|
数据采集 分布式计算 并行计算
mRMR算法实现特征选择-MATLAB
mRMR算法实现特征选择-MATLAB
480 2

热门文章

最新文章