Skip to content

Commit

Permalink
langchain[major]: Upgrade langchain to be pydantic 2 compatible (#26050)
Browse files Browse the repository at this point in the history
Upgrading the langchain package to be pydantic 2 compatible.

Had to remove some parts of unit tests in parsers that were relying on
spying on methods since that fails with pydantic 2. The unit tests don't
seem particularly good, so can be re-written at a future date.

Depends on: #26057

Most of this PR was done using gritql for code mods, followed by some
fixes done manually to account for changes made by pydantic
  • Loading branch information
eyurtsev authored Sep 4, 2024
1 parent c72a762 commit 9379613
Show file tree
Hide file tree
Showing 103 changed files with 557 additions and 523 deletions.
1 change: 0 additions & 1 deletion libs/langchain/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ lint_tests: PYTHON_FILES=tests
lint_tests: MYPY_CACHE=.mypy_cache_test

lint lint_diff lint_package lint_tests:
./scripts/check_pydantic.sh .
./scripts/lint_imports.sh
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff check $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff
Expand Down
38 changes: 21 additions & 17 deletions libs/langchain/langchain/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,12 @@
from langchain_core.prompts import BasePromptTemplate
from langchain_core.prompts.few_shot import FewShotPromptTemplate
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, root_validator
from langchain_core.runnables import Runnable, RunnableConfig, ensure_config
from langchain_core.runnables.utils import AddableDict
from langchain_core.tools import BaseTool
from langchain_core.utils.input import get_color_mapping
from pydantic import BaseModel, ConfigDict, model_validator
from typing_extensions import Self

from langchain.agents.agent_iterator import AgentExecutorIterator
from langchain.agents.agent_types import AgentType
Expand Down Expand Up @@ -420,8 +421,9 @@ class RunnableAgent(BaseSingleActionAgent):
individual LLM tokens will not be available in stream_log.
"""

class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(
arbitrary_types_allowed=True,
)

@property
def return_values(self) -> List[str]:
Expand Down Expand Up @@ -528,8 +530,9 @@ class RunnableMultiActionAgent(BaseMultiActionAgent):
individual LLM tokens will not be available in stream_log.
"""

class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(
arbitrary_types_allowed=True,
)

@property
def return_values(self) -> List[str]:
Expand Down Expand Up @@ -854,8 +857,8 @@ def input_keys(self) -> List[str]:
"""
return list(set(self.llm_chain.input_keys) - {"agent_scratchpad"})

@root_validator(pre=False, skip_on_failure=True)
def validate_prompt(cls, values: Dict) -> Dict:
@model_validator(mode="after")
def validate_prompt(self) -> Self:
"""Validate that prompt matches format.
Args:
Expand All @@ -868,7 +871,7 @@ def validate_prompt(cls, values: Dict) -> Dict:
ValueError: If `agent_scratchpad` is not in prompt.input_variables
and prompt is not a FewShotPromptTemplate or a PromptTemplate.
"""
prompt = values["llm_chain"].prompt
prompt = self.llm_chain.prompt
if "agent_scratchpad" not in prompt.input_variables:
logger.warning(
"`agent_scratchpad` should be a variable in prompt.input_variables."
Expand All @@ -881,7 +884,7 @@ def validate_prompt(cls, values: Dict) -> Dict:
prompt.suffix += "\n{agent_scratchpad}"
else:
raise ValueError(f"Got unexpected prompt type {type(prompt)}")
return values
return self

@property
@abstractmethod
Expand Down Expand Up @@ -1120,8 +1123,8 @@ def from_agent_and_tools(
**kwargs,
)

@root_validator(pre=False, skip_on_failure=True)
def validate_tools(cls, values: Dict) -> Dict:
@model_validator(mode="after")
def validate_tools(self) -> Self:
"""Validate that tools are compatible with agent.
Args:
Expand All @@ -1133,19 +1136,20 @@ def validate_tools(cls, values: Dict) -> Dict:
Raises:
ValueError: If allowed tools are different than provided tools.
"""
agent = values["agent"]
tools = values["tools"]
allowed_tools = agent.get_allowed_tools()
agent = self.agent
tools = self.tools
allowed_tools = agent.get_allowed_tools() # type: ignore
if allowed_tools is not None:
if set(allowed_tools) != set([tool.name for tool in tools]):
raise ValueError(
f"Allowed tools ({allowed_tools}) different than "
f"provided tools ({[tool.name for tool in tools]})"
)
return values
return self

@root_validator(pre=True)
def validate_runnable_agent(cls, values: Dict) -> Dict:
@model_validator(mode="before")
@classmethod
def validate_runnable_agent(cls, values: Dict) -> Any:
"""Convert runnable to agent if passed in.
Args:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from typing import List

from langchain_core.language_models import BaseLanguageModel
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.tools import BaseTool
from langchain_core.tools.base import BaseToolkit
from langchain_core.vectorstores import VectorStore
from pydantic import BaseModel, ConfigDict, Field


class VectorStoreInfo(BaseModel):
Expand All @@ -16,8 +16,9 @@ class VectorStoreInfo(BaseModel):
name: str
description: str

class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(
arbitrary_types_allowed=True,
)


class VectorStoreToolkit(BaseToolkit):
Expand All @@ -26,8 +27,9 @@ class VectorStoreToolkit(BaseToolkit):
vectorstore_info: VectorStoreInfo = Field(exclude=True)
llm: BaseLanguageModel

class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(
arbitrary_types_allowed=True,
)

def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
Expand Down Expand Up @@ -67,8 +69,9 @@ class VectorStoreRouterToolkit(BaseToolkit):
vectorstores: List[VectorStoreInfo] = Field(exclude=True)
llm: BaseLanguageModel

class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(
arbitrary_types_allowed=True,
)

def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
Expand Down
2 changes: 1 addition & 1 deletion libs/langchain/langchain/agents/chat/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from langchain_core.pydantic_v1 import Field
from langchain_core.tools import BaseTool
from pydantic import Field

from langchain.agents.agent import Agent, AgentOutputParser
from langchain.agents.chat.output_parser import ChatOutputParser
Expand Down
2 changes: 1 addition & 1 deletion libs/langchain/langchain/agents/conversational/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from langchain_core.callbacks import BaseCallbackManager
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import PromptTemplate
from langchain_core.pydantic_v1 import Field
from langchain_core.tools import BaseTool
from pydantic import Field

from langchain.agents.agent import Agent, AgentOutputParser
from langchain.agents.agent_types import AgentType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
MessagesPlaceholder,
SystemMessagePromptTemplate,
)
from langchain_core.pydantic_v1 import Field
from langchain_core.tools import BaseTool
from pydantic import Field

