使用PyTorch Lightning构建轻量化强化学习DQN(附完整源码)(二)

简介: 使用PyTorch Lightning构建轻量化强化学习DQN(附完整源码)(二)

智能体

智能体类将处理与环境的交互。智能体类主要有三种方法:

get_action:使用传递的ε值,智能体决定是使用随机操作,还是从网络输出中执行Q值最高的操作。

play_step:在这里,智能体通过从get_action中选择的操作在环境中执行一个步骤。从环境中获得反馈后,经验将存储在重播缓冲区中。如果环境已完成该步骤,则环境将重置。最后,返回当前的奖励和完成标志。

reset:重置环境并更新存储在代理中的当前状态。

class Agent:
    """
    Base Agent class handeling the interaction with the environment
    Args:
        env: training environment
        replay_buffer: replay buffer storing experiences
    """
    def __init__(self, env: gym.Env, replay_buffer: ReplayBuffer) -> None:
        self.env = env
        self.replay_buffer = replay_buffer
        self.reset()
        self.state = self.env.reset()
    def reset(self) -> None:
        """ Resents the environment and updates the state"""
        self.state = self.env.reset()
    def get_action(self, net: nn.Module, epsilon: float, device: str) -> int:
        """
        Using the given network, decide what action to carry out
        using an epsilon-greedy policy
        Args:
            net: DQN network
            epsilon: value to determine likelihood of taking a random action
            device: current device
        Returns:
            action
        """
        if np.random.random() < epsilon:
            action = self.env.action_space.sample()
        else:
            state = torch.tensor([self.state])
            if device not in ['cpu']:
                state = state.cuda(device)
            q_values = net(state)
            _, action = torch.max(q_values, dim=1)
            action = int(action.item())
        return action
    @torch.no_grad()
    def play_step(self, net: nn.Module, epsilon: float = 0.0, device: str = 'cpu') -> Tuple[float, bool]:
        """
        Carries out a single interaction step between the agent and the environment
        Args:
            net: DQN network
            epsilon: value to determine likelihood of taking a random action
            device: current device
        Returns:
            reward, done
        """
        action = self.get_action(net, epsilon, device)
        # do step in the environment
        new_state, reward, done, _ = self.env.step(action)
        exp = Experience(self.state, action, reward, done, new_state)
        self.replay_buffer.append(exp)
        self.state = new_state
        if done:
            self.reset()
        return reward, done

Lightning模块

现在我们已经为DQN建立了核心类,我们可以开始考虑训练DQN智能体。这就是lighting要介入的地方。我们将通过构建一个lighting模块,以一种干净和结构化的方式布置我们所有的训练逻辑。

Lightning提供了很多接口和可重写的函数,以获得最大的灵活性,但是我们必须实现4个关键方法才能使项目运行。就是下面的:

  1. forward()
  2. configure_optimizers
  3. train_dataloader
  4. train_step

有了这4种方法的填充,我们可以使我们遇到的任何ML模型都得到很好的训练。任何需要超过这些方法的东西都可以很好地与Lightning中剩余的接口和回调配合。有关这些可用接口的完整列表,请查看Lightning文档。现在,让我们看看我们的轻量化模型。

初始化

首先,我们需要初始化我们的环境、网络、智能体和重播缓冲区。我们还调用populate函数,它将以随机方式填充重播缓冲区(populate函数在下面的完整代码示例中显示)。

image.png

前向传递

我们在这里所做的就是封装我们的DQN网络的前向传递函数。

image.png

损失函数

在开始训练智能体之前,我们需要定义损失函数。这里使用的损失函数是基于Lapan的实现。

这是一个简单的均方误差(MSE)损失,将我们的DQN网络的当前状态动作值与下一个状态的预期状态动作值进行比较。在RL中我们没有完美的标签可以学习;相反,智能体从它期望的下一个状态的值的目标值中学习。

然而,通过使用同一个网络来预测当前状态的值和下一个状态的值,结果会成为一个不稳定的运动目标。为了对抗这种情况,我们使用目标网络。此网络是主网络的副本,并定期与主网络同步。这提供了一个临时固定的目标,允许代理计算更稳定的损失函数。

image.png

如您所见,状态操作值使用主网络计算,而下一个状态值(相当于我们的目标/标签)使用目标网络。

优化器

