Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add special topic for agent direct messaging #4385

Merged
merged 5 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from asyncio import Future, Task
from typing import Any, Dict, Set

from autogen_core.components._type_prefix_subscription import TypePrefixSubscription
from autogen_core.base._type_prefix_subscription import TypePrefixSubscription

from ..base import Subscription, TopicId
from ..components import TypeSubscription
Expand Down
12 changes: 12 additions & 0 deletions python/packages/autogen-core/src/autogen_core/base/_base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from ._subscription import Subscription, UnboundSubscription
from ._subscription_context import SubscriptionInstantiationContext
from ._topic import TopicId
from ._type_prefix_subscription import TypePrefixSubscription

T = TypeVar("T", bound=Agent)

Expand Down Expand Up @@ -149,6 +150,7 @@ async def register(
factory: Callable[[], Self | Awaitable[Self]],
*,
skip_class_subscriptions: bool = False,
skip_direct_message_subscription: bool = False,
) -> AgentType:
agent_type = AgentType(type)
agent_type = await runtime.register_factory(type=agent_type, agent_factory=factory, expected_class=cls)
Expand All @@ -166,6 +168,16 @@ async def register(
for subscription in subscriptions:
await runtime.add_subscription(subscription)

if not skip_direct_message_subscription:
# Additionally adds a special prefix subscription for this agent to receive direct messages
await runtime.add_subscription(
TypePrefixSubscription(
# The prefix MUST include ":" to avoid collisions with other agents
topic_type_prefix=agent_type.type + ":",
agent_type=agent_type.type,
)
)

# TODO: deduplication
for _message_type, serializer in cls._handles_types():
runtime.add_message_serializer(serializer)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import uuid

from ..base import AgentId, Subscription, TopicId
from ..base.exceptions import CantHandleException
from ._agent_id import AgentId
from ._subscription import Subscription
from ._topic import TopicId
from .exceptions import CantHandleException


class TypePrefixSubscription(Subscription):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
The :mod:`autogen_core.components` module provides building blocks for creating single agents
"""

from ..base._type_prefix_subscription import TypePrefixSubscription
from ._closure_agent import ClosureAgent
from ._default_subscription import DefaultSubscription, default_subscription, type_subscription
from ._default_topic import DefaultTopicId
from ._image import Image
from ._routed_agent import RoutedAgent, TypeRoutedAgent, event, message_handler, rpc
from ._type_prefix_subscription import TypePrefixSubscription
from ._type_subscription import TypeSubscription
from ._types import FunctionCall

Expand Down
4 changes: 2 additions & 2 deletions python/packages/autogen-core/tests/test_worker_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ async def get_subscribed_recipients() -> List[AgentId]:
)

subscriptions1 = get_current_subscriptions()
assert len(subscriptions1) == 1
assert len(subscriptions1) == 2
recipients1 = await get_subscribed_recipients()
assert AgentId(type="worker1", key="default") in recipients1

Expand Down Expand Up @@ -388,7 +388,7 @@ async def get_subscribed_recipients() -> List[AgentId]:
)

subscriptions3 = get_current_subscriptions()
assert len(subscriptions3) == 1
assert len(subscriptions3) == 2
assert first_subscription_id not in [x.id for x in subscriptions3]

recipients3 = await get_subscribed_recipients()
Expand Down
Loading