PyTorch 2.2 中文官方教程(十九)(3)https://developer.aliyun.com/article/1482623
要求
- PyTorch 1.10+
- 使用分布式数据并行开始
- 使用 ZeroRedundancyOptimizer 对优化器状态进行分片
什么是Join
?
在使用分布式数据并行开始 - 基本用例中,您看到了使用DistributedDataParallel执行数据并行训练的一般框架。这隐式地在每次反向传播中安排所有规约,以在各个 rank 之间同步梯度。这种集体通信需要来自进程组中所有 rank 的参与,因此,如果一个 rank 的输入较少,那么其他 rank 将挂起或出错(取决于后端)。更一般地说,对于执行每次迭代同步集体通信的任何类,这个问题都会持续存在。
Join
是一个上下文管理器,用于围绕您的每个 rank 训练循环,以便在不均匀输入下进行训练。上下文管理器允许提前耗尽输入的 rank(即join提前)来模拟尚未加入的 rank 执行的集体通信。通信被模拟的方式由钩子指定。
使用Join
与DistributedDataParallel
PyTorch 的DistributedDataParallel与Join
上下文管理器完全兼容。以下是一个示例用法:
import os import torch import torch.distributed as dist import torch.multiprocessing as mp from torch.distributed.algorithms.join import Join from torch.nn.parallel import DistributedDataParallel as DDP BACKEND = "nccl" WORLD_SIZE = 2 NUM_INPUTS = 5 def worker(rank): os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '29500' dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE) model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank]) # Rank 1 gets one more input than rank 0 inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)] num_inputs = 0 with Join([model]): for input in inputs: num_inputs += 1 loss = model(input).sum() loss.backward() print(f"Rank {rank} has exhausted all {num_inputs} of its inputs!") def main(): mp.spawn(worker, nprocs=WORLD_SIZE, join=True) if __name__ == "__main__": main()
这将产生以下输出(其中来自 rank 0 和 rank 1 的print()
可能是任意顺序):
Rank 0 has exhausted all 5 of its inputs! Rank 1 has exhausted all 6 of its inputs!
注意
DistributedDataParallel在引入这个通用的Join
上下文管理器之前提供了自己的join()上下文管理器。在上面的示例中,使用with Join([model]):
等同于使用with model.join():
。现有的DistributedDataParallel.join()
的一个限制是它不允许多个参与类,例如DistributedDataParallel
和ZeroRedundancyOptimizer一起。
使用Join
与DistributedDataParallel
和ZeroRedundancyOptimizer
Join
上下文管理器不仅适用于单个类,还适用于多个类一起。PyTorch 的ZeroRedundancyOptimizer
也与上下文管理器兼容,因此,在这里,我们将检查如何修改之前的示例以同时使用DistributedDataParallel
和ZeroRedundancyOptimizer
:
from torch.distributed.optim import ZeroRedundancyOptimizer as ZeRO from torch.optim import Adam def worker(rank): os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '29500' dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE) model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank]) optim = ZeRO(model.parameters(), Adam, lr=0.01) # Rank 1 gets one more input than rank 0 inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)] num_inputs = 0 # Pass both `model` and `optim` into `Join()` with Join([model, optim]): for input in inputs: num_inputs += 1 loss = model(input).sum() loss.backward() optim.step() print(f"Rank {rank} has exhausted all {num_inputs} of its inputs!")
这将产生与之前相同的输出。显著的变化是额外将ZeroRedundancyOptimizer
实例传递给Join()
。
传递关键字参数
类可以提供关键字参数,在运行时修改它们在上下文管理器中的行为。例如,DistributedDataParallel
提供了一个参数divide_by_initial_world_size
,确定梯度是由初始世界大小还是有效世界大小(即非加入等级的数量)除以。这样的关键字参数可以直接传递到上下文管理器中。
with Join([model, optim], divide_by_initial_world_size=False): for input in inputs: ...
警告
传递给上下文管理器的关键字参数在所有参与类之间共享。这不应该是一个限制,因为我们不希望出现多个Joinable
需要相同参数的不同设置的情况。尽管如此,这是需要记住的一点。
Join
是如何工作的?
现在我们已经看到了如何使用Join
上下文管理器的一些初步示例,让我们深入了解它的工作原理。这将为您提供对其提供的全部功能的更深入了解,并为您准备好制作自己的自定义类。在这里,我们将介绍Join
类以及支持类Joinable
和JoinHook
。
Joinable
首先,与Join
上下文管理器兼容的类必须继承自抽象基类Joinable
。特别是,Joinable
必须实现:
join_hook(self, **kwargs) -> JoinHook
这将返回Joinable
的JoinHook
实例,确定加入的进程应如何模拟Joinable
执行的每次迭代集体通信。
join_device(self) -> torch.device
这将返回一个设备,该设备将由Join
上下文管理器用于执行集体通信,例如torch.device("cuda:0")
或torch.device("cpu")
。
join_process_group(self) -> ProcessGroup
这将返回要由Join
上下文管理器用于执行集体通信的进程组。
特别是,join_device
和join_process_group
是必需的属性,以确保上下文管理器可以安排加入和未加入进程之间的集体通信。一个用法是使用全局归约在每次迭代中计算非加入进程的数量。另一个用法是实现throw_on_early_termination=True
所需的机制,我们将在下面稍后解释。
DistributedDataParallel
和ZeroRedundancyOptimizer
已经继承自Joinable
并实现了上述方法,这就是为什么我们可以直接在之前的示例中使用它们。
Joinable
类应确保调用Joinable
构造函数,因为它初始化了一个JoinConfig
实例,该实例在上下文管理器内部用于确保正确性。这将保存在每个Joinable
中作为一个字段_join_config
。
JoinHook
接下来,让我们来分解JoinHook
类。一个JoinHook
提供了两个进入上下文管理器的入口点:
main_hook(self) -> None
这个钩子在每个已加入的等级中被重复调用,同时存在一个尚未加入的等级。它旨在模拟每个训练迭代中由Joinable
执行的集体通信(例如,在一个前向传递、反向传递和优化器步骤中)。
post_hook(self, is_last_joiner: bool) -> None
这个钩子在所有等级都加入后被调用。它传递了一个额外的bool
参数is_last_joiner
,指示该等级是否是最后加入的等级之一。该参数可能对同步有用。
为了给出这些钩子可能看起来像什么的具体示例,ZeroRedundancyOptimizer
提供的主要钩子每次执行一步优化器,因为加入的等级仍然负责更新和同步其参数的片段,DistributedDataParallel
提供的后钩子将最终更新的模型从最后加入的等级之一广播到所有等级,以确保它在所有等级上都是相同的。
Join
最后,让我们看看这些如何适应 Join
类本身。
__init__(self, joinables: List[Joinable], enable: bool = True, throw_on_early_termination: bool = False)
正如我们在前面的示例中看到的,构造函数接受参与训练循环的 Joinable
类的列表。这些应该是在每次迭代中执行集体通信的类。
enable
是一个 bool
,如果您知道不会有不均匀的输入,可以将其设置为 False
,在这种情况下,上下文管理器类似于 contextlib.nullcontext()
变得无效。这也可能会禁用参与的 Joinable
中的与连接相关的计算。
throw_on_early_termination
是一个 bool
,如果检测到不均匀的输入,可以将其设置为 True
,以便每个等级在那一刻引发异常。这对于不符合上下文管理器要求的情况非常有用,这种情况最典型的是当来自不同类的集体通信可能任意交错时,例如在使用具有 SyncBatchNorm
层的模型时使用 DistributedDataParallel
。在这种情况下,应将此参数设置为 True
,以便应用逻辑可以捕获异常并确定如何继续。
- 核心逻辑发生在
__exit__()
方法中,当存在未连接的等级时循环调用每个Joinable
的主要钩子,然后一旦所有等级都加入,调用它们的后处理钩子。主要钩子和后处理钩子都按照传入的Joinable
的顺序进行迭代。 - 上下文管理器需要来自未连接进程的心跳。因此,每个
Joinable
类应在其每次迭代的集体通信之前调用Join.notify_join_context()
。上下文管理器将确保只有第一个传入的Joinable
实际发送心跳。
警告
如上所述关于 throw_on_early_termination
,Join
上下文管理器与某些类的组合不兼容。Joinable
的 JoinHook
必须是可序列化的,因为每个钩子在继续下一个之前完全执行。换句话说,两个钩子不能重叠。此外,目前主要钩子和后处理钩子都按照相同的确定性顺序进行迭代。如果这看起来是一个主要限制,我们可以修改 API 以允许自定义排序。
使玩具类与 Join
兼容
由于上一节介绍了几个概念,让我们通过一个玩具示例来实践。在这里,我们将实现一个类,该类在其等级加入之前计算所有等级看到的输入数量。这应该提供一个基本的想法,说明您如何使自己的类与 Join
上下文管理器兼容。
具体来说,以下代码使每个等级打印出(1)在其加入之前所有等级看到的输入数量和(2)所有等级看到的总输入数量。
import os import torch import torch.distributed as dist import torch.multiprocessing as mp from torch.distributed.algorithms.join import Join, Joinable, JoinHook BACKEND = "nccl" WORLD_SIZE = 2 NUM_INPUTS = 5 class CounterJoinHook(JoinHook): r""" Join hook for :class:`Counter`. Arguments: counter (Counter): the :class:`Counter` object using this hook. sync_max_count (bool): whether to sync the max count once all ranks join. """ def __init__( self, counter, sync_max_count ): self.counter = counter self.sync_max_count = sync_max_count def main_hook(self): r""" Shadows the counter's all-reduce by all-reducing a dim-1 zero tensor. """ t = torch.zeros(1, device=self.counter.device) dist.all_reduce(t) def post_hook(self, is_last_joiner: bool): r""" Synchronizes the max count across all :class:`Counter` s if ``sync_max_count=True``. """ if not self.sync_max_count: return rank = dist.get_rank(self.counter.process_group) common_rank = self.counter.find_common_rank(rank, is_last_joiner) if rank == common_rank: self.counter.max_count = self.counter.count.detach().clone() dist.broadcast(self.counter.max_count, src=common_rank) class Counter(Joinable): r""" Example :class:`Joinable` that counts the number of training iterations that it participates in. """ def __init__(self, device, process_group): super(Counter, self).__init__() self.device = device self.process_group = process_group self.count = torch.tensor([0], device=device).float() self.max_count = torch.tensor([0], device=device).float() def __call__(self): r""" Counts the number of inputs processed on this iteration by all ranks by all-reducing a dim-1 one tensor; increments its own internal count. """ Join.notify_join_context(self) t = torch.ones(1, device=self.device).float() dist.all_reduce(t) self.count += t def join_hook(self, **kwargs) -> JoinHook: r""" Return a join hook that shadows the all-reduce in :meth:`__call__`. This join hook supports the following keyword arguments: sync_max_count (bool, optional): whether to synchronize the maximum count across all ranks once all ranks join; default is ``False``. """ sync_max_count = kwargs.get("sync_max_count", False) return CounterJoinHook(self, sync_max_count) @property def join_device(self) -> torch.device: return self.device @property def join_process_group(self): return self.process_group def find_common_rank(self, rank, to_consider): r""" Returns the max rank of the ones to consider over the process group. """ common_rank = torch.tensor([rank if to_consider else -1], device=self.device) dist.all_reduce(common_rank, op=dist.ReduceOp.MAX, group=self.process_group) common_rank = common_rank.item() return common_rank def worker(rank): assert torch.cuda.device_count() >= WORLD_SIZE os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '29500' dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE) counter = Counter(torch.device(f"cuda:{rank}"), dist.group.WORLD) inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)] with Join([counter], sync_max_count=True): for _ in inputs: counter() print(f"{int(counter.count.item())} inputs processed before rank {rank} joined!") print(f"{int(counter.max_count.item())} inputs processed across all ranks!") def main(): mp.spawn(worker, nprocs=WORLD_SIZE, join=True) if __name__ == "__main__": main()
由于等级 0 看到 5 个输入,等级 1 看到 6 个输入,因此产生输出:
10 inputs processed before rank 0 joined! 11 inputs processed across all ranks! 11 inputs processed before rank 1 joined! 11 inputs processed across all ranks!
一些要强调的关键点:
Counter
实例在每次迭代中执行一次全局归约,因此主要钩子也执行一次全局归约以进行遮蔽。Counter
类在其__call__()
方法的开头调用Join.notify_join_context()
,因为这是在其每次迭代的集体通信之前的位置(即其全局归约)。is_last_joiner
参数用于确定后处理中的广播源。- 我们传递
sync_max_count
关键字参数给上下文管理器,然后将其转发到Counter
的连接钩子