Distributed Training with DistributedDataParallel

7 minute read

This is the final part of a 3-part series covering multiprocessing, distributed communication, and distributed training in PyTorch.

In this article, let’s see how we can make use of torch.nn.parallel.DistributedDataParallel for distributed training.

What is it?

DistributedDataParallel helps leveraging computing power of multiple computation devices (GPUs/CPUs) across multiple machines to accelerate your training. It employs data parallelism, i.e. concurrent execution of the same function on multiple computing devices. We wrap our model with DDP and then launch the training script multiple times: one/multiple process(es) on one/multiple machine(s).

Here is a summary on how DDP works: Each process holds a replica of our nn.Module. The initial model state (of process rank 0) is broadcasted to all processes. At each training step the accumulated gradients across all processes are used to update the parameters jointly.

This has the same effect as training with a large batch size, thus reducing the total training time. Its design ensures that it is mathematically equivalent to training on a single process.

For its algorithmic and engineering details, checkout the DDP paper. I may write an article covering this in the future, it’s pretty interesting :smile:

A local training example

We start with this piece of local training code on MNIST:

Import necessary libraries:

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms

Define an arbitrary model:

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1)
        self.n1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1)
        self.n2 = nn.BatchNorm2d(64)
        self.fc1 = nn.Linear(1600, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.n1(x)
        x = F.relu_(x)
        x = F.max_pool2d(x, 2)

        x = self.conv2(x)
        x = self.n2(x)
        x = F.relu_(x)
        x = F.max_pool2d(x, 2)

        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu_(x)
        x = self.fc2(x)
        x = F.log_softmax(x, dim=-1)

        return x

We can use 2 utility functions:

def get_mnist_loader(train=True, **kwargs):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    dataset = datasets.MNIST('./data', train=train, download=True, transform=transform)
    loader = torch.utils.data.DataLoader(dataset, **kwargs)
    return loader

def train(model, optimizer, train_loader, device):
    for X, target in train_loader:
        X = X.to(device)
        target = target.to(device)
        optimizer.zero_grad()
        pred = model(X)
        loss = F.nll_loss(pred, target)
        loss.backward()
        optimizer.step()

The good thing is, later on, we won’t touch the above code.

Finally, define our main function:

if __name__ == "__main__":
    cuda = True
    device = torch.device('cuda' if cuda else 'cpu')
    
    model = Net().to(device)
    train_loader = get_mnist_loader(train=True, batch_size=64, shuffle=True)
    optimizer = torch.optim.Adadelta(model.parameters(), lr=0.01)

    model.train()
    for _ in range(5):
        train(model, optimizer, train_loader, device)

Nothing out of the ordinary here. Most local training code follows the same routine as this example. Now let’s add distributed training to this code.

Distributed training with DDP

The code

DDP uses torch.distributed under the hood. To use DDP, we first have to initialize the process group and the torch.distributed package. In-depth discussion of the initialization process as well as the package itself can be found in the previous article.

Most commonly, we initialize the process group via TCP using the environment variable method. We would need the master node’s address and port number—to set up all communications—, as well as the world size and rank of the current process. We use NCCL backend for GPU training and Gloo backend for CPU training:

import torch.distributed as dist
...
if __name__ == "__main__":
    cuda = True
    # Init torch.distributed
    dist.init_process_group('nccl' if cuda else 'gloo')
    device = torch.device('cuda' if cuda else 'cpu')
    ...

The init call is equivalent to dist.init_process_group(..., init_method='env://'). This also means that all the information needed to initialize the process group should be provided by the environment variables, but how? We will handle this at launch time using torchrun, but let’s not worry about that for now.

We then simply wrap our model with DistributedDataParallel. Let’s start with the easy case where we run one process per node. For example, this is used when we have multiple nodes, one GPU per node. Since we don’t have to worry about using the correct device (GPU) for each process, the wrapping process is as simple as:

from torch.nn.parallel import DistributedDataParallel as DDP
...
if __name__ == "__main__":
    cuda = True
    # Init torch.distributed
    dist.init_process_group('nccl' if cuda else 'gloo')
    device = torch.device('cuda' if cuda else 'cpu')
    # Define and wrap model with DDP
    model = Net().to(device)
    model = DDP(model)
    ...

And voilà, we are ready to go!

torchrun

torch.distributed.launch used to be the go-to way to launch distributed training. However, it will soon be deprecated in favor of torchrun. Documentation for torchrun can be found here. When using torchrun, useful environment variables are made available to each process, including MASTER_ADDR, MASTER_PORT, WORLD_SIZE, RANK, and LOCAL_RANK.

For the example above, we can launch the training process by running the following command for each node:

torchrun --nnodes=$NUM_NODES --nproc_per_node=1 --master_addr=$MASTER_ADDR train.py

where $NUM_NODES is the number of nodes, $MASTER_ADDR is the address of the master node. To test it locally, set --nnodes=1 and --master_addr=127.0.0.1, or use the standalone flag as shown below.

torchrun --standalone --nnodes=1 --nproc_per_node=$NUM_PROCCESSES train.py

where $NUM_PROCCESSES is the number of processes. This is useful for the case where we have one node, multiple GPUs.

Multiple GPUs per node

Now, let’s consider the case where we want to run multiple processes per node. In practical uses, this often corresponds to the multiple nodes, multiple GPUs per node scenario. In this case, each process within a node must be assigned to a specific GPU.

Assuming that we know the number of GPUs \(N\) of node A, we will launch \(N\) processes on node A. Again, if we launch the processes using torchrun, we will have WORLD_SIZE, RANK, LOCAL_RANK in our env. var. LOCAL_RANK denotes the rank of a process within its node, while RANK denotes the rank of a process across all nodes, or globally.

Launching \(N\) processes within a node will give the processes LOCAL_RANKs from 0 to \(N-1\). We can use them as indices to access the corresponding GPUs.

...
if __name__ == "__main__":
    cuda = True
    # Init torch.distributed
    dist.init_process_group('nccl' if cuda else 'gloo')
    # Get LOCAL_RANK
    LOCAL_RANK = int(os.getenv('LOCAL_RANK', 0))
    # Set device
    if cuda:
        torch.cuda.set_device(LOCAL_RANK)
        device = torch.device(f'cuda:{LOCAL_RANK}')
    else:
        device = torch.device('cpu')
    # Define and wrap model with DDP
    model = Net().to(device)
    if cuda:
        model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
    else:
        model = DDP(model)
    ...

Use device where needed instead of selecting the device again to avoid unnecessary complications.

The launch process is similar to the previous example, but we need to set --nproc_per_node to appropriate value, e.g. to the number of GPUs on each node. For example:

Node A (4 GPUs, 192.168.1.100, master):

torchrun --nnodes=2 --nproc_per_node=4 --master_addr=192.168.1.100 train.py

Node B (2 GPUs, 192.168.1.101):

torchrun --nnodes=2 --nproc_per_node=2 --master_addr=192.168.1.100 train.py

Saving and loading

Since the models are synchronized across all processes, a storage-efficient way to save and load checkpoints is to save just once, commonly on process RANK=0. This means that other processes must somehow have access to the same checkpoint file. If that is infeasible, we can save one checkpoint per node (on proc. LOCAL_RANK=0), thus we can be sure that all processes within a node have access to a checkpoint file that is identical accross all nodes.

Thus we can do something like this:

def save(model, optimizer, path):
    LOCAL_RANK = int(os.getenv('LOCAL_RANK', 0))
    if LOCAL_RANK == 0:
        torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }, path)

def load(model, optimizer, path):
    LOCAL_RANK = int(os.getenv('LOCAL_RANK', 0))
    checkpoint = torch.load(path, map_location={f'cuda:0': f'cuda:{LOCAL_RANK}'})
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

And to save and load in main:

...
if __name__ == "__main__":
    ...
    # Save the model and optimizer (on process local rank 0)
    save(model, optimizer, 'checkpoint.pth')
    # Wait til the saving is done
    dist.barrier()
    # Load the model on all processes
    load(model, optimizer, 'checkpoint.pth')

If we load after save, it is necessary to synchronize the processes using dist.barrier() to make sure process 0 has finished saving before others read the file. After the load, we shall also synchronize the processes. However, assuming that we will do training after load, the backward pass of DDP will synchronize them anyway.

Also note that in torch.load, we pass a dictionary to map_location instead of a specific device. This will map all location tags cuda:0 in the saved file to cuda:LOCAL_RANK. This is necessary to make sure that the model is loaded on the correct device, and it works for both CPU and GPU, since it would just ignore everything had we put the model on CPU (no cuda:0 tag, no remapping).

These are the basic stuffs I would like to cover. For more details, please refer to the DDP API doc, DDP note, and DDP tutorial.

See you next time!