Skip to main content

PyTorch

Distributed Training

Process Communication

When spawning multiple processes for distributed training, there are some terminology to be aware of. The world size is the total number of processes used in training. In DDP, this is typically the number of GPUs as you are using one process per GPU. The rank is the unique ID of a process among the processes and spans 0, ..., world_size - 1. Ranks are important to coordinate inter process communication.

A launcher from scratch (which you wouldn't do in practice) would look something like:

"""run.py:"""
#!/usr/bin/env python
import os
import sys
import torch
import torch.distributed as dist
import torch.multiprocessing as mp

def run(rank, size):
""" Distributed function to be implemented later. """
pass

def init_process(rank, size, fn, backend='gloo'):
""" Initialize the distributed environment. """
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29500'
dist.init_process_group(backend, rank=rank, world_size=size)
fn(rank, size)


if __name__ == "__main__":
world_size = 2
processes = []
if "google.colab" in sys.modules:
print("Running in Google Colab")
mp.get_context("spawn")
else:
mp.set_start_method("spawn")
for rank in range(world_size):
p = mp.Process(target=init_process, args=(rank, world_size, run))
p.start()
processes.append(p)

for p in processes:
p.join()

In point-to-point communication, a process transfers data to another process. This can be done using send and recv for blocking communication and isend and irecv for non-blocking communication in the torch.distributed package. Point-to-point communication is usually used when more fine-grained control over process communication is desired.

In most cases, distributed training employs collective communication, which involve communication across all processes within a group. A group is a subset of all processes and can be created by passing a list of ranks to dist.new_group(group). By default collective communication acts on all processes in the world, but you can specify a particular group, e.g. dist.all_reduce(tensor, op, group).

The collective function all_reduce as an example performs the operation across all tensors and sends the result to each process:

all reduce

Some other important collectives to remember:

  • dist.broadcast(tensor, src, group) - copies tensor from src to all processes
  • dist.all_gather(tensor_list, tensor, group) - copies tensor from all processes to tensor_list on all processes

broadcastgather

There are many more collectives with diverse functions, some that don't do any operations except copy tensors, some that have distinct communication patterns such as:

  • dist.reduce(tensor, dst, op, group) - applies op to each tensor and stores result in process with rank dst (an int)
  • dist.scatter(tensor, scatter_list, src, group) - copies the i tensor scatter_list[i] to the ith process
  • dist.gather(tensor, gather_list, dst, group) - copies tensor from all processes to process with rank dst
  • dist.barrier(group) - blocks all processes in group until each one has entered this function

PyTorch provides many commutative operators for element wise operations (this is only a subset):

  • dist.ReduceOp.SUM
  • dist.ReduceOp.PRODUCT
  • dist.ReduceOp.MAX
  • dist.ReduceOp.MIN
  • dist.ReduceOp.BAND
  • dist.ReduceOp.BOR
  • dist.ReduceOp.BXOR
  • dist.ReduceOp.PREMUL_SUM

DDP Internals

What happens under the hood with DDP?

Key Idea

DDP adds autograd hooks to each parameter to track when gradients are available. Parameters gradients are organized into buckets and during backpropagation each bucket is all reduced to get the mean gradients for that bucket across all processes.

When you wrap ddp_model = DDP(model, device_ids=[rank]) the DDP constructor takes the reference to the local module and broadcasts its state_dict() from rank 0 to all other processes to ensure models start from the same state. Then each DDP process creates a local Reducer which handles gradient synchronization during backprop.

The Reducer organizes parameter gradients into buckets and reduces one bucket at a time. The mapping from parameter gradients to bucket is determined at construction, with bucket size configurable via bucket_cap_mb. The buckets also go in reverse order of Model.parameters() to follow the availability of gradients during the backward pass. The Reducer also registers autograd hooks with one hook per parameter, so when a gradient is ready it fires the hook and DDP will know.

During the backward pass, the autograd hooks setup during construction fires to trigger gradient synchronization. When gradients in one bucket all fire and DDP knows its ready, the Reducer starts an allreduce on the bucket to calculate mean gradient across all processes. When all buckets are ready, DDP blocks for all allreduce to finish. Finally, the averaged gradients are written to param.grad for the final weight update. All processes should share the same gradients at this point after the allreduce.

ddp

Communication Backends

Why do we have different backends? DDP and torch.distributed are merely abstractions, and the backend specifies how the processes actually communicate with one another. For instance, gloo is a great backend because it implements all point to point communication and collectives on CPU, and all collectives on GPU.

Key Idea

DDP is an API that orchestrates when you sync, the backend decides how you sync.

By default PyTorch comes with gloo and nccl, but mpi is not (you need to manually install).

Key Idea

Which backend to use:

  • Use nccl backend for distributed training with CUDA GPU. (NCCL stands for NVIDIA Collective Communication Library)
  • Use gloo backend for distributed training with CPU.
  • Use xccl backend for distributed training with XPU GPU.

Training Script

In practice, you will have the global process of rank 0 own the experiment (e.g. do logging) while the other ranks do gradient computation.

The overall steps are as follows for distributed code:

  1. Spawn N processes for N GPUs
  2. Create a process group (dist.init_process_group)
  3. Use DistributedSampler
  4. Wrap model with DDP
  5. Close process group (dist.destroy_process_group)

First, you must call mp.spawn() (typically in main) to spawn the multiple processes. At this point they are not aware of each other and are not assigned a GPU. Then before instantiating the model, each process must call:

def init_distributed(local_rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
torch.cuda.set_device(local_rank)
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)

The call dist.init_process_group() is a blocking point with processes blocked until all processes join the group.

torch.cuda.set_device(local_rank) ensures that each process has its exclusive GPU. Note that instead of doing device = "cuda" you will also have to specify the device ordinal either using torch.cuda.set_device(local_rank) or having device = torch.device("cuda", local_rank) which you pass around.

Key Idea

A torch.device contains a device type (["cuda", "cpu", "mps"]), and an optional device ordinal for the device type. If only the type is specified e.g. torch.device('cuda') then the implicit index is the current device from torch.cuda.current_device() which by default is the 0 device.

You can construct a device with the optional device ordinals:

# create device with device ordinal 0
torch.device('cuda', 0)

# same
torch.device('cuda:0')

# same if torch.accelerator.current_accelerator() == cuda
torch.device(0)

Methods taking device take both the object or the string. As you know, torch.device objects can also be substituted with a string, so tensor.to(torch.device('cuda:1')) is equivalent to tensor.to('cuda:1').

For using the DistributedSampler in the dataloader, make sure that shuffle=False is set in the dataloader. With DistributedSampler you must call sampler.set_epoch(epoch) at the beginning of each epoch for index shuffling. Note also that indices may be duplicated if the number of samples is not perfectly divisible by number of processes (because each process needs to get the same number of samples).

When using torchmetrics.Metric in a distributed setting (with a defined dist_reduce_fx), note that all ranks must call compute() to sync their results even if rank 0 is the only one that logs those metrics. Otherwise you must call with sync_on_compute=False.

It is also important to keep the overhead of synchronization low, so only calling synchronization code when necessary. For instance, with the training loop, it is often not necessary to all reduce the train loss across the ranks. However, for early stopping it is important that all ranks have the same validation loss, therefore it is recommended to all reduce the validation loss every epoch (which is not very frequent).

Pitfall

Avoid using num_workers > 0 for DDP dataloading. It causes hard to debug OOM errors and processes getting SIGKILL.

Resources

  • Good high level overview of HPC and DDP here
  • Medium article on implementing DDP here