From efbfc54ded449a4dc30f577971e30923c74a5a89 Mon Sep 17 00:00:00 2001 From: Prithvi Kannan <46332835+prithvikannan@users.noreply.github.com> Date: Wed, 11 Dec 2024 10:48:28 -0800 Subject: [PATCH 1/5] Port databricks-langchain integration tests (#26) * Port databricks-langchain integration tests Signed-off-by: Prithvi Kannan * format Signed-off-by: Prithvi Kannan * only unit tests in ci Signed-off-by: Prithvi Kannan * update import Signed-off-by: Prithvi Kannan * lint Signed-off-by: Prithvi Kannan --------- Signed-off-by: Prithvi Kannan --- .github/workflows/main.yml | 2 +- .../tests/integration_tests/__init__.py | 0 .../integration_tests/test_chat_models.py | 420 ++++++++++++++++++ .../tests/integration_tests/test_compile.py | 7 + .../integration_tests/test_embeddings.py | 29 ++ .../integration_tests/test_vectorstore.py | 46 ++ .../langchain/tests/unit_tests/__init__.py | 0 .../{ => unit_tests}/test_chat_models.py | 0 .../tests/{ => unit_tests}/test_embeddings.py | 0 .../tests/{ => unit_tests}/test_genie.py | 0 .../{ => unit_tests}/test_vectorstores.py | 0 11 files changed, 503 insertions(+), 1 deletion(-) create mode 100644 integrations/langchain/tests/integration_tests/__init__.py create mode 100644 integrations/langchain/tests/integration_tests/test_chat_models.py create mode 100644 integrations/langchain/tests/integration_tests/test_compile.py create mode 100644 integrations/langchain/tests/integration_tests/test_embeddings.py create mode 100644 integrations/langchain/tests/integration_tests/test_vectorstore.py create mode 100644 integrations/langchain/tests/unit_tests/__init__.py rename integrations/langchain/tests/{ => unit_tests}/test_chat_models.py (100%) rename integrations/langchain/tests/{ => unit_tests}/test_embeddings.py (100%) rename integrations/langchain/tests/{ => unit_tests}/test_genie.py (100%) rename integrations/langchain/tests/{ => unit_tests}/test_vectorstores.py (100%) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index fa5e819..08781c6 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -68,4 +68,4 @@ jobs: pip install integrations/langchain[dev] - name: Run tests run: | - pytest integrations/langchain/tests + pytest integrations/langchain/tests/unit_tests diff --git a/integrations/langchain/tests/integration_tests/__init__.py b/integrations/langchain/tests/integration_tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/integrations/langchain/tests/integration_tests/test_chat_models.py b/integrations/langchain/tests/integration_tests/test_chat_models.py new file mode 100644 index 0000000..cbce738 --- /dev/null +++ b/integrations/langchain/tests/integration_tests/test_chat_models.py @@ -0,0 +1,420 @@ +""" +This file contains the integration test for ChatDatabricks class. + +We run the integration tests nightly by the trusted CI/CD system defined in +a private repository, in order to securely run the tests. With this design, +integration test is not intended to be run manually by OSS contributors. +If you want to update the ChatDatabricks implementation and you think that you +need to update the corresponding integration test, please contact to the +maintainers of the repository to verify the changes. +""" + +from typing import Annotated +from unittest import mock + +import pytest +from langchain.agents import AgentExecutor, create_tool_calling_agent +from langchain_core.callbacks.base import BaseCallbackHandler +from langchain_core.messages import ( + AIMessage, + AIMessageChunk, + HumanMessage, + SystemMessage, + ToolMessage, +) +from langchain_core.output_parsers import StrOutputParser +from langchain_core.prompts import ChatPromptTemplate +from langchain_core.tools import tool +from langgraph.checkpoint.memory import MemorySaver +from langgraph.graph import START, StateGraph +from langgraph.graph.message import add_messages +from langgraph.prebuilt import ToolNode, create_react_agent, tools_condition +from pydantic import BaseModel, Field +from typing_extensions import TypedDict + +from databricks_langchain.chat_models import ChatDatabricks + +_TEST_ENDPOINT = "databricks-meta-llama-3-70b-instruct" + + +def test_chat_databricks_invoke(): + chat = ChatDatabricks(endpoint=_TEST_ENDPOINT, temperature=0, max_tokens=10, stop=["Java"]) + + response = chat.invoke("How to learn Java? Start the response by 'To learn Java,'") + assert isinstance(response, AIMessage) + assert response.content == "To learn " + assert response.response_metadata["prompt_tokens"] == 24 + assert response.response_metadata["completion_tokens"] == 3 + assert response.response_metadata["total_tokens"] == 27 + + response = chat.invoke("How to learn Python? Start the response by 'To learn Python,'") + assert response.content.startswith("To learn Python,") + assert len(response.content.split(" ")) <= 15 # Give some margin for tokenization difference + + # Call with a system message + response = chat.invoke( + [ + ("system", "You are helpful programming tutor."), + ("user", "How to learn Python? Start the response by 'To learn Python,'"), + ] + ) + assert response.content.startswith("To learn Python,") + + # Call with message history + response = chat.invoke( + [ + SystemMessage(content="You are helpful sports coach."), + HumanMessage(content="How to swim better?"), + AIMessage(content="You need more and more practice.", id="12345"), + HumanMessage(content="No, I need more tips."), + ] + ) + assert response.content is not None + + +def test_chat_databricks_invoke_multiple_completions(): + chat = ChatDatabricks( + endpoint=_TEST_ENDPOINT, + temperature=0.5, + n=3, + max_tokens=10, + ) + response = chat.invoke("How to learn Python?") + assert isinstance(response, AIMessage) + + +def test_chat_databricks_stream(): + class FakeCallbackHandler(BaseCallbackHandler): + def __init__(self): + self.chunk_counts = 0 + + def on_llm_new_token(self, *args, **kwargs): + self.chunk_counts += 1 + + callback = FakeCallbackHandler() + + chat = ChatDatabricks( + endpoint=_TEST_ENDPOINT, + temperature=0, + stop=["Python"], + max_tokens=100, + ) + + chunks = list(chat.stream("How to learn Python?", config={"callbacks": [callback]})) + assert len(chunks) > 0 + assert all(isinstance(chunk, AIMessageChunk) for chunk in chunks) + assert all("Python" not in chunk.content for chunk in chunks) + assert callback.chunk_counts == len(chunks) + + last_chunk = chunks[-1] + assert last_chunk.response_metadata["finish_reason"] == "stop" + + +def test_chat_databricks_stream_with_usage(): + class FakeCallbackHandler(BaseCallbackHandler): + def __init__(self): + self.chunk_counts = 0 + + def on_llm_new_token(self, *args, **kwargs): + self.chunk_counts += 1 + + callback = FakeCallbackHandler() + + chat = ChatDatabricks( + endpoint=_TEST_ENDPOINT, + temperature=0, + stop=["Python"], + max_tokens=100, + stream_usage=True, + ) + + chunks = list(chat.stream("How to learn Python?", config={"callbacks": [callback]})) + assert len(chunks) > 0 + assert all(isinstance(chunk, AIMessageChunk) for chunk in chunks) + assert all("Python" not in chunk.content for chunk in chunks) + assert callback.chunk_counts == len(chunks) + + last_chunk = chunks[-1] + assert last_chunk.response_metadata["finish_reason"] == "stop" + assert last_chunk.usage_metadata is not None + assert last_chunk.usage_metadata["input_tokens"] > 0 + assert last_chunk.usage_metadata["output_tokens"] > 0 + assert last_chunk.usage_metadata["total_tokens"] > 0 + + +@pytest.mark.asyncio +async def test_chat_databricks_ainvoke(): + chat = ChatDatabricks( + endpoint=_TEST_ENDPOINT, + temperature=0, + max_tokens=10, + ) + + response = await chat.ainvoke("How to learn Python? Start the response by 'To learn Python,'") + assert isinstance(response, AIMessage) + assert response.content.startswith("To learn Python,") + + +async def test_chat_databricks_astream(): + chat = ChatDatabricks( + endpoint=_TEST_ENDPOINT, + temperature=0, + max_tokens=10, + ) + chunk_count = 0 + async for chunk in chat.astream("How to learn Python?"): + assert isinstance(chunk, AIMessageChunk) + chunk_count += 1 + assert chunk_count > 0 + + +@pytest.mark.asyncio +async def test_chat_databricks_abatch(): + chat = ChatDatabricks( + endpoint=_TEST_ENDPOINT, + temperature=0, + max_tokens=10, + ) + + responses = await chat.abatch( + [ + "How to learn Python?", + "How to learn Java?", + "How to learn C++?", + ] + ) + assert len(responses) == 3 + assert all(isinstance(response, AIMessage) for response in responses) + + +@pytest.mark.parametrize("tool_choice", [None, "auto", "required", "any", "none"]) +def test_chat_databricks_tool_calls(tool_choice): + chat = ChatDatabricks( + endpoint=_TEST_ENDPOINT, + temperature=0, + max_tokens=100, + ) + + class GetWeather(BaseModel): + """Get the current weather in a given location""" + + location: str = Field(..., description="The city and state, e.g. San Francisco, CA") + + llm_with_tools = chat.bind_tools([GetWeather], tool_choice=tool_choice) + question = "Which is the current weather in Los Angeles, CA?" + response = llm_with_tools.invoke(question) + + if tool_choice == "none": + assert response.tool_calls == [] + return + + assert response.tool_calls == [ + { + "name": "GetWeather", + "args": {"location": "Los Angeles, CA"}, + "id": mock.ANY, + "type": "tool_call", + } + ] + + tool_msg = ToolMessage( + "GetWeather", + tool_call_id=response.additional_kwargs["tool_calls"][0]["id"], + ) + response = llm_with_tools.invoke( + [ + HumanMessage(question), + response, + tool_msg, + HumanMessage("What about San Francisco, CA?"), + ] + ) + + assert response.tool_calls == [ + { + "name": "GetWeather", + "args": {"location": "San Francisco, CA"}, + "id": mock.ANY, + "type": "tool_call", + } + ] + + +# Pydantic-based schema +class AnswerWithJustification(BaseModel): + """An answer to the user question along with justification for the answer.""" + + answer: str = Field(description="The answer to the user question.") + justification: str = Field(description="The justification for the answer.") + + +# Raw JSON schema +JSON_SCHEMA = { + "title": "AnswerWithJustification", + "description": "An answer to the user question along with justification.", + "type": "object", + "properties": { + "answer": { + "type": "string", + "description": "The answer to the user question.", + }, + "justification": { + "type": "string", + "description": "The justification for the answer.", + }, + }, + "required": ["answer", "justification"], +} + + +@pytest.mark.parametrize("schema", [AnswerWithJustification, JSON_SCHEMA, None]) +@pytest.mark.parametrize("method", ["function_calling", "json_mode"]) +def test_chat_databricks_with_structured_output(schema, method): + llm = ChatDatabricks(endpoint=_TEST_ENDPOINT) + + if schema is None and method == "function_calling": + pytest.skip("Cannot use function_calling without schema") + + structured_llm = llm.with_structured_output(schema, method=method) + + if method == "function_calling": + prompt = "What day comes two days after Monday?" + else: + prompt = ( + "What day comes two days after Monday? Return in JSON format with key " + "'answer' for the answer and 'justification' for the justification." + ) + + response = structured_llm.invoke(prompt) + + if schema == AnswerWithJustification: + assert response.answer == "Wednesday" + assert response.justification is not None + else: + assert response["answer"] == "Wednesday" + assert response["justification"] is not None + + # Invoke with raw output + structured_llm = llm.with_structured_output(schema, method=method, include_raw=True) + response_with_raw = structured_llm.invoke(prompt) + assert isinstance(response_with_raw["raw"], AIMessage) + + +def test_chat_databricks_runnable_sequence(): + chat = ChatDatabricks( + endpoint=_TEST_ENDPOINT, + temperature=0, + max_tokens=100, + ) + + prompt = ChatPromptTemplate.from_template("tell me a joke about {topic}") + chain = prompt | chat | StrOutputParser() + + response = chain.invoke({"topic": "chicken"}) + assert "chicken" in response + + +@tool +def add(a: int, b: int) -> int: + """Add two integers. + + Args: + a: First integer + b: Second integer + """ + return a + b + + +@tool +def multiply(a: int, b: int) -> int: + """Multiply two integers. + + Args: + a: First integer + b: Second integer + """ + return a * b + + +def test_chat_databricks_agent_executor(): + model = ChatDatabricks( + endpoint=_TEST_ENDPOINT, + temperature=0, + max_tokens=100, + ) + tools = [add, multiply] + prompt = ChatPromptTemplate.from_messages( + [ + ("system", "You are a helpful assistant"), + ("human", "{input}"), + ("placeholder", "{agent_scratchpad}"), + ] + ) + + agent = create_tool_calling_agent(model, tools, prompt) + agent_executor = AgentExecutor(agent=agent, tools=tools) + + response = agent_executor.invoke({"input": "What is (10 + 5) * 3?"}) + assert "45" in response["output"] + + +def test_chat_databricks_langgraph(): + model = ChatDatabricks( + endpoint=_TEST_ENDPOINT, + temperature=0, + max_tokens=100, + ) + tools = [add, multiply] + + app = create_react_agent(model, tools) + response = app.invoke({"messages": [("human", "What is (10 + 5) * 3?")]}) + assert "45" in response["messages"][-1].content + + +def test_chat_databricks_langgraph_with_memory(): + class State(TypedDict): + messages: Annotated[list, add_messages] + + tools = [add, multiply] + llm = ChatDatabricks( + endpoint=_TEST_ENDPOINT, + temperature=0, + max_tokens=100, + ) + llm_with_tools = llm.bind_tools(tools) + + def chatbot(state: State): + return {"messages": [llm_with_tools.invoke(state["messages"])]} + + graph_builder = StateGraph(State) + graph_builder.add_node("chatbot", chatbot) + + tool_node = ToolNode(tools=tools) + graph_builder.add_node("tools", tool_node) + graph_builder.add_conditional_edges("chatbot", tools_condition) + # Any time a tool is called, we return to the chatbot to decide the next step + graph_builder.add_edge("tools", "chatbot") + graph_builder.add_edge(START, "chatbot") + + graph = graph_builder.compile(checkpointer=MemorySaver()) + + response = graph.invoke( + {"messages": [("user", "What is (10 + 5) * 3?")]}, + config={"configurable": {"thread_id": "1"}}, + ) + assert "45" in response["messages"][-1].content + + response = graph.invoke( + {"messages": [("user", "Subtract 5 from it")]}, + config={"configurable": {"thread_id": "1"}}, + ) + + # Interestingly, the agent sometimes mistakes the subtraction for addition:( + # In such case, the agent asks for a retry so we need one more step. + if "Let me try again." in response["messages"][-1].content: + response = graph.invoke( + {"messages": [("user", "Ok, try again")]}, + config={"configurable": {"thread_id": "1"}}, + ) + + assert "40" in response["messages"][-1].content diff --git a/integrations/langchain/tests/integration_tests/test_compile.py b/integrations/langchain/tests/integration_tests/test_compile.py new file mode 100644 index 0000000..33ecccd --- /dev/null +++ b/integrations/langchain/tests/integration_tests/test_compile.py @@ -0,0 +1,7 @@ +import pytest + + +@pytest.mark.compile +def test_placeholder() -> None: + """Used for compiling integration tests without running any real tests.""" + pass diff --git a/integrations/langchain/tests/integration_tests/test_embeddings.py b/integrations/langchain/tests/integration_tests/test_embeddings.py new file mode 100644 index 0000000..3abda59 --- /dev/null +++ b/integrations/langchain/tests/integration_tests/test_embeddings.py @@ -0,0 +1,29 @@ +""" +This file contains the integration test for DatabricksEmbeddings class. + +We run the integration tests nightly by the trusted CI/CD system defined in +a private repository, in order to securely run the tests. With this design, +integration test is not intended to be run manually by OSS contributors. +If you want to update the DatabricksEmbeddings implementation and you think +that you need to update the corresponding integration test, please contact to +the maintainers of the repository to verify the changes. +""" + +from databricks_langchain import DatabricksEmbeddings + +_TEST_ENDPOINT = "databricks-bge-large-en" + + +def test_embedding_documents() -> None: + documents = ["foo bar"] + embedding = DatabricksEmbeddings(endpoint=_TEST_ENDPOINT) + output = embedding.embed_documents(documents) + assert len(output) == 1 + assert len(output[0]) > 0 + + +def test_embedding_query() -> None: + document = "foo bar" + embedding = DatabricksEmbeddings(endpoint=_TEST_ENDPOINT) + output = embedding.embed_query(document) + assert len(output) > 0 diff --git a/integrations/langchain/tests/integration_tests/test_vectorstore.py b/integrations/langchain/tests/integration_tests/test_vectorstore.py new file mode 100644 index 0000000..767deb7 --- /dev/null +++ b/integrations/langchain/tests/integration_tests/test_vectorstore.py @@ -0,0 +1,46 @@ +""" +This file contains the integration test for DatabricksVectorSearch class. + +We run the integration tests nightly by the trusted CI/CD system defined in +a private repository, in order to securely run the tests. With this design, +integration test is not intended to be run manually by OSS contributors. +If you want to update the DatabricksVectorSearch implementation and you think +that you need to update the corresponding integration test, please contact to +the maintainers of the repository to verify the changes. +""" + +import os +from datetime import timedelta + +import pytest +from databricks.sdk import WorkspaceClient +from databricks.sdk.service.jobs import RunLifecycleStateV2State, TerminationTypeType + + +@pytest.mark.timeout(3600) +def test_vectorstore(): + """ + We run the integration tests for vector store by Databricks Workflow, + because the setup is too complex to run within a single python file. + Thereby, this test simply triggers the workflow by calling the REST API. + """ + test_job_id = os.getenv("VS_TEST_JOB_ID") + if not test_job_id: + raise RuntimeError("Please set the environment variable VS_TEST_JOB_ID") + + w = WorkspaceClient() + + # Check if there is any ongoing job run + run_list = list(w.jobs.list_runs(job_id=test_job_id, active_only=True)) + no_active_run = len(run_list) == 0 + assert no_active_run, "There is an ongoing job run. Please wait for it to complete." + + # Trigger the workflow + response = w.jobs.run_now(job_id=test_job_id) + job_url = f"{w.config.host}/jobs/{test_job_id}/runs/{response.run_id}" + print(f"Started the job at {job_url}") # noqa: T201 + + # Wait for the job to complete + result = response.result(timeout=timedelta(seconds=3600)) + assert result.status.state == RunLifecycleStateV2State.TERMINATED + assert result.status.termination_details.type == TerminationTypeType.SUCCESS diff --git a/integrations/langchain/tests/unit_tests/__init__.py b/integrations/langchain/tests/unit_tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/integrations/langchain/tests/test_chat_models.py b/integrations/langchain/tests/unit_tests/test_chat_models.py similarity index 100% rename from integrations/langchain/tests/test_chat_models.py rename to integrations/langchain/tests/unit_tests/test_chat_models.py diff --git a/integrations/langchain/tests/test_embeddings.py b/integrations/langchain/tests/unit_tests/test_embeddings.py similarity index 100% rename from integrations/langchain/tests/test_embeddings.py rename to integrations/langchain/tests/unit_tests/test_embeddings.py diff --git a/integrations/langchain/tests/test_genie.py b/integrations/langchain/tests/unit_tests/test_genie.py similarity index 100% rename from integrations/langchain/tests/test_genie.py rename to integrations/langchain/tests/unit_tests/test_genie.py diff --git a/integrations/langchain/tests/test_vectorstores.py b/integrations/langchain/tests/unit_tests/test_vectorstores.py similarity index 100% rename from integrations/langchain/tests/test_vectorstores.py rename to integrations/langchain/tests/unit_tests/test_vectorstores.py From 681d59d0a672e2efdc6859442d8af5b742d93f76 Mon Sep 17 00:00:00 2001 From: Prithvi Kannan <46332835+prithvikannan@users.noreply.github.com> Date: Thu, 12 Dec 2024 00:22:56 -0800 Subject: [PATCH 2/5] Add langchain integration test deps (#30) Signed-off-by: Prithvi Kannan --- integrations/langchain/pyproject.toml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/integrations/langchain/pyproject.toml b/integrations/langchain/pyproject.toml index 1e19cb9..cb03c22 100644 --- a/integrations/langchain/pyproject.toml +++ b/integrations/langchain/pyproject.toml @@ -24,6 +24,11 @@ dev = [ "ruff==0.6.4", ] +integration = [ + "langgraph>=0.2.27", + "pytest-timeout>=2.3.1", +] + [build-system] requires = ["hatchling"] build-backend = "hatchling.build" From c9946ad183a5a7c18db850da329d6233abcaaa11 Mon Sep 17 00:00:00 2001 From: Prithvi Kannan <46332835+prithvikannan@users.noreply.github.com> Date: Thu, 12 Dec 2024 00:41:16 -0800 Subject: [PATCH 3/5] Fix test_chat_databricks_invoke with llama 3.3 (#29) * Fix test_chat_databricks_invoke with llama 3.3 Signed-off-by: Prithvi Kannan * fix Signed-off-by: Prithvi Kannan --------- Signed-off-by: Prithvi Kannan --- .../tests/integration_tests/test_chat_models.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/integrations/langchain/tests/integration_tests/test_chat_models.py b/integrations/langchain/tests/integration_tests/test_chat_models.py index cbce738..ab67e3d 100644 --- a/integrations/langchain/tests/integration_tests/test_chat_models.py +++ b/integrations/langchain/tests/integration_tests/test_chat_models.py @@ -43,9 +43,13 @@ def test_chat_databricks_invoke(): response = chat.invoke("How to learn Java? Start the response by 'To learn Java,'") assert isinstance(response, AIMessage) assert response.content == "To learn " - assert response.response_metadata["prompt_tokens"] == 24 - assert response.response_metadata["completion_tokens"] == 3 - assert response.response_metadata["total_tokens"] == 27 + assert 20 <= response.response_metadata["prompt_tokens"] <= 30 + assert 1 <= response.response_metadata["completion_tokens"] <= 10 + expected_total = ( + response.response_metadata["prompt_tokens"] + + response.response_metadata["completion_tokens"] + ) + assert response.response_metadata["total_tokens"] == expected_total response = chat.invoke("How to learn Python? Start the response by 'To learn Python,'") assert response.content.startswith("To learn Python,") From b540ce5f43d249cc171845fdb892ae3da07a2b90 Mon Sep 17 00:00:00 2001 From: Prithvi Kannan <46332835+prithvikannan@users.noreply.github.com> Date: Fri, 13 Dec 2024 15:03:38 -0800 Subject: [PATCH 4/5] databricks-langchain as primary package (#34) * databricks-langchain as primary package Signed-off-by: Prithvi Kannan * genie Signed-off-by: Prithvi Kannan --------- Signed-off-by: Prithvi Kannan --- integrations/langchain/README.md | 44 ++++++++++++++++++++++---------- 1 file changed, 30 insertions(+), 14 deletions(-) diff --git a/integrations/langchain/README.md b/integrations/langchain/README.md index c60cef0..0d24f6b 100644 --- a/integrations/langchain/README.md +++ b/integrations/langchain/README.md @@ -1,38 +1,54 @@ -# 🦜🔗 Using Databricks AI Bridge with Langchain +# 🦜🔗 Databricks LangChain Integration -Integrate Databricks AI Bridge package with Langchain to allow seamless usage of Databricks AI features with Langchain/Langgraph applications. - -Note: This repository is the future home for all Databricks integrations currently found in `langchain-databricks` and `langchain-community`. We have now aliased `langchain-databricks` to `databricks-langchain`, consolidating integrations such as ChatDatabricks, DatabricksEmbeddings, DatabricksVectorSearch, and more under this package. +The `databricks-langchain` package provides seamless integration of Databricks AI features into LangChain applications. This repository is now the central hub for all Databricks-related LangChain components, consolidating previous packages such as `langchain-databricks` and `langchain-community`. ## Installation -### Install from PyPI +### From PyPI ```sh pip install databricks-langchain ``` -### Install from source - +### From Source ```sh pip install git+ssh://git@github.com/databricks/databricks-ai-bridge.git#subdirectory=integrations/langchain ``` -## Get started +## Key Features -### Use LLMs on Databricks +- **LLMs Integration:** Use Databricks-hosted large language models (LLMs) like Llama and Mixtral through `ChatDatabricks`. +- **Vector Search:** Store and query vector representations using `DatabricksVectorSearch`. +- **Embeddings:** Generate embeddings with `DatabricksEmbeddings`. +- **Genie:** Use [Genie](https://www.databricks.com/product/ai-bi/genie) in Langchain. +## Getting Started + +### Use LLMs on Databricks ```python from databricks_langchain import ChatDatabricks + llm = ChatDatabricks(endpoint="databricks-meta-llama-3-1-70b-instruct") ``` -### (Preview) Use a Genie space as an agent - -> [!NOTE] -> Requires Genie API Private Preview. Reach out to your account team for enablement. +### Use a Genie Space as an Agent (Preview) +> **Note:** Requires Genie API Private Preview. Contact your Databricks account team for enablement. ```python from databricks_langchain.genie import GenieAgent -genie_agent = GenieAgent("space-id", "Genie", description="This Genie space has access to sales data in Europe") +genie_agent = GenieAgent( + "space-id", "Genie", + description="This Genie space has access to sales data in Europe" +) ``` + +--- + +## Contribution Guide +We welcome contributions! Please see our [contribution guidelines](https://github.com/databricks/databricks-ai-bridge/tree/main/integrations/langchain) for details. + +## License +This project is licensed under the [MIT License](LICENSE). + +Thank you for using Databricks LangChain! + From efb256a73d4cb42c21d710a3db42569790ebdf56 Mon Sep 17 00:00:00 2001 From: Sunish Sheth Date: Mon, 16 Dec 2024 11:06:00 -0800 Subject: [PATCH 5/5] Adding a python http_request wrapper to create external tools (#28) * Adding a python http_request wrapper to create external tools Signed-off-by: Sunish Sheth * Update based on comments Signed-off-by: Sunish Sheth --------- Signed-off-by: Sunish Sheth --- src/databricks_ai_bridge/external_tools.py | 57 ++++++++++++ src/databricks_ai_bridge/utils/annotations.py | 67 ++++++++++++++ .../test_external_tools.py | 89 +++++++++++++++++++ 3 files changed, 213 insertions(+) create mode 100644 src/databricks_ai_bridge/external_tools.py create mode 100644 src/databricks_ai_bridge/utils/annotations.py create mode 100644 tests/databricks_ai_bridge/test_external_tools.py diff --git a/src/databricks_ai_bridge/external_tools.py b/src/databricks_ai_bridge/external_tools.py new file mode 100644 index 0000000..7813b86 --- /dev/null +++ b/src/databricks_ai_bridge/external_tools.py @@ -0,0 +1,57 @@ +import json as js +from typing import Any, Dict, Optional + +import requests +from databricks.sdk import WorkspaceClient + +from databricks_ai_bridge.utils.annotations import experimental + + +@experimental +def http_request( + conn: str, + method: str, + path: str, + *, + json: Optional[Any] = None, + headers: Optional[Dict[str, str]] = None, + params: Optional[Dict[str, Any]] = None, +) -> requests.Response: + """ + Makes an HTTP request to a remote API using authentication from a Unity Catalog HTTP connection. + + Args: + conn (str): The connection name to use. This is required to identify the external connection. + method (str): The HTTP method to use (e.g., "GET", "POST"). This is required. + path (str): The relative path for the API endpoint. This is required. + json (Optional[Any]): JSON payload for the request. + headers (Optional[Dict[str, str]]): Additional headers for the request. + If not provided, only auth headers from connections would be passed. + params (Optional[Dict[str, Any]]): Query parameters for the request. + + Returns: + requests.Response: The HTTP response from the external function. + + Example Usage: + response = http_request( + conn="my_connection", + method="POST", + path="/api/v1/resource", + json={"key": "value"}, + headers={"extra_header_key": "extra_header_value"}, + params={"query": "example"} + ) + """ + workspaceConfig = WorkspaceClient().config + url = f"{workspaceConfig.host}/external-functions" + request_headers = workspaceConfig._header_factory() + payload = { + "connection_name": conn, + "method": method, + "path": path, + "json": js.dumps(json), + "header": headers, + "params": params, + } + + return requests.post(url, headers=request_headers, json=payload) diff --git a/src/databricks_ai_bridge/utils/annotations.py b/src/databricks_ai_bridge/utils/annotations.py new file mode 100644 index 0000000..098073f --- /dev/null +++ b/src/databricks_ai_bridge/utils/annotations.py @@ -0,0 +1,67 @@ +# This code is copied from MLflow: https://github.com/mlflow/mlflow/blob/v2.19.0/mlflow/utils/annotations.py#L31 + +import inspect +import re +import types +from typing import Any, Callable, TypeVar, Union + +C = TypeVar("C", bound=Callable[..., Any]) + + +def _get_min_indent_of_docstring(docstring_str: str) -> str: + """ + Get the minimum indentation string of a docstring, based on the assumption + that the closing triple quote for multiline comments must be on a new line. + Note that based on ruff rule D209, the closing triple quote for multiline + comments must be on a new line. + + Args: + docstring_str: string with docstring + + Returns: + Whitespace corresponding to the indent of a docstring. + """ + + if not docstring_str or "\n" not in docstring_str: + return "" + + return re.match(r"^\s*", docstring_str.rsplit("\n", 1)[-1]).group() + + +def experimental(api_or_type: Union[C, str]) -> C: + """Decorator / decorator creator for marking APIs experimental in the docstring. + + Args: + api_or_type: An API to mark, or an API typestring for which to generate a decorator. + + Returns: + Decorated API (if a ``api_or_type`` is an API) or a function that decorates + the specified API type (if ``api_or_type`` is a typestring). + """ + if isinstance(api_or_type, str): + + def f(api: C) -> C: + return _experimental(api=api, api_type=api_or_type) + + return f + elif inspect.isclass(api_or_type): + return _experimental(api=api_or_type, api_type="class") + elif inspect.isfunction(api_or_type): + return _experimental(api=api_or_type, api_type="function") + elif isinstance(api_or_type, (property, types.MethodType)): + return _experimental(api=api_or_type, api_type="property") + else: + return _experimental(api=api_or_type, api_type=str(type(api_or_type))) + + +def _experimental(api: C, api_type: str) -> C: + indent = _get_min_indent_of_docstring(api.__doc__) + notice = ( + indent + f".. Note:: Experimental: This {api_type} may change or " + "be removed in a future release without warning.\n\n" + ) + if api_type == "property": + api.__doc__ = api.__doc__ + "\n\n" + notice if api.__doc__ else notice + else: + api.__doc__ = notice + api.__doc__ if api.__doc__ else notice + return api diff --git a/tests/databricks_ai_bridge/test_external_tools.py b/tests/databricks_ai_bridge/test_external_tools.py new file mode 100644 index 0000000..6a1149d --- /dev/null +++ b/tests/databricks_ai_bridge/test_external_tools.py @@ -0,0 +1,89 @@ +from unittest.mock import MagicMock, patch + +from databricks_ai_bridge.external_tools import http_request + + +@patch("databricks_ai_bridge.external_tools.WorkspaceClient") +@patch("databricks_ai_bridge.external_tools.requests.post") +def test_http_request_success(mock_post, mock_workspace_client): + # Mock the WorkspaceClient config + mock_workspace_config = MagicMock() + mock_workspace_config.host = "https://mock-host" + mock_workspace_config._header_factory.return_value = {"Authorization": "Bearer mock-token"} + mock_workspace_client.return_value.config = mock_workspace_config + + # Mock the POST request + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"success": True} + mock_post.return_value = mock_response + + # Call the function + response = http_request( + conn="mock_connection", + method="POST", + path="/mock-path", + json={"key": "value"}, + headers={"Custom-Header": "HeaderValue"}, + params={"query": "test"}, + ) + + # Assertions + assert response.status_code == 200 + assert response.json() == {"success": True} + mock_post.assert_called_once_with( + "https://mock-host/external-functions", + headers={ + "Authorization": "Bearer mock-token", + }, + json={ + "connection_name": "mock_connection", + "method": "POST", + "path": "/mock-path", + "json": '{"key": "value"}', + "header": { + "Custom-Header": "HeaderValue", + }, + "params": {"query": "test"}, + }, + ) + + +@patch("databricks_ai_bridge.external_tools.WorkspaceClient") +@patch("databricks_ai_bridge.external_tools.requests.post") +def test_http_request_error_response(mock_post, mock_workspace_client): + # Mock the WorkspaceClient config + mock_workspace_config = MagicMock() + mock_workspace_config.host = "https://mock-host" + mock_workspace_config._header_factory.return_value = {"Authorization": "Bearer mock-token"} + mock_workspace_client.return_value.config = mock_workspace_config + + # Mock the POST request to return an error + mock_response = MagicMock() + mock_response.status_code = 400 + mock_response.json.return_value = {"error": "Bad Request"} + mock_post.return_value = mock_response + + # Call the function + response = http_request( + conn="mock_connection", + method="POST", + path="/mock-path", + json={"key": "value"}, + ) + + # Assertions + assert response.status_code == 400 + assert response.json() == {"error": "Bad Request"} + mock_post.assert_called_once_with( + "https://mock-host/external-functions", + headers={"Authorization": "Bearer mock-token"}, + json={ + "connection_name": "mock_connection", + "method": "POST", + "path": "/mock-path", + "json": '{"key": "value"}', + "header": None, + "params": None, + }, + )