avatar

Chenguang Xiao

Research Fellow University of Birmingham c.g.xiao@outlook.com

Multi-XPU training with PyTorch FSDP

Similar to PyTorch multi-XPU training with DDP, this example further extends the multi-XPU training with PyTorch Fully Sharded Data Parallel (FSDP).

This example is modefied from the official IPEX FSDP tutorial.

There are four key steps to run this example:

  1. Install Intel XPU PyTorch and related packages.
  2. Import intel_extension_for_pytorch and oneccl_bindings_for_pytorch.
  3. Set the mpi and ccl related environment variables.
  4. Launch the script with mpirun command rather than use mp.spawn inside your python script.

The first 3 steps are exactly the same as in example PyTorch multi-XPU training with DDP. Detail about step 4 is described in the following sections.

1. Example python code

"""
Import Intel® extension for Pytorch and Intel® oneCCL Bindings for Pytorch
"""

import argparse
import os

# Import Intel® extension for Pytorch\* and Intel® oneCCL Bindings for Pytorch\*
import intel_extension_for_pytorch  # noqa: F401
import oneccl_bindings_for_pytorch  # noqa: F401
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.optim.lr_scheduler import StepLR
from torch.utils.data.distributed import DistributedSampler
from torchvision import datasets, transforms

"""
Set the initialize the process group backend as Intel® oneCCL Bindings for Pytorch
"""


def setup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12334"

    # initialize the process group by Intel® oneCCL Bindings for Pytorch\*
    dist.init_process_group("ccl", rank=rank, world_size=world_size)


def cleanup():
    dist.destroy_process_group()


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output


"""
Change the device related logic from 'rank' to '"xpu:{}".format(rank)'
"""


def train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=None):
    model.train()
    # XPU device should be formatted as string, replace the rank with '"xpu:{}".format(rank)'
    ddp_loss = torch.zeros(2).to("xpu:{}".format(rank))
    if sampler:
        sampler.set_epoch(epoch)
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to("xpu:{}".format(rank)), target.to("xpu:{}".format(rank))
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target, reduction="sum")
        loss.backward()
        optimizer.step()
        ddp_loss[0] += loss.item()
        ddp_loss[1] += len(data)

    dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM)
    if rank == 0:
        print("Train Epoch: {} \tLoss: {:.6f}".format(epoch, ddp_loss[0] / ddp_loss[1]))


"""
Change the device related logic from 'rank' to '"xpu:{}".format(rank)'
"""


def test(model, rank, world_size, test_loader):
    model.eval()
    # correct = 0
    # XPU device should be formatted as string, replace the rank with '"xpu:{}".format(rank)'
    ddp_loss = torch.zeros(3).to("xpu:{}".format(rank))
    with torch.no_grad():
        for data, target in test_loader:
            data, target = (
                data.to("xpu:{}".format(rank)),
                target.to("xpu:{}".format(rank)),
            )
            output = model(data)
            ddp_loss[0] += F.nll_loss(
                output, target, reduction="sum"
            ).item()  # sum up batch loss
            pred = output.argmax(
                dim=1, keepdim=True
            )  # get the index of the max log-probability
            ddp_loss[1] += pred.eq(target.view_as(pred)).sum().item()
            ddp_loss[2] += len(data)

    dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM)

    if rank == 0:
        test_loss = ddp_loss[0] / ddp_loss[2]
        print(
            "Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n".format(
                test_loss,
                int(ddp_loss[1]),
                int(ddp_loss[2]),
                100.0 * ddp_loss[1] / ddp_loss[2],
            )
        )


"""
Change the device related logic from 'rank' to '"xpu:{}".format(rank)'.
Specify the argument `device_ids` as XPU device ("xpu:{}".format(rank)) in FSDP API.
"""


