Skip to content

Commit

Permalink
[Feature] bind delete_keys parameter on tf_client update_priority, li…
Browse files Browse the repository at this point in the history
…ke client API
  • Loading branch information
cmarlin committed Apr 20, 2021
1 parent 2b74d54 commit 2f47a22
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 4 deletions.
15 changes: 13 additions & 2 deletions reverb/cc/ops/client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ REGISTER_OP("ReverbClientUpdatePriorities")
.Input("table: string")
.Input("keys: uint64")
.Input("priorities: double")
.Input("keys_to_delete: uint64")
.Doc(R"doc(
Blocking call to update the priorities of a collection of items. Keys that could
not be found in table `table` on server are ignored and does not impact the rest
Expand Down Expand Up @@ -187,7 +188,9 @@ class UpdatePrioritiesOp : public tensorflow::OpKernel {
const tensorflow::Tensor* keys;
OP_REQUIRES_OK(context, context->input("keys", &keys));
const tensorflow::Tensor* priorities;
OP_REQUIRES_OK(context, context->input("priorities", &priorities));
OP_REQUIRES_OK(context, context->input("priorities", &priorities));
const tensorflow::Tensor* keys_to_delete;
OP_REQUIRES_OK(context, context->input("keys_to_delete", &keys_to_delete));

OP_REQUIRES(
context, keys->dims() == 1,
Expand All @@ -197,6 +200,9 @@ class UpdatePrioritiesOp : public tensorflow::OpKernel {
"Tensors `keys` and `priorities` do not match in shape (",
keys->shape().DebugString(), " vs. ",
priorities->shape().DebugString(), ")"));
OP_REQUIRES(
context, keys_to_delete->dims() == 1,
InvalidArgument("Tensors `keys_to_delete` must be of rank 1."));

std::string table_str = table->scalar<tstring>()();
std::vector<KeyWithPriority> updates;
Expand All @@ -207,14 +213,19 @@ class UpdatePrioritiesOp : public tensorflow::OpKernel {
updates.push_back(std::move(update));
}

std::vector<uint64_t> deletes;
for (int i = 0; i < keys_to_delete->dim_size(0); i++) {
deletes.push_back(keys_to_delete->flat<tensorflow::uint64>()(i));
}

// The call will only fail if the Reverb-server is brought down during an
// active call (e.g preempted). When this happens the request is retried and
// since MutatePriorities sets `wait_for_ready` the request will no be sent
// before the server is brought up again. It is therefore no problem to have
// this retry in this tight loop.
absl::Status status;
do {
status = resource->client()->MutatePriorities(table_str, updates, {});
status = resource->client()->MutatePriorities(table_str, updates, deletes);
} while (absl::IsUnavailable(status) || absl::IsDeadlineExceeded(status));
OP_REQUIRES_OK(context, ToTensorflowStatus(status));
}
Expand Down
7 changes: 6 additions & 1 deletion reverb/tf_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def update_priorities(self,
table: str,
keys: tf.Tensor,
priorities: tf.Tensor,
keys_to_delete: Optional[tf.Tensor] = None,
name: str = None):
"""Creates op for updating priorities of existing items in the replay.
Expand All @@ -126,16 +127,20 @@ def update_priorities(self,
table: Probability table to update.
keys: Keys of the items to update. Must be same length as `priorities`.
priorities: New priorities for `keys`. Must be same length as `keys`.
keys_to_delete: Keys of the items to delete
name: Optional name for the operation.
Returns:
A tf-op for performing the update.
"""

if keys_to_delete is None:
keys_to_delete = tf.constant([], dtype=tf.uint64)

with tf.name_scope(name, f'{self._name}_update_priorities',
['update_priorities']) as scope:
return gen_client_ops.reverb_client_update_priorities(
self._handle, table, keys, priorities, name=scope)
self._handle, table, keys, priorities, keys_to_delete, name=scope)

def dataset(self,
table: str,
Expand Down
34 changes: 33 additions & 1 deletion reverb/tf_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from reverb import tf_client
import tensorflow.compat.v1 as tf


def make_server():
return server.Server(
tables=[
Expand Down Expand Up @@ -161,6 +160,39 @@ def test_priority_update_is_applied(self):
self.fail('Updated item was not found')


def test_delete_key_is_applied(self):
# Start with 4 items
for i in range(4):
self._client.insert([np.array([i], dtype=np.uint32)], {'dist': 1})

# Until we have recieved all 4 items.
items = {}
while len(items) < 4:
item = next(self._client.sample('dist'))[0]
items[item.info.key] = item.info.probability

# remove 2 items
items_to_keep = [*items.keys()][:2]
items_to_remove = [*items.keys()][2:]
with self.session() as session:
client = tf_client.TFClient(self._client.server_address)
for key in items_to_remove:
update_op = client.update_priorities(
table=tf.constant('dist'),
keys=tf.constant([], dtype=tf.uint64),
priorities=tf.constant([], dtype=tf.float64),
keys_to_delete=tf.constant([key], dtype=tf.uint64))
self.assertIsNone(session.run(update_op))

# 2 remaining items must persist
final_items = {}
for _ in range(1000):
item = next(self._client.sample('dist'))[0]
self.assertTrue(item.info.key in items_to_keep)
final_items[item.info.key] = item.info.probability
self.assertEqual(len(final_items), 2)


class InsertOpTest(tf.test.TestCase):

@classmethod
Expand Down

0 comments on commit 2f47a22

Please sign in to comment.