From c4adcfb2d7d6195757d84dec8171d34a081dfa01 Mon Sep 17 00:00:00 2001 From: Sina Date: Thu, 24 Oct 2024 05:02:48 +0000 Subject: [PATCH] Allow prompts without an `instruction` block --- chainlite/load_prompt.py | 13 +++++++++---- tests/test_llm_generate.py | 13 +++++++++++++ 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/chainlite/load_prompt.py b/chainlite/load_prompt.py index c7c1310..a46e82f 100644 --- a/chainlite/load_prompt.py +++ b/chainlite/load_prompt.py @@ -139,8 +139,8 @@ def _split_prompt_to_blocks(prompt: str) -> List[Tuple[str, str]]: # check the prompt format is correct assert ( - len([b for b in block_indices if b[1] == "instruction"]) == 1 - ), "Prompts should contain exactly one instruction block" + len([b for b in block_indices if b[1] == "instruction"]) <= 1 + ), "Prompts should contain at most one instruction block" num_distillation_instruction = len( [b for b in block_indices if b[1] == "distillation_instruction"] ) @@ -166,12 +166,12 @@ def _split_prompt_to_blocks(prompt: str) -> List[Tuple[str, str]]: block_indices_with_end = block_indices + [(len(prompt), "end", "end")] blocks = [] for i in range(len(block_indices)): - block_string = prompt[ + block_content = prompt[ block_indices_with_end[i][0] + len(block_indices_with_end[i][2]) : block_indices_with_end[i + 1][0] ].strip() - blocks.append((block_indices_with_end[i][1], block_string)) + blocks.append((block_indices_with_end[i][1], block_content)) return blocks @@ -181,6 +181,11 @@ def _prompt_blocks_to_chat_messages( ) -> Tuple[ChatPromptTemplate, str | None]: message_prompt_templates = [] distillation_instruction = None + + # Add an instruction block if it is not present + if len([b for b in blocks if b[0] == "instruction"]) == 0: + blocks = [("instruction", "")] + blocks + if is_distilled: assert "distillation_instruction" in [ b[0] for b in blocks diff --git a/tests/test_llm_generate.py b/tests/test_llm_generate.py index c1cb29a..a074f7b 100644 --- a/tests/test_llm_generate.py +++ b/tests/test_llm_generate.py @@ -74,6 +74,19 @@ async def test_string_prompts(): temperature=0, ).ainvoke({"variable": "Y"}) assert "The value of Y is six" in response + + # Without instruction block + response = await llm_generation_chain( + template_file="", + template_blocks=[ + ("input", "what is X?"), + ("output", "The value of X is one"), + ("input", "what is {{ variable }}?"), + ], + engine=test_engine, + max_tokens=10, + temperature=0, + ).ainvoke({"variable": "Y"}) write_prompt_logs_to_file("tests/llm_input_outputs.jsonl")