Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[1/N] Initial implementation of local SPMD support #8810

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 107 additions & 0 deletions test/cpp/test_xla_sharding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,113 @@ TEST_F(XLAShardingTest, ShardTensor) {
EXPECT_EQ(shards[7].sizes(), c10::ArrayRef<long>({10, 1, 4, 4, 2}));
}

TEST_F(XLAShardingTest, ShardTensorLocalMesh) {
// Test sharding with a local mesh.
std::vector<std::string> devices = {"TPU:8", "TPU:9", "TPU:10", "TPU:11",
"TPU:12", "TPU:13", "TPU:14", "TPU:15"};

// 1D tiled
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the different tiled methods, could we not do parasitized testing to significantly save code, and make things more readable?

INSTANTIATE_TEST_CASE_P reference: https://www.sandordargo.com/blog/2019/04/24/parameterized-testing-with-gtest

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the suggestion! But I personally think it's hard to parameterize in this case. We need to programmatically derive the mesh and expected output from the test parameters. It's hard to generalize in this case (1d tile, 2d tile and etc). I feel like in the current implementation, even if it's long, but it's more readable.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose making paramatied functions in Python might be easier than C++. If we want to keep the way we are doing this, making these different tests might make things easier to test.

Currently if this test fails, it is hard to tell which case it is related to.

at::Tensor tensor = at::ones({8}, at::TensorOptions(at::kFloat));
xla::Shape tensor_shape =
CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice());
xla::OpSharding sharding =
xla::HloSharding::Tile1D(
CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice()),
devices.size())
.ToProto();
auto sharding_spec =
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
auto shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
/*padded=*/false);
EXPECT_EQ(shards.size(), 8);
for (auto shard : shards) {
EXPECT_EQ(shard.sizes(), c10::ArrayRef<long>({1}));
}

// 2D tiled, The first dim is halved and the last replicated. The last shard
// size should be smaller in dim=1 because it's not evenly divisible.
tensor = at::ones({8, 7, 4}, at::TensorOptions(at::kFloat));
tensor_shape =
CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice());
xla::Array2D<int64_t> mesh({
{0, 1, 2, 3},
{4, 5, 6, 7},
});
sharding = xla::HloSharding::Tile(mesh).ToProto();
sharding_spec =
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
/*padded=*/false);
EXPECT_EQ(shards.size(), 8);
EXPECT_EQ(shards[0].sizes(), c10::ArrayRef<long>({4, 2, 4}));
EXPECT_EQ(shards[7].sizes(), c10::ArrayRef<long>({4, 1, 4}));

// 3D tiled, the first dim is replicated and the last halved. The last shard
// size should be smaller in dim=1 because it's not evenly divisible.
xla::Array3D<int64_t> cube({{{0, 1}, {2, 3}, {4, 5}, {6, 7}}});
sharding_spec->sharding = xla::HloSharding::Tile(cube).ToProto();
shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
/*padded=*/false);
EXPECT_EQ(shards.size(), 8);
EXPECT_EQ(shards[0].sizes(), c10::ArrayRef<long>({8, 2, 2}));
EXPECT_EQ(shards[7].sizes(), c10::ArrayRef<long>({8, 1, 2}));

// Replicated, all shards should be identical.
sharding_spec->sharding = xla::HloSharding::Replicate().ToProto();
shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
/*padded=*/false);
EXPECT_EQ(shards.size(), 8);
EXPECT_EQ(shards[0].sizes(), c10::ArrayRef<long>({8, 7, 4}));
EXPECT_EQ(shards[7].sizes(), c10::ArrayRef<long>({8, 7, 4}));

// 4D tiled, the first and second dims are replicated and the last halved. The
// last shard size should be smaller in dim=2 because it's not evenly
// divisible.
tensor = at::ones({1, 8, 7, 4}, at::TensorOptions(at::kFloat));
tensor_shape =
CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice());
xla::Array4D<int64_t> tesseract({{{{0, 1}, {2, 3}, {4, 5}, {6, 7}}}});
sharding = xla::HloSharding::Tile(tesseract).ToProto();
sharding_spec =
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
/*padded=*/false);
EXPECT_EQ(shards.size(), 8);
EXPECT_EQ(shards[0].sizes(), c10::ArrayRef<long>({1, 8, 2, 2}));
EXPECT_EQ(shards[7].sizes(), c10::ArrayRef<long>({1, 8, 1, 2}));