from langchain.agents.agent import Agent, AgentOutputParser
from langchain.agents.conversational_chat.output_parser import ConvoOutputParser
Expand Down
2 changes: 1 addition & 1 deletion libs/langchain/langchain/agents/mrkl/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
from langchain_core.callbacks import BaseCallbackManager
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import PromptTemplate
from langchain_core.pydantic_v1 import Field
from langchain_core.tools import BaseTool, Tool
from langchain_core.tools.render import render_text_description
from pydantic import Field

from langchain.agents.agent import Agent, AgentExecutor, AgentOutputParser
from langchain.agents.agent_types import AgentType
Expand Down
15 changes: 8 additions & 7 deletions libs/langchain/langchain/agents/openai_assistant/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.callbacks import CallbackManager
from langchain_core.load import dumpd
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
from langchain_core.runnables import RunnableConfig, RunnableSerializable, ensure_config
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_tool
from pydantic import BaseModel, Field, model_validator
from typing_extensions import Self

if TYPE_CHECKING:
import openai
Expand Down Expand Up @@ -232,14 +233,14 @@ def execute_agent(agent, tools, input):
as_agent: bool = False
"""Use as a LangChain agent, compatible with the AgentExecutor."""

@root_validator(pre=False, skip_on_failure=True)
def validate_async_client(cls, values: dict) -> dict:
if values["async_client"] is None:
@model_validator(mode="after")
def validate_async_client(self) -> Self:
if self.async_client is None:
import openai

