This example shows how to run pure PyTorch DistributedDataParallel (DDP) with single machine multi-XPU environment. This could potentially provide some insights into how to use Intel XPU with advanced distributed training techniques like Fully Sharded Data Parallel (FSDP), DeepSpeed, and Accelerate.
This example is insipired by the official IPEX DDP example.
There are four keys to run this example:
intel_extension_for_pytorch
and oneccl_bindings_for_pytorch
.mpi
and ccl
related environment variables.mpirun
command.Refer to the following sections for details.
Except for xpu backend PyTorch, intel_extension_for_pytorch
, and oneccl_bindings_for_pytorch
are critical packages to run this example.
You can install them with the following commands:
# according to ipex: https://pytorch-extension.intel.com/installation?platform=gpu&version=v2.6.10%2Bxpu&os=linux%2Fwsl2&package=pip
python -m pip install torch==2.6.0 torchvision==0.21.0 --index-url https://download.pytorch.org/whl/xpu
python -m pip install intel-extension-for-pytorch==2.6.10+xpu oneccl_bind_pt==2.6.0+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/cn/
Make sure to import intel_extension_for_pytorch
and oneccl_bindings_for_pytorch
in your script.
They are necessary for multi-XPU training, whether your script explicitly uses them or not.
import intel_extension_for_pytorch # noqa: F401
import oneccl_bindings_for_pytorch # noqa: F401
import os
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.linear = nn.Linear(4, 5)
def forward(self, input):
return self.linear(input)
if __name__ == "__main__":
torch.xpu.manual_seed(123) # set a seed number
mpi_world_size = int(os.environ.get("PMI_SIZE", -1))
mpi_rank = int(os.environ.get("PMI_RANK", -1))
if mpi_world_size > 0:
os.environ["RANK"] = str(mpi_rank)
os.environ["WORLD_SIZE"] = str(mpi_world_size)
else:
# set the default rank and world size to 0 and 1
os.environ["RANK"] = str(os.environ.get("RANK", 0))
os.environ["WORLD_SIZE"] = str(os.environ.get("WORLD_SIZE", 1))
os.environ["MASTER_ADDR"] = "127.0.0.1" # your master address
os.environ["MASTER_PORT"] = "29500" # your master port
# Initialize the process group with ccl backend
dist.init_process_group(backend="ccl")
# For single-node distributed training, local_rank is the same as global rank
local_rank = dist.get_rank()
# Only set device for distributed training on GPU
device = "xpu:{}".format(local_rank)
model = Model().to(device)
if dist.get_world_size() > 1:
model = DDP(model, device_ids=[device])
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
loss_fn = nn.MSELoss().to(device)
for i in range(3):
print("Runing Iteration: {} on device {}".format(i, device))
input = torch.randn(2, 4).to(device)
labels = torch.randn(2, 5).to(device)
# forward
print("Runing forward: {} on device {}".format(i, device))
res = model(input)
# loss
print("Runing loss: {} on device {}".format(i, device))
L = loss_fn(res, labels)
# backward
print("Runing backward: {} on device {}".format(i, device))
L.backward()
# update
print("Runing optim: {} on device {}".format(i, device))
optimizer.step()
Make sure to set the following environment variables before running the script:
source __oneapi_install_path__/ccl/latest/env/vars.sh
source __oneapi_install_path__/mpi/latest/env/vars.sh
mpirun
mpirun
is the recommended way to launch the script for multi-XPU training.
mpirun -n 2 -l python your_script.py
Then you will see the DDP training works as expected.