diff --git a/app/models/dtos.py b/app/models/dtos.py index 2fe7c28f..55eedcde 100644 --- a/app/models/dtos.py +++ b/app/models/dtos.py @@ -13,11 +13,6 @@ class ContentType(str, Enum): TEXT = "text" -class Content(BaseModel): - text_content: str = Field(..., alias="textContent") - type: ContentType - - class SendMessageRequest(BaseModel): class Template(BaseModel): id: int @@ -29,14 +24,11 @@ class Template(BaseModel): class SendMessageResponse(BaseModel): - class Message(BaseModel): - sent_at: datetime = Field( - alias="sentAt", default_factory=datetime.utcnow - ) - content: list[Content] - used_model: str = Field(..., alias="usedModel") - message: Message + sent_at: datetime = Field( + alias="sentAt", default_factory=datetime.utcnow + ) + content: dict class ModelStatus(BaseModel): diff --git a/app/routes/messages.py b/app/routes/messages.py index ee8e98d5..4dc873d1 100644 --- a/app/routes/messages.py +++ b/app/routes/messages.py @@ -52,13 +52,8 @@ def send_message(body: SendMessageRequest) -> SendMessageResponse: except Exception as e: raise InternalServerException(str(e)) - # Turn content into an array if it's not already - if not isinstance(content, list): - content = [content] - return SendMessageResponse( usedModel=body.preferred_model, - message=SendMessageResponse.Message( - sentAt=datetime.now(timezone.utc), content=content - ), + sentAt=datetime.now(timezone.utc), + content=content, ) diff --git a/app/services/guidance_wrapper.py b/app/services/guidance_wrapper.py index 34b88802..3ab4ff0f 100644 --- a/app/services/guidance_wrapper.py +++ b/app/services/guidance_wrapper.py @@ -32,6 +32,10 @@ def query(self) -> Content: ValueError: if handlebars do not generate 'response' """ + import re + pattern = r'{{(?:gen|geneach|set) [\'"]([^\'"]+)[\'"]$}}' + var_names = re.findall(pattern, input_string) + template = guidance(self.handlebars) result = template( llm=self._get_llm(), @@ -42,10 +46,11 @@ def query(self) -> Content: if isinstance(result._exception, Exception): raise result._exception - if "response" not in result: - raise ValueError("The handlebars do not generate 'response'") + generated_vars = { + var_name: result[var_name] for var_name in var_names if var_name in result + } - return Content(type=ContentType.TEXT, textContent=result["response"]) + return generated_vars def is_up(self) -> bool: """Check if the chosen LLM model is up. diff --git a/tests/routes/messages_test.py b/tests/routes/messages_test.py index e0f09786..725514cf 100644 --- a/tests/routes/messages_test.py +++ b/tests/routes/messages_test.py @@ -31,9 +31,9 @@ def test_send_message(test_client, headers, mocker): mocker.patch.object( GuidanceWrapper, "query", - return_value=Content( - type=ContentType.TEXT, textContent="some content" - ), + return_value={ + "response": "some content", + }, autospec=True, ) @@ -55,9 +55,9 @@ def test_send_message(test_client, headers, mocker): assert response.status_code == 200 assert response.json() == { "usedModel": "GPT35_TURBO", - "message": { - "sentAt": "2023-06-16T01:21:34+00:00", - "content": [{"textContent": "some content", "type": "text"}], + "sentAt": "2023-06-16T01:21:34+00:00", + "content": { + "response": "some content", }, } diff --git a/tests/services/guidance_wrapper_test.py b/tests/services/guidance_wrapper_test.py index 502da4ba..cb0f540c 100644 --- a/tests/services/guidance_wrapper_test.py +++ b/tests/services/guidance_wrapper_test.py @@ -1,7 +1,6 @@ import pytest import guidance -from app.models.dtos import Content, ContentType from app.services.guidance_wrapper import GuidanceWrapper from app.config import OpenAIConfig @@ -33,9 +32,8 @@ def test_query_success(mocker): result = guidance_wrapper.query() - assert isinstance(result, Content) - assert result.type == ContentType.TEXT - assert result.text_content == "the output" + assert isinstance(result, dict) + assert result['response'] == "the output" def test_query_using_truncate_function(mocker): @@ -59,9 +57,9 @@ def test_query_using_truncate_function(mocker): result = guidance_wrapper.query() - assert isinstance(result, Content) - assert result.type == ContentType.TEXT - assert result.text_content == "the" + assert isinstance(result, dict) + assert result['answer'] == "the output" + assert result['response'] == "the" def test_query_missing_required_params(mocker): @@ -84,30 +82,5 @@ def test_query_missing_required_params(mocker): with pytest.raises(KeyError, match="Command/variable 'query' not found!"): result = guidance_wrapper.query() - assert isinstance(result, Content) - assert result.type == ContentType.TEXT - assert result.text_content == "the output" - - -def test_query_handlebars_not_generate_response(mocker): - mocker.patch.object( - GuidanceWrapper, - "_get_llm", - return_value=guidance.llms.Mock("the output"), - ) - - handlebars = "Not a valid handlebars" - guidance_wrapper = GuidanceWrapper( - model=llm_model_config, - handlebars=handlebars, - parameters={"query": "Something"}, - ) - - with pytest.raises( - ValueError, match="The handlebars do not generate 'response'" - ): - result = guidance_wrapper.query() - - assert isinstance(result, Content) - assert result.type == ContentType.TEXT - assert result.text_content == "the output" + assert isinstance(result, dict) + assert result['response'] == "the output"