From 69c7918adc2457c61d4a1c49c2c11cf04d14b7ca Mon Sep 17 00:00:00 2001 From: B-Step62 Date: Wed, 18 Sep 2024 15:16:26 +0900 Subject: [PATCH] Support LangChain 0.3.0 Signed-off-by: B-Step62 --- .../langchain_databricks/chat_models.py | 24 +++++++++---------- .../langchain_databricks/embeddings.py | 2 +- libs/databricks/pyproject.toml | 12 ++++++---- .../tests/unit_tests/test_chat_models.py | 2 +- 4 files changed, 22 insertions(+), 18 deletions(-) diff --git a/libs/databricks/langchain_databricks/chat_models.py b/libs/databricks/langchain_databricks/chat_models.py index f9fc7be..3c00902 100644 --- a/libs/databricks/langchain_databricks/chat_models.py +++ b/libs/databricks/langchain_databricks/chat_models.py @@ -40,14 +40,11 @@ parse_tool_call, ) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult -from langchain_core.pydantic_v1 import ( - BaseModel, - Field, - PrivateAttr, -) from langchain_core.runnables import Runnable from langchain_core.tools import BaseTool from langchain_core.utils.function_calling import convert_to_openai_tool +from mlflow.deployments import BaseDeploymentClient # type: ignore +from pydantic import BaseModel, Field from langchain_databricks.utils import get_deployment_client @@ -180,7 +177,7 @@ class ChatDatabricks(BaseChatModel): Tool calling: .. code-block:: python - from langchain_core.pydantic_v1 import BaseModel, Field + from pydantic import BaseModel, Field class GetWeather(BaseModel): '''Get the current weather in a given location''' @@ -225,13 +222,16 @@ class GetPopulation(BaseModel): """List of strings to stop generation at.""" max_tokens: Optional[int] = None """The maximum number of tokens to generate.""" - extra_params: dict = Field(default_factory=dict) + extra_params: Optional[Dict[str, Any]] = None """Any extra parameters to pass to the endpoint.""" - _client: Any = PrivateAttr() + client: Optional[BaseDeploymentClient] = Field( + default=None, exclude=True + ) #: :meta private: def __init__(self, **kwargs: Any): super().__init__(**kwargs) - self._client = get_deployment_client(self.target_uri) + self.client = get_deployment_client(self.target_uri) + self.extra_params = self.extra_params or {} @property def _default_params(self) -> Dict[str, Any]: @@ -254,7 +254,7 @@ def _generate( **kwargs: Any, ) -> ChatResult: data = self._prepare_inputs(messages, stop, **kwargs) - resp = self._client.predict(endpoint=self.endpoint, inputs=data) + resp = self.client.predict(endpoint=self.endpoint, inputs=data) # type: ignore return self._convert_response_to_chat_result(resp) def _prepare_inputs( @@ -267,7 +267,7 @@ def _prepare_inputs( "messages": [_convert_message_to_dict(msg) for msg in messages], "temperature": self.temperature, "n": self.n, - **self.extra_params, + **self.extra_params, # type: ignore **kwargs, } if stop := self.stop or stop: @@ -299,7 +299,7 @@ def _stream( ) -> Iterator[ChatGenerationChunk]: data = self._prepare_inputs(messages, stop, **kwargs) first_chunk_role = None - for chunk in self._client.predict_stream(endpoint=self.endpoint, inputs=data): + for chunk in self.client.predict_stream(endpoint=self.endpoint, inputs=data): # type: ignore if chunk["choices"]: choice = chunk["choices"][0] diff --git a/libs/databricks/langchain_databricks/embeddings.py b/libs/databricks/langchain_databricks/embeddings.py index 5211376..2cbf4a5 100644 --- a/libs/databricks/langchain_databricks/embeddings.py +++ b/libs/databricks/langchain_databricks/embeddings.py @@ -1,7 +1,7 @@ from typing import Any, Dict, Iterator, List from langchain_core.embeddings import Embeddings -from langchain_core.pydantic_v1 import BaseModel, PrivateAttr +from pydantic import BaseModel, PrivateAttr from langchain_databricks.utils import get_deployment_client diff --git a/libs/databricks/pyproject.toml b/libs/databricks/pyproject.toml index d4c0831..1b974db 100644 --- a/libs/databricks/pyproject.toml +++ b/libs/databricks/pyproject.toml @@ -12,10 +12,14 @@ license = "MIT" "Release Notes" = "https://github.com/langchain-ai/langchain-databricks/releases" [tool.poetry.dependencies] -# TODO: Replace <3.12 to <4.0 once https://github.com/mlflow/mlflow/commit/04370119fcc1b2ccdbcd9a50198ab00566d58cd2 is released -python = ">=3.8.1,<3.12" -langchain-core = "^0.2.35" -mlflow = ">=2.9" +python = ">=3.8.1,<4.0" +langchain-core = ">=0.2.35" + +# MLflow supports python 3.12 since https://github.com/mlflow/mlflow/commit/04370119fcc1b2ccdbcd9a50198ab00566d58cd2 +mlflow = [ + { version = ">=2.16.0", python = ">=3.12" }, + { version = ">=2.9", python = "<3.12" }, +] # MLflow depends on following libraries, which require different version for Python 3.8 vs 3.12 numpy = [ diff --git a/libs/databricks/tests/unit_tests/test_chat_models.py b/libs/databricks/tests/unit_tests/test_chat_models.py index bb20da4..fad3776 100644 --- a/libs/databricks/tests/unit_tests/test_chat_models.py +++ b/libs/databricks/tests/unit_tests/test_chat_models.py @@ -20,7 +20,7 @@ ToolMessageChunk, ) from langchain_core.messages.tool import ToolCallChunk -from langchain_core.pydantic_v1 import BaseModel, Field +from pydantic import BaseModel, Field from langchain_databricks.chat_models import ( ChatDatabricks,