diff --git a/app/models/dtos.py b/app/models/dtos.py index 2fe7c28f..8dba80a6 100644 --- a/app/models/dtos.py +++ b/app/models/dtos.py @@ -13,6 +13,7 @@ class ContentType(str, Enum): TEXT = "text" +# V1 API only class Content(BaseModel): text_content: str = Field(..., alias="textContent") type: ContentType @@ -20,7 +21,6 @@ class Content(BaseModel): class SendMessageRequest(BaseModel): class Template(BaseModel): - id: int content: str template: Template @@ -28,6 +28,7 @@ class Template(BaseModel): parameters: dict +# V1 API only class SendMessageResponse(BaseModel): class Message(BaseModel): sent_at: datetime = Field( @@ -39,6 +40,18 @@ class Message(BaseModel): message: Message +class SendMessageRequestV2(BaseModel): + template: str + preferred_model: str = Field(..., alias="preferredModel") + parameters: dict + + +class SendMessageResponseV2(BaseModel): + used_model: str = Field(..., alias="usedModel") + sent_at: datetime = Field(alias="sentAt", default_factory=datetime.utcnow) + content: dict + + class ModelStatus(BaseModel): model: str status: LLMStatus diff --git a/app/routes/messages.py b/app/routes/messages.py index ee8e98d5..49f96f62 100644 --- a/app/routes/messages.py +++ b/app/routes/messages.py @@ -11,7 +11,14 @@ InvalidModelException, ) from app.dependencies import TokenPermissionsValidator -from app.models.dtos import SendMessageRequest, SendMessageResponse +from app.models.dtos import ( + SendMessageRequest, + SendMessageResponse, + Content, + ContentType, + SendMessageRequestV2, + SendMessageResponseV2, +) from app.services.circuit_breaker import CircuitBreaker from app.services.guidance_wrapper import GuidanceWrapper from app.config import settings @@ -19,25 +26,22 @@ router = APIRouter(tags=["messages"]) -@router.post( - "/api/v1/messages", dependencies=[Depends(TokenPermissionsValidator())] -) -def send_message(body: SendMessageRequest) -> SendMessageResponse: +def execute_call(template, preferred_model, parameters) -> dict: try: - model = settings.pyris.llms[body.preferred_model] + model = settings.pyris.llms[preferred_model] except ValueError as e: raise InvalidModelException(str(e)) guidance = GuidanceWrapper( model=model, - handlebars=body.template.content, - parameters=body.parameters, + handlebars=template, + parameters=parameters, ) try: - content = CircuitBreaker.protected_call( + return CircuitBreaker.protected_call( func=guidance.query, - cache_key=body.preferred_model, + cache_key=preferred_model, accepted_exceptions=( KeyError, SyntaxError, @@ -52,13 +56,47 @@ def send_message(body: SendMessageRequest) -> SendMessageResponse: except Exception as e: raise InternalServerException(str(e)) - # Turn content into an array if it's not already - if not isinstance(content, list): - content = [content] + +@router.post( + "/api/v1/messages", dependencies=[Depends(TokenPermissionsValidator())] +) +def send_message(body: SendMessageRequest) -> SendMessageResponse: + generated_vars = execute_call( + body.template.content, body.preferred_model, body.parameters + ) + + # V1: Throw an exception if no 'response' variable was generated + if "response" not in generated_vars: + raise InternalServerException( + str(ValueError("The handlebars do not generate 'response'")) + ) return SendMessageResponse( usedModel=body.preferred_model, message=SendMessageResponse.Message( - sentAt=datetime.now(timezone.utc), content=content + sentAt=datetime.now(timezone.utc), + content=[ + Content( + type=ContentType.TEXT, + textContent=generated_vars[ + "response" + ], # V1: only return the 'response' variable + ) + ], ), ) + + +@router.post( + "/api/v2/messages", dependencies=[Depends(TokenPermissionsValidator())] +) +def send_message_v2(body: SendMessageRequestV2) -> SendMessageResponseV2: + generated_vars = execute_call( + body.template, body.preferred_model, body.parameters + ) + + return SendMessageResponseV2( + usedModel=body.preferred_model, + sentAt=datetime.now(timezone.utc), + content=generated_vars, # V2: return all generated variables + ) diff --git a/app/services/guidance_wrapper.py b/app/services/guidance_wrapper.py index 34b88802..834a8fd3 100644 --- a/app/services/guidance_wrapper.py +++ b/app/services/guidance_wrapper.py @@ -1,7 +1,7 @@ import guidance +import re from app.config import LLMModelConfig -from app.models.dtos import Content, ContentType from app.services.guidance_functions import truncate @@ -21,7 +21,7 @@ def __init__( self.handlebars = handlebars self.parameters = parameters - def query(self) -> Content: + def query(self) -> dict: """Get response from a chosen LLM model. Returns: @@ -29,9 +29,17 @@ def query(self) -> Content: Raises: Reraises exception from guidance package - ValueError: if handlebars do not generate 'response' """ + # Perform a regex search to find the names of the variables + # being generated in the program. This regex matches strings like: + # {{gen 'response' temperature=0.0 max_tokens=500}} + # {{#geneach 'values' num_iterations=3}} + # {{set 'answer' (truncate response 3)}} + # and extracts the variable names 'response', 'values', and 'answer' + pattern = r'{{#?(?:gen|geneach|set) +[\'"]([^\'"]+)[\'"]' + var_names = re.findall(pattern, self.handlebars) + template = guidance(self.handlebars) result = template( llm=self._get_llm(), @@ -42,10 +50,13 @@ def query(self) -> Content: if isinstance(result._exception, Exception): raise result._exception - if "response" not in result: - raise ValueError("The handlebars do not generate 'response'") + generated_vars = { + var_name: result[var_name] + for var_name in var_names + if var_name in result + } - return Content(type=ContentType.TEXT, textContent=result["response"]) + return generated_vars def is_up(self) -> bool: """Check if the chosen LLM model is up. @@ -64,7 +75,7 @@ def is_up(self) -> bool: content = ( GuidanceWrapper(model=self.model, handlebars=handlebars) .query() - .text_content + .get("response") ) return content == "1" diff --git a/tests/routes/messages_test.py b/tests/routes/messages_test.py index e0f09786..418f82f9 100644 --- a/tests/routes/messages_test.py +++ b/tests/routes/messages_test.py @@ -1,6 +1,5 @@ import pytest from freezegun import freeze_time -from app.models.dtos import Content, ContentType from app.services.guidance_wrapper import GuidanceWrapper import app.config as config @@ -31,9 +30,9 @@ def test_send_message(test_client, headers, mocker): mocker.patch.object( GuidanceWrapper, "query", - return_value=Content( - type=ContentType.TEXT, textContent="some content" - ), + return_value={ + "response": "some content", + }, autospec=True, ) @@ -51,9 +50,11 @@ def test_send_message(test_client, headers, mocker): "query": "Some query", }, } - response = test_client.post("/api/v1/messages", headers=headers, json=body) - assert response.status_code == 200 - assert response.json() == { + response_v1 = test_client.post( + "/api/v1/messages", headers=headers, json=body + ) + assert response_v1.status_code == 200 + assert response_v1.json() == { "usedModel": "GPT35_TURBO", "message": { "sentAt": "2023-06-16T01:21:34+00:00", @@ -62,6 +63,43 @@ def test_send_message(test_client, headers, mocker): } +@freeze_time("2023-06-16 03:21:34 +02:00") +@pytest.mark.usefixtures("model_configs") +def test_send_message_v2(test_client, headers, mocker): + mocker.patch.object( + GuidanceWrapper, + "query", + return_value={ + "response": "some content", + }, + autospec=True, + ) + + body = { + "template": "{{#user~}}I want a response to the following query:\ + {{query}}{{~/user}}{{#assistant~}}\ + {{gen 'response' temperature=0.0 max_tokens=500}}{{~/assistant}}", + "preferredModel": "GPT35_TURBO", + "parameters": { + "course": "Intro to Java", + "exercise": "Fun With Sets", + "query": "Some query", + }, + } + + response_v2 = test_client.post( + "/api/v2/messages", headers=headers, json=body + ) + assert response_v2.status_code == 200 + assert response_v2.json() == { + "usedModel": "GPT35_TURBO", + "sentAt": "2023-06-16T01:21:34+00:00", + "content": { + "response": "some content", + }, + } + + def test_send_message_missing_model(test_client, headers): response = test_client.post("/api/v1/messages", headers=headers, json={}) assert response.status_code == 404 diff --git a/tests/services/guidance_wrapper_test.py b/tests/services/guidance_wrapper_test.py index 502da4ba..84fe9333 100644 --- a/tests/services/guidance_wrapper_test.py +++ b/tests/services/guidance_wrapper_test.py @@ -1,7 +1,6 @@ import pytest import guidance -from app.models.dtos import Content, ContentType from app.services.guidance_wrapper import GuidanceWrapper from app.config import OpenAIConfig @@ -33,9 +32,8 @@ def test_query_success(mocker): result = guidance_wrapper.query() - assert isinstance(result, Content) - assert result.type == ContentType.TEXT - assert result.text_content == "the output" + assert isinstance(result, dict) + assert result["response"] == "the output" def test_query_using_truncate_function(mocker): @@ -59,9 +57,9 @@ def test_query_using_truncate_function(mocker): result = guidance_wrapper.query() - assert isinstance(result, Content) - assert result.type == ContentType.TEXT - assert result.text_content == "the" + assert isinstance(result, dict) + assert result["answer"] == "the output" + assert result["response"] == "the" def test_query_missing_required_params(mocker): @@ -84,30 +82,5 @@ def test_query_missing_required_params(mocker): with pytest.raises(KeyError, match="Command/variable 'query' not found!"): result = guidance_wrapper.query() - assert isinstance(result, Content) - assert result.type == ContentType.TEXT - assert result.text_content == "the output" - - -def test_query_handlebars_not_generate_response(mocker): - mocker.patch.object( - GuidanceWrapper, - "_get_llm", - return_value=guidance.llms.Mock("the output"), - ) - - handlebars = "Not a valid handlebars" - guidance_wrapper = GuidanceWrapper( - model=llm_model_config, - handlebars=handlebars, - parameters={"query": "Something"}, - ) - - with pytest.raises( - ValueError, match="The handlebars do not generate 'response'" - ): - result = guidance_wrapper.query() - - assert isinstance(result, Content) - assert result.type == ContentType.TEXT - assert result.text_content == "the output" + assert isinstance(result, dict) + assert result["response"] == "the output"