Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Debug radixcache: refactor recursive helper methods #3029

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Changes from 16 commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
b39aaae
Refactor radix cache helper functions to use iterative approaches
luzengxiangcn Feb 5, 2025
61e9125
Merge branch 'main' into debug_radixcache_stack_overflow
luzengxiangcn Feb 6, 2025
64b46fd
Update radix_cache.py
xiezhq-hermann Feb 8, 2025
2adaa3e
Merge branch 'main' into debug_radixcache_stack_overflow
xiezhq-hermann Feb 8, 2025
2e86323
Update radix_cache.py
xiezhq-hermann Feb 8, 2025
e569ec9
Merge branch 'main' into debug_radixcache_stack_overflow
luzengxiangcn Feb 8, 2025
f9fdf43
Merge branch 'main' into debug_radixcache_stack_overflow
xiezhq-hermann Feb 9, 2025
33ed800
Merge branch 'main' into debug_radixcache_stack_overflow
xiezhq-hermann Feb 10, 2025
6ac095e
Merge branch 'main' into debug_radixcache_stack_overflow
xiezhq-hermann Feb 12, 2025
439b87a
Merge branch 'main' into debug_radixcache_stack_overflow
luzengxiangcn Feb 18, 2025
d040e7a
Merge branch 'main' into debug_radixcache_stack_overflow
luzengxiangcn Feb 20, 2025
83fa8c5
Merge branch 'main' into debug_radixcache_stack_overflow
xiezhq-hermann Feb 20, 2025
b5a0afa
Merge branch 'main' into debug_radixcache_stack_overflow
luzengxiangcn Feb 21, 2025
f6b0587
Merge branch 'main' into debug_radixcache_stack_overflow
xiezhq-hermann Feb 22, 2025
b391c22
Merge branch 'main' into debug_radixcache_stack_overflow
luzengxiangcn Feb 24, 2025
401e7a4
Merge branch 'main' into debug_radixcache_stack_overflow
xiezhq-hermann Feb 24, 2025
8a23e1b
Merge branch 'main' into debug_radixcache_stack_overflow
xiezhq-hermann Feb 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 47 additions & 41 deletions python/sglang/srt/mem_cache/radix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,12 @@ def match_prefix(self, key: List[int], **kwargs) -> Tuple[torch.Tensor, int]:
if self.disable:
return [], self.root_node

value = []
last_node = [self.root_node]
self._match_prefix_helper(self.root_node, key, value, last_node)
value, last_node = self._match_prefix_helper(self.root_node, key)
if value:
value = torch.concat(value)
else:
value = torch.tensor([], dtype=torch.int32)
return value, last_node[0]
return value, last_node

def insert(self, key: List, value=None):
if self.disable:
Expand Down Expand Up @@ -191,7 +189,7 @@ def pretty_print(self):
print(f"#tokens: {self.total_size()}")

def total_size(self):
return self._total_size_helper(self.root_node)
return self._total_size_helper()

def evict(self, num_tokens: int, evict_callback: Callable):
if self.disable:
Expand Down Expand Up @@ -253,24 +251,23 @@ def protected_size(self):

##### Internal Helper Functions #####

def _match_prefix_helper(
self, node: TreeNode, key: List, value, last_node: TreeNode
):
def _match_prefix_helper(self, node: TreeNode, key: List):
node.last_access_time = time.time()
if len(key) == 0:
return

if key[0] in node.children.keys():
value = []
while len(key) > 0 and key[0] in node.children.keys():
child = node.children[key[0]]
child.last_access_time = time.time()
prefix_len = _key_match(child.key, key)
if prefix_len < len(child.key):
new_node = self._split_node(child.key, child, prefix_len)
value.append(new_node.value)
last_node[0] = new_node
node = new_node
break
else:
value.append(child.value)
last_node[0] = child
self._match_prefix_helper(child, key[prefix_len:], value, last_node)
node = child
key = key[prefix_len:]
return value, node

def _split_node(self, key, child: TreeNode, split_len: int):
# new_node -> child
Expand All @@ -291,22 +288,18 @@ def _insert_helper(self, node: TreeNode, key: List, value):
if len(key) == 0:
return 0

if key[0] in node.children.keys():
child = node.children[key[0]]
prefix_len = _key_match(child.key, key)
total_prefix_length = 0
while len(key) > 0 and key[0] in node.children.keys():
node = node.children[key[0]]
node.last_access_time = time.time()
prefix_len = _key_match(node.key, key)
total_prefix_length += prefix_len
key = key[prefix_len:]
value = value[prefix_len:]

if prefix_len == len(child.key):
if prefix_len == len(key):
return prefix_len
else:
key = key[prefix_len:]
value = value[prefix_len:]
return prefix_len + self._insert_helper(child, key, value)

new_node = self._split_node(child.key, child, prefix_len)
return prefix_len + self._insert_helper(
new_node, key[prefix_len:], value[prefix_len:]
)
if prefix_len < len(node.key):
new_node = self._split_node(node.key, node, prefix_len)
node = new_node

if len(key):
new_node = TreeNode()
Expand All @@ -315,12 +308,21 @@ def _insert_helper(self, node: TreeNode, key: List, value):
new_node.value = value
node.children[key[0]] = new_node
self.evictable_size_ += len(value)
return 0
return total_prefix_length

def _print_helper(self, node: TreeNode, indent: int):
for _, child in node.children.items():
print(" " * indent, len(child.key), child.key[:10], f"r={child.lock_ref}")
self._print_helper(child, indent=indent + 2)
"""Prints the radix tree in a human-readable format."""
stack = [(node, indent)]
while stack:
current_node, current_indent = stack.pop()
print(
" " * current_indent,
len(current_node.key),
current_node.key[:10],
f"r={current_node.lock_ref}",
)
for _, child in current_node.children.items():
stack.append((child, current_indent + 2))

def _delete_leaf(self, node):
for k, v in node.parent.children.items():
Expand All @@ -329,13 +331,17 @@ def _delete_leaf(self, node):
del node.parent.children[k]
self.evictable_size_ -= len(node.key)

def _total_size_helper(self, node: TreeNode):
if node.evicted:
return 0
x = len(node.value)
for child in node.children.values():
x += self._total_size_helper(child)
return x
def _total_size_helper(self):
total_size = 0
stack = [self.root_node]
while stack:
current_node = stack.pop()
total_size += len(current_node.value)
for child in current_node.children.values():
if child.evicted:
continue
stack.append(child)
return total_size

def _collect_leaves(self):
ret_list = []
Expand Down
Loading