Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[1/N] Initial implementation of local SPMD support #8810

Open
wants to merge 16 commits into
base: master
Choose a base branch
from

Conversation

lsy323
Copy link
Collaborator

@lsy323 lsy323 commented Mar 9, 2025

Before this PR, user must uses all global devices in their SPMD program, this limits the flexibility of running MPMD + SPMD under multi-host environment.

This PR enables local SPMD by setting environment variable XLA_USE_LOCAL_SPMD. Local SPMD and global SPMD can be switched in the same python program by switching on and off the environment variable.

Usage

Example usage is:

xr.use_spmd()
os.environ['XLA_USE_LOCAL_SPMD'] = '1'
local_mesh = Mesh(local_device_ids, mesh_shape, axis_names)
# Run local SPMD program
...

os.environ['XLA_USE_LOCAL_SPMD'] = '0'
global_mesh = Mesh(global_device_ids, global_mesh_shape, axis_names)
# Run global SPMD program
...

An example script is here

To make a local SPMD work in multi-host setting, we need to ensure:

  1. In the lowered HLO graph, the global ordinals are logical indices (starts from zero), instead of physical device ids.
  2. During XLA compilation, XLA needs to be configured to target at only local devices.

Implementation:

SPMD Python API:

  • Allow create mesh with only local devices.
  • Update _get_tile_assignment to support local device mesh, so that the device id will start from 0 in the HLO sharding annotation

Sharing Utilities:

  • Support local SPMD in GetShardReplicaAndIndicesForDevices. This is achieved by deriving global ordinal from physical device and SPMD sharding annotation.

PJRT computation client:

  • (Env var is used here) Configure XLA compilation option to only use local devices for the compiled program when local SPMD is enabled.

Test:

  • Added c++ test for sharding util change for local spmd use case
  • Verified on multi-host TPU VM, on an example script of running a small VAE with local SPMD, with different input resolutions on different host. Note that the global SPMD and local SPMD are not bridged, we need to have an API to create global SPMD tensor from local SPMD.

Future work:

  • Avoid using env var for local SPMD control.

@lsy323 lsy323 force-pushed the lsiyuan/local-spmd-impl branch from 0833b24 to 3865f67 Compare March 10, 2025 02:58
@lsy323 lsy323 marked this pull request as ready for review March 10, 2025 05:24
@lsy323 lsy323 changed the title local spmd impl [1/N] Initial implementation of local SPMD support Mar 10, 2025
@lsy323
Copy link
Collaborator Author

lsy323 commented Mar 10, 2025

I tried hard avoid using env var to control local SPMD enablement. (In this commit

My thought is:

  1. We need to find a way in PjrtComputationClient::Compile to determine the target devices of the XLA program we want to compile and execute.

  2. The information of sharding is stored in 2 places:
    a) xla::OpSharding in the lowered xla::XlaOp which has sharding annotation
    b) SPMD mesh

Determining the target device from SPMD mesh is the much more straightforward. However, it's not passed down to LTC and LTC cannot be easily extended to pass the SPMD mesh from Python layer (Because torch_xla work on tensors, attaching mesh to each sharded tensor seems to be redundant)

Therefore we have to derive the target devices from the xla::OpSharding.

  1. For each XLA computation, it will be constructed from a dedicated lowering context. Therefore I planned to derive the target device information from the xla::OpSharding of lowered XlaOp, by checking the tile_assignment_devices field.

  2. After XLA lowering, we will create a compilation instance, we store the target device info in the compilation instance object and then retrieve it in computation client before compilation.

The issue I have is in step 3:
It's not a valid solution to derive target SPMD devices from the tile_assignment_devices of Xla::Opsharding. In some cases, it's empty.

Since many test cases are failing with the above attempt, and there are many flavors of sharding annotation. I decided to use the env var to control the local SPMD enablement, which is clean since it's only used in the PjrtComputationClient

@lsy323
Copy link
Collaborator Author

lsy323 commented Mar 10, 2025

The GPU test failing is irrelevant, head is failing with the same error.

@lsy323 lsy323 requested review from tengyifei and qihqi March 10, 2025 17:20
@miladm miladm assigned miladm and lsy323 and unassigned miladm Mar 11, 2025
@miladm miladm requested a review from pgmoka March 11, 2025 22:51
@miladm miladm added the distributed SPMD and other distributed things. label Mar 11, 2025
assert num_devices > 0, "This requires XLA supported device(s)."
assert mesh.size() == num_devices, \
assert mesh.size() == num_devices or mesh.size() == num_local_devices, \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seem like a smell. Why do we only check the mesh size during mark_sharding? Furthermore, should we check that the mesh is a local mesh only when XLA_USE_LOCAL_SPMD==1?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems that these checks should be inside the Mesh constructor IMO

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the mesh constructor : )

Copy link
Collaborator

@rpsilva-aws rpsilva-aws Mar 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, should we branch the local and global variants? These asserts are technically alleviating the constraints for both cases - e.g., I can define a mesh size of 32 for 64 global devices, if that host only has 32 addressable devices.

On the downside, I can see how this will complicate in how we get rid of the env var.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should check if the mesh only contains local devices under local spmd, updated the condition

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the mesh constructor : )

Wait, this is not the mesh constructor. This is the mark_sharding function. It seems very strange that we're doing these localized checks in the mark_sharding function. IMO this check should be moved to the Mesh constructor.

@tengyifei
Copy link
Collaborator

We need to find a way in PjrtComputationClient::Compile to determine the target devices

Do you know how JAX figures out which target devices to run the graph on? E.g. if I create two smaller meshes [0, 1] and [2, 3] in JAX in a v6e-4, and jit a computation that uses both meshes, I suppose can JAX figure out that all 4 devices are involved in the computation, and launch the graph on all 4 devices (although some ops will only use 2 out of 4 devices)?

std::vector<std::string> devices = {"TPU:8", "TPU:9", "TPU:10", "TPU:11",
"TPU:12", "TPU:13", "TPU:14", "TPU:15"};

// 1D tiled
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the different tiled methods, could we not do parasitized testing to significantly save code, and make things more readable?

INSTANTIATE_TEST_CASE_P reference: https://www.sandordargo.com/blog/2019/04/24/parameterized-testing-with-gtest

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the suggestion! But I personally think it's hard to parameterize in this case. We need to programmatically derive the mesh and expected output from the test parameters. It's hard to generalize in this case (1d tile, 2d tile and etc). I feel like in the current implementation, even if it's long, but it's more readable.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose making paramatied functions in Python might be easier than C++. If we want to keep the way we are doing this, making these different tests might make things easier to test.

Currently if this test fails, it is hard to tell which case it is related to.

@lsy323
Copy link
Collaborator Author

lsy323 commented Mar 13, 2025

We need to find a way in PjrtComputationClient::Compile to determine the target devices

Do you know how JAX figures out which target devices to run the graph on? E.g. if I create two smaller meshes [0, 1] and [2, 3] in JAX in a v6e-4, and jit a computation that uses both meshes, I suppose can JAX figure out that all 4 devices are involved in the computation, and launch the graph on all 4 devices (although some ops will only use 2 out of 4 devices)?

I'm not sure if it's possible to use multiple meshed in a single computation. I think in the same SPMD program, we have to use only 1 mesh. I can have a try and update here later.

@rpsilva-aws
Copy link
Collaborator

rpsilva-aws commented Mar 13, 2025

Nice, this is great! Thanks for following up on it :)

A few other general questions:

Note that the global SPMD and local SPMD are not bridged, we need to have an API to create global SPMD tensor from local SPMD.

This is interesting, do we have a plan for this? We were stumbling on this too.

nit: We could consider adding an helper xs method for generating local SPMD meshes (similar to how we have one with 1D mesh), that includes the following taken from your ref above, given a partition_spec and mesh_shape:

process_id = xr.process_index()
num_local_devices = xr.addressable_runtime_device_count()

device_id_start = process_id * num_local_devices
device_ids = np.arange(device_id_start, device_id_start + num_local_devices)
return Mesh(device_ids, mesh_shape, partition_spec)

It's not a valid solution to derive target SPMD devices from the tile_assignment_devices of Xla::Opsharding. In some cases, it's empty.

The tile_assignment_devices is empty, for an existing OpSharding of type OTHER?

assert num_devices > 0, "This requires XLA supported device(s)."
assert mesh.size() == num_devices, \
assert mesh.size() == num_devices or mesh.size() == num_local_devices, \
Copy link
Collaborator

