From 828913293d3ed920a2e978206102efdafdcd5da1 Mon Sep 17 00:00:00 2001 From: Yan Xu Date: Fri, 8 Nov 2024 13:32:23 +0800 Subject: [PATCH 1/4] disc cache --- torch_xla/csrc/runtime/disc/BUILD | 10 ++- .../csrc/runtime/disc/compile_result.proto | 17 +++++ torch_xla/csrc/runtime/disc/disc_compile.cc | 10 ++- torch_xla/csrc/runtime/disc/disc_ral.h | 1 + .../csrc/runtime/disc_computation_client.cc | 70 +++++++++++++++++++ .../csrc/runtime/disc_computation_client.h | 16 ++--- 6 files changed, 114 insertions(+), 10 deletions(-) create mode 100644 torch_xla/csrc/runtime/disc/compile_result.proto diff --git a/torch_xla/csrc/runtime/disc/BUILD b/torch_xla/csrc/runtime/disc/BUILD index 999aa85ea64..96f0eb8d279 100755 --- a/torch_xla/csrc/runtime/disc/BUILD +++ b/torch_xla/csrc/runtime/disc/BUILD @@ -1,3 +1,4 @@ +load("@com_google_protobuf//:protobuf.bzl", "cc_proto_library") load( "@tsl//tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured", @@ -37,7 +38,12 @@ ptxla_cc_library( "-DGOOGLE_CUDA", ] ) - +cc_proto_library( + name = "disc_compiler_result_proto", + srcs = [ + "compile_result.proto", + ], +) ptxla_cc_library( name = "disc_utils", srcs = ["disc_utils.cc"], @@ -58,6 +64,7 @@ ptxla_cc_library( deps = [ ":disc_ral", ":disc_utils", + ":disc_compiler_result_proto", "//torch_xla/csrc/runtime:tf_logging", "//torch_xla/csrc/runtime:sys_util", "//torch_xla/csrc/runtime:env_vars", @@ -79,3 +86,4 @@ ptxla_cc_test( "@tsl//tsl/platform:test_main", ] ) + diff --git a/torch_xla/csrc/runtime/disc/compile_result.proto b/torch_xla/csrc/runtime/disc/compile_result.proto new file mode 100644 index 00000000000..acb9d18b87c --- /dev/null +++ b/torch_xla/csrc/runtime/disc/compile_result.proto @@ -0,0 +1,17 @@ +syntax = "proto3"; + +package torch_xla.runtime.disc; + +option cc_enable_arenas = true; + +message DataSpec { + string device = 1; + int32 dtype = 2; +} +message DISCCompileResult { + bytes ral_library = 1; + bytes ral_meta_pb = 2; + repeated DataSpec input_specs = 3; + repeated DataSpec output_specs = 4; + repeated string devices = 5; +} \ No newline at end of file diff --git a/torch_xla/csrc/runtime/disc/disc_compile.cc b/torch_xla/csrc/runtime/disc/disc_compile.cc index 053535f5e2e..da11465851e 100644 --- a/torch_xla/csrc/runtime/disc/disc_compile.cc +++ b/torch_xla/csrc/runtime/disc/disc_compile.cc @@ -4,6 +4,7 @@ #include +#include "torch_xla/csrc/runtime/disc/compile_result.pb.h" #include "torch_xla/csrc/runtime/env_vars.h" #include "torch_xla/csrc/runtime/sys_util.h" #include "torch_xla/csrc/runtime/tf_logging.h" @@ -98,7 +99,14 @@ DISCComplationResult Compile(mlir::ModuleOp &module, res.ral_mate_pb = ReadFileBytes(absl::StrCat(output_fname, ".pbtxt")); res.inputs = inputs; res.outputs = outputs; - + DISCCompileResult result; + result.set_ral_library(output_fname); + result.set_ral_meta_pb(absl::StrCat(output_fname, ".pbtxt")); + for (const auto &input : inputs) { + auto data_meta = result.add_input_specs(); + data_meta->set_device(input.device); + data_meta->set_dtype(static_cast(input.scalar_type)); + } return res; } diff --git a/torch_xla/csrc/runtime/disc/disc_ral.h b/torch_xla/csrc/runtime/disc/disc_ral.h index f47431689c5..b850a3c6ef6 100644 --- a/torch_xla/csrc/runtime/disc/disc_ral.h +++ b/torch_xla/csrc/runtime/disc/disc_ral.h @@ -33,6 +33,7 @@ class RalContext { ~RalContext(); std::vector Execute(const std::vector& inputs); + DISCComplationResult GetDiscResult() { return disc_result_; } private: void BindingInputs(const std::vector& inputs, diff --git a/torch_xla/csrc/runtime/disc_computation_client.cc b/torch_xla/csrc/runtime/disc_computation_client.cc index 6465551dbde..56f2729de3a 100644 --- a/torch_xla/csrc/runtime/disc_computation_client.cc +++ b/torch_xla/csrc/runtime/disc_computation_client.cc @@ -16,10 +16,12 @@ #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "torch_xla/csrc/runtime/computation_client.h" +#include "torch_xla/csrc/runtime/disc/compile_result.pb.h" #include "torch_xla/csrc/runtime/disc/disc_compile.h" #include "torch_xla/csrc/runtime/env_vars.h" #include "torch_xla/csrc/runtime/stablehlo_helper.h" #include "torch_xla/csrc/runtime/sys_util.h" +#include "xla/client/xla_computation.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/float_normalization.h" #include "xla/service/gpu/gpu_float_support.h" @@ -405,6 +407,74 @@ size_t DISCComputationClient::GetNumDevices() const { return world_size_; } int DISCComputationClient::GetProcessIndex() const { return local_rank_; } int DISCComputationClient::GetNumProcesses() const { return world_size_; } +std::string DISCComputationClient::SerializeComputation( + const ComputationPtr computation) { + auto client = dynamic_cast(computation.get()); + auto hlo_proto = client->computation().proto(); + auto result = client->executable->GetDiscResult(); + torch_xla::runtime::disc::DISCCompileResult result_pb; + result_pb.set_ral_library(result.ral_lib); + result_pb.set_ral_meta_pb(result.ral_mate_pb); + for (const auto& input : result.inputs) { + auto data_meta = result_pb.add_input_specs(); + data_meta->set_device(input.device); + data_meta->set_dtype(static_cast(input.scalar_type)); + } + for (const auto& output : result.outputs) { + auto data_meta = result_pb.add_output_specs(); + data_meta->set_device(output.device); + data_meta->set_dtype(static_cast(output.scalar_type)); + } + for (auto device : computation->devices()) { + result_pb.add_devices(device); + } + return hlo_proto.SerializeAsString() + ":::" + result_pb.SerializeAsString(); +} +ComputationClient::ComputationPtr DISCComputationClient::DeserializeComputation( + const std::string& serialized) { + std::vector parts = absl::StrSplit(serialized, ":::"); + if (parts.size() != 2) { + XLA_ERROR() << "Invalid serialized computation, should have 2 parts with " + "separator ':::', got " + << parts.size(); + } + if (parts[1].size() > std::numeric_limits::max()) { + XLA_ERROR() << "Serialized DISCCompileResult proto too large (>2GB)\n"; + } + xla::HloModuleProto hlo_proto; + disc::DISCCompileResult result_proto; + hlo_proto.ParseFromString(parts[0]); + XLA_VLOG(0) << "DeserializeComputation: " << hlo_proto.DebugString() << "\n"; + result_proto.ParseFromString(parts[1]); + XLA_VLOG(0) << "DeserializeComputation: " << result_proto.DebugString() + << "\n"; + + disc::DISCComplationResult compile_result; + compile_result.ral_lib = result_proto.ral_library(); + compile_result.ral_mate_pb = result_proto.ral_meta_pb(); + for (const auto& input : result_proto.input_specs()) { + disc::DataMeta data_meta; + data_meta.device = input.device(); + data_meta.scalar_type = static_cast(input.dtype()); + compile_result.inputs.push_back(data_meta); + } + for (const auto& output : result_proto.output_specs()) { + disc::DataMeta data_meta; + data_meta.device = output.device(); + data_meta.scalar_type = static_cast(output.dtype()); + compile_result.outputs.push_back(data_meta); + } + std::vector devices; + for (const auto& device : result_proto.devices()) { + devices.push_back(device); + } + + auto ral_context = std::make_unique(compile_result); + auto computation = std::make_shared( + std::move(xla::XlaComputation(hlo_proto)), devices, + std::move(ral_context)); + return computation; +} } // namespace runtime } // namespace torch_xla diff --git a/torch_xla/csrc/runtime/disc_computation_client.h b/torch_xla/csrc/runtime/disc_computation_client.h index 0701d7b3591..e4aeb223768 100644 --- a/torch_xla/csrc/runtime/disc_computation_client.h +++ b/torch_xla/csrc/runtime/disc_computation_client.h @@ -11,6 +11,8 @@ namespace runtime { class DISCComputationClient : public ComputationClient { public: + const std::string DefaultDevicePrefix = "CUDA:"; + DISCComputationClient(); ~DISCComputationClient(); @@ -55,15 +57,13 @@ class DISCComputationClient : public ComputationClient { XLA_ERROR() << __FUNCTION__ << " not implemented"; } - std::string SerializeComputation(const ComputationPtr computation) override { - XLA_ERROR() << __FUNCTION__ << " not implemented"; - } - - ComputationPtr DeserializeComputation( - const std::string& serialized) override { - XLA_ERROR() << __FUNCTION__ << " not implemented"; - } + std::string SerializeComputation(const ComputationPtr computation) override; + ComputationClient::ComputationPtr DeserializeComputation( + const std::string& serialized) override; + //{ + // XLA_ERROR() << __FUNCTION__ << " not implemented"; + //} torch::lazy::hash_t HashCompilationEnv() override { // TODO(wangang.wa): Improve this function. return torch::lazy::hash_t(); From 4353509925abf45b74e6654b98cb717237280609 Mon Sep 17 00:00:00 2001 From: Yan Xu Date: Fri, 8 Nov 2024 16:40:31 +0800 Subject: [PATCH 2/4] update --- torch_xla/csrc/runtime/disc_computation_client.cc | 3 --- torch_xla/csrc/runtime/disc_computation_client.h | 3 --- 2 files changed, 6 deletions(-) diff --git a/torch_xla/csrc/runtime/disc_computation_client.cc b/torch_xla/csrc/runtime/disc_computation_client.cc index 56f2729de3a..6393238431d 100644 --- a/torch_xla/csrc/runtime/disc_computation_client.cc +++ b/torch_xla/csrc/runtime/disc_computation_client.cc @@ -444,10 +444,7 @@ ComputationClient::ComputationPtr DISCComputationClient::DeserializeComputation( xla::HloModuleProto hlo_proto; disc::DISCCompileResult result_proto; hlo_proto.ParseFromString(parts[0]); - XLA_VLOG(0) << "DeserializeComputation: " << hlo_proto.DebugString() << "\n"; result_proto.ParseFromString(parts[1]); - XLA_VLOG(0) << "DeserializeComputation: " << result_proto.DebugString() - << "\n"; disc::DISCComplationResult compile_result; compile_result.ral_lib = result_proto.ral_library(); diff --git a/torch_xla/csrc/runtime/disc_computation_client.h b/torch_xla/csrc/runtime/disc_computation_client.h index e4aeb223768..f95a184a41b 100644 --- a/torch_xla/csrc/runtime/disc_computation_client.h +++ b/torch_xla/csrc/runtime/disc_computation_client.h @@ -61,9 +61,6 @@ class DISCComputationClient : public ComputationClient { ComputationClient::ComputationPtr DeserializeComputation( const std::string& serialized) override; - //{ - // XLA_ERROR() << __FUNCTION__ << " not implemented"; - //} torch::lazy::hash_t HashCompilationEnv() override { // TODO(wangang.wa): Improve this function. return torch::lazy::hash_t(); From 9913ef008ab22eccdaa592d74772e752a6a30db3 Mon Sep 17 00:00:00 2001 From: Yan Xu Date: Wed, 27 Nov 2024 16:42:05 +0800 Subject: [PATCH 3/4] update with comment --- bazel/disc.BUILD | 4 ++-- torch_xla/csrc/runtime/disc/disc_compile.cc | 8 -------- torch_xla/csrc/runtime/disc_computation_client.cc | 6 +++--- 3 files changed, 5 insertions(+), 13 deletions(-) diff --git a/bazel/disc.BUILD b/bazel/disc.BUILD index 0ca499011d3..465b6f81065 100644 --- a/bazel/disc.BUILD +++ b/bazel/disc.BUILD @@ -25,12 +25,12 @@ cc_library( cc_import( name="disc_ral_cuda", - shared_library = ":libral_base_context.so", + shared_library = "build/libral_base_context.so", ) cc_import( name="disc_custom_op", - shared_library = ":libdisc_custom_ops.so", + shared_library = "build/libdisc_custom_ops.so", ) genrule( diff --git a/torch_xla/csrc/runtime/disc/disc_compile.cc b/torch_xla/csrc/runtime/disc/disc_compile.cc index da11465851e..2037efa0f55 100644 --- a/torch_xla/csrc/runtime/disc/disc_compile.cc +++ b/torch_xla/csrc/runtime/disc/disc_compile.cc @@ -99,14 +99,6 @@ DISCComplationResult Compile(mlir::ModuleOp &module, res.ral_mate_pb = ReadFileBytes(absl::StrCat(output_fname, ".pbtxt")); res.inputs = inputs; res.outputs = outputs; - DISCCompileResult result; - result.set_ral_library(output_fname); - result.set_ral_meta_pb(absl::StrCat(output_fname, ".pbtxt")); - for (const auto &input : inputs) { - auto data_meta = result.add_input_specs(); - data_meta->set_device(input.device); - data_meta->set_dtype(static_cast(input.scalar_type)); - } return res; } diff --git a/torch_xla/csrc/runtime/disc_computation_client.cc b/torch_xla/csrc/runtime/disc_computation_client.cc index 6393238431d..44271bd1227 100644 --- a/torch_xla/csrc/runtime/disc_computation_client.cc +++ b/torch_xla/csrc/runtime/disc_computation_client.cc @@ -138,8 +138,6 @@ std::vector DISCComputationClient::TransferToDevice( auto dtype = at::TensorOptions(TorchTypeFromXlaType(tensor->shape().element_type())); auto ret = at::empty(sizes, dtype).contiguous(); - // tensor->populate_fn(tensor, ret.data_ptr(), - // ret.element_size() * ret.numel()); std::memcpy(ret.data_ptr(), tensor->data(), ret.element_size() * ret.numel()); @@ -407,6 +405,7 @@ size_t DISCComputationClient::GetNumDevices() const { return world_size_; } int DISCComputationClient::GetProcessIndex() const { return local_rank_; } int DISCComputationClient::GetNumProcesses() const { return world_size_; } + std::string DISCComputationClient::SerializeComputation( const ComputationPtr computation) { auto client = dynamic_cast(computation.get()); @@ -428,7 +427,8 @@ std::string DISCComputationClient::SerializeComputation( for (auto device : computation->devices()) { result_pb.add_devices(device); } - return hlo_proto.SerializeAsString() + ":::" + result_pb.SerializeAsString(); + return absl::StrCat(hlo_proto.SerializeAsString(), + ":::", result_pb.SerializeAsString()); } ComputationClient::ComputationPtr DISCComputationClient::DeserializeComputation( const std::string& serialized) { From 0f192743edc4a23fae045f1eaa2980c53cc0f796 Mon Sep 17 00:00:00 2001 From: Yan Xu Date: Wed, 27 Nov 2024 16:46:22 +0800 Subject: [PATCH 4/4] update with comment --- bazel/disc.BUILD | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bazel/disc.BUILD b/bazel/disc.BUILD index 465b6f81065..0ca499011d3 100644 --- a/bazel/disc.BUILD +++ b/bazel/disc.BUILD @@ -25,12 +25,12 @@ cc_library( cc_import( name="disc_ral_cuda", - shared_library = "build/libral_base_context.so", + shared_library = ":libral_base_context.so", ) cc_import( name="disc_custom_op", - shared_library = "build/libdisc_custom_ops.so", + shared_library = ":libdisc_custom_ops.so", ) genrule(