// 4D tiled and padded, all shard sizes should be idential.
shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
/*padded=*/true);
EXPECT_EQ(shards.size(), 8);
EXPECT_EQ(shards[0].sizes(), c10::ArrayRef<long>({1, 8, 2, 2}));
EXPECT_EQ(shards[7].sizes(), c10::ArrayRef<long>({1, 8, 2, 2}));

// 5D tiled, the first and second dims are replicated and the last halved. The
// last shard size should be smaller in dim=2 because it's not evenly
// divisible.
tensor = at::ones({10, 1, 8, 7, 4}, at::TensorOptions(at::kFloat));
tensor_shape =
CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice());
xla::Array<int64_t> hypercube(std::vector<int64_t>{1, 1, 2, 2, 2});
hypercube.FillIota(0);
sharding = xla::HloSharding::Tile(hypercube).ToProto();
sharding_spec =
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
/*padded=*/false);
EXPECT_EQ(shards.size(), 8);
EXPECT_EQ(shards[0].sizes(), c10::ArrayRef<long>({10, 1, 4, 4, 2}));
EXPECT_EQ(shards[7].sizes(), c10::ArrayRef<long>({10, 1, 4, 3, 2}));

// 5D tiled and padded, all shard sizes should be identical.
shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
/*padded=*/true);
EXPECT_EQ(shards.size(), 8);
EXPECT_EQ(shards[0].sizes(), c10::ArrayRef<long>({10, 1, 4, 4, 2}));
EXPECT_EQ(shards[7].sizes(), c10::ArrayRef<long>({10, 1, 4, 4, 2}));
}

