From 16af282429219334390b64822efe0779a6da66af Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Tue, 21 Nov 2023 23:04:20 -0800 Subject: [PATCH] BUGFIX: add prompt imports for backwards compat (#13702) --- libs/core/langchain_core/prompts/__init__.py | 3 +- libs/core/langchain_core/prompts/prompt.py | 4 --- .../tests/unit_tests/prompts/test_imports.py | 1 - libs/langchain/langchain/prompts/__init__.py | 4 +-- libs/langchain/langchain/prompts/base.py | 4 +++ libs/langchain/langchain/prompts/chat.py | 7 +++++ libs/langchain/langchain/prompts/few_shot.py | 7 ++++- libs/langchain/langchain/prompts/loading.py | 23 +++++++++++++-- libs/langchain/langchain/prompts/pipeline.py | 4 +-- libs/langchain/langchain/prompts/prompt.py | 5 +++- libs/langchain/langchain/tools/base.py | 6 ++++ .../tests/unit_tests/prompts/__init__.py | 1 + .../tests/unit_tests/prompts/test_base.py | 16 +++++++++++ .../tests/unit_tests/prompts/test_chat.py | 21 ++++++++++++++ .../tests/unit_tests/prompts/test_few_shot.py | 11 ++++++++ .../prompts/test_few_shot_with_templates.py | 7 +++++ .../tests/unit_tests/prompts/test_imports.py | 28 +++++++++++++++++++ .../tests/unit_tests/prompts/test_loading.py | 17 +++++++++++ .../tests/unit_tests/prompts/test_pipeline.py | 7 +++++ .../tests/unit_tests/prompts/test_prompt.py | 7 +++++ .../tests/unit_tests/tools/test_base.py | 18 ++++++++++++ 21 files changed, 186 insertions(+), 15 deletions(-) create mode 100644 libs/langchain/tests/unit_tests/prompts/__init__.py create mode 100644 libs/langchain/tests/unit_tests/prompts/test_base.py create mode 100644 libs/langchain/tests/unit_tests/prompts/test_chat.py create mode 100644 libs/langchain/tests/unit_tests/prompts/test_few_shot.py create mode 100644 libs/langchain/tests/unit_tests/prompts/test_few_shot_with_templates.py create mode 100644 libs/langchain/tests/unit_tests/prompts/test_imports.py create mode 100644 libs/langchain/tests/unit_tests/prompts/test_loading.py create mode 100644 libs/langchain/tests/unit_tests/prompts/test_pipeline.py create mode 100644 libs/langchain/tests/unit_tests/prompts/test_prompt.py create mode 100644 libs/langchain/tests/unit_tests/tools/test_base.py diff --git a/libs/core/langchain_core/prompts/__init__.py b/libs/core/langchain_core/prompts/__init__.py index bce1fc4d5d3eb..e625578b5d4e3 100644 --- a/libs/core/langchain_core/prompts/__init__.py +++ b/libs/core/langchain_core/prompts/__init__.py @@ -41,7 +41,7 @@ from langchain_core.prompts.few_shot_with_templates import FewShotPromptWithTemplates from langchain_core.prompts.loading import load_prompt from langchain_core.prompts.pipeline import PipelinePromptTemplate -from langchain_core.prompts.prompt import Prompt, PromptTemplate +from langchain_core.prompts.prompt import PromptTemplate from langchain_core.prompts.string import ( StringPromptTemplate, check_valid_template, @@ -62,7 +62,6 @@ "HumanMessagePromptTemplate", "MessagesPlaceholder", "PipelinePromptTemplate", - "Prompt", "PromptTemplate", "StringPromptTemplate", "SystemMessagePromptTemplate", diff --git a/libs/core/langchain_core/prompts/prompt.py b/libs/core/langchain_core/prompts/prompt.py index c192f46c51fcf..012b926953701 100644 --- a/libs/core/langchain_core/prompts/prompt.py +++ b/libs/core/langchain_core/prompts/prompt.py @@ -244,7 +244,3 @@ def from_template( partial_variables=_partial_variables, **kwargs, ) - - -# For backwards compatibility. -Prompt = PromptTemplate diff --git a/libs/core/tests/unit_tests/prompts/test_imports.py b/libs/core/tests/unit_tests/prompts/test_imports.py index 250afa23513bf..a9cacf91570fe 100644 --- a/libs/core/tests/unit_tests/prompts/test_imports.py +++ b/libs/core/tests/unit_tests/prompts/test_imports.py @@ -13,7 +13,6 @@ "HumanMessagePromptTemplate", "MessagesPlaceholder", "PipelinePromptTemplate", - "Prompt", "PromptTemplate", "StringPromptTemplate", "SystemMessagePromptTemplate", diff --git a/libs/langchain/langchain/prompts/__init__.py b/libs/langchain/langchain/prompts/__init__.py index d8ff6a9ad99cd..9a966fa552617 100644 --- a/libs/langchain/langchain/prompts/__init__.py +++ b/libs/langchain/langchain/prompts/__init__.py @@ -44,7 +44,6 @@ HumanMessagePromptTemplate, MessagesPlaceholder, PipelinePromptTemplate, - Prompt, PromptTemplate, StringPromptTemplate, SystemMessagePromptTemplate, @@ -52,6 +51,7 @@ ) from langchain.prompts.example_selector import NGramOverlapExampleSelector +from langchain.prompts.prompt import Prompt __all__ = [ "AIMessagePromptTemplate", @@ -67,11 +67,11 @@ "MessagesPlaceholder", "NGramOverlapExampleSelector", "PipelinePromptTemplate", - "Prompt", "PromptTemplate", "SemanticSimilarityExampleSelector", "StringPromptTemplate", "SystemMessagePromptTemplate", "load_prompt", "FewShotChatMessagePromptTemplate", + "Prompt", ] diff --git a/libs/langchain/langchain/prompts/base.py b/libs/langchain/langchain/prompts/base.py index 7c20569823ccb..a315ec92d8e10 100644 --- a/libs/langchain/langchain/prompts/base.py +++ b/libs/langchain/langchain/prompts/base.py @@ -1,3 +1,4 @@ +from langchain_core.prompt_values import StringPromptValue from langchain_core.prompts import ( BasePromptTemplate, StringPromptTemplate, @@ -6,6 +7,7 @@ jinja2_formatter, validate_jinja2, ) +from langchain_core.prompts.string import _get_jinja2_variables_from_template __all__ = [ "jinja2_formatter", @@ -14,4 +16,6 @@ "get_template_variables", "StringPromptTemplate", "BasePromptTemplate", + "StringPromptValue", + "_get_jinja2_variables_from_template", ] diff --git a/libs/langchain/langchain/prompts/chat.py b/libs/langchain/langchain/prompts/chat.py index c44c196cf4797..049f19382c3cf 100644 --- a/libs/langchain/langchain/prompts/chat.py +++ b/libs/langchain/langchain/prompts/chat.py @@ -1,3 +1,4 @@ +from langchain_core.prompt_values import ChatPromptValue, ChatPromptValueConcrete from langchain_core.prompts.chat import ( AIMessagePromptTemplate, BaseChatPromptTemplate, @@ -8,6 +9,8 @@ HumanMessagePromptTemplate, MessagesPlaceholder, SystemMessagePromptTemplate, + _convert_to_message, + _create_template_from_message_type, ) __all__ = [ @@ -20,4 +23,8 @@ "SystemMessagePromptTemplate", "BaseChatPromptTemplate", "ChatPromptTemplate", + "ChatPromptValue", + "ChatPromptValueConcrete", + "_convert_to_message", + "_create_template_from_message_type", ] diff --git a/libs/langchain/langchain/prompts/few_shot.py b/libs/langchain/langchain/prompts/few_shot.py index ab8e24098ede1..67b3106271d79 100644 --- a/libs/langchain/langchain/prompts/few_shot.py +++ b/libs/langchain/langchain/prompts/few_shot.py @@ -1,6 +1,11 @@ from langchain_core.prompts.few_shot import ( FewShotChatMessagePromptTemplate, FewShotPromptTemplate, + _FewShotPromptTemplateMixin, ) -__all__ = ["FewShotPromptTemplate", "FewShotChatMessagePromptTemplate"] +__all__ = [ + "FewShotPromptTemplate", + "FewShotChatMessagePromptTemplate", + "_FewShotPromptTemplateMixin", +] diff --git a/libs/langchain/langchain/prompts/loading.py b/libs/langchain/langchain/prompts/loading.py index df0f62f8503f8..b1dfcd9069120 100644 --- a/libs/langchain/langchain/prompts/loading.py +++ b/libs/langchain/langchain/prompts/loading.py @@ -1,4 +1,23 @@ -from langchain_core.prompts.loading import load_prompt, load_prompt_from_config +from langchain_core.prompts.loading import ( + _load_examples, + _load_few_shot_prompt, + _load_output_parser, + _load_prompt, + _load_prompt_from_file, + _load_template, + load_prompt, + load_prompt_from_config, +) from langchain_core.utils.loading import try_load_from_hub -__all__ = ["load_prompt_from_config", "load_prompt", "try_load_from_hub"] +__all__ = [ + "load_prompt_from_config", + "load_prompt", + "try_load_from_hub", + "_load_examples", + "_load_few_shot_prompt", + "_load_output_parser", + "_load_prompt", + "_load_prompt_from_file", + "_load_template", +] diff --git a/libs/langchain/langchain/prompts/pipeline.py b/libs/langchain/langchain/prompts/pipeline.py index 88e73e16f3352..88e2cc79ab25b 100644 --- a/libs/langchain/langchain/prompts/pipeline.py +++ b/libs/langchain/langchain/prompts/pipeline.py @@ -1,3 +1,3 @@ -from langchain_core.prompts.pipeline import PipelinePromptTemplate +from langchain_core.prompts.pipeline import PipelinePromptTemplate, _get_inputs -__all__ = ["PipelinePromptTemplate"] +__all__ = ["PipelinePromptTemplate", "_get_inputs"] diff --git a/libs/langchain/langchain/prompts/prompt.py b/libs/langchain/langchain/prompts/prompt.py index 047d55adfed6f..5e35f878bac0d 100644 --- a/libs/langchain/langchain/prompts/prompt.py +++ b/libs/langchain/langchain/prompts/prompt.py @@ -1,3 +1,6 @@ from langchain_core.prompts.prompt import PromptTemplate -__all__ = ["PromptTemplate"] +# For backwards compatibility. +Prompt = PromptTemplate + +__all__ = ["PromptTemplate", "Prompt"] diff --git a/libs/langchain/langchain/tools/base.py b/libs/langchain/langchain/tools/base.py index ff81eaa895620..fd34a1a5a5ab8 100644 --- a/libs/langchain/langchain/tools/base.py +++ b/libs/langchain/langchain/tools/base.py @@ -4,6 +4,9 @@ StructuredTool, Tool, ToolException, + _create_subset_model, + _get_filtered_args, + _SchemaConfig, create_schema_from_function, tool, ) @@ -16,4 +19,7 @@ "Tool", "StructuredTool", "tool", + "_SchemaConfig", + "_create_subset_model", + "_get_filtered_args", ] diff --git a/libs/langchain/tests/unit_tests/prompts/__init__.py b/libs/langchain/tests/unit_tests/prompts/__init__.py new file mode 100644 index 0000000000000..dc72afe0c4dab --- /dev/null +++ b/libs/langchain/tests/unit_tests/prompts/__init__.py @@ -0,0 +1 @@ +"""Test prompt functionality.""" diff --git a/libs/langchain/tests/unit_tests/prompts/test_base.py b/libs/langchain/tests/unit_tests/prompts/test_base.py new file mode 100644 index 0000000000000..00e2fccf217e5 --- /dev/null +++ b/libs/langchain/tests/unit_tests/prompts/test_base.py @@ -0,0 +1,16 @@ +from langchain.prompts.base import __all__ + +EXPECTED_ALL = [ + "BasePromptTemplate", + "StringPromptTemplate", + "StringPromptValue", + "_get_jinja2_variables_from_template", + "check_valid_template", + "get_template_variables", + "jinja2_formatter", + "validate_jinja2", +] + + +def test_all_imports() -> None: + assert set(__all__) == set(EXPECTED_ALL) diff --git a/libs/langchain/tests/unit_tests/prompts/test_chat.py b/libs/langchain/tests/unit_tests/prompts/test_chat.py new file mode 100644 index 0000000000000..515642417743d --- /dev/null +++ b/libs/langchain/tests/unit_tests/prompts/test_chat.py @@ -0,0 +1,21 @@ +from langchain.prompts.chat import __all__ + +EXPECTED_ALL = [ + "AIMessagePromptTemplate", + "BaseChatPromptTemplate", + "BaseMessagePromptTemplate", + "BaseStringMessagePromptTemplate", + "ChatMessagePromptTemplate", + "ChatPromptTemplate", + "ChatPromptValue", + "ChatPromptValueConcrete", + "HumanMessagePromptTemplate", + "MessagesPlaceholder", + "SystemMessagePromptTemplate", + "_convert_to_message", + "_create_template_from_message_type", +] + + +def test_all_imports() -> None: + assert set(__all__) == set(EXPECTED_ALL) diff --git a/libs/langchain/tests/unit_tests/prompts/test_few_shot.py b/libs/langchain/tests/unit_tests/prompts/test_few_shot.py new file mode 100644 index 0000000000000..248f309f24432 --- /dev/null +++ b/libs/langchain/tests/unit_tests/prompts/test_few_shot.py @@ -0,0 +1,11 @@ +from langchain.prompts.few_shot import __all__ + +EXPECTED_ALL = [ + "FewShotChatMessagePromptTemplate", + "FewShotPromptTemplate", + "_FewShotPromptTemplateMixin", +] + + +def test_all_imports() -> None: + assert set(__all__) == set(EXPECTED_ALL) diff --git a/libs/langchain/tests/unit_tests/prompts/test_few_shot_with_templates.py b/libs/langchain/tests/unit_tests/prompts/test_few_shot_with_templates.py new file mode 100644 index 0000000000000..012ac4dcb07fc --- /dev/null +++ b/libs/langchain/tests/unit_tests/prompts/test_few_shot_with_templates.py @@ -0,0 +1,7 @@ +from langchain.prompts.few_shot_with_templates import __all__ + +EXPECTED_ALL = ["FewShotPromptWithTemplates"] + + +def test_all_imports() -> None: + assert set(__all__) == set(EXPECTED_ALL) diff --git a/libs/langchain/tests/unit_tests/prompts/test_imports.py b/libs/langchain/tests/unit_tests/prompts/test_imports.py new file mode 100644 index 0000000000000..6ec17789ad043 --- /dev/null +++ b/libs/langchain/tests/unit_tests/prompts/test_imports.py @@ -0,0 +1,28 @@ +from langchain.prompts import __all__ + +EXPECTED_ALL = [ + "AIMessagePromptTemplate", + "BaseChatPromptTemplate", + "BasePromptTemplate", + "ChatMessagePromptTemplate", + "ChatPromptTemplate", + "FewShotPromptTemplate", + "FewShotPromptWithTemplates", + "HumanMessagePromptTemplate", + "LengthBasedExampleSelector", + "MaxMarginalRelevanceExampleSelector", + "MessagesPlaceholder", + "NGramOverlapExampleSelector", + "PipelinePromptTemplate", + "Prompt", + "PromptTemplate", + "SemanticSimilarityExampleSelector", + "StringPromptTemplate", + "SystemMessagePromptTemplate", + "load_prompt", + "FewShotChatMessagePromptTemplate", +] + + +def test_all_imports() -> None: + assert set(__all__) == set(EXPECTED_ALL) diff --git a/libs/langchain/tests/unit_tests/prompts/test_loading.py b/libs/langchain/tests/unit_tests/prompts/test_loading.py new file mode 100644 index 0000000000000..8a14876a8431f --- /dev/null +++ b/libs/langchain/tests/unit_tests/prompts/test_loading.py @@ -0,0 +1,17 @@ +from langchain.prompts.loading import __all__ + +EXPECTED_ALL = [ + "_load_examples", + "_load_few_shot_prompt", + "_load_output_parser", + "_load_prompt", + "_load_prompt_from_file", + "_load_template", + "load_prompt", + "load_prompt_from_config", + "try_load_from_hub", +] + + +def test_all_imports() -> None: + assert set(__all__) == set(EXPECTED_ALL) diff --git a/libs/langchain/tests/unit_tests/prompts/test_pipeline.py b/libs/langchain/tests/unit_tests/prompts/test_pipeline.py new file mode 100644 index 0000000000000..f261db6737899 --- /dev/null +++ b/libs/langchain/tests/unit_tests/prompts/test_pipeline.py @@ -0,0 +1,7 @@ +from langchain.prompts.pipeline import __all__ + +EXPECTED_ALL = ["PipelinePromptTemplate", "_get_inputs"] + + +def test_all_imports() -> None: + assert set(__all__) == set(EXPECTED_ALL) diff --git a/libs/langchain/tests/unit_tests/prompts/test_prompt.py b/libs/langchain/tests/unit_tests/prompts/test_prompt.py new file mode 100644 index 0000000000000..6e9120cec5519 --- /dev/null +++ b/libs/langchain/tests/unit_tests/prompts/test_prompt.py @@ -0,0 +1,7 @@ +from langchain.prompts.prompt import __all__ + +EXPECTED_ALL = ["Prompt", "PromptTemplate"] + + +def test_all_imports() -> None: + assert set(__all__) == set(EXPECTED_ALL) diff --git a/libs/langchain/tests/unit_tests/tools/test_base.py b/libs/langchain/tests/unit_tests/tools/test_base.py new file mode 100644 index 0000000000000..8c00a846b5879 --- /dev/null +++ b/libs/langchain/tests/unit_tests/tools/test_base.py @@ -0,0 +1,18 @@ +from langchain.tools.base import __all__ + +EXPECTED_ALL = [ + "BaseTool", + "SchemaAnnotationError", + "StructuredTool", + "Tool", + "ToolException", + "_SchemaConfig", + "_create_subset_model", + "_get_filtered_args", + "create_schema_from_function", + "tool", +] + + +def test_all_imports() -> None: + assert set(__all__) == set(EXPECTED_ALL)