Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support local compilation cache in disc backend #20

Open
wants to merge 4 commits into
base: acc
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion torch_xla/csrc/runtime/disc/BUILD
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
load("@com_google_protobuf//:protobuf.bzl", "cc_proto_library")
load(
"@tsl//tsl/platform/default:cuda_build_defs.bzl",
"if_cuda_is_configured",
Expand Down Expand Up @@ -37,7 +38,12 @@ ptxla_cc_library(
"-DGOOGLE_CUDA",
]
)

cc_proto_library(
name = "disc_compiler_result_proto",
srcs = [
"compile_result.proto",
],
)
ptxla_cc_library(
name = "disc_utils",
srcs = ["disc_utils.cc"],
Expand All @@ -58,6 +64,7 @@ ptxla_cc_library(
deps = [
":disc_ral",
":disc_utils",
":disc_compiler_result_proto",
"//torch_xla/csrc/runtime:tf_logging",
"//torch_xla/csrc/runtime:sys_util",
"//torch_xla/csrc/runtime:env_vars",
Expand All @@ -79,3 +86,4 @@ ptxla_cc_test(
"@tsl//tsl/platform:test_main",
]
)

17 changes: 17 additions & 0 deletions torch_xla/csrc/runtime/disc/compile_result.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
syntax = "proto3";

package torch_xla.runtime.disc;

option cc_enable_arenas = true;

message DataSpec {
string device = 1;
int32 dtype = 2;
}
message DISCCompileResult {
bytes ral_library = 1;
bytes ral_meta_pb = 2;
repeated DataSpec input_specs = 3;
repeated DataSpec output_specs = 4;
repeated string devices = 5;
}
2 changes: 1 addition & 1 deletion torch_xla/csrc/runtime/disc/disc_compile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include <filesystem>

#include "torch_xla/csrc/runtime/disc/compile_result.pb.h"
#include "torch_xla/csrc/runtime/env_vars.h"
#include "torch_xla/csrc/runtime/sys_util.h"
#include "torch_xla/csrc/runtime/tf_logging.h"
Expand Down Expand Up @@ -98,7 +99,6 @@ DISCComplationResult Compile(mlir::ModuleOp &module,
res.ral_mate_pb = ReadFileBytes(absl::StrCat(output_fname, ".pbtxt"));
res.inputs = inputs;
res.outputs = outputs;

return res;
}

Expand Down
1 change: 1 addition & 0 deletions torch_xla/csrc/runtime/disc/disc_ral.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class RalContext {
~RalContext();

std::vector<at::Tensor> Execute(const std::vector<at::Tensor>& inputs);
DISCComplationResult GetDiscResult() { return disc_result_; }

private:
void BindingInputs(const std::vector<at::Tensor>& inputs,
Expand Down
71 changes: 69 additions & 2 deletions torch_xla/csrc/runtime/disc_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "torch_xla/csrc/runtime/computation_client.h"
#include "torch_xla/csrc/runtime/disc/compile_result.pb.h"
#include "torch_xla/csrc/runtime/disc/disc_compile.h"
#include "torch_xla/csrc/runtime/env_vars.h"
#include "torch_xla/csrc/runtime/stablehlo_helper.h"
#include "torch_xla/csrc/runtime/sys_util.h"
#include "xla/client/xla_computation.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/service/float_normalization.h"
#include "xla/service/gpu/gpu_float_support.h"
Expand Down Expand Up @@ -136,8 +138,6 @@ std::vector<ComputationClient::DataPtr> DISCComputationClient::TransferToDevice(
auto dtype =
at::TensorOptions(TorchTypeFromXlaType(tensor->shape().element_type()));
auto ret = at::empty(sizes, dtype).contiguous();
// tensor->populate_fn(tensor, ret.data_ptr(),
// ret.element_size() * ret.numel());
std::memcpy(ret.data_ptr(), tensor->data(),
ret.element_size() * ret.numel());

Expand Down Expand Up @@ -406,5 +406,72 @@ int DISCComputationClient::GetProcessIndex() const { return local_rank_; }

int DISCComputationClient::GetNumProcesses() const { return world_size_; }

std::string DISCComputationClient::SerializeComputation(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add some blank lines between functions.

const ComputationPtr computation) {
auto client = dynamic_cast<DISCComputation*>(computation.get());
auto hlo_proto = client->computation().proto();
auto result = client->executable->GetDiscResult();
torch_xla::runtime::disc::DISCCompileResult result_pb;
result_pb.set_ral_library(result.ral_lib);
result_pb.set_ral_meta_pb(result.ral_mate_pb);
for (const auto& input : result.inputs) {
auto data_meta = result_pb.add_input_specs();
data_meta->set_device(input.device);
data_meta->set_dtype(static_cast<int>(input.scalar_type));
}
for (const auto& output : result.outputs) {
auto data_meta = result_pb.add_output_specs();
data_meta->set_device(output.device);
data_meta->set_dtype(static_cast<int>(output.scalar_type));
}
for (auto device : computation->devices()) {
result_pb.add_devices(device);
}
return absl::StrCat(hlo_proto.SerializeAsString(),
":::", result_pb.SerializeAsString());
}
ComputationClient::ComputationPtr DISCComputationClient::DeserializeComputation(
const std::string& serialized) {
std::vector<std::string> parts = absl::StrSplit(serialized, ":::");
if (parts.size() != 2) {
XLA_ERROR() << "Invalid serialized computation, should have 2 parts with "
"separator ':::', got "
<< parts.size();
}
if (parts[1].size() > std::numeric_limits<int>::max()) {
XLA_ERROR() << "Serialized DISCCompileResult proto too large (>2GB)\n";
}
xla::HloModuleProto hlo_proto;
disc::DISCCompileResult result_proto;
hlo_proto.ParseFromString(parts[0]);
result_proto.ParseFromString(parts[1]);

disc::DISCComplationResult compile_result;
compile_result.ral_lib = result_proto.ral_library();
compile_result.ral_mate_pb = result_proto.ral_meta_pb();
for (const auto& input : result_proto.input_specs()) {
disc::DataMeta data_meta;
data_meta.device = input.device();
data_meta.scalar_type = static_cast<at::ScalarType>(input.dtype());
compile_result.inputs.push_back(data_meta);
}
for (const auto& output : result_proto.output_specs()) {
disc::DataMeta data_meta;
data_meta.device = output.device();
data_meta.scalar_type = static_cast<at::ScalarType>(output.dtype());
compile_result.outputs.push_back(data_meta);
}
std::vector<std::string> devices;
for (const auto& device : result_proto.devices()) {
devices.push_back(device);
}

auto ral_context = std::make_unique<disc::RalContext>(compile_result);
auto computation = std::make_shared<DISCComputation>(
std::move(xla::XlaComputation(hlo_proto)), devices,
std::move(ral_context));
return computation;
}

} // namespace runtime
} // namespace torch_xla
13 changes: 5 additions & 8 deletions torch_xla/csrc/runtime/disc_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ namespace runtime {

class DISCComputationClient : public ComputationClient {
public:
const std::string DefaultDevicePrefix = "CUDA:";

DISCComputationClient();
~DISCComputationClient();

Expand Down Expand Up @@ -55,15 +57,10 @@ class DISCComputationClient : public ComputationClient {
XLA_ERROR() << __FUNCTION__ << " not implemented";
}

std::string SerializeComputation(const ComputationPtr computation) override {
XLA_ERROR() << __FUNCTION__ << " not implemented";
}

ComputationPtr DeserializeComputation(
const std::string& serialized) override {
XLA_ERROR() << __FUNCTION__ << " not implemented";
}
std::string SerializeComputation(const ComputationPtr computation) override;

ComputationClient::ComputationPtr DeserializeComputation(
const std::string& serialized) override;
torch::lazy::hash_t HashCompilationEnv() override {
// TODO(wangang.wa): Improve this function.
return torch::lazy::hash_t();
Expand Down