api_key = values["client"].api_key
values["async_client"] = openai.AsyncOpenAI(api_key=api_key)
return values
api_key = self.client.api_key
self.async_client = openai.AsyncOpenAI(api_key=api_key)
return self

@classmethod
def create_assistant(
Expand Down
11 changes: 6 additions & 5 deletions libs/langchain/langchain/agents/openai_functions_agent/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
HumanMessagePromptTemplate,
MessagesPlaceholder,
)
from langchain_core.pydantic_v1 import root_validator
from langchain_core.runnables import Runnable, RunnablePassthrough
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_function
from pydantic import model_validator
from typing_extensions import Self

from langchain.agents import BaseSingleActionAgent
from langchain.agents.format_scratchpad.openai_functions import (
Expand Down Expand Up @@ -58,8 +59,8 @@ def get_allowed_tools(self) -> List[str]:
"""Get allowed tools."""
return [t.name for t in self.tools]

@root_validator(pre=False, skip_on_failure=True)
def validate_prompt(cls, values: dict) -> dict:
@model_validator(mode="after")
def validate_prompt(self) -> Self:
"""Validate prompt.
Args:
Expand All @@ -71,13 +72,13 @@ def validate_prompt(cls, values: dict) -> dict:
Raises:
ValueError: If `agent_scratchpad` is not in the prompt.
"""
prompt: BasePromptTemplate = values["prompt"]
prompt: BasePromptTemplate = self.prompt
if "agent_scratchpad" not in prompt.input_variables:
raise ValueError(
"`agent_scratchpad` should be one of the variables in the prompt, "
f"got {prompt.input_variables}"
)
return values
return self

@property
def input_keys(self) -> List[str]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@
HumanMessagePromptTemplate,
MessagesPlaceholder,
)
from langchain_core.pydantic_v1 import root_validator
from langchain_core.tools import BaseTool
from pydantic import model_validator
from typing_extensions import Self

from langchain.agents import BaseMultiActionAgent
from langchain.agents.format_scratchpad.openai_functions import (
Expand Down Expand Up @@ -115,15 +116,15 @@ def get_allowed_tools(self) -> List[str]:
"""Get allowed tools."""
return [t.name for t in self.tools]

@root_validator(pre=False, skip_on_failure=True)
def validate_prompt(cls, values: dict) -> dict:
prompt: BasePromptTemplate = values["prompt"]
@model_validator(mode="after")
def validate_prompt(self) -> Self:
prompt: BasePromptTemplate = self.prompt
if "agent_scratchpad" not in prompt.input_variables:
raise ValueError(
"`agent_scratchpad` should be one of the variables in the prompt, "
f"got {prompt.input_variables}"
)
return values
return self

@property
def input_keys(self) -> List[str]:
Expand Down
2 changes: 1 addition & 1 deletion libs/langchain/langchain/agents/react/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from langchain_core.documents import Document
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import Field
from langchain_core.tools import BaseTool, Tool
from pydantic import Field

from langchain.agents.agent import Agent, AgentExecutor, AgentOutputParser
from langchain.agents.agent_types import AgentType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from langchain_core._api import deprecated
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import Field
from langchain_core.runnables import Runnable, RunnablePassthrough
from langchain_core.tools import BaseTool, Tool
from pydantic import Field

from langchain.agents.agent import Agent, AgentExecutor, AgentOutputParser
from langchain.agents.agent_types import AgentType
Expand Down
2 changes: 1 addition & 1 deletion libs/langchain/langchain/agents/structured_chat/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from langchain_core.pydantic_v1 import Field
from langchain_core.runnables import Runnable, RunnablePassthrough
from langchain_core.tools import BaseTool
from langchain_core.tools.render import ToolsRenderer
from pydantic import Field

from langchain.agents.agent import Agent, AgentOutputParser
from langchain.agents.format_scratchpad import format_log_to_str
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.exceptions import OutputParserException
from langchain_core.language_models import BaseLanguageModel
from langchain_core.pydantic_v1 import Field
from pydantic import Field

from langchain.agents.agent import AgentOutputParser
from langchain.agents.structured_chat.prompt import FORMAT_INSTRUCTIONS
Expand Down
Loading

0 comments on commit 9379613

Please sign in to comment.