智能体
智能体类将处理与环境的交互。智能体类主要有三种方法:
get_action:使用传递的ε值,智能体决定是使用随机操作,还是从网络输出中执行Q值最高的操作。
play_step:在这里,智能体通过从get_action中选择的操作在环境中执行一个步骤。从环境中获得反馈后,经验将存储在重播缓冲区中。如果环境已完成该步骤,则环境将重置。最后,返回当前的奖励和完成标志。
reset:重置环境并更新存储在代理中的当前状态。
class Agent: """ Base Agent class handeling the interaction with the environment Args: env: training environment replay_buffer: replay buffer storing experiences """ def __init__(self, env: gym.Env, replay_buffer: ReplayBuffer) -> None: self.env = env self.replay_buffer = replay_buffer self.reset() self.state = self.env.reset() def reset(self) -> None: """ Resents the environment and updates the state""" self.state = self.env.reset() def get_action(self, net: nn.Module, epsilon: float, device: str) -> int: """ Using the given network, decide what action to carry out using an epsilon-greedy policy Args: net: DQN network epsilon: value to determine likelihood of taking a random action device: current device Returns: action """ if np.random.random() < epsilon: action = self.env.action_space.sample() else: state = torch.tensor([self.state]) if device not in ['cpu']: state = state.cuda(device) q_values = net(state) _, action = torch.max(q_values, dim=1) action = int(action.item()) return action @torch.no_grad() def play_step(self, net: nn.Module, epsilon: float = 0.0, device: str = 'cpu') -> Tuple[float, bool]: """ Carries out a single interaction step between the agent and the environment Args: net: DQN network epsilon: value to determine likelihood of taking a random action device: current device Returns: reward, done """ action = self.get_action(net, epsilon, device) # do step in the environment new_state, reward, done, _ = self.env.step(action) exp = Experience(self.state, action, reward, done, new_state) self.replay_buffer.append(exp) self.state = new_state if done: self.reset() return reward, done
Lightning模块
现在我们已经为DQN建立了核心类,我们可以开始考虑训练DQN智能体。这就是lighting要介入的地方。我们将通过构建一个lighting模块,以一种干净和结构化的方式布置我们所有的训练逻辑。
Lightning提供了很多接口和可重写的函数,以获得最大的灵活性,但是我们必须实现4个关键方法才能使项目运行。就是下面的:
- forward()
- configure_optimizers
- train_dataloader
- train_step
有了这4种方法的填充,我们可以使我们遇到的任何ML模型都得到很好的训练。任何需要超过这些方法的东西都可以很好地与Lightning中剩余的接口和回调配合。有关这些可用接口的完整列表,请查看Lightning文档。现在,让我们看看我们的轻量化模型。
初始化
首先,我们需要初始化我们的环境、网络、智能体和重播缓冲区。我们还调用populate函数,它将以随机方式填充重播缓冲区(populate函数在下面的完整代码示例中显示)。
前向传递
我们在这里所做的就是封装我们的DQN网络的前向传递函数。
损失函数
在开始训练智能体之前,我们需要定义损失函数。这里使用的损失函数是基于Lapan的实现。
这是一个简单的均方误差(MSE)损失,将我们的DQN网络的当前状态动作值与下一个状态的预期状态动作值进行比较。在RL中我们没有完美的标签可以学习;相反,智能体从它期望的下一个状态的值的目标值中学习。
然而,通过使用同一个网络来预测当前状态的值和下一个状态的值,结果会成为一个不稳定的运动目标。为了对抗这种情况,我们使用目标网络。此网络是主网络的副本,并定期与主网络同步。这提供了一个临时固定的目标,允许代理计算更稳定的损失函数。
如您所见,状态操作值使用主网络计算,而下一个状态值(相当于我们的目标/标签)使用目标网络。
优化器
这是另外一个简单的补充,只是告诉lighting什么优化器将在反向传递期间使用。我们将使用标准的Adam优化器。
训练数据加载器
接下来,我们需要向Lightning提供我们的训练数据加载器。如您所料,我们初始化了先前创建的IterableDataset。然后像往常一样把这个传递给数据加载器。Lightning将在培训期间处理提供的批次,并将这些批次转换为Pythorch张量,并将它们移动到正确的设备。
训练步骤
最后我们有了训练的步骤。在这里,我们输入了每个训练迭代要执行的所有逻辑。
在每次训练迭代过程中,我们希望智能体通过调用前面定义的agent.play_step()并传入当前设备和ε值,在环境中执行一步。这将返回该步骤的奖励,以及本次迭代是否在该步骤中完成。我们将步骤奖励添加到整个事件中,以便跟踪智能体在该事件中的成功程度。
接下来,我们使用lighting提供的当前小批量,计算我们的损失。
如果我们已经到了本次迭代的结尾,用done标志表示,我们将用session reward更新当前的total_reward变量。
在步骤的最后,我们检查是否是同步主网络和目标网络的时间。通常在只更新一部分权重的情况下使用软更新,但对于这个简单的示例来说,完全更新就足够了。
最后,我们需要返回一个Dict,其中包含Lightning将用于反向传播的损耗,一个Dict包含我们要记录的值(注意:这些值必须是张量),另一个Dict包含我们要在进度条上显示的任何值。
就这样,我们现在有了运行DQN智能体所需的一切。
运行智能体
现在要做的就是初始化并适应我们的lighting模型。在我们的主python文件中,我们将设置种子,并提供一个arg解析器,其中包含我们要传递给模型的任何必要的超参数。
然后在我们的主方法中,我们用指定的参数初始化dqnlighting模型。接下来是Lightning训练器的设置。
在这里,我们设置教练过程使用GPU。如果您没有访问GPU的权限,请从培训器中删除“GPU”和“distributed_backend”参数。这种模式训练非常快,即使是使用CPU,所以为了在运行过程中观察Lightning,我们将关闭早停机制。
最后,因为我们使用的是可迭代数据集,所以需要指定val_check_interval。通常,此间隔是根据数据集的长度自动设置的。然而,可迭代数据集没有一个长度函数。因此,我们需要自己设置这个值,即使我们没有执行验证步骤。
最后一步是调用我们的模型上的trainer.fit(),并观看它的训练。
结果
大约1200代后,您将看到智能体的总奖励达到最大得分200。为了看到正在绘制的奖励指标,调用
tensorboard --logdir lightning_logs
在左边的图中你可以看到每一步的奖励。由于环境的性质,这将始终是1,因为智能体每一步都会得到+1的奖励,极点从没有下降(这就是全部奖励)。在右边的图中我们可以看到每一步的总奖励。智能体很快就达到了最高奖励,然后在好的状态和不好的状态之间波动。
结论
现在您已经看到了在强化学习项目中利用PyTorch Lightning的力量是多么简单和实用。
这是一个非常简单的例子,只是为了说明lighting在RL中的使用,所以这里有很多改进的空间。如果您想将此代码作为模板,并尝试实现自己的代理,下面是一些我会尝试的事情。
降低学习率或许更好。通过在configure_optimizer方法中初始化学习率调度程序来使用它。
- 提高目标网络的同步速率或使用软更新而不是完全更新
- 在更多步骤的过程中使用更渐进的ε衰减。
- 通过在训练器中设置max_epochs来增加训练的代数。
- 除了跟踪tensorboard日志中的总奖励,还跟踪平均总奖励。
- 使用test/val Lightning hook添加测试和验证步骤
- 最后,尝试一些更复杂的模型和环境
- 我希望这篇文章是有帮助的,将有助于启动您使用lighting启动自己的项目。快乐编码!








