Skip to content

Commit

Permalink
Fix python API bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy committed Nov 22, 2022
1 parent a81d203 commit 272dc9e
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 17 deletions.
4 changes: 2 additions & 2 deletions tensorflow/compiler/xla/client/xla_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -1573,8 +1573,8 @@ class XlaBuilder {
XlaOp AllToAllArray(XlaOp operand, int64_t split_dimension,
int64_t concat_dimension, int64_t split_count,
absl::Span<const ReplicaGroup> replica_groups,
const std::optional<ChannelHandle>& channel_id=std::nullopt,
const std::optional<bool> use_global_device_ids=std::nullopt);
const std::optional<ChannelHandle>& channel_id,
const std::optional<bool> use_global_device_ids);

// Creates an op with the given opcode and the output shape.
virtual StatusOr<XlaOp> AddOpWithShape(HloOpcode opcode, const Shape& shape,
Expand Down
19 changes: 6 additions & 13 deletions tensorflow/compiler/xla/python/xla_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -480,18 +480,6 @@ void BuildXlaCompilerSubmodule(py::module& m) {
&HloPrintOptions::is_in_nested_computation,
&HloPrintOptions::set_is_in_nested_computation);

// Added by Alpa
py::class_<HloSharding> hlo_sharding_class(m, "HloSharding");
hlo_sharding_class
.def(py::init([](const py::bytes& serialized_hlo_sharding_proto) {
OpSharding proto;
proto.ParseFromString(std::string(serialized_hlo_sharding_proto));
return ValueOrThrow(HloSharding::FromProto(proto));
}))
.def("proto_tuple", [](const HloSharding& hlo_sharding) {
return hlo_sharding.ToProto();
});

py::class_<HloModule, std::shared_ptr<HloModule>> hlo_module_class(
m, "HloModule");
hlo_module_class.def_property_readonly("name", &HloModule::name)
Expand Down Expand Up @@ -812,6 +800,11 @@ void BuildXlaCompilerSubmodule(py::module& m) {

py::class_<HloSharding> hlo_sharding(m, "HloSharding");
hlo_sharding.def_static("from_proto", &xla::HloSharding::FromProto)
.def(py::init([](const py::bytes& serialized_hlo_sharding_proto) {
OpSharding proto;
proto.ParseFromString(std::string(serialized_hlo_sharding_proto));
return ValueOrThrow(HloSharding::FromProto(proto));
}))
.def("__eq__", [](const xla::HloSharding& a,
const xla::HloSharding& b) { return a == b; })
.def("__hash__",
Expand Down Expand Up @@ -864,7 +857,7 @@ void BuildXlaCompilerSubmodule(py::module& m) {
m.def("set_hlo_module_output_shardings", &spmd::SetHloModuleOutputShardings);
m.def("set_hlo_module_input_shardings", &spmd::SetHloModuleInputShardings);
m.def("get_grad_sync_channel_ids", &spmd::GetGradSyncChannelIds);
m.def("get_alpa_jaxlib_version", [] { return "0.1.1"; });
m.def("get_alpa_jaxlib_version", [] { return "0.2.2"; });

m.def(
"run_auto_sharding",
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/compiler/xla/service/spmd/alpa_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ Status RunAutoShardingPass(HloModule* hlo_module,
layout_insensitive_algsimp_opts.set_enable_dot_strength_reduction(false); // Added by Alpa

if (hlo_module->config().use_spmd_partitioning()) {
HloPassPipeline spmd_pipeline("spmd-partitioner");
HloPassPipeline spmd_pipeline("run-auto-sharding");
AddHloVerifier(&spmd_pipeline);
const int64_t num_partitions = hlo_module->config().num_partitions();
if (num_partitions > 1) {
Expand Down Expand Up @@ -199,7 +199,7 @@ Status RunSpmdPartitionerPass(HloModule* hlo_module,

// TODO(yonghao): TF Profiler Traceme
if (hlo_module->config().use_spmd_partitioning()) {
HloPassPipeline spmd_pipeline("spmd-partitioner");
HloPassPipeline spmd_pipeline("run-spmd-partitioner");
const int64_t num_partitions = hlo_module->config().num_partitions();
if (num_partitions > 1) {
spmd_pipeline.AddPass<ShardingPropagation>(
Expand Down

0 comments on commit 272dc9e

Please sign in to comment.