From b4243137bc2a3219c484e4b73a0e17c27d4e1b1a Mon Sep 17 00:00:00 2001 From: Aarya2004 Date: Thu, 7 Nov 2024 12:05:59 -0500 Subject: [PATCH 1/4] fix(pipeline partial): Fixed bug caused by partial method in PipelinePromptTemplate --- libs/core/langchain_core/prompts/pipeline.py | 31 +++++++++++-- .../prompts/test_pipeline_prompt.py | 45 +++++++++++++++++++ 2 files changed, 72 insertions(+), 4 deletions(-) diff --git a/libs/core/langchain_core/prompts/pipeline.py b/libs/core/langchain_core/prompts/pipeline.py index e25a0a7f72461..15b126f0bdc3d 100644 --- a/libs/core/langchain_core/prompts/pipeline.py +++ b/libs/core/langchain_core/prompts/pipeline.py @@ -8,8 +8,16 @@ from langchain_core.prompts.chat import BaseChatPromptTemplate -def _get_inputs(inputs: dict, input_variables: list[str]) -> dict: - return {k: inputs[k] for k in input_variables} +def _get_inputs(inputs: dict, input_variables: list[str], partial_variables: Optional[dict] = None) -> dict: + result_dict = {} + if partial_variables is not None and len(partial_variables) != 0: + result_dict = {partial_k: partial_val for partial_k, partial_val in partial_variables.items()} + for k in input_variables: + if k in inputs: + result_dict[k] = inputs[k] + if k not in inputs and k not in result_dict: + raise ValueError(f"Input {k} was not provided and is not a partial") + return result_dict class PipelinePromptTemplate(BasePromptTemplate): @@ -58,7 +66,7 @@ def format_prompt(self, **kwargs: Any) -> PromptValue: A formatted string. """ for k, prompt in self.pipeline_prompts: - _inputs = _get_inputs(kwargs, prompt.input_variables) + _inputs = _get_inputs(kwargs, prompt.input_variables, prompt.partial_variables) if isinstance(prompt, BaseChatPromptTemplate): kwargs[k] = prompt.format_messages(**_inputs) else: @@ -76,7 +84,7 @@ async def aformat_prompt(self, **kwargs: Any) -> PromptValue: A formatted string. """ for k, prompt in self.pipeline_prompts: - _inputs = _get_inputs(kwargs, prompt.input_variables) + _inputs = _get_inputs(kwargs, prompt.input_variables, prompt.partial_variables) if isinstance(prompt, BaseChatPromptTemplate): kwargs[k] = await prompt.aformat_messages(**_inputs) else: @@ -106,6 +114,21 @@ async def aformat(self, **kwargs: Any) -> str: """ return (await self.aformat_prompt(**kwargs)).to_string() + def partial(self, **kwargs: dict[str, str]) -> None: + """Return a partial of the prompt template. + + Args: + kwargs: dict[str, str], partial variables to set. + + Returns: + BasePromptTemplate: A partial of the prompt template. + """ + for partial_var, partial_input in kwargs.items(): + for k, prompt in self.pipeline_prompts: + if partial_var in prompt.input_variables: + prompt.partial_variables[partial_var] = partial_input + + @property def _prompt_type(self) -> str: raise ValueError diff --git a/libs/core/tests/unit_tests/prompts/test_pipeline_prompt.py b/libs/core/tests/unit_tests/prompts/test_pipeline_prompt.py index f62af4c75777e..70d8953cd0288 100644 --- a/libs/core/tests/unit_tests/prompts/test_pipeline_prompt.py +++ b/libs/core/tests/unit_tests/prompts/test_pipeline_prompt.py @@ -32,6 +32,51 @@ def test_multi_variable_pipeline() -> None: assert output == "okay jim deep" +def test_partial_with_prompt_template() -> None: + full_template = """{introduction} + {example} + {start}""" + full_prompt = PromptTemplate.from_template(full_template) + + introduction_template = """You are impersonating {person}.""" + introduction_prompt = PromptTemplate.from_template(introduction_template) + + example_template = """Here's an example of an interaction: + Q: {example_q} + A: {example_a}""" + example_prompt = PromptTemplate.from_template(example_template) + + start_template = """Now, do this for real! + Q: {input} + A:""" + start_prompt = PromptTemplate.from_template(start_template) + + input_prompts = [ + ("introduction", introduction_prompt), + ("example", example_prompt), + ("start", start_prompt), + ] + pipeline_prompt = PipelinePromptTemplate( + final_prompt=full_prompt, pipeline_prompts=input_prompts + ) + + pipeline_prompt.partial(person="Elon Musk") + pipeline_prompt.partial(invalid_partial="Hello, I am invalid") + + ret = pipeline_prompt.format( + example_q="What's your favorite car?", + example_a="Tesla", + input="What's your favorite social media site?", + ) + assert ret == """You are impersonating Elon Musk. + Here's an example of an interaction: + Q: What's your favorite car? + A: Tesla + Now, do this for real! + Q: What's your favorite social media site? + A:""" + + async def test_partial_with_chat_prompts() -> None: prompt_a = ChatPromptTemplate( input_variables=["foo"], messages=[MessagesPlaceholder(variable_name="foo")] From cb70c16f8df32606ca1d5033b9a19a67dfdadf5c Mon Sep 17 00:00:00 2001 From: Aarya2004 Date: Wed, 27 Nov 2024 02:15:58 -0500 Subject: [PATCH 2/4] refactor(partial): Made changes more concise and reduced code changes --- libs/core/langchain_core/prompts/pipeline.py | 38 +++++++++---------- .../prompts/test_pipeline_prompt.py | 7 +++- 2 files changed, 23 insertions(+), 22 deletions(-) diff --git a/libs/core/langchain_core/prompts/pipeline.py b/libs/core/langchain_core/prompts/pipeline.py index 15b126f0bdc3d..ace28ec3d0166 100644 --- a/libs/core/langchain_core/prompts/pipeline.py +++ b/libs/core/langchain_core/prompts/pipeline.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, Callable from typing import Optional as Optional from pydantic import model_validator @@ -8,16 +8,12 @@ from langchain_core.prompts.chat import BaseChatPromptTemplate -def _get_inputs(inputs: dict, input_variables: list[str], partial_variables: Optional[dict] = None) -> dict: - result_dict = {} - if partial_variables is not None and len(partial_variables) != 0: - result_dict = {partial_k: partial_val for partial_k, partial_val in partial_variables.items()} +def _get_inputs(inputs: dict, input_variables: list[str]) -> dict: + ret = {} for k in input_variables: if k in inputs: - result_dict[k] = inputs[k] - if k not in inputs and k not in result_dict: - raise ValueError(f"Input {k} was not provided and is not a partial") - return result_dict + ret[k] = inputs[k] + return ret class PipelinePromptTemplate(BasePromptTemplate): @@ -66,7 +62,7 @@ def format_prompt(self, **kwargs: Any) -> PromptValue: A formatted string. """ for k, prompt in self.pipeline_prompts: - _inputs = _get_inputs(kwargs, prompt.input_variables, prompt.partial_variables) + _inputs = _get_inputs(kwargs, prompt.input_variables) if isinstance(prompt, BaseChatPromptTemplate): kwargs[k] = prompt.format_messages(**_inputs) else: @@ -84,7 +80,7 @@ async def aformat_prompt(self, **kwargs: Any) -> PromptValue: A formatted string. """ for k, prompt in self.pipeline_prompts: - _inputs = _get_inputs(kwargs, prompt.input_variables, prompt.partial_variables) + _inputs = _get_inputs(kwargs, prompt.input_variables) if isinstance(prompt, BaseChatPromptTemplate): kwargs[k] = await prompt.aformat_messages(**_inputs) else: @@ -114,20 +110,22 @@ async def aformat(self, **kwargs: Any) -> str: """ return (await self.aformat_prompt(**kwargs)).to_string() - def partial(self, **kwargs: dict[str, str]) -> None: - """Return a partial of the prompt template. + # Ignoring the type below since partial makes modifications rather + # than returning a new template + def partial(self, **kwargs: str | Callable[[], str]) -> None: # type: ignore[override] + """Add partial arguments to prompts in pipeline_prompts Args: kwargs: dict[str, str], partial variables to set. - - Returns: - BasePromptTemplate: A partial of the prompt template. """ - for partial_var, partial_input in kwargs.items(): - for k, prompt in self.pipeline_prompts: + for i, string_and_prompt in enumerate(self.pipeline_prompts): + k, prompt = string_and_prompt + prompt_kwargs = {} + for partial_var, partial_input in kwargs.items(): if partial_var in prompt.input_variables: - prompt.partial_variables[partial_var] = partial_input - + prompt_kwargs[partial_var] = partial_input + if len(prompt_kwargs) > 0: + self.pipeline_prompts[i] = (k, prompt.partial(**prompt_kwargs)) @property def _prompt_type(self) -> str: diff --git a/libs/core/tests/unit_tests/prompts/test_pipeline_prompt.py b/libs/core/tests/unit_tests/prompts/test_pipeline_prompt.py index 70d8953cd0288..889f7a8bd3175 100644 --- a/libs/core/tests/unit_tests/prompts/test_pipeline_prompt.py +++ b/libs/core/tests/unit_tests/prompts/test_pipeline_prompt.py @@ -62,19 +62,22 @@ def test_partial_with_prompt_template() -> None: pipeline_prompt.partial(person="Elon Musk") pipeline_prompt.partial(invalid_partial="Hello, I am invalid") - + ret = pipeline_prompt.format( example_q="What's your favorite car?", example_a="Tesla", input="What's your favorite social media site?", ) - assert ret == """You are impersonating Elon Musk. + assert ( + ret + == """You are impersonating Elon Musk. Here's an example of an interaction: Q: What's your favorite car? A: Tesla Now, do this for real! Q: What's your favorite social media site? A:""" + ) async def test_partial_with_chat_prompts() -> None: From ce6dce8a75a803fee5d20963044545019a6460c0 Mon Sep 17 00:00:00 2001 From: Aarya2004 Date: Wed, 27 Nov 2024 02:25:20 -0500 Subject: [PATCH 3/4] fix(type errors): Fixing type annotations --- libs/core/langchain_core/prompts/pipeline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/libs/core/langchain_core/prompts/pipeline.py b/libs/core/langchain_core/prompts/pipeline.py index ace28ec3d0166..068c430005a9a 100644 --- a/libs/core/langchain_core/prompts/pipeline.py +++ b/libs/core/langchain_core/prompts/pipeline.py @@ -1,4 +1,4 @@ -from typing import Any, Callable +from typing import Any, Callable, Union from typing import Optional as Optional from pydantic import model_validator @@ -112,7 +112,7 @@ async def aformat(self, **kwargs: Any) -> str: # Ignoring the type below since partial makes modifications rather # than returning a new template - def partial(self, **kwargs: str | Callable[[], str]) -> None: # type: ignore[override] + def partial(self, **kwargs: Union[str, Callable[[], str]]) -> None: # type: ignore[override] """Add partial arguments to prompts in pipeline_prompts Args: From f55743ae4fbabfaf50590978a5e555f8f9ba0548 Mon Sep 17 00:00:00 2001 From: Aarya2004 Date: Wed, 27 Nov 2024 02:33:29 -0500 Subject: [PATCH 4/4] refactor(linting): Fixed some linting issues --- libs/core/tests/unit_tests/prompts/test_pipeline_prompt.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/libs/core/tests/unit_tests/prompts/test_pipeline_prompt.py b/libs/core/tests/unit_tests/prompts/test_pipeline_prompt.py index 889f7a8bd3175..a7f00bf68dbff 100644 --- a/libs/core/tests/unit_tests/prompts/test_pipeline_prompt.py +++ b/libs/core/tests/unit_tests/prompts/test_pipeline_prompt.py @@ -56,8 +56,9 @@ def test_partial_with_prompt_template() -> None: ("example", example_prompt), ("start", start_prompt), ] - pipeline_prompt = PipelinePromptTemplate( - final_prompt=full_prompt, pipeline_prompts=input_prompts + pipeline_prompt = PipelinePromptTemplate( # type: ignore[call-arg] + final_prompt=full_prompt, + pipeline_prompts=input_prompts, # type: ignore[arg-type] ) pipeline_prompt.partial(person="Elon Musk")