Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create Langchain Genie #5

Merged
merged 7 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name: tests

Check warning on line 1 in .github/workflows/main.yml

View workflow job for this annotation

GitHub Actions / lint

1:1 [document-start] missing document start "---"

on:

Check warning on line 3 in .github/workflows/main.yml

View workflow job for this annotation

GitHub Actions / lint

3:1 [truthy] truthy value should be one of [false, true]
push:
branches:
- master
Expand Down Expand Up @@ -50,3 +50,24 @@
- name: Run tests
run: |
pytest tests/

langchain_test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.8', '3.9', '3.10']
timeout-minutes: 20
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
pip install .
pip install integrations/langchain[dev]
- name: Run tests
run: |
pytest integrations/langchain/tests
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@ Create a conda environement and install dev requirements
```
conda create --name databricks-ai-dev-env python=3.10
conda activate databricks-ai-dev-env
pip install -e ".[databricks-dev]"
pip install -e ".[dev]"
pip install -r requirements/lint-requirements.txt
```
67 changes: 67 additions & 0 deletions integrations/langchain/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
[project]
name = "databricks-langchain"
version = "0.0.1"
description = "Support for Datarbricks AI support in LangChain"
authors = [
{ name="Prithvi Kannan", email="[email protected]" },
]
readme = "README.md"
license = { text="Apache-2.0" }
requires-python = ">=3.8"
dependencies = [
"langchain>=0.2.0",
"langchain-community>=0.2.0",
"databricks-ai-bridge",
]

[project.optional-dependencies]
dev = [
"pytest",
"typing_extensions",
"databricks-sdk>=0.34.0",
"ruff==0.6.4",
]

[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

[tool.hatch.build]
include = [
"src/databricks_langchain/*"
]

[tool.hatch.build.targets.wheel]
packages = ["src/databricks_langchain"]

[tool.ruff]
line-length = 100
target-version = "py39"

[tool.ruff.lint]
select = [
# isort
"I",
# bugbear rules
"B",
# remove unused imports
"F401",
# bare except statements
"E722",
# print statements
"T201",
"T203",
# misuse of typing.TYPE_CHECKING
"TCH004",
# import rules
"TID251",
# undefined-local-with-import-star
"F403",
]

[tool.ruff.format]
docstring-code-format = true
docstring-code-line-length = 88

[tool.ruff.lint.pydocstyle]
convention = "google"
47 changes: 47 additions & 0 deletions integrations/langchain/src/databricks_langchain/genie.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from databricks_ai_bridge.genie import Genie


def _concat_messages_array(messages):
concatenated_message = "\n".join(
[
f"{message.get('role', message.get('name', 'unknown'))}: {message.get('content', '')}"
if isinstance(message, dict)
else f"{getattr(message, 'role', getattr(message, 'name', 'unknown'))}: {getattr(message, 'content', '')}"
for message in messages
]
)
return concatenated_message


def _query_genie_as_agent(input, genie_space_id, genie_agent_name):
from langchain_core.messages import AIMessage

genie = Genie(genie_space_id)

message = f"I will provide you a chat history, where your name is {genie_agent_name}. Please help with the described information in the chat history.\n"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question here: Are we expecting users to use the genie agent name in their questions? Will the questions be something like:

"Use Genie to get me my financial records"

Will this also work when users ask "Get my financial records"

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ideal end state is that the user can just specify "Get my financial records". This name is needed so that in a multi-agent (and in the future multi-user) chat array agents can identify who said what.


# Concatenate messages to form the chat history
message += _concat_messages_array(input.get("messages"))

# Send the message and wait for a response
genie_response = genie.ask_question(message)

if genie_response:
return {"messages": [AIMessage(content=genie_response)]}
else:
return {"messages": [AIMessage(content="")]}


def GenieAgent(genie_space_id, genie_agent_name="Genie", description=""):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

adding description to the interface, but TBD on how this will be used.

"""Create a genie agent that can be used to query the API"""
from functools import partial

from langchain_core.runnables import RunnableLambda

# Create a partial function with the genie_space_id pre-filled
partial_genie_agent = partial(
_query_genie_as_agent, genie_space_id=genie_space_id, genie_agent_name=genie_agent_name
)

# Use the partial function in the RunnableLambda
return RunnableLambda(partial_genie_agent)
71 changes: 71 additions & 0 deletions integrations/langchain/tests/test_genie.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from unittest.mock import patch

from langchain_core.messages import AIMessage

from databricks_langchain.genie import (
GenieAgent,
_concat_messages_array,
_query_genie_as_agent,
)


def test_concat_messages_array():
# Test a simple case with multiple messages
messages = [
{"role": "user", "content": "What is the weather?"},
{"role": "assistant", "content": "It is sunny."},
]
result = _concat_messages_array(messages)
expected = "user: What is the weather?\nassistant: It is sunny."
assert result == expected

# Test case with missing content
messages = [{"role": "user"}, {"role": "assistant", "content": "I don't know."}]
result = _concat_messages_array(messages)
expected = "user: \nassistant: I don't know."
assert result == expected

# Test case with non-dict message objects
class Message:
def __init__(self, role, content):
self.role = role
self.content = content

messages = [
Message("user", "Tell me a joke."),
Message("assistant", "Why did the chicken cross the road?"),
]
result = _concat_messages_array(messages)
expected = "user: Tell me a joke.\nassistant: Why did the chicken cross the road?"
assert result == expected


@patch("databricks_langchain.genie.Genie")
def test_query_genie_as_agent(MockGenie):
# Mock the Genie class and its response
mock_genie = MockGenie.return_value
mock_genie.ask_question.return_value = "It is sunny."

input_data = {"messages": [{"role": "user", "content": "What is the weather?"}]}
result = _query_genie_as_agent(input_data, "space-id", "Genie")

expected_message = {"messages": [AIMessage(content="It is sunny.")]}
assert result == expected_message

# Test the case when genie_response is empty
mock_genie.ask_question.return_value = None
result = _query_genie_as_agent(input_data, "space-id", "Genie")

expected_message = {"messages": [AIMessage(content="")]}
assert result == expected_message


@patch("langchain_core.runnables.RunnableLambda")
def test_create_genie_agent(MockRunnableLambda):
mock_runnable = MockRunnableLambda.return_value

agent = GenieAgent("space-id", "Genie")
assert agent == mock_runnable

# Check that the partial function is created with the correct arguments
MockRunnableLambda.assert_called()
17 changes: 3 additions & 14 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ readme = "README.md"
requires-python = ">=3.8"
dependencies = [
"typing_extensions",
"pydantic"
"pydantic",
"databricks-sdk>=0.34.0",
"pandas",
]

[project.license]
Expand All @@ -28,22 +30,9 @@ include = [
packages = ["src/databricks_ai_bridge"]

[project.optional-dependencies]
databricks = [
"databricks-sdk>=0.34.0",
"pandas",
]
databricks-dev = [
"hatch",
"pytest",
"databricks-sdk>=0.34.0",
"pandas",
"ruff==0.6.4",
]
dev = [
"hatch",
"pytest",
"databricks-sdk>=0.34.0",
"pandas",
"ruff==0.6.4",
]

Expand Down