这是另外一个简单的补充,只是告诉lighting什么优化器将在反向传递期间使用。我们将使用标准的Adam优化器。

image.png

训练数据加载器

接下来,我们需要向Lightning提供我们的训练数据加载器。如您所料,我们初始化了先前创建的IterableDataset。然后像往常一样把这个传递给数据加载器。Lightning将在培训期间处理提供的批次,并将这些批次转换为Pythorch张量,并将它们移动到正确的设备。

image.png

训练步骤

最后我们有了训练的步骤。在这里,我们输入了每个训练迭代要执行的所有逻辑。

在每次训练迭代过程中,我们希望智能体通过调用前面定义的agent.play_step()并传入当前设备和ε值,在环境中执行一步。这将返回该步骤的奖励,以及本次迭代是否在该步骤中完成。我们将步骤奖励添加到整个事件中,以便跟踪智能体在该事件中的成功程度。

接下来,我们使用lighting提供的当前小批量,计算我们的损失。

如果我们已经到了本次迭代的结尾,用done标志表示,我们将用session reward更新当前的total_reward变量。

在步骤的最后,我们检查是否是同步主网络和目标网络的时间。通常在只更新一部分权重的情况下使用软更新,但对于这个简单的示例来说,完全更新就足够了。

最后,我们需要返回一个Dict,其中包含Lightning将用于反向传播的损耗,一个Dict包含我们要记录的值(注意:这些值必须是张量),另一个Dict包含我们要在进度条上显示的任何值。

image.png

就这样,我们现在有了运行DQN智能体所需的一切。

运行智能体

现在要做的就是初始化并适应我们的lighting模型。在我们的主python文件中,我们将设置种子,并提供一个arg解析器,其中包含我们要传递给模型的任何必要的超参数。

image.png

然后在我们的主方法中,我们用指定的参数初始化dqnlighting模型。接下来是Lightning训练器的设置。

在这里,我们设置教练过程使用GPU。如果您没有访问GPU的权限,请从培训器中删除“GPU”和“distributed_backend”参数。这种模式训练非常快,即使是使用CPU,所以为了在运行过程中观察Lightning,我们将关闭早停机制。

最后,因为我们使用的是可迭代数据集,所以需要指定val_check_interval。通常,此间隔是根据数据集的长度自动设置的。然而,可迭代数据集没有一个长度函数。因此,我们需要自己设置这个值,即使我们没有执行验证步骤。

image.png

最后一步是调用我们的模型上的trainer.fit(),并观看它的训练。

结果

大约1200代后,您将看到智能体的总奖励达到最大得分200。为了看到正在绘制的奖励指标,调用

tensorboard --logdir lightning_logs

image.png

在左边的图中你可以看到每一步的奖励。由于环境的性质,这将始终是1,因为智能体每一步都会得到+1的奖励,极点从没有下降(这就是全部奖励)。在右边的中我们可以看到每一步的总奖励。智能体很快就达到了最高奖励,然后在好的状态和不好的状态之间波动。

结论

现在您已经看到了在强化学习项目中利用PyTorch Lightning的力量是多么简单和实用。

这是一个非常简单的例子,只是为了说明lighting在RL中的使用,所以这里有很多改进的空间。如果您想将此代码作为模板,并尝试实现自己的代理,下面是一些我会尝试的事情。

降低学习率或许更好。通过在configure_optimizer方法中初始化学习率调度程序来使用它。

  1. 提高目标网络的同步速率或使用软更新而不是完全更新
  2. 在更多步骤的过程中使用更渐进的ε衰减。
  3. 通过在训练器中设置max_epochs来增加训练的代数。
  4. 除了跟踪tensorboard日志中的总奖励,还跟踪平均总奖励。
  5. 使用test/val Lightning hook添加测试和验证步骤
  6. 最后,尝试一些更复杂的模型和环境
  7. 我希望这篇文章是有帮助的,将有助于启动您使用lighting启动自己的项目。快乐编码!
