使用PyTorch Lightning构建轻量化强化学习DQN(附完整源码)(一)

简介: 使用PyTorch Lightning构建轻量化强化学习DQN(附完整源码)(一)

什么是lighting?

image.png

Lightning是一个最近发布的Pythorch库,它可以清晰地抽象和自动化ML模型所附带的所有日常样板代码,允许您专注于实际的ML部分(这些也往往是最有趣的部分)。

除了自动化样板代码外,Lightning还可以作为一种样式指南,用于构建干净且可复制的ML系统。

这非常吸引人,原因如下:

  1. 通过抽象出样板工程代码,可以更容易地识别和理解ML代码。
  2. Lightning的统一结构使得在现有项目的基础上进行构建和理解变得非常容易。
  3. Lightning 自动化的代码是用经过全面测试、定期维护并遵循ML最佳实践的高质量代码构建的。

DQN

image.png

在我们进入代码之前,让我们快速回顾一下DQN的功能。DQN通过学习在特定状态下执行每个操作的值来学习给定环境的最佳策略。这些值称为Q值。

最初,智能体对其环境的理解非常差,因为它没有太多的经验。因此,它的Q值将非常不准确。然而,随着时间的推移,当智能体探索其环境时,它会学习到更精确的Q值,然后可以做出正确的决策。这允许它进一步改进,直到它最终收敛到一个最优策略(理想情况下)。

我们感兴趣的大多数环境,如现代电子游戏和模拟环境,都过于复杂和庞大,无法存储每个状态/动作对的值。这就是为什么我们使用深度神经网络来近似这些值。

智能体的一般生命周期如下所述:

  1. 智能体获取环境的当前状态并将其通过网络进行运算。然后,网络输出给定状态的每个动作的Q值。
  2. 接下来,我们决定是使用由网络给出智能体所认为的最优操作,还是采取随机操作,以便进一步探索。
  3. 这个动作被传递到环境中并得到反馈,告诉智能体它处于的下一个状态是什么,在上一个状态中执行上一个动作所得到的奖励,以及该步骤中的事件是否完成。
  4. 我们以元组(状态, 行为, 奖励, 下一状态, 已经完成的事件)的形式获取在最后一步中获得的经验,并将其存储在智能体内存中。
  5. 最后,我们从智能体内存中抽取一小批重复经验,并使用这些过去的经验计算智能体的损失。

这是DQN功能的一个高度概述。

轻量化DQN

image.png

启蒙时代是一场支配思想世界的智力和哲学运动,让我们看看构成我们的DQN的组成部分

模型:用来逼近Q值的神经网络

重播缓冲区:这是我们智能体的内存,用于存储以前的经验

智能体:智能体本身就是与环境和重播缓冲区交互的东西

Lightning模块:处理智能体的所有训练

模型

对于这个例子,我们可以使用一个非常简单的多层感知器(MLP)。这意味着我们没有使用任何花哨的东西,像卷积层或递归层,只是正常的线性层。这样做的原因是由于卡倒立摆环境的简单性,任何比这更复杂的东西都是过度复杂的。

class DQN(nn.Module):
    """
    Simple MLP network
    Args:
        obs_size: observation/state size of the environment
        n_actions: number of discrete actions available in the environment
        hidden_size: size of hidden layers
    """
    def __init__(self, obs_size: int, n_actions: int, hidden_size: int = 128):
        super(DQN, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, n_actions)
        )
    def forward(self, x):
        return self.net(x.float())

重播缓冲区

重播缓冲区的构建相当直接,我们只需要某种类型的数据结构来存储元组。我们需要能够对这些元组进行采样并添加新的元组。本例中的缓冲区基于Lapins重播缓冲区,因为它是迄今为止我发现的最简洁并且最快的实现。代码如下

# Named tuple for storing experience steps gathered in training
Experience = collections.namedtuple(
    'Experience', field_names=['state', 'action', 'reward',
                               'done', 'new_state'])
class ReplayBuffer:
    """
    Replay Buffer for storing past experiences allowing the agent to learn from them
    Args:
        capacity: size of the buffer
    """
    def __init__(self, capacity: int) -> None:
        self.buffer = collections.deque(maxlen=capacity)
    def __len__(self) -> None:
        return len(self.buffer)
    def append(self, experience: Experience) -> None:
        """
        Add experience to the buffer
        Args:
            experience: tuple (state, action, reward, done, new_state)
        """
        self.buffer.append(experience)
    def sample(self, batch_size: int) -> Tuple:
        indices = np.random.choice(len(self.buffer), batch_size, replace=False)
        states, actions, rewards, dones, next_states = zip(*[self.buffer[idx] for idx in indices])
        return (np.array(states), np.array(actions), np.array(rewards, dtype=np.float32),
                np.array(dones, dtype=np.bool), np.array(next_states))

但我们还没有完成。如果您在知道它的结构是基于创建数据加载器的思想创建的,然后使用它将小批量的经验传递给每个训练步骤这些原理之前使用过Lightning;那么对于大多数ML系统(如监督模型),这一切如何生效的是很清楚的。但是当我们在生成数据集时,它又是如何生效的呢?

我们需要创建自己的可迭代数据集,它使用不断更新的重播缓冲区来采样以前的经验。然后,我们有一小批经验被传递到训练步骤中用于计算我们的损失,就像其他任何模型一样。除了包含输入和标签之外,我们的小批量包含(状态, 行为, 奖励, 下一状态, 已经完成的事件)

