From 94d78e52745e440fcced7a47e8ce5708d8338e77 Mon Sep 17 00:00:00 2001 From: Timor Morrien Date: Sat, 22 Jul 2023 16:10:16 +0200 Subject: [PATCH] Development: Add truncate function (#15) --- app/services/guidance_functions.py | 8 +++++++ app/services/guidance_wrapper.py | 2 ++ tests/services/guidance_functions_test.py | 21 ++++++++++++++++++ tests/services/guidance_wrapper_test.py | 26 +++++++++++++++++++++++ 4 files changed, 57 insertions(+) create mode 100644 app/services/guidance_functions.py create mode 100644 tests/services/guidance_functions_test.py diff --git a/app/services/guidance_functions.py b/app/services/guidance_functions.py new file mode 100644 index 00000000..a703775d --- /dev/null +++ b/app/services/guidance_functions.py @@ -0,0 +1,8 @@ +def truncate(history: list[any], max_length: int): + if max_length == 0: + return [] + + if max_length > 0: + return history[:max_length] + + return history[max_length:] diff --git a/app/services/guidance_wrapper.py b/app/services/guidance_wrapper.py index 3d6d653a..c990f47c 100644 --- a/app/services/guidance_wrapper.py +++ b/app/services/guidance_wrapper.py @@ -2,6 +2,7 @@ from app.config import LLMModelConfig from app.models.dtos import Content, ContentType +from app.services.guidance_functions import truncate class GuidanceWrapper: @@ -34,6 +35,7 @@ def query(self) -> Content: template = guidance(self.handlebars) result = template( llm=self._get_llm(), + truncate=truncate, **self.parameters, ) diff --git a/tests/services/guidance_functions_test.py b/tests/services/guidance_functions_test.py new file mode 100644 index 00000000..182c3c4b --- /dev/null +++ b/tests/services/guidance_functions_test.py @@ -0,0 +1,21 @@ +import pytest +from app.services.guidance_functions import truncate + + +@pytest.mark.parametrize( + "history,max_length,expected", + [ + ([], -2, []), + ([], 0, []), + ([], 2, []), + ([1, 2, 3], 0, []), + # Get the last n elements + ([1, 2, 3], -4, [1, 2, 3]), + ([1, 2, 3], -2, [2, 3]), + # Get the first n elements + ([1, 2, 3], 2, [1, 2]), + ([1, 2, 3], 4, [1, 2, 3]), + ], +) +def test_truncate(history, max_length, expected): + assert truncate(history, max_length) == expected diff --git a/tests/services/guidance_wrapper_test.py b/tests/services/guidance_wrapper_test.py index b20bb666..526d8e52 100644 --- a/tests/services/guidance_wrapper_test.py +++ b/tests/services/guidance_wrapper_test.py @@ -34,6 +34,32 @@ def test_query_success(mocker): assert result.text_content == "the output" +def test_query_using_truncate_function(mocker): + mocker.patch.object( + GuidanceWrapper, + "_get_llm", + return_value=guidance.llms.Mock("the output"), + ) + + handlebars = """{{#user~}}I want a response to the following query: + {{query}}{{~/user}}{{#assistant~}} + {{gen 'answer' temperature=0.0 max_tokens=500}}{{~/assistant}} + {{set 'response' (truncate answer 3)}} + """ + + guidance_wrapper = GuidanceWrapper( + model=llm_model_config, + handlebars=handlebars, + parameters={"query": "Some query"}, + ) + + result = guidance_wrapper.query() + + assert isinstance(result, Content) + assert result.type == ContentType.TEXT + assert result.text_content == "the" + + def test_query_missing_required_params(mocker): mocker.patch.object( GuidanceWrapper,