From cd0920c8164d65fa939311d40c304898e1c68fa6 Mon Sep 17 00:00:00 2001 From: Jesse Schmidt Date: Wed, 6 Dec 2023 17:55:34 -0800 Subject: [PATCH] fix: TOOLS-2656 exact node ID matches with another node ID with the same prefix --- lib/live_cluster/client/cluster.py | 9 ++++++--- lib/utils/lookup_dict.py | 6 ++++++ test/unit/live_cluster/client/test_cluster.py | 14 ++++++++++++++ 3 files changed, 26 insertions(+), 3 deletions(-) diff --git a/lib/live_cluster/client/cluster.py b/lib/live_cluster/client/cluster.py index 4c6c1174..5e9cca50 100644 --- a/lib/live_cluster/client/cluster.py +++ b/lib/live_cluster/client/cluster.py @@ -430,9 +430,12 @@ def get_node(self, node) -> list[Node]: # Me must now look for exact matches. # Can't use "if not in self.node_lookup" here because we need to check for - # exact matches. Unless using node id than this condition requires they provide ip:port. - if node in self.node_lookup.keys(): - return self.node_lookup[node] + # exact matches. Unless using node id than this condition requires they provide + # ip:port. + try: + return [self.node_lookup.get_exact(node)] + except KeyError: + pass node_matches = self.node_lookup[node] match = None diff --git a/lib/utils/lookup_dict.py b/lib/utils/lookup_dict.py index 573ac2e5..99dafc9c 100644 --- a/lib/utils/lookup_dict.py +++ b/lib/utils/lookup_dict.py @@ -149,6 +149,12 @@ def get(self, k) -> list[ValueType]: keys = self.get_key(k) return [self._kv[key] for key in keys] + def get_exact(self, k) -> ValueType: + if k in self._kv: + return self._kv[k] + + raise KeyError("Unable to find key '%s'" % (k)) + def remove(self, k): keys = self.get_key(k) diff --git a/test/unit/live_cluster/client/test_cluster.py b/test/unit/live_cluster/client/test_cluster.py index 8886a966..dfdb3314 100644 --- a/test/unit/live_cluster/client/test_cluster.py +++ b/test/unit/live_cluster/client/test_cluster.py @@ -198,6 +198,11 @@ async def test_get_node(self): n = await self.get_info_mock("A0000000000000" + str(i), ip=ip, port=port) cl.update_node(n) + n = await self.get_info_mock("A", ip="1.1.1.1", port=3000) + cl.update_node(n) + n = await self.get_info_mock("AB", ip="2.2.2.2", port=3000) + cl.update_node(n) + expected = [ "192.168.0.1:3000", "192.168.0.1:3001", @@ -277,6 +282,15 @@ async def test_get_node(self): self.assertCountEqual(expected, actual) + expected = [ + "1.1.1.1:3000", + ] + + actual = cl.get_node("A") + actual = map(lambda x: x.key, actual) + + self.assertCountEqual(expected, actual) + async def test_get_nodes(self): cl = await self.get_cluster_mock(3)