Skip to content

Commit

Permalink
Implement hub callback priorities (#2530)
Browse files Browse the repository at this point in the history
* Implement hub callback priorities
* Update glue/core/hub.py
* Address review comments
* Move yield loop out of subscriber loop
* Add test
* Update docstrings
---------
Co-authored-by: Derek Homeier <[email protected]>
Co-authored-by: Thomas Robitaille <[email protected]>
  • Loading branch information
rosteen authored Feb 5, 2025
1 parent d59f708 commit 92c9bff
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 8 deletions.
20 changes: 15 additions & 5 deletions glue/core/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ def __init__(self, *args):

def subscribe(self, subscriber, message_class,
handler=None,
filter=lambda x: True):
filter=lambda x: True,
priority=10):
"""Subscribe an object to a type of message class.
:param subscriber: The subscribing object
Expand All @@ -88,6 +89,11 @@ def subscribe(self, subscriber, message_class,
are only passed to the subscriber if filter(message) == True.
The default is to always pass messages.
:param priority:
An optional integer to set the priority of the handler. Handlers
are sorted such that higher priority handlers get called first
when broadcasting a message.
Raises:
InvalidMessage: If the input class isn't a
Expand All @@ -113,7 +119,7 @@ def subscribe(self, subscriber, message_class,
if subscriber not in self._subscriptions:
self._subscriptions[subscriber] = HubCallbackContainer()

self._subscriptions[subscriber][message_class] = handler, filter
self._subscriptions[subscriber][message_class] = handler, filter, priority

def is_subscribed(self, subscriber, message):
"""
Expand Down Expand Up @@ -160,9 +166,10 @@ def _find_handlers(self, message):
"""Yields all (subscriber, handler) pairs that should receive a message
"""
# self._subscriptions:
# subscriber => { message type => (filter, handler)}
# subscriber => { message type => (filter, handler, priority)}

# loop over subscribed objects
prioritized_handlers = []
for subscriber, subscriptions in list(self._subscriptions.items()):

# subscriptions to message or its superclasses
Expand All @@ -175,9 +182,12 @@ def _find_handlers(self, message):
# narrow to the most-specific message
candidate = max(messages, key=_mro_count)

handler, test = subscriptions[candidate]
handler, test, priority = subscriptions[candidate]
if test(message):
yield subscriber, handler
prioritized_handlers.append((subscriber, handler, priority))

for subscriber, handler, _ in sorted(prioritized_handlers, key=lambda x: x[2], reverse=True):
yield subscriber, handler

@contextmanager
def ignore_callbacks(self, ignore_type):
Expand Down
11 changes: 8 additions & 3 deletions glue/core/hub_callback_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class HubCallbackContainer(object):
def __init__(self):
self.callbacks = {}

def _wrap(self, handler, filter):
def _wrap(self, handler, filter, priority):
"""
Given a function/method, this will automatically wrap a method using
weakref to avoid circular references.
Expand Down Expand Up @@ -57,6 +57,8 @@ def _wrap(self, handler, filter):

value += (filter, None)

value += (priority,)

return value

def _auto_remove(self, method_instance):
Expand Down Expand Up @@ -91,6 +93,9 @@ def __getitem__(self, message_class):
inst = callback[3]()
result += (partial(func, inst),)

# Add priority
result += (callback[4],)

return result

def __iter__(self):
Expand All @@ -108,8 +113,8 @@ def is_bound_method(func):
return hasattr(func, '__func__') and getattr(func, '__self__', None) is not None

def __setitem__(self, message_class, value):
handler, filter = value
self.callbacks[message_class] = self._wrap(handler, filter)
handler, filter, priority = value
self.callbacks[message_class] = self._wrap(handler, filter, priority)

def pop(self, message_class):
return self.callbacks.pop(message_class)
Expand Down
42 changes: 42 additions & 0 deletions glue/core/tests/test_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,48 @@ def test_invalid_init(self):
assert exc.value.args[0] == ("Inputs must be HubListener, data, "
"subset, or data collection objects")

@pytest.mark.parametrize('priorities', (False, True))
def test_handler_priorities(self, priorities):
"""
Test that handlers are called in order of descending priority, if set,
in order subscribed to otherwise.
"""
msg, _, subscriber1 = self.get_subscription()
_, _, subscriber2 = self.get_subscription()
_, _, subscriber3 = self.get_subscription()

class Handlers:

def __init__(self):
self.priority_test_val = 0

def handler1(self, msg):
self.priority_test_val += 10

def handler2(self, msg):
global priority_test_val
self.priority_test_val *= 2

def handler3(self, msg):
self.priority_test_val = 18

handlers = Handlers()

if priorities:
self.hub.subscribe(subscriber1, msg, handlers.handler1, priority=100)
self.hub.subscribe(subscriber2, msg, handlers.handler2)
self.hub.subscribe(subscriber3, msg, handlers.handler3, priority=200)
else:
self.hub.subscribe(subscriber1, msg, handlers.handler1)
self.hub.subscribe(subscriber2, msg, handlers.handler2)
self.hub.subscribe(subscriber3, msg, handlers.handler3)

msg_instance = msg("Test")

self.hub.broadcast(msg_instance)

assert handlers.priority_test_val == (56 if priorities else 18)


class TestHubListener(object):
"""This is a dumb test, I know. Fixated on code coverage"""
Expand Down

0 comments on commit 92c9bff

Please sign in to comment.