PyTorch 2.2 中文官方教程(八)(3)https://developer.aliyun.com/article/1482532
批量计算
我们教程的最后一个未探索的部分是我们在 TorchRL 中批量计算的能力。因为我们的环境对输入数据形状没有任何假设,所以我们可以无缝地在数据批次上执行它。更好的是:对于像我们的摆锤这样的非批量锁定环境,我们可以在不重新创建环境的情况下即时更改批量大小。为此,我们只需生成所需形状的参数。
batch_size = 10 # number of environments to be executed in batch td = env.reset(env.gen_params(batch_size=[batch_size])) print("reset (batch size of 10)", td) td = env.rand_step(td) print("rand step (batch size of 10)", td)
reset (batch size of 10) TensorDict( fields={ cos: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False), done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False), observation: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float32, is_shared=False), params: TensorDict( fields={ dt: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False), g: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False), l: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False), m: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False), max_speed: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False), max_torque: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([10]), device=None, is_shared=False), sin: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False), th: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False), thdot: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([10]), device=None, is_shared=False) rand step (batch size of 10) TensorDict( fields={ action: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False), cos: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False), done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ cos: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False), done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False), observation: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float32, is_shared=False), params: TensorDict( fields={ dt: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False), g: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False), l: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False), m: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False), max_speed: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False), max_torque: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([10]), device=None, is_shared=False), reward: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False), sin: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False), th: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False), thdot: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([10]), device=None, is_shared=False), observation: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float32, is_shared=False), params: TensorDict( fields={ dt: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False), g: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False), l: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False), m: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False), max_speed: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False), max_torque: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([10]), device=None, is_shared=False), sin: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False), th: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False), thdot: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([10]), device=None, is_shared=False)
使用一批数据执行一个轨迹需要我们在轨迹函数之外重置环境,因为我们需要动态定义批量大小,而rollout()
不支持这一点:
rollout = env.rollout( 3, auto_reset=False, # we're executing the reset out of the ``rollout`` call tensordict=env.reset(env.gen_params(batch_size=[batch_size])), ) print("rollout of len 3 (batch size of 10):", rollout)
rollout of len 3 (batch size of 10): TensorDict( fields={ action: Tensor(shape=torch.Size([10, 3, 1]), device=cpu, dtype=torch.float32, is_shared=False), cos: Tensor(shape=torch.Size([10, 3, 1]), device=cpu, dtype=torch.float32, is_shared=False), done: Tensor(shape=torch.Size([10, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ cos: Tensor(shape=torch.Size([10, 3, 1]), device=cpu, dtype=torch.float32, is_shared=False), done: Tensor(shape=torch.Size([10, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False), observation: Tensor(shape=torch.Size([10, 3, 3]), device=cpu, dtype=torch.float32, is_shared=False), params: TensorDict( fields={ dt: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float32, is_shared=False), g: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float32, is_shared=False), l: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float32, is_shared=False), m: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float32, is_shared=False), max_speed: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.int64, is_shared=False), max_torque: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([10, 3]), device=None, is_shared=False), reward: Tensor(shape=torch.Size([10, 3, 1]), device=cpu, dtype=torch.float32, is_shared=False), sin: Tensor(shape=torch.Size([10, 3, 1]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([10, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False), th: Tensor(shape=torch.Size([10, 3, 1]), device=cpu, dtype=torch.float32, is_shared=False), thdot: Tensor(shape=torch.Size([10, 3, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([10, 3]), device=None, is_shared=False), observation: Tensor(shape=torch.Size([10, 3, 3]), device=cpu, dtype=torch.float32, is_shared=False), params: TensorDict( fields={ dt: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float32, is_shared=False), g: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float32, is_shared=False), l: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float32, is_shared=False), m: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float32, is_shared=False), max_speed: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.int64, is_shared=False), max_torque: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([10, 3]), device=None, is_shared=False), sin: Tensor(shape=torch.Size([10, 3, 1]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([10, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False), th: Tensor(shape=torch.Size([10, 3, 1]), device=cpu, dtype=torch.float32, is_shared=False), thdot: Tensor(shape=torch.Size([10, 3, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([10, 3]), device=None, is_shared=False)
训练一个简单的策略
在这个例子中,我们将使用奖励作为可微目标来训练一个简单的策略,比如一个负损失。我们将利用我们的动态系统是完全可微的这一事实,通过轨迹返回反向传播并调整我们的策略权重,以直接最大化这个值。当然,在许多情况下,我们所做的假设并不成立,比如可微系统和对底层机制的完全访问。
然而,这只是一个非常简单的例子,展示了如何在 TorchRL 中使用自定义环境编写训练循环。
让我们首先编写策略网络:
torch.manual_seed(0) env.set_seed(0) net = nn.Sequential( nn.LazyLinear(64), nn.Tanh(), nn.LazyLinear(64), nn.Tanh(), nn.LazyLinear(64), nn.Tanh(), nn.LazyLinear(1), ) policy = TensorDictModule( net, in_keys=["observation"], out_keys=["action"], )
/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/lazy.py:181: UserWarning: Lazy modules are a new feature under heavy development so changes to the API or functionality can happen at any moment.
和我们的优化器:
optim = torch.optim.Adam(policy.parameters(), lr=2e-3)
训练循环
我们将依次:
- 生成一个轨迹
- 对奖励求和
- 通过这些操作定义的图进行反向传播
- 裁剪梯度范数并进行优化步骤
- 重复
在训练循环结束时,我们应该有一个接近 0 的最终奖励,这表明摆锤向上并保持静止。
batch_size = 32 pbar = tqdm.tqdm(range(20_000 // batch_size)) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, 20_000) logs = defaultdict(list) for _ in pbar: init_td = env.reset(env.gen_params(batch_size=[batch_size])) rollout = env.rollout(100, policy, tensordict=init_td, auto_reset=False) traj_return = rollout["next", "reward"].mean() (-traj_return).backward() gn = torch.nn.utils.clip_grad_norm_(net.parameters(), 1.0) optim.step() optim.zero_grad() pbar.set_description( f"reward: {traj_return: 4.4f}, " f"last reward: {rollout[..., -1]['next', 'reward'].mean(): 4.4f}, gradient norm: {gn: 4.4}" ) logs["return"].append(traj_return.item()) logs["last_reward"].append(rollout[..., -1]["next", "reward"].mean().item()) scheduler.step() def plot(): import matplotlib from matplotlib import pyplot as plt is_ipython = "inline" in matplotlib.get_backend() if is_ipython: from IPython import display with plt.ion(): plt.figure(figsize=(10, 5)) plt.subplot(1, 2, 1) plt.plot(logs["return"]) plt.title("returns") plt.xlabel("iteration") plt.subplot(1, 2, 2) plt.plot(logs["last_reward"]) plt.title("last reward") plt.xlabel("iteration") if is_ipython: display.display(plt.gcf()) display.clear_output(wait=True) plt.show() plot()
0%| | 0/625 [00:00<?, ?it/s] reward: -6.0488, last reward: -5.0748, gradient norm: 8.518: 0%| | 0/625 [00:00<?, ?it/s] reward: -6.0488, last reward: -5.0748, gradient norm: 8.518: 0%| | 1/625 [00:00<02:36, 3.99it/s] reward: -7.0499, last reward: -7.4472, gradient norm: 5.073: 0%| | 1/625 [00:00<02:36, 3.99it/s] reward: -7.0499, last reward: -7.4472, gradient norm: 5.073: 0%| | 2/625 [00:00<02:32, 4.08it/s] reward: -7.0685, last reward: -7.0408, gradient norm: 5.552: 0%| | 2/625 [00:00<02:32, 4.08it/s] reward: -7.0685, last reward: -7.0408, gradient norm: 5.552: 0%| | 3/625 [00:00<02:29, 4.15it/s] reward: -6.5154, last reward: -5.9086, gradient norm: 2.526: 0%| | 3/625 [00:00<02:29, 4.15it/s] reward: -6.5154, last reward: -5.9086, gradient norm: 2.526: 1%| | 4/625 [00:00<02:29, 4.14it/s] reward: -6.2004, last reward: -5.9401, gradient norm: 7.964: 1%| | 4/625 [00:01<02:29, 4.14it/s] reward: -6.2004, last reward: -5.9401, gradient norm: 7.964: 1%| | 5/625 [00:01<02:29, 4.14it/s] reward: -6.2566, last reward: -5.4981, gradient norm: 4.446: 1%| | 5/625 [00:01<02:29, 4.14it/s] reward: -6.2566, last reward: -5.4981, gradient norm: 4.446: 1%| | 6/625 [00:01<02:28, 4.17it/s] reward: -5.8926, last reward: -8.4134, gradient norm: 2.108: 1%| | 6/625 [00:01<02:28, 4.17it/s] reward: -5.8926, last reward: -8.4134, gradient norm: 2.108: 1%|1 | 7/625 [00:01<02:27, 4.19it/s] reward: -6.3541, last reward: -9.1257, gradient norm: 2.045: 1%|1 | 7/625 [00:01<02:27, 4.19it/s] reward: -6.3541, last reward: -9.1257, gradient norm: 2.045: 1%|1 | 8/625 [00:01<02:26, 4.20it/s] reward: -6.2071, last reward: -8.8872, gradient norm: 11.97: 1%|1 | 8/625 [00:02<02:26, 4.20it/s] reward: -6.2071, last reward: -8.8872, gradient norm: 11.97: 1%|1 | 9/625 [00:02<02:26, 4.20it/s] reward: -6.5838, last reward: -9.2693, gradient norm: 3.34: 1%|1 | 9/625 [00:02<02:26, 4.20it/s] reward: -6.5838, last reward: -9.2693, gradient norm: 3.34: 2%|1 | 10/625 [00:02<02:26, 4.21it/s] reward: -6.2601, last reward: -9.0436, gradient norm: 4.874: 2%|1 | 10/625 [00:02<02:26, 4.21it/s] reward: -6.2601, last reward: -9.0436, gradient norm: 4.874: 2%|1 | 11/625 [00:02<02:25, 4.21it/s] reward: -6.3676, last reward: -8.2883, gradient norm: 2.542: 2%|1 | 11/625 [00:02<02:25, 4.21it/s] reward: -6.3676, last reward: -8.2883, gradient norm: 2.542: 2%|1 | 12/625 [00:02<02:25, 4.21it/s] reward: -5.9768, last reward: -8.4551, gradient norm: 2.931: 2%|1 | 12/625 [00:03<02:25, 4.21it/s] reward: -5.9768, last reward: -8.4551, gradient norm: 2.931: 2%|2 | 13/625 [00:03<02:25, 4.22it/s] reward: -5.9597, last reward: -8.0172, gradient norm: 5.493: 2%|2 | 13/625 [00:03<02:25, 4.22it/s] reward: -5.9597, last reward: -8.0172, gradient norm: 5.493: 2%|2 | 14/625 [00:03<02:24, 4.22it/s] reward: -6.0045, last reward: -6.3726, gradient norm: 1.216: 2%|2 | 14/625 [00:03<02:24, 4.22it/s] reward: -6.0045, last reward: -6.3726, gradient norm: 1.216: 2%|2 | 15/625 [00:03<02:24, 4.22it/s] reward: -6.0157, last reward: -7.4454, gradient norm: 4.614: 2%|2 | 15/625 [00:03<02:24, 4.22it/s] reward: -6.0157, last reward: -7.4454, gradient norm: 4.614: 3%|2 | 16/625 [00:03<02:24, 4.22it/s] reward: -5.7248, last reward: -4.7793, gradient norm: 11.7: 3%|2 | 16/625 [00:04<02:24, 4.22it/s] reward: -5.7248, last reward: -4.7793, gradient norm: 11.7: 3%|2 | 17/625 [00:04<02:24, 4.21it/s] reward: -5.8783, last reward: -3.7558, gradient norm: 7.704: 3%|2 | 17/625 [00:04<02:24, 4.21it/s] reward: -5.8783, last reward: -3.7558, gradient norm: 7.704: 3%|2 | 18/625 [00:04<02:24, 4.21it/s] reward: -6.0913, last reward: -6.0003, gradient norm: 17.23: 3%|2 | 18/625 [00:04<02:24, 4.21it/s] reward: -6.0913, last reward: -6.0003, gradient norm: 17.23: 3%|3 | 19/625 [00:04<02:24, 4.20it/s] reward: -5.9328, last reward: -5.2019, gradient norm: 3.004: 3%|3 | 19/625 [00:04<02:24, 4.20it/s] reward: -5.9328, last reward: -5.2019, gradient norm: 3.004: 3%|3 | 20/625 [00:04<02:24, 4.20it/s] reward: -6.1899, last reward: -6.5583, gradient norm: 8.905: 3%|3 | 20/625 [00:05<02:24, 4.20it/s] reward: -6.1899, last reward: -6.5583, gradient norm: 8.905: 3%|3 | 21/625 [00:05<02:23, 4.20it/s] reward: -5.8776, last reward: -6.3394, gradient norm: 86.04: 3%|3 | 21/625 [00:05<02:23, 4.20it/s] reward: -5.8776, last reward: -6.3394, gradient norm: 86.04: 4%|3 | 22/625 [00:05<02:23, 4.19it/s] reward: -6.3972, last reward: -6.5765, gradient norm: 20.42: 4%|3 | 22/625 [00:05<02:23, 4.19it/s] reward: -6.3972, last reward: -6.5765, gradient norm: 20.42: 4%|3 | 23/625 [00:05<02:23, 4.20it/s] reward: -6.3652, last reward: -5.7013, gradient norm: 4.733: 4%|3 | 23/625 [00:05<02:23, 4.20it/s] reward: -6.3652, last reward: -5.7013, gradient norm: 4.733: 4%|3 | 24/625 [00:05<02:22, 4.21it/s] reward: -5.5586, last reward: -6.3572, gradient norm: 7.792: 4%|3 | 24/625 [00:05<02:22, 4.21it/s] reward: -5.5586, last reward: -6.3572, gradient norm: 7.792: 4%|4 | 25/625 [00:05<02:22, 4.20it/s] reward: -5.4795, last reward: -4.5168, gradient norm: 1.692: 4%|4 | 25/625 [00:06<02:22, 4.20it/s] reward: -5.4795, last reward: -4.5168, gradient norm: 1.692: 4%|4 | 26/625 [00:06<02:22, 4.21it/s] reward: -5.5407, last reward: -7.0325, gradient norm: 773.3: 4%|4 | 26/625 [00:06<02:22, 4.21it/s] reward: -5.5407, last reward: -7.0325, gradient norm: 773.3: 4%|4 | 27/625 [00:06<02:22, 4.20it/s] reward: -5.7399, last reward: -6.0130, gradient norm: 2.865: 4%|4 | 27/625 [00:06<02:22, 4.20it/s] reward: -5.7399, last reward: -6.0130, gradient norm: 2.865: 4%|4 | 28/625 [00:06<02:22, 4.20it/s] reward: -6.0738, last reward: -6.5728, gradient norm: 2.833: 4%|4 | 28/625 [00:06<02:22, 4.20it/s] reward: -6.0738, last reward: -6.5728, gradient norm: 2.833: 5%|4 | 29/625 [00:06<02:21, 4.20it/s] reward: -6.0101, last reward: -6.4175, gradient norm: 6.212: 5%|4 | 29/625 [00:07<02:21, 4.20it/s] reward: -6.0101, last reward: -6.4175, gradient norm: 6.212: 5%|4 | 30/625 [00:07<02:21, 4.20it/s] reward: -5.9955, last reward: -4.7723, gradient norm: 3.158: 5%|4 | 30/625 [00:07<02:21, 4.20it/s] reward: -5.9955, last reward: -4.7723, gradient norm: 3.158: 5%|4 | 31/625 [00:07<02:21, 4.21it/s] reward: -5.6103, last reward: -3.8313, gradient norm: 5.422: 5%|4 | 31/625 [00:07<02:21, 4.21it/s] reward: -5.6103, last reward: -3.8313, gradient norm: 5.422: 5%|5 | 32/625 [00:07<02:20, 4.21it/s] reward: -5.6042, last reward: -3.8542, gradient norm: 5.069: 5%|5 | 32/625 [00:07<02:20, 4.21it/s] reward: -5.6042, last reward: -3.8542, gradient norm: 5.069: 5%|5 | 33/625 [00:07<02:20, 4.21it/s] reward: -5.5265, last reward: -4.3386, gradient norm: 2.368: 5%|5 | 33/625 [00:08<02:20, 4.21it/s] reward: -5.5265, last reward: -4.3386, gradient norm: 2.368: 5%|5 | 34/625 [00:08<02:20, 4.21it/s] reward: -5.6277, last reward: -5.1658, gradient norm: 25.25: 5%|5 | 34/625 [00:08<02:20, 4.21it/s] reward: -5.6277, last reward: -5.1658, gradient norm: 25.25: 6%|5 | 35/625 [00:08<02:20, 4.21it/s] reward: -5.6876, last reward: -5.1197, gradient norm: 110.2: 6%|5 | 35/625 [00:08<02:20, 4.21it/s] reward: -5.6876, last reward: -5.1197, gradient norm: 110.2: 6%|5 | 36/625 [00:08<02:19, 4.21it/s] reward: -6.0015, last reward: -4.9656, gradient norm: 1.3: 6%|5 | 36/625 [00:08<02:19, 4.21it/s] reward: -6.0015, last reward: -4.9656, gradient norm: 1.3: 6%|5 | 37/625 [00:08<02:19, 4.22it/s] reward: -5.6628, last reward: -6.0784, gradient norm: 10.63: 6%|5 | 37/625 [00:09<02:19, 4.22it/s] reward: -5.6628, last reward: -6.0784, gradient norm: 10.63: 6%|6 | 38/625 [00:09<02:19, 4.22it/s] reward: -5.8188, last reward: -5.3053, gradient norm: 20.95: 6%|6 | 38/625 [00:09<02:19, 4.22it/s] reward: -5.8188, last reward: -5.3053, gradient norm: 20.95: 6%|6 | 39/625 [00:09<02:19, 4.21it/s] reward: -5.5934, last reward: -5.4250, gradient norm: 2.52: 6%|6 | 39/625 [00:09<02:19, 4.21it/s] reward: -5.5934, last reward: -5.4250, gradient norm: 2.52: 6%|6 | 40/625 [00:09<02:19, 4.20it/s] reward: -5.4317, last reward: -5.2191, gradient norm: 11.53: 6%|6 | 40/625 [00:09<02:19, 4.20it/s] reward: -5.4317, last reward: -5.2191, gradient norm: 11.53: 7%|6 | 41/625 [00:09<02:19, 4.20it/s] reward: -5.8227, last reward: -5.2263, gradient norm: 5.554: 7%|6 | 41/625 [00:10<02:19, 4.20it/s] reward: -5.8227, last reward: -5.2263, gradient norm: 5.554: 7%|6 | 42/625 [00:10<02:19, 4.19it/s] reward: -5.6086, last reward: -3.3930, gradient norm: 13.2: 7%|6 | 42/625 [00:10<02:19, 4.19it/s] reward: -5.6086, last reward: -3.3930, gradient norm: 13.2: 7%|6 | 43/625 [00:10<02:18, 4.19it/s] reward: -5.5969, last reward: -4.8821, gradient norm: 2.538: 7%|6 | 43/625 [00:10<02:18, 4.19it/s] reward: -5.5969, last reward: -4.8821, gradient norm: 2.538: 7%|7 | 44/625 [00:10<02:18, 4.19it/s] reward: -5.5018, last reward: -4.3099, gradient norm: 3.416: 7%|7 | 44/625 [00:10<02:18, 4.19it/s] reward: -5.5018, last reward: -4.3099, gradient norm: 3.416: 7%|7 | 45/625 [00:10<02:18, 4.18it/s] reward: -5.6813, last reward: -5.1515, gradient norm: 19.79: 7%|7 | 45/625 [00:10<02:18, 4.18it/s] reward: -5.6813, last reward: -5.1515, gradient norm: 19.79: 7%|7 | 46/625 [00:10<02:18, 4.17it/s] reward: -5.8823, last reward: -5.6010, gradient norm: 12.73: 7%|7 | 46/625 [00:11<02:18, 4.17it/s] reward: -5.8823, last reward: -5.6010, gradient norm: 12.73: 8%|7 | 47/625 [00:11<02:18, 4.17it/s] reward: -5.2582, last reward: -6.6556, gradient norm: 6.568: 8%|7 | 47/625 [00:11<02:18, 4.17it/s] reward: -5.2582, last reward: -6.6556, gradient norm: 6.568: 8%|7 | 48/625 [00:11<02:17, 4.18it/s] reward: -5.6368, last reward: -6.3310, gradient norm: 8.046: 8%|7 | 48/625 [00:11<02:17, 4.18it/s] reward: -5.6368, last reward: -6.3310, gradient norm: 8.046: 8%|7 | 49/625 [00:11<02:17, 4.18it/s] reward: -5.6776, last reward: -6.1928, gradient norm: 4.976: 8%|7 | 49/625 [00:11<02:17, 4.18it/s] reward: -5.6776, last reward: -6.1928, gradient norm: 4.976: 8%|8 | 50/625 [00:11<02:17, 4.18it/s] reward: -5.6418, last reward: -4.5608, gradient norm: 2.355: 8%|8 | 50/625 [00:12<02:17, 4.18it/s] reward: -5.6418, last reward: -4.5608, gradient norm: 2.355: 8%|8 | 51/625 [00:12<02:17, 4.18it/s] reward: -5.4142, last reward: -4.4533, gradient norm: 3.903: 8%|8 | 51/625 [00:12<02:17, 4.18it/s] reward: -5.4142, last reward: -4.4533, gradient norm: 3.903: 8%|8 | 52/625 [00:12<02:16, 4.19it/s] reward: -5.3920, last reward: -3.6933, gradient norm: 5.534: 8%|8 | 52/625 [00:12<02:16, 4.19it/s] reward: -5.3920, last reward: -3.6933, gradient norm: 5.534: 8%|8 | 53/625 [00:12<02:16, 4.19it/s] reward: -5.3322, last reward: -3.1984, gradient norm: 4.058: 8%|8 | 53/625 [00:12<02:16, 4.19it/s] reward: -5.3322, last reward: -3.1984, gradient norm: 4.058: 9%|8 | 54/625 [00:12<02:16, 4.19it/s] reward: -5.3709, last reward: -4.5488, gradient norm: 37.33: 9%|8 | 54/625 [00:13<02:16, 4.19it/s] reward: -5.3709, last reward: -4.5488, gradient norm: 37.33: 9%|8 | 55/625 [00:13<02:16, 4.19it/s] reward: -5.4076, last reward: -3.1880, gradient norm: 1.395: 9%|8 | 55/625 [00:13<02:16, 4.19it/s] reward: -5.4076, last reward: -3.1880, gradient norm: 1.395: 9%|8 | 56/625 [00:13<02:16, 4.18it/s] reward: -5.3727, last reward: -2.1695, gradient norm: 2.613: 9%|8 | 56/625 [00:13<02:16, 4.18it/s] reward: -5.3727, last reward: -2.1695, gradient norm: 2.613: 9%|9 | 57/625 [00:13<02:15, 4.19it/s] reward: -5.6188, last reward: -2.7869, gradient norm: 1.464: 9%|9 | 57/625 [00:13<02:15, 4.19it/s] reward: -5.6188, last reward: -2.7869, gradient norm: 1.464: 9%|9 | 58/625 [00:13<02:15, 4.19it/s] reward: -5.4788, last reward: -5.2309, gradient norm: 12.19: 9%|9 | 58/625 [00:14<02:15, 4.19it/s] reward: -5.4788, last reward: -5.2309, gradient norm: 12.19: 9%|9 | 59/625 [00:14<02:16, 4.15it/s] reward: -5.1972, last reward: -5.1203, gradient norm: 67.95: 9%|9 | 59/625 [00:14<02:16, 4.15it/s] reward: -5.1972, last reward: -5.1203, gradient norm: 67.95: 10%|9 | 60/625 [00:14<02:15, 4.16it/s] reward: -5.4977, last reward: -4.8712, gradient norm: 4.688: 10%|9 | 60/625 [00:14<02:15, 4.16it/s] reward: -5.4977, last reward: -4.8712, gradient norm: 4.688: 10%|9 | 61/625 [00:14<02:15, 4.17it/s] reward: -5.4804, last reward: -6.0890, gradient norm: 3.287: 10%|9 | 61/625 [00:14<02:15, 4.17it/s] reward: -5.4804, last reward: -6.0890, gradient norm: 3.287: 10%|9 | 62/625 [00:14<02:14, 4.19it/s] reward: -5.3051, last reward: -4.3689, gradient norm: 64.25: 10%|9 | 62/625 [00:15<02:14, 4.19it/s] reward: -5.3051, last reward: -4.3689, gradient norm: 64.25: 10%|# | 63/625 [00:15<02:14, 4.19it/s] reward: -5.3228, last reward: -4.2780, gradient norm: 9.055: 10%|# | 63/625 [00:15<02:14, 4.19it/s] reward: -5.3228, last reward: -4.2780, gradient norm: 9.055: 10%|# | 64/625 [00:15<02:14, 4.16it/s] reward: -5.1394, last reward: -4.0425, gradient norm: 9.393: 10%|# | 64/625 [00:15<02:14, 4.16it/s] reward: -5.1394, last reward: -4.0425, gradient norm: 9.393: 10%|# | 65/625 [00:15<02:14, 4.17it/s] reward: -5.2673, last reward: -4.0022, gradient norm: 8.597: 10%|# | 65/625 [00:15<02:14, 4.17it/s] reward: -5.2673, last reward: -4.0022, gradient norm: 8.597: 11%|# | 66/625 [00:15<02:14, 4.17it/s] reward: -5.1040, last reward: -4.5461, gradient norm: 18.81: 11%|# | 66/625 [00:15<02:14, 4.17it/s] reward: -5.1040, last reward: -4.5461, gradient norm: 18.81: 11%|# | 67/625 [00:15<02:13, 4.17it/s] reward: -5.3599, last reward: -4.0312, gradient norm: 34.25: 11%|# | 67/625 [00:16<02:13, 4.17it/s] reward: -5.3599, last reward: -4.0312, gradient norm: 34.25: 11%|# | 68/625 [00:16<02:13, 4.18it/s] reward: -5.3867, last reward: -6.7588, gradient norm: 4.311: 11%|# | 68/625 [00:16<02:13, 4.18it/s] reward: -5.3867, last reward: -6.7588, gradient norm: 4.311: 11%|#1 | 69/625 [00:16<02:13, 4.18it/s] reward: -5.3548, last reward: -8.1878, gradient norm: 44.19: 11%|#1 | 69/625 [00:16<02:13, 4.18it/s] reward: -5.3548, last reward: -8.1878, gradient norm: 44.19: 11%|#1 | 70/625 [00:16<02:12, 4.18it/s] reward: -5.3264, last reward: -6.2046, gradient norm: 6.25: 11%|#1 | 70/625 [00:16<02:12, 4.18it/s] reward: -5.3264, last reward: -6.2046, gradient norm: 6.25: 11%|#1 | 71/625 [00:16<02:12, 4.19it/s] reward: -5.3723, last reward: -5.9680, gradient norm: 11.1: 11%|#1 | 71/625 [00:17<02:12, 4.19it/s] ...
结论
在本教程中,我们学习了如何从头开始编码一个无状态环境。我们涉及了以下主题:
- 编码环境时需要注意的四个基本组件(
step
、reset
、种子和构建规范)。我们看到这些方法和类如何与TensorDict
类交互; - 如何测试环境是否正确编码使用
check_env_specs()
; - 如何在无状态环境的上下文中追加转换以及如何编写自定义转换;
- 如何在完全可微分的模拟器上训练策略。
脚本的总运行时间:(2 分钟 30.147 秒)
下载 Python 源代码:pendulum.py
下载 Jupyter 笔记本:pendulum.ipynb