Skip to content

Commit

Permalink
update with comment
Browse files Browse the repository at this point in the history
  • Loading branch information
Yancey1989 committed Nov 27, 2024
1 parent 4353509 commit 9913ef0
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 13 deletions.
4 changes: 2 additions & 2 deletions bazel/disc.BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@ cc_library(

cc_import(
name="disc_ral_cuda",
shared_library = ":libral_base_context.so",
shared_library = "build/libral_base_context.so",
)

cc_import(
name="disc_custom_op",
shared_library = ":libdisc_custom_ops.so",
shared_library = "build/libdisc_custom_ops.so",
)

genrule(
Expand Down
8 changes: 0 additions & 8 deletions torch_xla/csrc/runtime/disc/disc_compile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,6 @@ DISCComplationResult Compile(mlir::ModuleOp &module,
res.ral_mate_pb = ReadFileBytes(absl::StrCat(output_fname, ".pbtxt"));
res.inputs = inputs;
res.outputs = outputs;
DISCCompileResult result;
result.set_ral_library(output_fname);
result.set_ral_meta_pb(absl::StrCat(output_fname, ".pbtxt"));
for (const auto &input : inputs) {
auto data_meta = result.add_input_specs();
data_meta->set_device(input.device);
data_meta->set_dtype(static_cast<int>(input.scalar_type));
}
return res;
}

Expand Down
6 changes: 3 additions & 3 deletions torch_xla/csrc/runtime/disc_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,6 @@ std::vector<ComputationClient::DataPtr> DISCComputationClient::TransferToDevice(
auto dtype =
at::TensorOptions(TorchTypeFromXlaType(tensor->shape().element_type()));
auto ret = at::empty(sizes, dtype).contiguous();
// tensor->populate_fn(tensor, ret.data_ptr(),
// ret.element_size() * ret.numel());
std::memcpy(ret.data_ptr(), tensor->data(),
ret.element_size() * ret.numel());

Expand Down Expand Up @@ -407,6 +405,7 @@ size_t DISCComputationClient::GetNumDevices() const { return world_size_; }
int DISCComputationClient::GetProcessIndex() const { return local_rank_; }

int DISCComputationClient::GetNumProcesses() const { return world_size_; }

std::string DISCComputationClient::SerializeComputation(
const ComputationPtr computation) {
auto client = dynamic_cast<DISCComputation*>(computation.get());
Expand All @@ -428,7 +427,8 @@ std::string DISCComputationClient::SerializeComputation(
for (auto device : computation->devices()) {
result_pb.add_devices(device);
}
return hlo_proto.SerializeAsString() + ":::" + result_pb.SerializeAsString();
return absl::StrCat(hlo_proto.SerializeAsString(),
":::", result_pb.SerializeAsString());
}
ComputationClient::ComputationPtr DISCComputationClient::DeserializeComputation(
const std::string& serialized) {
Expand Down

0 comments on commit 9913ef0

Please sign in to comment.