TEST_F(XLAShardingTest, ShardTensorMultiHost) {
std::vector<std::string> devices = {"TPU:4", "TPU:5", "TPU:6", "TPU:7"};

Expand Down
7 changes: 5 additions & 2 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1482,7 +1482,7 @@ void InitXlaModuleBindings(py::module m) {
if (UseVirtualDevice()) {
return 1;
} else {
return runtime::GetComputationClient()->GetNumDevices();
return runtime::GetComputationClient()->GetNumLocalDevices();
}
});
m.def("_xla_get_all_devices", []() {
Expand All @@ -1500,13 +1500,16 @@ void InitXlaModuleBindings(py::module m) {
m.def("_xla_get_runtime_devices",
[]() { return runtime::GetComputationClient()->GetLocalDevices(); });
m.def("_xla_num_runtime_devices", []() -> int64_t {
return runtime::GetComputationClient()->GetNumDevices();
return runtime::GetComputationClient()->GetNumLocalDevices();
});
m.def("_xla_get_all_runtime_devices", []() {
std::vector<std::string> all_devices =
runtime::GetComputationClient()->GetAllDevices();
return all_devices;
});
m.def("_xla_num_global_devices", []() -> int64_t {
return runtime::GetComputationClient()->GetNumGlobalDevices();
});
m.def(
"_xla_real_devices",
[](const std::optional<std::vector<std::string>> devices) {
Expand Down
4 changes: 3 additions & 1 deletion torch_xla/csrc/runtime/computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,9 @@ class ComputationClient {

virtual std::intptr_t GetCudaStreamForDevice(int local_device_id) const = 0;

virtual size_t GetNumDevices() const = 0;
virtual size_t GetNumLocalDevices() const = 0;

virtual size_t GetNumGlobalDevices() const = 0;

virtual std::vector<std::string> GetLocalDevices() const = 0;

Expand Down
6 changes: 5 additions & 1 deletion torch_xla/csrc/runtime/ifrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -613,10 +613,14 @@ IfrtComputationClient::ExecuteReplicated(
return data_handles;
}

size_t IfrtComputationClient::GetNumDevices() const {
size_t IfrtComputationClient::GetNumLocalDevices() const {
Copy link
Collaborator

@rpsilva-aws rpsilva-aws Mar 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: We could also consider naming to include "addressable"/"visible". Technically, we can have N addressable devices for a process, but not necessarily all the host's devices - which could be considered to be local devices too. If we have 2 processes in one host, each with one addressable device, we can think of how we want the terminology to play out here.

return client_->addressable_device_count();
}

size_t IfrtComputationClient::GetNumGlobalDevices() const {
return client_->device_count();
}

std::string IfrtComputationClient::GetDefaultDevice() const {
return IfrtDeviceToString(client_->addressable_devices()[0]);
}
Expand Down
4 changes: 3 additions & 1 deletion torch_xla/csrc/runtime/ifrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,9 @@ class IfrtComputationClient : public ComputationClient {
absl::Span<const std::string> devices,
const ExecuteReplicatedOptions& options) override;

size_t GetNumDevices() const override;
size_t GetNumLocalDevices() const override;

size_t GetNumGlobalDevices() const override;

std::string GetDefaultDevice() const override;

Expand Down
25 changes: 19 additions & 6 deletions torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,10 @@ std::vector<ComputationClient::ComputationPtr> PjRtComputationClient::Compile(
.set_allow_spmd_sharding_propagation_to_output(
{instance.allow_spmd_sharding_propagation_to_output});

int num_partitions = client_->device_count();
int num_partitions = GetNumGlobalDevices();
if (runtime::sys_util::GetEnvBool("XLA_USE_LOCAL_SPMD", false)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: We could do bool use_local_spmd = runtime::sys_util::GetEnvBool("XLA_USE_LOCAL_SPMD", false); and use in both places below.

num_partitions = GetNumLocalDevices();
}
compile_options.executable_build_options.set_num_partitions(
num_partitions);
compile_options.executable_build_options.set_num_replicas(1);
Expand Down Expand Up @@ -589,11 +592,18 @@ std::vector<ComputationClient::ComputationPtr> PjRtComputationClient::Compile(
}

// TODO(244391366) verify this is correct for the collectives ops
xla::DeviceAssignment device_assignment(1, client_->device_count());
xla::DeviceAssignment device_assignment(1, num_partitions);
// DeviceAssignment values must be the PjRtDevice ID, so we need to
// unwind the global ordinal mapping.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can adapt the comment, since we don't only "unwind the global ordinal mapping" now.

for (const auto& [device_id, global_ordinal] : global_ordinals_) {
device_assignment(0, global_ordinal) = device_id;
if (runtime::sys_util::GetEnvBool("XLA_USE_LOCAL_SPMD", false)) {
auto local_pjrt_devices = client_->addressable_devices();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit const auto&

for (int i = 0; i < local_pjrt_devices.size(); ++i) {
device_assignment(0, i) = local_pjrt_devices[i]->id();
}
} else {
for (const auto& [device_id, global_ordinal] : global_ordinals_) {
device_assignment(0, global_ordinal) = device_id;
}
}
compile_options.executable_build_options.set_device_assignment(
device_assignment);
Expand Down Expand Up @@ -649,7 +659,6 @@ std::vector<ComputationClient::ComputationPtr> PjRtComputationClient::Compile(

CreateCompileHandlesCounter()->AddValue(1);
}

return computations;
}

Expand Down Expand Up @@ -917,10 +926,14 @@ PjRtComputationClient::ExecuteReplicated(
return data_handles;
}

size_t PjRtComputationClient::GetNumDevices() const {
size_t PjRtComputationClient::GetNumLocalDevices() const {
return client_->addressable_device_count();
}

size_t PjRtComputationClient::GetNumGlobalDevices() const {
return client_->device_count();
}

std::string PjRtComputationClient::GetDefaultDevice() const {
return PjRtDeviceToString(client_->addressable_devices()[0]);
}
Expand Down
4 changes: 3 additions & 1 deletion torch_xla/csrc/runtime/pjrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ class PjRtComputationClient : public ComputationClient {
absl::Span<const std::string> devices,
const ExecuteReplicatedOptions& options) override;

size_t GetNumDevices() const override;
size_t GetNumLocalDevices() const override;

size_t GetNumGlobalDevices() const override;

std::string GetDefaultDevice() const override;

Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/tensor_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ struct XLAGuardImpl : public c10::impl::DeviceGuardImplInterface {
return 0;
}

return client->GetNumDevices();
return client->GetNumLocalDevices();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: Where do we use this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just rename the API to make explicitly say it's number of local devices, orignal name is confusing.

}
};

Expand Down
8 changes: 4 additions & 4 deletions torch_xla/csrc/xla_graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1422,10 +1422,10 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile(
program_shape.result(), static_cast<XlaDeviceType>(coll.device.type()));

std::vector<runtime::ComputationClient::CompileInstance> instances;
instances.push_back({std::move(computation), coll.device.toString(),
runtime::GetComputationClient()->GetCompilationDevices(
coll.device.toString(), devices),
&shape, should_wrap_parameter, is_sharded});
instances.emplace_back(std::move(computation), coll.device.toString(),
runtime::GetComputationClient()->GetCompilationDevices(
coll.device.toString(), devices),
&shape, should_wrap_parameter, is_sharded);
instances.front().eager_mode = UseEagerMode();
if (use_autosharding) {
TF_VLOG(5) << "use_auto_spmd_partitioning is set.";
Expand Down
22 changes: 18 additions & 4 deletions torch_xla/csrc/xla_sharding_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,20 @@ std::vector<int64_t> TileAssignmentDimensions(
// order of the output corresponds to the order of the `devices`, which can be
// arbitrarily set by the caller.
std::unordered_map<int, int> build_index_map(
const std::vector<std::string>& devices) {
const std::vector<std::string>& devices, size_t num_mesh_devices) {
std::unordered_map<int, int> device_index;
for (int i = 0; i < devices.size(); ++i) {
int global_ordinal = ParseDeviceString(devices[i]).ordinal();
// The global ordianl here is the device's ordinal in the mesh, which is
// can be different from the physical device index.
// We only support 2 cases here:
// 1. Mesh contains all global devices.
// 2. Mesh contains only local devices. (in multi-host scenario)
// Example: In multi-host v6e-8, each host has a mesh of its local
// devices, host 1 has devices TPU:{4, 5, 6, 7}. In this case
// the global ordinal of TPU:4 is 0, TPU:5 is 1, and so on.

int global_ordinal =
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: should this still be called global_ordinal? or maybe mesh_ordinal, based on your explanation? "global" sounds like it's talking about all devices across different hosts.

ParseDeviceString(devices[i]).ordinal() % num_mesh_devices;
device_index[global_ordinal] = i;
}
return device_index;
Expand Down Expand Up @@ -371,7 +381,12 @@ ShardingUtil::GetShardReplicaAndIndicesForDevices(
shard_indices[i] = std::make_pair(global_ordinal, indices);
}
} else if (sharding.type() == xla::OpSharding::OTHER) {
auto device_index = build_index_map(devices);
size_t num_tiles =
std::accumulate(sharding.tile_assignment_dimensions().begin(),
sharding.tile_assignment_dimensions().end(), 1,
[](int a, int b) { return a * b; });
std::unordered_map<int, int> device_index =
build_index_map(devices, num_tiles);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I don't fully shard the tensor over the mesh, is num_tiles the correct value to provide as num_mesh_devices to build_index_map? It seems that we'd need to actually plumb in the mesh (or at least the mesh device count) somehow.

For example, if I'm doing local SPMD with a 2D mesh of 2x2 and axis name 'x', 'y', then later I shard a tensor only over the x axis. Would num_tiles be 2 or 4?

std::vector<int64_t> tile_assignment_devices(
sharding.tile_assignment_devices().begin(),
sharding.tile_assignment_devices().end());
Expand Down Expand Up @@ -442,7 +457,6 @@ std::vector<at::Tensor> ShardingUtil::ShardTensor(
}
TF_VLOG(5) << "ShardTensor with sharding type(" << sharding.type()
<< ")... and minibatch = " << minibatch << std::endl;
auto device_index = build_index_map(devices);
std::vector<at::Tensor> shards(devices.size());
if (shardings == nullptr || sharding.type() == xla::OpSharding::REPLICATED ||
sharding.type() == xla::OpSharding::UNKNOWN) {
Expand Down
Loading
Loading