Skip to content

Commit

Permalink
add methods to deserialize prompts that were old
Browse files Browse the repository at this point in the history
  • Loading branch information
hwchase17 committed Dec 18, 2023
1 parent 23eb480 commit a5ca088
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 4 deletions.
11 changes: 8 additions & 3 deletions libs/core/langchain_core/load/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,16 @@
import os
from typing import Any, Dict, List, Optional

from langchain_core.load.mapping import SERIALIZABLE_MAPPING
from langchain_core.load.mapping import (
OLD_PROMPT_TEMPLATE_FORMATS,
SERIALIZABLE_MAPPING,
)
from langchain_core.load.serializable import Serializable

DEFAULT_NAMESPACES = ["langchain", "langchain_core", "langchain_community"]

ALL_SERIALIZABLE_MAPPINGS = {**SERIALIZABLE_MAPPING, **OLD_PROMPT_TEMPLATE_FORMATS}


class Reviver:
"""Reviver for JSON objects."""
Expand Down Expand Up @@ -67,13 +72,13 @@ def __call__(self, value: Dict[str, Any]) -> Any:
if namespace[0] in DEFAULT_NAMESPACES:
# Get the importable path
key = tuple(namespace + [name])
if key not in SERIALIZABLE_MAPPING:
if key not in ALL_SERIALIZABLE_MAPPINGS:
raise ValueError(
"Trying to deserialize something that cannot "
"be deserialized in current version of langchain-core: "
f"{key}"
)
import_path = SERIALIZABLE_MAPPING[key]
import_path = ALL_SERIALIZABLE_MAPPINGS[key]
# Split into module and name
import_dir, import_obj = import_path[:-1], import_path[-1]
# Import module
Expand Down
159 changes: 159 additions & 0 deletions libs/core/langchain_core/load/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,3 +476,162 @@
"RunnableRetry",
),
}

# Needed for backwards compatibility for a few versions where we serialized
# with langchain_core
OLD_PROMPT_TEMPLATE_FORMATS = {
(
"langchain_core",
"prompts",
"base",
"BasePromptTemplate",
): (
"langchain_core",
"prompts",
"base",
"BasePromptTemplate",
),
(
"langchain_core",
"prompts",
"prompt",
"PromptTemplate",
): (
"langchain_core",
"prompts",
"prompt",
"PromptTemplate",
),
(
"langchain_core",
"prompts",
"chat",
"MessagesPlaceholder",
): (
"langchain_core",
"prompts",
"chat",
"MessagesPlaceholder",
),
(
"langchain_core",
"prompts",
"chat",
"ChatPromptTemplate",
): (
"langchain_core",
"prompts",
"chat",
"ChatPromptTemplate",
),
(
"langchain_core",
"prompts",
"chat",
"HumanMessagePromptTemplate",
): (
"langchain_core",
"prompts",
"chat",
"HumanMessagePromptTemplate",
),
(
"langchain_core",
"prompts",
"chat",
"SystemMessagePromptTemplate",
): (
"langchain_core",
"prompts",
"chat",
"SystemMessagePromptTemplate",
),
(
"langchain_core",
"prompts",
"chat",
"BaseMessagePromptTemplate",
): (
"langchain_core",
"prompts",
"chat",
"BaseMessagePromptTemplate",
),
(
"langchain_core",
"prompts",
"chat",
"BaseChatPromptTemplate",
): (
"langchain_core",
"prompts",
"chat",
"BaseChatPromptTemplate",
),
(
"langchain_core",
"prompts",
"chat",
"ChatMessagePromptTemplate",
): (
"langchain_core",
"prompts",
"chat",
"ChatMessagePromptTemplate",
),
(
"langchain_core",
"prompts",
"few_shot_with_templates",
"FewShotPromptWithTemplates",
): (
"langchain_core",
"prompts",
"few_shot_with_templates",
"FewShotPromptWithTemplates",
),
(
"langchain_core",
"prompts",
"pipeline",
"PipelinePromptTemplate",
): (
"langchain_core",
"prompts",
"pipeline",
"PipelinePromptTemplate",
),
(
"langchain_core",
"prompts",
"string",
"StringPromptTemplate",
): (
"langchain_core",
"prompts",
"string",
"StringPromptTemplate",
),
(
"langchain_core",
"prompts",
"chat",
"BaseStringMessagePromptTemplate",
): (
"langchain_core",
"prompts",
"chat",
"BaseStringMessagePromptTemplate",
),
(
"langchain_core",
"prompts",
"chat",
"AIMessagePromptTemplate",
): (
"langchain_core",
"prompts",
"chat",
"AIMessagePromptTemplate",
),
}
3 changes: 2 additions & 1 deletion libs/core/tests/unit_tests/runnables/test_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ def test_interfaces() -> None:


def _get_get_session_history(
*, store: Optional[Dict[str, Any]] = None
*,
store: Optional[Dict[str, Any]] = None,
) -> Callable[..., ChatMessageHistory]:
chat_history_store = store if store is not None else {}

Expand Down

0 comments on commit a5ca088

Please sign in to comment.