Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Development: Add V2 Messages Endpoint #34

Merged
merged 16 commits into from
Nov 24, 2023
14 changes: 2 additions & 12 deletions app/models/dtos.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,6 @@ class ContentType(str, Enum):
TEXT = "text"


class Content(BaseModel):
text_content: str = Field(..., alias="textContent")
type: ContentType


class SendMessageRequest(BaseModel):
class Template(BaseModel):
id: int
Expand All @@ -29,14 +24,9 @@ class Template(BaseModel):


class SendMessageResponse(BaseModel):
class Message(BaseModel):
sent_at: datetime = Field(
alias="sentAt", default_factory=datetime.utcnow
)
content: list[Content]

used_model: str = Field(..., alias="usedModel")
message: Message
sent_at: datetime = Field(alias="sentAt", default_factory=datetime.utcnow)
content: dict


class ModelStatus(BaseModel):
Expand Down
9 changes: 2 additions & 7 deletions app/routes/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,8 @@ def send_message(body: SendMessageRequest) -> SendMessageResponse:
except Exception as e:
raise InternalServerException(str(e))

# Turn content into an array if it's not already
if not isinstance(content, list):
content = [content]

return SendMessageResponse(
usedModel=body.preferred_model,
message=SendMessageResponse.Message(
sentAt=datetime.now(timezone.utc), content=content
),
sentAt=datetime.now(timezone.utc),
content=content,
)
22 changes: 16 additions & 6 deletions app/services/guidance_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import guidance

from app.config import LLMModelConfig
from app.models.dtos import Content, ContentType
from app.services.guidance_functions import truncate


Expand All @@ -21,7 +20,7 @@ def __init__(
self.handlebars = handlebars
self.parameters = parameters

def query(self) -> Content:
def query(self) -> dict:
"""Get response from a chosen LLM model.

Returns:
Expand All @@ -32,6 +31,14 @@ def query(self) -> Content:
ValueError: if handlebars do not generate 'response'
"""

# Perform a regex search to find the names of the variables being generated
# by the handlebars template
# This regex matches strings like: {{gen 'response' temperature=0.0 max_tokens=500}}
# and extracts the variable name 'response'
import re
MichaelOwenDyer marked this conversation as resolved.
Show resolved Hide resolved
pattern = r'{{(?:gen|geneach|set) [\'"]([^\'"]+)[\'"]$}}'
MichaelOwenDyer marked this conversation as resolved.
Show resolved Hide resolved
var_names = re.findall(pattern, self.handlebars)

template = guidance(self.handlebars)
result = template(
llm=self._get_llm(),
Expand All @@ -42,10 +49,13 @@ def query(self) -> Content:
if isinstance(result._exception, Exception):
raise result._exception

if "response" not in result:
raise ValueError("The handlebars do not generate 'response'")
generated_vars = {
var_name: result[var_name]
for var_name in var_names
if var_name in result
}

return Content(type=ContentType.TEXT, textContent=result["response"])
return generated_vars

def is_up(self) -> bool:
"""Check if the chosen LLM model is up.
Expand All @@ -64,7 +74,7 @@ def is_up(self) -> bool:
content = (
GuidanceWrapper(model=self.model, handlebars=handlebars)
.query()
.text_content
.get("response")
)
return content == "1"

Expand Down
12 changes: 6 additions & 6 deletions tests/routes/messages_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ def test_send_message(test_client, headers, mocker):
mocker.patch.object(
GuidanceWrapper,
"query",
return_value=Content(
type=ContentType.TEXT, textContent="some content"
),
return_value={
"response": "some content",
},
autospec=True,
)

Expand All @@ -55,9 +55,9 @@ def test_send_message(test_client, headers, mocker):
assert response.status_code == 200
assert response.json() == {
"usedModel": "GPT35_TURBO",
"message": {
"sentAt": "2023-06-16T01:21:34+00:00",
"content": [{"textContent": "some content", "type": "text"}],
"sentAt": "2023-06-16T01:21:34+00:00",
"content": {
"response": "some content",
},
}

Expand Down
41 changes: 7 additions & 34 deletions tests/services/guidance_wrapper_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import pytest
import guidance

from app.models.dtos import Content, ContentType
from app.services.guidance_wrapper import GuidanceWrapper
from app.config import OpenAIConfig

Expand Down Expand Up @@ -33,9 +32,8 @@ def test_query_success(mocker):

result = guidance_wrapper.query()

assert isinstance(result, Content)
assert result.type == ContentType.TEXT
assert result.text_content == "the output"
assert isinstance(result, dict)
assert result["response"] == "the output"


def test_query_using_truncate_function(mocker):
Expand All @@ -59,9 +57,9 @@ def test_query_using_truncate_function(mocker):

result = guidance_wrapper.query()

assert isinstance(result, Content)
assert result.type == ContentType.TEXT
assert result.text_content == "the"
assert isinstance(result, dict)
assert result["answer"] == "the output"
assert result["response"] == "the"


def test_query_missing_required_params(mocker):
Expand All @@ -84,30 +82,5 @@ def test_query_missing_required_params(mocker):
with pytest.raises(KeyError, match="Command/variable 'query' not found!"):
result = guidance_wrapper.query()

assert isinstance(result, Content)
assert result.type == ContentType.TEXT
assert result.text_content == "the output"


def test_query_handlebars_not_generate_response(mocker):
mocker.patch.object(
GuidanceWrapper,
"_get_llm",
return_value=guidance.llms.Mock("the output"),
)

handlebars = "Not a valid handlebars"
guidance_wrapper = GuidanceWrapper(
model=llm_model_config,
handlebars=handlebars,
parameters={"query": "Something"},
)

with pytest.raises(
ValueError, match="The handlebars do not generate 'response'"
):
result = guidance_wrapper.query()

assert isinstance(result, Content)
assert result.type == ContentType.TEXT
assert result.text_content == "the output"
assert isinstance(result, dict)
assert result["response"] == "the output"
Loading