Skip to content

Commit

Permalink
Reformat code
Browse files Browse the repository at this point in the history
  • Loading branch information
monst committed Oct 9, 2023
1 parent 5414ed4 commit 153d2b7
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 12 deletions.
4 changes: 1 addition & 3 deletions app/models/dtos.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@ class Template(BaseModel):

class SendMessageResponse(BaseModel):
used_model: str = Field(..., alias="usedModel")
sent_at: datetime = Field(
alias="sentAt", default_factory=datetime.utcnow
)
sent_at: datetime = Field(alias="sentAt", default_factory=datetime.utcnow)
content: dict


Expand Down
12 changes: 7 additions & 5 deletions app/services/guidance_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import guidance

from app.config import LLMModelConfig
from app.models.dtos import Content, ContentType
from app.services.guidance_functions import truncate


Expand All @@ -21,7 +20,7 @@ def __init__(
self.handlebars = handlebars
self.parameters = parameters

def query(self) -> Content:
def query(self) -> dict:
"""Get response from a chosen LLM model.
Returns:
Expand All @@ -33,8 +32,9 @@ def query(self) -> Content:
"""

import re

pattern = r'{{(?:gen|geneach|set) [\'"]([^\'"]+)[\'"]$}}'
var_names = re.findall(pattern, input_string)
var_names = re.findall(pattern, self.handlebars)

template = guidance(self.handlebars)
result = template(
Expand All @@ -47,7 +47,9 @@ def query(self) -> Content:
raise result._exception

generated_vars = {
var_name: result[var_name] for var_name in var_names if var_name in result
var_name: result[var_name]
for var_name in var_names
if var_name in result
}

return generated_vars
Expand All @@ -69,7 +71,7 @@ def is_up(self) -> bool:
content = (
GuidanceWrapper(model=self.model, handlebars=handlebars)
.query()
.text_content
.get("response")
)
return content == "1"

Expand Down
8 changes: 4 additions & 4 deletions tests/services/guidance_wrapper_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_query_success(mocker):
result = guidance_wrapper.query()

assert isinstance(result, dict)
assert result['response'] == "the output"
assert result["response"] == "the output"


def test_query_using_truncate_function(mocker):
Expand All @@ -58,8 +58,8 @@ def test_query_using_truncate_function(mocker):
result = guidance_wrapper.query()

assert isinstance(result, dict)
assert result['answer'] == "the output"
assert result['response'] == "the"
assert result["answer"] == "the output"
assert result["response"] == "the"


def test_query_missing_required_params(mocker):
Expand All @@ -83,4 +83,4 @@ def test_query_missing_required_params(mocker):
result = guidance_wrapper.query()

assert isinstance(result, dict)
assert result['response'] == "the output"
assert result["response"] == "the output"

0 comments on commit 153d2b7

Please sign in to comment.