diff --git a/app/models/dtos.py b/app/models/dtos.py index 55eedcde..d52d0487 100644 --- a/app/models/dtos.py +++ b/app/models/dtos.py @@ -25,9 +25,7 @@ class Template(BaseModel): class SendMessageResponse(BaseModel): used_model: str = Field(..., alias="usedModel") - sent_at: datetime = Field( - alias="sentAt", default_factory=datetime.utcnow - ) + sent_at: datetime = Field(alias="sentAt", default_factory=datetime.utcnow) content: dict diff --git a/app/services/guidance_wrapper.py b/app/services/guidance_wrapper.py index 3ab4ff0f..82160189 100644 --- a/app/services/guidance_wrapper.py +++ b/app/services/guidance_wrapper.py @@ -1,7 +1,6 @@ import guidance from app.config import LLMModelConfig -from app.models.dtos import Content, ContentType from app.services.guidance_functions import truncate @@ -21,7 +20,7 @@ def __init__( self.handlebars = handlebars self.parameters = parameters - def query(self) -> Content: + def query(self) -> dict: """Get response from a chosen LLM model. Returns: @@ -33,8 +32,9 @@ def query(self) -> Content: """ import re + pattern = r'{{(?:gen|geneach|set) [\'"]([^\'"]+)[\'"]$}}' - var_names = re.findall(pattern, input_string) + var_names = re.findall(pattern, self.handlebars) template = guidance(self.handlebars) result = template( @@ -47,7 +47,9 @@ def query(self) -> Content: raise result._exception generated_vars = { - var_name: result[var_name] for var_name in var_names if var_name in result + var_name: result[var_name] + for var_name in var_names + if var_name in result } return generated_vars @@ -69,7 +71,7 @@ def is_up(self) -> bool: content = ( GuidanceWrapper(model=self.model, handlebars=handlebars) .query() - .text_content + .get("response") ) return content == "1" diff --git a/tests/services/guidance_wrapper_test.py b/tests/services/guidance_wrapper_test.py index cb0f540c..84fe9333 100644 --- a/tests/services/guidance_wrapper_test.py +++ b/tests/services/guidance_wrapper_test.py @@ -33,7 +33,7 @@ def test_query_success(mocker): result = guidance_wrapper.query() assert isinstance(result, dict) - assert result['response'] == "the output" + assert result["response"] == "the output" def test_query_using_truncate_function(mocker): @@ -58,8 +58,8 @@ def test_query_using_truncate_function(mocker): result = guidance_wrapper.query() assert isinstance(result, dict) - assert result['answer'] == "the output" - assert result['response'] == "the" + assert result["answer"] == "the output" + assert result["response"] == "the" def test_query_missing_required_params(mocker): @@ -83,4 +83,4 @@ def test_query_missing_required_params(mocker): result = guidance_wrapper.query() assert isinstance(result, dict) - assert result['response'] == "the output" + assert result["response"] == "the output"