Skip to content

Commit

Permalink
Support LangChain 0.3.0
Browse files Browse the repository at this point in the history
Signed-off-by: B-Step62 <[email protected]>
  • Loading branch information
B-Step62 committed Sep 18, 2024
1 parent deee25b commit 69c7918
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 18 deletions.
24 changes: 12 additions & 12 deletions libs/databricks/langchain_databricks/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,11 @@
parse_tool_call,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import (
BaseModel,
Field,
PrivateAttr,
)
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_tool
from mlflow.deployments import BaseDeploymentClient # type: ignore
from pydantic import BaseModel, Field

from langchain_databricks.utils import get_deployment_client

Expand Down Expand Up @@ -180,7 +177,7 @@ class ChatDatabricks(BaseChatModel):
Tool calling:
.. code-block:: python
from langchain_core.pydantic_v1 import BaseModel, Field
from pydantic import BaseModel, Field
class GetWeather(BaseModel):
'''Get the current weather in a given location'''
Expand Down Expand Up @@ -225,13 +222,16 @@ class GetPopulation(BaseModel):
"""List of strings to stop generation at."""
max_tokens: Optional[int] = None
"""The maximum number of tokens to generate."""
extra_params: dict = Field(default_factory=dict)
extra_params: Optional[Dict[str, Any]] = None
"""Any extra parameters to pass to the endpoint."""
_client: Any = PrivateAttr()
client: Optional[BaseDeploymentClient] = Field(
default=None, exclude=True
) #: :meta private:

def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
self._client = get_deployment_client(self.target_uri)
self.client = get_deployment_client(self.target_uri)
self.extra_params = self.extra_params or {}

@property
def _default_params(self) -> Dict[str, Any]:
Expand All @@ -254,7 +254,7 @@ def _generate(
**kwargs: Any,
) -> ChatResult:
data = self._prepare_inputs(messages, stop, **kwargs)
resp = self._client.predict(endpoint=self.endpoint, inputs=data)
resp = self.client.predict(endpoint=self.endpoint, inputs=data) # type: ignore
return self._convert_response_to_chat_result(resp)

def _prepare_inputs(
Expand All @@ -267,7 +267,7 @@ def _prepare_inputs(
"messages": [_convert_message_to_dict(msg) for msg in messages],
"temperature": self.temperature,
"n": self.n,
**self.extra_params,
**self.extra_params, # type: ignore
**kwargs,
}
if stop := self.stop or stop:
Expand Down Expand Up @@ -299,7 +299,7 @@ def _stream(
) -> Iterator[ChatGenerationChunk]:
data = self._prepare_inputs(messages, stop, **kwargs)
first_chunk_role = None
for chunk in self._client.predict_stream(endpoint=self.endpoint, inputs=data):
for chunk in self.client.predict_stream(endpoint=self.endpoint, inputs=data): # type: ignore
if chunk["choices"]:
choice = chunk["choices"][0]

Expand Down
2 changes: 1 addition & 1 deletion libs/databricks/langchain_databricks/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, Dict, Iterator, List

from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, PrivateAttr
from pydantic import BaseModel, PrivateAttr

from langchain_databricks.utils import get_deployment_client

Expand Down
12 changes: 8 additions & 4 deletions libs/databricks/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,14 @@ license = "MIT"
"Release Notes" = "https://github.com/langchain-ai/langchain-databricks/releases"

[tool.poetry.dependencies]
# TODO: Replace <3.12 to <4.0 once https://github.com/mlflow/mlflow/commit/04370119fcc1b2ccdbcd9a50198ab00566d58cd2 is released
python = ">=3.8.1,<3.12"
langchain-core = "^0.2.35"
mlflow = ">=2.9"
python = ">=3.8.1,<4.0"
langchain-core = ">=0.2.35"

# MLflow supports python 3.12 since https://github.com/mlflow/mlflow/commit/04370119fcc1b2ccdbcd9a50198ab00566d58cd2
mlflow = [
{ version = ">=2.16.0", python = ">=3.12" },
{ version = ">=2.9", python = "<3.12" },
]

# MLflow depends on following libraries, which require different version for Python 3.8 vs 3.12
numpy = [
Expand Down
2 changes: 1 addition & 1 deletion libs/databricks/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
ToolMessageChunk,
)
from langchain_core.messages.tool import ToolCallChunk
from langchain_core.pydantic_v1 import BaseModel, Field
from pydantic import BaseModel, Field

from langchain_databricks.chat_models import (
ChatDatabricks,
Expand Down

0 comments on commit 69c7918

Please sign in to comment.