def fsdp_main(rank, world_size, args):
    setup(rank, world_size)

    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
    )

    dataset1 = datasets.MNIST("../data", train=True, download=True, transform=transform)
    dataset2 = datasets.MNIST("../data", train=False, transform=transform)

    sampler1 = DistributedSampler(
        dataset1, rank=rank, num_replicas=world_size, shuffle=True
    )
    sampler2 = DistributedSampler(dataset2, rank=rank, num_replicas=world_size)

    train_kwargs = {"batch_size": args.batch_size, "sampler": sampler1}
    test_kwargs = {"batch_size": args.test_batch_size, "sampler": sampler2}
    xpu_kwargs = {"num_workers": 2, "pin_memory": True, "shuffle": False}
    train_kwargs.update(xpu_kwargs)
    test_kwargs.update(xpu_kwargs)

    train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)
    test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
    # my_auto_wrap_policy = functools.partial(
    #     size_based_auto_wrap_policy, min_num_params=100
    # )
    device = torch.device("xpu:{}".format(rank))
    torch.xpu.set_device(device)

    init_start_event = torch.xpu.Event(enable_timing=True)
    init_end_event = torch.xpu.Event(enable_timing=True)

    model = Net().to(device)
    # Specify the argument `device_ids` as XPU device ("xpu:{}".format(rank)) in FSDP API.
    model = FSDP(model, device_id=device)

    optimizer = optim.Adadelta(model.parameters(), lr=args.lr)

    scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
    init_start_event.record()  # type: ignore
    for epoch in range(1, args.epochs + 1):
        train(
            args,
            model,
            rank,
            world_size,
            train_loader,
            optimizer,
            epoch,
            sampler=sampler1,
        )
        test(model, rank, world_size, test_loader)
        scheduler.step()

    init_end_event.record()  # type: ignore

    if rank == 0:
        print(
            f"XPU event elapsed time: {init_start_event.elapsed_time(init_end_event) / 1000}sec"
        )
        print(f"{model}")

    if args.save_model:
        # use a barrier to make sure training is done on all ranks
        dist.barrier()
        states = model.state_dict()
        if rank == 0:
            torch.save(states, "mnist_cnn.pt")

    cleanup()


"""
Replace CUDA runtime API with XPU runtime API.
"""
if __name__ == "__main__":
    # Training settings
    parser = argparse.ArgumentParser(description="PyTorch MNIST Example")
    parser.add_argument(
        "--batch-size",
        type=int,
        default=64,
        metavar="N",
        help="input batch size for training (default: 64)",
    )
    parser.add_argument(
        "--test-batch-size",
        type=int,
        default=1000,
        metavar="N",
        help="input batch size for testing (default: 1000)",
    )
    parser.add_argument(
        "--epochs",
        type=int,
        default=10,
        metavar="N",
        help="number of epochs to train (default: 14)",
    )
    parser.add_argument(
        "--lr",
        type=float,
        default=1.0,
        metavar="LR",
        help="learning rate (default: 1.0)",
    )
    parser.add_argument(
        "--gamma",
        type=float,
        default=0.7,
        metavar="M",
        help="Learning rate step gamma (default: 0.7)",
    )
    parser.add_argument(
        "--no-cuda", action="store_true", default=False, help="disables CUDA training"
    )
    parser.add_argument(
        "--seed", type=int, default=1, metavar="S", help="random seed (default: 1)"
    )
    parser.add_argument(
        "--save-model",
        action="store_true",
        default=False,
        help="For Saving the current Model",
    )
    args = parser.parse_args()

    torch.manual_seed(args.seed)

    # This does not work as expected on XPU, so we use MPI instead.
    # WORLD_SIZE = torch.xpu.device_count()
    # mp.spawn(fsdp_main, args=(WORLD_SIZE, args), nprocs=WORLD_SIZE, join=True)

    ######### use mpi instead of mp.spawn #########
    mpi_world_size = int(os.environ.get("PMI_SIZE", -1))
    mpi_rank = int(os.environ.get("PMI_RANK", -1))
    if mpi_world_size > 0:
        os.environ["RANK"] = str(mpi_rank)
        os.environ["WORLD_SIZE"] = str(mpi_world_size)
    else:
        # set the default rank and world size to 0 and 1
        os.environ["RANK"] = str(os.environ.get("RANK", 0))
        os.environ["WORLD_SIZE"] = str(os.environ.get("WORLD_SIZE", 1))

    fsdp_main(mpi_rank, mpi_world_size, args)

2. Start parallel processing with MPI instead of mp.spawn

The main difference between this example and the official PyTorch FSDP example is that MPI is used to start the parallel processing instead of mp.spawn. Let’s focus on the last part of the code as below:

    - This does not work as expected on XPU, so we use MPI instead.
    - WORLD_SIZE = torch.xpu.device_count()
    - mp.spawn(fsdp_main, args=(WORLD_SIZE, args), nprocs=WORLD_SIZE, join=True)

    + ######### use mpi instead of mp.spawn #########
    + mpi_world_size = int(os.environ.get("PMI_SIZE", -1))
    + mpi_rank = int(os.environ.get("PMI_RANK", -1))
    + if mpi_world_size > 0:
    +     os.environ["RANK"] = str(mpi_rank)
    +     os.environ["WORLD_SIZE"] = str(mpi_world_size)
    + else:
    +     # set the default rank and world size to 0 and 1
    +     os.environ["RANK"] = str(os.environ.get("RANK", 0))
    +     os.environ["WORLD_SIZE"] = str(os.environ.get("WORLD_SIZE", 1))

    + fsdp_main(mpi_rank, mpi_world_size, args)

Extract the RANK and WORLD_SIZE from the environment variables set by MPI, and then pass them to the fsdp_main function. Then run the script with mpirun command instead of using mp.spawn inside the python script.

mpirun -n 2 -l python xpu-fsdp-demo.py