From 9379613132023037cf2fe2e28634847cbac881e1 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Wed, 4 Sep 2024 16:59:07 -0400 Subject: [PATCH] langchain[major]: Upgrade langchain to be pydantic 2 compatible (#26050) Upgrading the langchain package to be pydantic 2 compatible. Had to remove some parts of unit tests in parsers that were relying on spying on methods since that fails with pydantic 2. The unit tests don't seem particularly good, so can be re-written at a future date. Depends on: https://github.com/langchain-ai/langchain/pull/26057 Most of this PR was done using gritql for code mods, followed by some fixes done manually to account for changes made by pydantic --- libs/langchain/Makefile | 1 - libs/langchain/langchain/agents/agent.py | 38 ++-- .../agent_toolkits/vectorstore/toolkit.py | 17 +- libs/langchain/langchain/agents/chat/base.py | 2 +- .../langchain/agents/conversational/base.py | 2 +- .../agents/conversational_chat/base.py | 2 +- libs/langchain/langchain/agents/mrkl/base.py | 2 +- .../langchain/agents/openai_assistant/base.py | 15 +- .../agents/openai_functions_agent/base.py | 11 +- .../openai_functions_multi_agent/base.py | 11 +- libs/langchain/langchain/agents/react/base.py | 2 +- .../agents/self_ask_with_search/base.py | 2 +- .../langchain/agents/structured_chat/base.py | 2 +- .../agents/structured_chat/output_parser.py | 2 +- libs/langchain/langchain/chains/api/base.py | 24 +-- libs/langchain/langchain/chains/base.py | 18 +- .../chains/combine_documents/base.py | 2 +- .../chains/combine_documents/map_reduce.py | 24 ++- .../chains/combine_documents/map_rerank.py | 31 ++-- .../chains/combine_documents/reduce.py | 8 +- .../chains/combine_documents/refine.py | 19 +- .../chains/combine_documents/stuff.py | 14 +- .../chains/constitutional_ai/models.py | 2 +- .../langchain/chains/conversation/base.py | 24 +-- .../chains/conversational_retrieval/base.py | 16 +- .../chains/elasticsearch_database/base.py | 18 +- libs/langchain/langchain/chains/flare/base.py | 2 +- libs/langchain/langchain/chains/hyde/base.py | 8 +- libs/langchain/langchain/chains/llm.py | 9 +- .../langchain/chains/llm_checker/base.py | 14 +- .../langchain/chains/llm_math/base.py | 14 +- .../chains/llm_summarization_checker/base.py | 14 +- libs/langchain/langchain/chains/mapreduce.py | 8 +- libs/langchain/langchain/chains/moderation.py | 7 +- .../langchain/langchain/chains/natbot/base.py | 14 +- .../langchain/chains/openai_functions/base.py | 2 +- .../openai_functions/citation_fuzzy_match.py | 2 +- .../chains/openai_functions/extraction.py | 2 +- .../openai_functions/qa_with_structure.py | 2 +- .../chains/openai_tools/extraction.py | 2 +- .../langchain/chains/prompt_selector.py | 2 +- .../langchain/chains/qa_generation/base.py | 2 +- .../langchain/chains/qa_with_sources/base.py | 14 +- .../chains/qa_with_sources/retrieval.py | 2 +- .../chains/qa_with_sources/vector_db.py | 7 +- .../chains/query_constructor/schema.py | 9 +- .../langchain/chains/retrieval_qa/base.py | 21 ++- .../langchain/langchain/chains/router/base.py | 8 +- .../chains/router/embedding_router.py | 8 +- .../langchain/chains/router/llm_router.py | 11 +- libs/langchain/langchain/chains/sequential.py | 30 +-- .../chains/structured_output/base.py | 2 +- libs/langchain/langchain/chains/transform.py | 2 +- libs/langchain/langchain/chat_models/base.py | 2 +- .../agents/trajectory_eval_chain.py | 7 +- .../evaluation/comparison/eval_chain.py | 7 +- .../evaluation/criteria/eval_chain.py | 7 +- .../evaluation/embedding_distance/base.py | 7 +- .../langchain/evaluation/qa/eval_chain.py | 11 +- .../langchain/evaluation/qa/generate_chain.py | 2 +- .../evaluation/scoring/eval_chain.py | 7 +- .../evaluation/string_distance/base.py | 2 +- .../langchain/indexes/vectorstore.py | 16 +- .../langchain/langchain/memory/chat_memory.py | 2 +- libs/langchain/langchain/memory/combined.py | 2 +- libs/langchain/langchain/memory/entity.py | 7 +- libs/langchain/langchain/memory/summary.py | 2 +- .../langchain/langchain/memory/vectorstore.py | 2 +- .../memory/vectorstore_token_buffer_memory.py | 2 +- .../langchain/langchain/output_parsers/fix.py | 5 +- .../output_parsers/pandas_dataframe.py | 2 +- .../langchain/output_parsers/retry.py | 7 +- .../langchain/output_parsers/structured.py | 2 +- .../langchain/output_parsers/yaml.py | 2 +- .../retrievers/contextual_compression.py | 6 +- .../retrievers/document_compressors/base.py | 6 +- .../document_compressors/chain_extract.py | 6 +- .../document_compressors/chain_filter.py | 6 +- .../document_compressors/cohere_rerank.py | 14 +- .../cross_encoder_rerank.py | 8 +- .../document_compressors/embeddings_filter.py | 7 +- .../document_compressors/listwise_rerank.py | 7 +- .../langchain/retrievers/ensemble.py | 7 +- .../langchain/retrievers/multi_vector.py | 9 +- .../langchain/retrievers/self_query/base.py | 14 +- .../retrievers/time_weighted_retriever.py | 7 +- .../langchain/smith/evaluation/config.py | 12 +- libs/langchain/scripts/check_pydantic.sh | 27 --- .../chat_models/test_base.py | 2 +- .../tests/mock_servers/robot/server.py | 2 +- .../tests/unit_tests/agents/test_agent.py | 172 +++++++++++------- .../callbacks/fake_callback_handler.py | 6 +- .../evaluation/agents/test_eval_chain.py | 2 +- .../tests/unit_tests/llms/fake_llm.py | 2 +- .../unit_tests/llms/test_fake_chat_model.py | 6 +- .../tests/unit_tests/load/test_dump.py | 12 +- .../unit_tests/output_parsers/test_fix.py | 29 --- .../unit_tests/output_parsers/test_retry.py | 50 ----- .../output_parsers/test_yaml_parser.py | 2 +- .../evaluation/test_string_run_evaluator.py | 3 +- .../tests/unit_tests/test_imports.py | 2 + .../langchain/tests/unit_tests/test_schema.py | 22 +-- .../unit_tests/utils/test_openai_functions.py | 2 +- 103 files changed, 557 insertions(+), 523 deletions(-) delete mode 100755 libs/langchain/scripts/check_pydantic.sh diff --git a/libs/langchain/Makefile b/libs/langchain/Makefile index c77591f6430d6..e06cd2e65d1a9 100644 --- a/libs/langchain/Makefile +++ b/libs/langchain/Makefile @@ -54,7 +54,6 @@ lint_tests: PYTHON_FILES=tests lint_tests: MYPY_CACHE=.mypy_cache_test lint lint_diff lint_package lint_tests: - ./scripts/check_pydantic.sh . ./scripts/lint_imports.sh [ "$(PYTHON_FILES)" = "" ] || poetry run ruff check $(PYTHON_FILES) [ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff diff --git a/libs/langchain/langchain/agents/agent.py b/libs/langchain/langchain/agents/agent.py index 50797228a53b8..5fae509f343db 100644 --- a/libs/langchain/langchain/agents/agent.py +++ b/libs/langchain/langchain/agents/agent.py @@ -40,11 +40,12 @@ from langchain_core.prompts import BasePromptTemplate from langchain_core.prompts.few_shot import FewShotPromptTemplate from langchain_core.prompts.prompt import PromptTemplate -from langchain_core.pydantic_v1 import BaseModel, root_validator from langchain_core.runnables import Runnable, RunnableConfig, ensure_config from langchain_core.runnables.utils import AddableDict from langchain_core.tools import BaseTool from langchain_core.utils.input import get_color_mapping +from pydantic import BaseModel, ConfigDict, model_validator +from typing_extensions import Self from langchain.agents.agent_iterator import AgentExecutorIterator from langchain.agents.agent_types import AgentType @@ -420,8 +421,9 @@ class RunnableAgent(BaseSingleActionAgent): individual LLM tokens will not be available in stream_log. """ - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict( + arbitrary_types_allowed=True, + ) @property def return_values(self) -> List[str]: @@ -528,8 +530,9 @@ class RunnableMultiActionAgent(BaseMultiActionAgent): individual LLM tokens will not be available in stream_log. """ - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict( + arbitrary_types_allowed=True, + ) @property def return_values(self) -> List[str]: @@ -854,8 +857,8 @@ def input_keys(self) -> List[str]: """ return list(set(self.llm_chain.input_keys) - {"agent_scratchpad"}) - @root_validator(pre=False, skip_on_failure=True) - def validate_prompt(cls, values: Dict) -> Dict: + @model_validator(mode="after") + def validate_prompt(self) -> Self: """Validate that prompt matches format. Args: @@ -868,7 +871,7 @@ def validate_prompt(cls, values: Dict) -> Dict: ValueError: If `agent_scratchpad` is not in prompt.input_variables and prompt is not a FewShotPromptTemplate or a PromptTemplate. """ - prompt = values["llm_chain"].prompt + prompt = self.llm_chain.prompt if "agent_scratchpad" not in prompt.input_variables: logger.warning( "`agent_scratchpad` should be a variable in prompt.input_variables." @@ -881,7 +884,7 @@ def validate_prompt(cls, values: Dict) -> Dict: prompt.suffix += "\n{agent_scratchpad}" else: raise ValueError(f"Got unexpected prompt type {type(prompt)}") - return values + return self @property @abstractmethod @@ -1120,8 +1123,8 @@ def from_agent_and_tools( **kwargs, ) - @root_validator(pre=False, skip_on_failure=True) - def validate_tools(cls, values: Dict) -> Dict: + @model_validator(mode="after") + def validate_tools(self) -> Self: """Validate that tools are compatible with agent. Args: @@ -1133,19 +1136,20 @@ def validate_tools(cls, values: Dict) -> Dict: Raises: ValueError: If allowed tools are different than provided tools. """ - agent = values["agent"] - tools = values["tools"] - allowed_tools = agent.get_allowed_tools() + agent = self.agent + tools = self.tools + allowed_tools = agent.get_allowed_tools() # type: ignore if allowed_tools is not None: if set(allowed_tools) != set([tool.name for tool in tools]): raise ValueError( f"Allowed tools ({allowed_tools}) different than " f"provided tools ({[tool.name for tool in tools]})" ) - return values + return self - @root_validator(pre=True) - def validate_runnable_agent(cls, values: Dict) -> Dict: + @model_validator(mode="before") + @classmethod + def validate_runnable_agent(cls, values: Dict) -> Any: """Convert runnable to agent if passed in. Args: diff --git a/libs/langchain/langchain/agents/agent_toolkits/vectorstore/toolkit.py b/libs/langchain/langchain/agents/agent_toolkits/vectorstore/toolkit.py index c40b3f3b7f289..71114a49eea37 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/vectorstore/toolkit.py +++ b/libs/langchain/langchain/agents/agent_toolkits/vectorstore/toolkit.py @@ -3,10 +3,10 @@ from typing import List from langchain_core.language_models import BaseLanguageModel -from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.tools import BaseTool from langchain_core.tools.base import BaseToolkit from langchain_core.vectorstores import VectorStore +from pydantic import BaseModel, ConfigDict, Field class VectorStoreInfo(BaseModel): @@ -16,8 +16,9 @@ class VectorStoreInfo(BaseModel): name: str description: str - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict( + arbitrary_types_allowed=True, + ) class VectorStoreToolkit(BaseToolkit): @@ -26,8 +27,9 @@ class VectorStoreToolkit(BaseToolkit): vectorstore_info: VectorStoreInfo = Field(exclude=True) llm: BaseLanguageModel - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict( + arbitrary_types_allowed=True, + ) def get_tools(self) -> List[BaseTool]: """Get the tools in the toolkit.""" @@ -67,8 +69,9 @@ class VectorStoreRouterToolkit(BaseToolkit): vectorstores: List[VectorStoreInfo] = Field(exclude=True) llm: BaseLanguageModel - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict( + arbitrary_types_allowed=True, + ) def get_tools(self) -> List[BaseTool]: """Get the tools in the toolkit.""" diff --git a/libs/langchain/langchain/agents/chat/base.py b/libs/langchain/langchain/agents/chat/base.py index a7a16be772611..00ced776f1342 100644 --- a/libs/langchain/langchain/agents/chat/base.py +++ b/libs/langchain/langchain/agents/chat/base.py @@ -10,8 +10,8 @@ HumanMessagePromptTemplate, SystemMessagePromptTemplate, ) -from langchain_core.pydantic_v1 import Field from langchain_core.tools import BaseTool +from pydantic import Field from langchain.agents.agent import Agent, AgentOutputParser from langchain.agents.chat.output_parser import ChatOutputParser diff --git a/libs/langchain/langchain/agents/conversational/base.py b/libs/langchain/langchain/agents/conversational/base.py index bbbf666e5903d..a0ef85946abd3 100644 --- a/libs/langchain/langchain/agents/conversational/base.py +++ b/libs/langchain/langchain/agents/conversational/base.py @@ -8,8 +8,8 @@ from langchain_core.callbacks import BaseCallbackManager from langchain_core.language_models import BaseLanguageModel from langchain_core.prompts import PromptTemplate -from langchain_core.pydantic_v1 import Field from langchain_core.tools import BaseTool +from pydantic import Field from langchain.agents.agent import Agent, AgentOutputParser from langchain.agents.agent_types import AgentType diff --git a/libs/langchain/langchain/agents/conversational_chat/base.py b/libs/langchain/langchain/agents/conversational_chat/base.py index 08ec829613a53..138933addbaec 100644 --- a/libs/langchain/langchain/agents/conversational_chat/base.py +++ b/libs/langchain/langchain/agents/conversational_chat/base.py @@ -17,8 +17,8 @@ MessagesPlaceholder, SystemMessagePromptTemplate, ) -from langchain_core.pydantic_v1 import Field from langchain_core.tools import BaseTool +from pydantic import Field from langchain.agents.agent import Agent, AgentOutputParser from langchain.agents.conversational_chat.output_parser import ConvoOutputParser diff --git a/libs/langchain/langchain/agents/mrkl/base.py b/libs/langchain/langchain/agents/mrkl/base.py index 1b16d4d862f6f..cc4d9da5537d7 100644 --- a/libs/langchain/langchain/agents/mrkl/base.py +++ b/libs/langchain/langchain/agents/mrkl/base.py @@ -8,9 +8,9 @@ from langchain_core.callbacks import BaseCallbackManager from langchain_core.language_models import BaseLanguageModel from langchain_core.prompts import PromptTemplate -from langchain_core.pydantic_v1 import Field from langchain_core.tools import BaseTool, Tool from langchain_core.tools.render import render_text_description +from pydantic import Field from langchain.agents.agent import Agent, AgentExecutor, AgentOutputParser from langchain.agents.agent_types import AgentType diff --git a/libs/langchain/langchain/agents/openai_assistant/base.py b/libs/langchain/langchain/agents/openai_assistant/base.py index a2f0ef74a5bff..e7d4d2bfa9c9b 100644 --- a/libs/langchain/langchain/agents/openai_assistant/base.py +++ b/libs/langchain/langchain/agents/openai_assistant/base.py @@ -20,10 +20,11 @@ from langchain_core.agents import AgentAction, AgentFinish from langchain_core.callbacks import CallbackManager from langchain_core.load import dumpd -from langchain_core.pydantic_v1 import BaseModel, Field, root_validator from langchain_core.runnables import RunnableConfig, RunnableSerializable, ensure_config from langchain_core.tools import BaseTool from langchain_core.utils.function_calling import convert_to_openai_tool +from pydantic import BaseModel, Field, model_validator +from typing_extensions import Self if TYPE_CHECKING: import openai @@ -232,14 +233,14 @@ def execute_agent(agent, tools, input): as_agent: bool = False """Use as a LangChain agent, compatible with the AgentExecutor.""" - @root_validator(pre=False, skip_on_failure=True) - def validate_async_client(cls, values: dict) -> dict: - if values["async_client"] is None: + @model_validator(mode="after") + def validate_async_client(self) -> Self: + if self.async_client is None: import openai - api_key = values["client"].api_key - values["async_client"] = openai.AsyncOpenAI(api_key=api_key) - return values + api_key = self.client.api_key + self.async_client = openai.AsyncOpenAI(api_key=api_key) + return self @classmethod def create_assistant( diff --git a/libs/langchain/langchain/agents/openai_functions_agent/base.py b/libs/langchain/langchain/agents/openai_functions_agent/base.py index fbb23a56e4785..5a40a665e3310 100644 --- a/libs/langchain/langchain/agents/openai_functions_agent/base.py +++ b/libs/langchain/langchain/agents/openai_functions_agent/base.py @@ -17,10 +17,11 @@ HumanMessagePromptTemplate, MessagesPlaceholder, ) -from langchain_core.pydantic_v1 import root_validator from langchain_core.runnables import Runnable, RunnablePassthrough from langchain_core.tools import BaseTool from langchain_core.utils.function_calling import convert_to_openai_function +from pydantic import model_validator +from typing_extensions import Self from langchain.agents import BaseSingleActionAgent from langchain.agents.format_scratchpad.openai_functions import ( @@ -58,8 +59,8 @@ def get_allowed_tools(self) -> List[str]: """Get allowed tools.""" return [t.name for t in self.tools] - @root_validator(pre=False, skip_on_failure=True) - def validate_prompt(cls, values: dict) -> dict: + @model_validator(mode="after") + def validate_prompt(self) -> Self: """Validate prompt. Args: @@ -71,13 +72,13 @@ def validate_prompt(cls, values: dict) -> dict: Raises: ValueError: If `agent_scratchpad` is not in the prompt. """ - prompt: BasePromptTemplate = values["prompt"] + prompt: BasePromptTemplate = self.prompt if "agent_scratchpad" not in prompt.input_variables: raise ValueError( "`agent_scratchpad` should be one of the variables in the prompt, " f"got {prompt.input_variables}" ) - return values + return self @property def input_keys(self) -> List[str]: diff --git a/libs/langchain/langchain/agents/openai_functions_multi_agent/base.py b/libs/langchain/langchain/agents/openai_functions_multi_agent/base.py index 819fc8b1d939e..bec49e5f14d24 100644 --- a/libs/langchain/langchain/agents/openai_functions_multi_agent/base.py +++ b/libs/langchain/langchain/agents/openai_functions_multi_agent/base.py @@ -21,8 +21,9 @@ HumanMessagePromptTemplate, MessagesPlaceholder, ) -from langchain_core.pydantic_v1 import root_validator from langchain_core.tools import BaseTool +from pydantic import model_validator +from typing_extensions import Self from langchain.agents import BaseMultiActionAgent from langchain.agents.format_scratchpad.openai_functions import ( @@ -115,15 +116,15 @@ def get_allowed_tools(self) -> List[str]: """Get allowed tools.""" return [t.name for t in self.tools] - @root_validator(pre=False, skip_on_failure=True) - def validate_prompt(cls, values: dict) -> dict: - prompt: BasePromptTemplate = values["prompt"] + @model_validator(mode="after") + def validate_prompt(self) -> Self: + prompt: BasePromptTemplate = self.prompt if "agent_scratchpad" not in prompt.input_variables: raise ValueError( "`agent_scratchpad` should be one of the variables in the prompt, " f"got {prompt.input_variables}" ) - return values + return self @property def input_keys(self) -> List[str]: diff --git a/libs/langchain/langchain/agents/react/base.py b/libs/langchain/langchain/agents/react/base.py index 93a60bbe61ba2..81a38141fe686 100644 --- a/libs/langchain/langchain/agents/react/base.py +++ b/libs/langchain/langchain/agents/react/base.py @@ -8,8 +8,8 @@ from langchain_core.documents import Document from langchain_core.language_models import BaseLanguageModel from langchain_core.prompts import BasePromptTemplate -from langchain_core.pydantic_v1 import Field from langchain_core.tools import BaseTool, Tool +from pydantic import Field from langchain.agents.agent import Agent, AgentExecutor, AgentOutputParser from langchain.agents.agent_types import AgentType diff --git a/libs/langchain/langchain/agents/self_ask_with_search/base.py b/libs/langchain/langchain/agents/self_ask_with_search/base.py index 7a1d81f7d40d1..9a642b81b1289 100644 --- a/libs/langchain/langchain/agents/self_ask_with_search/base.py +++ b/libs/langchain/langchain/agents/self_ask_with_search/base.py @@ -7,9 +7,9 @@ from langchain_core._api import deprecated from langchain_core.language_models import BaseLanguageModel from langchain_core.prompts import BasePromptTemplate -from langchain_core.pydantic_v1 import Field from langchain_core.runnables import Runnable, RunnablePassthrough from langchain_core.tools import BaseTool, Tool +from pydantic import Field from langchain.agents.agent import Agent, AgentExecutor, AgentOutputParser from langchain.agents.agent_types import AgentType diff --git a/libs/langchain/langchain/agents/structured_chat/base.py b/libs/langchain/langchain/agents/structured_chat/base.py index e1403e26f2607..a520cfecf7142 100644 --- a/libs/langchain/langchain/agents/structured_chat/base.py +++ b/libs/langchain/langchain/agents/structured_chat/base.py @@ -11,10 +11,10 @@ HumanMessagePromptTemplate, SystemMessagePromptTemplate, ) -from langchain_core.pydantic_v1 import Field from langchain_core.runnables import Runnable, RunnablePassthrough from langchain_core.tools import BaseTool from langchain_core.tools.render import ToolsRenderer +from pydantic import Field from langchain.agents.agent import Agent, AgentOutputParser from langchain.agents.format_scratchpad import format_log_to_str diff --git a/libs/langchain/langchain/agents/structured_chat/output_parser.py b/libs/langchain/langchain/agents/structured_chat/output_parser.py index f73f87f4f4be4..1cdb4fe1fb4a9 100644 --- a/libs/langchain/langchain/agents/structured_chat/output_parser.py +++ b/libs/langchain/langchain/agents/structured_chat/output_parser.py @@ -8,7 +8,7 @@ from langchain_core.agents import AgentAction, AgentFinish from langchain_core.exceptions import OutputParserException from langchain_core.language_models import BaseLanguageModel -from langchain_core.pydantic_v1 import Field +from pydantic import Field from langchain.agents.agent import AgentOutputParser from langchain.agents.structured_chat.prompt import FORMAT_INSTRUCTIONS diff --git a/libs/langchain/langchain/chains/api/base.py b/libs/langchain/langchain/chains/api/base.py index 94896102dc6fd..4387cb623e06a 100644 --- a/libs/langchain/langchain/chains/api/base.py +++ b/libs/langchain/langchain/chains/api/base.py @@ -12,7 +12,8 @@ ) from langchain_core.language_models import BaseLanguageModel from langchain_core.prompts import BasePromptTemplate -from langchain_core.pydantic_v1 import Field, root_validator +from pydantic import Field, model_validator +from typing_extensions import Self from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT from langchain.chains.base import Chain @@ -227,19 +228,20 @@ def output_keys(self) -> List[str]: """ return [self.output_key] - @root_validator(pre=False, skip_on_failure=True) - def validate_api_request_prompt(cls, values: Dict) -> Dict: + @model_validator(mode="after") + def validate_api_request_prompt(self) -> Self: """Check that api request prompt expects the right variables.""" - input_vars = values["api_request_chain"].prompt.input_variables + input_vars = self.api_request_chain.prompt.input_variables expected_vars = {"question", "api_docs"} if set(input_vars) != expected_vars: raise ValueError( f"Input variables should be {expected_vars}, got {input_vars}" ) - return values + return self - @root_validator(pre=True) - def validate_limit_to_domains(cls, values: Dict) -> Dict: + @model_validator(mode="before") + @classmethod + def validate_limit_to_domains(cls, values: Dict) -> Any: """Check that allowed domains are valid.""" # This check must be a pre=True check, so that a default of None # won't be set to limit_to_domains if it's not provided. @@ -258,16 +260,16 @@ def validate_limit_to_domains(cls, values: Dict) -> Dict: ) return values - @root_validator(pre=False, skip_on_failure=True) - def validate_api_answer_prompt(cls, values: Dict) -> Dict: + @model_validator(mode="after") + def validate_api_answer_prompt(self) -> Self: """Check that api answer prompt expects the right variables.""" - input_vars = values["api_answer_chain"].prompt.input_variables + input_vars = self.api_answer_chain.prompt.input_variables expected_vars = {"question", "api_docs", "api_url", "api_response"} if set(input_vars) != expected_vars: raise ValueError( f"Input variables should be {expected_vars}, got {input_vars}" ) - return values + return self def _call( self, diff --git a/libs/langchain/langchain/chains/base.py b/libs/langchain/langchain/chains/base.py index 4c1dcc07eb231..a2d9ca57ddc79 100644 --- a/libs/langchain/langchain/chains/base.py +++ b/libs/langchain/langchain/chains/base.py @@ -21,7 +21,6 @@ from langchain_core.load.dump import dumpd from langchain_core.memory import BaseMemory from langchain_core.outputs import RunInfo -from langchain_core.pydantic_v1 import BaseModel, Field, root_validator, validator from langchain_core.runnables import ( RunnableConfig, RunnableSerializable, @@ -29,6 +28,13 @@ run_in_executor, ) from langchain_core.runnables.utils import create_model +from pydantic import ( + BaseModel, + ConfigDict, + Field, + model_validator, + validator, +) from langchain.schema import RUN_KEY @@ -96,8 +102,9 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC): callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True) """[DEPRECATED] Use `callbacks` instead.""" - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict( + arbitrary_types_allowed=True, + ) def get_input_schema( self, config: Optional[RunnableConfig] = None @@ -223,8 +230,9 @@ async def ainvoke( def _chain_type(self) -> str: raise NotImplementedError("Saving not supported for this chain type.") - @root_validator(pre=True) - def raise_callback_manager_deprecation(cls, values: Dict) -> Dict: + @model_validator(mode="before") + @classmethod + def raise_callback_manager_deprecation(cls, values: Dict) -> Any: """Raise deprecation warning if callback_manager is used.""" if values.get("callback_manager") is not None: if values.get("callbacks") is not None: diff --git a/libs/langchain/langchain/chains/combine_documents/base.py b/libs/langchain/langchain/chains/combine_documents/base.py index 00b6002da0ed4..2406cd4215f7a 100644 --- a/libs/langchain/langchain/chains/combine_documents/base.py +++ b/libs/langchain/langchain/chains/combine_documents/base.py @@ -10,10 +10,10 @@ ) from langchain_core.documents import Document from langchain_core.prompts import BasePromptTemplate, PromptTemplate -from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.runnables.config import RunnableConfig from langchain_core.runnables.utils import create_model from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter +from pydantic import BaseModel, Field from langchain.chains.base import Chain diff --git a/libs/langchain/langchain/chains/combine_documents/map_reduce.py b/libs/langchain/langchain/chains/combine_documents/map_reduce.py index 229ed45740c5d..7b4885cdea4fe 100644 --- a/libs/langchain/langchain/chains/combine_documents/map_reduce.py +++ b/libs/langchain/langchain/chains/combine_documents/map_reduce.py @@ -6,9 +6,9 @@ from langchain_core.callbacks import Callbacks from langchain_core.documents import Document -from langchain_core.pydantic_v1 import BaseModel, root_validator from langchain_core.runnables.config import RunnableConfig from langchain_core.runnables.utils import create_model +from pydantic import BaseModel, ConfigDict, model_validator from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.reduce import ReduceDocumentsChain @@ -126,12 +126,14 @@ def output_keys(self) -> List[str]: _output_keys = _output_keys + ["intermediate_steps"] return _output_keys - class Config: - arbitrary_types_allowed = True - extra = "forbid" + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + ) - @root_validator(pre=True) - def get_reduce_chain(cls, values: Dict) -> Dict: + @model_validator(mode="before") + @classmethod + def get_reduce_chain(cls, values: Dict) -> Any: """For backwards compatibility.""" if "combine_document_chain" in values: if "reduce_documents_chain" in values: @@ -153,16 +155,18 @@ def get_reduce_chain(cls, values: Dict) -> Dict: return values - @root_validator(pre=True) - def get_return_intermediate_steps(cls, values: Dict) -> Dict: + @model_validator(mode="before") + @classmethod + def get_return_intermediate_steps(cls, values: Dict) -> Any: """For backwards compatibility.""" if "return_map_steps" in values: values["return_intermediate_steps"] = values["return_map_steps"] del values["return_map_steps"] return values - @root_validator(pre=True) - def get_default_document_variable_name(cls, values: Dict) -> Dict: + @model_validator(mode="before") + @classmethod + def get_default_document_variable_name(cls, values: Dict) -> Any: """Get default document variable name, if not provided.""" if "llm_chain" not in values: raise ValueError("llm_chain must be provided") diff --git a/libs/langchain/langchain/chains/combine_documents/map_rerank.py b/libs/langchain/langchain/chains/combine_documents/map_rerank.py index 0fa346dee8bc8..61e6e226ed65d 100644 --- a/libs/langchain/langchain/chains/combine_documents/map_rerank.py +++ b/libs/langchain/langchain/chains/combine_documents/map_rerank.py @@ -6,9 +6,10 @@ from langchain_core.callbacks import Callbacks from langchain_core.documents import Document -from langchain_core.pydantic_v1 import BaseModel, root_validator from langchain_core.runnables.config import RunnableConfig from langchain_core.runnables.utils import create_model +from pydantic import BaseModel, ConfigDict, model_validator +from typing_extensions import Self from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.llm import LLMChain @@ -74,9 +75,10 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain): """Return intermediate steps. Intermediate steps include the results of calling llm_chain on each document.""" - class Config: - arbitrary_types_allowed = True - extra = "forbid" + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + ) def get_output_schema( self, config: Optional[RunnableConfig] = None @@ -104,30 +106,31 @@ def output_keys(self) -> List[str]: _output_keys += self.metadata_keys return _output_keys - @root_validator(pre=False, skip_on_failure=True) - def validate_llm_output(cls, values: Dict) -> Dict: + @model_validator(mode="after") + def validate_llm_output(self) -> Self: """Validate that the combine chain outputs a dictionary.""" - output_parser = values["llm_chain"].prompt.output_parser + output_parser = self.llm_chain.prompt.output_parser if not isinstance(output_parser, RegexParser): raise ValueError( "Output parser of llm_chain should be a RegexParser," f" got {output_parser}" ) output_keys = output_parser.output_keys - if values["rank_key"] not in output_keys: + if self.rank_key not in output_keys: raise ValueError( - f"Got {values['rank_key']} as key to rank on, but did not find " + f"Got {self.rank_key} as key to rank on, but did not find " f"it in the llm_chain output keys ({output_keys})" ) - if values["answer_key"] not in output_keys: + if self.answer_key not in output_keys: raise ValueError( - f"Got {values['answer_key']} as key to return, but did not find " + f"Got {self.answer_key} as key to return, but did not find " f"it in the llm_chain output keys ({output_keys})" ) - return values + return self - @root_validator(pre=True) - def get_default_document_variable_name(cls, values: Dict) -> Dict: + @model_validator(mode="before") + @classmethod + def get_default_document_variable_name(cls, values: Dict) -> Any: """Get default document variable name, if not provided.""" if "llm_chain" not in values: raise ValueError("llm_chain must be provided") diff --git a/libs/langchain/langchain/chains/combine_documents/reduce.py b/libs/langchain/langchain/chains/combine_documents/reduce.py index 7b2cd6c89a375..662be15e38b55 100644 --- a/libs/langchain/langchain/chains/combine_documents/reduce.py +++ b/libs/langchain/langchain/chains/combine_documents/reduce.py @@ -6,6 +6,7 @@ from langchain_core.callbacks import Callbacks from langchain_core.documents import Document +from pydantic import ConfigDict from langchain.chains.combine_documents.base import BaseCombineDocumentsChain @@ -204,9 +205,10 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain): If None, it will keep trying to collapse documents to fit token_max. Otherwise, after it reaches the max number, it will throw an error""" - class Config: - arbitrary_types_allowed = True - extra = "forbid" + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + ) @property def _collapse_chain(self) -> BaseCombineDocumentsChain: diff --git a/libs/langchain/langchain/chains/combine_documents/refine.py b/libs/langchain/langchain/chains/combine_documents/refine.py index cf2f5d9e92f50..7129ac147dbb1 100644 --- a/libs/langchain/langchain/chains/combine_documents/refine.py +++ b/libs/langchain/langchain/chains/combine_documents/refine.py @@ -8,7 +8,7 @@ from langchain_core.documents import Document from langchain_core.prompts import BasePromptTemplate, format_document from langchain_core.prompts.prompt import PromptTemplate -from langchain_core.pydantic_v1 import Field, root_validator +from pydantic import ConfigDict, Field, model_validator from langchain.chains.combine_documents.base import ( BaseCombineDocumentsChain, @@ -98,20 +98,23 @@ def output_keys(self) -> List[str]: _output_keys = _output_keys + ["intermediate_steps"] return _output_keys - class Config: - arbitrary_types_allowed = True - extra = "forbid" + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + ) - @root_validator(pre=True) - def get_return_intermediate_steps(cls, values: Dict) -> Dict: + @model_validator(mode="before") + @classmethod + def get_return_intermediate_steps(cls, values: Dict) -> Any: """For backwards compatibility.""" if "return_refine_steps" in values: values["return_intermediate_steps"] = values["return_refine_steps"] del values["return_refine_steps"] return values - @root_validator(pre=True) - def get_default_document_variable_name(cls, values: Dict) -> Dict: + @model_validator(mode="before") + @classmethod + def get_default_document_variable_name(cls, values: Dict) -> Any: """Get default document variable name, if not provided.""" if "initial_llm_chain" not in values: raise ValueError("initial_llm_chain must be provided") diff --git a/libs/langchain/langchain/chains/combine_documents/stuff.py b/libs/langchain/langchain/chains/combine_documents/stuff.py index cdecec0f40b82..26750d55fe84f 100644 --- a/libs/langchain/langchain/chains/combine_documents/stuff.py +++ b/libs/langchain/langchain/chains/combine_documents/stuff.py @@ -8,8 +8,8 @@ from langchain_core.language_models import LanguageModelLike from langchain_core.output_parsers import BaseOutputParser, StrOutputParser from langchain_core.prompts import BasePromptTemplate, format_document -from langchain_core.pydantic_v1 import Field, root_validator from langchain_core.runnables import Runnable, RunnablePassthrough +from pydantic import ConfigDict, Field, model_validator from langchain.chains.combine_documents.base import ( DEFAULT_DOCUMENT_PROMPT, @@ -156,12 +156,14 @@ class StuffDocumentsChain(BaseCombineDocumentsChain): document_separator: str = "\n\n" """The string with which to join the formatted documents""" - class Config: - arbitrary_types_allowed = True - extra = "forbid" + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + ) - @root_validator(pre=True) - def get_default_document_variable_name(cls, values: Dict) -> Dict: + @model_validator(mode="before") + @classmethod + def get_default_document_variable_name(cls, values: Dict) -> Any: """Get default document variable name, if not provided. If only one variable is present in the llm_chain.prompt, diff --git a/libs/langchain/langchain/chains/constitutional_ai/models.py b/libs/langchain/langchain/chains/constitutional_ai/models.py index 8058553eb257a..7f9a623459913 100644 --- a/libs/langchain/langchain/chains/constitutional_ai/models.py +++ b/libs/langchain/langchain/chains/constitutional_ai/models.py @@ -1,6 +1,6 @@ """Models for the Constitutional AI chain.""" -from langchain_core.pydantic_v1 import BaseModel +from pydantic import BaseModel class ConstitutionalPrinciple(BaseModel): diff --git a/libs/langchain/langchain/chains/conversation/base.py b/libs/langchain/langchain/chains/conversation/base.py index dee4b9dbc0cce..881eed49e8c61 100644 --- a/libs/langchain/langchain/chains/conversation/base.py +++ b/libs/langchain/langchain/chains/conversation/base.py @@ -1,11 +1,12 @@ """Chain that carries on a conversation and calls an LLM.""" -from typing import Dict, List +from typing import List from langchain_core._api import deprecated from langchain_core.memory import BaseMemory from langchain_core.prompts import BasePromptTemplate -from langchain_core.pydantic_v1 import Field, root_validator +from pydantic import ConfigDict, Field, model_validator +from typing_extensions import Self from langchain.chains.conversation.prompt import PROMPT from langchain.chains.llm import LLMChain @@ -110,9 +111,10 @@ def get_session_history(session_id: str) -> InMemoryChatMessageHistory: input_key: str = "input" #: :meta private: output_key: str = "response" #: :meta private: - class Config: - arbitrary_types_allowed = True - extra = "forbid" + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + ) @classmethod def is_lc_serializable(cls) -> bool: @@ -123,17 +125,17 @@ def input_keys(self) -> List[str]: """Use this since so some prompt vars come from history.""" return [self.input_key] - @root_validator(pre=False, skip_on_failure=True) - def validate_prompt_input_variables(cls, values: Dict) -> Dict: + @model_validator(mode="after") + def validate_prompt_input_variables(self) -> Self: """Validate that prompt input variables are consistent.""" - memory_keys = values["memory"].memory_variables - input_key = values["input_key"] + memory_keys = self.memory.memory_variables + input_key = self.input_key if input_key in memory_keys: raise ValueError( f"The input key {input_key} was also found in the memory keys " f"({memory_keys}) - please provide keys that don't overlap." ) - prompt_variables = values["prompt"].input_variables + prompt_variables = self.prompt.input_variables expected_keys = memory_keys + [input_key] if set(expected_keys) != set(prompt_variables): raise ValueError( @@ -141,4 +143,4 @@ def validate_prompt_input_variables(cls, values: Dict) -> Dict: f"{prompt_variables}, but got {memory_keys} as inputs from " f"memory, and {input_key} as the normal input key." ) - return values + return self diff --git a/libs/langchain/langchain/chains/conversational_retrieval/base.py b/libs/langchain/langchain/chains/conversational_retrieval/base.py index 4251ced2fdf1a..3c653c433362e 100644 --- a/libs/langchain/langchain/chains/conversational_retrieval/base.py +++ b/libs/langchain/langchain/chains/conversational_retrieval/base.py @@ -18,10 +18,10 @@ from langchain_core.language_models import BaseLanguageModel from langchain_core.messages import BaseMessage from langchain_core.prompts import BasePromptTemplate -from langchain_core.pydantic_v1 import BaseModel, Field, root_validator from langchain_core.retrievers import BaseRetriever from langchain_core.runnables import RunnableConfig from langchain_core.vectorstores import VectorStore +from pydantic import BaseModel, ConfigDict, Field, model_validator from langchain.chains.base import Chain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain @@ -96,10 +96,11 @@ class BaseConversationalRetrievalChain(Chain): """If specified, the chain will return a fixed response if no docs are found for the question. """ - class Config: - allow_population_by_field_name = True - arbitrary_types_allowed = True - extra = "forbid" + model_config = ConfigDict( + populate_by_name=True, + arbitrary_types_allowed=True, + extra="forbid", + ) @property def input_keys(self) -> List[str]: @@ -482,8 +483,9 @@ class ChatVectorDBChain(BaseConversationalRetrievalChain): def _chain_type(self) -> str: return "chat-vector-db" - @root_validator(pre=True) - def raise_deprecation(cls, values: Dict) -> Dict: + @model_validator(mode="before") + @classmethod + def raise_deprecation(cls, values: Dict) -> Any: warnings.warn( "`ChatVectorDBChain` is deprecated - " "please use `from langchain.chains import ConversationalRetrievalChain`" diff --git a/libs/langchain/langchain/chains/elasticsearch_database/base.py b/libs/langchain/langchain/chains/elasticsearch_database/base.py index 89875f2d8a425..85bf7de93d29a 100644 --- a/libs/langchain/langchain/chains/elasticsearch_database/base.py +++ b/libs/langchain/langchain/chains/elasticsearch_database/base.py @@ -9,8 +9,9 @@ from langchain_core.output_parsers import BaseOutputParser, StrOutputParser from langchain_core.output_parsers.json import SimpleJsonOutputParser from langchain_core.prompts import BasePromptTemplate -from langchain_core.pydantic_v1 import root_validator from langchain_core.runnables import Runnable +from pydantic import ConfigDict, model_validator +from typing_extensions import Self from langchain.chains.base import Chain from langchain.chains.elasticsearch_database.prompts import ANSWER_PROMPT, DSL_PROMPT @@ -51,17 +52,18 @@ class ElasticsearchDatabaseChain(Chain): return_intermediate_steps: bool = False """Whether or not to return the intermediate steps along with the final answer.""" - class Config: - arbitrary_types_allowed = True - extra = "forbid" + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + ) - @root_validator(pre=False, skip_on_failure=True) - def validate_indices(cls, values: dict) -> dict: - if values["include_indices"] and values["ignore_indices"]: + @model_validator(mode="after") + def validate_indices(self) -> Self: + if self.include_indices and self.ignore_indices: raise ValueError( "Cannot specify both 'include_indices' and 'ignore_indices'." ) - return values + return self @property def input_keys(self) -> List[str]: diff --git a/libs/langchain/langchain/chains/flare/base.py b/libs/langchain/langchain/chains/flare/base.py index 53f3dd1e40e82..caf64fe77aa40 100644 --- a/libs/langchain/langchain/chains/flare/base.py +++ b/libs/langchain/langchain/chains/flare/base.py @@ -11,9 +11,9 @@ from langchain_core.messages import AIMessage from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import BasePromptTemplate -from langchain_core.pydantic_v1 import Field from langchain_core.retrievers import BaseRetriever from langchain_core.runnables import Runnable +from pydantic import Field from langchain.chains.base import Chain from langchain.chains.flare.prompts import ( diff --git a/libs/langchain/langchain/chains/hyde/base.py b/libs/langchain/langchain/chains/hyde/base.py index 833999127b659..cf64b29cada46 100644 --- a/libs/langchain/langchain/chains/hyde/base.py +++ b/libs/langchain/langchain/chains/hyde/base.py @@ -14,6 +14,7 @@ from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import BasePromptTemplate from langchain_core.runnables import Runnable +from pydantic import ConfigDict from langchain.chains.base import Chain from langchain.chains.hyde.prompts import PROMPT_MAP @@ -29,9 +30,10 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings): base_embeddings: Embeddings llm_chain: Runnable - class Config: - arbitrary_types_allowed = True - extra = "forbid" + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + ) @property def input_keys(self) -> List[str]: diff --git a/libs/langchain/langchain/chains/llm.py b/libs/langchain/langchain/chains/llm.py index feb3468fc44e4..f69aad7838bde 100644 --- a/libs/langchain/langchain/chains/llm.py +++ b/libs/langchain/langchain/chains/llm.py @@ -23,7 +23,6 @@ from langchain_core.outputs import ChatGeneration, Generation, LLMResult from langchain_core.prompt_values import PromptValue from langchain_core.prompts import BasePromptTemplate, PromptTemplate -from langchain_core.pydantic_v1 import Field from langchain_core.runnables import ( Runnable, RunnableBinding, @@ -32,6 +31,7 @@ ) from langchain_core.runnables.configurable import DynamicRunnable from langchain_core.utils.input import get_colored_text +from pydantic import ConfigDict, Field from langchain.chains.base import Chain @@ -95,9 +95,10 @@ def is_lc_serializable(self) -> bool: If false, will return a bunch of extra information about the generation.""" llm_kwargs: dict = Field(default_factory=dict) - class Config: - arbitrary_types_allowed = True - extra = "forbid" + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + ) @property def input_keys(self) -> List[str]: diff --git a/libs/langchain/langchain/chains/llm_checker/base.py b/libs/langchain/langchain/chains/llm_checker/base.py index ea2bc546a5791..bfff3d1b40dd8 100644 --- a/libs/langchain/langchain/chains/llm_checker/base.py +++ b/libs/langchain/langchain/chains/llm_checker/base.py @@ -9,7 +9,7 @@ from langchain_core.callbacks import CallbackManagerForChainRun from langchain_core.language_models import BaseLanguageModel from langchain_core.prompts import PromptTemplate -from langchain_core.pydantic_v1 import root_validator +from pydantic import ConfigDict, model_validator from langchain.chains.base import Chain from langchain.chains.llm import LLMChain @@ -100,12 +100,14 @@ class LLMCheckerChain(Chain): input_key: str = "query" #: :meta private: output_key: str = "result" #: :meta private: - class Config: - arbitrary_types_allowed = True - extra = "forbid" + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + ) - @root_validator(pre=True) - def raise_deprecation(cls, values: Dict) -> Dict: + @model_validator(mode="before") + @classmethod + def raise_deprecation(cls, values: Dict) -> Any: if "llm" in values: warnings.warn( "Directly instantiating an LLMCheckerChain with an llm is deprecated. " diff --git a/libs/langchain/langchain/chains/llm_math/base.py b/libs/langchain/langchain/chains/llm_math/base.py index e7fd89dcd542c..5bc51bf253e54 100644 --- a/libs/langchain/langchain/chains/llm_math/base.py +++ b/libs/langchain/langchain/chains/llm_math/base.py @@ -14,7 +14,7 @@ ) from langchain_core.language_models import BaseLanguageModel from langchain_core.prompts import BasePromptTemplate -from langchain_core.pydantic_v1 import root_validator +from pydantic import ConfigDict, model_validator from langchain.chains.base import Chain from langchain.chains.llm import LLMChain @@ -156,12 +156,14 @@ async def acall_model(state: ChainState, config: RunnableConfig): input_key: str = "question" #: :meta private: output_key: str = "answer" #: :meta private: - class Config: - arbitrary_types_allowed = True - extra = "forbid" + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + ) - @root_validator(pre=True) - def raise_deprecation(cls, values: Dict) -> Dict: + @model_validator(mode="before") + @classmethod + def raise_deprecation(cls, values: Dict) -> Any: try: import numexpr # noqa: F401 except ImportError: diff --git a/libs/langchain/langchain/chains/llm_summarization_checker/base.py b/libs/langchain/langchain/chains/llm_summarization_checker/base.py index f177f401529e4..c7d075dbf4405 100644 --- a/libs/langchain/langchain/chains/llm_summarization_checker/base.py +++ b/libs/langchain/langchain/chains/llm_summarization_checker/base.py @@ -10,7 +10,7 @@ from langchain_core.callbacks import CallbackManagerForChainRun from langchain_core.language_models import BaseLanguageModel from langchain_core.prompts.prompt import PromptTemplate -from langchain_core.pydantic_v1 import root_validator +from pydantic import ConfigDict, model_validator from langchain.chains.base import Chain from langchain.chains.llm import LLMChain @@ -105,12 +105,14 @@ class LLMSummarizationCheckerChain(Chain): max_checks: int = 2 """Maximum number of times to check the assertions. Default to double-checking.""" - class Config: - arbitrary_types_allowed = True - extra = "forbid" + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + ) - @root_validator(pre=True) - def raise_deprecation(cls, values: Dict) -> Dict: + @model_validator(mode="before") + @classmethod + def raise_deprecation(cls, values: Dict) -> Any: if "llm" in values: warnings.warn( "Directly instantiating an LLMSummarizationCheckerChain with an llm is " diff --git a/libs/langchain/langchain/chains/mapreduce.py b/libs/langchain/langchain/chains/mapreduce.py index 1eaccf67a850f..c153b7f9a6dae 100644 --- a/libs/langchain/langchain/chains/mapreduce.py +++ b/libs/langchain/langchain/chains/mapreduce.py @@ -14,6 +14,7 @@ from langchain_core.language_models import BaseLanguageModel from langchain_core.prompts import BasePromptTemplate from langchain_text_splitters import TextSplitter +from pydantic import ConfigDict from langchain.chains import ReduceDocumentsChain from langchain.chains.base import Chain @@ -77,9 +78,10 @@ def from_params( **kwargs, ) - class Config: - arbitrary_types_allowed = True - extra = "forbid" + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + ) @property def input_keys(self) -> List[str]: diff --git a/libs/langchain/langchain/chains/moderation.py b/libs/langchain/langchain/chains/moderation.py index a4b3551491c45..670b4773e3d63 100644 --- a/libs/langchain/langchain/chains/moderation.py +++ b/libs/langchain/langchain/chains/moderation.py @@ -6,8 +6,8 @@ AsyncCallbackManagerForChainRun, CallbackManagerForChainRun, ) -from langchain_core.pydantic_v1 import Field, root_validator from langchain_core.utils import check_package_version, get_from_dict_or_env +from pydantic import Field, model_validator from langchain.chains.base import Chain @@ -40,8 +40,9 @@ class OpenAIModerationChain(Chain): openai_organization: Optional[str] = None openai_pre_1_0: bool = Field(default=None) - @root_validator(pre=True) - def validate_environment(cls, values: Dict) -> Dict: + @model_validator(mode="before") + @classmethod + def validate_environment(cls, values: Dict) -> Any: """Validate that api key and python package exists in environment.""" openai_api_key = get_from_dict_or_env( values, "openai_api_key", "OPENAI_API_KEY" diff --git a/libs/langchain/langchain/chains/natbot/base.py b/libs/langchain/langchain/chains/natbot/base.py index e92131ff35cad..aca7fca8a7d8b 100644 --- a/libs/langchain/langchain/chains/natbot/base.py +++ b/libs/langchain/langchain/chains/natbot/base.py @@ -9,8 +9,8 @@ from langchain_core.callbacks import CallbackManagerForChainRun from langchain_core.language_models import BaseLanguageModel from langchain_core.output_parsers import StrOutputParser -from langchain_core.pydantic_v1 import root_validator from langchain_core.runnables import Runnable +from pydantic import ConfigDict, model_validator from langchain.chains.base import Chain from langchain.chains.natbot.prompt import PROMPT @@ -59,12 +59,14 @@ class NatBotChain(Chain): previous_command: str = "" #: :meta private: output_key: str = "command" #: :meta private: - class Config: - arbitrary_types_allowed = True - extra = "forbid" + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + ) - @root_validator(pre=True) - def raise_deprecation(cls, values: Dict) -> Dict: + @model_validator(mode="before") + @classmethod + def raise_deprecation(cls, values: Dict) -> Any: if "llm" in values: warnings.warn( "Directly instantiating an NatBotChain with an llm is deprecated. " diff --git a/libs/langchain/langchain/chains/openai_functions/base.py b/libs/langchain/langchain/chains/openai_functions/base.py index 568d992de990f..729313ab6b4d8 100644 --- a/libs/langchain/langchain/chains/openai_functions/base.py +++ b/libs/langchain/langchain/chains/openai_functions/base.py @@ -19,11 +19,11 @@ PydanticAttrOutputFunctionsParser, ) from langchain_core.prompts import BasePromptTemplate -from langchain_core.pydantic_v1 import BaseModel from langchain_core.utils.function_calling import ( PYTHON_TO_JSON_TYPES, convert_to_openai_function, ) +from pydantic import BaseModel from langchain.chains import LLMChain from langchain.chains.structured_output.base import ( diff --git a/libs/langchain/langchain/chains/openai_functions/citation_fuzzy_match.py b/libs/langchain/langchain/chains/openai_functions/citation_fuzzy_match.py index 038489d13a696..e9a83e8abc6e8 100644 --- a/libs/langchain/langchain/chains/openai_functions/citation_fuzzy_match.py +++ b/libs/langchain/langchain/chains/openai_functions/citation_fuzzy_match.py @@ -5,8 +5,8 @@ from langchain_core.messages import HumanMessage, SystemMessage from langchain_core.output_parsers.openai_functions import PydanticOutputFunctionsParser from langchain_core.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate -from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.runnables import Runnable +from pydantic import BaseModel, Field from langchain.chains.llm import LLMChain from langchain.chains.openai_functions.utils import get_llm_kwargs diff --git a/libs/langchain/langchain/chains/openai_functions/extraction.py b/libs/langchain/langchain/chains/openai_functions/extraction.py index ec76ad8a6cbbe..f6b9debad4959 100644 --- a/libs/langchain/langchain/chains/openai_functions/extraction.py +++ b/libs/langchain/langchain/chains/openai_functions/extraction.py @@ -7,7 +7,7 @@ PydanticAttrOutputFunctionsParser, ) from langchain_core.prompts import BasePromptTemplate, ChatPromptTemplate -from langchain_core.pydantic_v1 import BaseModel +from pydantic import BaseModel from langchain.chains.base import Chain from langchain.chains.llm import LLMChain diff --git a/libs/langchain/langchain/chains/openai_functions/qa_with_structure.py b/libs/langchain/langchain/chains/openai_functions/qa_with_structure.py index f13e2f9e522f5..5bab1c1ced60a 100644 --- a/libs/langchain/langchain/chains/openai_functions/qa_with_structure.py +++ b/libs/langchain/langchain/chains/openai_functions/qa_with_structure.py @@ -10,8 +10,8 @@ ) from langchain_core.prompts import PromptTemplate from langchain_core.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate -from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.utils.pydantic import is_basemodel_subclass +from pydantic import BaseModel, Field from langchain.chains.llm import LLMChain from langchain.chains.openai_functions.utils import get_llm_kwargs diff --git a/libs/langchain/langchain/chains/openai_tools/extraction.py b/libs/langchain/langchain/chains/openai_tools/extraction.py index 55251f5186784..189ef77d9d4d1 100644 --- a/libs/langchain/langchain/chains/openai_tools/extraction.py +++ b/libs/langchain/langchain/chains/openai_tools/extraction.py @@ -4,9 +4,9 @@ from langchain_core.language_models import BaseLanguageModel from langchain_core.output_parsers.openai_tools import PydanticToolsParser from langchain_core.prompts import ChatPromptTemplate -from langchain_core.pydantic_v1 import BaseModel from langchain_core.runnables import Runnable from langchain_core.utils.function_calling import convert_pydantic_to_openai_function +from pydantic import BaseModel _EXTRACTION_TEMPLATE = """Extract and save the relevant entities mentioned \ in the following passage together with their properties. diff --git a/libs/langchain/langchain/chains/prompt_selector.py b/libs/langchain/langchain/chains/prompt_selector.py index 453e2ea03577f..4014cdc1fbbbf 100644 --- a/libs/langchain/langchain/chains/prompt_selector.py +++ b/libs/langchain/langchain/chains/prompt_selector.py @@ -5,7 +5,7 @@ from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.language_models.llms import BaseLLM from langchain_core.prompts import BasePromptTemplate -from langchain_core.pydantic_v1 import BaseModel, Field +from pydantic import BaseModel, Field class BasePromptSelector(BaseModel, ABC): diff --git a/libs/langchain/langchain/chains/qa_generation/base.py b/libs/langchain/langchain/chains/qa_generation/base.py index b66b8a5442599..a55c078610124 100644 --- a/libs/langchain/langchain/chains/qa_generation/base.py +++ b/libs/langchain/langchain/chains/qa_generation/base.py @@ -7,8 +7,8 @@ from langchain_core.callbacks import CallbackManagerForChainRun from langchain_core.language_models import BaseLanguageModel from langchain_core.prompts import BasePromptTemplate -from langchain_core.pydantic_v1 import Field from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter +from pydantic import Field from langchain.chains.base import Chain from langchain.chains.llm import LLMChain diff --git a/libs/langchain/langchain/chains/qa_with_sources/base.py b/libs/langchain/langchain/chains/qa_with_sources/base.py index aed2d57cf91e5..495f0ab79e78b 100644 --- a/libs/langchain/langchain/chains/qa_with_sources/base.py +++ b/libs/langchain/langchain/chains/qa_with_sources/base.py @@ -15,7 +15,7 @@ from langchain_core.documents import Document from langchain_core.language_models import BaseLanguageModel from langchain_core.prompts import BasePromptTemplate -from langchain_core.pydantic_v1 import root_validator +from pydantic import ConfigDict, model_validator from langchain.chains import ReduceDocumentsChain from langchain.chains.base import Chain @@ -97,9 +97,10 @@ def from_chain_type( ) return cls(combine_documents_chain=combine_documents_chain, **kwargs) - class Config: - arbitrary_types_allowed = True - extra = "forbid" + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + ) @property def input_keys(self) -> List[str]: @@ -120,8 +121,9 @@ def output_keys(self) -> List[str]: _output_keys = _output_keys + ["source_documents"] return _output_keys - @root_validator(pre=True) - def validate_naming(cls, values: Dict) -> Dict: + @model_validator(mode="before") + @classmethod + def validate_naming(cls, values: Dict) -> Any: """Fix backwards compatibility in naming.""" if "combine_document_chain" in values: values["combine_documents_chain"] = values.pop("combine_document_chain") diff --git a/libs/langchain/langchain/chains/qa_with_sources/retrieval.py b/libs/langchain/langchain/chains/qa_with_sources/retrieval.py index f4a924c8dd162..95485b9fe0e3c 100644 --- a/libs/langchain/langchain/chains/qa_with_sources/retrieval.py +++ b/libs/langchain/langchain/chains/qa_with_sources/retrieval.py @@ -7,8 +7,8 @@ CallbackManagerForChainRun, ) from langchain_core.documents import Document -from langchain_core.pydantic_v1 import Field from langchain_core.retrievers import BaseRetriever +from pydantic import Field from langchain.chains.combine_documents.stuff import StuffDocumentsChain from langchain.chains.qa_with_sources.base import BaseQAWithSourcesChain diff --git a/libs/langchain/langchain/chains/qa_with_sources/vector_db.py b/libs/langchain/langchain/chains/qa_with_sources/vector_db.py index ca594994f3a6b..6330db38bc9b4 100644 --- a/libs/langchain/langchain/chains/qa_with_sources/vector_db.py +++ b/libs/langchain/langchain/chains/qa_with_sources/vector_db.py @@ -8,8 +8,8 @@ CallbackManagerForChainRun, ) from langchain_core.documents import Document -from langchain_core.pydantic_v1 import Field, root_validator from langchain_core.vectorstores import VectorStore +from pydantic import Field, model_validator from langchain.chains.combine_documents.stuff import StuffDocumentsChain from langchain.chains.qa_with_sources.base import BaseQAWithSourcesChain @@ -61,8 +61,9 @@ async def _aget_docs( ) -> List[Document]: raise NotImplementedError("VectorDBQAWithSourcesChain does not support async") - @root_validator(pre=True) - def raise_deprecation(cls, values: Dict) -> Dict: + @model_validator(mode="before") + @classmethod + def raise_deprecation(cls, values: Dict) -> Any: warnings.warn( "`VectorDBQAWithSourcesChain` is deprecated - " "please use `from langchain.chains import RetrievalQAWithSourcesChain`" diff --git a/libs/langchain/langchain/chains/query_constructor/schema.py b/libs/langchain/langchain/chains/query_constructor/schema.py index 585addc7919a3..56103d9a9fede 100644 --- a/libs/langchain/langchain/chains/query_constructor/schema.py +++ b/libs/langchain/langchain/chains/query_constructor/schema.py @@ -1,4 +1,4 @@ -from langchain_core.pydantic_v1 import BaseModel +from pydantic import BaseModel, ConfigDict class AttributeInfo(BaseModel): @@ -8,6 +8,7 @@ class AttributeInfo(BaseModel): description: str type: str - class Config: - arbitrary_types_allowed = True - frozen = True + model_config = ConfigDict( + arbitrary_types_allowed=True, + frozen=True, + ) diff --git a/libs/langchain/langchain/chains/retrieval_qa/base.py b/libs/langchain/langchain/chains/retrieval_qa/base.py index 689dd8b0c217b..3d688474bd39a 100644 --- a/libs/langchain/langchain/chains/retrieval_qa/base.py +++ b/libs/langchain/langchain/chains/retrieval_qa/base.py @@ -16,9 +16,9 @@ from langchain_core.documents import Document from langchain_core.language_models import BaseLanguageModel from langchain_core.prompts import PromptTemplate -from langchain_core.pydantic_v1 import Field, root_validator from langchain_core.retrievers import BaseRetriever from langchain_core.vectorstores import VectorStore +from pydantic import ConfigDict, Field, model_validator from langchain.chains.base import Chain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain @@ -47,10 +47,11 @@ class BaseRetrievalQA(Chain): return_source_documents: bool = False """Return the source documents or not.""" - class Config: - allow_population_by_field_name = True - arbitrary_types_allowed = True - extra = "forbid" + model_config = ConfigDict( + populate_by_name=True, + arbitrary_types_allowed=True, + extra="forbid", + ) @property def input_keys(self) -> List[str]: @@ -309,16 +310,18 @@ class VectorDBQA(BaseRetrievalQA): search_kwargs: Dict[str, Any] = Field(default_factory=dict) """Extra search args.""" - @root_validator(pre=True) - def raise_deprecation(cls, values: Dict) -> Dict: + @model_validator(mode="before") + @classmethod + def raise_deprecation(cls, values: Dict) -> Any: warnings.warn( "`VectorDBQA` is deprecated - " "please use `from langchain.chains import RetrievalQA`" ) return values - @root_validator(pre=True) - def validate_search_type(cls, values: Dict) -> Dict: + @model_validator(mode="before") + @classmethod + def validate_search_type(cls, values: Dict) -> Any: """Validate search type.""" if "search_type" in values: search_type = values["search_type"] diff --git a/libs/langchain/langchain/chains/router/base.py b/libs/langchain/langchain/chains/router/base.py index d0b680dd952b1..fa489c8110c33 100644 --- a/libs/langchain/langchain/chains/router/base.py +++ b/libs/langchain/langchain/chains/router/base.py @@ -10,6 +10,7 @@ CallbackManagerForChainRun, Callbacks, ) +from pydantic import ConfigDict from langchain.chains.base import Chain @@ -60,9 +61,10 @@ class MultiRouteChain(Chain): """If True, use default_chain when an invalid destination name is provided. Defaults to False.""" - class Config: - arbitrary_types_allowed = True - extra = "forbid" + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + ) @property def input_keys(self) -> List[str]: diff --git a/libs/langchain/langchain/chains/router/embedding_router.py b/libs/langchain/langchain/chains/router/embedding_router.py index a1bc126a49f2a..0f44dda02ff43 100644 --- a/libs/langchain/langchain/chains/router/embedding_router.py +++ b/libs/langchain/langchain/chains/router/embedding_router.py @@ -9,6 +9,7 @@ from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.vectorstores import VectorStore +from pydantic import ConfigDict from langchain.chains.router.base import RouterChain @@ -19,9 +20,10 @@ class EmbeddingRouterChain(RouterChain): vectorstore: VectorStore routing_keys: List[str] = ["query"] - class Config: - arbitrary_types_allowed = True - extra = "forbid" + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + ) @property def input_keys(self) -> List[str]: diff --git a/libs/langchain/langchain/chains/router/llm_router.py b/libs/langchain/langchain/chains/router/llm_router.py index 132f350e819c9..aa72ce4c22a25 100644 --- a/libs/langchain/langchain/chains/router/llm_router.py +++ b/libs/langchain/langchain/chains/router/llm_router.py @@ -13,8 +13,9 @@ from langchain_core.language_models import BaseLanguageModel from langchain_core.output_parsers import BaseOutputParser from langchain_core.prompts import BasePromptTemplate -from langchain_core.pydantic_v1 import root_validator from langchain_core.utils.json import parse_and_check_json_markdown +from pydantic import model_validator +from typing_extensions import Self from langchain.chains import LLMChain from langchain.chains.router.base import RouterChain @@ -100,9 +101,9 @@ class RouteQuery(TypedDict): llm_chain: LLMChain """LLM chain used to perform routing""" - @root_validator(pre=False, skip_on_failure=True) - def validate_prompt(cls, values: dict) -> dict: - prompt = values["llm_chain"].prompt + @model_validator(mode="after") + def validate_prompt(self) -> Self: + prompt = self.llm_chain.prompt if prompt.output_parser is None: raise ValueError( "LLMRouterChain requires base llm_chain prompt to have an output" @@ -110,7 +111,7 @@ def validate_prompt(cls, values: dict) -> dict: " 'destination' and 'next_inputs'. Received a prompt with no output" " parser." ) - return values + return self @property def input_keys(self) -> List[str]: diff --git a/libs/langchain/langchain/chains/sequential.py b/libs/langchain/langchain/chains/sequential.py index e75300f7cbf84..b19f65e9aa2c3 100644 --- a/libs/langchain/langchain/chains/sequential.py +++ b/libs/langchain/langchain/chains/sequential.py @@ -6,8 +6,9 @@ AsyncCallbackManagerForChainRun, CallbackManagerForChainRun, ) -from langchain_core.pydantic_v1 import root_validator from langchain_core.utils.input import get_color_mapping +from pydantic import ConfigDict, model_validator +from typing_extensions import Self from langchain.chains.base import Chain @@ -20,9 +21,10 @@ class SequentialChain(Chain): output_variables: List[str] #: :meta private: return_all: bool = False - class Config: - arbitrary_types_allowed = True - extra = "forbid" + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + ) @property def input_keys(self) -> List[str]: @@ -40,8 +42,9 @@ def output_keys(self) -> List[str]: """ return self.output_variables - @root_validator(pre=True) - def validate_chains(cls, values: Dict) -> Dict: + @model_validator(mode="before") + @classmethod + def validate_chains(cls, values: Dict) -> Any: """Validate that the correct inputs exist for all chains.""" chains = values["chains"] input_variables = values["input_variables"] @@ -129,9 +132,10 @@ class SimpleSequentialChain(Chain): input_key: str = "input" #: :meta private: output_key: str = "output" #: :meta private: - class Config: - arbitrary_types_allowed = True - extra = "forbid" + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + ) @property def input_keys(self) -> List[str]: @@ -149,10 +153,10 @@ def output_keys(self) -> List[str]: """ return [self.output_key] - @root_validator(pre=False, skip_on_failure=True) - def validate_chains(cls, values: Dict) -> Dict: + @model_validator(mode="after") + def validate_chains(self) -> Self: """Validate that chains are all single input/output.""" - for chain in values["chains"]: + for chain in self.chains: if len(chain.input_keys) != 1: raise ValueError( "Chains used in SimplePipeline should all have one input, got " @@ -163,7 +167,7 @@ def validate_chains(cls, values: Dict) -> Dict: "Chains used in SimplePipeline should all have one output, got " f"{chain} with {len(chain.output_keys)} outputs." ) - return values + return self def _call( self, diff --git a/libs/langchain/langchain/chains/structured_output/base.py b/libs/langchain/langchain/chains/structured_output/base.py index 14526d014ccc4..0cd8be4f144d3 100644 --- a/libs/langchain/langchain/chains/structured_output/base.py +++ b/libs/langchain/langchain/chains/structured_output/base.py @@ -18,13 +18,13 @@ PydanticToolsParser, ) from langchain_core.prompts import BasePromptTemplate -from langchain_core.pydantic_v1 import BaseModel from langchain_core.runnables import Runnable from langchain_core.utils.function_calling import ( convert_to_openai_function, convert_to_openai_tool, ) from langchain_core.utils.pydantic import is_basemodel_subclass +from pydantic import BaseModel @deprecated( diff --git a/libs/langchain/langchain/chains/transform.py b/libs/langchain/langchain/chains/transform.py index e95bfc9a85a84..2812722369b72 100644 --- a/libs/langchain/langchain/chains/transform.py +++ b/libs/langchain/langchain/chains/transform.py @@ -8,7 +8,7 @@ AsyncCallbackManagerForChainRun, CallbackManagerForChainRun, ) -from langchain_core.pydantic_v1 import Field +from pydantic import Field from langchain.chains.base import Chain diff --git a/libs/langchain/langchain/chat_models/base.py b/libs/langchain/langchain/chat_models/base.py index ce50f7b7b095d..63411ea152993 100644 --- a/libs/langchain/langchain/chat_models/base.py +++ b/libs/langchain/langchain/chat_models/base.py @@ -30,11 +30,11 @@ generate_from_stream, ) from langchain_core.messages import AnyMessage, BaseMessage -from langchain_core.pydantic_v1 import BaseModel from langchain_core.runnables import Runnable, RunnableConfig from langchain_core.runnables.schema import StreamEvent from langchain_core.tools import BaseTool from langchain_core.tracers import RunLog, RunLogPatch +from pydantic import BaseModel from typing_extensions import TypeAlias __all__ = [ diff --git a/libs/langchain/langchain/evaluation/agents/trajectory_eval_chain.py b/libs/langchain/langchain/evaluation/agents/trajectory_eval_chain.py index 791b95f64cf7a..1d52c5a78e3e3 100644 --- a/libs/langchain/langchain/evaluation/agents/trajectory_eval_chain.py +++ b/libs/langchain/langchain/evaluation/agents/trajectory_eval_chain.py @@ -28,8 +28,8 @@ from langchain_core.language_models import BaseLanguageModel from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.output_parsers import BaseOutputParser -from langchain_core.pydantic_v1 import Field from langchain_core.tools import BaseTool +from pydantic import ConfigDict, Field from langchain.chains.llm import LLMChain from langchain.evaluation.agents.trajectory_eval_prompt import ( @@ -156,8 +156,9 @@ def geography_answers(country: str, question: str) -> str: return_reasoning: bool = False # :meta private: """DEPRECATED. Reasoning always returned.""" - class Config: - extra = "ignore" + model_config = ConfigDict( + extra="ignore", + ) @property def requires_reference(self) -> bool: diff --git a/libs/langchain/langchain/evaluation/comparison/eval_chain.py b/libs/langchain/langchain/evaluation/comparison/eval_chain.py index d76f836ad8725..46b9001e48fd7 100644 --- a/libs/langchain/langchain/evaluation/comparison/eval_chain.py +++ b/libs/langchain/langchain/evaluation/comparison/eval_chain.py @@ -10,7 +10,7 @@ from langchain_core.language_models import BaseLanguageModel from langchain_core.output_parsers import BaseOutputParser from langchain_core.prompts.prompt import PromptTemplate -from langchain_core.pydantic_v1 import Field +from pydantic import ConfigDict, Field from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple from langchain.chains.llm import LLMChain @@ -191,8 +191,9 @@ class PairwiseStringEvalChain(PairwiseStringEvaluator, LLMEvalChain, LLMChain): def is_lc_serializable(cls) -> bool: return False - class Config: - extra = "ignore" + model_config = ConfigDict( + extra="ignore", + ) @property def requires_reference(self) -> bool: diff --git a/libs/langchain/langchain/evaluation/criteria/eval_chain.py b/libs/langchain/langchain/evaluation/criteria/eval_chain.py index 34fc656cbc7e5..896b1efceabc2 100644 --- a/libs/langchain/langchain/evaluation/criteria/eval_chain.py +++ b/libs/langchain/langchain/evaluation/criteria/eval_chain.py @@ -8,7 +8,7 @@ from langchain_core.language_models import BaseLanguageModel from langchain_core.output_parsers import BaseOutputParser from langchain_core.prompts import BasePromptTemplate -from langchain_core.pydantic_v1 import Field +from pydantic import ConfigDict, Field from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple from langchain.chains.llm import LLMChain @@ -236,8 +236,9 @@ class CriteriaEvalChain(StringEvaluator, LLMEvalChain, LLMChain): def is_lc_serializable(cls) -> bool: return False - class Config: - extra = "ignore" + model_config = ConfigDict( + extra="ignore", + ) @property def requires_reference(self) -> bool: diff --git a/libs/langchain/langchain/evaluation/embedding_distance/base.py b/libs/langchain/langchain/evaluation/embedding_distance/base.py index d983c72fbf00f..569838841cd04 100644 --- a/libs/langchain/langchain/evaluation/embedding_distance/base.py +++ b/libs/langchain/langchain/evaluation/embedding_distance/base.py @@ -10,8 +10,8 @@ Callbacks, ) from langchain_core.embeddings import Embeddings -from langchain_core.pydantic_v1 import Field from langchain_core.utils import pre_init +from pydantic import ConfigDict, Field from langchain.chains.base import Chain from langchain.evaluation.schema import PairwiseStringEvaluator, StringEvaluator @@ -113,8 +113,9 @@ def _validate_tiktoken_installed(cls, values: Dict[str, Any]) -> Dict[str, Any]: ) return values - class Config: - arbitrary_types_allowed: bool = True + model_config = ConfigDict( + arbitrary_types_allowed=True, + ) @property def output_keys(self) -> List[str]: diff --git a/libs/langchain/langchain/evaluation/qa/eval_chain.py b/libs/langchain/langchain/evaluation/qa/eval_chain.py index 0204d8fe90161..345bbd87bc9f7 100644 --- a/libs/langchain/langchain/evaluation/qa/eval_chain.py +++ b/libs/langchain/langchain/evaluation/qa/eval_chain.py @@ -9,6 +9,7 @@ from langchain_core.callbacks.manager import Callbacks from langchain_core.language_models import BaseLanguageModel from langchain_core.prompts import PromptTemplate +from pydantic import ConfigDict from langchain.chains.llm import LLMChain from langchain.evaluation.qa.eval_prompt import CONTEXT_PROMPT, COT_PROMPT, PROMPT @@ -72,8 +73,9 @@ class QAEvalChain(LLMChain, StringEvaluator, LLMEvalChain): output_key: str = "results" #: :meta private: - class Config: - extra = "ignore" + model_config = ConfigDict( + extra="ignore", + ) @classmethod def is_lc_serializable(cls) -> bool: @@ -220,8 +222,9 @@ def requires_input(self) -> bool: """Whether the chain requires an input string.""" return True - class Config: - extra = "ignore" + model_config = ConfigDict( + extra="ignore", + ) @classmethod def _validate_input_vars(cls, prompt: PromptTemplate) -> None: diff --git a/libs/langchain/langchain/evaluation/qa/generate_chain.py b/libs/langchain/langchain/evaluation/qa/generate_chain.py index 32dea149a447b..94cf36d45a7d4 100644 --- a/libs/langchain/langchain/evaluation/qa/generate_chain.py +++ b/libs/langchain/langchain/evaluation/qa/generate_chain.py @@ -6,7 +6,7 @@ from langchain_core.language_models import BaseLanguageModel from langchain_core.output_parsers import BaseLLMOutputParser -from langchain_core.pydantic_v1 import Field +from pydantic import Field from langchain.chains.llm import LLMChain from langchain.evaluation.qa.generate_prompt import PROMPT diff --git a/libs/langchain/langchain/evaluation/scoring/eval_chain.py b/libs/langchain/langchain/evaluation/scoring/eval_chain.py index 3b800f8ffc0d4..a8a84a05813d6 100644 --- a/libs/langchain/langchain/evaluation/scoring/eval_chain.py +++ b/libs/langchain/langchain/evaluation/scoring/eval_chain.py @@ -10,7 +10,7 @@ from langchain_core.language_models import BaseLanguageModel from langchain_core.output_parsers import BaseOutputParser from langchain_core.prompts.prompt import PromptTemplate -from langchain_core.pydantic_v1 import Field +from pydantic import ConfigDict, Field from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple from langchain.chains.llm import LLMChain @@ -179,8 +179,9 @@ class ScoreStringEvalChain(StringEvaluator, LLMEvalChain, LLMChain): criterion_name: str """The name of the criterion being evaluated.""" - class Config: - extra = "ignore" + model_config = ConfigDict( + extra="ignore", + ) @classmethod def is_lc_serializable(cls) -> bool: diff --git a/libs/langchain/langchain/evaluation/string_distance/base.py b/libs/langchain/langchain/evaluation/string_distance/base.py index bb9c9719d736f..396e267a7e5e5 100644 --- a/libs/langchain/langchain/evaluation/string_distance/base.py +++ b/libs/langchain/langchain/evaluation/string_distance/base.py @@ -8,8 +8,8 @@ CallbackManagerForChainRun, Callbacks, ) -from langchain_core.pydantic_v1 import Field from langchain_core.utils import pre_init +from pydantic import Field from langchain.chains.base import Chain from langchain.evaluation.schema import PairwiseStringEvaluator, StringEvaluator diff --git a/libs/langchain/langchain/indexes/vectorstore.py b/libs/langchain/langchain/indexes/vectorstore.py index deb408dff2d0c..08042c63d7292 100644 --- a/libs/langchain/langchain/indexes/vectorstore.py +++ b/libs/langchain/langchain/indexes/vectorstore.py @@ -4,9 +4,9 @@ from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.language_models import BaseLanguageModel -from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.vectorstores import VectorStore from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter +from pydantic import BaseModel, ConfigDict, Field from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain from langchain.chains.retrieval_qa.base import RetrievalQA @@ -21,9 +21,10 @@ class VectorStoreIndexWrapper(BaseModel): vectorstore: VectorStore - class Config: - arbitrary_types_allowed = True - extra = "forbid" + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + ) def query( self, @@ -142,9 +143,10 @@ class VectorstoreIndexCreator(BaseModel): text_splitter: TextSplitter = Field(default_factory=_get_default_text_splitter) vectorstore_kwargs: dict = Field(default_factory=dict) - class Config: - arbitrary_types_allowed = True - extra = "forbid" + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + ) def from_loaders(self, loaders: List[BaseLoader]) -> VectorStoreIndexWrapper: """Create a vectorstore index from loaders.""" diff --git a/libs/langchain/langchain/memory/chat_memory.py b/libs/langchain/langchain/memory/chat_memory.py index 10feaa3e1b95f..3053a8198ddcf 100644 --- a/libs/langchain/langchain/memory/chat_memory.py +++ b/libs/langchain/langchain/memory/chat_memory.py @@ -8,7 +8,7 @@ ) from langchain_core.memory import BaseMemory from langchain_core.messages import AIMessage, HumanMessage -from langchain_core.pydantic_v1 import Field +from pydantic import Field from langchain.memory.utils import get_prompt_input_key diff --git a/libs/langchain/langchain/memory/combined.py b/libs/langchain/langchain/memory/combined.py index 5ab0048895bba..f9117af4107cf 100644 --- a/libs/langchain/langchain/memory/combined.py +++ b/libs/langchain/langchain/memory/combined.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List, Set from langchain_core.memory import BaseMemory -from langchain_core.pydantic_v1 import validator +from pydantic import validator from langchain.memory.chat_memory import BaseChatMemory diff --git a/libs/langchain/langchain/memory/entity.py b/libs/langchain/langchain/memory/entity.py index 57fdb75537eb2..3032a0479b0e8 100644 --- a/libs/langchain/langchain/memory/entity.py +++ b/libs/langchain/langchain/memory/entity.py @@ -6,7 +6,7 @@ from langchain_core.language_models import BaseLanguageModel from langchain_core.messages import BaseMessage, get_buffer_string from langchain_core.prompts import BasePromptTemplate -from langchain_core.pydantic_v1 import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field from langchain.chains.llm import LLMChain from langchain.memory.chat_memory import BaseChatMemory @@ -245,8 +245,9 @@ class SQLiteEntityStore(BaseEntityStore): table_name: str = "memory_store" conn: Any = None - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict( + arbitrary_types_allowed=True, + ) def __init__( self, diff --git a/libs/langchain/langchain/memory/summary.py b/libs/langchain/langchain/memory/summary.py index 23c3f2bca1f43..e1d32003c7148 100644 --- a/libs/langchain/langchain/memory/summary.py +++ b/libs/langchain/langchain/memory/summary.py @@ -7,8 +7,8 @@ from langchain_core.language_models import BaseLanguageModel from langchain_core.messages import BaseMessage, SystemMessage, get_buffer_string from langchain_core.prompts import BasePromptTemplate -from langchain_core.pydantic_v1 import BaseModel from langchain_core.utils import pre_init +from pydantic import BaseModel from langchain.chains.llm import LLMChain from langchain.memory.chat_memory import BaseChatMemory diff --git a/libs/langchain/langchain/memory/vectorstore.py b/libs/langchain/langchain/memory/vectorstore.py index b719749b1ce37..a4511dd192529 100644 --- a/libs/langchain/langchain/memory/vectorstore.py +++ b/libs/langchain/langchain/memory/vectorstore.py @@ -3,8 +3,8 @@ from typing import Any, Dict, List, Optional, Sequence, Union from langchain_core.documents import Document -from langchain_core.pydantic_v1 import Field from langchain_core.vectorstores import VectorStoreRetriever +from pydantic import Field from langchain.memory.chat_memory import BaseMemory from langchain.memory.utils import get_prompt_input_key diff --git a/libs/langchain/langchain/memory/vectorstore_token_buffer_memory.py b/libs/langchain/langchain/memory/vectorstore_token_buffer_memory.py index f611c04903d07..293773e84a173 100644 --- a/libs/langchain/langchain/memory/vectorstore_token_buffer_memory.py +++ b/libs/langchain/langchain/memory/vectorstore_token_buffer_memory.py @@ -13,8 +13,8 @@ from langchain_core.messages import BaseMessage from langchain_core.prompts.chat import SystemMessagePromptTemplate -from langchain_core.pydantic_v1 import Field, PrivateAttr from langchain_core.vectorstores import VectorStoreRetriever +from pydantic import Field, PrivateAttr from langchain.memory import ConversationTokenBufferMemory, VectorStoreRetrieverMemory from langchain.memory.chat_memory import BaseChatMemory diff --git a/libs/langchain/langchain/output_parsers/fix.py b/libs/langchain/langchain/output_parsers/fix.py index a22bfff582be0..f0a1a701c2334 100644 --- a/libs/langchain/langchain/output_parsers/fix.py +++ b/libs/langchain/langchain/output_parsers/fix.py @@ -1,11 +1,12 @@ from __future__ import annotations -from typing import Any, TypeVar, Union +from typing import Annotated, Any, TypeVar, Union from langchain_core.exceptions import OutputParserException from langchain_core.output_parsers import BaseOutputParser, StrOutputParser from langchain_core.prompts import BasePromptTemplate from langchain_core.runnables import Runnable, RunnableSerializable +from pydantic import SkipValidation from typing_extensions import TypedDict from langchain.output_parsers.prompts import NAIVE_FIX_PROMPT @@ -26,7 +27,7 @@ class OutputFixingParser(BaseOutputParser[T]): def is_lc_serializable(cls) -> bool: return True - parser: BaseOutputParser[T] + parser: Annotated[BaseOutputParser[T], SkipValidation()] """The parser to use to parse the output.""" # Should be an LLMChain but we want to avoid top-level imports from langchain.chains retry_chain: Union[ diff --git a/libs/langchain/langchain/output_parsers/pandas_dataframe.py b/libs/langchain/langchain/output_parsers/pandas_dataframe.py index 3447767c088fa..2cc899bad8a27 100644 --- a/libs/langchain/langchain/output_parsers/pandas_dataframe.py +++ b/libs/langchain/langchain/output_parsers/pandas_dataframe.py @@ -3,7 +3,7 @@ from langchain_core.exceptions import OutputParserException from langchain_core.output_parsers.base import BaseOutputParser -from langchain_core.pydantic_v1 import validator +from pydantic import validator from langchain.output_parsers.format_instructions import ( PANDAS_DATAFRAME_FORMAT_INSTRUCTIONS, diff --git a/libs/langchain/langchain/output_parsers/retry.py b/libs/langchain/langchain/output_parsers/retry.py index cc56c627c0f0c..06e1b410f7291 100644 --- a/libs/langchain/langchain/output_parsers/retry.py +++ b/libs/langchain/langchain/output_parsers/retry.py @@ -8,7 +8,8 @@ from langchain_core.prompt_values import PromptValue from langchain_core.prompts import BasePromptTemplate, PromptTemplate from langchain_core.runnables import RunnableSerializable -from typing_extensions import TypedDict +from pydantic import SkipValidation +from typing_extensions import Annotated, TypedDict NAIVE_COMPLETION_RETRY = """Prompt: {prompt} @@ -53,7 +54,7 @@ class RetryOutputParser(BaseOutputParser[T]): LLM, and telling it the completion did not satisfy criteria in the prompt. """ - parser: BaseOutputParser[T] + parser: Annotated[BaseOutputParser[T], SkipValidation()] """The parser to use to parse the output.""" # Should be an LLMChain but we want to avoid top-level imports from langchain.chains retry_chain: Union[RunnableSerializable[RetryOutputParserRetryChainInput, str], Any] @@ -183,7 +184,7 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]): LLM, which in theory should give it more information on how to fix it. """ - parser: BaseOutputParser[T] + parser: Annotated[BaseOutputParser[T], SkipValidation()] """The parser to use to parse the output.""" # Should be an LLMChain but we want to avoid top-level imports from langchain.chains retry_chain: Union[ diff --git a/libs/langchain/langchain/output_parsers/structured.py b/libs/langchain/langchain/output_parsers/structured.py index 097e1a7170f4c..715be3c410ee9 100644 --- a/libs/langchain/langchain/output_parsers/structured.py +++ b/libs/langchain/langchain/output_parsers/structured.py @@ -4,7 +4,7 @@ from langchain_core.output_parsers import BaseOutputParser from langchain_core.output_parsers.json import parse_and_check_json_markdown -from langchain_core.pydantic_v1 import BaseModel +from pydantic import BaseModel from langchain.output_parsers.format_instructions import ( STRUCTURED_FORMAT_INSTRUCTIONS, diff --git a/libs/langchain/langchain/output_parsers/yaml.py b/libs/langchain/langchain/output_parsers/yaml.py index e7c071eb40068..5ea989b733a17 100644 --- a/libs/langchain/langchain/output_parsers/yaml.py +++ b/libs/langchain/langchain/output_parsers/yaml.py @@ -5,7 +5,7 @@ import yaml from langchain_core.exceptions import OutputParserException from langchain_core.output_parsers import BaseOutputParser -from langchain_core.pydantic_v1 import BaseModel, ValidationError +from pydantic import BaseModel, ValidationError from langchain.output_parsers.format_instructions import YAML_FORMAT_INSTRUCTIONS diff --git a/libs/langchain/langchain/retrievers/contextual_compression.py b/libs/langchain/langchain/retrievers/contextual_compression.py index c73180b889d3a..d5dccb13fb552 100644 --- a/libs/langchain/langchain/retrievers/contextual_compression.py +++ b/libs/langchain/langchain/retrievers/contextual_compression.py @@ -6,6 +6,7 @@ ) from langchain_core.documents import Document from langchain_core.retrievers import BaseRetriever, RetrieverLike +from pydantic import ConfigDict from langchain.retrievers.document_compressors.base import ( BaseDocumentCompressor, @@ -21,8 +22,9 @@ class ContextualCompressionRetriever(BaseRetriever): base_retriever: RetrieverLike """Base Retriever to use for getting relevant documents.""" - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict( + arbitrary_types_allowed=True, + ) def _get_relevant_documents( self, diff --git a/libs/langchain/langchain/retrievers/document_compressors/base.py b/libs/langchain/langchain/retrievers/document_compressors/base.py index a68515af8b023..dd25d428fa7fc 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/base.py +++ b/libs/langchain/langchain/retrievers/document_compressors/base.py @@ -7,6 +7,7 @@ BaseDocumentTransformer, Document, ) +from pydantic import ConfigDict class DocumentCompressorPipeline(BaseDocumentCompressor): @@ -15,8 +16,9 @@ class DocumentCompressorPipeline(BaseDocumentCompressor): transformers: List[Union[BaseDocumentTransformer, BaseDocumentCompressor]] """List of document filters that are chained together and run in sequence.""" - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict( + arbitrary_types_allowed=True, + ) def compress_documents( self, diff --git a/libs/langchain/langchain/retrievers/document_compressors/chain_extract.py b/libs/langchain/langchain/retrievers/document_compressors/chain_extract.py index cc86f2be49b73..9933319ad0485 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/chain_extract.py +++ b/libs/langchain/langchain/retrievers/document_compressors/chain_extract.py @@ -11,6 +11,7 @@ from langchain_core.output_parsers import BaseOutputParser, StrOutputParser from langchain_core.prompts import PromptTemplate from langchain_core.runnables import Runnable +from pydantic import ConfigDict from langchain.chains.llm import LLMChain from langchain.retrievers.document_compressors.base import BaseDocumentCompressor @@ -56,8 +57,9 @@ class LLMChainExtractor(BaseDocumentCompressor): get_input: Callable[[str, Document], dict] = default_get_input """Callable for constructing the chain input from the query and a Document.""" - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict( + arbitrary_types_allowed=True, + ) def compress_documents( self, diff --git a/libs/langchain/langchain/retrievers/document_compressors/chain_filter.py b/libs/langchain/langchain/retrievers/document_compressors/chain_filter.py index 2db6f5be3a7b2..bfa1cd694dc26 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/chain_filter.py +++ b/libs/langchain/langchain/retrievers/document_compressors/chain_filter.py @@ -9,6 +9,7 @@ from langchain_core.prompts import BasePromptTemplate, PromptTemplate from langchain_core.runnables import Runnable from langchain_core.runnables.config import RunnableConfig +from pydantic import ConfigDict from langchain.chains import LLMChain from langchain.output_parsers.boolean import BooleanOutputParser @@ -41,8 +42,9 @@ class LLMChainFilter(BaseDocumentCompressor): get_input: Callable[[str, Document], dict] = default_get_input """Callable for constructing the chain input from the query and a Document.""" - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict( + arbitrary_types_allowed=True, + ) def compress_documents( self, diff --git a/libs/langchain/langchain/retrievers/document_compressors/cohere_rerank.py b/libs/langchain/langchain/retrievers/document_compressors/cohere_rerank.py index f7d96c29df737..2030807ce310c 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/cohere_rerank.py +++ b/libs/langchain/langchain/retrievers/document_compressors/cohere_rerank.py @@ -6,8 +6,8 @@ from langchain_core._api.deprecation import deprecated from langchain_core.callbacks.manager import Callbacks from langchain_core.documents import Document -from langchain_core.pydantic_v1 import root_validator from langchain_core.utils import get_from_dict_or_env +from pydantic import ConfigDict, model_validator from langchain.retrievers.document_compressors.base import BaseDocumentCompressor @@ -30,12 +30,14 @@ class CohereRerank(BaseDocumentCompressor): user_agent: str = "langchain" """Identifier for the application making the request.""" - class Config: - arbitrary_types_allowed = True - extra = "forbid" + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + ) - @root_validator(pre=True) - def validate_environment(cls, values: Dict) -> Dict: + @model_validator(mode="before") + @classmethod + def validate_environment(cls, values: Dict) -> Any: """Validate that api key and python package exists in environment.""" if not values.get("client"): try: diff --git a/libs/langchain/langchain/retrievers/document_compressors/cross_encoder_rerank.py b/libs/langchain/langchain/retrievers/document_compressors/cross_encoder_rerank.py index d1b683f2d9b8f..fff77c15266bd 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/cross_encoder_rerank.py +++ b/libs/langchain/langchain/retrievers/document_compressors/cross_encoder_rerank.py @@ -5,6 +5,7 @@ from langchain_core.callbacks import Callbacks from langchain_core.documents import BaseDocumentCompressor, Document +from pydantic import ConfigDict from langchain.retrievers.document_compressors.cross_encoder import BaseCrossEncoder @@ -18,9 +19,10 @@ class CrossEncoderReranker(BaseDocumentCompressor): top_n: int = 3 """Number of documents to return.""" - class Config: - arbitrary_types_allowed = True - extra = "forbid" + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + ) def compress_documents( self, diff --git a/libs/langchain/langchain/retrievers/document_compressors/embeddings_filter.py b/libs/langchain/langchain/retrievers/document_compressors/embeddings_filter.py index d29a0e7ac5f61..8e3f1dbf43f4f 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/embeddings_filter.py +++ b/libs/langchain/langchain/retrievers/document_compressors/embeddings_filter.py @@ -4,8 +4,8 @@ from langchain_core.callbacks.manager import Callbacks from langchain_core.documents import Document from langchain_core.embeddings import Embeddings -from langchain_core.pydantic_v1 import Field from langchain_core.utils import pre_init +from pydantic import ConfigDict, Field from langchain.retrievers.document_compressors.base import ( BaseDocumentCompressor, @@ -41,8 +41,9 @@ class EmbeddingsFilter(BaseDocumentCompressor): to be considered redundant. Defaults to None, must be specified if `k` is set to None.""" - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict( + arbitrary_types_allowed=True, + ) @pre_init def validate_params(cls, values: Dict) -> Dict: diff --git a/libs/langchain/langchain/retrievers/document_compressors/listwise_rerank.py b/libs/langchain/langchain/retrievers/document_compressors/listwise_rerank.py index 16647df82c602..5039a36b6aba2 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/listwise_rerank.py +++ b/libs/langchain/langchain/retrievers/document_compressors/listwise_rerank.py @@ -6,8 +6,8 @@ from langchain_core.documents import BaseDocumentCompressor, Document from langchain_core.language_models import BaseLanguageModel from langchain_core.prompts import BasePromptTemplate, ChatPromptTemplate -from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough +from pydantic import BaseModel, ConfigDict, Field _default_system_tmpl = """{context} @@ -76,8 +76,9 @@ class LLMListwiseRerank(BaseDocumentCompressor): top_n: int = 3 """Number of documents to return.""" - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict( + arbitrary_types_allowed=True, + ) def compress_documents( self, diff --git a/libs/langchain/langchain/retrievers/ensemble.py b/libs/langchain/langchain/retrievers/ensemble.py index bf2cc46376c21..59a89e555b3e1 100644 --- a/libs/langchain/langchain/retrievers/ensemble.py +++ b/libs/langchain/langchain/retrievers/ensemble.py @@ -25,7 +25,6 @@ ) from langchain_core.documents import Document from langchain_core.load.dump import dumpd -from langchain_core.pydantic_v1 import root_validator from langchain_core.retrievers import BaseRetriever, RetrieverLike from langchain_core.runnables import RunnableConfig from langchain_core.runnables.config import ensure_config, patch_config @@ -33,6 +32,7 @@ ConfigurableFieldSpec, get_unique_config_specs, ) +from pydantic import model_validator T = TypeVar("T") H = TypeVar("H", bound=Hashable) @@ -83,8 +83,9 @@ def config_specs(self) -> List[ConfigurableFieldSpec]: spec for retriever in self.retrievers for spec in retriever.config_specs ) - @root_validator(pre=True) - def set_weights(cls, values: Dict[str, Any]) -> Dict[str, Any]: + @model_validator(mode="before") + @classmethod + def set_weights(cls, values: Dict[str, Any]) -> Any: if not values.get("weights"): n_retrievers = len(values["retrievers"]) values["weights"] = [1 / n_retrievers] * n_retrievers diff --git a/libs/langchain/langchain/retrievers/multi_vector.py b/libs/langchain/langchain/retrievers/multi_vector.py index 54a4d935dcd56..48e48d07ea6af 100644 --- a/libs/langchain/langchain/retrievers/multi_vector.py +++ b/libs/langchain/langchain/retrievers/multi_vector.py @@ -1,15 +1,15 @@ from enum import Enum -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional from langchain_core.callbacks import ( AsyncCallbackManagerForRetrieverRun, CallbackManagerForRetrieverRun, ) from langchain_core.documents import Document -from langchain_core.pydantic_v1 import Field, root_validator from langchain_core.retrievers import BaseRetriever from langchain_core.stores import BaseStore, ByteStore from langchain_core.vectorstores import VectorStore +from pydantic import Field, model_validator from langchain.storage._lc_store import create_kv_docstore @@ -41,8 +41,9 @@ class MultiVectorRetriever(BaseRetriever): search_type: SearchType = SearchType.similarity """Type of search to perform (similarity / mmr)""" - @root_validator(pre=True) - def shim_docstore(cls, values: Dict) -> Dict: + @model_validator(mode="before") + @classmethod + def shim_docstore(cls, values: Dict) -> Any: byte_store = values.get("byte_store") docstore = values.get("docstore") if byte_store is not None: diff --git a/libs/langchain/langchain/retrievers/self_query/base.py b/libs/langchain/langchain/retrievers/self_query/base.py index 0b89db472c124..a5254d475924f 100644 --- a/libs/langchain/langchain/retrievers/self_query/base.py +++ b/libs/langchain/langchain/retrievers/self_query/base.py @@ -9,11 +9,11 @@ ) from langchain_core.documents import Document from langchain_core.language_models import BaseLanguageModel -from langchain_core.pydantic_v1 import Field, root_validator from langchain_core.retrievers import BaseRetriever from langchain_core.runnables import Runnable from langchain_core.structured_query import StructuredQuery, Visitor from langchain_core.vectorstores import VectorStore +from pydantic import ConfigDict, Field, model_validator from langchain.chains.query_constructor.base import load_query_constructor_runnable from langchain.chains.query_constructor.schema import AttributeInfo @@ -223,12 +223,14 @@ class SelfQueryRetriever(BaseRetriever): use_original_query: bool = False """Use original query instead of the revised new query from LLM""" - class Config: - allow_population_by_field_name = True - arbitrary_types_allowed = True + model_config = ConfigDict( + populate_by_name=True, + arbitrary_types_allowed=True, + ) - @root_validator(pre=True) - def validate_translator(cls, values: Dict) -> Dict: + @model_validator(mode="before") + @classmethod + def validate_translator(cls, values: Dict) -> Any: """Validate translator.""" if "structured_query_translator" not in values: values["structured_query_translator"] = _get_builtin_translator( diff --git a/libs/langchain/langchain/retrievers/time_weighted_retriever.py b/libs/langchain/langchain/retrievers/time_weighted_retriever.py index 0acf17edce806..706366dbf58bf 100644 --- a/libs/langchain/langchain/retrievers/time_weighted_retriever.py +++ b/libs/langchain/langchain/retrievers/time_weighted_retriever.py @@ -7,9 +7,9 @@ CallbackManagerForRetrieverRun, ) from langchain_core.documents import Document -from langchain_core.pydantic_v1 import Field from langchain_core.retrievers import BaseRetriever from langchain_core.vectorstores import VectorStore +from pydantic import ConfigDict, Field def _get_hours_passed(time: datetime.datetime, ref_time: datetime.datetime) -> float: @@ -46,8 +46,9 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever): None assigns no salience to documents not fetched from the vector store. """ - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict( + arbitrary_types_allowed=True, + ) def _document_get_date(self, field: str, document: Document) -> datetime.datetime: """Return the value of the date field of a document.""" diff --git a/libs/langchain/langchain/smith/evaluation/config.py b/libs/langchain/langchain/smith/evaluation/config.py index e9bdd324779db..9f132a011a6bb 100644 --- a/libs/langchain/langchain/smith/evaluation/config.py +++ b/libs/langchain/langchain/smith/evaluation/config.py @@ -5,10 +5,10 @@ from langchain_core.embeddings import Embeddings from langchain_core.language_models import BaseLanguageModel from langchain_core.prompts import BasePromptTemplate -from langchain_core.pydantic_v1 import BaseModel, Field from langsmith import RunEvaluator from langsmith.evaluation.evaluator import EvaluationResult, EvaluationResults from langsmith.schemas import Example, Run +from pydantic import BaseModel, ConfigDict, Field from langchain.evaluation.criteria.eval_chain import CRITERIA_TYPE from langchain.evaluation.embedding_distance.base import ( @@ -156,8 +156,9 @@ class RunEvalConfig(BaseModel): eval_llm: Optional[BaseLanguageModel] = None """The language model to pass to any evaluators that require one.""" - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict( + arbitrary_types_allowed=True, + ) class Criteria(SingleKeyEvalConfig): """Configuration for a reference-free criteria evaluator. @@ -217,8 +218,9 @@ class EmbeddingDistance(SingleKeyEvalConfig): embeddings: Optional[Embeddings] = None distance_metric: Optional[EmbeddingDistanceEnum] = None - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict( + arbitrary_types_allowed=True, + ) class StringDistance(SingleKeyEvalConfig): """Configuration for a string distance evaluator. diff --git a/libs/langchain/scripts/check_pydantic.sh b/libs/langchain/scripts/check_pydantic.sh deleted file mode 100755 index 06b5bb81ae236..0000000000000 --- a/libs/langchain/scripts/check_pydantic.sh +++ /dev/null @@ -1,27 +0,0 @@ -#!/bin/bash -# -# This script searches for lines starting with "import pydantic" or "from pydantic" -# in tracked files within a Git repository. -# -# Usage: ./scripts/check_pydantic.sh /path/to/repository - -# Check if a path argument is provided -if [ $# -ne 1 ]; then - echo "Usage: $0 /path/to/repository" - exit 1 -fi - -repository_path="$1" - -# Search for lines matching the pattern within the specified repository -result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic') - -# Check if any matching lines were found -if [ -n "$result" ]; then - echo "ERROR: The following lines need to be updated:" - echo "$result" - echo "Please replace the code with an import from langchain_core.pydantic_v1." - echo "For example, replace 'from pydantic import BaseModel'" - echo "with 'from langchain_core.pydantic_v1 import BaseModel'" - exit 1 -fi diff --git a/libs/langchain/tests/integration_tests/chat_models/test_base.py b/libs/langchain/tests/integration_tests/chat_models/test_base.py index cda11263ddfbc..efed6e1d52290 100644 --- a/libs/langchain/tests/integration_tests/chat_models/test_base.py +++ b/libs/langchain/tests/integration_tests/chat_models/test_base.py @@ -4,9 +4,9 @@ from langchain_core.language_models import BaseChatModel from langchain_core.messages import AIMessage from langchain_core.prompts import ChatPromptTemplate -from langchain_core.pydantic_v1 import BaseModel from langchain_core.runnables import RunnableConfig from langchain_standard_tests.integration_tests import ChatModelIntegrationTests +from pydantic import BaseModel from langchain.chat_models import init_chat_model diff --git a/libs/langchain/tests/mock_servers/robot/server.py b/libs/langchain/tests/mock_servers/robot/server.py index 1156cf1bb89db..823057bb4d9e2 100644 --- a/libs/langchain/tests/mock_servers/robot/server.py +++ b/libs/langchain/tests/mock_servers/robot/server.py @@ -8,7 +8,7 @@ from fastapi import FastAPI, HTTPException, Query from fastapi.middleware.cors import CORSMiddleware from fastapi.openapi.utils import get_openapi -from langchain_core.pydantic_v1 import BaseModel, Field +from pydantic import BaseModel, Field PORT = 7289 diff --git a/libs/langchain/tests/unit_tests/agents/test_agent.py b/libs/langchain/tests/unit_tests/agents/test_agent.py index 6db94330a1328..038ffb2930996 100644 --- a/libs/langchain/tests/unit_tests/agents/test_agent.py +++ b/libs/langchain/tests/unit_tests/agents/test_agent.py @@ -6,7 +6,6 @@ from langchain_core.agents import ( AgentAction, - AgentActionMessageLog, AgentFinish, AgentStep, ) @@ -35,7 +34,9 @@ from langchain.agents.output_parsers.openai_tools import OpenAIToolAgentAction from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler from tests.unit_tests.llms.fake_chat_model import GenericFakeChatModel -from tests.unit_tests.stubs import AnyStr +from tests.unit_tests.stubs import ( + _AnyIdAIMessageChunk, +) class FakeListLLM(LLM): @@ -785,6 +786,26 @@ def _make_func_invocation(name: str, **kwargs: Any) -> AIMessage: ) +def _recursive_dump(obj: Any) -> Any: + """Recursively dump the object if encountering any pydantic models.""" + if isinstance(obj, dict): + return { + k: _recursive_dump(v) + for k, v in obj.items() + if k != "id" # Remove the id field for testing purposes + } + if isinstance(obj, list): + return [_recursive_dump(v) for v in obj] + if hasattr(obj, "dict"): + # if the object contains an ID field, we'll remove it for testing purposes + if hasattr(obj, "id"): + d = obj.dict() + d.pop("id") + return _recursive_dump(d) + return _recursive_dump(obj.dict()) + return obj + + async def test_openai_agent_with_streaming() -> None: """Test openai agent with streaming.""" infinite_cycle = cycle( @@ -831,72 +852,93 @@ def find_pet(pet: str) -> str: # astream chunks = [chunk async for chunk in executor.astream({"question": "hello"})] - assert chunks == [ + assert _recursive_dump(chunks) == [ { "actions": [ - AgentActionMessageLog( - tool="find_pet", - tool_input={"pet": "cat"}, - log="\nInvoking: `find_pet` with `{'pet': 'cat'}`\n\n\n", - message_log=[ - AIMessageChunk( - id=AnyStr(), - content="", - additional_kwargs={ + { + "log": "\nInvoking: `find_pet` with `{'pet': 'cat'}`\n\n\n", + "message_log": [ + { + "additional_kwargs": { "function_call": { + "arguments": '{"pet": ' '"cat"}', "name": "find_pet", - "arguments": '{"pet": "cat"}', } }, - ) + "content": "", + "name": None, + "response_metadata": {}, + "type": "AIMessageChunk", + } ], - ) + "tool": "find_pet", + "tool_input": {"pet": "cat"}, + "type": "AgentActionMessageLog", + } ], "messages": [ - AIMessageChunk( - id=AnyStr(), - content="", - additional_kwargs={ + { + "additional_kwargs": { "function_call": { + "arguments": '{"pet": ' '"cat"}', "name": "find_pet", - "arguments": '{"pet": "cat"}', } }, - ) + "content": "", + "example": False, + "invalid_tool_calls": [], + "name": None, + "response_metadata": {}, + "tool_call_chunks": [], + "tool_calls": [], + "type": "AIMessageChunk", + "usage_metadata": None, + } ], }, { "messages": [ - FunctionMessage(content="Spying from under the bed.", name="find_pet") + { + "additional_kwargs": {}, + "content": "Spying from under the bed.", + "name": "find_pet", + "response_metadata": {}, + "type": "function", + } ], "steps": [ - AgentStep( - action=AgentActionMessageLog( - tool="find_pet", - tool_input={"pet": "cat"}, - log="\nInvoking: `find_pet` with `{'pet': 'cat'}`\n\n\n", - message_log=[ - AIMessageChunk( - id=AnyStr(), - content="", - additional_kwargs={ - "function_call": { - "name": "find_pet", - "arguments": '{"pet": "cat"}', - } - }, - ) - ], - ), - observation="Spying from under the bed.", - ) + { + "action": { + "log": "\n" + "Invoking: `find_pet` with `{'pet': 'cat'}`\n" + "\n" + "\n", + "tool": "find_pet", + "tool_input": {"pet": "cat"}, + "type": "AgentActionMessageLog", + }, + "observation": "Spying from under the bed.", + } ], }, { - "messages": [AIMessage(content="The cat is spying from under the bed.")], + "messages": [ + { + "additional_kwargs": {}, + "content": "The cat is spying from under the bed.", + "example": False, + "invalid_tool_calls": [], + "name": None, + "response_metadata": {}, + "tool_calls": [], + "type": "ai", + "usage_metadata": None, + } + ], "output": "The cat is spying from under the bed.", }, ] + # # # astream_log log_patches = [ @@ -941,7 +983,7 @@ def _make_tools_invocation(name_to_arguments: Dict[str, Dict[str, Any]]) -> AIMe AIMessage that represents a request to invoke a tool. """ raw_tool_calls = [ - {"function": {"name": name, "arguments": json.dumps(arguments)}, "id": idx} + {"function": {"name": name, "arguments": json.dumps(arguments)}, "id": str(idx)} for idx, (name, arguments) in enumerate(name_to_arguments.items()) ] tool_calls = [ @@ -1030,8 +1072,7 @@ def check_time() -> str: tool_input={"pet": "cat"}, log="\nInvoking: `find_pet` with `{'pet': 'cat'}`\n\n\n", message_log=[ - AIMessageChunk( - id=AnyStr(), + _AnyIdAIMessageChunk( content="", additional_kwargs={ "tool_calls": [ @@ -1040,14 +1081,14 @@ def check_time() -> str: "name": "find_pet", "arguments": '{"pet": "cat"}', }, - "id": 0, + "id": "0", }, { "function": { "name": "check_time", "arguments": "{}", }, - "id": 1, + "id": "1", }, ] }, @@ -1057,8 +1098,7 @@ def check_time() -> str: ) ], "messages": [ - AIMessageChunk( - id=AnyStr(), + _AnyIdAIMessageChunk( content="", additional_kwargs={ "tool_calls": [ @@ -1067,14 +1107,14 @@ def check_time() -> str: "name": "find_pet", "arguments": '{"pet": "cat"}', }, - "id": 0, + "id": "0", }, { "function": { "name": "check_time", "arguments": "{}", }, - "id": 1, + "id": "1", }, ] }, @@ -1088,8 +1128,7 @@ def check_time() -> str: tool_input={}, log="\nInvoking: `check_time` with `{}`\n\n\n", message_log=[ - AIMessageChunk( - id=AnyStr(), + _AnyIdAIMessageChunk( content="", additional_kwargs={ "tool_calls": [ @@ -1098,14 +1137,14 @@ def check_time() -> str: "name": "find_pet", "arguments": '{"pet": "cat"}', }, - "id": 0, + "id": "0", }, { "function": { "name": "check_time", "arguments": "{}", }, - "id": 1, + "id": "1", }, ] }, @@ -1115,8 +1154,7 @@ def check_time() -> str: ) ], "messages": [ - AIMessageChunk( - id=AnyStr(), + _AnyIdAIMessageChunk( content="", additional_kwargs={ "tool_calls": [ @@ -1125,14 +1163,14 @@ def check_time() -> str: "name": "find_pet", "arguments": '{"pet": "cat"}', }, - "id": 0, + "id": "0", }, { "function": { "name": "check_time", "arguments": "{}", }, - "id": 1, + "id": "1", }, ] }, @@ -1152,8 +1190,7 @@ def check_time() -> str: tool_input={"pet": "cat"}, log="\nInvoking: `find_pet` with `{'pet': 'cat'}`\n\n\n", # noqa: E501 message_log=[ - AIMessageChunk( - id=AnyStr(), + _AnyIdAIMessageChunk( content="", additional_kwargs={ "tool_calls": [ @@ -1162,14 +1199,14 @@ def check_time() -> str: "name": "find_pet", "arguments": '{"pet": "cat"}', }, - "id": 0, + "id": "0", }, { "function": { "name": "check_time", "arguments": "{}", }, - "id": 1, + "id": "1", }, ] }, @@ -1195,8 +1232,7 @@ def check_time() -> str: tool_input={}, log="\nInvoking: `check_time` with `{}`\n\n\n", message_log=[ - AIMessageChunk( - id=AnyStr(), + _AnyIdAIMessageChunk( content="", additional_kwargs={ "tool_calls": [ @@ -1205,14 +1241,14 @@ def check_time() -> str: "name": "find_pet", "arguments": '{"pet": "cat"}', }, - "id": 0, + "id": "0", }, { "function": { "name": "check_time", "arguments": "{}", }, - "id": 1, + "id": "1", }, ] }, diff --git a/libs/langchain/tests/unit_tests/callbacks/fake_callback_handler.py b/libs/langchain/tests/unit_tests/callbacks/fake_callback_handler.py index fdef58e13ff4a..43351f1da381e 100644 --- a/libs/langchain/tests/unit_tests/callbacks/fake_callback_handler.py +++ b/libs/langchain/tests/unit_tests/callbacks/fake_callback_handler.py @@ -6,7 +6,7 @@ from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler from langchain_core.messages import BaseMessage -from langchain_core.pydantic_v1 import BaseModel +from pydantic import BaseModel class BaseFakeCallbackHandler(BaseModel): @@ -254,7 +254,7 @@ def on_retriever_error( ) -> Any: self.on_retriever_error_common() - def __deepcopy__(self, memo: dict) -> "FakeCallbackHandler": + def __deepcopy__(self, memo: dict) -> "FakeCallbackHandler": # type: ignore return self @@ -388,5 +388,5 @@ async def on_text( ) -> None: self.on_text_common() - def __deepcopy__(self, memo: dict) -> "FakeAsyncCallbackHandler": + def __deepcopy__(self, memo: dict) -> "FakeAsyncCallbackHandler": # type: ignore return self diff --git a/libs/langchain/tests/unit_tests/evaluation/agents/test_eval_chain.py b/libs/langchain/tests/unit_tests/evaluation/agents/test_eval_chain.py index 5178b6b20d99a..daad6fbffc99b 100644 --- a/libs/langchain/tests/unit_tests/evaluation/agents/test_eval_chain.py +++ b/libs/langchain/tests/unit_tests/evaluation/agents/test_eval_chain.py @@ -6,8 +6,8 @@ from langchain_core.agents import AgentAction, BaseMessage from langchain_core.callbacks.manager import CallbackManagerForLLMRun from langchain_core.exceptions import OutputParserException -from langchain_core.pydantic_v1 import Field from langchain_core.tools import tool +from pydantic import Field from langchain.evaluation.agents.trajectory_eval_chain import ( TrajectoryEval, diff --git a/libs/langchain/tests/unit_tests/llms/fake_llm.py b/libs/langchain/tests/unit_tests/llms/fake_llm.py index e75865b40f290..7e21cce3404cc 100644 --- a/libs/langchain/tests/unit_tests/llms/fake_llm.py +++ b/libs/langchain/tests/unit_tests/llms/fake_llm.py @@ -4,7 +4,7 @@ from langchain_core.callbacks.manager import CallbackManagerForLLMRun from langchain_core.language_models.llms import LLM -from langchain_core.pydantic_v1 import validator +from pydantic import validator class FakeLLM(LLM): diff --git a/libs/langchain/tests/unit_tests/llms/test_fake_chat_model.py b/libs/langchain/tests/unit_tests/llms/test_fake_chat_model.py index f009222608668..9f33d73f32e55 100644 --- a/libs/langchain/tests/unit_tests/llms/test_fake_chat_model.py +++ b/libs/langchain/tests/unit_tests/llms/test_fake_chat_model.py @@ -9,7 +9,7 @@ from langchain_core.outputs import ChatGenerationChunk, GenerationChunk from tests.unit_tests.llms.fake_chat_model import GenericFakeChatModel -from tests.unit_tests.stubs import AnyStr, _AnyIdAIMessage, _AnyIdAIMessageChunk +from tests.unit_tests.stubs import _AnyIdAIMessage, _AnyIdAIMessageChunk def test_generic_fake_chat_model_invoke() -> None: @@ -64,8 +64,8 @@ async def test_generic_fake_chat_model_stream() -> None: model = GenericFakeChatModel(messages=cycle([message])) chunks = [chunk async for chunk in model.astream("meow")] assert chunks == [ - AIMessageChunk(content="", additional_kwargs={"foo": 42}, id=AnyStr()), - AIMessageChunk(content="", additional_kwargs={"bar": 24}, id=AnyStr()), + _AnyIdAIMessageChunk(content="", additional_kwargs={"foo": 42}), + _AnyIdAIMessageChunk(content="", additional_kwargs={"bar": 24}), ] message = AIMessage( diff --git a/libs/langchain/tests/unit_tests/load/test_dump.py b/libs/langchain/tests/unit_tests/load/test_dump.py index 0ac05f7df21ff..76af8513e73ee 100644 --- a/libs/langchain/tests/unit_tests/load/test_dump.py +++ b/libs/langchain/tests/unit_tests/load/test_dump.py @@ -8,7 +8,7 @@ import pytest from langchain_core.load.dump import dumps from langchain_core.load.serializable import Serializable -from langchain_core.pydantic_v1 import Field, root_validator +from pydantic import ConfigDict, Field, model_validator class Person(Serializable): @@ -84,11 +84,13 @@ class TestClass(Serializable): my_favorite_secret: str = Field(alias="my_favorite_secret_alias") my_other_secret: str = Field() - class Config: - allow_population_by_field_name = True + model_config = ConfigDict( + populate_by_name=True, + ) - @root_validator(pre=True) - def get_from_env(cls, values: Dict) -> Dict: + @model_validator(mode="before") + @classmethod + def get_from_env(cls, values: Dict) -> Any: """Get the values from the environment.""" if "my_favorite_secret" not in values: values["my_favorite_secret"] = os.getenv("MY_FAVORITE_SECRET") diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_fix.py b/libs/langchain/tests/unit_tests/output_parsers/test_fix.py index 61d2d8a0c4613..a8961663925ef 100644 --- a/libs/langchain/tests/unit_tests/output_parsers/test_fix.py +++ b/libs/langchain/tests/unit_tests/output_parsers/test_fix.py @@ -6,13 +6,11 @@ from langchain_core.messages import AIMessage from langchain_core.prompts.prompt import PromptTemplate from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough -from pytest_mock import MockerFixture from langchain.output_parsers.boolean import BooleanOutputParser from langchain.output_parsers.datetime import DatetimeOutputParser from langchain.output_parsers.fix import BaseOutputParser, OutputFixingParser from langchain.output_parsers.prompts import NAIVE_FIX_PROMPT -from langchain.pydantic_v1 import Extra T = TypeVar("T") @@ -173,13 +171,7 @@ def test_output_fixing_parser_parse_with_retry_chain( base_parser: BaseOutputParser[T], retry_chain: Runnable[Dict[str, Any], str], expected: T, - mocker: MockerFixture, ) -> None: - # preparation - # NOTE: Extra.allow is necessary in order to use spy and mock - retry_chain.Config.extra = Extra.allow # type: ignore - base_parser.Config.extra = Extra.allow # type: ignore - invoke_spy = mocker.spy(retry_chain, "invoke") # NOTE: get_format_instructions of some parsers behave randomly instructions = base_parser.get_format_instructions() object.__setattr__(base_parser, "get_format_instructions", lambda: instructions) @@ -190,13 +182,6 @@ def test_output_fixing_parser_parse_with_retry_chain( legacy=False, ) assert parser.parse(input) == expected - invoke_spy.assert_called_once_with( - dict( - instructions=base_parser.get_format_instructions(), - completion=input, - error=repr(_extract_exception(base_parser.parse, input)), - ) - ) @pytest.mark.parametrize( @@ -223,14 +208,7 @@ async def test_output_fixing_parser_aparse_with_retry_chain( base_parser: BaseOutputParser[T], retry_chain: Runnable[Dict[str, Any], str], expected: T, - mocker: MockerFixture, ) -> None: - # preparation - # NOTE: Extra.allow is necessary in order to use spy and mock - retry_chain.Config.extra = Extra.allow # type: ignore - base_parser.Config.extra = Extra.allow # type: ignore - ainvoke_spy = mocker.spy(retry_chain, "ainvoke") - # NOTE: get_format_instructions of some parsers behave randomly instructions = base_parser.get_format_instructions() object.__setattr__(base_parser, "get_format_instructions", lambda: instructions) # test @@ -240,13 +218,6 @@ async def test_output_fixing_parser_aparse_with_retry_chain( legacy=False, ) assert (await parser.aparse(input)) == expected - ainvoke_spy.assert_called_once_with( - dict( - instructions=base_parser.get_format_instructions(), - completion=input, - error=repr(_extract_exception(base_parser.parse, input)), - ) - ) def _extract_exception( diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_retry.py b/libs/langchain/tests/unit_tests/output_parsers/test_retry.py index 7af3597f47573..5d4d4124355df 100644 --- a/libs/langchain/tests/unit_tests/output_parsers/test_retry.py +++ b/libs/langchain/tests/unit_tests/output_parsers/test_retry.py @@ -4,7 +4,6 @@ import pytest from langchain_core.prompt_values import PromptValue, StringPromptValue from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough -from pytest_mock import MockerFixture from langchain.output_parsers.boolean import BooleanOutputParser from langchain.output_parsers.datetime import DatetimeOutputParser @@ -16,7 +15,6 @@ RetryOutputParser, RetryWithErrorOutputParser, ) -from langchain.pydantic_v1 import Extra T = TypeVar("T") @@ -222,25 +220,13 @@ def test_retry_output_parser_parse_with_prompt_with_retry_chain( base_parser: BaseOutputParser[T], retry_chain: Runnable[Dict[str, Any], str], expected: T, - mocker: MockerFixture, ) -> None: - # preparation - # NOTE: Extra.allow is necessary in order to use spy and mock - retry_chain.Config.extra = Extra.allow # type: ignore - invoke_spy = mocker.spy(retry_chain, "invoke") - # test parser = RetryOutputParser( parser=base_parser, retry_chain=retry_chain, legacy=False, ) assert parser.parse_with_prompt(input, prompt) == expected - invoke_spy.assert_called_once_with( - dict( - prompt=prompt.to_string(), - completion=input, - ) - ) @pytest.mark.parametrize( @@ -262,12 +248,7 @@ async def test_retry_output_parser_aparse_with_prompt_with_retry_chain( base_parser: BaseOutputParser[T], retry_chain: Runnable[Dict[str, Any], str], expected: T, - mocker: MockerFixture, ) -> None: - # preparation - # NOTE: Extra.allow is necessary in order to use spy and mock - retry_chain.Config.extra = Extra.allow # type: ignore - ainvoke_spy = mocker.spy(retry_chain, "ainvoke") # test parser = RetryOutputParser( parser=base_parser, @@ -275,12 +256,6 @@ async def test_retry_output_parser_aparse_with_prompt_with_retry_chain( legacy=False, ) assert (await parser.aparse_with_prompt(input, prompt)) == expected - ainvoke_spy.assert_called_once_with( - dict( - prompt=prompt.to_string(), - completion=input, - ) - ) @pytest.mark.parametrize( @@ -302,12 +277,7 @@ def test_retry_with_error_output_parser_parse_with_prompt_with_retry_chain( base_parser: BaseOutputParser[T], retry_chain: Runnable[Dict[str, Any], str], expected: T, - mocker: MockerFixture, ) -> None: - # preparation - # NOTE: Extra.allow is necessary in order to use spy and mock - retry_chain.Config.extra = Extra.allow # type: ignore - invoke_spy = mocker.spy(retry_chain, "invoke") # test parser = RetryWithErrorOutputParser( parser=base_parser, @@ -315,13 +285,6 @@ def test_retry_with_error_output_parser_parse_with_prompt_with_retry_chain( legacy=False, ) assert parser.parse_with_prompt(input, prompt) == expected - invoke_spy.assert_called_once_with( - dict( - prompt=prompt.to_string(), - completion=input, - error=repr(_extract_exception(base_parser.parse, input)), - ) - ) @pytest.mark.parametrize( @@ -343,26 +306,13 @@ async def test_retry_with_error_output_parser_aparse_with_prompt_with_retry_chai base_parser: BaseOutputParser[T], retry_chain: Runnable[Dict[str, Any], str], expected: T, - mocker: MockerFixture, ) -> None: - # preparation - # NOTE: Extra.allow is necessary in order to use spy and mock - retry_chain.Config.extra = Extra.allow # type: ignore - ainvoke_spy = mocker.spy(retry_chain, "ainvoke") - # test parser = RetryWithErrorOutputParser( parser=base_parser, retry_chain=retry_chain, legacy=False, ) assert (await parser.aparse_with_prompt(input, prompt)) == expected - ainvoke_spy.assert_called_once_with( - dict( - prompt=prompt.to_string(), - completion=input, - error=repr(_extract_exception(base_parser.parse, input)), - ) - ) def _extract_exception( diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_yaml_parser.py b/libs/langchain/tests/unit_tests/output_parsers/test_yaml_parser.py index 065ca4aa96ca4..6e678ed8f03d2 100644 --- a/libs/langchain/tests/unit_tests/output_parsers/test_yaml_parser.py +++ b/libs/langchain/tests/unit_tests/output_parsers/test_yaml_parser.py @@ -5,7 +5,7 @@ import pytest from langchain_core.exceptions import OutputParserException -from langchain_core.pydantic_v1 import BaseModel, Field +from pydantic import BaseModel, Field from langchain.output_parsers.yaml import YamlOutputParser diff --git a/libs/langchain/tests/unit_tests/smith/evaluation/test_string_run_evaluator.py b/libs/langchain/tests/unit_tests/smith/evaluation/test_string_run_evaluator.py index fd2f9e40c3efa..cb2916193a85b 100644 --- a/libs/langchain/tests/unit_tests/smith/evaluation/test_string_run_evaluator.py +++ b/libs/langchain/tests/unit_tests/smith/evaluation/test_string_run_evaluator.py @@ -12,11 +12,10 @@ def test_evaluate_run() -> None: run_mapper = ChainStringRunMapper() - example_mapper = MagicMock() string_evaluator = criteria.CriteriaEvalChain.from_llm(fake_llm.FakeLLM()) evaluator = StringRunEvaluatorChain( run_mapper=run_mapper, - example_mapper=example_mapper, + example_mapper=None, name="test_evaluator", string_evaluator=string_evaluator, ) diff --git a/libs/langchain/tests/unit_tests/test_imports.py b/libs/langchain/tests/unit_tests/test_imports.py index 1433638098273..22b98bf2cfbbe 100644 --- a/libs/langchain/tests/unit_tests/test_imports.py +++ b/libs/langchain/tests/unit_tests/test_imports.py @@ -34,6 +34,8 @@ def test_import_all() -> None: # If the module is not installed, we suppress the error if "Module langchain_community" in str(e) and COMMUNITY_NOT_INSTALLED: pass + except Exception as e: + raise AssertionError(f"Could not import {module_name}.{name}") from e def test_import_all_using_dir() -> None: diff --git a/libs/langchain/tests/unit_tests/test_schema.py b/libs/langchain/tests/unit_tests/test_schema.py index a498e234a5070..5e720db065347 100644 --- a/libs/langchain/tests/unit_tests/test_schema.py +++ b/libs/langchain/tests/unit_tests/test_schema.py @@ -19,26 +19,23 @@ ) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, Generation from langchain_core.prompt_values import ChatPromptValueConcrete, StringPromptValue -from langchain_core.pydantic_v1 import BaseModel, ValidationError +from pydantic import RootModel, ValidationError def test_serialization_of_wellknown_objects() -> None: """Test that pydantic is able to serialize and deserialize well known objects.""" - - class WellKnownLCObject(BaseModel): - """A well known LangChain object.""" - - __root__: Union[ + well_known_lc_object = RootModel[ + Union[ Document, HumanMessage, SystemMessage, ChatMessage, FunctionMessage, + FunctionMessageChunk, AIMessage, HumanMessageChunk, SystemMessageChunk, ChatMessageChunk, - FunctionMessageChunk, AIMessageChunk, StringPromptValue, ChatPromptValueConcrete, @@ -49,6 +46,7 @@ class WellKnownLCObject(BaseModel): Generation, ChatGenerationChunk, ] + ] lc_objects = [ HumanMessage(content="human"), @@ -97,11 +95,11 @@ class WellKnownLCObject(BaseModel): ] for lc_object in lc_objects: - d = lc_object.dict() + d = lc_object.model_dump() assert "type" in d, f"Missing key `type` for {type(lc_object)}" - obj1 = WellKnownLCObject.parse_obj(d) - assert type(obj1.__root__) is type(lc_object), f"failed for {type(lc_object)}" + obj1 = well_known_lc_object.model_validate(d) + assert type(obj1.root) is type(lc_object), f"failed for {type(lc_object)}" - with pytest.raises(ValidationError): + with pytest.raises((TypeError, ValidationError)): # Make sure that specifically validation error is raised - WellKnownLCObject.parse_obj({}) + well_known_lc_object.model_validate({}) diff --git a/libs/langchain/tests/unit_tests/utils/test_openai_functions.py b/libs/langchain/tests/unit_tests/utils/test_openai_functions.py index 34a0b8126f083..ca66e1c64ae2f 100644 --- a/libs/langchain/tests/unit_tests/utils/test_openai_functions.py +++ b/libs/langchain/tests/unit_tests/utils/test_openai_functions.py @@ -1,5 +1,5 @@ -from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.utils.function_calling import convert_pydantic_to_openai_function +from pydantic import BaseModel, Field def test_convert_pydantic_to_openai_function() -> None: