Skip to content

Commit

Permalink
Support LangChain 0.3.0 (#14)
Browse files Browse the repository at this point in the history
* Support LangChain 0.3.0

Signed-off-by: B-Step62 <[email protected]>

* update lock file

Signed-off-by: B-Step62 <[email protected]>

---------

Signed-off-by: B-Step62 <[email protected]>
  • Loading branch information
B-Step62 authored Sep 19, 2024
1 parent deee25b commit c2dd754
Show file tree
Hide file tree
Showing 5 changed files with 663 additions and 485 deletions.
24 changes: 12 additions & 12 deletions libs/databricks/langchain_databricks/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,11 @@
parse_tool_call,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import (
BaseModel,
Field,
PrivateAttr,
)
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_tool
from mlflow.deployments import BaseDeploymentClient # type: ignore
from pydantic import BaseModel, Field

from langchain_databricks.utils import get_deployment_client

Expand Down Expand Up @@ -180,7 +177,7 @@ class ChatDatabricks(BaseChatModel):
Tool calling:
.. code-block:: python
from langchain_core.pydantic_v1 import BaseModel, Field
from pydantic import BaseModel, Field
class GetWeather(BaseModel):
'''Get the current weather in a given location'''
Expand Down Expand Up @@ -225,13 +222,16 @@ class GetPopulation(BaseModel):
"""List of strings to stop generation at."""
max_tokens: Optional[int] = None
"""The maximum number of tokens to generate."""
extra_params: dict = Field(default_factory=dict)
extra_params: Optional[Dict[str, Any]] = None
"""Any extra parameters to pass to the endpoint."""
_client: Any = PrivateAttr()
client: Optional[BaseDeploymentClient] = Field(
default=None, exclude=True
) #: :meta private:

def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
self._client = get_deployment_client(self.target_uri)
self.client = get_deployment_client(self.target_uri)
self.extra_params = self.extra_params or {}

@property
def _default_params(self) -> Dict[str, Any]:
Expand All @@ -254,7 +254,7 @@ def _generate(
**kwargs: Any,
) -> ChatResult:
data = self._prepare_inputs(messages, stop, **kwargs)
resp = self._client.predict(endpoint=self.endpoint, inputs=data)
resp = self.client.predict(endpoint=self.endpoint, inputs=data) # type: ignore
return self._convert_response_to_chat_result(resp)

def _prepare_inputs(
Expand All @@ -267,7 +267,7 @@ def _prepare_inputs(
"messages": [_convert_message_to_dict(msg) for msg in messages],
"temperature": self.temperature,
"n": self.n,
**self.extra_params,
**self.extra_params, # type: ignore
**kwargs,
}
if stop := self.stop or stop:
Expand Down Expand Up @@ -299,7 +299,7 @@ def _stream(
) -> Iterator[ChatGenerationChunk]:
data = self._prepare_inputs(messages, stop, **kwargs)
first_chunk_role = None
for chunk in self._client.predict_stream(endpoint=self.endpoint, inputs=data):
for chunk in self.client.predict_stream(endpoint=self.endpoint, inputs=data): # type: ignore
if chunk["choices"]:
choice = chunk["choices"][0]

Expand Down
2 changes: 1 addition & 1 deletion libs/databricks/langchain_databricks/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, Dict, Iterator, List

from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, PrivateAttr
from pydantic import BaseModel, PrivateAttr

from langchain_databricks.utils import get_deployment_client

Expand Down
Loading

0 comments on commit c2dd754

Please sign in to comment.