Skip to content

Commit

Permalink
Create Langchain Genie
Browse files Browse the repository at this point in the history
Signed-off-by: Prithvi Kannan <[email protected]>
  • Loading branch information
prithvikannan committed Oct 22, 2024
1 parent 86a74fb commit e9dfff7
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 0 deletions.
68 changes: 68 additions & 0 deletions integrations/langchain/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
[project]
name = "databricks-langchain"
version = "0.0.1"
description = "Support for Datarbricks AI support in LangChain"
authors = [
{ name="Prithvi Kannan", email="[email protected]" },
]
readme = "README.md"
license = { text="Apache-2.0" }
requires-python = ">=3.8"
dependencies = [
"langchain>=0.2.0",
"langchain-community>=0.2.0",
"databricks-ai-bridge",
]

[project.optional-dependencies]
dev = [
"pytest",
"typing_extensions",
"databricks-sdk>=0.34.0",
"ruff==0.6.4",
"langgraph",
]

[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

[tool.hatch.build]
include = [
"src/databricks_langchain/*"
]

[tool.hatch.build.targets.wheel]
packages = ["src/databricks_langchain"]

[tool.ruff]
line-length = 100
target-version = "py39"

[tool.ruff.lint]
select = [
# isort
"I",
# bugbear rules
"B",
# remove unused imports
"F401",
# bare except statements
"E722",
# print statements
"T201",
"T203",
# misuse of typing.TYPE_CHECKING
"TCH004",
# import rules
"TID251",
# undefined-local-with-import-star
"F403",
]

[tool.ruff.format]
docstring-code-format = true
docstring-code-line-length = 88

[tool.ruff.lint.pydocstyle]
convention = "google"
43 changes: 43 additions & 0 deletions integrations/langchain/src/databricks_langchain/genie.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from databricks_ai_bridge.genie import Genie

def _concat_messages_array(messages):
concatenated_message = "\n".join(
[
f"{message.get('role', message.get('name', 'unknown'))}: {message.get('content', '')}"
if isinstance(message, dict)
else f"{getattr(message, 'role', getattr(message, 'name', 'unknown'))}: {getattr(message, 'content', '')}"
for message in messages
]
)
return concatenated_message


def _query_genie_as_agent(input, genie_space_id, genie_agent_name):
from langchain_core.messages import AIMessage
genie = Genie(genie_space_id)

message = f"I will provide you a chat history, where your name is {genie_agent_name}. Please help with the described information in the chat history.\n"

# Concatenate messages to form the chat history
message += _concat_messages_array(input.get("messages"))

# Send the message and wait for a response
genie_response = genie.ask_question(message)

if genie_response:
return {"messages": [AIMessage(content=genie_response)]}
else:
return {"messages": [AIMessage(content="")]}


def create_genie_agent(genie_space_id, genie_agent_name="Genie"):
"""Create a genie agent that can be used to query the API"""
from functools import partial

from langchain_core.runnables import RunnableLambda

# Create a partial function with the genie_space_id pre-filled
partial_genie_agent = partial(_query_genie_as_agent, genie_space_id=genie_space_id, genie_agent_name=genie_agent_name)

# Use the partial function in the RunnableLambda
return RunnableLambda(partial_genie_agent)
72 changes: 72 additions & 0 deletions integrations/langchain/tests/test_genie.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import pytest
from unittest.mock import patch, MagicMock
from langchain_core.messages import AIMessage
from my_module import _concat_messages_array, _query_genie_as_agent, create_genie_agent

def test_concat_messages_array():
# Test a simple case with multiple messages
messages = [
{"role": "user", "content": "What is the weather?"},
{"role": "assistant", "content": "It is sunny."}
]
result = _concat_messages_array(messages)
expected = "user: What is the weather?\nassistant: It is sunny."
assert result == expected

# Test case with missing content
messages = [
{"role": "user"},
{"role": "assistant", "content": "I don't know."}
]
result = _concat_messages_array(messages)
expected = "user: \nassistant: I don't know."
assert result == expected

# Test case with non-dict message objects
class Message:
def __init__(self, role, content):
self.role = role
self.content = content

messages = [
Message("user", "Tell me a joke."),
Message("assistant", "Why did the chicken cross the road?")
]
result = _concat_messages_array(messages)
expected = "user: Tell me a joke.\nassistant: Why did the chicken cross the road?"
assert result == expected


@patch('databricks_ai_bridge.genie.Genie')
def test_query_genie_as_agent(MockGenie):
# Mock the Genie class and its response
mock_genie = MockGenie.return_value
mock_genie.ask_question.return_value = "It is sunny."

input_data = {"messages": [{"role": "user", "content": "What is the weather?"}]}
result = _query_genie_as_agent(input_data, "space-id", "Genie")

expected_message = {
"messages": [AIMessage(content="It is sunny.")]
}
assert result == expected_message

# Test the case when genie_response is empty
mock_genie.ask_question.return_value = None
result = _query_genie_as_agent(input_data, "space-id", "Genie")

expected_message = {
"messages": [AIMessage(content="")]
}
assert result == expected_message


@patch('langchain_core.runnables.RunnableLambda')
def test_create_genie_agent(MockRunnableLambda):
mock_runnable = MockRunnableLambda.return_value

agent = create_genie_agent("space-id", "Genie")
assert agent == mock_runnable

# Check that the partial function is created with the correct arguments
MockRunnableLambda.assert_called()

0 comments on commit e9dfff7

Please sign in to comment.