PyTorch 2.2 中文官方教程(十九)(4)

简介: PyTorch 2.2 中文官方教程(十九)

PyTorch 2.2 中文官方教程(十九)(3)https://developer.aliyun.com/article/1482623

要求

什么是Join

使用分布式数据并行开始 - 基本用例中,您看到了使用DistributedDataParallel执行数据并行训练的一般框架。这隐式地在每次反向传播中安排所有规约,以在各个 rank 之间同步梯度。这种集体通信需要来自进程组中所有 rank 的参与,因此,如果一个 rank 的输入较少,那么其他 rank 将挂起或出错(取决于后端)。更一般地说,对于执行每次迭代同步集体通信的任何类,这个问题都会持续存在。

Join是一个上下文管理器,用于围绕您的每个 rank 训练循环,以便在不均匀输入下进行训练。上下文管理器允许提前耗尽输入的 rank(即join提前)来模拟尚未加入的 rank 执行的集体通信。通信被模拟的方式由钩子指定。

使用JoinDistributedDataParallel

PyTorch 的DistributedDataParallelJoin上下文管理器完全兼容。以下是一个示例用法:

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed.algorithms.join import Join
from torch.nn.parallel import DistributedDataParallel as DDP
BACKEND = "nccl"
WORLD_SIZE = 2
NUM_INPUTS = 5
def worker(rank):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)
    model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
    # Rank 1 gets one more input than rank 0
    inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]
    num_inputs = 0
    with Join([model]):
        for input in inputs:
            num_inputs += 1
            loss = model(input).sum()
            loss.backward()
    print(f"Rank {rank} has exhausted all {num_inputs} of its inputs!")
def main():
    mp.spawn(worker, nprocs=WORLD_SIZE, join=True)
if __name__ == "__main__":
    main() 

这将产生以下输出(其中来自 rank 0 和 rank 1 的print()可能是任意顺序):

Rank 0 has exhausted all 5 of its inputs!
Rank 1 has exhausted all 6 of its inputs! 

注意

DistributedDataParallel在引入这个通用的Join上下文管理器之前提供了自己的join()上下文管理器。在上面的示例中,使用with Join([model]):等同于使用with model.join():。现有的DistributedDataParallel.join()的一个限制是它不允许多个参与类,例如DistributedDataParallelZeroRedundancyOptimizer一起。

使用JoinDistributedDataParallelZeroRedundancyOptimizer

Join上下文管理器不仅适用于单个类,还适用于多个类一起。PyTorch 的ZeroRedundancyOptimizer也与上下文管理器兼容,因此,在这里,我们将检查如何修改之前的示例以同时使用DistributedDataParallelZeroRedundancyOptimizer

from torch.distributed.optim import ZeroRedundancyOptimizer as ZeRO
from torch.optim import Adam
def worker(rank):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)
    model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
    optim = ZeRO(model.parameters(), Adam, lr=0.01)
    # Rank 1 gets one more input than rank 0
    inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]
    num_inputs = 0
    # Pass both `model` and `optim` into `Join()`
    with Join([model, optim]):
        for input in inputs:
            num_inputs += 1
            loss = model(input).sum()
            loss.backward()
            optim.step()
    print(f"Rank {rank} has exhausted all {num_inputs} of its inputs!") 

这将产生与之前相同的输出。显著的变化是额外将ZeroRedundancyOptimizer实例传递给Join()

传递关键字参数

类可以提供关键字参数,在运行时修改它们在上下文管理器中的行为。例如,DistributedDataParallel提供了一个参数divide_by_initial_world_size,确定梯度是由初始世界大小还是有效世界大小(即非加入等级的数量)除以。这样的关键字参数可以直接传递到上下文管理器中。

with Join([model, optim], divide_by_initial_world_size=False):
    for input in inputs:
        ... 

警告

传递给上下文管理器的关键字参数在所有参与类之间共享。这不应该是一个限制,因为我们不希望出现多个Joinable需要相同参数的不同设置的情况。尽管如此,这是需要记住的一点。

Join是如何工作的?

现在我们已经看到了如何使用Join上下文管理器的一些初步示例,让我们深入了解它的工作原理。这将为您提供对其提供的全部功能的更深入了解,并为您准备好制作自己的自定义类。在这里,我们将介绍Join类以及支持类JoinableJoinHook

Joinable

首先,与Join上下文管理器兼容的类必须继承自抽象基类Joinable。特别是,Joinable必须实现:

  • join_hook(self, **kwargs) -> JoinHook

这将返回JoinableJoinHook实例,确定加入的进程应如何模拟Joinable执行的每次迭代集体通信。

  • join_device(self) -> torch.device

这将返回一个设备,该设备将由Join上下文管理器用于执行集体通信,例如torch.device("cuda:0")torch.device("cpu")

  • join_process_group(self) -> ProcessGroup

这将返回要由Join上下文管理器用于执行集体通信的进程组。

