Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Core: Added partial functionality to PipelinePromptTemplate #28377

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions libs/core/langchain_core/prompts/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, Callable, Union
from typing import Optional as Optional

from pydantic import model_validator
Expand All @@ -9,7 +9,11 @@


def _get_inputs(inputs: dict, input_variables: list[str]) -> dict:
return {k: inputs[k] for k in input_variables}
ret = {}
for k in input_variables:
if k in inputs:
ret[k] = inputs[k]
return ret


class PipelinePromptTemplate(BasePromptTemplate):
Expand Down Expand Up @@ -106,6 +110,23 @@ async def aformat(self, **kwargs: Any) -> str:
"""
return (await self.aformat_prompt(**kwargs)).to_string()

# Ignoring the type below since partial makes modifications rather
# than returning a new template
def partial(self, **kwargs: Union[str, Callable[[], str]]) -> None: # type: ignore[override]
"""Add partial arguments to prompts in pipeline_prompts

Args:
kwargs: dict[str, str], partial variables to set.
"""
for i, string_and_prompt in enumerate(self.pipeline_prompts):
k, prompt = string_and_prompt
prompt_kwargs = {}
for partial_var, partial_input in kwargs.items():
if partial_var in prompt.input_variables:
prompt_kwargs[partial_var] = partial_input
if len(prompt_kwargs) > 0:
self.pipeline_prompts[i] = (k, prompt.partial(**prompt_kwargs))

@property
def _prompt_type(self) -> str:
raise ValueError
Expand Down
49 changes: 49 additions & 0 deletions libs/core/tests/unit_tests/prompts/test_pipeline_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,55 @@ def test_multi_variable_pipeline() -> None:
assert output == "okay jim deep"


def test_partial_with_prompt_template() -> None:
full_template = """{introduction}
{example}
{start}"""
full_prompt = PromptTemplate.from_template(full_template)

introduction_template = """You are impersonating {person}."""
introduction_prompt = PromptTemplate.from_template(introduction_template)

example_template = """Here's an example of an interaction:
Q: {example_q}
A: {example_a}"""
example_prompt = PromptTemplate.from_template(example_template)

start_template = """Now, do this for real!
Q: {input}
A:"""
start_prompt = PromptTemplate.from_template(start_template)

input_prompts = [
("introduction", introduction_prompt),
("example", example_prompt),
("start", start_prompt),
]
pipeline_prompt = PipelinePromptTemplate( # type: ignore[call-arg]
final_prompt=full_prompt,
pipeline_prompts=input_prompts, # type: ignore[arg-type]
)

pipeline_prompt.partial(person="Elon Musk")
pipeline_prompt.partial(invalid_partial="Hello, I am invalid")

ret = pipeline_prompt.format(
example_q="What's your favorite car?",
example_a="Tesla",
input="What's your favorite social media site?",
)
assert (
ret
== """You are impersonating Elon Musk.
Here's an example of an interaction:
Q: What's your favorite car?
A: Tesla
Now, do this for real!
Q: What's your favorite social media site?
A:"""
)


async def test_partial_with_chat_prompts() -> None:
prompt_a = ChatPromptTemplate(
input_variables=["foo"], messages=[MessagesPlaceholder(variable_name="foo")]
Expand Down
Loading