diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 91ebd2e..94a3681 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -50,3 +50,24 @@ jobs: - name: Run tests run: | pytest tests/ + + langchain_test: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ['3.8', '3.9', '3.10'] + timeout-minutes: 20 + steps: + - name: Checkout code + uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + pip install . + pip install integrations/langchain[dev] + - name: Run tests + run: | + pytest integrations/langchain/tests diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 7d00e19..3baf1f5 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -5,6 +5,6 @@ Create a conda environement and install dev requirements ``` conda create --name databricks-ai-dev-env python=3.10 conda activate databricks-ai-dev-env -pip install -e ".[databricks-dev]" +pip install -e ".[dev]" pip install -r requirements/lint-requirements.txt ``` diff --git a/integrations/langchain/pyproject.toml b/integrations/langchain/pyproject.toml new file mode 100644 index 0000000..40527ed --- /dev/null +++ b/integrations/langchain/pyproject.toml @@ -0,0 +1,67 @@ +[project] +name = "databricks-langchain" +version = "0.0.1" +description = "Support for Datarbricks AI support in LangChain" +authors = [ + { name="Prithvi Kannan", email="prithvi.kannan@databricks.com" }, +] +readme = "README.md" +license = { text="Apache-2.0" } +requires-python = ">=3.8" +dependencies = [ + "langchain>=0.2.0", + "langchain-community>=0.2.0", + "databricks-ai-bridge", +] + +[project.optional-dependencies] +dev = [ + "pytest", + "typing_extensions", + "databricks-sdk>=0.34.0", + "ruff==0.6.4", +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build] +include = [ + "src/databricks_langchain/*" +] + +[tool.hatch.build.targets.wheel] +packages = ["src/databricks_langchain"] + +[tool.ruff] +line-length = 100 +target-version = "py39" + +[tool.ruff.lint] +select = [ + # isort + "I", + # bugbear rules + "B", + # remove unused imports + "F401", + # bare except statements + "E722", + # print statements + "T201", + "T203", + # misuse of typing.TYPE_CHECKING + "TCH004", + # import rules + "TID251", + # undefined-local-with-import-star + "F403", +] + +[tool.ruff.format] +docstring-code-format = true +docstring-code-line-length = 88 + +[tool.ruff.lint.pydocstyle] +convention = "google" \ No newline at end of file diff --git a/integrations/langchain/src/databricks_langchain/genie.py b/integrations/langchain/src/databricks_langchain/genie.py new file mode 100644 index 0000000..153c2df --- /dev/null +++ b/integrations/langchain/src/databricks_langchain/genie.py @@ -0,0 +1,47 @@ +from databricks_ai_bridge.genie import Genie + + +def _concat_messages_array(messages): + concatenated_message = "\n".join( + [ + f"{message.get('role', message.get('name', 'unknown'))}: {message.get('content', '')}" + if isinstance(message, dict) + else f"{getattr(message, 'role', getattr(message, 'name', 'unknown'))}: {getattr(message, 'content', '')}" + for message in messages + ] + ) + return concatenated_message + + +def _query_genie_as_agent(input, genie_space_id, genie_agent_name): + from langchain_core.messages import AIMessage + + genie = Genie(genie_space_id) + + message = f"I will provide you a chat history, where your name is {genie_agent_name}. Please help with the described information in the chat history.\n" + + # Concatenate messages to form the chat history + message += _concat_messages_array(input.get("messages")) + + # Send the message and wait for a response + genie_response = genie.ask_question(message) + + if genie_response: + return {"messages": [AIMessage(content=genie_response)]} + else: + return {"messages": [AIMessage(content="")]} + + +def GenieAgent(genie_space_id, genie_agent_name="Genie", description=""): + """Create a genie agent that can be used to query the API""" + from functools import partial + + from langchain_core.runnables import RunnableLambda + + # Create a partial function with the genie_space_id pre-filled + partial_genie_agent = partial( + _query_genie_as_agent, genie_space_id=genie_space_id, genie_agent_name=genie_agent_name + ) + + # Use the partial function in the RunnableLambda + return RunnableLambda(partial_genie_agent) diff --git a/integrations/langchain/tests/test_genie.py b/integrations/langchain/tests/test_genie.py new file mode 100644 index 0000000..70c6c28 --- /dev/null +++ b/integrations/langchain/tests/test_genie.py @@ -0,0 +1,71 @@ +from unittest.mock import patch + +from langchain_core.messages import AIMessage + +from databricks_langchain.genie import ( + GenieAgent, + _concat_messages_array, + _query_genie_as_agent, +) + + +def test_concat_messages_array(): + # Test a simple case with multiple messages + messages = [ + {"role": "user", "content": "What is the weather?"}, + {"role": "assistant", "content": "It is sunny."}, + ] + result = _concat_messages_array(messages) + expected = "user: What is the weather?\nassistant: It is sunny." + assert result == expected + + # Test case with missing content + messages = [{"role": "user"}, {"role": "assistant", "content": "I don't know."}] + result = _concat_messages_array(messages) + expected = "user: \nassistant: I don't know." + assert result == expected + + # Test case with non-dict message objects + class Message: + def __init__(self, role, content): + self.role = role + self.content = content + + messages = [ + Message("user", "Tell me a joke."), + Message("assistant", "Why did the chicken cross the road?"), + ] + result = _concat_messages_array(messages) + expected = "user: Tell me a joke.\nassistant: Why did the chicken cross the road?" + assert result == expected + + +@patch("databricks_langchain.genie.Genie") +def test_query_genie_as_agent(MockGenie): + # Mock the Genie class and its response + mock_genie = MockGenie.return_value + mock_genie.ask_question.return_value = "It is sunny." + + input_data = {"messages": [{"role": "user", "content": "What is the weather?"}]} + result = _query_genie_as_agent(input_data, "space-id", "Genie") + + expected_message = {"messages": [AIMessage(content="It is sunny.")]} + assert result == expected_message + + # Test the case when genie_response is empty + mock_genie.ask_question.return_value = None + result = _query_genie_as_agent(input_data, "space-id", "Genie") + + expected_message = {"messages": [AIMessage(content="")]} + assert result == expected_message + + +@patch("langchain_core.runnables.RunnableLambda") +def test_create_genie_agent(MockRunnableLambda): + mock_runnable = MockRunnableLambda.return_value + + agent = GenieAgent("space-id", "Genie") + assert agent == mock_runnable + + # Check that the partial function is created with the correct arguments + MockRunnableLambda.assert_called() diff --git a/pyproject.toml b/pyproject.toml index 775d184..c1df40d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,9 @@ readme = "README.md" requires-python = ">=3.8" dependencies = [ "typing_extensions", - "pydantic" + "pydantic", + "databricks-sdk>=0.34.0", + "pandas", ] [project.license] @@ -28,22 +30,9 @@ include = [ packages = ["src/databricks_ai_bridge"] [project.optional-dependencies] -databricks = [ - "databricks-sdk>=0.34.0", - "pandas", -] -databricks-dev = [ - "hatch", - "pytest", - "databricks-sdk>=0.34.0", - "pandas", - "ruff==0.6.4", -] dev = [ "hatch", "pytest", - "databricks-sdk>=0.34.0", - "pandas", "ruff==0.6.4", ]