diff --git a/libs/core/langchain_core/prompts/base.py b/libs/core/langchain_core/prompts/base.py index 042b2d77ea80f..096f432f3eed9 100644 --- a/libs/core/langchain_core/prompts/base.py +++ b/libs/core/langchain_core/prompts/base.py @@ -7,15 +7,7 @@ from collections.abc import Mapping from functools import cached_property from pathlib import Path -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Generic, - Optional, - TypeVar, - Union, -) +from typing import TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar, Union import yaml from pydantic import BaseModel, ConfigDict, Field, model_validator @@ -36,7 +28,6 @@ if TYPE_CHECKING: from langchain_core.documents import Document - FormatOutputType = TypeVar("FormatOutputType") @@ -260,27 +251,47 @@ async def aformat_prompt(self, **kwargs: Any) -> PromptValue: """ return self.format_prompt(**kwargs) - def partial(self, **kwargs: Union[str, Callable[[], str]]) -> BasePromptTemplate: + def partial( + self, **kwargs: Union[str, Callable[[], str], BasePromptTemplate] + ) -> BasePromptTemplate: """Return a partial of the prompt template. Args: - kwargs: Union[str, Callable[[], str], partial variables to set. + kwargs: Union[str, Callable[[], str], BasePromptTemplate], + partial variables to set. Returns: BasePromptTemplate: A partial of the prompt template. """ prompt_dict = self.__dict__.copy() - prompt_dict["input_variables"] = list( - set(self.input_variables).difference(kwargs) + input_vars = set(self.input_variables).difference(kwargs) + partial_vars = {} + for key, partial_var in kwargs.items(): + if isinstance(partial_var, BasePromptTemplate): + # Prepare partial arguments, excluding the current key + new_kwargs = kwargs.copy() + new_kwargs.pop(key) + partial_var = partial_var.partial(**new_kwargs) + partial_vars[key] = partial_var + prompt_dict.update( + { + "input_variables": list(input_vars), + "partial_variables": {**kwargs, **partial_vars}, + } ) - prompt_dict["partial_variables"] = {**self.partial_variables, **kwargs} return type(self)(**prompt_dict) def _merge_partial_and_user_variables(self, **kwargs: Any) -> dict[str, Any]: # Get partial params: - partial_kwargs = { - k: v if not callable(v) else v() for k, v in self.partial_variables.items() - } + partial_kwargs = {} + for k, v in self.partial_variables.items(): + if isinstance(v, BasePromptTemplate): + # Propagate partial variables and kwargs to nested prompt templates + partial_kwargs[k] = v.format(**kwargs) + elif callable(v): + partial_kwargs[k] = v() + else: + partial_kwargs[k] = v return {**partial_kwargs, **kwargs} @abstractmethod diff --git a/libs/core/langchain_core/prompts/prompt.py b/libs/core/langchain_core/prompts/prompt.py index 325ee067ecaf9..d7b7ec6128277 100644 --- a/libs/core/langchain_core/prompts/prompt.py +++ b/libs/core/langchain_core/prompts/prompt.py @@ -8,6 +8,7 @@ from pydantic import BaseModel, model_validator +from langchain_core.prompts import BasePromptTemplate from langchain_core.prompts.string import ( DEFAULT_FORMATTER_MAPPING, PromptTemplateFormat, @@ -104,12 +105,24 @@ def pre_init_validation(cls, values: dict) -> Any: ) if values["template_format"]: + # Collect nested partial variables from + # BasePromptTemplate instances in partial_variables + nested_partial_vars = { + key + for partial_var in values["partial_variables"].values() + if isinstance(partial_var, BasePromptTemplate) + for key in partial_var.partial_variables + } + + # Filter template variables based on + # partial_variables and nested_partial_vars values["input_variables"] = [ var for var in get_template_variables( values["template"], values["template_format"] ) if var not in values["partial_variables"] + and var not in nested_partial_vars ] return values diff --git a/libs/core/tests/unit_tests/prompts/test_prompt.py b/libs/core/tests/unit_tests/prompts/test_prompt.py index d56654d874d5b..f9be5f4450178 100644 --- a/libs/core/tests/unit_tests/prompts/test_prompt.py +++ b/libs/core/tests/unit_tests/prompts/test_prompt.py @@ -403,6 +403,58 @@ def test_partial() -> None: assert result == "This is a foo test." +def test_nested_prompt_template_as_partial() -> None: + """Test prompt with PromptTemplate as partial variable.""" + template_nested = "{bar}" + prompt_nested = PromptTemplate(input_variables=["bar"], template=template_nested) + + template = "This is a {foo} test." + prompt = PromptTemplate(input_variables=["foo"], template=template) + assert prompt.template == template + assert prompt.input_variables == ["foo"] + + new_prompt = prompt.partial(foo=prompt_nested) + assert new_prompt.input_variables == [] + assert new_prompt.partial_variables["foo"].input_variables == ["bar"] + assert new_prompt.partial_variables["foo"].partial_variables == {} + result = new_prompt.format(bar="bar") + assert result == "This is a bar test." + + new_prompt = prompt.partial(foo=prompt_nested, bar="bar") + assert new_prompt.input_variables == [] + assert new_prompt.partial_variables["foo"].input_variables == [] + assert new_prompt.partial_variables["foo"].partial_variables == {"bar": "bar"} + result = new_prompt.format() + assert result == "This is a bar test." + + +def test_nested_prompt_template_with_shared_variable() -> None: + """Test prompt with PromptTemplate as partial variable, sharing another variable.""" + template_nested = "{bar}" + prompt_nested = PromptTemplate( + input_variables=["bar", "foo"], template=template_nested + ) + + template = "This is a {foo} {bar} test." + prompt = PromptTemplate(input_variables=["foo", "bar"], template=template) + assert prompt.template == template + assert prompt.input_variables == ["bar", "foo"] + + new_prompt = prompt.partial(foo=prompt_nested) + assert new_prompt.input_variables == ["bar"] + assert new_prompt.partial_variables["foo"].input_variables == ["bar"] + assert new_prompt.partial_variables["foo"].partial_variables == {} + result = new_prompt.format(bar="bar") + assert result == "This is a bar bar test." + + new_prompt = prompt.partial(foo=prompt_nested, bar="bar") + assert new_prompt.input_variables == [] + assert new_prompt.partial_variables["foo"].input_variables == [] + assert new_prompt.partial_variables["foo"].partial_variables == {"bar": "bar"} + result = new_prompt.format() + assert result == "This is a bar bar test." + + @pytest.mark.requires("jinja2") def test_prompt_from_jinja2_template() -> None: """Test prompts can be constructed from a jinja2 template.""" @@ -508,7 +560,7 @@ def test_prompt_jinja2_missing_input_variables() -> None: @pytest.mark.requires("jinja2") def test_prompt_jinja2_extra_input_variables() -> None: - """Test error is raised when there are too many input variables.""" + """Test warning is raised when there are too many input variables.""" template = "This is a {{ foo }} test." input_variables = ["foo", "bar"] with pytest.warns(UserWarning): @@ -525,7 +577,7 @@ def test_prompt_jinja2_extra_input_variables() -> None: @pytest.mark.requires("jinja2") def test_prompt_jinja2_wrong_input_variables() -> None: - """Test error is raised when name of input variable is wrong.""" + """Test warning is raised when name of input variable is wrong.""" template = "This is a {{ foo }} test." input_variables = ["bar"] with pytest.warns(UserWarning):