-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Prithvi Kannan <[email protected]>
- Loading branch information
1 parent
86a74fb
commit e9dfff7
Showing
3 changed files
with
183 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
[project] | ||
name = "databricks-langchain" | ||
version = "0.0.1" | ||
description = "Support for Datarbricks AI support in LangChain" | ||
authors = [ | ||
{ name="Prithvi Kannan", email="[email protected]" }, | ||
] | ||
readme = "README.md" | ||
license = { text="Apache-2.0" } | ||
requires-python = ">=3.8" | ||
dependencies = [ | ||
"langchain>=0.2.0", | ||
"langchain-community>=0.2.0", | ||
"databricks-ai-bridge", | ||
] | ||
|
||
[project.optional-dependencies] | ||
dev = [ | ||
"pytest", | ||
"typing_extensions", | ||
"databricks-sdk>=0.34.0", | ||
"ruff==0.6.4", | ||
"langgraph", | ||
] | ||
|
||
[build-system] | ||
requires = ["hatchling"] | ||
build-backend = "hatchling.build" | ||
|
||
[tool.hatch.build] | ||
include = [ | ||
"src/databricks_langchain/*" | ||
] | ||
|
||
[tool.hatch.build.targets.wheel] | ||
packages = ["src/databricks_langchain"] | ||
|
||
[tool.ruff] | ||
line-length = 100 | ||
target-version = "py39" | ||
|
||
[tool.ruff.lint] | ||
select = [ | ||
# isort | ||
"I", | ||
# bugbear rules | ||
"B", | ||
# remove unused imports | ||
"F401", | ||
# bare except statements | ||
"E722", | ||
# print statements | ||
"T201", | ||
"T203", | ||
# misuse of typing.TYPE_CHECKING | ||
"TCH004", | ||
# import rules | ||
"TID251", | ||
# undefined-local-with-import-star | ||
"F403", | ||
] | ||
|
||
[tool.ruff.format] | ||
docstring-code-format = true | ||
docstring-code-line-length = 88 | ||
|
||
[tool.ruff.lint.pydocstyle] | ||
convention = "google" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
from databricks_ai_bridge.genie import Genie | ||
|
||
def _concat_messages_array(messages): | ||
concatenated_message = "\n".join( | ||
[ | ||
f"{message.get('role', message.get('name', 'unknown'))}: {message.get('content', '')}" | ||
if isinstance(message, dict) | ||
else f"{getattr(message, 'role', getattr(message, 'name', 'unknown'))}: {getattr(message, 'content', '')}" | ||
for message in messages | ||
] | ||
) | ||
return concatenated_message | ||
|
||
|
||
def _query_genie_as_agent(input, genie_space_id, genie_agent_name): | ||
from langchain_core.messages import AIMessage | ||
genie = Genie(genie_space_id) | ||
|
||
message = f"I will provide you a chat history, where your name is {genie_agent_name}. Please help with the described information in the chat history.\n" | ||
|
||
# Concatenate messages to form the chat history | ||
message += _concat_messages_array(input.get("messages")) | ||
|
||
# Send the message and wait for a response | ||
genie_response = genie.ask_question(message) | ||
|
||
if genie_response: | ||
return {"messages": [AIMessage(content=genie_response)]} | ||
else: | ||
return {"messages": [AIMessage(content="")]} | ||
|
||
|
||
def create_genie_agent(genie_space_id, genie_agent_name="Genie"): | ||
"""Create a genie agent that can be used to query the API""" | ||
from functools import partial | ||
|
||
from langchain_core.runnables import RunnableLambda | ||
|
||
# Create a partial function with the genie_space_id pre-filled | ||
partial_genie_agent = partial(_query_genie_as_agent, genie_space_id=genie_space_id, genie_agent_name=genie_agent_name) | ||
|
||
# Use the partial function in the RunnableLambda | ||
return RunnableLambda(partial_genie_agent) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import pytest | ||
from unittest.mock import patch, MagicMock | ||
from langchain_core.messages import AIMessage | ||
from my_module import _concat_messages_array, _query_genie_as_agent, create_genie_agent | ||
|
||
def test_concat_messages_array(): | ||
# Test a simple case with multiple messages | ||
messages = [ | ||
{"role": "user", "content": "What is the weather?"}, | ||
{"role": "assistant", "content": "It is sunny."} | ||
] | ||
result = _concat_messages_array(messages) | ||
expected = "user: What is the weather?\nassistant: It is sunny." | ||
assert result == expected | ||
|
||
# Test case with missing content | ||
messages = [ | ||
{"role": "user"}, | ||
{"role": "assistant", "content": "I don't know."} | ||
] | ||
result = _concat_messages_array(messages) | ||
expected = "user: \nassistant: I don't know." | ||
assert result == expected | ||
|
||
# Test case with non-dict message objects | ||
class Message: | ||
def __init__(self, role, content): | ||
self.role = role | ||
self.content = content | ||
|
||
messages = [ | ||
Message("user", "Tell me a joke."), | ||
Message("assistant", "Why did the chicken cross the road?") | ||
] | ||
result = _concat_messages_array(messages) | ||
expected = "user: Tell me a joke.\nassistant: Why did the chicken cross the road?" | ||
assert result == expected | ||
|
||
|
||
@patch('databricks_ai_bridge.genie.Genie') | ||
def test_query_genie_as_agent(MockGenie): | ||
# Mock the Genie class and its response | ||
mock_genie = MockGenie.return_value | ||
mock_genie.ask_question.return_value = "It is sunny." | ||
|
||
input_data = {"messages": [{"role": "user", "content": "What is the weather?"}]} | ||
result = _query_genie_as_agent(input_data, "space-id", "Genie") | ||
|
||
expected_message = { | ||
"messages": [AIMessage(content="It is sunny.")] | ||
} | ||
assert result == expected_message | ||
|
||
# Test the case when genie_response is empty | ||
mock_genie.ask_question.return_value = None | ||
result = _query_genie_as_agent(input_data, "space-id", "Genie") | ||
|
||
expected_message = { | ||
"messages": [AIMessage(content="")] | ||
} | ||
assert result == expected_message | ||
|
||
|
||
@patch('langchain_core.runnables.RunnableLambda') | ||
def test_create_genie_agent(MockRunnableLambda): | ||
mock_runnable = MockRunnableLambda.return_value | ||
|
||
agent = create_genie_agent("space-id", "Genie") | ||
assert agent == mock_runnable | ||
|
||
# Check that the partial function is created with the correct arguments | ||
MockRunnableLambda.assert_called() |