From a5ca08844718a162271af9ae6bdd80a0bccf9f1b Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Mon, 18 Dec 2023 10:45:13 -0800 Subject: [PATCH] add methods to deserialize prompts that were old --- libs/core/langchain_core/load/load.py | 11 +- libs/core/langchain_core/load/mapping.py | 159 ++++++++++++++++++ .../unit_tests/runnables/test_history.py | 3 +- 3 files changed, 169 insertions(+), 4 deletions(-) diff --git a/libs/core/langchain_core/load/load.py b/libs/core/langchain_core/load/load.py index e97deede28bf0..89980bc07436e 100644 --- a/libs/core/langchain_core/load/load.py +++ b/libs/core/langchain_core/load/load.py @@ -3,11 +3,16 @@ import os from typing import Any, Dict, List, Optional -from langchain_core.load.mapping import SERIALIZABLE_MAPPING +from langchain_core.load.mapping import ( + OLD_PROMPT_TEMPLATE_FORMATS, + SERIALIZABLE_MAPPING, +) from langchain_core.load.serializable import Serializable DEFAULT_NAMESPACES = ["langchain", "langchain_core", "langchain_community"] +ALL_SERIALIZABLE_MAPPINGS = {**SERIALIZABLE_MAPPING, **OLD_PROMPT_TEMPLATE_FORMATS} + class Reviver: """Reviver for JSON objects.""" @@ -67,13 +72,13 @@ def __call__(self, value: Dict[str, Any]) -> Any: if namespace[0] in DEFAULT_NAMESPACES: # Get the importable path key = tuple(namespace + [name]) - if key not in SERIALIZABLE_MAPPING: + if key not in ALL_SERIALIZABLE_MAPPINGS: raise ValueError( "Trying to deserialize something that cannot " "be deserialized in current version of langchain-core: " f"{key}" ) - import_path = SERIALIZABLE_MAPPING[key] + import_path = ALL_SERIALIZABLE_MAPPINGS[key] # Split into module and name import_dir, import_obj = import_path[:-1], import_path[-1] # Import module diff --git a/libs/core/langchain_core/load/mapping.py b/libs/core/langchain_core/load/mapping.py index 923525f0fdb9d..dcec6affb406b 100644 --- a/libs/core/langchain_core/load/mapping.py +++ b/libs/core/langchain_core/load/mapping.py @@ -476,3 +476,162 @@ "RunnableRetry", ), } + +# Needed for backwards compatibility for a few versions where we serialized +# with langchain_core +OLD_PROMPT_TEMPLATE_FORMATS = { + ( + "langchain_core", + "prompts", + "base", + "BasePromptTemplate", + ): ( + "langchain_core", + "prompts", + "base", + "BasePromptTemplate", + ), + ( + "langchain_core", + "prompts", + "prompt", + "PromptTemplate", + ): ( + "langchain_core", + "prompts", + "prompt", + "PromptTemplate", + ), + ( + "langchain_core", + "prompts", + "chat", + "MessagesPlaceholder", + ): ( + "langchain_core", + "prompts", + "chat", + "MessagesPlaceholder", + ), + ( + "langchain_core", + "prompts", + "chat", + "ChatPromptTemplate", + ): ( + "langchain_core", + "prompts", + "chat", + "ChatPromptTemplate", + ), + ( + "langchain_core", + "prompts", + "chat", + "HumanMessagePromptTemplate", + ): ( + "langchain_core", + "prompts", + "chat", + "HumanMessagePromptTemplate", + ), + ( + "langchain_core", + "prompts", + "chat", + "SystemMessagePromptTemplate", + ): ( + "langchain_core", + "prompts", + "chat", + "SystemMessagePromptTemplate", + ), + ( + "langchain_core", + "prompts", + "chat", + "BaseMessagePromptTemplate", + ): ( + "langchain_core", + "prompts", + "chat", + "BaseMessagePromptTemplate", + ), + ( + "langchain_core", + "prompts", + "chat", + "BaseChatPromptTemplate", + ): ( + "langchain_core", + "prompts", + "chat", + "BaseChatPromptTemplate", + ), + ( + "langchain_core", + "prompts", + "chat", + "ChatMessagePromptTemplate", + ): ( + "langchain_core", + "prompts", + "chat", + "ChatMessagePromptTemplate", + ), + ( + "langchain_core", + "prompts", + "few_shot_with_templates", + "FewShotPromptWithTemplates", + ): ( + "langchain_core", + "prompts", + "few_shot_with_templates", + "FewShotPromptWithTemplates", + ), + ( + "langchain_core", + "prompts", + "pipeline", + "PipelinePromptTemplate", + ): ( + "langchain_core", + "prompts", + "pipeline", + "PipelinePromptTemplate", + ), + ( + "langchain_core", + "prompts", + "string", + "StringPromptTemplate", + ): ( + "langchain_core", + "prompts", + "string", + "StringPromptTemplate", + ), + ( + "langchain_core", + "prompts", + "chat", + "BaseStringMessagePromptTemplate", + ): ( + "langchain_core", + "prompts", + "chat", + "BaseStringMessagePromptTemplate", + ), + ( + "langchain_core", + "prompts", + "chat", + "AIMessagePromptTemplate", + ): ( + "langchain_core", + "prompts", + "chat", + "AIMessagePromptTemplate", + ), +} diff --git a/libs/core/tests/unit_tests/runnables/test_history.py b/libs/core/tests/unit_tests/runnables/test_history.py index 193a779021f50..72cfcbf77cdd6 100644 --- a/libs/core/tests/unit_tests/runnables/test_history.py +++ b/libs/core/tests/unit_tests/runnables/test_history.py @@ -19,7 +19,8 @@ def test_interfaces() -> None: def _get_get_session_history( - *, store: Optional[Dict[str, Any]] = None + *, + store: Optional[Dict[str, Any]] = None, ) -> Callable[..., ChatMessageHistory]: chat_history_store = store if store is not None else {}