PyTorch 2.2 中文官方教程(八)(4)

简介: PyTorch 2.2 中文官方教程(八)

PyTorch 2.2 中文官方教程(八)(3)https://developer.aliyun.com/article/1482532

批量计算

我们教程的最后一个未探索的部分是我们在 TorchRL 中批量计算的能力。因为我们的环境对输入数据形状没有任何假设,所以我们可以无缝地在数据批次上执行它。更好的是:对于像我们的摆锤这样的非批量锁定环境,我们可以在不重新创建环境的情况下即时更改批量大小。为此,我们只需生成所需形状的参数。

batch_size = 10  # number of environments to be executed in batch
td = env.reset(env.gen_params(batch_size=[batch_size]))
print("reset (batch size of 10)", td)
td = env.rand_step(td)
print("rand step (batch size of 10)", td) 
reset (batch size of 10) TensorDict(
    fields={
        cos: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        observation: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        params: TensorDict(
            fields={
                dt: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
                g: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
                l: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
                m: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
                max_speed: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False),
                max_torque: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False),
        sin: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        th: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        thdot: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)
rand step (batch size of 10) TensorDict(
    fields={
        action: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        cos: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                cos: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                params: TensorDict(
                    fields={
                        dt: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
                        g: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
                        l: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
                        m: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
                        max_speed: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False),
                        max_torque: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([10]),
                    device=None,
                    is_shared=False),
                reward: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                sin: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                th: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                thdot: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False),
        observation: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        params: TensorDict(
            fields={
                dt: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
                g: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
                l: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
                m: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
                max_speed: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False),
                max_torque: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False),
        sin: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        th: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        thdot: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False) 

使用一批数据执行一个轨迹需要我们在轨迹函数之外重置环境,因为我们需要动态定义批量大小,而rollout()不支持这一点:

rollout = env.rollout(
    3,
    auto_reset=False,  # we're executing the reset out of the ``rollout`` call
    tensordict=env.reset(env.gen_params(batch_size=[batch_size])),
)
print("rollout of len 3 (batch size of 10):", rollout) 
rollout of len 3 (batch size of 10): TensorDict(
    fields={
        action: Tensor(shape=torch.Size([10, 3, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        cos: Tensor(shape=torch.Size([10, 3, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([10, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                cos: Tensor(shape=torch.Size([10, 3, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                done: Tensor(shape=torch.Size([10, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([10, 3, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                params: TensorDict(
                    fields={
                        dt: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                        g: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                        l: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                        m: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                        max_speed: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.int64, is_shared=False),
                        max_torque: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([10, 3]),
                    device=None,
                    is_shared=False),
                reward: Tensor(shape=torch.Size([10, 3, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                sin: Tensor(shape=torch.Size([10, 3, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                terminated: Tensor(shape=torch.Size([10, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                th: Tensor(shape=torch.Size([10, 3, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                thdot: Tensor(shape=torch.Size([10, 3, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([10, 3]),
            device=None,
            is_shared=False),
        observation: Tensor(shape=torch.Size([10, 3, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        params: TensorDict(
            fields={
                dt: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                g: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                l: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                m: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                max_speed: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.int64, is_shared=False),
                max_torque: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([10, 3]),
            device=None,
            is_shared=False),
        sin: Tensor(shape=torch.Size([10, 3, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: Tensor(shape=torch.Size([10, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        th: Tensor(shape=torch.Size([10, 3, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        thdot: Tensor(shape=torch.Size([10, 3, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([10, 3]),
    device=None,
    is_shared=False) 

训练一个简单的策略

在这个例子中,我们将使用奖励作为可微目标来训练一个简单的策略,比如一个负损失。我们将利用我们的动态系统是完全可微的这一事实,通过轨迹返回反向传播并调整我们的策略权重,以直接最大化这个值。当然,在许多情况下,我们所做的假设并不成立,比如可微系统和对底层机制的完全访问。

然而,这只是一个非常简单的例子,展示了如何在 TorchRL 中使用自定义环境编写训练循环。

让我们首先编写策略网络:

torch.manual_seed(0)
env.set_seed(0)
net = nn.Sequential(
    nn.LazyLinear(64),
    nn.Tanh(),
    nn.LazyLinear(64),
    nn.Tanh(),
    nn.LazyLinear(64),
    nn.Tanh(),
    nn.LazyLinear(1),
)
policy = TensorDictModule(
    net,
    in_keys=["observation"],
    out_keys=["action"],
) 
/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/lazy.py:181: UserWarning:
Lazy modules are a new feature under heavy development so changes to the API or functionality can happen at any moment. 

和我们的优化器:

optim = torch.optim.Adam(policy.parameters(), lr=2e-3) 

训练循环

我们将依次:

  • 生成一个轨迹
  • 对奖励求和
  • 通过这些操作定义的图进行反向传播
  • 裁剪梯度范数并进行优化步骤
  • 重复

在训练循环结束时,我们应该有一个接近 0 的最终奖励,这表明摆锤向上并保持静止。

batch_size = 32
pbar = tqdm.tqdm(range(20_000 // batch_size))
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, 20_000)
logs = defaultdict(list)
for _ in pbar:
    init_td = env.reset(env.gen_params(batch_size=[batch_size]))
    rollout = env.rollout(100, policy, tensordict=init_td, auto_reset=False)
    traj_return = rollout["next", "reward"].mean()
    (-traj_return).backward()
    gn = torch.nn.utils.clip_grad_norm_(net.parameters(), 1.0)
    optim.step()
    optim.zero_grad()
    pbar.set_description(
        f"reward: {traj_return: 4.4f}, "
        f"last reward: {rollout[...,  -1]['next',  'reward'].mean(): 4.4f}, gradient norm: {gn: 4.4}"
    )
    logs["return"].append(traj_return.item())
    logs["last_reward"].append(rollout[..., -1]["next", "reward"].mean().item())
    scheduler.step()
def plot():
    import matplotlib
    from matplotlib import pyplot as plt
    is_ipython = "inline" in matplotlib.get_backend()
    if is_ipython:
        from IPython import display
    with plt.ion():
        plt.figure(figsize=(10, 5))
        plt.subplot(1, 2, 1)
        plt.plot(logs["return"])
        plt.title("returns")
        plt.xlabel("iteration")
        plt.subplot(1, 2, 2)
        plt.plot(logs["last_reward"])
        plt.title("last reward")
        plt.xlabel("iteration")
        if is_ipython:
            display.display(plt.gcf())
            display.clear_output(wait=True)
        plt.show()
plot() 

0%|          | 0/625 [00:00<?, ?it/s]
reward: -6.0488, last reward: -5.0748, gradient norm:  8.518:   0%|          | 0/625 [00:00<?, ?it/s]
reward: -6.0488, last reward: -5.0748, gradient norm:  8.518:   0%|          | 1/625 [00:00<02:36,  3.99it/s]
reward: -7.0499, last reward: -7.4472, gradient norm:  5.073:   0%|          | 1/625 [00:00<02:36,  3.99it/s]
reward: -7.0499, last reward: -7.4472, gradient norm:  5.073:   0%|          | 2/625 [00:00<02:32,  4.08it/s]
reward: -7.0685, last reward: -7.0408, gradient norm:  5.552:   0%|          | 2/625 [00:00<02:32,  4.08it/s]
reward: -7.0685, last reward: -7.0408, gradient norm:  5.552:   0%|          | 3/625 [00:00<02:29,  4.15it/s]
reward: -6.5154, last reward: -5.9086, gradient norm:  2.526:   0%|          | 3/625 [00:00<02:29,  4.15it/s]
reward: -6.5154, last reward: -5.9086, gradient norm:  2.526:   1%|          | 4/625 [00:00<02:29,  4.14it/s]
reward: -6.2004, last reward: -5.9401, gradient norm:  7.964:   1%|          | 4/625 [00:01<02:29,  4.14it/s]
reward: -6.2004, last reward: -5.9401, gradient norm:  7.964:   1%|          | 5/625 [00:01<02:29,  4.14it/s]
reward: -6.2566, last reward: -5.4981, gradient norm:  4.446:   1%|          | 5/625 [00:01<02:29,  4.14it/s]
reward: -6.2566, last reward: -5.4981, gradient norm:  4.446:   1%|          | 6/625 [00:01<02:28,  4.17it/s]
reward: -5.8926, last reward: -8.4134, gradient norm:  2.108:   1%|          | 6/625 [00:01<02:28,  4.17it/s]
reward: -5.8926, last reward: -8.4134, gradient norm:  2.108:   1%|1         | 7/625 [00:01<02:27,  4.19it/s]
reward: -6.3541, last reward: -9.1257, gradient norm:  2.045:   1%|1         | 7/625 [00:01<02:27,  4.19it/s]
reward: -6.3541, last reward: -9.1257, gradient norm:  2.045:   1%|1         | 8/625 [00:01<02:26,  4.20it/s]
reward: -6.2071, last reward: -8.8872, gradient norm:  11.97:   1%|1         | 8/625 [00:02<02:26,  4.20it/s]
reward: -6.2071, last reward: -8.8872, gradient norm:  11.97:   1%|1         | 9/625 [00:02<02:26,  4.20it/s]
reward: -6.5838, last reward: -9.2693, gradient norm:  3.34:   1%|1         | 9/625 [00:02<02:26,  4.20it/s]
reward: -6.5838, last reward: -9.2693, gradient norm:  3.34:   2%|1         | 10/625 [00:02<02:26,  4.21it/s]
reward: -6.2601, last reward: -9.0436, gradient norm:  4.874:   2%|1         | 10/625 [00:02<02:26,  4.21it/s]
reward: -6.2601, last reward: -9.0436, gradient norm:  4.874:   2%|1         | 11/625 [00:02<02:25,  4.21it/s]
reward: -6.3676, last reward: -8.2883, gradient norm:  2.542:   2%|1         | 11/625 [00:02<02:25,  4.21it/s]
reward: -6.3676, last reward: -8.2883, gradient norm:  2.542:   2%|1         | 12/625 [00:02<02:25,  4.21it/s]
reward: -5.9768, last reward: -8.4551, gradient norm:  2.931:   2%|1         | 12/625 [00:03<02:25,  4.21it/s]
reward: -5.9768, last reward: -8.4551, gradient norm:  2.931:   2%|2         | 13/625 [00:03<02:25,  4.22it/s]
reward: -5.9597, last reward: -8.0172, gradient norm:  5.493:   2%|2         | 13/625 [00:03<02:25,  4.22it/s]
reward: -5.9597, last reward: -8.0172, gradient norm:  5.493:   2%|2         | 14/625 [00:03<02:24,  4.22it/s]
reward: -6.0045, last reward: -6.3726, gradient norm:  1.216:   2%|2         | 14/625 [00:03<02:24,  4.22it/s]
reward: -6.0045, last reward: -6.3726, gradient norm:  1.216:   2%|2         | 15/625 [00:03<02:24,  4.22it/s]
reward: -6.0157, last reward: -7.4454, gradient norm:  4.614:   2%|2         | 15/625 [00:03<02:24,  4.22it/s]
reward: -6.0157, last reward: -7.4454, gradient norm:  4.614:   3%|2         | 16/625 [00:03<02:24,  4.22it/s]
reward: -5.7248, last reward: -4.7793, gradient norm:  11.7:   3%|2         | 16/625 [00:04<02:24,  4.22it/s]
reward: -5.7248, last reward: -4.7793, gradient norm:  11.7:   3%|2         | 17/625 [00:04<02:24,  4.21it/s]
reward: -5.8783, last reward: -3.7558, gradient norm:  7.704:   3%|2         | 17/625 [00:04<02:24,  4.21it/s]
reward: -5.8783, last reward: -3.7558, gradient norm:  7.704:   3%|2         | 18/625 [00:04<02:24,  4.21it/s]
reward: -6.0913, last reward: -6.0003, gradient norm:  17.23:   3%|2         | 18/625 [00:04<02:24,  4.21it/s]
reward: -6.0913, last reward: -6.0003, gradient norm:  17.23:   3%|3         | 19/625 [00:04<02:24,  4.20it/s]
reward: -5.9328, last reward: -5.2019, gradient norm:  3.004:   3%|3         | 19/625 [00:04<02:24,  4.20it/s]
reward: -5.9328, last reward: -5.2019, gradient norm:  3.004:   3%|3         | 20/625 [00:04<02:24,  4.20it/s]
reward: -6.1899, last reward: -6.5583, gradient norm:  8.905:   3%|3         | 20/625 [00:05<02:24,  4.20it/s]
reward: -6.1899, last reward: -6.5583, gradient norm:  8.905:   3%|3         | 21/625 [00:05<02:23,  4.20it/s]
reward: -5.8776, last reward: -6.3394, gradient norm:  86.04:   3%|3         | 21/625 [00:05<02:23,  4.20it/s]
reward: -5.8776, last reward: -6.3394, gradient norm:  86.04:   4%|3         | 22/625 [00:05<02:23,  4.19it/s]
reward: -6.3972, last reward: -6.5765, gradient norm:  20.42:   4%|3         | 22/625 [00:05<02:23,  4.19it/s]
reward: -6.3972, last reward: -6.5765, gradient norm:  20.42:   4%|3         | 23/625 [00:05<02:23,  4.20it/s]
reward: -6.3652, last reward: -5.7013, gradient norm:  4.733:   4%|3         | 23/625 [00:05<02:23,  4.20it/s]
reward: -6.3652, last reward: -5.7013, gradient norm:  4.733:   4%|3         | 24/625 [00:05<02:22,  4.21it/s]
reward: -5.5586, last reward: -6.3572, gradient norm:  7.792:   4%|3         | 24/625 [00:05<02:22,  4.21it/s]
reward: -5.5586, last reward: -6.3572, gradient norm:  7.792:   4%|4         | 25/625 [00:05<02:22,  4.20it/s]
reward: -5.4795, last reward: -4.5168, gradient norm:  1.692:   4%|4         | 25/625 [00:06<02:22,  4.20it/s]
reward: -5.4795, last reward: -4.5168, gradient norm:  1.692:   4%|4         | 26/625 [00:06<02:22,  4.21it/s]
reward: -5.5407, last reward: -7.0325, gradient norm:  773.3:   4%|4         | 26/625 [00:06<02:22,  4.21it/s]
reward: -5.5407, last reward: -7.0325, gradient norm:  773.3:   4%|4         | 27/625 [00:06<02:22,  4.20it/s]
reward: -5.7399, last reward: -6.0130, gradient norm:  2.865:   4%|4         | 27/625 [00:06<02:22,  4.20it/s]
reward: -5.7399, last reward: -6.0130, gradient norm:  2.865:   4%|4         | 28/625 [00:06<02:22,  4.20it/s]
reward: -6.0738, last reward: -6.5728, gradient norm:  2.833:   4%|4         | 28/625 [00:06<02:22,  4.20it/s]
reward: -6.0738, last reward: -6.5728, gradient norm:  2.833:   5%|4         | 29/625 [00:06<02:21,  4.20it/s]
reward: -6.0101, last reward: -6.4175, gradient norm:  6.212:   5%|4         | 29/625 [00:07<02:21,  4.20it/s]
reward: -6.0101, last reward: -6.4175, gradient norm:  6.212:   5%|4         | 30/625 [00:07<02:21,  4.20it/s]
reward: -5.9955, last reward: -4.7723, gradient norm:  3.158:   5%|4         | 30/625 [00:07<02:21,  4.20it/s]
reward: -5.9955, last reward: -4.7723, gradient norm:  3.158:   5%|4         | 31/625 [00:07<02:21,  4.21it/s]
reward: -5.6103, last reward: -3.8313, gradient norm:  5.422:   5%|4         | 31/625 [00:07<02:21,  4.21it/s]
reward: -5.6103, last reward: -3.8313, gradient norm:  5.422:   5%|5         | 32/625 [00:07<02:20,  4.21it/s]
reward: -5.6042, last reward: -3.8542, gradient norm:  5.069:   5%|5         | 32/625 [00:07<02:20,  4.21it/s]
reward: -5.6042, last reward: -3.8542, gradient norm:  5.069:   5%|5         | 33/625 [00:07<02:20,  4.21it/s]
reward: -5.5265, last reward: -4.3386, gradient norm:  2.368:   5%|5         | 33/625 [00:08<02:20,  4.21it/s]
reward: -5.5265, last reward: -4.3386, gradient norm:  2.368:   5%|5         | 34/625 [00:08<02:20,  4.21it/s]
reward: -5.6277, last reward: -5.1658, gradient norm:  25.25:   5%|5         | 34/625 [00:08<02:20,  4.21it/s]
reward: -5.6277, last reward: -5.1658, gradient norm:  25.25:   6%|5         | 35/625 [00:08<02:20,  4.21it/s]
reward: -5.6876, last reward: -5.1197, gradient norm:  110.2:   6%|5         | 35/625 [00:08<02:20,  4.21it/s]
reward: -5.6876, last reward: -5.1197, gradient norm:  110.2:   6%|5         | 36/625 [00:08<02:19,  4.21it/s]
reward: -6.0015, last reward: -4.9656, gradient norm:  1.3:   6%|5         | 36/625 [00:08<02:19,  4.21it/s]
reward: -6.0015, last reward: -4.9656, gradient norm:  1.3:   6%|5         | 37/625 [00:08<02:19,  4.22it/s]
reward: -5.6628, last reward: -6.0784, gradient norm:  10.63:   6%|5         | 37/625 [00:09<02:19,  4.22it/s]
reward: -5.6628, last reward: -6.0784, gradient norm:  10.63:   6%|6         | 38/625 [00:09<02:19,  4.22it/s]
reward: -5.8188, last reward: -5.3053, gradient norm:  20.95:   6%|6         | 38/625 [00:09<02:19,  4.22it/s]
reward: -5.8188, last reward: -5.3053, gradient norm:  20.95:   6%|6         | 39/625 [00:09<02:19,  4.21it/s]
reward: -5.5934, last reward: -5.4250, gradient norm:  2.52:   6%|6         | 39/625 [00:09<02:19,  4.21it/s]
reward: -5.5934, last reward: -5.4250, gradient norm:  2.52:   6%|6         | 40/625 [00:09<02:19,  4.20it/s]
reward: -5.4317, last reward: -5.2191, gradient norm:  11.53:   6%|6         | 40/625 [00:09<02:19,  4.20it/s]
reward: -5.4317, last reward: -5.2191, gradient norm:  11.53:   7%|6         | 41/625 [00:09<02:19,  4.20it/s]
reward: -5.8227, last reward: -5.2263, gradient norm:  5.554:   7%|6         | 41/625 [00:10<02:19,  4.20it/s]
reward: -5.8227, last reward: -5.2263, gradient norm:  5.554:   7%|6         | 42/625 [00:10<02:19,  4.19it/s]
reward: -5.6086, last reward: -3.3930, gradient norm:  13.2:   7%|6         | 42/625 [00:10<02:19,  4.19it/s]
reward: -5.6086, last reward: -3.3930, gradient norm:  13.2:   7%|6         | 43/625 [00:10<02:18,  4.19it/s]
reward: -5.5969, last reward: -4.8821, gradient norm:  2.538:   7%|6         | 43/625 [00:10<02:18,  4.19it/s]
reward: -5.5969, last reward: -4.8821, gradient norm:  2.538:   7%|7         | 44/625 [00:10<02:18,  4.19it/s]
reward: -5.5018, last reward: -4.3099, gradient norm:  3.416:   7%|7         | 44/625 [00:10<02:18,  4.19it/s]
reward: -5.5018, last reward: -4.3099, gradient norm:  3.416:   7%|7         | 45/625 [00:10<02:18,  4.18it/s]
reward: -5.6813, last reward: -5.1515, gradient norm:  19.79:   7%|7         | 45/625 [00:10<02:18,  4.18it/s]
reward: -5.6813, last reward: -5.1515, gradient norm:  19.79:   7%|7         | 46/625 [00:10<02:18,  4.17it/s]
reward: -5.8823, last reward: -5.6010, gradient norm:  12.73:   7%|7         | 46/625 [00:11<02:18,  4.17it/s]
reward: -5.8823, last reward: -5.6010, gradient norm:  12.73:   8%|7         | 47/625 [00:11<02:18,  4.17it/s]
reward: -5.2582, last reward: -6.6556, gradient norm:  6.568:   8%|7         | 47/625 [00:11<02:18,  4.17it/s]
reward: -5.2582, last reward: -6.6556, gradient norm:  6.568:   8%|7         | 48/625 [00:11<02:17,  4.18it/s]
reward: -5.6368, last reward: -6.3310, gradient norm:  8.046:   8%|7         | 48/625 [00:11<02:17,  4.18it/s]
reward: -5.6368, last reward: -6.3310, gradient norm:  8.046:   8%|7         | 49/625 [00:11<02:17,  4.18it/s]
reward: -5.6776, last reward: -6.1928, gradient norm:  4.976:   8%|7         | 49/625 [00:11<02:17,  4.18it/s]
reward: -5.6776, last reward: -6.1928, gradient norm:  4.976:   8%|8         | 50/625 [00:11<02:17,  4.18it/s]
reward: -5.6418, last reward: -4.5608, gradient norm:  2.355:   8%|8         | 50/625 [00:12<02:17,  4.18it/s]
reward: -5.6418, last reward: -4.5608, gradient norm:  2.355:   8%|8         | 51/625 [00:12<02:17,  4.18it/s]
reward: -5.4142, last reward: -4.4533, gradient norm:  3.903:   8%|8         | 51/625 [00:12<02:17,  4.18it/s]
reward: -5.4142, last reward: -4.4533, gradient norm:  3.903:   8%|8         | 52/625 [00:12<02:16,  4.19it/s]
reward: -5.3920, last reward: -3.6933, gradient norm:  5.534:   8%|8         | 52/625 [00:12<02:16,  4.19it/s]
reward: -5.3920, last reward: -3.6933, gradient norm:  5.534:   8%|8         | 53/625 [00:12<02:16,  4.19it/s]
reward: -5.3322, last reward: -3.1984, gradient norm:  4.058:   8%|8         | 53/625 [00:12<02:16,  4.19it/s]
reward: -5.3322, last reward: -3.1984, gradient norm:  4.058:   9%|8         | 54/625 [00:12<02:16,  4.19it/s]
reward: -5.3709, last reward: -4.5488, gradient norm:  37.33:   9%|8         | 54/625 [00:13<02:16,  4.19it/s]
reward: -5.3709, last reward: -4.5488, gradient norm:  37.33:   9%|8         | 55/625 [00:13<02:16,  4.19it/s]
reward: -5.4076, last reward: -3.1880, gradient norm:  1.395:   9%|8         | 55/625 [00:13<02:16,  4.19it/s]
reward: -5.4076, last reward: -3.1880, gradient norm:  1.395:   9%|8         | 56/625 [00:13<02:16,  4.18it/s]
reward: -5.3727, last reward: -2.1695, gradient norm:  2.613:   9%|8         | 56/625 [00:13<02:16,  4.18it/s]
reward: -5.3727, last reward: -2.1695, gradient norm:  2.613:   9%|9         | 57/625 [00:13<02:15,  4.19it/s]
reward: -5.6188, last reward: -2.7869, gradient norm:  1.464:   9%|9         | 57/625 [00:13<02:15,  4.19it/s]
reward: -5.6188, last reward: -2.7869, gradient norm:  1.464:   9%|9         | 58/625 [00:13<02:15,  4.19it/s]
reward: -5.4788, last reward: -5.2309, gradient norm:  12.19:   9%|9         | 58/625 [00:14<02:15,  4.19it/s]
reward: -5.4788, last reward: -5.2309, gradient norm:  12.19:   9%|9         | 59/625 [00:14<02:16,  4.15it/s]
reward: -5.1972, last reward: -5.1203, gradient norm:  67.95:   9%|9         | 59/625 [00:14<02:16,  4.15it/s]
reward: -5.1972, last reward: -5.1203, gradient norm:  67.95:  10%|9         | 60/625 [00:14<02:15,  4.16it/s]
reward: -5.4977, last reward: -4.8712, gradient norm:  4.688:  10%|9         | 60/625 [00:14<02:15,  4.16it/s]
reward: -5.4977, last reward: -4.8712, gradient norm:  4.688:  10%|9         | 61/625 [00:14<02:15,  4.17it/s]
reward: -5.4804, last reward: -6.0890, gradient norm:  3.287:  10%|9         | 61/625 [00:14<02:15,  4.17it/s]
reward: -5.4804, last reward: -6.0890, gradient norm:  3.287:  10%|9         | 62/625 [00:14<02:14,  4.19it/s]
reward: -5.3051, last reward: -4.3689, gradient norm:  64.25:  10%|9         | 62/625 [00:15<02:14,  4.19it/s]
reward: -5.3051, last reward: -4.3689, gradient norm:  64.25:  10%|#         | 63/625 [00:15<02:14,  4.19it/s]
reward: -5.3228, last reward: -4.2780, gradient norm:  9.055:  10%|#         | 63/625 [00:15<02:14,  4.19it/s]
reward: -5.3228, last reward: -4.2780, gradient norm:  9.055:  10%|#         | 64/625 [00:15<02:14,  4.16it/s]
reward: -5.1394, last reward: -4.0425, gradient norm:  9.393:  10%|#         | 64/625 [00:15<02:14,  4.16it/s]
reward: -5.1394, last reward: -4.0425, gradient norm:  9.393:  10%|#         | 65/625 [00:15<02:14,  4.17it/s]
reward: -5.2673, last reward: -4.0022, gradient norm:  8.597:  10%|#         | 65/625 [00:15<02:14,  4.17it/s]
reward: -5.2673, last reward: -4.0022, gradient norm:  8.597:  11%|#         | 66/625 [00:15<02:14,  4.17it/s]
reward: -5.1040, last reward: -4.5461, gradient norm:  18.81:  11%|#         | 66/625 [00:15<02:14,  4.17it/s]
reward: -5.1040, last reward: -4.5461, gradient norm:  18.81:  11%|#         | 67/625 [00:15<02:13,  4.17it/s]
reward: -5.3599, last reward: -4.0312, gradient norm:  34.25:  11%|#         | 67/625 [00:16<02:13,  4.17it/s]
reward: -5.3599, last reward: -4.0312, gradient norm:  34.25:  11%|#         | 68/625 [00:16<02:13,  4.18it/s]
reward: -5.3867, last reward: -6.7588, gradient norm:  4.311:  11%|#         | 68/625 [00:16<02:13,  4.18it/s]
reward: -5.3867, last reward: -6.7588, gradient norm:  4.311:  11%|#1        | 69/625 [00:16<02:13,  4.18it/s]
reward: -5.3548, last reward: -8.1878, gradient norm:  44.19:  11%|#1        | 69/625 [00:16<02:13,  4.18it/s]
reward: -5.3548, last reward: -8.1878, gradient norm:  44.19:  11%|#1        | 70/625 [00:16<02:12,  4.18it/s]
reward: -5.3264, last reward: -6.2046, gradient norm:  6.25:  11%|#1        | 70/625 [00:16<02:12,  4.18it/s]
reward: -5.3264, last reward: -6.2046, gradient norm:  6.25:  11%|#1        | 71/625 [00:16<02:12,  4.19it/s]
reward: -5.3723, last reward: -5.9680, gradient norm:  11.1:  11%|#1        | 71/625 [00:17<02:12,  4.19it/s]
...

结论

在本教程中,我们学习了如何从头开始编码一个无状态环境。我们涉及了以下主题:

  • 编码环境时需要注意的四个基本组件(stepreset、种子和构建规范)。我们看到这些方法和类如何与TensorDict类交互;
  • 如何测试环境是否正确编码使用check_env_specs();
  • 如何在无状态环境的上下文中追加转换以及如何编写自定义转换;
  • 如何在完全可微分的模拟器上训练策略。

脚本的总运行时间:(2 分钟 30.147 秒)

下载 Python 源代码:pendulum.py

下载 Jupyter 笔记本:pendulum.ipynb

Sphinx-Gallery 生成的画廊

相关文章
|
2月前
|
PyTorch 算法框架/工具 异构计算
PyTorch 2.2 中文官方教程(十八)(1)
PyTorch 2.2 中文官方教程(十八)
101 2
PyTorch 2.2 中文官方教程(十八)(1)
|
2月前
|
并行计算 PyTorch 算法框架/工具
PyTorch 2.2 中文官方教程(十七)(4)
PyTorch 2.2 中文官方教程(十七)
67 2
PyTorch 2.2 中文官方教程(十七)(4)
|
2月前
|
PyTorch 算法框架/工具 异构计算
PyTorch 2.2 中文官方教程(十九)(1)
PyTorch 2.2 中文官方教程(十九)
89 1
PyTorch 2.2 中文官方教程(十九)(1)
|
2月前
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch 2.2 中文官方教程(十八)(3)
PyTorch 2.2 中文官方教程(十八)
49 1
PyTorch 2.2 中文官方教程(十八)(3)
|
2月前
|
API PyTorch 算法框架/工具
PyTorch 2.2 中文官方教程(十八)(2)
PyTorch 2.2 中文官方教程(十八)
77 1
PyTorch 2.2 中文官方教程(十八)(2)
|
2月前
|
异构计算 PyTorch 算法框架/工具
PyTorch 2.2 中文官方教程(十七)(3)
PyTorch 2.2 中文官方教程(十七)
57 1
PyTorch 2.2 中文官方教程(十七)(3)
|
2月前
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch 2.2 中文官方教程(十八)(4)
PyTorch 2.2 中文官方教程(十八)
65 1
|
2月前
|
PyTorch 算法框架/工具 异构计算
PyTorch 2.2 中文官方教程(二十)(4)
PyTorch 2.2 中文官方教程(二十)
51 0
PyTorch 2.2 中文官方教程(二十)(4)
|
2月前
|
Android开发 PyTorch 算法框架/工具
PyTorch 2.2 中文官方教程(二十)(2)
PyTorch 2.2 中文官方教程(二十)
67 0
PyTorch 2.2 中文官方教程(二十)(2)
|
2月前
|
iOS开发 PyTorch 算法框架/工具
PyTorch 2.2 中文官方教程(二十)(1)
PyTorch 2.2 中文官方教程(二十)
68 0
PyTorch 2.2 中文官方教程(二十)(1)