Skip to content
This repository has been archived by the owner on Nov 13, 2024. It is now read-only.

Add last message query generator #210

Merged
1 change: 1 addition & 0 deletions src/canopy/chat_engine/query_generator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .base import QueryGenerator
from .function_calling import FunctionCallingQueryGenerator
from .last_message import LastMessageQueryGenerator
36 changes: 36 additions & 0 deletions src/canopy/chat_engine/query_generator/last_message.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import List

from canopy.chat_engine.query_generator import QueryGenerator
from canopy.models.data_models import Messages, Query, Role


class LastMessageQueryGenerator(QueryGenerator):
"""
Returns the last message as a query without running any LLMs. This can be
considered as the most basic query generation. Please use other query generators
for more accurate results.
"""

izellevy marked this conversation as resolved.
Show resolved Hide resolved
def generate(self,
messages: Messages,
max_prompt_tokens: int) -> List[Query]:
"""
max_prompt_token is dismissed since we do not consume any token for
generating the queries.
"""

if len(messages) == 0:
raise ValueError("Passed chat history does not contain any messages. "
"Please include at least one message in the history.")

last_message = messages[-1]

if last_message.role != Role.USER:
raise ValueError(f"Expected a UserMessage, got {type(last_message)}.")

return [Query(text=last_message.content)]

async def agenerate(self,
messages: Messages,
max_prompt_tokens: int) -> List[Query]:
return self.generate(messages, max_prompt_tokens)
41 changes: 41 additions & 0 deletions tests/unit/query_generators/test_last_message_query_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import pytest

from canopy.chat_engine.query_generator import LastMessageQueryGenerator
from canopy.models.data_models import UserMessage, Query, AssistantMessage


@pytest.fixture
def sample_messages():
return [
UserMessage(content="What is photosynthesis?")
]


@pytest.fixture
def query_generator():
return LastMessageQueryGenerator()


def test_generate(query_generator, sample_messages):
expected = [Query(text=sample_messages[-1].content)]
actual = query_generator.generate(sample_messages, 0)
assert actual == expected


@pytest.mark.asyncio
async def test_agenerate(query_generator, sample_messages):
expected = [Query(text=sample_messages[-1].content)]
actual = await query_generator.agenerate(sample_messages, 0)
assert actual == expected


def test_generate_fails_with_empty_history(query_generator):
with pytest.raises(ValueError):
query_generator.generate([], 0)


def test_generate_fails_with_no_user_message(query_generator):
with pytest.raises(ValueError):
query_generator.generate([
AssistantMessage(content="Hi! How can I help you?")
], 0)