Skip to content

Commit

Permalink
feat (core): allow nested prompt templates
Browse files Browse the repository at this point in the history
  • Loading branch information
taamedag authored and taamedag committed Nov 11, 2024
1 parent 3b0b7cf commit 6905b5d
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 20 deletions.
47 changes: 29 additions & 18 deletions libs/core/langchain_core/prompts/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,7 @@
from collections.abc import Mapping
from functools import cached_property
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
Callable,
Generic,
Optional,
TypeVar,
Union,
)
from typing import TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar, Union

import yaml
from pydantic import BaseModel, ConfigDict, Field, model_validator
Expand All @@ -36,7 +28,6 @@
if TYPE_CHECKING:
from langchain_core.documents import Document


FormatOutputType = TypeVar("FormatOutputType")


Expand Down Expand Up @@ -260,27 +251,47 @@ async def aformat_prompt(self, **kwargs: Any) -> PromptValue:
"""
return self.format_prompt(**kwargs)

def partial(self, **kwargs: Union[str, Callable[[], str]]) -> BasePromptTemplate:
def partial(
self, **kwargs: Union[str, Callable[[], str], BasePromptTemplate]
) -> BasePromptTemplate:
"""Return a partial of the prompt template.
Args:
kwargs: Union[str, Callable[[], str], partial variables to set.
kwargs: Union[str, Callable[[], str], BasePromptTemplate],
partial variables to set.
Returns:
BasePromptTemplate: A partial of the prompt template.
"""
prompt_dict = self.__dict__.copy()
prompt_dict["input_variables"] = list(
set(self.input_variables).difference(kwargs)
input_vars = set(self.input_variables).difference(kwargs)
partial_vars = {}
for key, partial_var in kwargs.items():
if isinstance(partial_var, BasePromptTemplate):
# Prepare partial arguments, excluding the current key
new_kwargs = kwargs.copy()
new_kwargs.pop(key)
partial_var = partial_var.partial(**new_kwargs)
partial_vars[key] = partial_var
prompt_dict.update(
{
"input_variables": list(input_vars),
"partial_variables": {**kwargs, **partial_vars},
}
)
prompt_dict["partial_variables"] = {**self.partial_variables, **kwargs}
return type(self)(**prompt_dict)

def _merge_partial_and_user_variables(self, **kwargs: Any) -> dict[str, Any]:
# Get partial params:
partial_kwargs = {
k: v if not callable(v) else v() for k, v in self.partial_variables.items()
}
partial_kwargs = {}
for k, v in self.partial_variables.items():
if isinstance(v, BasePromptTemplate):
# Propagate partial variables and kwargs to nested prompt templates
partial_kwargs[k] = v.format(**kwargs)
elif callable(v):
partial_kwargs[k] = v()
else:
partial_kwargs[k] = v
return {**partial_kwargs, **kwargs}

@abstractmethod
Expand Down
13 changes: 13 additions & 0 deletions libs/core/langchain_core/prompts/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from pydantic import BaseModel, model_validator

from langchain_core.prompts import BasePromptTemplate
from langchain_core.prompts.string import (
DEFAULT_FORMATTER_MAPPING,
PromptTemplateFormat,
Expand Down Expand Up @@ -104,12 +105,24 @@ def pre_init_validation(cls, values: dict) -> Any:
)

if values["template_format"]:
# Collect nested partial variables from
# BasePromptTemplate instances in partial_variables
nested_partial_vars = {
key
for partial_var in values["partial_variables"].values()
if isinstance(partial_var, BasePromptTemplate)
for key in partial_var.partial_variables
}

# Filter template variables based on
# partial_variables and nested_partial_vars
values["input_variables"] = [
var
for var in get_template_variables(
values["template"], values["template_format"]
)
if var not in values["partial_variables"]
and var not in nested_partial_vars
]

return values
Expand Down
56 changes: 54 additions & 2 deletions libs/core/tests/unit_tests/prompts/test_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,58 @@ def test_partial() -> None:
assert result == "This is a foo test."


def test_nested_prompt_template_as_partial() -> None:
"""Test prompt with PromptTemplate as partial variable."""
template_nested = "{bar}"
prompt_nested = PromptTemplate(input_variables=["bar"], template=template_nested)

template = "This is a {foo} test."
prompt = PromptTemplate(input_variables=["foo"], template=template)
assert prompt.template == template
assert prompt.input_variables == ["foo"]

new_prompt = prompt.partial(foo=prompt_nested)
assert new_prompt.input_variables == []
assert new_prompt.partial_variables["foo"].input_variables == ["bar"]
assert new_prompt.partial_variables["foo"].partial_variables == {}
result = new_prompt.format(bar="bar")
assert result == "This is a bar test."

new_prompt = prompt.partial(foo=prompt_nested, bar="bar")
assert new_prompt.input_variables == []
assert new_prompt.partial_variables["foo"].input_variables == []
assert new_prompt.partial_variables["foo"].partial_variables == {"bar": "bar"}
result = new_prompt.format()
assert result == "This is a bar test."


def test_nested_prompt_template_with_shared_variable() -> None:
"""Test prompt with PromptTemplate as partial variable, sharing another variable."""
template_nested = "{bar}"
prompt_nested = PromptTemplate(
input_variables=["bar", "foo"], template=template_nested
)

template = "This is a {foo} {bar} test."
prompt = PromptTemplate(input_variables=["foo", "bar"], template=template)
assert prompt.template == template
assert prompt.input_variables == ["bar", "foo"]

new_prompt = prompt.partial(foo=prompt_nested)
assert new_prompt.input_variables == ["bar"]
assert new_prompt.partial_variables["foo"].input_variables == ["bar"]
assert new_prompt.partial_variables["foo"].partial_variables == {}
result = new_prompt.format(bar="bar")
assert result == "This is a bar bar test."

new_prompt = prompt.partial(foo=prompt_nested, bar="bar")
assert new_prompt.input_variables == []
assert new_prompt.partial_variables["foo"].input_variables == []
assert new_prompt.partial_variables["foo"].partial_variables == {"bar": "bar"}
result = new_prompt.format()
assert result == "This is a bar bar test."


@pytest.mark.requires("jinja2")
def test_prompt_from_jinja2_template() -> None:
"""Test prompts can be constructed from a jinja2 template."""
Expand Down Expand Up @@ -508,7 +560,7 @@ def test_prompt_jinja2_missing_input_variables() -> None:

@pytest.mark.requires("jinja2")
def test_prompt_jinja2_extra_input_variables() -> None:
"""Test error is raised when there are too many input variables."""
"""Test warning is raised when there are too many input variables."""
template = "This is a {{ foo }} test."
input_variables = ["foo", "bar"]
with pytest.warns(UserWarning):
Expand All @@ -525,7 +577,7 @@ def test_prompt_jinja2_extra_input_variables() -> None:

@pytest.mark.requires("jinja2")
def test_prompt_jinja2_wrong_input_variables() -> None:
"""Test error is raised when name of input variable is wrong."""
"""Test warning is raised when name of input variable is wrong."""
template = "This is a {{ foo }} test."
input_variables = ["bar"]
with pytest.warns(UserWarning):
Expand Down

0 comments on commit 6905b5d

Please sign in to comment.