Skip to content

Commit

Permalink
ChatSession: Split native content out of message class (#136668)
Browse files Browse the repository at this point in the history
Split native content out of message class
  • Loading branch information
balloob authored Jan 28, 2025
1 parent 48a9154 commit 5690516
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 73 deletions.
3 changes: 1 addition & 2 deletions homeassistant/components/assist_pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1101,11 +1101,10 @@ async def recognize_intent(
"speech", ""
)
chat_session.async_add_message(
conversation.ChatMessage(
conversation.Content(
role="assistant",
agent_id=agent_id,
content=speech,
native=intent_response,
)
)
conversation_result = conversation.ConversationResult(
Expand Down
11 changes: 9 additions & 2 deletions homeassistant/components/conversation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,21 +48,28 @@
from .entity import ConversationEntity
from .http import async_setup as async_setup_conversation_http
from .models import AbstractConversationAgent, ConversationInput, ConversationResult
from .session import ChatMessage, ChatSession, ConverseError, async_get_chat_session
from .session import (
ChatSession,
Content,
ConverseError,
NativeContent,
async_get_chat_session,
)
from .trace import ConversationTraceEventType, async_conversation_trace_append

__all__ = [
"DOMAIN",
"HOME_ASSISTANT_AGENT",
"OLD_HOME_ASSISTANT_AGENT",
"ChatMessage",
"ChatSession",
"Content",
"ConversationEntity",
"ConversationEntityFeature",
"ConversationInput",
"ConversationResult",
"ConversationTraceEventType",
"ConverseError",
"NativeContent",
"async_conversation_trace_append",
"async_converse",
"async_get_agent_info",
Expand Down
5 changes: 2 additions & 3 deletions homeassistant/components/conversation/default_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
)
from .entity import ConversationEntity
from .models import ConversationInput, ConversationResult
from .session import ChatMessage, async_get_chat_session
from .session import Content, async_get_chat_session
from .trace import ConversationTraceEventType, async_conversation_trace_append

_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -374,11 +374,10 @@ async def async_process(self, user_input: ConversationInput) -> ConversationResu

speech: str = response.speech.get("plain", {}).get("speech", "")
chat_session.async_add_message(
ChatMessage(
Content(
role="assistant",
agent_id=user_input.agent_id,
content=speech,
native=response,
)
)

Expand Down
36 changes: 17 additions & 19 deletions homeassistant/components/conversation/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ async def async_get_chat_session(
else:
history = ChatSession(hass, conversation_id, user_input.agent_id)

message: ChatMessage = ChatMessage(
message: Content = Content(
role="user",
agent_id=user_input.agent_id,
content=user_input.text,
Expand Down Expand Up @@ -169,23 +169,21 @@ def as_conversation_result(self) -> ConversationResult:


@dataclass
class ChatMessage[_NativeT]:
"""Base class for chat messages.
class Content:
"""Base class for chat messages."""

When role is native, the content is to be ignored and message
is only meant for storing the native object.
"""

role: Literal["system", "assistant", "user", "native"]
role: Literal["system", "assistant", "user"]
agent_id: str | None
content: str
native: _NativeT | None = field(default=None)

# Validate in post-init that if role is native, there is no content and a native object exists
def __post_init__(self) -> None:
"""Validate native message."""
if self.role == "native" and self.native is None:
raise ValueError("Native message must have a native object")

@dataclass(frozen=True)
class NativeContent[_NativeT]:
"""Native content."""

role: str = field(init=False, default="native")
agent_id: str
content: _NativeT


@dataclass
Expand All @@ -196,15 +194,15 @@ class ChatSession[_NativeT]:
conversation_id: str
agent_id: str | None
user_name: str | None = None
messages: list[ChatMessage[_NativeT]] = field(
default_factory=lambda: [ChatMessage(role="system", agent_id=None, content="")]
messages: list[Content | NativeContent[_NativeT]] = field(
default_factory=lambda: [Content(role="system", agent_id=None, content="")]
)
extra_system_prompt: str | None = None
llm_api: llm.APIInstance | None = None
last_updated: datetime = field(default_factory=dt_util.utcnow)

@callback
def async_add_message(self, message: ChatMessage[_NativeT]) -> None:
def async_add_message(self, message: Content | NativeContent[_NativeT]) -> None:
"""Process intent."""
if message.role == "system":
raise ValueError("Cannot add system messages to history")
Expand All @@ -216,7 +214,7 @@ def async_add_message(self, message: ChatMessage[_NativeT]) -> None:
@callback
def async_get_messages(
self, agent_id: str | None = None
) -> list[ChatMessage[_NativeT]]:
) -> list[Content | NativeContent[_NativeT]]:
"""Get messages for a specific agent ID.
This will filter out any native message tied to other agent IDs.
Expand Down Expand Up @@ -328,7 +326,7 @@ async def async_update_llm_data(
self.llm_api = llm_api
self.user_name = user_name
self.extra_system_prompt = extra_system_prompt
self.messages[0] = ChatMessage(
self.messages[0] = Content(
role="system",
agent_id=user_input.agent_id,
content=prompt,
Expand Down
26 changes: 12 additions & 14 deletions homeassistant/components/openai_conversation/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,13 @@ def _message_convert(message: ChatCompletionMessage) -> ChatCompletionMessagePar


def _chat_message_convert(
message: conversation.ChatMessage[ChatCompletionMessageParam],
agent_id: str | None,
message: conversation.Content
| conversation.NativeContent[ChatCompletionMessageParam],
) -> ChatCompletionMessageParam:
"""Convert any native chat message for this agent to the native format."""
if message.native is not None and message.agent_id == agent_id:
return message.native
if message.role == "native":
# mypy doesn't understand that checking role ensures content type
return message.content # type: ignore[return-value]
return cast(
ChatCompletionMessageParam,
{"role": message.role, "content": message.content},
Expand Down Expand Up @@ -157,14 +158,15 @@ async def async_process(
async with conversation.async_get_chat_session(
self.hass, user_input
) as session:
return await self._async_call_api(user_input, session)
return await self._async_handle_message(user_input, session)

async def _async_call_api(
async def _async_handle_message(
self,
user_input: conversation.ConversationInput,
session: conversation.ChatSession[ChatCompletionMessageParam],
) -> conversation.ConversationResult:
"""Call the API."""
assert user_input.agent_id
options = self.entry.options

try:
Expand All @@ -185,8 +187,7 @@ async def _async_call_api(
]

messages = [
_chat_message_convert(message, user_input.agent_id)
for message in session.async_get_messages()
_chat_message_convert(message) for message in session.async_get_messages()
]

client = self.entry.runtime_data
Expand All @@ -212,11 +213,10 @@ async def _async_call_api(
messages.append(_message_convert(response))

session.async_add_message(
conversation.ChatMessage(
conversation.Content(
role=response.role,
agent_id=user_input.agent_id,
content=response.content or "",
native=messages[-1],
),
)

Expand All @@ -237,11 +237,9 @@ async def _async_call_api(
)
)
session.async_add_message(
conversation.ChatMessage(
role="native",
conversation.NativeContent(
agent_id=user_input.agent_id,
content="",
native=messages[-1],
content=messages[-1],
)
)

Expand Down
51 changes: 18 additions & 33 deletions tests/components/conversation/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ async def test_cleanup(
assert chat_session.conversation_id != conversation_id
conversation_id = chat_session.conversation_id
chat_session.async_add_message(
session.ChatMessage(
session.Content(
role="assistant",
agent_id="mock-agent-id",
content="Hey!",
Expand Down Expand Up @@ -127,12 +127,6 @@ async def test_cleanup(
assert len(chat_session.messages) == 2


def test_chat_message() -> None:
"""Test chat message."""
with pytest.raises(ValueError):
session.ChatMessage(role="native", agent_id=None, content="", native=None)


async def test_add_message(
hass: HomeAssistant, mock_conversation_input: ConversationInput
) -> None:
Expand All @@ -144,27 +138,27 @@ async def test_add_message(

with pytest.raises(ValueError):
chat_session.async_add_message(
session.ChatMessage(role="system", agent_id=None, content="")
session.Content(role="system", agent_id=None, content="")
)

# No 2 user messages in a row
assert chat_session.messages[1].role == "user"

with pytest.raises(ValueError):
chat_session.async_add_message(
session.ChatMessage(role="user", agent_id=None, content="")
session.Content(role="user", agent_id=None, content="")
)

# No 2 assistant messages in a row
chat_session.async_add_message(
session.ChatMessage(role="assistant", agent_id=None, content="")
session.Content(role="assistant", agent_id=None, content="")
)
assert len(chat_session.messages) == 3
assert chat_session.messages[-1].role == "assistant"

with pytest.raises(ValueError):
chat_session.async_add_message(
session.ChatMessage(role="assistant", agent_id=None, content="")
session.Content(role="assistant", agent_id=None, content="")
)


Expand All @@ -177,52 +171,46 @@ async def test_message_filtering(
) as chat_session:
messages = chat_session.async_get_messages(agent_id=None)
assert len(messages) == 2
assert messages[0] == session.ChatMessage(
assert messages[0] == session.Content(
role="system",
agent_id=None,
content="",
)
assert messages[1] == session.ChatMessage(
assert messages[1] == session.Content(
role="user",
agent_id="mock-agent-id",
content=mock_conversation_input.text,
)
# Cannot add a second user message in a row
with pytest.raises(ValueError):
chat_session.async_add_message(
session.ChatMessage(
session.Content(
role="user",
agent_id="mock-agent-id",
content="Hey!",
)
)

chat_session.async_add_message(
session.ChatMessage(
session.Content(
role="assistant",
agent_id="mock-agent-id",
content="Hey!",
native="assistant-reply-native",
)
)
# Different agent, native messages will be filtered out.
chat_session.async_add_message(
session.ChatMessage(
role="native", agent_id="another-mock-agent-id", content="", native=1
)
session.NativeContent(agent_id="another-mock-agent-id", content=1)
)
chat_session.async_add_message(
session.ChatMessage(
role="native", agent_id="mock-agent-id", content="", native=1
)
session.NativeContent(agent_id="mock-agent-id", content=1)
)
# A non-native message from another agent is not filtered out.
chat_session.async_add_message(
session.ChatMessage(
session.Content(
role="assistant",
agent_id="another-mock-agent-id",
content="Hi!",
native=1,
)
)

Expand All @@ -231,17 +219,14 @@ async def test_message_filtering(
messages = chat_session.async_get_messages(agent_id="mock-agent-id")
assert len(messages) == 5

assert messages[2] == session.ChatMessage(
assert messages[2] == session.Content(
role="assistant",
agent_id="mock-agent-id",
content="Hey!",
native="assistant-reply-native",
)
assert messages[3] == session.ChatMessage(
role="native", agent_id="mock-agent-id", content="", native=1
)
assert messages[4] == session.ChatMessage(
role="assistant", agent_id="another-mock-agent-id", content="Hi!", native=1
assert messages[3] == session.NativeContent(agent_id="mock-agent-id", content=1)
assert messages[4] == session.Content(
role="assistant", agent_id="another-mock-agent-id", content="Hi!"
)


Expand Down Expand Up @@ -361,7 +346,7 @@ async def test_extra_systen_prompt(
user_llm_prompt=None,
)
chat_session.async_add_message(
session.ChatMessage(
session.Content(
role="assistant",
agent_id="mock-agent-id",
content="Hey!",
Expand Down Expand Up @@ -401,7 +386,7 @@ async def test_extra_systen_prompt(
user_llm_prompt=None,
)
chat_session.async_add_message(
session.ChatMessage(
session.Content(
role="assistant",
agent_id="mock-agent-id",
content="Hey!",
Expand Down

0 comments on commit 5690516

Please sign in to comment.