From 1ecc9efcaf82bad41a7541adf1b523b9f0bdbf7c Mon Sep 17 00:00:00 2001 From: Zengxiang Lu <24939744@qq.com> Date: Sun, 26 Jan 2025 11:28:19 +0800 Subject: [PATCH] Refactor recursive functions to iterative to prevent stack overflow Convert `_match_prefix_helper`, `_insert_helper`, `_print_helper`, and `_total_size_helper` from recursive to iterative implementations using stacks or loop. This prevents potential stack overflow errors with deep trees while maintaining the same functionality. Also update the test cases to use tensor values instead of strings for better testing of the radix cache's core functionality. --- python/sglang/srt/mem_cache/radix_cache.py | 110 ++++++++++++--------- 1 file changed, 62 insertions(+), 48 deletions(-) diff --git a/python/sglang/srt/mem_cache/radix_cache.py b/python/sglang/srt/mem_cache/radix_cache.py index 3bf87b54299..5b24634126b 100644 --- a/python/sglang/srt/mem_cache/radix_cache.py +++ b/python/sglang/srt/mem_cache/radix_cache.py @@ -111,14 +111,12 @@ def match_prefix(self, key: List[int], **kwargs) -> Tuple[torch.Tensor, int]: if self.disable: return [], self.root_node - value = [] - last_node = [self.root_node] - self._match_prefix_helper(self.root_node, key, value, last_node) + value, last_node = self._match_prefix_helper(self.root_node, key) if value: value = torch.concat(value) else: value = torch.tensor([], dtype=torch.int32) - return value, last_node[0] + return value, last_node def insert(self, key: List, value=None): if self.disable: @@ -191,7 +189,7 @@ def pretty_print(self): print(f"#tokens: {self.total_size()}") def total_size(self): - return self._total_size_helper(self.root_node) + return self._total_size_helper() def evict(self, num_tokens: int, evict_callback: Callable): if self.disable: @@ -253,24 +251,22 @@ def protected_size(self): ##### Internal Helper Functions ##### - def _match_prefix_helper( - self, node: TreeNode, key: List, value, last_node: TreeNode - ): + def _match_prefix_helper(self, node: TreeNode, key: List): node.last_access_time = time.time() - if len(key) == 0: - return - - if key[0] in node.children.keys(): + value = [] + while len(key) > 0 and key[0] in node.children.keys(): child = node.children[key[0]] prefix_len = _key_match(child.key, key) if prefix_len < len(child.key): new_node = self._split_node(child.key, child, prefix_len) value.append(new_node.value) - last_node[0] = new_node + node = new_node + break else: value.append(child.value) - last_node[0] = child - self._match_prefix_helper(child, key[prefix_len:], value, last_node) + node = child + key = key[prefix_len:] + return value, node def _split_node(self, key, child: TreeNode, split_len: int): # new_node -> child @@ -291,22 +287,25 @@ def _insert_helper(self, node: TreeNode, key: List, value): if len(key) == 0: return 0 - if key[0] in node.children.keys(): + total_prefix_length = 0 + while len(key) > 0 and key[0] in node.children.keys(): child = node.children[key[0]] + child.last_access_time = time.time() prefix_len = _key_match(child.key, key) - + total_prefix_length += prefix_len if prefix_len == len(child.key): if prefix_len == len(key): - return prefix_len + break else: key = key[prefix_len:] value = value[prefix_len:] - return prefix_len + self._insert_helper(child, key, value) - - new_node = self._split_node(child.key, child, prefix_len) - return prefix_len + self._insert_helper( - new_node, key[prefix_len:], value[prefix_len:] - ) + node = child + else: + new_node = self._split_node(child.key, child, prefix_len) + key = key[prefix_len:] + value = value[prefix_len:] + node = new_node + break if len(key): new_node = TreeNode() @@ -315,12 +314,21 @@ def _insert_helper(self, node: TreeNode, key: List, value): new_node.value = value node.children[key[0]] = new_node self.evictable_size_ += len(value) - return 0 + return total_prefix_length def _print_helper(self, node: TreeNode, indent: int): - for _, child in node.children.items(): - print(" " * indent, len(child.key), child.key[:10], f"r={child.lock_ref}") - self._print_helper(child, indent=indent + 2) + """Prints the radix tree in a human-readable format.""" + stack = [(node, indent)] + while stack: + current_node, current_indent = stack.pop() + print( + " " * current_indent, + len(current_node.key), + current_node.key[:10], + f"r={current_node.lock_ref}", + ) + for _, child in current_node.children.items(): + stack.append((child, current_indent + 2)) def _delete_leaf(self, node): for k, v in node.parent.children.items(): @@ -329,13 +337,14 @@ def _delete_leaf(self, node): del node.parent.children[k] self.evictable_size_ -= len(node.key) - def _total_size_helper(self, node: TreeNode): - if node.evicted: - return 0 - x = len(node.value) - for child in node.children.values(): - x += self._total_size_helper(child) - return x + def _total_size_helper(self): + total_size = 0 + stack = [self.root_node] + while stack: + current_node = stack.pop() + total_size += len(current_node.value) + stack.extend(current_node.children.values()) + return total_size def _collect_leaves(self): ret_list = [] @@ -353,20 +362,25 @@ def _collect_leaves(self): if __name__ == "__main__": tree = RadixCache(None, None, False) + a = torch.Tensor([1, 2, 3]) + b = torch.Tensor([1, 2, 4]) + c = torch.Tensor([1, 3, 5]) + tree.insert([], torch.Tensor([])) + tree.insert(a.tolist(), a) + tree.insert(b.tolist(), b) + val, node = tree.match_prefix(c.tolist()) + tree.insert([1, 1, 3, 5], torch.Tensor([1, 1, 3, 5])) + tree.insert([1, 1, 4, 5], torch.Tensor([1, 1, 4, 5])) + + val, node = tree.match_prefix([]) - tree.insert("Hello") - tree.insert("Hello") - tree.insert("Hello_L.A.!") - # tree.insert("Hello_world! Happy") - # tree.insert("I love you!") tree.pretty_print() - # print(tree.match_prefix("I love you! aha")) + def evict_callback(x): + print("evict", x) + return len(x) - # def evict_callback(x): - # print("evict", x) - # return len(x) - - # tree.evict(5, evict_callback) - # tree.evict(10, evict_callback) - # tree.pretty_print() + tree.evict(1, evict_callback) + tree.evict(1, evict_callback) + tree.evict(1, evict_callback) + tree.pretty_print()