class RLDataset(IterableDataset):
    """
    Iterable Dataset containing the ReplayBuffer
    which will be updated with new experiences during training
    Args:
        buffer: replay buffer
        sample_size: number of experiences to sample at a time
    """
    def __init__(self, buffer: ReplayBuffer, sample_size: int = 200) -> None:
        self.buffer = buffer
        self.sample_size = sample_size
    def __iter__(self) -> Tuple:
        states, actions, rewards, dones, new_states = self.buffer.sample(self.sample_size)
        for i in range(len(dones)):
            yield states[i], actions[i], rewards[i], dones[i], new_states[i]

您可以看到,在创建数据集时,我们传入重播缓冲区,然后可以从中采样,以允许数据加载器将批处理传递给Lightning模块。

目录
相关文章
|
8月前
|
机器学习/深度学习 存储 数据管理
面向强化学习的状态空间建模:RSSM的介绍和PyTorch实现
循环状态空间模型(Recurrent State Space Models, RSSM)由 Danijar Hafer 等人提出,是现代基于模型的强化学习(MBRL)中的关键组件。RSSM 旨在构建可靠的环境动态预测模型,使智能体能够模拟未来轨迹并进行前瞻性规划。本文介绍了如何用 PyTorch 实现 RSSM,包括环境配置、模型架构(编码器、动态模型、解码器和奖励模型)、训练系统设计(经验回放缓冲区和智能体)及训练器实现。通过具体案例展示了在 CarRacing 环境中的应用,详细说明了数据收集、训练过程和实验结果。
319 13
面向强化学习的状态空间建模:RSSM的介绍和PyTorch实现
|
7月前
|
机器学习/深度学习 算法 安全
用PyTorch从零构建 DeepSeek R1:模型架构和分步训练详解
本文详细介绍了DeepSeek R1模型的构建过程,涵盖从基础模型选型到多阶段训练流程,再到关键技术如强化学习、拒绝采样和知识蒸馏的应用。
620 3
用PyTorch从零构建 DeepSeek R1:模型架构和分步训练详解
|
8月前
|
机器学习/深度学习 算法 PyTorch
深度强化学习中SAC算法:数学原理、网络架构及其PyTorch实现
软演员-评论家算法(Soft Actor-Critic, SAC)是深度强化学习领域的重要进展,基于最大熵框架优化策略,在探索与利用之间实现动态平衡。SAC通过双Q网络设计和自适应温度参数,提升了训练稳定性和样本效率。本文详细解析了SAC的数学原理、网络架构及PyTorch实现,涵盖演员网络的动作采样与对数概率计算、评论家网络的Q值估计及其损失函数,并介绍了完整的SAC智能体实现流程。SAC在连续动作空间中表现出色,具有高样本效率和稳定的训练过程,适合实际应用场景。
2016 7
深度强化学习中SAC算法:数学原理、网络架构及其PyTorch实现
|
8月前
|
机器学习/深度学习 并行计算 PyTorch
TorchOptimizer:基于贝叶斯优化的PyTorch Lightning超参数调优框架
TorchOptimizer 是一个基于贝叶斯优化方法的超参数优化框架,专为 PyTorch Lightning 模型设计。它通过高斯过程建模目标函数,实现智能化的超参数组合选择,并利用并行计算加速优化过程。该框架支持自定义约束条件、日志记录和检查点机制,显著提升模型性能,适用于各种规模的深度学习项目。相比传统方法,TorchOptimizer 能更高效地确定最优超参数配置。
489 7
|
9月前
|
PyTorch Shell API
Ascend Extension for PyTorch的源码解析
本文介绍了Ascend对PyTorch代码的适配过程,包括源码下载、编译步骤及常见问题,详细解析了torch-npu编译后的文件结构和三种实现昇腾NPU算子调用的方式:通过torch的register方式、定义算子方式和API重定向映射方式。这对于开发者理解和使用Ascend平台上的PyTorch具有重要指导意义。
|
11月前
|
机器学习/深度学习 自然语言处理 监控
利用 PyTorch Lightning 搭建一个文本分类模型
利用 PyTorch Lightning 搭建一个文本分类模型
利用 PyTorch Lightning 搭建一个文本分类模型
|
10月前
|
并行计算 监控 搜索推荐
使用 PyTorch-BigGraph 构建和部署大规模图嵌入的完整教程
当处理大规模图数据时,复杂性难以避免。PyTorch-BigGraph (PBG) 是一款专为此设计的工具,能够高效处理数十亿节点和边的图数据。PBG通过多GPU或节点无缝扩展,利用高效的分区技术,生成准确的嵌入表示,适用于社交网络、推荐系统和知识图谱等领域。本文详细介绍PBG的设置、训练和优化方法,涵盖环境配置、数据准备、模型训练、性能优化和实际应用案例,帮助读者高效处理大规模图数据。
179 5
|
10月前
|
机器学习/深度学习 人工智能 PyTorch
使用Pytorch构建视觉语言模型(VLM)
视觉语言模型(Vision Language Model,VLM)正在改变计算机对视觉和文本信息的理解与交互方式。本文将介绍 VLM 的核心组件和实现细节,可以让你全面掌握这项前沿技术。我们的目标是理解并实现能够通过指令微调来执行有用任务的视觉语言模型。
227 2
|
10月前
|
机器学习/深度学习 监控 PyTorch
深度学习工程实践:PyTorch Lightning与Ignite框架的技术特性对比分析
在深度学习框架的选择上,PyTorch Lightning和Ignite代表了两种不同的技术路线。本文将从技术实现的角度,深入分析这两个框架在实际应用中的差异,为开发者提供客观的技术参考。
234 7
|
12月前
|
存储 缓存 PyTorch
使用PyTorch从零构建Llama 3
本文将详细指导如何从零开始构建完整的Llama 3模型架构,并在自定义数据集上执行训练和推理。
183 1

热门文章

最新文章

推荐镜像

更多