-
Notifications
You must be signed in to change notification settings - Fork 504
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[1/N] Initial implementation of local SPMD support #8810
base: master
Are you sure you want to change the base?
Changes from all commits
2d2f08a
c4aa854
2f33433
b633c76
ba4b480
7561786
ff12d44
9873129
e418350
e2d157b
3865f67
d3feb5f
2f1fc1a
a1eaaeb
e97401d
4bfd7b5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -613,10 +613,14 @@ IfrtComputationClient::ExecuteReplicated( | |
return data_handles; | ||
} | ||
|
||
size_t IfrtComputationClient::GetNumDevices() const { | ||
size_t IfrtComputationClient::GetNumLocalDevices() const { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: We could also consider naming to include "addressable"/"visible". Technically, we can have N addressable devices for a process, but not necessarily all the host's devices - which could be considered to be local devices too. If we have 2 processes in one host, each with one addressable device, we can think of how we want the terminology to play out here. |
||
return client_->addressable_device_count(); | ||
} | ||
|
||
size_t IfrtComputationClient::GetNumGlobalDevices() const { | ||
return client_->device_count(); | ||
} | ||
|
||
std::string IfrtComputationClient::GetDefaultDevice() const { | ||
return IfrtDeviceToString(client_->addressable_devices()[0]); | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -559,7 +559,10 @@ std::vector<ComputationClient::ComputationPtr> PjRtComputationClient::Compile( | |
.set_allow_spmd_sharding_propagation_to_output( | ||
{instance.allow_spmd_sharding_propagation_to_output}); | ||
|
||
int num_partitions = client_->device_count(); | ||
int num_partitions = GetNumGlobalDevices(); | ||
if (runtime::sys_util::GetEnvBool("XLA_USE_LOCAL_SPMD", false)) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: We could do |
||
num_partitions = GetNumLocalDevices(); | ||
} | ||
compile_options.executable_build_options.set_num_partitions( | ||
num_partitions); | ||
compile_options.executable_build_options.set_num_replicas(1); | ||
|
@@ -589,11 +592,18 @@ std::vector<ComputationClient::ComputationPtr> PjRtComputationClient::Compile( | |
} | ||
|
||
// TODO(244391366) verify this is correct for the collectives ops | ||
xla::DeviceAssignment device_assignment(1, client_->device_count()); | ||
xla::DeviceAssignment device_assignment(1, num_partitions); | ||
// DeviceAssignment values must be the PjRtDevice ID, so we need to | ||
// unwind the global ordinal mapping. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can adapt the comment, since we don't only "unwind the global ordinal mapping" now. |
||
for (const auto& [device_id, global_ordinal] : global_ordinals_) { | ||
device_assignment(0, global_ordinal) = device_id; | ||
if (runtime::sys_util::GetEnvBool("XLA_USE_LOCAL_SPMD", false)) { | ||
auto local_pjrt_devices = client_->addressable_devices(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit |
||
for (int i = 0; i < local_pjrt_devices.size(); ++i) { | ||
device_assignment(0, i) = local_pjrt_devices[i]->id(); | ||
} | ||
} else { | ||
for (const auto& [device_id, global_ordinal] : global_ordinals_) { | ||
device_assignment(0, global_ordinal) = device_id; | ||
} | ||
} | ||
compile_options.executable_build_options.set_device_assignment( | ||
device_assignment); | ||
|
@@ -649,7 +659,6 @@ std::vector<ComputationClient::ComputationPtr> PjRtComputationClient::Compile( | |
|
||
CreateCompileHandlesCounter()->AddValue(1); | ||
} | ||
|
||
return computations; | ||
} | ||
|
||
|
@@ -917,10 +926,14 @@ PjRtComputationClient::ExecuteReplicated( | |
return data_handles; | ||
} | ||
|
||
size_t PjRtComputationClient::GetNumDevices() const { | ||
size_t PjRtComputationClient::GetNumLocalDevices() const { | ||
return client_->addressable_device_count(); | ||
} | ||
|
||
size_t PjRtComputationClient::GetNumGlobalDevices() const { | ||
return client_->device_count(); | ||
} | ||
|
||
std::string PjRtComputationClient::GetDefaultDevice() const { | ||
return PjRtDeviceToString(client_->addressable_devices()[0]); | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -57,7 +57,7 @@ struct XLAGuardImpl : public c10::impl::DeviceGuardImplInterface { | |
return 0; | ||
} | ||
|
||
return client->GetNumDevices(); | ||
return client->GetNumLocalDevices(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Q: Where do we use this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just rename the API to make explicitly say it's number of local devices, orignal name is confusing. |
||
} | ||
}; | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -85,10 +85,20 @@ std::vector<int64_t> TileAssignmentDimensions( | |
// order of the output corresponds to the order of the `devices`, which can be | ||
// arbitrarily set by the caller. | ||
std::unordered_map<int, int> build_index_map( | ||
const std::vector<std::string>& devices) { | ||
const std::vector<std::string>& devices, size_t num_mesh_devices) { | ||
std::unordered_map<int, int> device_index; | ||
for (int i = 0; i < devices.size(); ++i) { | ||
int global_ordinal = ParseDeviceString(devices[i]).ordinal(); | ||
// The global ordianl here is the device's ordinal in the mesh, which is | ||
// can be different from the physical device index. | ||
// We only support 2 cases here: | ||
// 1. Mesh contains all global devices. | ||
// 2. Mesh contains only local devices. (in multi-host scenario) | ||
// Example: In multi-host v6e-8, each host has a mesh of its local | ||
// devices, host 1 has devices TPU:{4, 5, 6, 7}. In this case | ||
// the global ordinal of TPU:4 is 0, TPU:5 is 1, and so on. | ||
|
||
int global_ordinal = | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: should this still be called |
||
ParseDeviceString(devices[i]).ordinal() % num_mesh_devices; | ||
tengyifei marked this conversation as resolved.
Show resolved
Hide resolved
|
||
device_index[global_ordinal] = i; | ||
} | ||
return device_index; | ||
|
@@ -371,7 +381,12 @@ ShardingUtil::GetShardReplicaAndIndicesForDevices( | |
shard_indices[i] = std::make_pair(global_ordinal, indices); | ||
} | ||
} else if (sharding.type() == xla::OpSharding::OTHER) { | ||
auto device_index = build_index_map(devices); | ||
size_t num_tiles = | ||
std::accumulate(sharding.tile_assignment_dimensions().begin(), | ||
sharding.tile_assignment_dimensions().end(), 1, | ||
[](int a, int b) { return a * b; }); | ||
std::unordered_map<int, int> device_index = | ||
build_index_map(devices, num_tiles); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I don't fully shard the tensor over the mesh, is For example, if I'm doing local SPMD with a 2D mesh of |
||
std::vector<int64_t> tile_assignment_devices( | ||
sharding.tile_assignment_devices().begin(), | ||
sharding.tile_assignment_devices().end()); | ||
|
@@ -442,7 +457,6 @@ std::vector<at::Tensor> ShardingUtil::ShardTensor( | |
} | ||
TF_VLOG(5) << "ShardTensor with sharding type(" << sharding.type() | ||
<< ")... and minibatch = " << minibatch << std::endl; | ||
auto device_index = build_index_map(devices); | ||
std::vector<at::Tensor> shards(devices.size()); | ||
if (shardings == nullptr || sharding.type() == xla::OpSharding::REPLICATED || | ||
sharding.type() == xla::OpSharding::UNKNOWN) { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the different tiled methods, could we not do parasitized testing to significantly save code, and make things more readable?
INSTANTIATE_TEST_CASE_P reference: https://www.sandordargo.com/blog/2019/04/24/parameterized-testing-with-gtest
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the suggestion! But I personally think it's hard to parameterize in this case. We need to programmatically derive the mesh and expected output from the test parameters. It's hard to generalize in this case (1d tile, 2d tile and etc). I feel like in the current implementation, even if it's long, but it's more readable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suppose making paramatied functions in Python might be easier than C++. If we want to keep the way we are doing this, making these different tests might make things easier to test.
Currently if this test fails, it is hard to tell which case it is related to.