Skip to content

Commit

Permalink
Expose KeyValueStore of the distributed service as a python binding.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 448820874
  • Loading branch information
yashk2810 authored and tensorflower-gardener committed May 15, 2022
1 parent cd33ee9 commit 73c1b1e
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 2 deletions.
12 changes: 11 additions & 1 deletion tensorflow/compiler/xla/python/xla.cc
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,17 @@ PYBIND11_MODULE(xla_extension, m) {
std::shared_ptr<DistributedRuntimeClient>>
distributed_runtime_client(m, "DistributedRuntimeClient");
distributed_runtime_client.def("connect", &DistributedRuntimeClient::Connect)
.def("shutdown", &DistributedRuntimeClient::Shutdown);
.def("shutdown", &DistributedRuntimeClient::Shutdown)
.def(
"blocking_key_value_get",
[](DistributedRuntimeClient& client, std::string key,
int64_t timeout_in_ms) {
return client.BlockingKeyValueGet(
key, absl::Milliseconds(timeout_in_ms));
},
py::arg("key"), py::arg("timeout_in_ms"))
.def("key_value_set", &DistributedRuntimeClient::KeyValueSet,
py::arg("key"), py::arg("value"));

m.def(
"get_distributed_runtime_service",
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/compiler/xla/python/xla_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@

# Just an internal arbitrary increasing number to help with backward-compatible
# changes.
_version = 66
_version = 67

# Version number for MLIR:Python components.
mlir_api_version = 18
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/compiler/xla/python/xla_extension/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,8 @@ class DistributedRuntimeService:
class DistributedRuntimeClient:
def connect(self) -> _Status: ...
def shutdown(self) -> _Status: ...
def blocking_key_value_get(self, key: str, timeout_in_ms: int) -> _Status: ...
def key_value_set(self, key: str, value: str) -> _Status: ...

def get_distributed_runtime_service(
address: str,
Expand Down

0 comments on commit 73c1b1e

Please sign in to comment.