-
Notifications
You must be signed in to change notification settings - Fork 503
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[1/N] Initial implementation of local SPMD support #8810
base: master
Are you sure you want to change the base?
Conversation
0833b24
to
3865f67
Compare
I tried hard avoid using env var to control local SPMD enablement. (In this commit My thought is:
Determining the target device from SPMD mesh is the much more straightforward. However, it's not passed down to LTC and LTC cannot be easily extended to pass the SPMD mesh from Python layer (Because torch_xla work on tensors, attaching mesh to each sharded tensor seems to be redundant) Therefore we have to derive the target devices from the
The issue I have is in step 3: Since many test cases are failing with the above attempt, and there are many flavors of sharding annotation. I decided to use the env var to control the local SPMD enablement, which is clean since it's only used in the |
The GPU test failing is irrelevant, head is failing with the same error. |
assert num_devices > 0, "This requires XLA supported device(s)." | ||
assert mesh.size() == num_devices, \ | ||
assert mesh.size() == num_devices or mesh.size() == num_local_devices, \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seem like a smell. Why do we only check the mesh size during mark_sharding
? Furthermore, should we check that the mesh is a local mesh only when XLA_USE_LOCAL_SPMD==1
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems that these checks should be inside the Mesh
constructor IMO
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the mesh constructor : )
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, should we branch the local and global variants? These asserts are technically alleviating the constraints for both cases - e.g., I can define a mesh size of 32 for 64 global devices, if that host only has 32 addressable devices.
On the downside, I can see how this will complicate in how we get rid of the env var.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should check if the mesh only contains local devices under local spmd, updated the condition
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the mesh constructor : )
Wait, this is not the mesh constructor. This is the mark_sharding
function. It seems very strange that we're doing these localized checks in the mark_sharding
function. IMO this check should be moved to the Mesh
constructor.
Do you know how JAX figures out which target devices to run the graph on? E.g. if I create two smaller meshes |
std::vector<std::string> devices = {"TPU:8", "TPU:9", "TPU:10", "TPU:11", | ||
"TPU:12", "TPU:13", "TPU:14", "TPU:15"}; | ||
|
||
// 1D tiled |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the different tiled methods, could we not do parasitized testing to significantly save code, and make things more readable?
INSTANTIATE_TEST_CASE_P reference: https://www.sandordargo.com/blog/2019/04/24/parameterized-testing-with-gtest
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the suggestion! But I personally think it's hard to parameterize in this case. We need to programmatically derive the mesh and expected output from the test parameters. It's hard to generalize in this case (1d tile, 2d tile and etc). I feel like in the current implementation, even if it's long, but it's more readable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suppose making paramatied functions in Python might be easier than C++. If we want to keep the way we are doing this, making these different tests might make things easier to test.
Currently if this test fails, it is hard to tell which case it is related to.
I'm not sure if it's possible to use multiple meshed in a single computation. I think in the same SPMD program, we have to use only 1 mesh. I can have a try and update here later. |
Nice, this is great! Thanks for following up on it :) A few other general questions:
This is interesting, do we have a plan for this? We were stumbling on this too. nit: We could consider adding an helper
The |
assert num_devices > 0, "This requires XLA supported device(s)." | ||
assert mesh.size() == num_devices, \ | ||
assert mesh.size() == num_devices or mesh.size() == num_local_devices, \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, should we branch the local and global variants? These asserts are technically alleviating the constraints for both cases - e.g., I can define a mesh size of 32 for 64 global devices, if that host only has 32 addressable devices.
On the downside, I can see how this will complicate in how we get rid of the env var.
@@ -559,7 +559,10 @@ std::vector<ComputationClient::ComputationPtr> PjRtComputationClient::Compile( | |||
.set_allow_spmd_sharding_propagation_to_output( | |||
{instance.allow_spmd_sharding_propagation_to_output}); | |||
|
|||
int num_partitions = client_->device_count(); | |||
int num_partitions = GetNumGlobalDevices(); | |||
if (runtime::sys_util::GetEnvBool("XLA_USE_LOCAL_SPMD", false)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: We could do bool use_local_spmd = runtime::sys_util::GetEnvBool("XLA_USE_LOCAL_SPMD", false);
and use in both places below.
@@ -589,11 +592,18 @@ std::vector<ComputationClient::ComputationPtr> PjRtComputationClient::Compile( | |||
} | |||
|
|||
// TODO(244391366) verify this is correct for the collectives ops | |||
xla::DeviceAssignment device_assignment(1, client_->device_count()); | |||
xla::DeviceAssignment device_assignment(1, num_partitions); | |||
// DeviceAssignment values must be the PjRtDevice ID, so we need to | |||
// unwind the global ordinal mapping. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can adapt the comment, since we don't only "unwind the global ordinal mapping" now.
for (const auto& [device_id, global_ordinal] : global_ordinals_) { | ||
device_assignment(0, global_ordinal) = device_id; | ||
if (runtime::sys_util::GetEnvBool("XLA_USE_LOCAL_SPMD", false)) { | ||
auto local_pjrt_devices = client_->addressable_devices(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit const auto&
@@ -613,10 +613,14 @@ IfrtComputationClient::ExecuteReplicated( | |||
return data_handles; | |||
} | |||
|
|||
size_t IfrtComputationClient::GetNumDevices() const { | |||
size_t IfrtComputationClient::GetNumLocalDevices() const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: We could also consider naming to include "addressable"/"visible". Technically, we can have N addressable devices for a process, but not necessarily all the host's devices - which could be considered to be local devices too. If we have 2 processes in one host, each with one addressable device, we can think of how we want the terminology to play out here.
@@ -57,7 +57,7 @@ struct XLAGuardImpl : public c10::impl::DeviceGuardImplInterface { | |||
return 0; | |||
} | |||
|
|||
return client->GetNumDevices(); | |||
return client->GetNumLocalDevices(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Q: Where do we use this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just rename the API to make explicitly say it's number of local devices, orignal name is confusing.
we need to normalize the physical device ids to generate the correct HLO | ||
sharding annotation. | ||
""" | ||
device_id_min = np.min(device_mesh) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: We could avoid the copy if device_id_min == 0
by exiting early (e.g. thousands of hosts).
# device ids are continous | ||
if os.environ['XLA_USE_LOCAL_SPMD'] == '1': | ||
# In local SPMD mesh only contains local devices. | ||
min_device_idx = xr.process_index() * xr.addressable_runtime_device_count( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This assumes homogeneous distribution of addressable devices - will this be a requirement? Say we have a different amount of addressable devices per MPMD, e.g. [0,1], [2,3,4,5,6,7].
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is a requirement for now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not much to add. I am interested in follow-ups to other comments
std::vector<std::string> devices = {"TPU:8", "TPU:9", "TPU:10", "TPU:11", | ||
"TPU:12", "TPU:13", "TPU:14", "TPU:15"}; | ||
|
||
// 1D tiled |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suppose making paramatied functions in Python might be easier than C++. If we want to keep the way we are doing this, making these different tests might make things easier to test.
Currently if this test fails, it is hard to tell which case it is related to.
Ok also to answer my own question, JAX does these things:
So this roughly translates to every sharded tensor having a mesh and the device assignment is derived from their meshes.
Is it possible to put a |
sharding.tile_assignment_dimensions().end(), 1, | ||
[](int a, int b) { return a * b; }); | ||
std::unordered_map<int, int> device_index = | ||
build_index_map(devices, num_tiles); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I don't fully shard the tensor over the mesh, is num_tiles
the correct value to provide as num_mesh_devices
to build_index_map
? It seems that we'd need to actually plumb in the mesh (or at least the mesh device count) somehow.
For example, if I'm doing local SPMD with a 2D mesh of 2x2
and axis name 'x', 'y'
, then later I shard a tensor only over the x
axis. Would num_tiles
be 2 or 4?
// devices, host 1 has devices TPU:{4, 5, 6, 7}. In this case | ||
// the global ordinal of TPU:4 is 0, TPU:5 is 1, and so on. | ||
|
||
int global_ordinal = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: should this still be called global_ordinal
? or maybe mesh_ordinal
, based on your explanation? "global" sounds like it's talking about all devices across different hosts.
Sorry for the late comment -- I'm honestly a bit worried of the subtle distinction between local device IDs and global device IDs, particularly the
I'm wondering if it makes sense to do a prefactor, to hide the Once we've done this refactor, the Python layer can reason in terms of If this feature is no longer urgently required, I wonder is it possible to do this refactor? I think it could also let us support other kinds of SPMD (e.g. using 2 chips out of 4 chips) instead of being restricted to either full local SPMD or global SPMD. |
Before this PR, user must uses all global devices in their SPMD program, this limits the flexibility of running MPMD + SPMD under multi-host environment.
This PR enables local SPMD by setting environment variable
XLA_USE_LOCAL_SPMD
. Local SPMD and global SPMD can be switched in the same python program by switching on and off the environment variable.Usage
Example usage is:
An example script is here
To make a local SPMD work in multi-host setting, we need to ensure:
Implementation:
SPMD Python API:
_get_tile_assignment
to support local device mesh, so that the device id will start from 0 in the HLO sharding annotationSharing Utilities:
GetShardReplicaAndIndicesForDevices
. This is achieved by deriving global ordinal from physical device and SPMD sharding annotation.PJRT computation client:
Test:
Future work: