Similar to PyTorch multi-XPU training with DDP, this example further extends the multi-XPU training with PyTorch Fully Sharded Data Parallel (FSDP).
This example is modefied from the official IPEX FSDP tutorial.
There are four key steps to run this example:
intel_extension_for_pytorch
and oneccl_bindings_for_pytorch
.mpi
and ccl
related environment variables.mpirun
command rather than use mp.spawn
inside your python script.The first 3 steps are exactly the same as in example PyTorch multi-XPU training with DDP. Detail about step 4 is described in the following sections.
"""
Import Intel® extension for Pytorch and Intel® oneCCL Bindings for Pytorch
"""
import argparse
import os
# Import Intel® extension for Pytorch\* and Intel® oneCCL Bindings for Pytorch\*
import intel_extension_for_pytorch # noqa: F401
import oneccl_bindings_for_pytorch # noqa: F401
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.optim.lr_scheduler import StepLR
from torch.utils.data.distributed import DistributedSampler
from torchvision import datasets, transforms
"""
Set the initialize the process group backend as Intel® oneCCL Bindings for Pytorch
"""
def setup(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12334"
# initialize the process group by Intel® oneCCL Bindings for Pytorch\*
dist.init_process_group("ccl", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output
"""
Change the device related logic from 'rank' to '"xpu:{}".format(rank)'
"""
def train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=None):
model.train()
# XPU device should be formatted as string, replace the rank with '"xpu:{}".format(rank)'
ddp_loss = torch.zeros(2).to("xpu:{}".format(rank))
if sampler:
sampler.set_epoch(epoch)
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to("xpu:{}".format(rank)), target.to("xpu:{}".format(rank))
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target, reduction="sum")
loss.backward()
optimizer.step()
ddp_loss[0] += loss.item()
ddp_loss[1] += len(data)
dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM)
if rank == 0:
print("Train Epoch: {} \tLoss: {:.6f}".format(epoch, ddp_loss[0] / ddp_loss[1]))
"""
Change the device related logic from 'rank' to '"xpu:{}".format(rank)'
"""
def test(model, rank, world_size, test_loader):
model.eval()
# correct = 0
# XPU device should be formatted as string, replace the rank with '"xpu:{}".format(rank)'
ddp_loss = torch.zeros(3).to("xpu:{}".format(rank))
with torch.no_grad():
for data, target in test_loader:
data, target = (
data.to("xpu:{}".format(rank)),
target.to("xpu:{}".format(rank)),
)
output = model(data)
ddp_loss[0] += F.nll_loss(
output, target, reduction="sum"
).item() # sum up batch loss
pred = output.argmax(
dim=1, keepdim=True
) # get the index of the max log-probability
ddp_loss[1] += pred.eq(target.view_as(pred)).sum().item()
ddp_loss[2] += len(data)
dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM)
if rank == 0:
test_loss = ddp_loss[0] / ddp_loss[2]
print(
"Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n".format(
test_loss,
int(ddp_loss[1]),
int(ddp_loss[2]),
100.0 * ddp_loss[1] / ddp_loss[2],
)
)
"""
Change the device related logic from 'rank' to '"xpu:{}".format(rank)'.
Specify the argument `device_ids` as XPU device ("xpu:{}".format(rank)) in FSDP API.
"""
def fsdp_main(rank, world_size, args):
setup(rank, world_size)
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)
dataset1 = datasets.MNIST("../data", train=True, download=True, transform=transform)
dataset2 = datasets.MNIST("../data", train=False, transform=transform)
sampler1 = DistributedSampler(
dataset1, rank=rank, num_replicas=world_size, shuffle=True
)
sampler2 = DistributedSampler(dataset2, rank=rank, num_replicas=world_size)
train_kwargs = {"batch_size": args.batch_size, "sampler": sampler1}
test_kwargs = {"batch_size": args.test_batch_size, "sampler": sampler2}
xpu_kwargs = {"num_workers": 2, "pin_memory": True, "shuffle": False}
train_kwargs.update(xpu_kwargs)
test_kwargs.update(xpu_kwargs)
train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
# my_auto_wrap_policy = functools.partial(
# size_based_auto_wrap_policy, min_num_params=100
# )
device = torch.device("xpu:{}".format(rank))
torch.xpu.set_device(device)
init_start_event = torch.xpu.Event(enable_timing=True)
init_end_event = torch.xpu.Event(enable_timing=True)
model = Net().to(device)
# Specify the argument `device_ids` as XPU device ("xpu:{}".format(rank)) in FSDP API.
model = FSDP(model, device_id=device)
optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
init_start_event.record() # type: ignore
for epoch in range(1, args.epochs + 1):
train(
args,
model,
rank,
world_size,
train_loader,
optimizer,
epoch,
sampler=sampler1,
)
test(model, rank, world_size, test_loader)
scheduler.step()
init_end_event.record() # type: ignore
if rank == 0:
print(
f"XPU event elapsed time: {init_start_event.elapsed_time(init_end_event) / 1000}sec"
)
print(f"{model}")
if args.save_model:
# use a barrier to make sure training is done on all ranks
dist.barrier()
states = model.state_dict()
if rank == 0:
torch.save(states, "mnist_cnn.pt")
cleanup()
"""
Replace CUDA runtime API with XPU runtime API.
"""
if __name__ == "__main__":
# Training settings
parser = argparse.ArgumentParser(description="PyTorch MNIST Example")
parser.add_argument(
"--batch-size",
type=int,
default=64,
metavar="N",
help="input batch size for training (default: 64)",
)
parser.add_argument(
"--test-batch-size",
type=int,
default=1000,
metavar="N",
help="input batch size for testing (default: 1000)",
)
parser.add_argument(
"--epochs",
type=int,
default=10,
metavar="N",
help="number of epochs to train (default: 14)",
)
parser.add_argument(
"--lr",
type=float,
default=1.0,
metavar="LR",
help="learning rate (default: 1.0)",
)
parser.add_argument(
"--gamma",
type=float,
default=0.7,
metavar="M",
help="Learning rate step gamma (default: 0.7)",
)
parser.add_argument(
"--no-cuda", action="store_true", default=False, help="disables CUDA training"
)
parser.add_argument(
"--seed", type=int, default=1, metavar="S", help="random seed (default: 1)"
)
parser.add_argument(
"--save-model",
action="store_true",
default=False,
help="For Saving the current Model",
)
args = parser.parse_args()
torch.manual_seed(args.seed)
# This does not work as expected on XPU, so we use MPI instead.
# WORLD_SIZE = torch.xpu.device_count()
# mp.spawn(fsdp_main, args=(WORLD_SIZE, args), nprocs=WORLD_SIZE, join=True)
######### use mpi instead of mp.spawn #########
mpi_world_size = int(os.environ.get("PMI_SIZE", -1))
mpi_rank = int(os.environ.get("PMI_RANK", -1))
if mpi_world_size > 0:
os.environ["RANK"] = str(mpi_rank)
os.environ["WORLD_SIZE"] = str(mpi_world_size)
else:
# set the default rank and world size to 0 and 1
os.environ["RANK"] = str(os.environ.get("RANK", 0))
os.environ["WORLD_SIZE"] = str(os.environ.get("WORLD_SIZE", 1))
fsdp_main(mpi_rank, mpi_world_size, args)
mp.spawn
The main difference between this example and the official PyTorch FSDP example is that MPI is used to start the parallel processing instead of mp.spawn
.
Let’s focus on the last part of the code as below:
- This does not work as expected on XPU, so we use MPI instead.
- WORLD_SIZE = torch.xpu.device_count()
- mp.spawn(fsdp_main, args=(WORLD_SIZE, args), nprocs=WORLD_SIZE, join=True)
+ ######### use mpi instead of mp.spawn #########
+ mpi_world_size = int(os.environ.get("PMI_SIZE", -1))
+ mpi_rank = int(os.environ.get("PMI_RANK", -1))
+ if mpi_world_size > 0:
+ os.environ["RANK"] = str(mpi_rank)
+ os.environ["WORLD_SIZE"] = str(mpi_world_size)
+ else:
+ # set the default rank and world size to 0 and 1
+ os.environ["RANK"] = str(os.environ.get("RANK", 0))
+ os.environ["WORLD_SIZE"] = str(os.environ.get("WORLD_SIZE", 1))
+ fsdp_main(mpi_rank, mpi_world_size, args)
Extract the RANK and WORLD_SIZE from the environment variables set by MPI, and then pass them to the fsdp_main
function.
Then run the script with mpirun
command instead of using mp.spawn
inside the python script.
mpirun -n 2 -l python xpu-fsdp-demo.py