diff --git a/libs/langchain/Makefile b/libs/langchain/Makefile index c77591f6430d6..e06cd2e65d1a9 100644 --- a/libs/langchain/Makefile +++ b/libs/langchain/Makefile @@ -54,7 +54,6 @@ lint_tests: PYTHON_FILES=tests lint_tests: MYPY_CACHE=.mypy_cache_test lint lint_diff lint_package lint_tests: - ./scripts/check_pydantic.sh . ./scripts/lint_imports.sh [ "$(PYTHON_FILES)" = "" ] || poetry run ruff check $(PYTHON_FILES) [ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff diff --git a/libs/langchain/langchain/agents/agent.py b/libs/langchain/langchain/agents/agent.py index 50797228a53b8..5fae509f343db 100644 --- a/libs/langchain/langchain/agents/agent.py +++ b/libs/langchain/langchain/agents/agent.py @@ -40,11 +40,12 @@ from langchain_core.prompts import BasePromptTemplate from langchain_core.prompts.few_shot import FewShotPromptTemplate from langchain_core.prompts.prompt import PromptTemplate -from langchain_core.pydantic_v1 import BaseModel, root_validator from langchain_core.runnables import Runnable, RunnableConfig, ensure_config from langchain_core.runnables.utils import AddableDict from langchain_core.tools import BaseTool from langchain_core.utils.input import get_color_mapping +from pydantic import BaseModel, ConfigDict, model_validator +from typing_extensions import Self from langchain.agents.agent_iterator import AgentExecutorIterator from langchain.agents.agent_types import AgentType @@ -420,8 +421,9 @@ class RunnableAgent(BaseSingleActionAgent): individual LLM tokens will not be available in stream_log. """ - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict( + arbitrary_types_allowed=True, + ) @property def return_values(self) -> List[str]: @@ -528,8 +530,9 @@ class RunnableMultiActionAgent(BaseMultiActionAgent): individual LLM tokens will not be available in stream_log. """ - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict( + arbitrary_types_allowed=True, + ) @property def return_values(self) -> List[str]: @@ -854,8 +857,8 @@ def input_keys(self) -> List[str]: """ return list(set(self.llm_chain.input_keys) - {"agent_scratchpad"}) - @root_validator(pre=False, skip_on_failure=True) - def validate_prompt(cls, values: Dict) -> Dict: + @model_validator(mode="after") + def validate_prompt(self) -> Self: """Validate that prompt matches format. Args: @@ -868,7 +871,7 @@ def validate_prompt(cls, values: Dict) -> Dict: ValueError: If `agent_scratchpad` is not in prompt.input_variables and prompt is not a FewShotPromptTemplate or a PromptTemplate. """ - prompt = values["llm_chain"].prompt + prompt = self.llm_chain.prompt if "agent_scratchpad" not in prompt.input_variables: logger.warning( "`agent_scratchpad` should be a variable in prompt.input_variables." @@ -881,7 +884,7 @@ def validate_prompt(cls, values: Dict) -> Dict: prompt.suffix += "\n{agent_scratchpad}" else: raise ValueError(f"Got unexpected prompt type {type(prompt)}") - return values + return self @property @abstractmethod @@ -1120,8 +1123,8 @@ def from_agent_and_tools( **kwargs, ) - @root_validator(pre=False, skip_on_failure=True) - def validate_tools(cls, values: Dict) -> Dict: + @model_validator(mode="after") + def validate_tools(self) -> Self: """Validate that tools are compatible with agent. Args: @@ -1133,19 +1136,20 @@ def validate_tools(cls, values: Dict) -> Dict: Raises: ValueError: If allowed tools are different than provided tools. """ - agent = values["agent"] - tools = values["tools"] - allowed_tools = agent.get_allowed_tools() + agent = self.agent + tools = self.tools + allowed_tools = agent.get_allowed_tools() # type: ignore if allowed_tools is not None: if set(allowed_tools) != set([tool.name for tool in tools]): raise ValueError( f"Allowed tools ({allowed_tools}) different than " f"provided tools ({[tool.name for tool in tools]})" ) - return values + return self - @root_validator(pre=True) - def validate_runnable_agent(cls, values: Dict) -> Dict: + @model_validator(mode="before") + @classmethod + def validate_runnable_agent(cls, values: Dict) -> Any: """Convert runnable to agent if passed in. Args: diff --git a/libs/langchain/langchain/agents/agent_toolkits/vectorstore/toolkit.py b/libs/langchain/langchain/agents/agent_toolkits/vectorstore/toolkit.py index c40b3f3b7f289..71114a49eea37 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/vectorstore/toolkit.py +++ b/libs/langchain/langchain/agents/agent_toolkits/vectorstore/toolkit.py @@ -3,10 +3,10 @@ from typing import List from langchain_core.language_models import BaseLanguageModel -from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.tools import BaseTool from langchain_core.tools.base import BaseToolkit from langchain_core.vectorstores import VectorStore +from pydantic import BaseModel, ConfigDict, Field class VectorStoreInfo(BaseModel): @@ -16,8 +16,9 @@ class VectorStoreInfo(BaseModel): name: str description: str - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict( + arbitrary_types_allowed=True, + ) class VectorStoreToolkit(BaseToolkit): @@ -26,8 +27,9 @@ class VectorStoreToolkit(BaseToolkit): vectorstore_info: VectorStoreInfo = Field(exclude=True) llm: BaseLanguageModel - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict( + arbitrary_types_allowed=True, + ) def get_tools(self) -> List[BaseTool]: """Get the tools in the toolkit.""" @@ -67,8 +69,9 @@ class VectorStoreRouterToolkit(BaseToolkit): vectorstores: List[VectorStoreInfo] = Field(exclude=True) llm: BaseLanguageModel - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict( + arbitrary_types_allowed=True, + ) def get_tools(self) -> List[BaseTool]: """Get the tools in the toolkit.""" diff --git a/libs/langchain/langchain/agents/chat/base.py b/libs/langchain/langchain/agents/chat/base.py index a7a16be772611..00ced776f1342 100644 --- a/libs/langchain/langchain/agents/chat/base.py +++ b/libs/langchain/langchain/agents/chat/base.py @@ -10,8 +10,8 @@ HumanMessagePromptTemplate, SystemMessagePromptTemplate, ) -from langchain_core.pydantic_v1 import Field from langchain_core.tools import BaseTool +from pydantic import Field from langchain.agents.agent import Agent, AgentOutputParser from langchain.agents.chat.output_parser import ChatOutputParser diff --git a/libs/langchain/langchain/agents/conversational/base.py b/libs/langchain/langchain/agents/conversational/base.py index bbbf666e5903d..a0ef85946abd3 100644 --- a/libs/langchain/langchain/agents/conversational/base.py +++ b/libs/langchain/langchain/agents/conversational/base.py @@ -8,8 +8,8 @@ from langchain_core.callbacks import BaseCallbackManager from langchain_core.language_models import BaseLanguageModel from langchain_core.prompts import PromptTemplate -from langchain_core.pydantic_v1 import Field from langchain_core.tools import BaseTool +from pydantic import Field from langchain.agents.agent import Agent, AgentOutputParser from langchain.agents.agent_types import AgentType diff --git a/libs/langchain/langchain/agents/conversational_chat/base.py b/libs/langchain/langchain/agents/conversational_chat/base.py index 08ec829613a53..138933addbaec 100644 --- a/libs/langchain/langchain/agents/conversational_chat/base.py +++ b/libs/langchain/langchain/agents/conversational_chat/base.py @@ -17,8 +17,8 @@ MessagesPlaceholder, SystemMessagePromptTemplate, ) -from langchain_core.pydantic_v1 import Field from langchain_core.tools import BaseTool +from pydantic import Field from langchain.agents.agent import Agent, AgentOutputParser from langchain.agents.conversational_chat.output_parser import ConvoOutputParser diff --git a/libs/langchain/langchain/agents/mrkl/base.py b/libs/langchain/langchain/agents/mrkl/base.py index 1b16d4d862f6f..cc4d9da5537d7 100644 --- a/libs/langchain/langchain/agents/mrkl/base.py +++ b/libs/langchain/langchain/agents/mrkl/base.py @@ -8,9 +8,9 @@ from langchain_core.callbacks import BaseCallbackManager from langchain_core.language_models import BaseLanguageModel from langchain_core.prompts import PromptTemplate -from langchain_core.pydantic_v1 import Field from langchain_core.tools import BaseTool, Tool from langchain_core.tools.render import render_text_description +from pydantic import Field from langchain.agents.agent import Agent, AgentExecutor, AgentOutputParser from langchain.agents.agent_types import AgentType diff --git a/libs/langchain/langchain/agents/openai_assistant/base.py b/libs/langchain/langchain/agents/openai_assistant/base.py index a2f0ef74a5bff..e7d4d2bfa9c9b 100644 --- a/libs/langchain/langchain/agents/openai_assistant/base.py +++ b/libs/langchain/langchain/agents/openai_assistant/base.py @@ -20,10 +20,11 @@ from langchain_core.agents import AgentAction, AgentFinish from langchain_core.callbacks import CallbackManager from langchain_core.load import dumpd -from langchain_core.pydantic_v1 import BaseModel, Field, root_validator from langchain_core.runnables import RunnableConfig, RunnableSerializable, ensure_config from langchain_core.tools import BaseTool from langchain_core.utils.function_calling import convert_to_openai_tool +from pydantic import BaseModel, Field, model_validator +from typing_extensions import Self if TYPE_CHECKING: import openai @@ -232,14 +233,14 @@ def execute_agent(agent, tools, input): as_agent: bool = False """Use as a LangChain agent, compatible with the AgentExecutor.""" - @root_validator(pre=False, skip_on_failure=True) - def validate_async_client(cls, values: dict) -> dict: - if values["async_client"] is None: + @model_validator(mode="after") + def validate_async_client(self) -> Self: + if self.async_client is None: import openai - api_key = values["client"].api_key - values["async_client"] = openai.AsyncOpenAI(api_key=api_key) - return values + api_key = self.client.api_key + self.async_client = openai.AsyncOpenAI(api_key=api_key) + return self @classmethod def create_assistant( diff --git a/libs/langchain/langchain/agents/openai_functions_agent/base.py b/libs/langchain/langchain/agents/openai_functions_agent/base.py index fbb23a56e4785..5a40a665e3310 100644 --- a/libs/langchain/langchain/agents/openai_functions_agent/base.py +++ b/libs/langchain/langchain/agents/openai_functions_agent/base.py @@ -17,10 +17,11 @@ HumanMessagePromptTemplate, MessagesPlaceholder, ) -from langchain_core.pydantic_v1 import root_validator from langchain_core.runnables import Runnable, RunnablePassthrough from langchain_core.tools import BaseTool from langchain_core.utils.function_calling import convert_to_openai_function +from pydantic import model_validator +from typing_extensions import Self from langchain.agents import BaseSingleActionAgent from langchain.agents.format_scratchpad.openai_functions import ( @@ -58,8 +59,8 @@ def get_allowed_tools(self) -> List[str]: """Get allowed tools.""" return [t.name for t in self.tools] - @root_validator(pre=False, skip_on_failure=True) - def validate_prompt(cls, values: dict) -> dict: + @model_validator(mode="after") + def validate_prompt(self) -> Self: """Validate prompt. Args: @@ -71,13 +72,13 @@ def validate_prompt(cls, values: dict) -> dict: Raises: ValueError: If `agent_scratchpad` is not in the prompt. """ - prompt: BasePromptTemplate = values["prompt"] + prompt: BasePromptTemplate = self.prompt if "agent_scratchpad" not in prompt.input_variables: raise ValueError( "`agent_scratchpad` should be one of the variables in the prompt, " f"got {prompt.input_variables}" ) - return values + return self @property def input_keys(self) -> List[str]: diff --git a/libs/langchain/langchain/agents/openai_functions_multi_agent/base.py b/libs/langchain/langchain/agents/openai_functions_multi_agent/base.py index 819fc8b1d939e..bec49e5f14d24 100644 --- a/libs/langchain/langchain/agents/openai_functions_multi_agent/base.py +++ b/libs/langchain/langchain/agents/openai_functions_multi_agent/base.py @@ -21,8 +21,9 @@ HumanMessagePromptTemplate, MessagesPlaceholder, ) -from langchain_core.pydantic_v1 import root_validator from langchain_core.tools import BaseTool +from pydantic import model_validator +from typing_extensions import Self from langchain.agents import BaseMultiActionAgent from langchain.agents.format_scratchpad.openai_functions import ( @@ -115,15 +116,15 @@ def get_allowed_tools(self) -> List[str]: """Get allowed tools.""" return [t.name for t in self.tools] - @root_validator(pre=False, skip_on_failure=True) - def validate_prompt(cls, values: dict) -> dict: - prompt: BasePromptTemplate = values["prompt"] + @model_validator(mode="after") + def validate_prompt(self) -> Self: + prompt: BasePromptTemplate = self.prompt if "agent_scratchpad" not in prompt.input_variables: raise ValueError( "`agent_scratchpad` should be one of the variables in the prompt, " f"got {prompt.input_variables}" ) - return values + return self @property def input_keys(self) -> List[str]: diff --git a/libs/langchain/langchain/agents/react/base.py b/libs/langchain/langchain/agents/react/base.py index 93a60bbe61ba2..81a38141fe686 100644 --- a/libs/langchain/langchain/agents/react/base.py +++ b/libs/langchain/langchain/agents/react/base.py @@ -8,8 +8,8 @@ from langchain_core.documents import Document from langchain_core.language_models import BaseLanguageModel from langchain_core.prompts import BasePromptTemplate -from langchain_core.pydantic_v1 import Field from langchain_core.tools import BaseTool, Tool +from pydantic import Field from langchain.agents.agent import Agent, AgentExecutor, AgentOutputParser from langchain.agents.agent_types import AgentType diff --git a/libs/langchain/langchain/agents/self_ask_with_search/base.py b/libs/langchain/langchain/agents/self_ask_with_search/base.py index 7a1d81f7d40d1..9a642b81b1289 100644 --- a/libs/langchain/langchain/agents/self_ask_with_search/base.py +++ b/libs/langchain/langchain/agents/self_ask_with_search/base.py @@ -7,9 +7,9 @@ from langchain_core._api import deprecated from langchain_core.language_models import BaseLanguageModel from langchain_core.prompts import BasePromptTemplate -from langchain_core.pydantic_v1 import Field from langchain_core.runnables import Runnable, RunnablePassthrough from langchain_core.tools import BaseTool, Tool +from pydantic import Field from langchain.agents.agent import Agent, AgentExecutor, AgentOutputParser from langchain.agents.agent_types import AgentType diff --git a/libs/langchain/langchain/agents/structured_chat/base.py b/libs/langchain/langchain/agents/structured_chat/base.py index e1403e26f2607..a520cfecf7142 100644 --- a/libs/langchain/langchain/agents/structured_chat/base.py +++ b/libs/langchain/langchain/agents/structured_chat/base.py @@ -11,10 +11,10 @@ HumanMessagePromptTemplate, SystemMessagePromptTemplate, ) -from langchain_core.pydantic_v1 import Field from langchain_core.runnables import Runnable, RunnablePassthrough from langchain_core.tools import BaseTool from langchain_core.tools.render import ToolsRenderer +from pydantic import Field from langchain.agents.agent import Agent, AgentOutputParser from langchain.agents.format_scratchpad import format_log_to_str diff --git a/libs/langchain/langchain/agents/structured_chat/output_parser.py b/libs/langchain/langchain/agents/structured_chat/output_parser.py index f73f87f4f4be4..1cdb4fe1fb4a9 100644 --- a/libs/langchain/langchain/agents/structured_chat/output_parser.py +++ b/libs/langchain/langchain/agents/structured_chat/output_parser.py @@ -8,7 +8,7 @@ from langchain_core.agents import AgentAction, AgentFinish from langchain_core.exceptions import OutputParserException from langchain_core.language_models import BaseLanguageModel -from langchain_core.pydantic_v1 import Field +from pydantic import Field from langchain.agents.agent import AgentOutputParser from langchain.agents.structured_chat.prompt import FORMAT_INSTRUCTIONS diff --git a/libs/langchain/langchain/chains/api/base.py b/libs/langchain/langchain/chains/api/base.py index 94896102dc6fd..4387cb623e06a 100644 --- a/libs/langchain/langchain/chains/api/base.py +++ b/libs/langchain/langchain/chains/api/base.py @@ -12,7 +12,8 @@ ) from langchain_core.language_models import BaseLanguageModel from langchain_core.prompts import BasePromptTemplate -from langchain_core.pydantic_v1 import Field, root_validator +from pydantic import Field, model_validator +from typing_extensions import Self from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT from langchain.chains.base import Chain @@ -227,19 +228,20 @@ def output_keys(self) -> List[str]: """ return [self.output_key] - @root_validator(pre=False, skip_on_failure=True) - def validate_api_request_prompt(cls, values: Dict) -> Dict: + @model_validator(mode="after") + def validate_api_request_prompt(self) -> Self: """Check that api request prompt expects the right variables.""" - input_vars = values["api_request_chain"].prompt.input_variables + input_vars = self.api_request_chain.prompt.input_variables expected_vars = {"question", "api_docs"} if set(input_vars) != expected_vars: raise ValueError( f"Input variables should be {expected_vars}, got {input_vars}" ) - return values + return self - @root_validator(pre=True) - def validate_limit_to_domains(cls, values: Dict) -> Dict: + @model_validator(mode="before") + @classmethod + def validate_limit_to_domains(cls, values: Dict) -> Any: """Check that allowed domains are valid.""" # This check must be a pre=True check, so that a default of None # won't be set to limit_to_domains if it's not provided. @@ -258,16 +260,16 @@ def validate_limit_to_domains(cls, values: Dict) -> Dict: ) return values - @root_validator(pre=False, skip_on_failure=True) - def validate_api_answer_prompt(cls, values: Dict) -> Dict: + @model_validator(mode="after") + def validate_api_answer_prompt(self) -> Self: """Check that api answer prompt expects the right variables.""" - input_vars = values["api_answer_chain"].prompt.input_variables + input_vars = self.api_answer_chain.prompt.input_variables expected_vars = {"question", "api_docs", "api_url", "api_response"} if set(input_vars) != expected_vars: raise ValueError( f"Input variables should be {expected_vars}, got {input_vars}" ) - return values + return self def _call( self, diff --git a/libs/langchain/langchain/chains/base.py b/libs/langchain/langchain/chains/base.py index 4c1dcc07eb231..a2d9ca57ddc79 100644 --- a/libs/langchain/langchain/chains/base.py +++ b/libs/langchain/langchain/chains/base.py @@ -21,7 +21,6 @@ from langchain_core.load.dump import dumpd from langchain_core.memory import BaseMemory from langchain_core.outputs import RunInfo -from langchain_core.pydantic_v1 import BaseModel, Field, root_validator, validator from langchain_core.runnables import ( RunnableConfig, RunnableSerializable, @@ -29,6 +28,13 @@ run_in_executor, ) from langchain_core.runnables.utils import create_model +from pydantic import ( + BaseModel, + ConfigDict, + Field, + model_validator, + validator, +) from langchain.schema import RUN_KEY @@ -96,8 +102,9 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC): callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True) """[DEPRECATED] Use `callbacks` instead.""" - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict( + arbitrary_types_allowed=True, + ) def get_input_schema( self, config: Optional[RunnableConfig] = None @@ -223,8 +230,9 @@ async def ainvoke( def _chain_type(self) -> str: raise NotImplementedError("Saving not supported for this chain type.") - @root_validator(pre=True) - def raise_callback_manager_deprecation(cls, values: Dict) -> Dict: + @model_validator(mode="before") + @classmethod + def raise_callback_manager_deprecation(cls, values: Dict) -> Any: """Raise deprecation warning if callback_manager is used.""" if values.get("callback_manager") is not None: if values.get("callbacks") is not None: diff --git a/libs/langchain/langchain/chains/combine_documents/base.py b/libs/langchain/langchain/chains/combine_documents/base.py index 00b6002da0ed4..2406cd4215f7a 100644 --- a/libs/langchain/langchain/chains/combine_documents/base.py +++ b/libs/langchain/langchain/chains/combine_documents/base.py @@ -10,10 +10,10 @@ ) from langchain_core.documents import Document from langchain_core.prompts import BasePromptTemplate, PromptTemplate -from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.runnables.config import RunnableConfig from langchain_core.runnables.utils import create_model from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter +from pydantic import BaseModel, Field from langchain.chains.base import Chain diff --git a/libs/langchain/langchain/chains/combine_documents/map_reduce.py b/libs/langchain/langchain/chains/combine_documents/map_reduce.py index 229ed45740c5d..7b4885cdea4fe 100644 --- a/libs/langchain/langchain/chains/combine_documents/map_reduce.py +++ b/libs/langchain/langchain/chains/combine_documents/map_reduce.py @@ -6,9 +6,9 @@ from langchain_core.callbacks import Callbacks from langchain_core.documents import Document -from langchain_core.pydantic_v1 import BaseModel, root_validator from langchain_core.runnables.config import RunnableConfig from langchain_core.runnables.utils import create_model +from pydantic import BaseModel, ConfigDict, model_validator from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.reduce import ReduceDocumentsChain @@ -126,12 +126,14 @@ def output_keys(self) -> List[str]: _output_keys = _output_keys + ["intermediate_steps"] return _output_keys - class Config: - arbitrary_types_allowed = True - extra = "forbid" + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + ) - @root_validator(pre=True) - def get_reduce_chain(cls, values: Dict) -> Dict: + @model_validator(mode="before") + @classmethod + def get_reduce_chain(cls, values: Dict) -> Any: """For backwards compatibility.""" if "combine_document_chain" in values: if "reduce_documents_chain" in values: @@ -153,16 +155,18 @@ def get_reduce_chain(cls, values: Dict) -> Dict: return values - @root_validator(pre=True) - def get_return_intermediate_steps(cls, values: Dict) -> Dict: + @model_validator(mode="before") + @classmethod + def get_return_intermediate_steps(cls, values: Dict) -> Any: """For backwards compatibility.""" if "return_map_steps" in values: values["return_intermediate_steps"] = values["return_map_steps"] del values["return_map_steps"] return values - @root_validator(pre=True) - def get_default_document_variable_name(cls, values: Dict) -> Dict: + @model_validator(mode="before") + @classmethod + def get_default_document_variable_name(cls, values: Dict) -> Any: """Get default document variable name, if not provided.""" if "llm_chain" not in values: raise ValueError("llm_chain must be provided") diff --git a/libs/langchain/langchain/chains/combine_documents/map_rerank.py b/libs/langchain/langchain/chains/combine_documents/map_rerank.py index 0fa346dee8bc8..61e6e226ed65d 100644 --- a/libs/langchain/langchain/chains/combine_documents/map_rerank.py +++ b/libs/langchain/langchain/chains/combine_documents/map_rerank.py @@ -6,9 +6,10 @@ from langchain_core.callbacks import Callbacks from langchain_core.documents import Document -from langchain_core.pydantic_v1 import BaseModel, root_validator from langchain_core.runnables.config import RunnableConfig from langchain_core.runnables.utils import create_model +from pydantic import BaseModel, ConfigDict, model_validator +from typing_extensions import Self from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.llm import LLMChain @@ -74,9 +75,10 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain): """Return intermediate steps. Intermediate steps include the results of calling llm_chain on each document.""" - class Config: - arbitrary_types_allowed = True - extra = "forbid" + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + ) def get_output_schema( self, config: Optional[RunnableConfig] = None @@ -104,30 +106,31 @@ def output_keys(self) -> List[str]: _output_keys += self.metadata_keys return _output_keys - @root_validator(pre=False, skip_on_failure=True) - def validate_llm_output(cls, values: Dict) -> Dict: + @model_validator(mode="after") + def validate_llm_output(self) -> Self: """Validate that the combine chain outputs a dictionary.""" - output_parser = values["llm_chain"].prompt.output_parser + output_parser = self.llm_chain.prompt.output_parser if not isinstance(output_parser, RegexParser): raise ValueError( "Output parser of llm_chain should be a RegexParser," f" got {output_parser}" ) output_keys = output_parser.output_keys - if values["rank_key"] not in output_keys: + if self.rank_key not in output_keys: raise ValueError( - f"Got {values['rank_key']} as key to rank on, but did not find " + f"Got {self.rank_key} as key to rank on, but did not find " f"it in the llm_chain output keys ({output_keys})" ) - if values["answer_key"] not in output_keys: + if self.answer_key not in output_keys: raise ValueError( - f"Got {values['answer_key']} as key to return, but did not find " + f"Got {self.answer_key} as key to return, but did not find " f"it in the llm_chain output keys ({output_keys})" ) - return values + return self - @root_validator(pre=True) - def get_default_document_variable_name(cls, values: Dict) -> Dict: + @model_validator(mode="before") + @classmethod + def get_default_document_variable_name(cls, values: Dict) -> Any: """Get default document variable name, if not provided.""" if "llm_chain" not in values: raise ValueError("llm_chain must be provided") diff --git a/libs/langchain/langchain/chains/combine_documents/reduce.py b/libs/langchain/langchain/chains/combine_documents/reduce.py index 7b2cd6c89a375..662be15e38b55 100644 --- a/libs/langchain/langchain/chains/combine_documents/reduce.py +++ b/libs/langchain/langchain/chains/combine_documents/reduce.py @@ -6,6 +6,7 @@ from langchain_core.callbacks import Callbacks from langchain_core.documents import Document +from pydantic import ConfigDict from langchain.chains.combine_documents.base import BaseCombineDocumentsChain @@ -204,9 +205,10 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain): If None, it will keep trying to collapse documents to fit token_max. Otherwise, after it reaches the max number, it will throw an error""" - class Config: - arbitrary_types_allowed = True - extra = "forbid" + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + ) @property def _collapse_chain(self) -> BaseCombineDocumentsChain: diff --git a/libs/langchain/langchain/chains/combine_documents/refine.py b/libs/langchain/langchain/chains/combine_documents/refine.py index cf2f5d9e92f50..7129ac147dbb1 100644 --- a/libs/langchain/langchain/chains/combine_documents/refine.py +++ b/libs/langchain/langchain/chains/combine_documents/refine.py @@ -8,7 +8,7 @@ from langchain_core.documents import Document from langchain_core.prompts import BasePromptTemplate, format_document from langchain_core.prompts.prompt import PromptTemplate -from langchain_core.pydantic_v1 import Field, root_validator +from pydantic import ConfigDict, Field, model_validator from langchain.chains.combine_documents.base import ( BaseCombineDocumentsChain, @@ -98,20 +98,23 @@ def output_keys(self) -> List[str]: _output_keys = _output_keys + ["intermediate_steps"] return _output_keys - class Config: - arbitrary_types_allowed = True - extra = "forbid" + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + ) - @root_validator(pre=True) - def get_return_intermediate_steps(cls, values: Dict) -> Dict: + @model_validator(mode="before") + @classmethod + def get_return_intermediate_steps(cls, values: Dict) -> Any: """For backwards compatibility.""" if "return_refine_steps" in values: values["return_intermediate_steps"] = values["return_refine_steps"] del values["return_refine_steps"] return values - @root_validator(pre=True) - def get_default_document_variable_name(cls, values: Dict) -> Dict: + @model_validator(mode="before") + @classmethod + def get_default_document_variable_name(cls, values: Dict) -> Any: """Get default document variable name, if not provided.""" if "initial_llm_chain" not in values: raise ValueError("initial_llm_chain must be provided") diff --git a/libs/langchain/langchain/chains/combine_documents/stuff.py b/libs/langchain/langchain/chains/combine_documents/stuff.py index cdecec0f40b82..26750d55fe84f 100644 --- a/libs/langchain/langchain/chains/combine_documents/stuff.py +++ b/libs/langchain/langchain/chains/combine_documents/stuff.py @@ -8,8 +8,8 @@ from langchain_core.language_models import LanguageModelLike from langchain_core.output_parsers import BaseOutputParser, StrOutputParser from langchain_core.prompts import BasePromptTemplate, format_document -from langchain_core.pydantic_v1 import Field, root_validator from langchain_core.runnables import Runnable, RunnablePassthrough +from pydantic import ConfigDict, Field, model_validator from langchain.chains.combine_documents.base import ( DEFAULT_DOCUMENT_PROMPT, @@ -156,12 +156,14 @@ class StuffDocumentsChain(BaseCombineDocumentsChain): document_separator: str = "\n\n" """The string with which to join the formatted documents""" - class Config: - arbitrary_types_allowed = True - extra = "forbid" + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + ) - @root_validator(pre=True) - def get_default_document_variable_name(cls, values: Dict) -> Dict: + @model_validator(mode="before") + @classmethod + def get_default_document_variable_name(cls, values: Dict) -> Any: """Get default document variable name, if not provided. If only one variable is present in the llm_chain.prompt, diff --git a/libs/langchain/langchain/chains/constitutional_ai/models.py b/libs/langchain/langchain/chains/constitutional_ai/models.py index 8058553eb257a..7f9a623459913 100644 --- a/libs/langchain/langchain/chains/constitutional_ai/models.py +++ b/libs/langchain/langchain/chains/constitutional_ai/models.py @@ -1,6 +1,6 @@ """Models for the Constitutional AI chain.""" -from langchain_core.pydantic_v1 import BaseModel +from pydantic import BaseModel class ConstitutionalPrinciple(BaseModel): diff --git a/libs/langchain/langchain/chains/conversation/base.py b/libs/langchain/langchain/chains/conversation/base.py index dee4b9dbc0cce..881eed49e8c61 100644 --- a/libs/langchain/langchain/chains/conversation/base.py +++ b/libs/langchain/langchain/chains/conversation/base.py @@ -1,11 +1,12 @@ """Chain that carries on a conversation and calls an LLM.""" -from typing import Dict, List +from typing import List from langchain_core._api import deprecated from langchain_core.memory import BaseMemory from langchain_core.prompts import BasePromptTemplate -from langchain_core.pydantic_v1 import Field, root_validator +from pydantic import ConfigDict, Field, model_validator +from typing_extensions import Self from langchain.chains.conversation.prompt import PROMPT from langchain.chains.llm import LLMChain @@ -110,9 +111,10 @@ def get_session_history(session_id: str) -> InMemoryChatMessageHistory: input_key: str = "input" #: :meta private: output_key: str = "response" #: :meta private: - class Config: - arbitrary_types_allowed = True - extra = "forbid" + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + ) @classmethod def is_lc_serializable(cls) -> bool: @@ -123,17 +125,17 @@ def input_keys(self) -> List[str]: """Use this since so some prompt vars come from history.""" return [self.input_key] - @root_validator(pre=False, skip_on_failure=True) - def validate_prompt_input_variables(cls, values: Dict) -> Dict: + @model_validator(mode="after") + def validate_prompt_input_variables(self) -> Self: """Validate that prompt input variables are consistent.""" - memory_keys = values["memory"].memory_variables - input_key = values["input_key"] + memory_keys = self.memory.memory_variables + input_key = self.input_key if input_key in memory_keys: raise ValueError( f"The input key {input_key} was also found in the memory keys " f"({memory_keys}) - please provide keys that don't overlap." ) - prompt_variables = values["prompt"].input_variables + prompt_variables = self.prompt.input_variables expected_keys = memory_keys + [input_key] if set(expected_keys) != set(prompt_variables): raise ValueError( @@ -141,4 +143,4 @@ def validate_prompt_input_variables(cls, values: Dict) -> Dict: f"{prompt_variables}, but got {memory_keys} as inputs from " f"memory, and {input_key} as the normal input key." ) - return values + return self diff --git a/libs/langchain/langchain/chains/conversational_retrieval/base.py b/libs/langchain/langchain/chains/conversational_retrieval/base.py index 4251ced2fdf1a..3c653c433362e 100644 --- a/libs/langchain/langchain/chains/conversational_retrieval/base.py +++ b/libs/langchain/langchain/chains/conversational_retrieval/base.py @@ -18,10 +18,10 @@ from langchain_core.language_models import BaseLanguageModel from langchain_core.messages import BaseMessage from langchain_core.prompts import BasePromptTemplate -from langchain_core.pydantic_v1 import BaseModel, Field, root_validator from langchain_core.retrievers import BaseRetriever from langchain_core.runnables import RunnableConfig from langchain_core.vectorstores import VectorStore +from pydantic import BaseModel, ConfigDict, Field, model_validator from langchain.chains.base import Chain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain @@ -96,10 +96,11 @@ class BaseConversationalRetrievalChain(Chain): """If specified, the chain will return a fixed response if no docs are found for the question. """ - class Config: - allow_population_by_field_name = True - arbitrary_types_allowed = True - extra = "forbid" + model_config = ConfigDict( + populate_by_name=True, + arbitrary_types_allowed=True, + extra="forbid", + ) @property def input_keys(self) -> List[str]: @@ -482,8 +483,9 @@ class ChatVectorDBChain(BaseConversationalRetrievalChain): def _chain_type(self) -> str: return "chat-vector-db" - @root_validator(pre=True) - def raise_deprecation(cls, values: Dict) -> Dict: + @model_validator(mode="before") + @classmethod + def raise_deprecation(cls, values: Dict) -> Any: warnings.warn( "`ChatVectorDBChain` is deprecated - " "please use `from langchain.chains import ConversationalRetrievalChain`" diff --git a/libs/langchain/langchain/chains/elasticsearch_database/base.py b/libs/langchain/langchain/chains/elasticsearch_database/base.py index 89875f2d8a425..85bf7de93d29a 100644 --- a/libs/langchain/langchain/chains/elasticsearch_database/base.py +++ b/libs/langchain/langchain/chains/elasticsearch_database/base.py @@ -9,8 +9,9 @@ from langchain_core.output_parsers import BaseOutputParser, StrOutputParser from langchain_core.output_parsers.json import SimpleJsonOutputParser from langchain_core.prompts import BasePromptTemplate -from langchain_core.pydantic_v1 import root_validator from langchain_core.runnables import Runnable +from pydantic import ConfigDict, model_validator +from typing_extensions import Self from langchain.chains.base import Chain from langchain.chains.elasticsearch_database.prompts import ANSWER_PROMPT, DSL_PROMPT @@ -51,17 +52,18 @@ class ElasticsearchDatabaseChain(Chain): return_intermediate_steps: bool = False """Whether or not to return the intermediate steps along with the final answer.""" - class Config: - arbitrary_types_allowed = True - extra = "forbid" + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + ) - @root_validator(pre=False, skip_on_failure=True) - def validate_indices(cls, values: dict) -> dict: - if values["include_indices"] and values["ignore_indices"]: + @model_validator(mode="after") + def validate_indices(self) -> Self: + if self.include_indices and self.ignore_indices: raise ValueError( "Cannot specify both 'include_indices' and 'ignore_indices'." ) - return values + return self @property def input_keys(self) -> List[str]: diff --git a/libs/langchain/langchain/chains/flare/base.py b/libs/langchain/langchain/chains/flare/base.py index 53f3dd1e40e82..caf64fe77aa40 100644 --- a/libs/langchain/langchain/chains/flare/base.py +++ b/libs/langchain/langchain/chains/flare/base.py @@ -11,9 +11,9 @@ from langchain_core.messages import AIMessage from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import BasePromptTemplate -from langchain_core.pydantic_v1 import Field from langchain_core.retrievers import BaseRetriever from langchain_core.runnables import Runnable +from pydantic import Field from langchain.chains.base import Chain from langchain.chains.flare.prompts import ( diff --git a/libs/langchain/langchain/chains/hyde/base.py b/libs/langchain/langchain/chains/hyde/base.py index 833999127b659..cf64b29cada46 100644 --- a/libs/langchain/langchain/chains/hyde/base.py +++ b/libs/langchain/langchain/chains/hyde/base.py @@ -14,6 +14,7 @@ from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import BasePromptTemplate from langchain_core.runnables import Runnable +from pydantic import ConfigDict from langchain.chains.base import Chain from langchain.chains.hyde.prompts import PROMPT_MAP @@ -29,9 +30,10 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings): base_embeddings: Embeddings llm_chain: Runnable - class Config: - arbitrary_types_allowed = True - extra = "forbid" + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + ) @property def input_keys(self) -> List[str]: diff --git a/libs/langchain/langchain/chains/llm.py b/libs/langchain/langchain/chains/llm.py index feb3468fc44e4..f69aad7838bde 100644 --- a/libs/langchain/langchain/chains/llm.py +++ b/libs/langchain/langchain/chains/llm.py @@ -23,7 +23,6 @@ from langchain_core.outputs import ChatGeneration, Generation, LLMResult from langchain_core.prompt_values import PromptValue from langchain_core.prompts import BasePromptTemplate, PromptTemplate -from langchain_core.pydantic_v1 import Field from langchain_core.runnables import ( Runnable, RunnableBinding, @@ -32,6 +31,7 @@ ) from langchain_core.runnables.configurable import DynamicRunnable from langchain_core.utils.input import get_colored_text +from pydantic import ConfigDict, Field from langchain.chains.base import Chain @@ -95,9 +95,10 @@ def is_lc_serializable(self) -> bool: If false, will return a bunch of extra information about the generation.""" llm_kwargs: dict = Field(default_factory=dict) - class Config: - arbitrary_types_allowed = True - extra = "forbid" + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + ) @property def input_keys(self) -> List[str]: diff --git a/libs/langchain/langchain/chains/llm_checker/base.py b/libs/langchain/langchain/chains/llm_checker/base.py index ea2bc546a5791..bfff3d1b40dd8 100644 --- a/libs/langchain/langchain/chains/llm_checker/base.py +++ b/libs/langchain/langchain/chains/llm_checker/base.py @@ -9,7 +9,7 @@ from langchain_core.callbacks import CallbackManagerForChainRun from langchain_core.language_models import BaseLanguageModel from langchain_core.prompts import PromptTemplate -from langchain_core.pydantic_v1 import root_validator +from pydantic import ConfigDict, model_validator from langchain.chains.base import Chain from langchain.chains.llm import LLMChain @@ -100,12 +100,14 @@ class LLMCheckerChain(Chain): input_key: str = "query" #: :meta private: output_key: str = "result" #: :meta private: - class Config: - arbitrary_types_allowed = True - extra = "forbid" + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + ) - @root_validator(pre=True) - def raise_deprecation(cls, values: Dict) -> Dict: + @model_validator(mode="before") + @classmethod + def raise_deprecation(cls, values: Dict) -> Any: if "llm" in values: warnings.warn( "Directly instantiating an LLMCheckerChain with an llm is deprecated. " diff --git a/libs/langchain/langchain/chains/llm_math/base.py b/libs/langchain/langchain/chains/llm_math/base.py index e7fd89dcd542c..5bc51bf253e54 100644 --- a/libs/langchain/langchain/chains/llm_math/base.py +++ b/libs/langchain/langchain/chains/llm_math/base.py @@ -14,7 +14,7 @@ ) from langchain_core.language_models import BaseLanguageModel from langchain_core.prompts import BasePromptTemplate -from langchain_core.pydantic_v1 import root_validator +from pydantic import ConfigDict, model_validator from langchain.chains.base import Chain from langchain.chains.llm import LLMChain @@ -156,12 +156,14 @@ async def acall_model(state: ChainState, config: RunnableConfig): input_key: str = "question" #: :meta private: output_key: str = "answer" #: :meta private: - class Config: - arbitrary_types_allowed = True - extra = "forbid" + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + ) - @root_validator(pre=True) - def raise_deprecation(cls, values: Dict) -> Dict: + @model_validator(mode="before") + @classmethod + def raise_deprecation(cls, values: Dict) -> Any: try: import numexpr # noqa: F401 except ImportError: diff --git a/libs/langchain/langchain/chains/llm_summarization_checker/base.py b/libs/langchain/langchain/chains/llm_summarization_checker/base.py index f177f401529e4..c7d075dbf4405 100644 --- a/libs/langchain/langchain/chains/llm_summarization_checker/base.py +++ b/libs/langchain/langchain/chains/llm_summarization_checker/base.py @@ -10,7 +10,7 @@ from langchain_core.callbacks import CallbackManagerForChainRun from langchain_core.language_models import BaseLanguageModel from langchain_core.prompts.prompt import PromptTemplate -from langchain_core.pydantic_v1 import root_validator +from pydantic import ConfigDict, model_validator from langchain.chains.base import Chain from langchain.chains.llm import LLMChain @@ -105,12 +105,14 @@ class LLMSummarizationCheckerChain(Chain): max_checks: int = 2 """Maximum number of times to check the assertions. Default to double-checking.""" - class Config: - arbitrary_types_allowed = True - extra = "forbid" + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + ) - @root_validator(pre=True) - def raise_deprecation(cls, values: Dict) -> Dict: + @model_validator(mode="before") + @classmethod + def raise_deprecation(cls, values: Dict) -> Any: if "llm" in values: warnings.warn( "Directly instantiating an LLMSummarizationCheckerChain with an llm is " diff --git a/libs/langchain/langchain/chains/mapreduce.py b/libs/langchain/langchain/chains/mapreduce.py index 1eaccf67a850f..c153b7f9a6dae 100644 --- a/libs/langchain/langchain/chains/mapreduce.py +++ b/libs/langchain/langchain/chains/mapreduce.py @@ -14,6 +14,7 @@ from langchain_core.language_models import BaseLanguageModel from langchain_core.prompts import BasePromptTemplate from langchain_text_splitters import TextSplitter +from pydantic import ConfigDict from langchain.chains import ReduceDocumentsChain from langchain.chains.base import Chain @@ -77,9 +78,10 @@ def from_params( **kwargs, ) - class Config: - arbitrary_types_allowed = True - extra = "forbid" + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + ) @property def input_keys(self) -> List[str]: diff --git a/libs/langchain/langchain/chains/moderation.py b/libs/langchain/langchain/chains/moderation.py index a4b3551491c45..670b4773e3d63 100644 --- a/libs/langchain/langchain/chains/moderation.py +++ b/libs/langchain/langchain/chains/moderation.py @@ -6,8 +6,8 @@ AsyncCallbackManagerForChainRun, CallbackManagerForChainRun, ) -from langchain_core.pydantic_v1 import Field, root_validator from langchain_core.utils import check_package_version, get_from_dict_or_env +from pydantic import Field, model_validator from langchain.chains.base import Chain @@ -40,8 +40,9 @@ class OpenAIModerationChain(Chain): openai_organization: Optional[str] = None openai_pre_1_0: bool = Field(default=None) - @root_validator(pre=True) - def validate_environment(cls, values: Dict) -> Dict: + @model_validator(mode="before") + @classmethod + def validate_environment(cls, values: Dict) -> Any: """Validate that api key and python package exists in environment.""" openai_api_key = get_from_dict_or_env( values, "openai_api_key", "OPENAI_API_KEY" diff --git a/libs/langchain/langchain/chains/natbot/base.py b/libs/langchain/langchain/chains/natbot/base.py index e92131ff35cad..aca7fca8a7d8b 100644 --- a/libs/langchain/langchain/chains/natbot/base.py +++ b/libs/langchain/langchain/chains/natbot/base.py @@ -9,8 +9,8 @@ from langchain_core.callbacks import CallbackManagerForChainRun from langchain_core.language_models import BaseLanguageModel from langchain_core.output_parsers import StrOutputParser -from langchain_core.pydantic_v1 import root_validator from langchain_core.runnables import Runnable +from pydantic import ConfigDict, model_validator from langchain.chains.base import Chain from langchain.chains.natbot.prompt import PROMPT @@ -59,12 +59,14 @@ class NatBotChain(Chain): previous_command: str = "" #: :meta private: output_key: str = "command" #: :meta private: - class Config: - arbitrary_types_allowed = True - extra = "forbid" + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + ) - @root_validator(pre=True) - def raise_deprecation(cls, values: Dict) -> Dict: + @model_validator(mode="before") + @classmethod + def raise_deprecation(cls, values: Dict) -> Any: if "llm" in values: warnings.warn( "Directly instantiating an NatBotChain with an llm is deprecated. " diff --git a/libs/langchain/langchain/chains/openai_functions/base.py b/libs/langchain/langchain/chains/openai_functions/base.py index 568d992de990f..729313ab6b4d8 100644 --- a/libs/langchain/langchain/chains/openai_functions/base.py +++ b/libs/langchain/langchain/chains/openai_functions/base.py @@ -19,11 +19,11 @@ PydanticAttrOutputFunctionsParser, ) from langchain_core.prompts import BasePromptTemplate -from langchain_core.pydantic_v1 import BaseModel from langchain_core.utils.function_calling import ( PYTHON_TO_JSON_TYPES, convert_to_openai_function, ) +from pydantic import BaseModel from langchain.chains import LLMChain from langchain.chains.structured_output.base import ( diff --git a/libs/langchain/langchain/chains/openai_functions/citation_fuzzy_match.py b/libs/langchain/langchain/chains/openai_functions/citation_fuzzy_match.py index 038489d13a696..e9a83e8abc6e8 100644 --- a/libs/langchain/langchain/chains/openai_functions/citation_fuzzy_match.py +++ b/libs/langchain/langchain/chains/openai_functions/citation_fuzzy_match.py @@ -5,8 +5,8 @@ from langchain_core.messages import HumanMessage, SystemMessage from langchain_core.output_parsers.openai_functions import PydanticOutputFunctionsParser from langchain_core.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate -from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.runnables import Runnable +from pydantic import BaseModel, Field from langchain.chains.llm import LLMChain from langchain.chains.openai_functions.utils import get_llm_kwargs diff --git a/libs/langchain/langchain/chains/openai_functions/extraction.py b/libs/langchain/langchain/chains/openai_functions/extraction.py index ec76ad8a6cbbe..f6b9debad4959 100644 --- a/libs/langchain/langchain/chains/openai_functions/extraction.py +++ b/libs/langchain/langchain/chains/openai_functions/extraction.py @@ -7,7 +7,7 @@ PydanticAttrOutputFunctionsParser, ) from langchain_core.prompts import BasePromptTemplate, ChatPromptTemplate -from langchain_core.pydantic_v1 import BaseModel +from pydantic import BaseModel from langchain.chains.base import Chain from langchain.chains.llm import LLMChain diff --git a/libs/langchain/langchain/chains/openai_functions/qa_with_structure.py b/libs/langchain/langchain/chains/openai_functions/qa_with_structure.py index f13e2f9e522f5..5bab1c1ced60a 100644 --- a/libs/langchain/langchain/chains/openai_functions/qa_with_structure.py +++ b/libs/langchain/langchain/chains/openai_functions/qa_with_structure.py @@ -10,8 +10,8 @@ ) from langchain_core.prompts import PromptTemplate from langchain_core.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate -from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.utils.pydantic import is_basemodel_subclass +from pydantic import BaseModel, Field from langchain.chains.llm import LLMChain from langchain.chains.openai_functions.utils import get_llm_kwargs diff --git a/libs/langchain/langchain/chains/openai_tools/extraction.py b/libs/langchain/langchain/chains/openai_tools/extraction.py index 55251f5186784..189ef77d9d4d1 100644 --- a/libs/langchain/langchain/chains/openai_tools/extraction.py +++ b/libs/langchain/langchain/chains/openai_tools/extraction.py @@ -4,9 +4,9 @@ from langchain_core.language_models import BaseLanguageModel from langchain_core.output_parsers.openai_tools import PydanticToolsParser from langchain_core.prompts import ChatPromptTemplate -from langchain_core.pydantic_v1 import BaseModel from langchain_core.runnables import Runnable from langchain_core.utils.function_calling import convert_pydantic_to_openai_function +from pydantic import BaseModel _EXTRACTION_TEMPLATE = """Extract and save the relevant entities mentioned \ in the following passage together with their properties. diff --git a/libs/langchain/langchain/chains/prompt_selector.py b/libs/langchain/langchain/chains/prompt_selector.py index 453e2ea03577f..4014cdc1fbbbf 100644 --- a/libs/langchain/langchain/chains/prompt_selector.py +++ b/libs/langchain/langchain/chains/prompt_selector.py @@ -5,7 +5,7 @@ from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.language_models.llms import BaseLLM from langchain_core.prompts import BasePromptTemplate -from langchain_core.pydantic_v1 import BaseModel, Field +from pydantic import BaseModel, Field class BasePromptSelector(BaseModel, ABC): diff --git a/libs/langchain/langchain/chains/qa_generation/base.py b/libs/langchain/langchain/chains/qa_generation/base.py index b66b8a5442599..a55c078610124 100644 --- a/libs/langchain/langchain/chains/qa_generation/base.py +++ b/libs/langchain/langchain/chains/qa_generation/base.py @@ -7,8 +7,8 @@ from langchain_core.callbacks import CallbackManagerForChainRun from langchain_core.language_models import BaseLanguageModel from langchain_core.prompts import BasePromptTemplate -from langchain_core.pydantic_v1 import Field from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter +from pydantic import Field from langchain.chains.base import Chain from langchain.chains.llm import LLMChain diff --git a/libs/langchain/langchain/chains/qa_with_sources/base.py b/libs/langchain/langchain/chains/qa_with_sources/base.py index aed2d57cf91e5..495f0ab79e78b 100644 --- a/libs/langchain/langchain/chains/qa_with_sources/base.py +++ b/libs/langchain/langchain/chains/qa_with_sources/base.py @@ -15,7 +15,7 @@ from langchain_core.documents import Document from langchain_core.language_models import BaseLanguageModel from langchain_core.prompts import BasePromptTemplate -from langchain_core.pydantic_v1 import root_validator +from pydantic import ConfigDict, model_validator from langchain.chains import ReduceDocumentsChain from langchain.chains.base import Chain @@ -97,9 +97,10 @@ def from_chain_type( ) return cls(combine_documents_chain=combine_documents_chain, **kwargs) - class Config: - arbitrary_types_allowed = True - extra = "forbid" + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + ) @property def input_keys(self) -> List[str]: @@ -120,8 +121,9 @@ def output_keys(self) -> List[str]: _output_keys = _output_keys + ["source_documents"] return _output_keys - @root_validator(pre=True) - def validate_naming(cls, values: Dict) -> Dict: + @model_validator(mode="before") + @classmethod + def validate_naming(cls, values: Dict) -> Any: """Fix backwards compatibility in naming.""" if "combine_document_chain" in values: values["combine_documents_chain"] = values.pop("combine_document_chain") diff --git a/libs/langchain/langchain/chains/qa_with_sources/retrieval.py b/libs/langchain/langchain/chains/qa_with_sources/retrieval.py index f4a924c8dd162..95485b9fe0e3c 100644 --- a/libs/langchain/langchain/chains/qa_with_sources/retrieval.py +++ b/libs/langchain/langchain/chains/qa_with_sources/retrieval.py @@ -7,8 +7,8 @@ CallbackManagerForChainRun, ) from langchain_core.documents import Document -from langchain_core.pydantic_v1 import Field from langchain_core.retrievers import BaseRetriever +from pydantic import Field from langchain.chains.combine_documents.stuff import StuffDocumentsChain from langchain.chains.qa_with_sources.base import BaseQAWithSourcesChain diff --git a/libs/langchain/langchain/chains/qa_with_sources/vector_db.py b/libs/langchain/langchain/chains/qa_with_sources/vector_db.py index ca594994f3a6b..6330db38bc9b4 100644 --- a/libs/langchain/langchain/chains/qa_with_sources/vector_db.py +++ b/libs/langchain/langchain/chains/qa_with_sources/vector_db.py @@ -8,8 +8,8 @@ CallbackManagerForChainRun, ) from langchain_core.documents import Document -from langchain_core.pydantic_v1 import Field, root_validator from langchain_core.vectorstores import VectorStore +from pydantic import Field, model_validator from langchain.chains.combine_documents.stuff import StuffDocumentsChain from langchain.chains.qa_with_sources.base import BaseQAWithSourcesChain @@ -61,8 +61,9 @@ async def _aget_docs( ) -> List[Document]: raise NotImplementedError("VectorDBQAWithSourcesChain does not support async") - @root_validator(pre=True) - def raise_deprecation(cls, values: Dict) -> Dict: + @model_validator(mode="before") + @classmethod + def raise_deprecation(cls, values: Dict) -> Any: warnings.warn( "`VectorDBQAWithSourcesChain` is deprecated - " "please use `from langchain.chains import RetrievalQAWithSourcesChain`" diff --git a/libs/langchain/langchain/chains/query_constructor/schema.py b/libs/langchain/langchain/chains/query_constructor/schema.py index 585addc7919a3..56103d9a9fede 100644 --- a/libs/langchain/langchain/chains/query_constructor/schema.py +++ b/libs/langchain/langchain/chains/query_constructor/schema.py @@ -1,4 +1,4 @@ -from langchain_core.pydantic_v1 import BaseModel +from pydantic import BaseModel, ConfigDict class AttributeInfo(BaseModel): @@ -8,6 +8,7 @@ class AttributeInfo(BaseModel): description: str type: str - class Config: - arbitrary_types_allowed = True - frozen = True + model_config = ConfigDict( + arbitrary_types_allowed=True, + frozen=True, + ) diff --git a/libs/langchain/langchain/chains/retrieval_qa/base.py b/libs/langchain/langchain/chains/retrieval_qa/base.py index 689dd8b0c217b..3d688474bd39a 100644 --- a/libs/langchain/langchain/chains/retrieval_qa/base.py +++ b/libs/langchain/langchain/chains/retrieval_qa/base.py @@ -16,9 +16,9 @@ from langchain_core.documents import Document from langchain_core.language_models import BaseLanguageModel from langchain_core.prompts import PromptTemplate -from langchain_core.pydantic_v1 import Field, root_validator from langchain_core.retrievers import BaseRetriever from langchain_core.vectorstores import VectorStore +from pydantic import ConfigDict, Field, model_validator from langchain.chains.base import Chain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain @@ -47,10 +47,11 @@ class BaseRetrievalQA(Chain): return_source_documents: bool = False """Return the source documents or not.""" - class Config: - allow_population_by_field_name = True - arbitrary_types_allowed = True - extra = "forbid" + model_config = ConfigDict( + populate_by_name=True, + arbitrary_types_allowed=True, + extra="forbid", + ) @property def input_keys(self) -> List[str]: @@ -309,16 +310,18 @@ class VectorDBQA(BaseRetrievalQA): search_kwargs: Dict[str, Any] = Field(default_factory=dict) """Extra search args.""" - @root_validator(pre=True) - def raise_deprecation(cls, values: Dict) -> Dict: + @model_validator(mode="before") + @classmethod + def raise_deprecation(cls, values: Dict) -> Any: warnings.warn( "`VectorDBQA` is deprecated - " "please use `from langchain.chains import RetrievalQA`" ) return values - @root_validator(pre=True) - def validate_search_type(cls, values: Dict) -> Dict: + @model_validator(mode="before") + @classmethod + def validate_search_type(cls, values: Dict) -> Any: """Validate search type.""" if "search_type" in values: search_type = values["search_type"] diff --git a/libs/langchain/langchain/chains/router/base.py b/libs/langchain/langchain/chains/router/base.py index d0b680dd952b1..fa489c8110c33 100644 --- a/libs/langchain/langchain/chains/router/base.py +++ b/libs/langchain/langchain/chains/router/base.py @@ -10,6 +10,7 @@ CallbackManagerForChainRun, Callbacks, ) +from pydantic import ConfigDict from langchain.chains.base import Chain @@ -60,9 +61,10 @@ class MultiRouteChain(Chain): """If True, use default_chain when an invalid destination name is provided. Defaults to False.""" - class Config: - arbitrary_types_allowed = True - extra = "forbid" + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + ) @property def input_keys(self) -> List[str]: diff --git a/libs/langchain/langchain/chains/router/embedding_router.py b/libs/langchain/langchain/chains/router/embedding_router.py index a1bc126a49f2a..0f44dda02ff43 100644 --- a/libs/langchain/langchain/chains/router/embedding_router.py +++ b/libs/langchain/langchain/chains/router/embedding_router.py @@ -9,6 +9,7 @@ from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.vectorstores import VectorStore +from pydantic import ConfigDict from langchain.chains.router.base import RouterChain @@ -19,9 +20,10 @@ class EmbeddingRouterChain(RouterChain): vectorstore: VectorStore routing_keys: List[str] = ["query"] - class Config: - arbitrary_types_allowed = True - extra = "forbid" + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + ) @property def input_keys(self) -> List[str]: diff --git a/libs/langchain/langchain/chains/router/llm_router.py b/libs/langchain/langchain/chains/router/llm_router.py index 132f350e819c9..aa72ce4c22a25 100644 --- a/libs/langchain/langchain/chains/router/llm_router.py +++ b/libs/langchain/langchain/chains/router/llm_router.py @@ -13,8 +13,9 @@ from langchain_core.language_models import BaseLanguageModel from langchain_core.output_parsers import BaseOutputParser from langchain_core.prompts import BasePromptTemplate -from langchain_core.pydantic_v1 import root_validator from langchain_core.utils.json import parse_and_check_json_markdown +from pydantic import model_validator +from typing_extensions import Self from langchain.chains import LLMChain from langchain.chains.router.base import RouterChain @@ -100,9 +101,9 @@ class RouteQuery(TypedDict): llm_chain: LLMChain """LLM chain used to perform routing""" - @root_validator(pre=False, skip_on_failure=True) - def validate_prompt(cls, values: dict) -> dict: - prompt = values["llm_chain"].prompt + @model_validator(mode="after") + def validate_prompt(self) -> Self: + prompt = self.llm_chain.prompt if prompt.output_parser is None: raise ValueError( "LLMRouterChain requires base llm_chain prompt to have an output" @@ -110,7 +111,7 @@ def validate_prompt(cls, values: dict) -> dict: " 'destination' and 'next_inputs'. Received a prompt with no output" " parser." ) - return values + return self @property def input_keys(self) -> List[str]: diff --git a/libs/langchain/langchain/chains/sequential.py b/libs/langchain/langchain/chains/sequential.py index e75300f7cbf84..b19f65e9aa2c3 100644 --- a/libs/langchain/langchain/chains/sequential.py +++ b/libs/langchain/langchain/chains/sequential.py @@ -6,8 +6,9 @@ AsyncCallbackManagerForChainRun, CallbackManagerForChainRun, ) -from langchain_core.pydantic_v1 import root_validator from langchain_core.utils.input import get_color_mapping +from pydantic import ConfigDict, model_validator +from typing_extensions import Self from langchain.chains.base import Chain @@ -20,9 +21,10 @@ class SequentialChain(Chain): output_variables: List[str] #: :meta private: return_all: bool = False - class Config: - arbitrary_types_allowed = True - extra = "forbid" + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + ) @property def input_keys(self) -> List[str]: @@ -40,8 +42,9 @@ def output_keys(self) -> List[str]: """ return self.output_variables - @root_validator(pre=True) - def validate_chains(cls, values: Dict) -> Dict: + @model_validator(mode="before") + @classmethod + def validate_chains(cls, values: Dict) -> Any: """Validate that the correct inputs exist for all chains.""" chains = values["chains"] input_variables = values["input_variables"] @@ -129,9 +132,10 @@ class SimpleSequentialChain(Chain): input_key: str = "input" #: :meta private: output_key: str = "output" #: :meta private: - class Config: - arbitrary_types_allowed = True - extra = "forbid" + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + ) @property def input_keys(self) -> List[str]: @@ -149,10 +153,10 @@ def output_keys(self) -> List[str]: """ return [self.output_key] - @root_validator(pre=False, skip_on_failure=True) - def validate_chains(cls, values: Dict) -> Dict: + @model_validator(mode="after") + def validate_chains(self) -> Self: """Validate that chains are all single input/output.""" - for chain in values["chains"]: + for chain in self.chains: if len(chain.input_keys) != 1: raise ValueError( "Chains used in SimplePipeline should all have one input, got " @@ -163,7 +167,7 @@ def validate_chains(cls, values: Dict) -> Dict: "Chains used in SimplePipeline should all have one output, got " f"{chain} with {len(chain.output_keys)} outputs." ) - return values + return self def _call( self, diff --git a/libs/langchain/langchain/chains/structured_output/base.py b/libs/langchain/langchain/chains/structured_output/base.py index 14526d014ccc4..0cd8be4f144d3 100644 --- a/libs/langchain/langchain/chains/structured_output/base.py +++ b/libs/langchain/langchain/chains/structured_output/base.py @@ -18,13 +18,13 @@ PydanticToolsParser, ) from langchain_core.prompts import BasePromptTemplate -from langchain_core.pydantic_v1 import BaseModel from langchain_core.runnables import Runnable from langchain_core.utils.function_calling import ( convert_to_openai_function, convert_to_openai_tool, ) from langchain_core.utils.pydantic import is_basemodel_subclass +from pydantic import BaseModel @deprecated( diff --git a/libs/langchain/langchain/chains/transform.py b/libs/langchain/langchain/chains/transform.py index e95bfc9a85a84..2812722369b72 100644 --- a/libs/langchain/langchain/chains/transform.py +++ b/libs/langchain/langchain/chains/transform.py @@ -8,7 +8,7 @@ AsyncCallbackManagerForChainRun, CallbackManagerForChainRun, ) -from langchain_core.pydantic_v1 import Field +from pydantic import Field from langchain.chains.base import Chain diff --git a/libs/langchain/langchain/chat_models/base.py b/libs/langchain/langchain/chat_models/base.py index ce50f7b7b095d..63411ea152993 100644 --- a/libs/langchain/langchain/chat_models/base.py +++ b/libs/langchain/langchain/chat_models/base.py @@ -30,11 +30,11 @@ generate_from_stream, ) from langchain_core.messages import AnyMessage, BaseMessage -from langchain_core.pydantic_v1 import BaseModel from langchain_core.runnables import Runnable, RunnableConfig from langchain_core.runnables.schema import StreamEvent from langchain_core.tools import BaseTool from langchain_core.tracers import RunLog, RunLogPatch +from pydantic import BaseModel from typing_extensions import TypeAlias __all__ = [ diff --git a/libs/langchain/langchain/evaluation/agents/trajectory_eval_chain.py b/libs/langchain/langchain/evaluation/agents/trajectory_eval_chain.py index 791b95f64cf7a..1d52c5a78e3e3 100644 --- a/libs/langchain/langchain/evaluation/agents/trajectory_eval_chain.py +++ b/libs/langchain/langchain/evaluation/agents/trajectory_eval_chain.py @@ -28,8 +28,8 @@ from langchain_core.language_models import BaseLanguageModel from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.output_parsers import BaseOutputParser -from langchain_core.pydantic_v1 import Field from langchain_core.tools import BaseTool +from pydantic import ConfigDict, Field from langchain.chains.llm import LLMChain from langchain.evaluation.agents.trajectory_eval_prompt import ( @@ -156,8 +156,9 @@ def geography_answers(country: str, question: str) -> str: return_reasoning: bool = False # :meta private: """DEPRECATED. Reasoning always returned.""" - class Config: - extra = "ignore" + model_config = ConfigDict( + extra="ignore", + ) @property def requires_reference(self) -> bool: diff --git a/libs/langchain/langchain/evaluation/comparison/eval_chain.py b/libs/langchain/langchain/evaluation/comparison/eval_chain.py index d76f836ad8725..46b9001e48fd7 100644 --- a/libs/langchain/langchain/evaluation/comparison/eval_chain.py +++ b/libs/langchain/langchain/evaluation/comparison/eval_chain.py @@ -10,7 +10,7 @@ from langchain_core.language_models import BaseLanguageModel from langchain_core.output_parsers import BaseOutputParser from langchain_core.prompts.prompt import PromptTemplate -from langchain_core.pydantic_v1 import Field +from pydantic import ConfigDict, Field from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple from langchain.chains.llm import LLMChain @@ -191,8 +191,9 @@ class PairwiseStringEvalChain(PairwiseStringEvaluator, LLMEvalChain, LLMChain): def is_lc_serializable(cls) -> bool: return False - class Config: - extra = "ignore" + model_config = ConfigDict( + extra="ignore", + ) @property def requires_reference(self) -> bool: diff --git a/libs/langchain/langchain/evaluation/criteria/eval_chain.py b/libs/langchain/langchain/evaluation/criteria/eval_chain.py index 34fc656cbc7e5..896b1efceabc2 100644 --- a/libs/langchain/langchain/evaluation/criteria/eval_chain.py +++ b/libs/langchain/langchain/evaluation/criteria/eval_chain.py @@ -8,7 +8,7 @@ from langchain_core.language_models import BaseLanguageModel from langchain_core.output_parsers import BaseOutputParser from langchain_core.prompts import BasePromptTemplate -from langchain_core.pydantic_v1 import Field +from pydantic import ConfigDict, Field from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple from langchain.chains.llm import LLMChain @@ -236,8 +236,9 @@ class CriteriaEvalChain(StringEvaluator, LLMEvalChain, LLMChain): def is_lc_serializable(cls) -> bool: return False - class Config: - extra = "ignore" + model_config = ConfigDict( + extra="ignore", + ) @property def requires_reference(self) -> bool: diff --git a/libs/langchain/langchain/evaluation/embedding_distance/base.py b/libs/langchain/langchain/evaluation/embedding_distance/base.py index d983c72fbf00f..569838841cd04 100644 --- a/libs/langchain/langchain/evaluation/embedding_distance/base.py +++ b/libs/langchain/langchain/evaluation/embedding_distance/base.py @@ -10,8 +10,8 @@ Callbacks, ) from langchain_core.embeddings import Embeddings -from langchain_core.pydantic_v1 import Field from langchain_core.utils import pre_init +from pydantic import ConfigDict, Field from langchain.chains.base import Chain from langchain.evaluation.schema import PairwiseStringEvaluator, StringEvaluator @@ -113,8 +113,9 @@ def _validate_tiktoken_installed(cls, values: Dict[str, Any]) -> Dict[str, Any]: ) return values - class Config: - arbitrary_types_allowed: bool = True + model_config = ConfigDict( + arbitrary_types_allowed=True, + ) @property def output_keys(self) -> List[str]: diff --git a/libs/langchain/langchain/evaluation/qa/eval_chain.py b/libs/langchain/langchain/evaluation/qa/eval_chain.py index 0204d8fe90161..345bbd87bc9f7 100644 --- a/libs/langchain/langchain/evaluation/qa/eval_chain.py +++ b/libs/langchain/langchain/evaluation/qa/eval_chain.py @@ -9,6 +9,7 @@ from langchain_core.callbacks.manager import Callbacks from langchain_core.language_models import BaseLanguageModel from langchain_core.prompts import PromptTemplate +from pydantic import ConfigDict from langchain.chains.llm import LLMChain from langchain.evaluation.qa.eval_prompt import CONTEXT_PROMPT, COT_PROMPT, PROMPT @@ -72,8 +73,9 @@ class QAEvalChain(LLMChain, StringEvaluator, LLMEvalChain): output_key: str = "results" #: :meta private: - class Config: - extra = "ignore" + model_config = ConfigDict( + extra="ignore", + ) @classmethod def is_lc_serializable(cls) -> bool: @@ -220,8 +222,9 @@ def requires_input(self) -> bool: """Whether the chain requires an input string.""" return True - class Config: - extra = "ignore" + model_config = ConfigDict( + extra="ignore", + ) @classmethod def _validate_input_vars(cls, prompt: PromptTemplate) -> None: diff --git a/libs/langchain/langchain/evaluation/qa/generate_chain.py b/libs/langchain/langchain/evaluation/qa/generate_chain.py index 32dea149a447b..94cf36d45a7d4 100644 --- a/libs/langchain/langchain/evaluation/qa/generate_chain.py +++ b/libs/langchain/langchain/evaluation/qa/generate_chain.py @@ -6,7 +6,7 @@ from langchain_core.language_models import BaseLanguageModel from langchain_core.output_parsers import BaseLLMOutputParser -from langchain_core.pydantic_v1 import Field +from pydantic import Field from langchain.chains.llm import LLMChain from langchain.evaluation.qa.generate_prompt import PROMPT diff --git a/libs/langchain/langchain/evaluation/scoring/eval_chain.py b/libs/langchain/langchain/evaluation/scoring/eval_chain.py index 3b800f8ffc0d4..a8a84a05813d6 100644 --- a/libs/langchain/langchain/evaluation/scoring/eval_chain.py +++ b/libs/langchain/langchain/evaluation/scoring/eval_chain.py @@ -10,7 +10,7 @@ from langchain_core.language_models import BaseLanguageModel from langchain_core.output_parsers import BaseOutputParser from langchain_core.prompts.prompt import PromptTemplate -from langchain_core.pydantic_v1 import Field +from pydantic import ConfigDict, Field from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple from langchain.chains.llm import LLMChain @@ -179,8 +179,9 @@ class ScoreStringEvalChain(StringEvaluator, LLMEvalChain, LLMChain): criterion_name: str """The name of the criterion being evaluated.""" - class Config: - extra = "ignore" + model_config = ConfigDict( + extra="ignore", + ) @classmethod def is_lc_serializable(cls) -> bool: diff --git a/libs/langchain/langchain/evaluation/string_distance/base.py b/libs/langchain/langchain/evaluation/string_distance/base.py index bb9c9719d736f..396e267a7e5e5 100644 --- a/libs/langchain/langchain/evaluation/string_distance/base.py +++ b/libs/langchain/langchain/evaluation/string_distance/base.py @@ -8,8 +8,8 @@ CallbackManagerForChainRun, Callbacks, ) -from langchain_core.pydantic_v1 import Field from langchain_core.utils import pre_init +from pydantic import Field from langchain.chains.base import Chain from langchain.evaluation.schema import PairwiseStringEvaluator, StringEvaluator diff --git a/libs/langchain/langchain/indexes/vectorstore.py b/libs/langchain/langchain/indexes/vectorstore.py index deb408dff2d0c..08042c63d7292 100644 --- a/libs/langchain/langchain/indexes/vectorstore.py +++ b/libs/langchain/langchain/indexes/vectorstore.py @@ -4,9 +4,9 @@ from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.language_models import BaseLanguageModel -from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.vectorstores import VectorStore from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter +from pydantic import BaseModel, ConfigDict, Field from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain from langchain.chains.retrieval_qa.base import RetrievalQA @@ -21,9 +21,10 @@ class VectorStoreIndexWrapper(BaseModel): vectorstore: VectorStore - class Config: - arbitrary_types_allowed = True - extra = "forbid" + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + ) def query( self, @@ -142,9 +143,10 @@ class VectorstoreIndexCreator(BaseModel): text_splitter: TextSplitter = Field(default_factory=_get_default_text_splitter) vectorstore_kwargs: dict = Field(default_factory=dict) - class Config: - arbitrary_types_allowed = True - extra = "forbid" + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + ) def from_loaders(self, loaders: List[BaseLoader]) -> VectorStoreIndexWrapper: """Create a vectorstore index from loaders.""" diff --git a/libs/langchain/langchain/memory/chat_memory.py b/libs/langchain/langchain/memory/chat_memory.py index 10feaa3e1b95f..3053a8198ddcf 100644 --- a/libs/langchain/langchain/memory/chat_memory.py +++ b/libs/langchain/langchain/memory/chat_memory.py @@ -8,7 +8,7 @@ ) from langchain_core.memory import BaseMemory from langchain_core.messages import AIMessage, HumanMessage -from langchain_core.pydantic_v1 import Field +from pydantic import Field from langchain.memory.utils import get_prompt_input_key diff --git a/libs/langchain/langchain/memory/combined.py b/libs/langchain/langchain/memory/combined.py index 5ab0048895bba..f9117af4107cf 100644 --- a/libs/langchain/langchain/memory/combined.py +++ b/libs/langchain/langchain/memory/combined.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List, Set from langchain_core.memory import BaseMemory -from langchain_core.pydantic_v1 import validator +from pydantic import validator from langchain.memory.chat_memory import BaseChatMemory diff --git a/libs/langchain/langchain/memory/entity.py b/libs/langchain/langchain/memory/entity.py index 57fdb75537eb2..3032a0479b0e8 100644 --- a/libs/langchain/langchain/memory/entity.py +++ b/libs/langchain/langchain/memory/entity.py @@ -6,7 +6,7 @@ from langchain_core.language_models import BaseLanguageModel from langchain_core.messages import BaseMessage, get_buffer_string from langchain_core.prompts import BasePromptTemplate -from langchain_core.pydantic_v1 import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field from langchain.chains.llm import LLMChain from langchain.memory.chat_memory import BaseChatMemory @@ -245,8 +245,9 @@ class SQLiteEntityStore(BaseEntityStore): table_name: str = "memory_store" conn: Any = None - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict( + arbitrary_types_allowed=True, + ) def __init__( self, diff --git a/libs/langchain/langchain/memory/summary.py b/libs/langchain/langchain/memory/summary.py index 23c3f2bca1f43..e1d32003c7148 100644 --- a/libs/langchain/langchain/memory/summary.py +++ b/libs/langchain/langchain/memory/summary.py @@ -7,8 +7,8 @@ from langchain_core.language_models import BaseLanguageModel from langchain_core.messages import BaseMessage, SystemMessage, get_buffer_string from langchain_core.prompts import BasePromptTemplate -from langchain_core.pydantic_v1 import BaseModel from langchain_core.utils import pre_init +from pydantic import BaseModel from langchain.chains.llm import LLMChain from langchain.memory.chat_memory import BaseChatMemory diff --git a/libs/langchain/langchain/memory/vectorstore.py b/libs/langchain/langchain/memory/vectorstore.py index b719749b1ce37..a4511dd192529 100644 --- a/libs/langchain/langchain/memory/vectorstore.py +++ b/libs/langchain/langchain/memory/vectorstore.py @@ -3,8 +3,8 @@ from typing import Any, Dict, List, Optional, Sequence, Union from langchain_core.documents import Document -from langchain_core.pydantic_v1 import Field from langchain_core.vectorstores import VectorStoreRetriever +from pydantic import Field from langchain.memory.chat_memory import BaseMemory from langchain.memory.utils import get_prompt_input_key diff --git a/libs/langchain/langchain/memory/vectorstore_token_buffer_memory.py b/libs/langchain/langchain/memory/vectorstore_token_buffer_memory.py index f611c04903d07..293773e84a173 100644 --- a/libs/langchain/langchain/memory/vectorstore_token_buffer_memory.py +++ b/libs/langchain/langchain/memory/vectorstore_token_buffer_memory.py @@ -13,8 +13,8 @@ from langchain_core.messages import BaseMessage from langchain_core.prompts.chat import SystemMessagePromptTemplate -from langchain_core.pydantic_v1 import Field, PrivateAttr from langchain_core.vectorstores import VectorStoreRetriever +from pydantic import Field, PrivateAttr from langchain.memory import ConversationTokenBufferMemory, VectorStoreRetrieverMemory from langchain.memory.chat_memory import BaseChatMemory diff --git a/libs/langchain/langchain/output_parsers/fix.py b/libs/langchain/langchain/output_parsers/fix.py index a22bfff582be0..f0a1a701c2334 100644 --- a/libs/langchain/langchain/output_parsers/fix.py +++ b/libs/langchain/langchain/output_parsers/fix.py @@ -1,11 +1,12 @@ from __future__ import annotations -from typing import Any, TypeVar, Union +from typing import Annotated, Any, TypeVar, Union from langchain_core.exceptions import OutputParserException from langchain_core.output_parsers import BaseOutputParser, StrOutputParser from langchain_core.prompts import BasePromptTemplate from langchain_core.runnables import Runnable, RunnableSerializable +from pydantic import SkipValidation from typing_extensions import TypedDict from langchain.output_parsers.prompts import NAIVE_FIX_PROMPT @@ -26,7 +27,7 @@ class OutputFixingParser(BaseOutputParser[T]): def is_lc_serializable(cls) -> bool: return True - parser: BaseOutputParser[T] + parser: Annotated[BaseOutputParser[T], SkipValidation()] """The parser to use to parse the output.""" # Should be an LLMChain but we want to avoid top-level imports from langchain.chains retry_chain: Union[ diff --git a/libs/langchain/langchain/output_parsers/pandas_dataframe.py b/libs/langchain/langchain/output_parsers/pandas_dataframe.py index 3447767c088fa..2cc899bad8a27 100644 --- a/libs/langchain/langchain/output_parsers/pandas_dataframe.py +++ b/libs/langchain/langchain/output_parsers/pandas_dataframe.py @@ -3,7 +3,7 @@ from langchain_core.exceptions import OutputParserException from langchain_core.output_parsers.base import BaseOutputParser -from langchain_core.pydantic_v1 import validator +from pydantic import validator from langchain.output_parsers.format_instructions import ( PANDAS_DATAFRAME_FORMAT_INSTRUCTIONS, diff --git a/libs/langchain/langchain/output_parsers/retry.py b/libs/langchain/langchain/output_parsers/retry.py index cc56c627c0f0c..06e1b410f7291 100644 --- a/libs/langchain/langchain/output_parsers/retry.py +++ b/libs/langchain/langchain/output_parsers/retry.py @@ -8,7 +8,8 @@ from langchain_core.prompt_values import PromptValue from langchain_core.prompts import BasePromptTemplate, PromptTemplate from langchain_core.runnables import RunnableSerializable -from typing_extensions import TypedDict +from pydantic import SkipValidation +from typing_extensions import Annotated, TypedDict NAIVE_COMPLETION_RETRY = """Prompt: {prompt} @@ -53,7 +54,7 @@ class RetryOutputParser(BaseOutputParser[T]): LLM, and telling it the completion did not satisfy criteria in the prompt. """ - parser: BaseOutputParser[T] + parser: Annotated[BaseOutputParser[T], SkipValidation()] """The parser to use to parse the output.""" # Should be an LLMChain but we want to avoid top-level imports from langchain.chains retry_chain: Union[RunnableSerializable[RetryOutputParserRetryChainInput, str], Any] @@ -183,7 +184,7 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]): LLM, which in theory should give it more information on how to fix it. """ - parser: BaseOutputParser[T] + parser: Annotated[BaseOutputParser[T], SkipValidation()] """The parser to use to parse the output.""" # Should be an LLMChain but we want to avoid top-level imports from langchain.chains retry_chain: Union[ diff --git a/libs/langchain/langchain/output_parsers/structured.py b/libs/langchain/langchain/output_parsers/structured.py index 097e1a7170f4c..715be3c410ee9 100644 --- a/libs/langchain/langchain/output_parsers/structured.py +++ b/libs/langchain/langchain/output_parsers/structured.py @@ -4,7 +4,7 @@ from langchain_core.output_parsers import BaseOutputParser from langchain_core.output_parsers.json import parse_and_check_json_markdown -from langchain_core.pydantic_v1 import BaseModel +from pydantic import BaseModel from langchain.output_parsers.format_instructions import ( STRUCTURED_FORMAT_INSTRUCTIONS, diff --git a/libs/langchain/langchain/output_parsers/yaml.py b/libs/langchain/langchain/output_parsers/yaml.py index e7c071eb40068..5ea989b733a17 100644 --- a/libs/langchain/langchain/output_parsers/yaml.py +++ b/libs/langchain/langchain/output_parsers/yaml.py @@ -5,7 +5,7 @@ import yaml from langchain_core.exceptions import OutputParserException from langchain_core.output_parsers import BaseOutputParser -from langchain_core.pydantic_v1 import BaseModel, ValidationError +from pydantic import BaseModel, ValidationError from langchain.output_parsers.format_instructions import YAML_FORMAT_INSTRUCTIONS diff --git a/libs/langchain/langchain/retrievers/contextual_compression.py b/libs/langchain/langchain/retrievers/contextual_compression.py index c73180b889d3a..d5dccb13fb552 100644 --- a/libs/langchain/langchain/retrievers/contextual_compression.py +++ b/libs/langchain/langchain/retrievers/contextual_compression.py @@ -6,6 +6,7 @@ ) from langchain_core.documents import Document from langchain_core.retrievers import BaseRetriever, RetrieverLike +from pydantic import ConfigDict from langchain.retrievers.document_compressors.base import ( BaseDocumentCompressor, @@ -21,8 +22,9 @@ class ContextualCompressionRetriever(BaseRetriever): base_retriever: RetrieverLike """Base Retriever to use for getting relevant documents.""" - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict( + arbitrary_types_allowed=True, + ) def _get_relevant_documents( self, diff --git a/libs/langchain/langchain/retrievers/document_compressors/base.py b/libs/langchain/langchain/retrievers/document_compressors/base.py index a68515af8b023..dd25d428fa7fc 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/base.py +++ b/libs/langchain/langchain/retrievers/document_compressors/base.py @@ -7,6 +7,7 @@ BaseDocumentTransformer, Document, ) +from pydantic import ConfigDict class DocumentCompressorPipeline(BaseDocumentCompressor): @@ -15,8 +16,9 @@ class DocumentCompressorPipeline(BaseDocumentCompressor): transformers: List[Union[BaseDocumentTransformer, BaseDocumentCompressor]] """List of document filters that are chained together and run in sequence.""" - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict( + arbitrary_types_allowed=True, + ) def compress_documents( self, diff --git a/libs/langchain/langchain/retrievers/document_compressors/chain_extract.py b/libs/langchain/langchain/retrievers/document_compressors/chain_extract.py index cc86f2be49b73..9933319ad0485 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/chain_extract.py +++ b/libs/langchain/langchain/retrievers/document_compressors/chain_extract.py @@ -11,6 +11,7 @@ from langchain_core.output_parsers import BaseOutputParser, StrOutputParser from langchain_core.prompts import PromptTemplate from langchain_core.runnables import Runnable +from pydantic import ConfigDict from langchain.chains.llm import LLMChain from langchain.retrievers.document_compressors.base import BaseDocumentCompressor @@ -56,8 +57,9 @@ class LLMChainExtractor(BaseDocumentCompressor): get_input: Callable[[str, Document], dict] = default_get_input """Callable for constructing the chain input from the query and a Document.""" - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict( + arbitrary_types_allowed=True, + ) def compress_documents( self, diff --git a/libs/langchain/langchain/retrievers/document_compressors/chain_filter.py b/libs/langchain/langchain/retrievers/document_compressors/chain_filter.py index 2db6f5be3a7b2..bfa1cd694dc26 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/chain_filter.py +++ b/libs/langchain/langchain/retrievers/document_compressors/chain_filter.py @@ -9,6 +9,7 @@ from langchain_core.prompts import BasePromptTemplate, PromptTemplate from langchain_core.runnables import Runnable from langchain_core.runnables.config import RunnableConfig +from pydantic import ConfigDict from langchain.chains import LLMChain from langchain.output_parsers.boolean import BooleanOutputParser @@ -41,8 +42,9 @@ class LLMChainFilter(BaseDocumentCompressor): get_input: Callable[[str, Document], dict] = default_get_input """Callable for constructing the chain input from the query and a Document.""" - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict( + arbitrary_types_allowed=True, + ) def compress_documents( self, diff --git a/libs/langchain/langchain/retrievers/document_compressors/cohere_rerank.py b/libs/langchain/langchain/retrievers/document_compressors/cohere_rerank.py index f7d96c29df737..2030807ce310c 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/cohere_rerank.py +++ b/libs/langchain/langchain/retrievers/document_compressors/cohere_rerank.py @@ -6,8 +6,8 @@ from langchain_core._api.deprecation import deprecated from langchain_core.callbacks.manager import Callbacks from langchain_core.documents import Document -from langchain_core.pydantic_v1 import root_validator from langchain_core.utils import get_from_dict_or_env +from pydantic import ConfigDict, model_validator from langchain.retrievers.document_compressors.base import BaseDocumentCompressor @@ -30,12 +30,14 @@ class CohereRerank(BaseDocumentCompressor): user_agent: str = "langchain" """Identifier for the application making the request.""" - class Config: - arbitrary_types_allowed = True - extra = "forbid" + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + ) - @root_validator(pre=True) - def validate_environment(cls, values: Dict) -> Dict: + @model_validator(mode="before") + @classmethod + def validate_environment(cls, values: Dict) -> Any: """Validate that api key and python package exists in environment.""" if not values.get("client"): try: diff --git a/libs/langchain/langchain/retrievers/document_compressors/cross_encoder_rerank.py b/libs/langchain/langchain/retrievers/document_compressors/cross_encoder_rerank.py index d1b683f2d9b8f..fff77c15266bd 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/cross_encoder_rerank.py +++ b/libs/langchain/langchain/retrievers/document_compressors/cross_encoder_rerank.py @@ -5,6 +5,7 @@ from langchain_core.callbacks import Callbacks from langchain_core.documents import BaseDocumentCompressor, Document +from pydantic import ConfigDict from langchain.retrievers.document_compressors.cross_encoder import BaseCrossEncoder @@ -18,9 +19,10 @@ class CrossEncoderReranker(BaseDocumentCompressor): top_n: int = 3 """Number of documents to return.""" - class Config: - arbitrary_types_allowed = True - extra = "forbid" + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + ) def compress_documents( self, diff --git a/libs/langchain/langchain/retrievers/document_compressors/embeddings_filter.py b/libs/langchain/langchain/retrievers/document_compressors/embeddings_filter.py index d29a0e7ac5f61..8e3f1dbf43f4f 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/embeddings_filter.py +++ b/libs/langchain/langchain/retrievers/document_compressors/embeddings_filter.py @@ -4,8 +4,8 @@ from langchain_core.callbacks.manager import Callbacks from langchain_core.documents import Document from langchain_core.embeddings import Embeddings -from langchain_core.pydantic_v1 import Field from langchain_core.utils import pre_init +from pydantic import ConfigDict, Field from langchain.retrievers.document_compressors.base import ( BaseDocumentCompressor, @@ -41,8 +41,9 @@ class EmbeddingsFilter(BaseDocumentCompressor): to be considered redundant. Defaults to None, must be specified if `k` is set to None.""" - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict( + arbitrary_types_allowed=True, + ) @pre_init def validate_params(cls, values: Dict) -> Dict: diff --git a/libs/langchain/langchain/retrievers/document_compressors/listwise_rerank.py b/libs/langchain/langchain/retrievers/document_compressors/listwise_rerank.py index 16647df82c602..5039a36b6aba2 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/listwise_rerank.py +++ b/libs/langchain/langchain/retrievers/document_compressors/listwise_rerank.py @@ -6,8 +6,8 @@ from langchain_core.documents import BaseDocumentCompressor, Document from langchain_core.language_models import BaseLanguageModel from langchain_core.prompts import BasePromptTemplate, ChatPromptTemplate -from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough +from pydantic import BaseModel, ConfigDict, Field _default_system_tmpl = """{context} @@ -76,8 +76,9 @@ class LLMListwiseRerank(BaseDocumentCompressor): top_n: int = 3 """Number of documents to return.""" - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict( + arbitrary_types_allowed=True, + ) def compress_documents( self, diff --git a/libs/langchain/langchain/retrievers/ensemble.py b/libs/langchain/langchain/retrievers/ensemble.py index bf2cc46376c21..59a89e555b3e1 100644 --- a/libs/langchain/langchain/retrievers/ensemble.py +++ b/libs/langchain/langchain/retrievers/ensemble.py @@ -25,7 +25,6 @@ ) from langchain_core.documents import Document from langchain_core.load.dump import dumpd -from langchain_core.pydantic_v1 import root_validator from langchain_core.retrievers import BaseRetriever, RetrieverLike from langchain_core.runnables import RunnableConfig from langchain_core.runnables.config import ensure_config, patch_config @@ -33,6 +32,7 @@ ConfigurableFieldSpec, get_unique_config_specs, ) +from pydantic import model_validator T = TypeVar("T") H = TypeVar("H", bound=Hashable) @@ -83,8 +83,9 @@ def config_specs(self) -> List[ConfigurableFieldSpec]: spec for retriever in self.retrievers for spec in retriever.config_specs ) - @root_validator(pre=True) - def set_weights(cls, values: Dict[str, Any]) -> Dict[str, Any]: + @model_validator(mode="before") + @classmethod + def set_weights(cls, values: Dict[str, Any]) -> Any: if not values.get("weights"): n_retrievers = len(values["retrievers"]) values["weights"] = [1 / n_retrievers] * n_retrievers diff --git a/libs/langchain/langchain/retrievers/multi_vector.py b/libs/langchain/langchain/retrievers/multi_vector.py index 54a4d935dcd56..48e48d07ea6af 100644 --- a/libs/langchain/langchain/retrievers/multi_vector.py +++ b/libs/langchain/langchain/retrievers/multi_vector.py @@ -1,15 +1,15 @@ from enum import Enum -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional from langchain_core.callbacks import ( AsyncCallbackManagerForRetrieverRun, CallbackManagerForRetrieverRun, ) from langchain_core.documents import Document -from langchain_core.pydantic_v1 import Field, root_validator from langchain_core.retrievers import BaseRetriever from langchain_core.stores import BaseStore, ByteStore from langchain_core.vectorstores import VectorStore +from pydantic import Field, model_validator from langchain.storage._lc_store import create_kv_docstore @@ -41,8 +41,9 @@ class MultiVectorRetriever(BaseRetriever): search_type: SearchType = SearchType.similarity """Type of search to perform (similarity / mmr)""" - @root_validator(pre=True) - def shim_docstore(cls, values: Dict) -> Dict: + @model_validator(mode="before") + @classmethod + def shim_docstore(cls, values: Dict) -> Any: byte_store = values.get("byte_store") docstore = values.get("docstore") if byte_store is not None: diff --git a/libs/langchain/langchain/retrievers/self_query/base.py b/libs/langchain/langchain/retrievers/self_query/base.py index 0b89db472c124..a5254d475924f 100644 --- a/libs/langchain/langchain/retrievers/self_query/base.py +++ b/libs/langchain/langchain/retrievers/self_query/base.py @@ -9,11 +9,11 @@ ) from langchain_core.documents import Document from langchain_core.language_models import BaseLanguageModel -from langchain_core.pydantic_v1 import Field, root_validator from langchain_core.retrievers import BaseRetriever from langchain_core.runnables import Runnable from langchain_core.structured_query import StructuredQuery, Visitor from langchain_core.vectorstores import VectorStore +from pydantic import ConfigDict, Field, model_validator from langchain.chains.query_constructor.base import load_query_constructor_runnable from langchain.chains.query_constructor.schema import AttributeInfo @@ -223,12 +223,14 @@ class SelfQueryRetriever(BaseRetriever): use_original_query: bool = False """Use original query instead of the revised new query from LLM""" - class Config: - allow_population_by_field_name = True - arbitrary_types_allowed = True + model_config = ConfigDict( + populate_by_name=True, + arbitrary_types_allowed=True, + ) - @root_validator(pre=True) - def validate_translator(cls, values: Dict) -> Dict: + @model_validator(mode="before") + @classmethod + def validate_translator(cls, values: Dict) -> Any: """Validate translator.""" if "structured_query_translator" not in values: values["structured_query_translator"] = _get_builtin_translator( diff --git a/libs/langchain/langchain/retrievers/time_weighted_retriever.py b/libs/langchain/langchain/retrievers/time_weighted_retriever.py index 0acf17edce806..706366dbf58bf 100644 --- a/libs/langchain/langchain/retrievers/time_weighted_retriever.py +++ b/libs/langchain/langchain/retrievers/time_weighted_retriever.py @@ -7,9 +7,9 @@ CallbackManagerForRetrieverRun, ) from langchain_core.documents import Document -from langchain_core.pydantic_v1 import Field from langchain_core.retrievers import BaseRetriever from langchain_core.vectorstores import VectorStore +from pydantic import ConfigDict, Field def _get_hours_passed(time: datetime.datetime, ref_time: datetime.datetime) -> float: @@ -46,8 +46,9 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever): None assigns no salience to documents not fetched from the vector store. """ - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict( + arbitrary_types_allowed=True, + ) def _document_get_date(self, field: str, document: Document) -> datetime.datetime: """Return the value of the date field of a document.""" diff --git a/libs/langchain/langchain/smith/evaluation/config.py b/libs/langchain/langchain/smith/evaluation/config.py index e9bdd324779db..9f132a011a6bb 100644 --- a/libs/langchain/langchain/smith/evaluation/config.py +++ b/libs/langchain/langchain/smith/evaluation/config.py @@ -5,10 +5,10 @@ from langchain_core.embeddings import Embeddings from langchain_core.language_models import BaseLanguageModel from langchain_core.prompts import BasePromptTemplate -from langchain_core.pydantic_v1 import BaseModel, Field from langsmith import RunEvaluator from langsmith.evaluation.evaluator import EvaluationResult, EvaluationResults from langsmith.schemas import Example, Run +from pydantic import BaseModel, ConfigDict, Field from langchain.evaluation.criteria.eval_chain import CRITERIA_TYPE from langchain.evaluation.embedding_distance.base import ( @@ -156,8 +156,9 @@ class RunEvalConfig(BaseModel): eval_llm: Optional[BaseLanguageModel] = None """The language model to pass to any evaluators that require one.""" - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict( + arbitrary_types_allowed=True, + ) class Criteria(SingleKeyEvalConfig): """Configuration for a reference-free criteria evaluator. @@ -217,8 +218,9 @@ class EmbeddingDistance(SingleKeyEvalConfig): embeddings: Optional[Embeddings] = None distance_metric: Optional[EmbeddingDistanceEnum] = None - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict( + arbitrary_types_allowed=True, + ) class StringDistance(SingleKeyEvalConfig): """Configuration for a string distance evaluator. diff --git a/libs/langchain/scripts/check_pydantic.sh b/libs/langchain/scripts/check_pydantic.sh deleted file mode 100755 index 06b5bb81ae236..0000000000000 --- a/libs/langchain/scripts/check_pydantic.sh +++ /dev/null @@ -1,27 +0,0 @@ -#!/bin/bash -# -# This script searches for lines starting with "import pydantic" or "from pydantic" -# in tracked files within a Git repository. -# -# Usage: ./scripts/check_pydantic.sh /path/to/repository - -# Check if a path argument is provided -if [ $# -ne 1 ]; then - echo "Usage: $0 /path/to/repository" - exit 1 -fi - -repository_path="$1" - -# Search for lines matching the pattern within the specified repository -result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic') - -# Check if any matching lines were found -if [ -n "$result" ]; then - echo "ERROR: The following lines need to be updated:" - echo "$result" - echo "Please replace the code with an import from langchain_core.pydantic_v1." - echo "For example, replace 'from pydantic import BaseModel'" - echo "with 'from langchain_core.pydantic_v1 import BaseModel'" - exit 1 -fi diff --git a/libs/langchain/tests/integration_tests/chat_models/test_base.py b/libs/langchain/tests/integration_tests/chat_models/test_base.py index cda11263ddfbc..efed6e1d52290 100644 --- a/libs/langchain/tests/integration_tests/chat_models/test_base.py +++ b/libs/langchain/tests/integration_tests/chat_models/test_base.py @@ -4,9 +4,9 @@ from langchain_core.language_models import BaseChatModel from langchain_core.messages import AIMessage from langchain_core.prompts import ChatPromptTemplate -from langchain_core.pydantic_v1 import BaseModel from langchain_core.runnables import RunnableConfig from langchain_standard_tests.integration_tests import ChatModelIntegrationTests +from pydantic import BaseModel from langchain.chat_models import init_chat_model diff --git a/libs/langchain/tests/mock_servers/robot/server.py b/libs/langchain/tests/mock_servers/robot/server.py index 1156cf1bb89db..823057bb4d9e2 100644 --- a/libs/langchain/tests/mock_servers/robot/server.py +++ b/libs/langchain/tests/mock_servers/robot/server.py @@ -8,7 +8,7 @@ from fastapi import FastAPI, HTTPException, Query from fastapi.middleware.cors import CORSMiddleware from fastapi.openapi.utils import get_openapi -from langchain_core.pydantic_v1 import BaseModel, Field +from pydantic import BaseModel, Field PORT = 7289 diff --git a/libs/langchain/tests/unit_tests/agents/test_agent.py b/libs/langchain/tests/unit_tests/agents/test_agent.py index 6db94330a1328..038ffb2930996 100644 --- a/libs/langchain/tests/unit_tests/agents/test_agent.py +++ b/libs/langchain/tests/unit_tests/agents/test_agent.py @@ -6,7 +6,6 @@ from langchain_core.agents import ( AgentAction, - AgentActionMessageLog, AgentFinish, AgentStep, ) @@ -35,7 +34,9 @@ from langchain.agents.output_parsers.openai_tools import OpenAIToolAgentAction from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler from tests.unit_tests.llms.fake_chat_model import GenericFakeChatModel -from tests.unit_tests.stubs import AnyStr +from tests.unit_tests.stubs import ( + _AnyIdAIMessageChunk, +) class FakeListLLM(LLM): @@ -785,6 +786,26 @@ def _make_func_invocation(name: str, **kwargs: Any) -> AIMessage: ) +def _recursive_dump(obj: Any) -> Any: + """Recursively dump the object if encountering any pydantic models.""" + if isinstance(obj, dict): + return { + k: _recursive_dump(v) + for k, v in obj.items() + if k != "id" # Remove the id field for testing purposes + } + if isinstance(obj, list): + return [_recursive_dump(v) for v in obj] + if hasattr(obj, "dict"): + # if the object contains an ID field, we'll remove it for testing purposes + if hasattr(obj, "id"): + d = obj.dict() + d.pop("id") + return _recursive_dump(d) + return _recursive_dump(obj.dict()) + return obj + + async def test_openai_agent_with_streaming() -> None: """Test openai agent with streaming.""" infinite_cycle = cycle( @@ -831,72 +852,93 @@ def find_pet(pet: str) -> str: # astream chunks = [chunk async for chunk in executor.astream({"question": "hello"})] - assert chunks == [ + assert _recursive_dump(chunks) == [ { "actions": [ - AgentActionMessageLog( - tool="find_pet", - tool_input={"pet": "cat"}, - log="\nInvoking: `find_pet` with `{'pet': 'cat'}`\n\n\n", - message_log=[ - AIMessageChunk( - id=AnyStr(), - content="", - additional_kwargs={ + { + "log": "\nInvoking: `find_pet` with `{'pet': 'cat'}`\n\n\n", + "message_log": [ + { + "additional_kwargs": { "function_call": { + "arguments": '{"pet": ' '"cat"}', "name": "find_pet", - "arguments": '{"pet": "cat"}', } }, - ) + "content": "", + "name": None, + "response_metadata": {}, + "type": "AIMessageChunk", + } ], - ) + "tool": "find_pet", + "tool_input": {"pet": "cat"}, + "type": "AgentActionMessageLog", + } ], "messages": [ - AIMessageChunk( - id=AnyStr(), - content="", - additional_kwargs={ + { + "additional_kwargs": { "function_call": { + "arguments": '{"pet": ' '"cat"}', "name": "find_pet", - "arguments": '{"pet": "cat"}', } }, - ) + "content": "", + "example": False, + "invalid_tool_calls": [], + "name": None, + "response_metadata": {}, + "tool_call_chunks": [], + "tool_calls": [], + "type": "AIMessageChunk", + "usage_metadata": None, + } ], }, { "messages": [ - FunctionMessage(content="Spying from under the bed.", name="find_pet") + { + "additional_kwargs": {}, + "content": "Spying from under the bed.", + "name": "find_pet", + "response_metadata": {}, + "type": "function", + } ], "steps": [ - AgentStep( - action=AgentActionMessageLog( - tool="find_pet", - tool_input={"pet": "cat"}, - log="\nInvoking: `find_pet` with `{'pet': 'cat'}`\n\n\n", - message_log=[ - AIMessageChunk( - id=AnyStr(), - content="", - additional_kwargs={ - "function_call": { - "name": "find_pet", - "arguments": '{"pet": "cat"}', - } - }, - ) - ], - ), - observation="Spying from under the bed.", - ) + { + "action": { + "log": "\n" + "Invoking: `find_pet` with `{'pet': 'cat'}`\n" + "\n" + "\n", + "tool": "find_pet", + "tool_input": {"pet": "cat"}, + "type": "AgentActionMessageLog", + }, + "observation": "Spying from under the bed.", + } ], }, { - "messages": [AIMessage(content="The cat is spying from under the bed.")], + "messages": [ + { + "additional_kwargs": {}, + "content": "The cat is spying from under the bed.", + "example": False, + "invalid_tool_calls": [], + "name": None, + "response_metadata": {}, + "tool_calls": [], + "type": "ai", + "usage_metadata": None, + } + ], "output": "The cat is spying from under the bed.", }, ] + # # # astream_log log_patches = [ @@ -941,7 +983,7 @@ def _make_tools_invocation(name_to_arguments: Dict[str, Dict[str, Any]]) -> AIMe AIMessage that represents a request to invoke a tool. """ raw_tool_calls = [ - {"function": {"name": name, "arguments": json.dumps(arguments)}, "id": idx} + {"function": {"name": name, "arguments": json.dumps(arguments)}, "id": str(idx)} for idx, (name, arguments) in enumerate(name_to_arguments.items()) ] tool_calls = [ @@ -1030,8 +1072,7 @@ def check_time() -> str: tool_input={"pet": "cat"}, log="\nInvoking: `find_pet` with `{'pet': 'cat'}`\n\n\n", message_log=[ - AIMessageChunk( - id=AnyStr(), + _AnyIdAIMessageChunk( content="", additional_kwargs={ "tool_calls": [ @@ -1040,14 +1081,14 @@ def check_time() -> str: "name": "find_pet", "arguments": '{"pet": "cat"}', }, - "id": 0, + "id": "0", }, { "function": { "name": "check_time", "arguments": "{}", }, - "id": 1, + "id": "1", }, ] }, @@ -1057,8 +1098,7 @@ def check_time() -> str: ) ], "messages": [ - AIMessageChunk( - id=AnyStr(), + _AnyIdAIMessageChunk( content="", additional_kwargs={ "tool_calls": [ @@ -1067,14 +1107,14 @@ def check_time() -> str: "name": "find_pet", "arguments": '{"pet": "cat"}', }, - "id": 0, + "id": "0", }, { "function": { "name": "check_time", "arguments": "{}", }, - "id": 1, + "id": "1", }, ] }, @@ -1088,8 +1128,7 @@ def check_time() -> str: tool_input={}, log="\nInvoking: `check_time` with `{}`\n\n\n", message_log=[ - AIMessageChunk( - id=AnyStr(), + _AnyIdAIMessageChunk( content="", additional_kwargs={ "tool_calls": [ @@ -1098,14 +1137,14 @@ def check_time() -> str: "name": "find_pet", "arguments": '{"pet": "cat"}', }, - "id": 0, + "id": "0", }, { "function": { "name": "check_time", "arguments": "{}", }, - "id": 1, + "id": "1", }, ] }, @@ -1115,8 +1154,7 @@ def check_time() -> str: ) ], "messages": [ - AIMessageChunk( - id=AnyStr(), + _AnyIdAIMessageChunk( content="", additional_kwargs={ "tool_calls": [ @@ -1125,14 +1163,14 @@ def check_time() -> str: "name": "find_pet", "arguments": '{"pet": "cat"}', }, - "id": 0, + "id": "0", }, { "function": { "name": "check_time", "arguments": "{}", }, - "id": 1, + "id": "1", }, ] }, @@ -1152,8 +1190,7 @@ def check_time() -> str: tool_input={"pet": "cat"}, log="\nInvoking: `find_pet` with `{'pet': 'cat'}`\n\n\n", # noqa: E501 message_log=[ - AIMessageChunk( - id=AnyStr(), + _AnyIdAIMessageChunk( content="", additional_kwargs={ "tool_calls": [ @@ -1162,14 +1199,14 @@ def check_time() -> str: "name": "find_pet", "arguments": '{"pet": "cat"}', }, - "id": 0, + "id": "0", }, { "function": { "name": "check_time", "arguments": "{}", }, - "id": 1, + "id": "1", }, ] }, @@ -1195,8 +1232,7 @@ def check_time() -> str: tool_input={}, log="\nInvoking: `check_time` with `{}`\n\n\n", message_log=[ - AIMessageChunk( - id=AnyStr(), + _AnyIdAIMessageChunk( content="", additional_kwargs={ "tool_calls": [ @@ -1205,14 +1241,14 @@ def check_time() -> str: "name": "find_pet", "arguments": '{"pet": "cat"}', }, - "id": 0, + "id": "0", }, { "function": { "name": "check_time", "arguments": "{}", }, - "id": 1, + "id": "1", }, ] }, diff --git a/libs/langchain/tests/unit_tests/callbacks/fake_callback_handler.py b/libs/langchain/tests/unit_tests/callbacks/fake_callback_handler.py index fdef58e13ff4a..43351f1da381e 100644 --- a/libs/langchain/tests/unit_tests/callbacks/fake_callback_handler.py +++ b/libs/langchain/tests/unit_tests/callbacks/fake_callback_handler.py @@ -6,7 +6,7 @@ from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler from langchain_core.messages import BaseMessage -from langchain_core.pydantic_v1 import BaseModel +from pydantic import BaseModel class BaseFakeCallbackHandler(BaseModel): @@ -254,7 +254,7 @@ def on_retriever_error( ) -> Any: self.on_retriever_error_common() - def __deepcopy__(self, memo: dict) -> "FakeCallbackHandler": + def __deepcopy__(self, memo: dict) -> "FakeCallbackHandler": # type: ignore return self @@ -388,5 +388,5 @@ async def on_text( ) -> None: self.on_text_common() - def __deepcopy__(self, memo: dict) -> "FakeAsyncCallbackHandler": + def __deepcopy__(self, memo: dict) -> "FakeAsyncCallbackHandler": # type: ignore return self diff --git a/libs/langchain/tests/unit_tests/evaluation/agents/test_eval_chain.py b/libs/langchain/tests/unit_tests/evaluation/agents/test_eval_chain.py index 5178b6b20d99a..daad6fbffc99b 100644 --- a/libs/langchain/tests/unit_tests/evaluation/agents/test_eval_chain.py +++ b/libs/langchain/tests/unit_tests/evaluation/agents/test_eval_chain.py @@ -6,8 +6,8 @@ from langchain_core.agents import AgentAction, BaseMessage from langchain_core.callbacks.manager import CallbackManagerForLLMRun from langchain_core.exceptions import OutputParserException -from langchain_core.pydantic_v1 import Field from langchain_core.tools import tool +from pydantic import Field from langchain.evaluation.agents.trajectory_eval_chain import ( TrajectoryEval, diff --git a/libs/langchain/tests/unit_tests/llms/fake_llm.py b/libs/langchain/tests/unit_tests/llms/fake_llm.py index e75865b40f290..7e21cce3404cc 100644 --- a/libs/langchain/tests/unit_tests/llms/fake_llm.py +++ b/libs/langchain/tests/unit_tests/llms/fake_llm.py @@ -4,7 +4,7 @@ from langchain_core.callbacks.manager import CallbackManagerForLLMRun from langchain_core.language_models.llms import LLM -from langchain_core.pydantic_v1 import validator +from pydantic import validator class FakeLLM(LLM): diff --git a/libs/langchain/tests/unit_tests/llms/test_fake_chat_model.py b/libs/langchain/tests/unit_tests/llms/test_fake_chat_model.py index f009222608668..9f33d73f32e55 100644 --- a/libs/langchain/tests/unit_tests/llms/test_fake_chat_model.py +++ b/libs/langchain/tests/unit_tests/llms/test_fake_chat_model.py @@ -9,7 +9,7 @@ from langchain_core.outputs import ChatGenerationChunk, GenerationChunk from tests.unit_tests.llms.fake_chat_model import GenericFakeChatModel -from tests.unit_tests.stubs import AnyStr, _AnyIdAIMessage, _AnyIdAIMessageChunk +from tests.unit_tests.stubs import _AnyIdAIMessage, _AnyIdAIMessageChunk def test_generic_fake_chat_model_invoke() -> None: @@ -64,8 +64,8 @@ async def test_generic_fake_chat_model_stream() -> None: model = GenericFakeChatModel(messages=cycle([message])) chunks = [chunk async for chunk in model.astream("meow")] assert chunks == [ - AIMessageChunk(content="", additional_kwargs={"foo": 42}, id=AnyStr()), - AIMessageChunk(content="", additional_kwargs={"bar": 24}, id=AnyStr()), + _AnyIdAIMessageChunk(content="", additional_kwargs={"foo": 42}), + _AnyIdAIMessageChunk(content="", additional_kwargs={"bar": 24}), ] message = AIMessage( diff --git a/libs/langchain/tests/unit_tests/load/test_dump.py b/libs/langchain/tests/unit_tests/load/test_dump.py index 0ac05f7df21ff..76af8513e73ee 100644 --- a/libs/langchain/tests/unit_tests/load/test_dump.py +++ b/libs/langchain/tests/unit_tests/load/test_dump.py @@ -8,7 +8,7 @@ import pytest from langchain_core.load.dump import dumps from langchain_core.load.serializable import Serializable -from langchain_core.pydantic_v1 import Field, root_validator +from pydantic import ConfigDict, Field, model_validator class Person(Serializable): @@ -84,11 +84,13 @@ class TestClass(Serializable): my_favorite_secret: str = Field(alias="my_favorite_secret_alias") my_other_secret: str = Field() - class Config: - allow_population_by_field_name = True + model_config = ConfigDict( + populate_by_name=True, + ) - @root_validator(pre=True) - def get_from_env(cls, values: Dict) -> Dict: + @model_validator(mode="before") + @classmethod + def get_from_env(cls, values: Dict) -> Any: """Get the values from the environment.""" if "my_favorite_secret" not in values: values["my_favorite_secret"] = os.getenv("MY_FAVORITE_SECRET") diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_fix.py b/libs/langchain/tests/unit_tests/output_parsers/test_fix.py index 61d2d8a0c4613..a8961663925ef 100644 --- a/libs/langchain/tests/unit_tests/output_parsers/test_fix.py +++ b/libs/langchain/tests/unit_tests/output_parsers/test_fix.py @@ -6,13 +6,11 @@ from langchain_core.messages import AIMessage from langchain_core.prompts.prompt import PromptTemplate from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough -from pytest_mock import MockerFixture from langchain.output_parsers.boolean import BooleanOutputParser from langchain.output_parsers.datetime import DatetimeOutputParser from langchain.output_parsers.fix import BaseOutputParser, OutputFixingParser from langchain.output_parsers.prompts import NAIVE_FIX_PROMPT -from langchain.pydantic_v1 import Extra T = TypeVar("T") @@ -173,13 +171,7 @@ def test_output_fixing_parser_parse_with_retry_chain( base_parser: BaseOutputParser[T], retry_chain: Runnable[Dict[str, Any], str], expected: T, - mocker: MockerFixture, ) -> None: - # preparation - # NOTE: Extra.allow is necessary in order to use spy and mock - retry_chain.Config.extra = Extra.allow # type: ignore - base_parser.Config.extra = Extra.allow # type: ignore - invoke_spy = mocker.spy(retry_chain, "invoke") # NOTE: get_format_instructions of some parsers behave randomly instructions = base_parser.get_format_instructions() object.__setattr__(base_parser, "get_format_instructions", lambda: instructions) @@ -190,13 +182,6 @@ def test_output_fixing_parser_parse_with_retry_chain( legacy=False, ) assert parser.parse(input) == expected - invoke_spy.assert_called_once_with( - dict( - instructions=base_parser.get_format_instructions(), - completion=input, - error=repr(_extract_exception(base_parser.parse, input)), - ) - ) @pytest.mark.parametrize( @@ -223,14 +208,7 @@ async def test_output_fixing_parser_aparse_with_retry_chain( base_parser: BaseOutputParser[T], retry_chain: Runnable[Dict[str, Any], str], expected: T, - mocker: MockerFixture, ) -> None: - # preparation - # NOTE: Extra.allow is necessary in order to use spy and mock - retry_chain.Config.extra = Extra.allow # type: ignore - base_parser.Config.extra = Extra.allow # type: ignore - ainvoke_spy = mocker.spy(retry_chain, "ainvoke") - # NOTE: get_format_instructions of some parsers behave randomly instructions = base_parser.get_format_instructions() object.__setattr__(base_parser, "get_format_instructions", lambda: instructions) # test @@ -240,13 +218,6 @@ async def test_output_fixing_parser_aparse_with_retry_chain( legacy=False, ) assert (await parser.aparse(input)) == expected - ainvoke_spy.assert_called_once_with( - dict( - instructions=base_parser.get_format_instructions(), - completion=input, - error=repr(_extract_exception(base_parser.parse, input)), - ) - ) def _extract_exception( diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_retry.py b/libs/langchain/tests/unit_tests/output_parsers/test_retry.py index 7af3597f47573..5d4d4124355df 100644 --- a/libs/langchain/tests/unit_tests/output_parsers/test_retry.py +++ b/libs/langchain/tests/unit_tests/output_parsers/test_retry.py @@ -4,7 +4,6 @@ import pytest from langchain_core.prompt_values import PromptValue, StringPromptValue from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough -from pytest_mock import MockerFixture from langchain.output_parsers.boolean import BooleanOutputParser from langchain.output_parsers.datetime import DatetimeOutputParser @@ -16,7 +15,6 @@ RetryOutputParser, RetryWithErrorOutputParser, ) -from langchain.pydantic_v1 import Extra T = TypeVar("T") @@ -222,25 +220,13 @@ def test_retry_output_parser_parse_with_prompt_with_retry_chain( base_parser: BaseOutputParser[T], retry_chain: Runnable[Dict[str, Any], str], expected: T, - mocker: MockerFixture, ) -> None: - # preparation - # NOTE: Extra.allow is necessary in order to use spy and mock - retry_chain.Config.extra = Extra.allow # type: ignore - invoke_spy = mocker.spy(retry_chain, "invoke") - # test parser = RetryOutputParser( parser=base_parser, retry_chain=retry_chain, legacy=False, ) assert parser.parse_with_prompt(input, prompt) == expected - invoke_spy.assert_called_once_with( - dict( - prompt=prompt.to_string(), - completion=input, - ) - ) @pytest.mark.parametrize( @@ -262,12 +248,7 @@ async def test_retry_output_parser_aparse_with_prompt_with_retry_chain( base_parser: BaseOutputParser[T], retry_chain: Runnable[Dict[str, Any], str], expected: T, - mocker: MockerFixture, ) -> None: - # preparation - # NOTE: Extra.allow is necessary in order to use spy and mock - retry_chain.Config.extra = Extra.allow # type: ignore - ainvoke_spy = mocker.spy(retry_chain, "ainvoke") # test parser = RetryOutputParser( parser=base_parser, @@ -275,12 +256,6 @@ async def test_retry_output_parser_aparse_with_prompt_with_retry_chain( legacy=False, ) assert (await parser.aparse_with_prompt(input, prompt)) == expected - ainvoke_spy.assert_called_once_with( - dict( - prompt=prompt.to_string(), - completion=input, - ) - ) @pytest.mark.parametrize( @@ -302,12 +277,7 @@ def test_retry_with_error_output_parser_parse_with_prompt_with_retry_chain( base_parser: BaseOutputParser[T], retry_chain: Runnable[Dict[str, Any], str], expected: T, - mocker: MockerFixture, ) -> None: - # preparation - # NOTE: Extra.allow is necessary in order to use spy and mock - retry_chain.Config.extra = Extra.allow # type: ignore - invoke_spy = mocker.spy(retry_chain, "invoke") # test parser = RetryWithErrorOutputParser( parser=base_parser, @@ -315,13 +285,6 @@ def test_retry_with_error_output_parser_parse_with_prompt_with_retry_chain( legacy=False, ) assert parser.parse_with_prompt(input, prompt) == expected - invoke_spy.assert_called_once_with( - dict( - prompt=prompt.to_string(), - completion=input, - error=repr(_extract_exception(base_parser.parse, input)), - ) - ) @pytest.mark.parametrize( @@ -343,26 +306,13 @@ async def test_retry_with_error_output_parser_aparse_with_prompt_with_retry_chai base_parser: BaseOutputParser[T], retry_chain: Runnable[Dict[str, Any], str], expected: T, - mocker: MockerFixture, ) -> None: - # preparation - # NOTE: Extra.allow is necessary in order to use spy and mock - retry_chain.Config.extra = Extra.allow # type: ignore - ainvoke_spy = mocker.spy(retry_chain, "ainvoke") - # test parser = RetryWithErrorOutputParser( parser=base_parser, retry_chain=retry_chain, legacy=False, ) assert (await parser.aparse_with_prompt(input, prompt)) == expected - ainvoke_spy.assert_called_once_with( - dict( - prompt=prompt.to_string(), - completion=input, - error=repr(_extract_exception(base_parser.parse, input)), - ) - ) def _extract_exception( diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_yaml_parser.py b/libs/langchain/tests/unit_tests/output_parsers/test_yaml_parser.py index 065ca4aa96ca4..6e678ed8f03d2 100644 --- a/libs/langchain/tests/unit_tests/output_parsers/test_yaml_parser.py +++ b/libs/langchain/tests/unit_tests/output_parsers/test_yaml_parser.py @@ -5,7 +5,7 @@ import pytest from langchain_core.exceptions import OutputParserException -from langchain_core.pydantic_v1 import BaseModel, Field +from pydantic import BaseModel, Field from langchain.output_parsers.yaml import YamlOutputParser diff --git a/libs/langchain/tests/unit_tests/smith/evaluation/test_string_run_evaluator.py b/libs/langchain/tests/unit_tests/smith/evaluation/test_string_run_evaluator.py index fd2f9e40c3efa..cb2916193a85b 100644 --- a/libs/langchain/tests/unit_tests/smith/evaluation/test_string_run_evaluator.py +++ b/libs/langchain/tests/unit_tests/smith/evaluation/test_string_run_evaluator.py @@ -12,11 +12,10 @@ def test_evaluate_run() -> None: run_mapper = ChainStringRunMapper() - example_mapper = MagicMock() string_evaluator = criteria.CriteriaEvalChain.from_llm(fake_llm.FakeLLM()) evaluator = StringRunEvaluatorChain( run_mapper=run_mapper, - example_mapper=example_mapper, + example_mapper=None, name="test_evaluator", string_evaluator=string_evaluator, ) diff --git a/libs/langchain/tests/unit_tests/test_imports.py b/libs/langchain/tests/unit_tests/test_imports.py index 1433638098273..22b98bf2cfbbe 100644 --- a/libs/langchain/tests/unit_tests/test_imports.py +++ b/libs/langchain/tests/unit_tests/test_imports.py @@ -34,6 +34,8 @@ def test_import_all() -> None: # If the module is not installed, we suppress the error if "Module langchain_community" in str(e) and COMMUNITY_NOT_INSTALLED: pass + except Exception as e: + raise AssertionError(f"Could not import {module_name}.{name}") from e def test_import_all_using_dir() -> None: diff --git a/libs/langchain/tests/unit_tests/test_schema.py b/libs/langchain/tests/unit_tests/test_schema.py index a498e234a5070..5e720db065347 100644 --- a/libs/langchain/tests/unit_tests/test_schema.py +++ b/libs/langchain/tests/unit_tests/test_schema.py @@ -19,26 +19,23 @@ ) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, Generation from langchain_core.prompt_values import ChatPromptValueConcrete, StringPromptValue -from langchain_core.pydantic_v1 import BaseModel, ValidationError +from pydantic import RootModel, ValidationError def test_serialization_of_wellknown_objects() -> None: """Test that pydantic is able to serialize and deserialize well known objects.""" - - class WellKnownLCObject(BaseModel): - """A well known LangChain object.""" - - __root__: Union[ + well_known_lc_object = RootModel[ + Union[ Document, HumanMessage, SystemMessage, ChatMessage, FunctionMessage, + FunctionMessageChunk, AIMessage, HumanMessageChunk, SystemMessageChunk, ChatMessageChunk, - FunctionMessageChunk, AIMessageChunk, StringPromptValue, ChatPromptValueConcrete, @@ -49,6 +46,7 @@ class WellKnownLCObject(BaseModel): Generation, ChatGenerationChunk, ] + ] lc_objects = [ HumanMessage(content="human"), @@ -97,11 +95,11 @@ class WellKnownLCObject(BaseModel): ] for lc_object in lc_objects: - d = lc_object.dict() + d = lc_object.model_dump() assert "type" in d, f"Missing key `type` for {type(lc_object)}" - obj1 = WellKnownLCObject.parse_obj(d) - assert type(obj1.__root__) is type(lc_object), f"failed for {type(lc_object)}" + obj1 = well_known_lc_object.model_validate(d) + assert type(obj1.root) is type(lc_object), f"failed for {type(lc_object)}" - with pytest.raises(ValidationError): + with pytest.raises((TypeError, ValidationError)): # Make sure that specifically validation error is raised - WellKnownLCObject.parse_obj({}) + well_known_lc_object.model_validate({}) diff --git a/libs/langchain/tests/unit_tests/utils/test_openai_functions.py b/libs/langchain/tests/unit_tests/utils/test_openai_functions.py index 34a0b8126f083..ca66e1c64ae2f 100644 --- a/libs/langchain/tests/unit_tests/utils/test_openai_functions.py +++ b/libs/langchain/tests/unit_tests/utils/test_openai_functions.py @@ -1,5 +1,5 @@ -from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.utils.function_calling import convert_pydantic_to_openai_function +from pydantic import BaseModel, Field def test_convert_pydantic_to_openai_function() -> None: