From 9da169327b65501127f925f353271c751f9ea9f9 Mon Sep 17 00:00:00 2001 From: Sina Date: Thu, 25 Jul 2024 03:23:52 +0000 Subject: [PATCH] Add option to define new global prompt constants --- chainlite/__init__.py | 2 ++ chainlite/load_prompt.py | 57 ++++++++++++++++++++++++++------------ tests/constants.prompt | 6 ++++ tests/test_llm_generate.py | 34 +++++++++++++++++++++-- 4 files changed, 78 insertions(+), 21 deletions(-) create mode 100644 tests/constants.prompt diff --git a/chainlite/__init__.py b/chainlite/__init__.py index e9bff61..ed95f64 100644 --- a/chainlite/__init__.py +++ b/chainlite/__init__.py @@ -11,6 +11,7 @@ write_prompt_logs_to_file, string_to_json, ) +from .load_prompt import register_prompt_constants from .utils import get_logger @@ -24,4 +25,5 @@ "chain", "get_all_configured_engines", "string_to_json", + "register_prompt_constants", ] diff --git a/chainlite/load_prompt.py b/chainlite/load_prompt.py index c62d973..73cf2f8 100644 --- a/chainlite/load_prompt.py +++ b/chainlite/load_prompt.py @@ -5,10 +5,12 @@ from zoneinfo import ZoneInfo # Python 3.9 and later from jinja2 import Environment, FileSystemLoader -from langchain_core.prompts import (AIMessagePromptTemplate, - ChatPromptTemplate, - HumanMessagePromptTemplate, - SystemMessagePromptTemplate) +from langchain_core.prompts import ( + AIMessagePromptTemplate, + ChatPromptTemplate, + HumanMessagePromptTemplate, + SystemMessagePromptTemplate, +) jinja2_comment_pattern = re.compile(r"{#.*?#}", re.DOTALL) @@ -61,7 +63,25 @@ def load_template_file(template_file: str, keep_indentation: bool) -> str: return raw_template -def add_template_constants( +added_template_constants = {} + + +def register_prompt_constants(constant_name_to_value_map: dict) -> None: + """ + Make constant values available to all prompt templates. + By default, current_year, today and location are set, and you can overwrite them or add new constants using this method. + + Args: + constant_name_to_value_map (dict): A dictionary where keys are constant names and values are the corresponding constant values. + + Returns: + None + """ + for k, v in constant_name_to_value_map.items(): + added_template_constants[k] = v + + +def add_constants_to_template( chat_prompt_template: ChatPromptTemplate, ) -> ChatPromptTemplate: # always make these useful constants available in a template @@ -69,16 +89,15 @@ def add_template_constants( pacific_zone = ZoneInfo("America/Los_Angeles") today = datetime.now(pacific_zone).date() - current_year = today.year - today = today.strftime("%B %d, %Y") # e.g. May 30, 2024 - location = "the U.S." - chatbot_name = "WikiChat" - chat_prompt_template = chat_prompt_template.partial( - today=today, - current_year=current_year, - location=location, - chatbot_name=chatbot_name, - ) + template_constants = { + "current_year": today.year, + "today": today.strftime("%B %d, %Y"), # e.g. May 30, 2024 + "location": "the U.S.", + } + for k, v in added_template_constants.items(): + template_constants[k] = v + + chat_prompt_template = chat_prompt_template.partial(**template_constants) return chat_prompt_template @@ -166,11 +185,13 @@ def _prompt_blocks_to_chat_messages( # only keep the distillation_instruction and the last input assert distillation_instruction is not None message_prompt_templates = [ - SystemMessagePromptTemplate.from_template(distillation_instruction, template_format="jinja2"), + SystemMessagePromptTemplate.from_template( + distillation_instruction, template_format="jinja2" + ), message_prompt_templates[-1], ] chat_prompt_template = ChatPromptTemplate.from_messages(message_prompt_templates) - chat_prompt_template = add_template_constants(chat_prompt_template) + chat_prompt_template = add_constants_to_template(chat_prompt_template) if distillation_instruction is None: # if distillation instruction is not provided, will default to instruction block_type, distillation_instruction = tuple( @@ -180,7 +201,7 @@ def _prompt_blocks_to_chat_messages( distillation_instruction = ( ( - add_template_constants( + add_constants_to_template( ChatPromptTemplate.from_template( distillation_instruction, template_format="jinja2" ) diff --git a/tests/constants.prompt b/tests/constants.prompt new file mode 100644 index 0000000..12b2aef --- /dev/null +++ b/tests/constants.prompt @@ -0,0 +1,6 @@ +# instruction +Today's date is {{ today }}. +The current year is {{ current_year }}. + +# input +{{ question }} \ No newline at end of file diff --git a/tests/test_llm_generate.py b/tests/test_llm_generate.py index b6281fd..40d267a 100644 --- a/tests/test_llm_generate.py +++ b/tests/test_llm_generate.py @@ -1,3 +1,5 @@ +from datetime import datetime +from zoneinfo import ZoneInfo import pytest from langchain_core.runnables import RunnableLambda @@ -6,7 +8,8 @@ llm_generation_chain, load_config_from_file, write_prompt_logs_to_file, - get_all_configured_engines + get_all_configured_engines, + register_prompt_constants, ) from chainlite.llm_config import GlobalVars @@ -22,11 +25,13 @@ {"topic": "Rabbits"}, ] -test_engine="gpt-4o-mini" +test_engine = "gpt-4o-mini" + @pytest.mark.asyncio(scope="session") async def test_llm_generate(): - print(get_all_configured_engines()) + logger.info("All registered engines: %s", str(get_all_configured_engines())) + # Check that the config file has been loaded properly assert GlobalVars.all_llm_endpoints assert GlobalVars.prompt_dirs @@ -58,6 +63,29 @@ async def test_readme_example(): ).ainvoke({"topic": "Life as a PhD student"}) +@pytest.mark.asyncio(scope="session") +async def test_constants(): + pacific_zone = ZoneInfo("America/Los_Angeles") + today = datetime.now(pacific_zone).date().strftime("%B %d, %Y") # e.g. May 30, 2024 + response = await llm_generation_chain( + template_file="tests/constants.prompt", + engine=test_engine, + max_tokens=10, + temperature=0, + ).ainvoke({"question": "what is today's date?"}) + assert today in response + + # overwrite "today" + register_prompt_constants({"today": "Thursday"}) + response = await llm_generation_chain( + template_file="tests/constants.prompt", + engine=test_engine, + max_tokens=10, + temperature=0, + ).ainvoke({"question": "what day of the week is today?"}) + assert "thursday" in response.lower() + + @pytest.mark.asyncio(scope="session") async def test_batching(): response = await llm_generation_chain(