Skip to content

Commit

Permalink
BUGFIX: add prompt imports for backwards compat (#13702)
Browse files Browse the repository at this point in the history
  • Loading branch information
baskaryan authored Nov 22, 2023
1 parent 78da341 commit 16af282
Show file tree
Hide file tree
Showing 21 changed files with 186 additions and 15 deletions.
3 changes: 1 addition & 2 deletions libs/core/langchain_core/prompts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from langchain_core.prompts.few_shot_with_templates import FewShotPromptWithTemplates
from langchain_core.prompts.loading import load_prompt
from langchain_core.prompts.pipeline import PipelinePromptTemplate
from langchain_core.prompts.prompt import Prompt, PromptTemplate
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.prompts.string import (
StringPromptTemplate,
check_valid_template,
Expand All @@ -62,7 +62,6 @@
"HumanMessagePromptTemplate",
"MessagesPlaceholder",
"PipelinePromptTemplate",
"Prompt",
"PromptTemplate",
"StringPromptTemplate",
"SystemMessagePromptTemplate",
Expand Down
4 changes: 0 additions & 4 deletions libs/core/langchain_core/prompts/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,3 @@ def from_template(
partial_variables=_partial_variables,
**kwargs,
)


# For backwards compatibility.
Prompt = PromptTemplate
1 change: 0 additions & 1 deletion libs/core/tests/unit_tests/prompts/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
"HumanMessagePromptTemplate",
"MessagesPlaceholder",
"PipelinePromptTemplate",
"Prompt",
"PromptTemplate",
"StringPromptTemplate",
"SystemMessagePromptTemplate",
Expand Down
4 changes: 2 additions & 2 deletions libs/langchain/langchain/prompts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,14 @@
HumanMessagePromptTemplate,
MessagesPlaceholder,
PipelinePromptTemplate,
Prompt,
PromptTemplate,
StringPromptTemplate,
SystemMessagePromptTemplate,
load_prompt,
)

from langchain.prompts.example_selector import NGramOverlapExampleSelector
from langchain.prompts.prompt import Prompt

__all__ = [
"AIMessagePromptTemplate",
Expand All @@ -67,11 +67,11 @@
"MessagesPlaceholder",
"NGramOverlapExampleSelector",
"PipelinePromptTemplate",
"Prompt",
"PromptTemplate",
"SemanticSimilarityExampleSelector",
"StringPromptTemplate",
"SystemMessagePromptTemplate",
"load_prompt",
"FewShotChatMessagePromptTemplate",
"Prompt",
]
4 changes: 4 additions & 0 deletions libs/langchain/langchain/prompts/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from langchain_core.prompt_values import StringPromptValue
from langchain_core.prompts import (
BasePromptTemplate,
StringPromptTemplate,
Expand All @@ -6,6 +7,7 @@
jinja2_formatter,
validate_jinja2,
)
from langchain_core.prompts.string import _get_jinja2_variables_from_template

__all__ = [
"jinja2_formatter",
Expand All @@ -14,4 +16,6 @@
"get_template_variables",
"StringPromptTemplate",
"BasePromptTemplate",
"StringPromptValue",
"_get_jinja2_variables_from_template",
]
7 changes: 7 additions & 0 deletions libs/langchain/langchain/prompts/chat.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from langchain_core.prompt_values import ChatPromptValue, ChatPromptValueConcrete
from langchain_core.prompts.chat import (
AIMessagePromptTemplate,
BaseChatPromptTemplate,
Expand All @@ -8,6 +9,8 @@
HumanMessagePromptTemplate,
MessagesPlaceholder,
SystemMessagePromptTemplate,
_convert_to_message,
_create_template_from_message_type,
)

__all__ = [
Expand All @@ -20,4 +23,8 @@
"SystemMessagePromptTemplate",
"BaseChatPromptTemplate",
"ChatPromptTemplate",
"ChatPromptValue",
"ChatPromptValueConcrete",
"_convert_to_message",
"_create_template_from_message_type",
]
7 changes: 6 additions & 1 deletion libs/langchain/langchain/prompts/few_shot.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from langchain_core.prompts.few_shot import (
FewShotChatMessagePromptTemplate,
FewShotPromptTemplate,
_FewShotPromptTemplateMixin,
)

__all__ = ["FewShotPromptTemplate", "FewShotChatMessagePromptTemplate"]
__all__ = [
"FewShotPromptTemplate",
"FewShotChatMessagePromptTemplate",
"_FewShotPromptTemplateMixin",
]
23 changes: 21 additions & 2 deletions libs/langchain/langchain/prompts/loading.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,23 @@
from langchain_core.prompts.loading import load_prompt, load_prompt_from_config
from langchain_core.prompts.loading import (
_load_examples,
_load_few_shot_prompt,
_load_output_parser,
_load_prompt,
_load_prompt_from_file,
_load_template,
load_prompt,
load_prompt_from_config,
)
from langchain_core.utils.loading import try_load_from_hub

__all__ = ["load_prompt_from_config", "load_prompt", "try_load_from_hub"]
__all__ = [
"load_prompt_from_config",
"load_prompt",
"try_load_from_hub",
"_load_examples",
"_load_few_shot_prompt",
"_load_output_parser",
"_load_prompt",
"_load_prompt_from_file",
"_load_template",
]
4 changes: 2 additions & 2 deletions libs/langchain/langchain/prompts/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from langchain_core.prompts.pipeline import PipelinePromptTemplate
from langchain_core.prompts.pipeline import PipelinePromptTemplate, _get_inputs

__all__ = ["PipelinePromptTemplate"]
__all__ = ["PipelinePromptTemplate", "_get_inputs"]
5 changes: 4 additions & 1 deletion libs/langchain/langchain/prompts/prompt.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from langchain_core.prompts.prompt import PromptTemplate

__all__ = ["PromptTemplate"]
# For backwards compatibility.
Prompt = PromptTemplate

__all__ = ["PromptTemplate", "Prompt"]
6 changes: 6 additions & 0 deletions libs/langchain/langchain/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
StructuredTool,
Tool,
ToolException,
_create_subset_model,
_get_filtered_args,
_SchemaConfig,
create_schema_from_function,
tool,
)
Expand All @@ -16,4 +19,7 @@
"Tool",
"StructuredTool",
"tool",
"_SchemaConfig",
"_create_subset_model",
"_get_filtered_args",
]
1 change: 1 addition & 0 deletions libs/langchain/tests/unit_tests/prompts/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Test prompt functionality."""
16 changes: 16 additions & 0 deletions libs/langchain/tests/unit_tests/prompts/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from langchain.prompts.base import __all__

EXPECTED_ALL = [
"BasePromptTemplate",
"StringPromptTemplate",
"StringPromptValue",
"_get_jinja2_variables_from_template",
"check_valid_template",
"get_template_variables",
"jinja2_formatter",
"validate_jinja2",
]


def test_all_imports() -> None:
assert set(__all__) == set(EXPECTED_ALL)
21 changes: 21 additions & 0 deletions libs/langchain/tests/unit_tests/prompts/test_chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from langchain.prompts.chat import __all__

EXPECTED_ALL = [
"AIMessagePromptTemplate",
"BaseChatPromptTemplate",
"BaseMessagePromptTemplate",
"BaseStringMessagePromptTemplate",
"ChatMessagePromptTemplate",
"ChatPromptTemplate",
"ChatPromptValue",
"ChatPromptValueConcrete",
"HumanMessagePromptTemplate",
"MessagesPlaceholder",
"SystemMessagePromptTemplate",
"_convert_to_message",
"_create_template_from_message_type",
]


def test_all_imports() -> None:
assert set(__all__) == set(EXPECTED_ALL)
11 changes: 11 additions & 0 deletions libs/langchain/tests/unit_tests/prompts/test_few_shot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from langchain.prompts.few_shot import __all__

EXPECTED_ALL = [
"FewShotChatMessagePromptTemplate",
"FewShotPromptTemplate",
"_FewShotPromptTemplateMixin",
]


def test_all_imports() -> None:
assert set(__all__) == set(EXPECTED_ALL)
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from langchain.prompts.few_shot_with_templates import __all__

EXPECTED_ALL = ["FewShotPromptWithTemplates"]


def test_all_imports() -> None:
assert set(__all__) == set(EXPECTED_ALL)
28 changes: 28 additions & 0 deletions libs/langchain/tests/unit_tests/prompts/test_imports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from langchain.prompts import __all__

EXPECTED_ALL = [
"AIMessagePromptTemplate",
"BaseChatPromptTemplate",
"BasePromptTemplate",
"ChatMessagePromptTemplate",
"ChatPromptTemplate",
"FewShotPromptTemplate",
"FewShotPromptWithTemplates",
"HumanMessagePromptTemplate",
"LengthBasedExampleSelector",
"MaxMarginalRelevanceExampleSelector",
"MessagesPlaceholder",
"NGramOverlapExampleSelector",
"PipelinePromptTemplate",
"Prompt",
"PromptTemplate",
"SemanticSimilarityExampleSelector",
"StringPromptTemplate",
"SystemMessagePromptTemplate",
"load_prompt",
"FewShotChatMessagePromptTemplate",
]


def test_all_imports() -> None:
assert set(__all__) == set(EXPECTED_ALL)
17 changes: 17 additions & 0 deletions libs/langchain/tests/unit_tests/prompts/test_loading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from langchain.prompts.loading import __all__

EXPECTED_ALL = [
"_load_examples",
"_load_few_shot_prompt",
"_load_output_parser",
"_load_prompt",
"_load_prompt_from_file",
"_load_template",
"load_prompt",
"load_prompt_from_config",
"try_load_from_hub",
]


def test_all_imports() -> None:
assert set(__all__) == set(EXPECTED_ALL)
7 changes: 7 additions & 0 deletions libs/langchain/tests/unit_tests/prompts/test_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from langchain.prompts.pipeline import __all__

EXPECTED_ALL = ["PipelinePromptTemplate", "_get_inputs"]


def test_all_imports() -> None:
assert set(__all__) == set(EXPECTED_ALL)
7 changes: 7 additions & 0 deletions libs/langchain/tests/unit_tests/prompts/test_prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from langchain.prompts.prompt import __all__

EXPECTED_ALL = ["Prompt", "PromptTemplate"]


def test_all_imports() -> None:
assert set(__all__) == set(EXPECTED_ALL)
18 changes: 18 additions & 0 deletions libs/langchain/tests/unit_tests/tools/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from langchain.tools.base import __all__

EXPECTED_ALL = [
"BaseTool",
"SchemaAnnotationError",
"StructuredTool",
"Tool",
"ToolException",
"_SchemaConfig",
"_create_subset_model",
"_get_filtered_args",
"create_schema_from_function",
"tool",
]


def test_all_imports() -> None:
assert set(__all__) == set(EXPECTED_ALL)

0 comments on commit 16af282

Please sign in to comment.