特别是,join_devicejoin_process_group是必需的属性,以确保上下文管理器可以安排加入和未加入进程之间的集体通信。一个用法是使用全局归约在每次迭代中计算非加入进程的数量。另一个用法是实现throw_on_early_termination=True所需的机制,我们将在下面稍后解释。

DistributedDataParallelZeroRedundancyOptimizer已经继承自Joinable并实现了上述方法,这就是为什么我们可以直接在之前的示例中使用它们。

Joinable类应确保调用Joinable构造函数,因为它初始化了一个JoinConfig实例,该实例在上下文管理器内部用于确保正确性。这将保存在每个Joinable中作为一个字段_join_config

JoinHook

接下来,让我们来分解JoinHook类。一个JoinHook提供了两个进入上下文管理器的入口点:

  • main_hook(self) -> None

这个钩子在每个已加入的等级中被重复调用,同时存在一个尚未加入的等级。它旨在模拟每个训练迭代中由Joinable执行的集体通信(例如,在一个前向传递、反向传递和优化器步骤中)。

  • post_hook(self, is_last_joiner: bool) -> None

这个钩子在所有等级都加入后被调用。它传递了一个额外的bool参数is_last_joiner,指示该等级是否是最后加入的等级之一。该参数可能对同步有用。

为了给出这些钩子可能看起来像什么的具体示例,ZeroRedundancyOptimizer提供的主要钩子每次执行一步优化器,因为加入的等级仍然负责更新和同步其参数的片段,DistributedDataParallel提供的后钩子将最终更新的模型从最后加入的等级之一广播到所有等级,以确保它在所有等级上都是相同的。

Join

最后,让我们看看这些如何适应 Join 类本身。

  • __init__(self, joinables: List[Joinable], enable: bool = True, throw_on_early_termination: bool = False)

正如我们在前面的示例中看到的,构造函数接受参与训练循环的 Joinable 类的列表。这些应该是在每次迭代中执行集体通信的类。

enable 是一个 bool,如果您知道不会有不均匀的输入,可以将其设置为 False,在这种情况下,上下文管理器类似于 contextlib.nullcontext() 变得无效。这也可能会禁用参与的 Joinable 中的与连接相关的计算。

throw_on_early_termination 是一个 bool,如果检测到不均匀的输入,可以将其设置为 True,以便每个等级在那一刻引发异常。这对于不符合上下文管理器要求的情况非常有用,这种情况最典型的是当来自不同类的集体通信可能任意交错时,例如在使用具有 SyncBatchNorm 层的模型时使用 DistributedDataParallel。在这种情况下,应将此参数设置为 True,以便应用逻辑可以捕获异常并确定如何继续。

  • 核心逻辑发生在 __exit__() 方法中,当存在未连接的等级时循环调用每个 Joinable 的主要钩子,然后一旦所有等级都加入,调用它们的后处理钩子。主要钩子和后处理钩子都按照传入的 Joinable 的顺序进行迭代。
  • 上下文管理器需要来自未连接进程的心跳。因此,每个 Joinable 类应在其每次迭代的集体通信之前调用 Join.notify_join_context()。上下文管理器将确保只有第一个传入的 Joinable 实际发送心跳。

警告

如上所述关于 throw_on_early_terminationJoin 上下文管理器与某些类的组合不兼容。JoinableJoinHook 必须是可序列化的,因为每个钩子在继续下一个之前完全执行。换句话说,两个钩子不能重叠。此外,目前主要钩子和后处理钩子都按照相同的确定性顺序进行迭代。如果这看起来是一个主要限制,我们可以修改 API 以允许自定义排序。

使玩具类与 Join 兼容

由于上一节介绍了几个概念,让我们通过一个玩具示例来实践。在这里,我们将实现一个类,该类在其等级加入之前计算所有等级看到的输入数量。这应该提供一个基本的想法,说明您如何使自己的类与 Join 上下文管理器兼容。

具体来说,以下代码使每个等级打印出(1)在其加入之前所有等级看到的输入数量和(2)所有等级看到的总输入数量。

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed.algorithms.join import Join, Joinable, JoinHook
BACKEND = "nccl"
WORLD_SIZE = 2
NUM_INPUTS = 5
class CounterJoinHook(JoinHook):
  r"""
 Join hook for :class:`Counter`.
 Arguments:
 counter (Counter): the :class:`Counter` object using this hook.
 sync_max_count (bool): whether to sync the max count once all ranks
 join.
 """
    def __init__(
        self,
        counter,
        sync_max_count
    ):
        self.counter = counter
        self.sync_max_count = sync_max_count
    def main_hook(self):
  r"""
 Shadows the counter's all-reduce by all-reducing a dim-1 zero tensor.
 """
        t = torch.zeros(1, device=self.counter.device)
        dist.all_reduce(t)
    def post_hook(self, is_last_joiner: bool):
  r"""
 Synchronizes the max count across all :class:`Counter` s if
 ``sync_max_count=True``.
 """
        if not self.sync_max_count:
            return
        rank = dist.get_rank(self.counter.process_group)
        common_rank = self.counter.find_common_rank(rank, is_last_joiner)
        if rank == common_rank:
            self.counter.max_count = self.counter.count.detach().clone()
        dist.broadcast(self.counter.max_count, src=common_rank)
