avatar

Chenguang Xiao

Research Fellow University of Birmingham c.g.xiao@outlook.com

Multi-GPU training with Intel XPU using PyTorch DDP

This example shows how to run pure PyTorch DistributedDataParallel (DDP) with single machine multi-XPU environment. This could potentially provide some insights into how to use Intel XPU with advanced distributed training techniques like Fully Sharded Data Parallel (FSDP), DeepSpeed, and Accelerate.

This example is insipired by the official IPEX DDP example.

There are four keys to run this example:

  1. Install Intel XPU PyTorch and related packages.
  2. Import intel_extension_for_pytorch and oneccl_bindings_for_pytorch.
  3. Set the mpi and ccl related environment variables.
  4. Launch the script with mpirun command.

Refer to the following sections for details.

1. python environment setup:

Except for xpu backend PyTorch, intel_extension_for_pytorch, and oneccl_bindings_for_pytorch are critical packages to run this example. You can install them with the following commands:

# according to ipex: https://pytorch-extension.intel.com/installation?platform=gpu&version=v2.6.10%2Bxpu&os=linux%2Fwsl2&package=pip
python -m pip install torch==2.6.0 torchvision==0.21.0 --index-url https://download.pytorch.org/whl/xpu
python -m pip install intel-extension-for-pytorch==2.6.10+xpu oneccl_bind_pt==2.6.0+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/cn/

2. Import support packages

Make sure to import intel_extension_for_pytorch and oneccl_bindings_for_pytorch in your script. They are necessary for multi-XPU training, whether your script explicitly uses them or not.

import intel_extension_for_pytorch  # noqa: F401
import oneccl_bindings_for_pytorch  # noqa: F401

import os
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear = nn.Linear(4, 5)

    def forward(self, input):
        return self.linear(input)


if __name__ == "__main__":
    torch.xpu.manual_seed(123)  # set a seed number
    mpi_world_size = int(os.environ.get("PMI_SIZE", -1))
    mpi_rank = int(os.environ.get("PMI_RANK", -1))
    if mpi_world_size > 0:
        os.environ["RANK"] = str(mpi_rank)
        os.environ["WORLD_SIZE"] = str(mpi_world_size)
    else:
        # set the default rank and world size to 0 and 1
        os.environ["RANK"] = str(os.environ.get("RANK", 0))
        os.environ["WORLD_SIZE"] = str(os.environ.get("WORLD_SIZE", 1))
    os.environ["MASTER_ADDR"] = "127.0.0.1"  # your master address
    os.environ["MASTER_PORT"] = "29500"  # your master port

    # Initialize the process group with ccl backend
    dist.init_process_group(backend="ccl")

    # For single-node distributed training, local_rank is the same as global rank
    local_rank = dist.get_rank()
    # Only set device for distributed training on GPU
    device = "xpu:{}".format(local_rank)
    model = Model().to(device)
    if dist.get_world_size() > 1:
        model = DDP(model, device_ids=[device])

    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
    loss_fn = nn.MSELoss().to(device)
    for i in range(3):
        print("Runing Iteration: {} on device {}".format(i, device))
        input = torch.randn(2, 4).to(device)
        labels = torch.randn(2, 5).to(device)
        # forward
        print("Runing forward: {} on device {}".format(i, device))
        res = model(input)
        # loss
        print("Runing loss: {} on device {}".format(i, device))
        L = loss_fn(res, labels)
        # backward
        print("Runing backward: {} on device {}".format(i, device))
        L.backward()
        # update
        print("Runing optim: {} on device {}".format(i, device))
        optimizer.step()

3. Set MPI and CCL environment variables

Make sure to set the following environment variables before running the script:

source __oneapi_install_path__/ccl/latest/env/vars.sh

source __oneapi_install_path__/mpi/latest/env/vars.sh

4. Launch the script with mpirun

mpirun is the recommended way to launch the script for multi-XPU training.

mpirun -n 2 -l python your_script.py

Then you will see the DDP training works as expected.