目录
相关文章
|
机器学习/深度学习 存储 数据管理
面向强化学习的状态空间建模:RSSM的介绍和PyTorch实现
循环状态空间模型(Recurrent State Space Models, RSSM)由 Danijar Hafer 等人提出,是现代基于模型的强化学习(MBRL)中的关键组件。RSSM 旨在构建可靠的环境动态预测模型,使智能体能够模拟未来轨迹并进行前瞻性规划。本文介绍了如何用 PyTorch 实现 RSSM,包括环境配置、模型架构(编码器、动态模型、解码器和奖励模型)、训练系统设计(经验回放缓冲区和智能体)及训练器实现。通过具体案例展示了在 CarRacing 环境中的应用,详细说明了数据收集、训练过程和实验结果。
1066 13
面向强化学习的状态空间建模:RSSM的介绍和PyTorch实现
|
机器学习/深度学习 算法 PyTorch
深度强化学习中SAC算法:数学原理、网络架构及其PyTorch实现
软演员-评论家算法(Soft Actor-Critic, SAC)是深度强化学习领域的重要进展,基于最大熵框架优化策略,在探索与利用之间实现动态平衡。SAC通过双Q网络设计和自适应温度参数,提升了训练稳定性和样本效率。本文详细解析了SAC的数学原理、网络架构及PyTorch实现,涵盖演员网络的动作采样与对数概率计算、评论家网络的Q值估计及其损失函数,并介绍了完整的SAC智能体实现流程。SAC在连续动作空间中表现出色,具有高样本效率和稳定的训练过程,适合实际应用场景。
5500 7
深度强化学习中SAC算法:数学原理、网络架构及其PyTorch实现
|
机器学习/深度学习 并行计算 PyTorch
TorchOptimizer:基于贝叶斯优化的PyTorch Lightning超参数调优框架
TorchOptimizer 是一个基于贝叶斯优化方法的超参数优化框架,专为 PyTorch Lightning 模型设计。它通过高斯过程建模目标函数,实现智能化的超参数组合选择,并利用并行计算加速优化过程。该框架支持自定义约束条件、日志记录和检查点机制,显著提升模型性能,适用于各种规模的深度学习项目。相比传统方法,TorchOptimizer 能更高效地确定最优超参数配置。
739 7
|
机器学习/深度学习 自然语言处理 监控
利用 PyTorch Lightning 搭建一个文本分类模型
利用 PyTorch Lightning 搭建一个文本分类模型
利用 PyTorch Lightning 搭建一个文本分类模型
|
PyTorch Shell API
Ascend Extension for PyTorch的源码解析
本文介绍了Ascend对PyTorch代码的适配过程,包括源码下载、编译步骤及常见问题,详细解析了torch-npu编译后的文件结构和三种实现昇腾NPU算子调用的方式:通过torch的register方式、定义算子方式和API重定向映射方式。这对于开发者理解和使用Ascend平台上的PyTorch具有重要指导意义。
|
并行计算 监控 搜索推荐
使用 PyTorch-BigGraph 构建和部署大规模图嵌入的完整教程
当处理大规模图数据时,复杂性难以避免。PyTorch-BigGraph (PBG) 是一款专为此设计的工具,能够高效处理数十亿节点和边的图数据。PBG通过多GPU或节点无缝扩展,利用高效的分区技术,生成准确的嵌入表示,适用于社交网络、推荐系统和知识图谱等领域。本文详细介绍PBG的设置、训练和优化方法,涵盖环境配置、数据准备、模型训练、性能优化和实际应用案例,帮助读者高效处理大规模图数据。
416 5
|
机器学习/深度学习 监控 PyTorch
深度学习工程实践:PyTorch Lightning与Ignite框架的技术特性对比分析
在深度学习框架的选择上,PyTorch Lightning和Ignite代表了两种不同的技术路线。本文将从技术实现的角度,深入分析这两个框架在实际应用中的差异,为开发者提供客观的技术参考。
507 7
|
机器学习/深度学习 人工智能 PyTorch
使用Pytorch构建视觉语言模型(VLM)
视觉语言模型(Vision Language Model,VLM)正在改变计算机对视觉和文本信息的理解与交互方式。本文将介绍 VLM 的核心组件和实现细节,可以让你全面掌握这项前沿技术。我们的目标是理解并实现能够通过指令微调来执行有用任务的视觉语言模型。
592 2
|
存储 缓存 PyTorch
使用PyTorch从零构建Llama 3
本文将详细指导如何从零开始构建完整的Llama 3模型架构,并在自定义数据集上执行训练和推理。
665 1

热门文章

最新文章

推荐镜像

更多