class Counter(Joinable):
  r"""
 Example :class:`Joinable` that counts the number of training iterations
 that it participates in.
 """
    def __init__(self, device, process_group):
        super(Counter, self).__init__()
        self.device = device
        self.process_group = process_group
        self.count = torch.tensor([0], device=device).float()
        self.max_count = torch.tensor([0], device=device).float()
    def __call__(self):
  r"""
 Counts the number of inputs processed on this iteration by all ranks
 by all-reducing a dim-1 one tensor; increments its own internal count.
 """
        Join.notify_join_context(self)
        t = torch.ones(1, device=self.device).float()
        dist.all_reduce(t)
        self.count += t
    def join_hook(self, **kwargs) -> JoinHook:
  r"""
 Return a join hook that shadows the all-reduce in :meth:`__call__`.
 This join hook supports the following keyword arguments:
 sync_max_count (bool, optional): whether to synchronize the maximum
 count across all ranks once all ranks join; default is ``False``.
 """
        sync_max_count = kwargs.get("sync_max_count", False)
        return CounterJoinHook(self, sync_max_count)
    @property
    def join_device(self) -> torch.device:
        return self.device
    @property
    def join_process_group(self):
        return self.process_group
    def find_common_rank(self, rank, to_consider):
  r"""
 Returns the max rank of the ones to consider over the process group.
 """
        common_rank = torch.tensor([rank if to_consider else -1], device=self.device)
        dist.all_reduce(common_rank, op=dist.ReduceOp.MAX, group=self.process_group)
        common_rank = common_rank.item()
        return common_rank
def worker(rank):
    assert torch.cuda.device_count() >= WORLD_SIZE
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)
    counter = Counter(torch.device(f"cuda:{rank}"), dist.group.WORLD)
    inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]
    with Join([counter], sync_max_count=True):
        for _ in inputs:
            counter()
    print(f"{int(counter.count.item())} inputs processed before rank {rank} joined!")
    print(f"{int(counter.max_count.item())} inputs processed across all ranks!")
def main():
    mp.spawn(worker, nprocs=WORLD_SIZE, join=True)
if __name__ == "__main__":
    main() 

由于等级 0 看到 5 个输入,等级 1 看到 6 个输入,因此产生输出:

10 inputs processed before rank 0 joined!
11 inputs processed across all ranks!
11 inputs processed before rank 1 joined!
11 inputs processed across all ranks! 

一些要强调的关键点:

  • Counter 实例在每次迭代中执行一次全局归约,因此主要钩子也执行一次全局归约以进行遮蔽。
  • Counter 类在其 __call__() 方法的开头调用 Join.notify_join_context(),因为这是在其每次迭代的集体通信之前的位置(即其全局归约)。
  • is_last_joiner 参数用于确定后处理中的广播源。
  • 我们传递 sync_max_count 关键字参数给上下文管理器,然后将其转发到 Counter 的连接钩子
相关文章
|
15天前
|
Android开发 PyTorch 算法框架/工具
PyTorch 2.2 中文官方教程(二十)(2)
PyTorch 2.2 中文官方教程(二十)
42 0
PyTorch 2.2 中文官方教程(二十)(2)
|
15天前
|
iOS开发 PyTorch 算法框架/工具
PyTorch 2.2 中文官方教程(二十)(1)
PyTorch 2.2 中文官方教程(二十)
45 0
PyTorch 2.2 中文官方教程(二十)(1)
|
15天前
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch 2.2 中文官方教程(十八)(4)
PyTorch 2.2 中文官方教程(十八)
53 1
|
15天前
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch 2.2 中文官方教程(十八)(3)
PyTorch 2.2 中文官方教程(十八)
26 1
PyTorch 2.2 中文官方教程(十八)(3)
|
15天前
|
并行计算 PyTorch 算法框架/工具
PyTorch 2.2 中文官方教程(十七)(4)
PyTorch 2.2 中文官方教程(十七)
25 2
PyTorch 2.2 中文官方教程(十七)(4)
|
15天前
|
PyTorch 算法框架/工具 机器学习/深度学习
PyTorch 2.2 中文官方教程(十七)(2)
PyTorch 2.2 中文官方教程(十七)
38 1
PyTorch 2.2 中文官方教程(十七)(2)
|
15天前
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch 2.2 中文官方教程(十五)(1)
PyTorch 2.2 中文官方教程(十五)
46 1
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch 2.2 中文官方教程(十四)(4)
PyTorch 2.2 中文官方教程(十四)
63 1
PyTorch 2.2 中文官方教程(十四)(4)
|
15天前
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch 2.2 中文官方教程(十四)(2)
PyTorch 2.2 中文官方教程(十四)
49 1
|
3月前
|
机器学习/深度学习 编解码 PyTorch
Pytorch实现手写数字识别 | MNIST数据集(CNN卷积神经网络)
Pytorch实现手写数字识别 | MNIST数据集(CNN卷积神经网络)