@rpsilva-aws rpsilva-aws Mar 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, should we branch the local and global variants? These asserts are technically alleviating the constraints for both cases - e.g., I can define a mesh size of 32 for 64 global devices, if that host only has 32 addressable devices.

On the downside, I can see how this will complicate in how we get rid of the env var.

@@ -559,7 +559,10 @@ std::vector<ComputationClient::ComputationPtr> PjRtComputationClient::Compile(
.set_allow_spmd_sharding_propagation_to_output(
{instance.allow_spmd_sharding_propagation_to_output});

int num_partitions = client_->device_count();
int num_partitions = GetNumGlobalDevices();
if (runtime::sys_util::GetEnvBool("XLA_USE_LOCAL_SPMD", false)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: We could do bool use_local_spmd = runtime::sys_util::GetEnvBool("XLA_USE_LOCAL_SPMD", false); and use in both places below.

@@ -589,11 +592,18 @@ std::vector<ComputationClient::ComputationPtr> PjRtComputationClient::Compile(
}

// TODO(244391366) verify this is correct for the collectives ops
xla::DeviceAssignment device_assignment(1, client_->device_count());
xla::DeviceAssignment device_assignment(1, num_partitions);
// DeviceAssignment values must be the PjRtDevice ID, so we need to
// unwind the global ordinal mapping.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can adapt the comment, since we don't only "unwind the global ordinal mapping" now.

for (const auto& [device_id, global_ordinal] : global_ordinals_) {
device_assignment(0, global_ordinal) = device_id;
if (runtime::sys_util::GetEnvBool("XLA_USE_LOCAL_SPMD", false)) {
auto local_pjrt_devices = client_->addressable_devices();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit const auto&

@@ -613,10 +613,14 @@ IfrtComputationClient::ExecuteReplicated(
return data_handles;
}

size_t IfrtComputationClient::GetNumDevices() const {
size_t IfrtComputationClient::GetNumLocalDevices() const {
Copy link
Collaborator

@rpsilva-aws rpsilva-aws Mar 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: We could also consider naming to include "addressable"/"visible". Technically, we can have N addressable devices for a process, but not necessarily all the host's devices - which could be considered to be local devices too. If we have 2 processes in one host, each with one addressable device, we can think of how we want the terminology to play out here.

@@ -57,7 +57,7 @@ struct XLAGuardImpl : public c10::impl::DeviceGuardImplInterface {
return 0;
}

return client->GetNumDevices();
return client->GetNumLocalDevices();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: Where do we use this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just rename the API to make explicitly say it's number of local devices, orignal name is confusing.

we need to normalize the physical device ids to generate the correct HLO
sharding annotation.
"""
device_id_min = np.min(device_mesh)
Copy link
Collaborator

@rpsilva-aws rpsilva-aws Mar 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: We could avoid the copy if device_id_min == 0 by exiting early (e.g. thousands of hosts).

# device ids are continous
if os.environ['XLA_USE_LOCAL_SPMD'] == '1':
# In local SPMD mesh only contains local devices.
min_device_idx = xr.process_index() * xr.addressable_runtime_device_count(
Copy link
Collaborator

@rpsilva-aws rpsilva-aws Mar 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assumes homogeneous distribution of addressable devices - will this be a requirement? Say we have a different amount of addressable devices per MPMD, e.g. [0,1], [2,3,4,5,6,7].

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is a requirement for now

Copy link
Collaborator

@pgmoka pgmoka left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not much to add. I am interested in follow-ups to other comments

std::vector<std::string> devices = {"TPU:8", "TPU:9", "TPU:10", "TPU:11",
"TPU:12", "TPU:13", "TPU:14", "TPU:15"};

// 1D tiled
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose making paramatied functions in Python might be easier than C++. If we want to keep the way we are doing this, making these different tests might make things easier to test.

Currently if this test fails, it is hard to tell which case it is related to.

@tengyifei
Copy link
Collaborator

I'm not sure if it's possible to use multiple meshed in a single computation. I think in the same SPMD program, we have to use only 1 mesh. I can have a try and update here later.

Ok also to answer my own question, JAX does these things:

  • When compiling, it looks at the input/output/intermediate shardings to find the device assignment in all those, and validates that they're the same devices. 1. That means one cannot use two non-overlapping meshes in a single jitted computation.
  • Once it has done that, it just determines the number of partitions based on the number of devices 2.

So this roughly translates to every sharded tensor having a mesh and the device assignment is derived from their meshes.

(Because torch_xla work on tensors, attaching mesh to each sharded tensor seems to be redundant)
Therefore we have to derive the target devices from the xla::OpSharding.

Is it possible to put a DeviceAssignment object next to the xla::OpSharding object? The DeviceAssignment object would hold a vector of device IDs. It sounds like we already store xla::OpSharding for each sharded tensor, so storing an extra field shouldn't be that much overhead and can handle the situations where xla::OpSharding is insufficient.

sharding.tile_assignment_dimensions().end(), 1,
[](int a, int b) { return a * b; });
std::unordered_map<int, int> device_index =
build_index_map(devices, num_tiles);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I don't fully shard the tensor over the mesh, is num_tiles the correct value to provide as num_mesh_devices to build_index_map? It seems that we'd need to actually plumb in the mesh (or at least the mesh device count) somehow.

For example, if I'm doing local SPMD with a 2D mesh of 2x2 and axis name 'x', 'y', then later I shard a tensor only over the x axis. Would num_tiles be 2 or 4?

// devices, host 1 has devices TPU:{4, 5, 6, 7}. In this case
// the global ordinal of TPU:4 is 0, TPU:5 is 1, and so on.

int global_ordinal =
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: should this still be called global_ordinal? or maybe mesh_ordinal, based on your explanation? "global" sounds like it's talking about all devices across different hosts.

@tengyifei
Copy link
Collaborator

tengyifei commented Mar 17, 2025

Sorry for the late comment -- I'm honestly a bit worried of the subtle distinction between local device IDs and global device IDs, particularly the normalize_logical_mesh function and the logic scatter across Python and C++. I think I understand what's happening enough to have a suggestion. Here's my understanding -- LMK if I got any part wrong:

  • In order to do "local SPMD", we need to tell XLA that num_partitions == 4 or whatever the number of local addressable devices. We also need to set the IDs in any OpSharding proto to be local IDs i.e. begin at 0. Finally, we need to create a xla::DeviceAssignment that maps these device IDs to PjRtDevice::id()s. It's a bit confusing because the PJRT device ID is what everyone else thinks of as "device IDs" when it comes to XLA, and the PJRT device ID is in fact sparse (e.g. it can be 100001 in case of multi-slice). So some part of our code refers the "local/global device ID" as "global ordinals" instead, which is really just the index of a device in the mesh. XLA thinks in terms of the PjRtDevice::id()s and our mesh ordinal is just an intermediate contract. In the extreme case, we could shuffle them arbitrarily and as long as we provide an appropriate xla::DeviceAssignment, that will result in the correct collectives.
  • Currently, the mark_sharding call implementations also lack abstraction. They directly build a xla::OpSharding proto and the same xla::OpSharding proto is later given to the XLA compiler, which requires that IDs in there start from 0. IIUC that's what forces us to normalize_logical_mesh. If XLA didn't have this requirement, we could avoid normalize_logical_mesh and just need to have the right xla::DeviceAssignment.
  • As a result, we have many kinds of device ID concepts. There's a "global device ID" that is dense and goes from 0 to the number of chips in the environment. There's a "local device ID" that's derived by subtracting a bunch of device ID by their minimum, whose correctness is guaranteed by a subtle check in xs.Mesh that requires device IDs to be locally addressable iff in local SPMD mode. Not to mention there's a third device ID which is the non-dense PjRTDevice::id(). The worse part is that they're all untyped integers.

I'm wondering if it makes sense to do a prefactor, to hide the xla::OpSharding proto from the public API. I think instead of storing xla::OpSharding, we'd probably create a torch_xla::OpSharding type that stores the same things except that it always stores global IDs (i.e. the ones we use to index into xr.global_runtime_device_attributes(). We could also store any other necessary information that lets us recover the number of partitions.

Once we've done this refactor, the Python layer can reason in terms of torch_xla::OpSharding objects instead of xla::OpSharding. Then it's not forced to normalize the device IDs. We'll still have two kinds of device IDs (the global dense ID and the sparse PjRTDevice::id(), but that's better than having three kinds of device IDs).

If this feature is no longer urgently required, I wonder is it possible to do this refactor? I think it could also let us support other kinds of SPMD (e.g. using 2 chips out of 4 chips) instead of being restricted to either full local SPMD or global SPMD.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
distributed SPMD and other distributed things.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants