Skip to content

Commit

Permalink
Development: Add V2 Messages Endpoint (#34)
Browse files Browse the repository at this point in the history
Co-authored-by: Timor Morrien <[email protected]>
  • Loading branch information
MichaelOwenDyer and Hialus authored Nov 24, 2023
1 parent fe3176a commit 74de5c7
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 63 deletions.
15 changes: 14 additions & 1 deletion app/models/dtos.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,22 @@ class ContentType(str, Enum):
TEXT = "text"


# V1 API only
class Content(BaseModel):
text_content: str = Field(..., alias="textContent")
type: ContentType


class SendMessageRequest(BaseModel):
class Template(BaseModel):
id: int
content: str

template: Template
preferred_model: str = Field(..., alias="preferredModel")
parameters: dict


# V1 API only
class SendMessageResponse(BaseModel):
class Message(BaseModel):
sent_at: datetime = Field(
Expand All @@ -39,6 +40,18 @@ class Message(BaseModel):
message: Message


class SendMessageRequestV2(BaseModel):
template: str
preferred_model: str = Field(..., alias="preferredModel")
parameters: dict


class SendMessageResponseV2(BaseModel):
used_model: str = Field(..., alias="usedModel")
sent_at: datetime = Field(alias="sentAt", default_factory=datetime.utcnow)
content: dict


class ModelStatus(BaseModel):
model: str
status: LLMStatus
Expand Down
66 changes: 52 additions & 14 deletions app/routes/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,33 +11,37 @@
InvalidModelException,
)
from app.dependencies import TokenPermissionsValidator
from app.models.dtos import SendMessageRequest, SendMessageResponse
from app.models.dtos import (
SendMessageRequest,
SendMessageResponse,
Content,
ContentType,
SendMessageRequestV2,
SendMessageResponseV2,
)
from app.services.circuit_breaker import CircuitBreaker
from app.services.guidance_wrapper import GuidanceWrapper
from app.config import settings

router = APIRouter(tags=["messages"])


@router.post(
"/api/v1/messages", dependencies=[Depends(TokenPermissionsValidator())]
)
def send_message(body: SendMessageRequest) -> SendMessageResponse:
def execute_call(template, preferred_model, parameters) -> dict:
try:
model = settings.pyris.llms[body.preferred_model]
model = settings.pyris.llms[preferred_model]
except ValueError as e:
raise InvalidModelException(str(e))

guidance = GuidanceWrapper(
model=model,
handlebars=body.template.content,
parameters=body.parameters,
handlebars=template,
parameters=parameters,
)

try:
content = CircuitBreaker.protected_call(
return CircuitBreaker.protected_call(
func=guidance.query,
cache_key=body.preferred_model,
cache_key=preferred_model,
accepted_exceptions=(
KeyError,
SyntaxError,
Expand All @@ -52,13 +56,47 @@ def send_message(body: SendMessageRequest) -> SendMessageResponse:
except Exception as e:
raise InternalServerException(str(e))

# Turn content into an array if it's not already
if not isinstance(content, list):
content = [content]

@router.post(
"/api/v1/messages", dependencies=[Depends(TokenPermissionsValidator())]
)
def send_message(body: SendMessageRequest) -> SendMessageResponse:
generated_vars = execute_call(
body.template.content, body.preferred_model, body.parameters
)

# V1: Throw an exception if no 'response' variable was generated
if "response" not in generated_vars:
raise InternalServerException(
str(ValueError("The handlebars do not generate 'response'"))
)

return SendMessageResponse(
usedModel=body.preferred_model,
message=SendMessageResponse.Message(
sentAt=datetime.now(timezone.utc), content=content
sentAt=datetime.now(timezone.utc),
content=[
Content(
type=ContentType.TEXT,
textContent=generated_vars[
"response"
], # V1: only return the 'response' variable
)
],
),
)


@router.post(
"/api/v2/messages", dependencies=[Depends(TokenPermissionsValidator())]
)
def send_message_v2(body: SendMessageRequestV2) -> SendMessageResponseV2:
generated_vars = execute_call(
body.template, body.preferred_model, body.parameters
)

return SendMessageResponseV2(
usedModel=body.preferred_model,
sentAt=datetime.now(timezone.utc),
content=generated_vars, # V2: return all generated variables
)
25 changes: 18 additions & 7 deletions app/services/guidance_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import guidance
import re

from app.config import LLMModelConfig
from app.models.dtos import Content, ContentType
from app.services.guidance_functions import truncate


Expand All @@ -21,17 +21,25 @@ def __init__(
self.handlebars = handlebars
self.parameters = parameters

def query(self) -> Content:
def query(self) -> dict:
"""Get response from a chosen LLM model.
Returns:
Text content object with LLM's response.
Raises:
Reraises exception from guidance package
ValueError: if handlebars do not generate 'response'
"""

# Perform a regex search to find the names of the variables
# being generated in the program. This regex matches strings like:
# {{gen 'response' temperature=0.0 max_tokens=500}}
# {{#geneach 'values' num_iterations=3}}
# {{set 'answer' (truncate response 3)}}
# and extracts the variable names 'response', 'values', and 'answer'
pattern = r'{{#?(?:gen|geneach|set) +[\'"]([^\'"]+)[\'"]'
var_names = re.findall(pattern, self.handlebars)

template = guidance(self.handlebars)
result = template(
llm=self._get_llm(),
Expand All @@ -42,10 +50,13 @@ def query(self) -> Content:
if isinstance(result._exception, Exception):
raise result._exception

if "response" not in result:
raise ValueError("The handlebars do not generate 'response'")
generated_vars = {
var_name: result[var_name]
for var_name in var_names
if var_name in result
}

return Content(type=ContentType.TEXT, textContent=result["response"])
return generated_vars

def is_up(self) -> bool:
"""Check if the chosen LLM model is up.
Expand All @@ -64,7 +75,7 @@ def is_up(self) -> bool:
content = (
GuidanceWrapper(model=self.model, handlebars=handlebars)
.query()
.text_content
.get("response")
)
return content == "1"

Expand Down
52 changes: 45 additions & 7 deletions tests/routes/messages_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import pytest
from freezegun import freeze_time
from app.models.dtos import Content, ContentType
from app.services.guidance_wrapper import GuidanceWrapper
import app.config as config

Expand Down Expand Up @@ -31,9 +30,9 @@ def test_send_message(test_client, headers, mocker):
mocker.patch.object(
GuidanceWrapper,
"query",
return_value=Content(
type=ContentType.TEXT, textContent="some content"
),
return_value={
"response": "some content",
},
autospec=True,
)

Expand All @@ -51,9 +50,11 @@ def test_send_message(test_client, headers, mocker):
"query": "Some query",
},
}
response = test_client.post("/api/v1/messages", headers=headers, json=body)
assert response.status_code == 200
assert response.json() == {
response_v1 = test_client.post(
"/api/v1/messages", headers=headers, json=body
)
assert response_v1.status_code == 200
assert response_v1.json() == {
"usedModel": "GPT35_TURBO",
"message": {
"sentAt": "2023-06-16T01:21:34+00:00",
Expand All @@ -62,6 +63,43 @@ def test_send_message(test_client, headers, mocker):
}


@freeze_time("2023-06-16 03:21:34 +02:00")
@pytest.mark.usefixtures("model_configs")
def test_send_message_v2(test_client, headers, mocker):
mocker.patch.object(
GuidanceWrapper,
"query",
return_value={
"response": "some content",
},
autospec=True,
)

body = {
"template": "{{#user~}}I want a response to the following query:\
{{query}}{{~/user}}{{#assistant~}}\
{{gen 'response' temperature=0.0 max_tokens=500}}{{~/assistant}}",
"preferredModel": "GPT35_TURBO",
"parameters": {
"course": "Intro to Java",
"exercise": "Fun With Sets",
"query": "Some query",
},
}

response_v2 = test_client.post(
"/api/v2/messages", headers=headers, json=body
)
assert response_v2.status_code == 200
assert response_v2.json() == {
"usedModel": "GPT35_TURBO",
"sentAt": "2023-06-16T01:21:34+00:00",
"content": {
"response": "some content",
},
}


def test_send_message_missing_model(test_client, headers):
response = test_client.post("/api/v1/messages", headers=headers, json={})
assert response.status_code == 404
Expand Down
41 changes: 7 additions & 34 deletions tests/services/guidance_wrapper_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import pytest
import guidance

from app.models.dtos import Content, ContentType
from app.services.guidance_wrapper import GuidanceWrapper
from app.config import OpenAIConfig

Expand Down Expand Up @@ -33,9 +32,8 @@ def test_query_success(mocker):

result = guidance_wrapper.query()

assert isinstance(result, Content)
assert result.type == ContentType.TEXT
assert result.text_content == "the output"
assert isinstance(result, dict)
assert result["response"] == "the output"


def test_query_using_truncate_function(mocker):
Expand All @@ -59,9 +57,9 @@ def test_query_using_truncate_function(mocker):

result = guidance_wrapper.query()

assert isinstance(result, Content)
assert result.type == ContentType.TEXT
assert result.text_content == "the"
assert isinstance(result, dict)
assert result["answer"] == "the output"
assert result["response"] == "the"


def test_query_missing_required_params(mocker):
Expand All @@ -84,30 +82,5 @@ def test_query_missing_required_params(mocker):
with pytest.raises(KeyError, match="Command/variable 'query' not found!"):
result = guidance_wrapper.query()

assert isinstance(result, Content)
assert result.type == ContentType.TEXT
assert result.text_content == "the output"


def test_query_handlebars_not_generate_response(mocker):
mocker.patch.object(
GuidanceWrapper,
"_get_llm",
return_value=guidance.llms.Mock("the output"),
)

handlebars = "Not a valid handlebars"
guidance_wrapper = GuidanceWrapper(
model=llm_model_config,
handlebars=handlebars,
parameters={"query": "Something"},
)

with pytest.raises(
ValueError, match="The handlebars do not generate 'response'"
):
result = guidance_wrapper.query()

assert isinstance(result, Content)
assert result.type == ContentType.TEXT
assert result.text_content == "the output"
assert isinstance(result, dict)
assert result["response"] == "the output"

0 comments on commit 74de5c7

Please sign in to comment.