From 2f47a2225998a76bbedde32488de37b056ea94e1 Mon Sep 17 00:00:00 2001 From: Cyril MARLIN Date: Tue, 20 Apr 2021 15:24:40 +0200 Subject: [PATCH] [Feature] bind delete_keys parameter on tf_client update_priority, like client API --- reverb/cc/ops/client.cc | 15 +++++++++++++-- reverb/tf_client.py | 7 ++++++- reverb/tf_client_test.py | 34 +++++++++++++++++++++++++++++++++- 3 files changed, 52 insertions(+), 4 deletions(-) diff --git a/reverb/cc/ops/client.cc b/reverb/cc/ops/client.cc index 5f7a6e40..b9831bb5 100644 --- a/reverb/cc/ops/client.cc +++ b/reverb/cc/ops/client.cc @@ -70,6 +70,7 @@ REGISTER_OP("ReverbClientUpdatePriorities") .Input("table: string") .Input("keys: uint64") .Input("priorities: double") + .Input("keys_to_delete: uint64") .Doc(R"doc( Blocking call to update the priorities of a collection of items. Keys that could not be found in table `table` on server are ignored and does not impact the rest @@ -187,7 +188,9 @@ class UpdatePrioritiesOp : public tensorflow::OpKernel { const tensorflow::Tensor* keys; OP_REQUIRES_OK(context, context->input("keys", &keys)); const tensorflow::Tensor* priorities; - OP_REQUIRES_OK(context, context->input("priorities", &priorities)); + OP_REQUIRES_OK(context, context->input("priorities", &priorities)); + const tensorflow::Tensor* keys_to_delete; + OP_REQUIRES_OK(context, context->input("keys_to_delete", &keys_to_delete)); OP_REQUIRES( context, keys->dims() == 1, @@ -197,6 +200,9 @@ class UpdatePrioritiesOp : public tensorflow::OpKernel { "Tensors `keys` and `priorities` do not match in shape (", keys->shape().DebugString(), " vs. ", priorities->shape().DebugString(), ")")); + OP_REQUIRES( + context, keys_to_delete->dims() == 1, + InvalidArgument("Tensors `keys_to_delete` must be of rank 1.")); std::string table_str = table->scalar()(); std::vector updates; @@ -207,6 +213,11 @@ class UpdatePrioritiesOp : public tensorflow::OpKernel { updates.push_back(std::move(update)); } + std::vector deletes; + for (int i = 0; i < keys_to_delete->dim_size(0); i++) { + deletes.push_back(keys_to_delete->flat()(i)); + } + // The call will only fail if the Reverb-server is brought down during an // active call (e.g preempted). When this happens the request is retried and // since MutatePriorities sets `wait_for_ready` the request will no be sent @@ -214,7 +225,7 @@ class UpdatePrioritiesOp : public tensorflow::OpKernel { // this retry in this tight loop. absl::Status status; do { - status = resource->client()->MutatePriorities(table_str, updates, {}); + status = resource->client()->MutatePriorities(table_str, updates, deletes); } while (absl::IsUnavailable(status) || absl::IsDeadlineExceeded(status)); OP_REQUIRES_OK(context, ToTensorflowStatus(status)); } diff --git a/reverb/tf_client.py b/reverb/tf_client.py index 2d43b4da..67e6ac23 100644 --- a/reverb/tf_client.py +++ b/reverb/tf_client.py @@ -117,6 +117,7 @@ def update_priorities(self, table: str, keys: tf.Tensor, priorities: tf.Tensor, + keys_to_delete: Optional[tf.Tensor] = None, name: str = None): """Creates op for updating priorities of existing items in the replay. @@ -126,16 +127,20 @@ def update_priorities(self, table: Probability table to update. keys: Keys of the items to update. Must be same length as `priorities`. priorities: New priorities for `keys`. Must be same length as `keys`. + keys_to_delete: Keys of the items to delete name: Optional name for the operation. Returns: A tf-op for performing the update. """ + if keys_to_delete is None: + keys_to_delete = tf.constant([], dtype=tf.uint64) + with tf.name_scope(name, f'{self._name}_update_priorities', ['update_priorities']) as scope: return gen_client_ops.reverb_client_update_priorities( - self._handle, table, keys, priorities, name=scope) + self._handle, table, keys, priorities, keys_to_delete, name=scope) def dataset(self, table: str, diff --git a/reverb/tf_client_test.py b/reverb/tf_client_test.py index 516b15a8..02f6703c 100644 --- a/reverb/tf_client_test.py +++ b/reverb/tf_client_test.py @@ -25,7 +25,6 @@ from reverb import tf_client import tensorflow.compat.v1 as tf - def make_server(): return server.Server( tables=[ @@ -161,6 +160,39 @@ def test_priority_update_is_applied(self): self.fail('Updated item was not found') + def test_delete_key_is_applied(self): + # Start with 4 items + for i in range(4): + self._client.insert([np.array([i], dtype=np.uint32)], {'dist': 1}) + + # Until we have recieved all 4 items. + items = {} + while len(items) < 4: + item = next(self._client.sample('dist'))[0] + items[item.info.key] = item.info.probability + + # remove 2 items + items_to_keep = [*items.keys()][:2] + items_to_remove = [*items.keys()][2:] + with self.session() as session: + client = tf_client.TFClient(self._client.server_address) + for key in items_to_remove: + update_op = client.update_priorities( + table=tf.constant('dist'), + keys=tf.constant([], dtype=tf.uint64), + priorities=tf.constant([], dtype=tf.float64), + keys_to_delete=tf.constant([key], dtype=tf.uint64)) + self.assertIsNone(session.run(update_op)) + + # 2 remaining items must persist + final_items = {} + for _ in range(1000): + item = next(self._client.sample('dist'))[0] + self.assertTrue(item.info.key in items_to_keep) + final_items[item.info.key] = item.info.probability + self.assertEqual(len(final_items), 2) + + class InsertOpTest(tf.test.TestCase): @classmethod