-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Create Langchain Genie #5
Merged
Merged
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
e9dfff7
Create Langchain Genie
prithvikannan ab8b450
ruff
prithvikannan 4076c00
Merge remote-tracking branch 'origin/main' into langchain-genie
prithvikannan aa4f3d1
update langchain test
prithvikannan 2879071
pyproject
prithvikannan 580c928
ruff
prithvikannan 776d8c1
rename
prithvikannan File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
[project] | ||
name = "databricks-langchain" | ||
version = "0.0.1" | ||
description = "Support for Datarbricks AI support in LangChain" | ||
authors = [ | ||
{ name="Prithvi Kannan", email="[email protected]" }, | ||
] | ||
readme = "README.md" | ||
license = { text="Apache-2.0" } | ||
requires-python = ">=3.8" | ||
dependencies = [ | ||
"langchain>=0.2.0", | ||
"langchain-community>=0.2.0", | ||
"databricks-ai-bridge", | ||
] | ||
|
||
[project.optional-dependencies] | ||
dev = [ | ||
"pytest", | ||
"typing_extensions", | ||
"databricks-sdk>=0.34.0", | ||
"ruff==0.6.4", | ||
] | ||
|
||
[build-system] | ||
requires = ["hatchling"] | ||
build-backend = "hatchling.build" | ||
|
||
[tool.hatch.build] | ||
include = [ | ||
"src/databricks_langchain/*" | ||
] | ||
|
||
[tool.hatch.build.targets.wheel] | ||
packages = ["src/databricks_langchain"] | ||
|
||
[tool.ruff] | ||
line-length = 100 | ||
target-version = "py39" | ||
|
||
[tool.ruff.lint] | ||
select = [ | ||
# isort | ||
"I", | ||
# bugbear rules | ||
"B", | ||
# remove unused imports | ||
"F401", | ||
# bare except statements | ||
"E722", | ||
# print statements | ||
"T201", | ||
"T203", | ||
# misuse of typing.TYPE_CHECKING | ||
"TCH004", | ||
# import rules | ||
"TID251", | ||
# undefined-local-with-import-star | ||
"F403", | ||
] | ||
|
||
[tool.ruff.format] | ||
docstring-code-format = true | ||
docstring-code-line-length = 88 | ||
|
||
[tool.ruff.lint.pydocstyle] | ||
convention = "google" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from databricks_ai_bridge.genie import Genie | ||
|
||
|
||
def _concat_messages_array(messages): | ||
concatenated_message = "\n".join( | ||
[ | ||
f"{message.get('role', message.get('name', 'unknown'))}: {message.get('content', '')}" | ||
if isinstance(message, dict) | ||
else f"{getattr(message, 'role', getattr(message, 'name', 'unknown'))}: {getattr(message, 'content', '')}" | ||
for message in messages | ||
] | ||
) | ||
return concatenated_message | ||
|
||
|
||
def _query_genie_as_agent(input, genie_space_id, genie_agent_name): | ||
from langchain_core.messages import AIMessage | ||
|
||
genie = Genie(genie_space_id) | ||
|
||
message = f"I will provide you a chat history, where your name is {genie_agent_name}. Please help with the described information in the chat history.\n" | ||
|
||
# Concatenate messages to form the chat history | ||
message += _concat_messages_array(input.get("messages")) | ||
|
||
# Send the message and wait for a response | ||
genie_response = genie.ask_question(message) | ||
|
||
if genie_response: | ||
return {"messages": [AIMessage(content=genie_response)]} | ||
else: | ||
return {"messages": [AIMessage(content="")]} | ||
|
||
|
||
def GenieAgent(genie_space_id, genie_agent_name="Genie", description=""): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. adding |
||
"""Create a genie agent that can be used to query the API""" | ||
from functools import partial | ||
|
||
from langchain_core.runnables import RunnableLambda | ||
|
||
# Create a partial function with the genie_space_id pre-filled | ||
partial_genie_agent = partial( | ||
_query_genie_as_agent, genie_space_id=genie_space_id, genie_agent_name=genie_agent_name | ||
) | ||
|
||
# Use the partial function in the RunnableLambda | ||
return RunnableLambda(partial_genie_agent) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
from unittest.mock import patch | ||
|
||
from langchain_core.messages import AIMessage | ||
|
||
from databricks_langchain.genie import ( | ||
GenieAgent, | ||
_concat_messages_array, | ||
_query_genie_as_agent, | ||
) | ||
|
||
|
||
def test_concat_messages_array(): | ||
# Test a simple case with multiple messages | ||
messages = [ | ||
{"role": "user", "content": "What is the weather?"}, | ||
{"role": "assistant", "content": "It is sunny."}, | ||
] | ||
result = _concat_messages_array(messages) | ||
expected = "user: What is the weather?\nassistant: It is sunny." | ||
assert result == expected | ||
|
||
# Test case with missing content | ||
messages = [{"role": "user"}, {"role": "assistant", "content": "I don't know."}] | ||
result = _concat_messages_array(messages) | ||
expected = "user: \nassistant: I don't know." | ||
assert result == expected | ||
|
||
# Test case with non-dict message objects | ||
class Message: | ||
def __init__(self, role, content): | ||
self.role = role | ||
self.content = content | ||
|
||
messages = [ | ||
Message("user", "Tell me a joke."), | ||
Message("assistant", "Why did the chicken cross the road?"), | ||
] | ||
result = _concat_messages_array(messages) | ||
expected = "user: Tell me a joke.\nassistant: Why did the chicken cross the road?" | ||
assert result == expected | ||
|
||
|
||
@patch("databricks_langchain.genie.Genie") | ||
def test_query_genie_as_agent(MockGenie): | ||
# Mock the Genie class and its response | ||
mock_genie = MockGenie.return_value | ||
mock_genie.ask_question.return_value = "It is sunny." | ||
|
||
input_data = {"messages": [{"role": "user", "content": "What is the weather?"}]} | ||
result = _query_genie_as_agent(input_data, "space-id", "Genie") | ||
|
||
expected_message = {"messages": [AIMessage(content="It is sunny.")]} | ||
assert result == expected_message | ||
|
||
# Test the case when genie_response is empty | ||
mock_genie.ask_question.return_value = None | ||
result = _query_genie_as_agent(input_data, "space-id", "Genie") | ||
|
||
expected_message = {"messages": [AIMessage(content="")]} | ||
assert result == expected_message | ||
|
||
|
||
@patch("langchain_core.runnables.RunnableLambda") | ||
def test_create_genie_agent(MockRunnableLambda): | ||
mock_runnable = MockRunnableLambda.return_value | ||
|
||
agent = GenieAgent("space-id", "Genie") | ||
assert agent == mock_runnable | ||
|
||
# Check that the partial function is created with the correct arguments | ||
MockRunnableLambda.assert_called() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Question here: Are we expecting users to use the genie agent name in their questions? Will the questions be something like:
"Use Genie to get me my financial records"
Will this also work when users ask "Get my financial records"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The ideal end state is that the user can just specify "Get my financial records". This name is needed so that in a multi-agent (and in the future multi-user) chat array agents can identify who said what.