Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SherLoQ support added #9

Merged
merged 37 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
4d3b49d
querycreatortool added
Feb 29, 2024
d49d0ed
sqlquerycreatortool added
Mar 1, 2024
d63efe3
Update version and add SQLCoderToolkit
arunraja1 Mar 4, 2024
d395589
Updated tool.py for sqlcoder
Mar 4, 2024
5a220ce
retry logic added in sqlquerycreator tool
Mar 4, 2024
a651d32
sql query capture updated
Mar 12, 2024
2da388a
version update
Mar 12, 2024
f5a7e02
Merge remote-tracking branch 'origin/main' into users/sushant/sherloq
Mar 13, 2024
18f699f
Merge branch 'main' into users/sushant/sherloq
Mar 18, 2024
6129071
langchain version update
Mar 18, 2024
3dd193e
merge fix
Mar 18, 2024
225fbf4
version change to accomodate 36dev1(uncommited code to fix unstructer…
Mar 19, 2024
13cdeef
add changes to sqlcoder prompt
Apr 1, 2024
a190332
Merge branch 'main' into users/sushant/sherloq
Apr 3, 2024
39ec7f7
version update
Apr 3, 2024
099962a
added structured decomposition😁 to sherloq
Apr 4, 2024
20d3f78
sqlcoder prompt updated
Apr 4, 2024
f145add
prompt updated sqlcoder
Apr 15, 2024
dc89085
Add CustomPlanandSolveExecutor class to agent.py
arunraja1 May 5, 2024
14a50ad
Update langchain version to 0.1.46dev1
arunraja1 May 5, 2024
868474c
query mixing fixed
May 6, 2024
7e3da7f
version updated
sushantburnawal May 6, 2024
22120ff
chore: Add sources to agent state
arunraja1 May 6, 2024
1c22b3d
sqlcoder retry prompt updated
sushantburnawal May 7, 2024
78f3ed5
version updated
sushantburnawal May 7, 2024
ee19e78
Merge branch 'users/sushant/sherloq' of https://github.com/skypointcl…
arunraja1 May 8, 2024
a802744
changes added
sushantburnawal May 8, 2024
9db1f35
retry mechanism bug fix
sushantburnawal May 22, 2024
e2b0246
SQL_QUERY_CREATOR_TEMPLATE is removed from langchain
sushantburnawal May 24, 2024
4bcd25e
changes to sqlcodertoolkit
sushantburnawal May 24, 2024
093b660
changes to sqlcodertoolkit
sushantburnawal May 24, 2024
4f611ba
changes to sqlcodertoolkit
sushantburnawal May 24, 2024
cf50a6c
Add prompt instructions to include sources in the document variable
arunraja1 May 30, 2024
b910473
Update document variable to include relevant sources
arunraja1 May 30, 2024
16acf00
Update document variable to include relevant sources
arunraja1 May 30, 2024
e2737c3
Update sources handling in BaseQAWithSourcesChain
arunraja1 May 30, 2024
09d5c0c
Update SQL query creator tool to remove unnecessary code and improve …
arunraja1 Jun 21, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 171 additions & 0 deletions libs/langchain/langchain/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1534,6 +1534,8 @@ def _take_next_step(
self.state.append(
{str(intermediate_steps[-1][0]): str(intermediate_steps[-1][1])}
)
elif self.state is not None and "few_shot_examples" in self.state[0].keys():
pass
else:
self.state = []

Expand Down Expand Up @@ -1619,3 +1621,172 @@ def _take_next_step(
)
result.append((agent_action, observation))
return result


class CustomPlanandSolveExecutor(AgentExecutor):
"""Agent that is using tools."""

agent: Union[BaseSingleActionAgent, BaseMultiActionAgent]
"""The agent to run for creating a plan and determining actions
to take at each step of the execution loop."""
tools: Sequence[BaseTool]
"""The valid tools the agent can call."""
return_intermediate_steps: bool = False
"""Whether to return the agent's trajectory of intermediate steps
at the end in addition to the final output."""
max_iterations: Optional[int] = 15
"""The maximum number of steps to take before ending the execution
loop. Setting to 'None' could lead to an infinite loop."""

max_execution_time: Optional[float] = None
"""The maximum amount of wall clock time to spend in the execution
loop.
"""
early_stopping_method: str = "force"
"""The method to use for early stopping if the agent never
returns `AgentFinish`. Either 'force' or 'generate'.
`"force"` returns a string saying that it stopped because it met a
time or iteration limit.
`"generate"` calls the agent's LLM Chain one final time to generate
a final answer based on the previous steps.
"""
handle_parsing_errors: Union[
bool, str, Callable[[OutputParserException], str]
] = False
"""How to handle errors raised by the agent's output parser.
Defaults to `False`, which raises the error.
s
If `true`, the error will be sent back to the LLM as an observation.
If a string, the string itself will be sent to the LLM as an observation.
If a callable function, the function will be called with the exception
as an argument, and the result of that function will be passed to the agent
as an observation.
"""
state: List[Dict[str, Any]] = None
trim_intermediate_steps: Union[
int, Callable[[List[Tuple[AgentAction, str]]], List[Tuple[AgentAction, str]]]
] = -1

def _take_next_step(
self,
name_to_tool_map: Dict[str, StateTool],
color_mapping: Dict[str, str],
inputs: Dict[str, str],
intermediate_steps: List[Tuple[AgentAction, str]],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Union[AgentFinish, List[Tuple[AgentAction, str]]]:
"""Take a single step in the thought-action-observation loop.
Override this to take control of how the agent makes and acts on choices.
"""
try:
intermediate_steps = self._prepare_intermediate_steps(intermediate_steps)
# Check the intermediate_steps and add the last one to the state
if intermediate_steps:
try:
tool_name = intermediate_steps[-1][0].dict().get("tool",None)
tool_input = intermediate_steps[-1][0].dict().get("tool_input",None)
sql_query = json.loads(intermediate_steps[-1][1][0]).get("sql_query",None)
answer = json.loads(intermediate_steps[-1][1][0]).get("answer",None)
sources = json.loads(intermediate_steps[-1][1][0]).get("sources",None)

self.state.append(
{
"tool": tool_name,
"tool_input": tool_input,
"sql_query": sql_query,
"answer": answer,
"sources": sources

}
)
except:
self.state = []
else:
self.state = []

if len(self.state) != 0:
# return self.state['answer']
return AgentFinish({"output": self.state[0]["answer"]}, intermediate_steps[0][0].log)

# Call the LLM to see what to do.
output = self.agent.plan(
intermediate_steps,
callbacks=run_manager.get_child() if run_manager else None,
**inputs,
)
except OutputParserException as e:
if isinstance(self.handle_parsing_errors, bool):
raise_error = not self.handle_parsing_errors
else:
raise_error = False
if raise_error:
raise e
text = str(e)
if isinstance(self.handle_parsing_errors, bool):
if e.send_to_llm:
observation = str(e.observation)
text = str(e.llm_output)
else:
observation = "Invalid or incomplete response"
elif isinstance(self.handle_parsing_errors, str):
observation = self.handle_parsing_errors
elif callable(self.handle_parsing_errors):
observation = self.handle_parsing_errors(e)
else:
raise ValueError("Got unexpected type of `handle_parsing_errors`")
output = AgentAction("_Exception", observation, text)
if run_manager:
run_manager.on_agent_action(output, color="green")
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
observation = ExceptionTool().run(
output.tool_input,
verbose=self.verbose,
color=None,
callbacks=run_manager.get_child() if run_manager else None,
**tool_run_kwargs,
)
return [(output, observation)]
# If the tool chosen is the finishing tool, then we end and return.
if isinstance(output, AgentFinish):
return output
actions: List[AgentAction]
if isinstance(output, AgentAction):
actions = [output]
else:
actions = output
result = []
for agent_action in actions:
if run_manager:
run_manager.on_agent_action(agent_action, color="green")
# Otherwise we lookup the tool
if agent_action.tool in name_to_tool_map:
tool = name_to_tool_map[agent_action.tool]
return_direct = tool.return_direct
color = color_mapping[agent_action.tool]
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
if return_direct:
tool_run_kwargs["llm_prefix"] = ""
# We then call the tool on the tool input to get an observation
observation = tool.run(
agent_action.tool_input,
verbose=self.verbose,
color=color,
state=self.state,
callbacks=run_manager.get_child()
if run_manager
else None**tool_run_kwargs,
)
else:
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
observation = InvalidTool().run(
{
"requested_tool_name": agent_action.tool,
"available_tool_names": list(name_to_tool_map.keys()),
},
verbose=self.verbose,
color=None,
callbacks=run_manager.get_child() if run_manager else None,
**tool_run_kwargs,
)
result.append((agent_action, observation))
return result
2 changes: 2 additions & 0 deletions libs/langchain/langchain/agents/agent_toolkits/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from langchain.agents.agent_toolkits.office365.toolkit import O365Toolkit
from langchain.agents.agent_toolkits.openapi.base import create_openapi_agent
from langchain.agents.agent_toolkits.openapi.toolkit import OpenAPIToolkit
from langchain.agents.agent_toolkits.sqlcoder.toolkit import SQLCoderToolkit
from langchain.agents.agent_toolkits.playwright.toolkit import PlayWrightBrowserToolkit
from langchain.agents.agent_toolkits.powerbi.base import create_pbi_agent
from langchain.agents.agent_toolkits.powerbi.chat_base import create_pbi_chat_agent
Expand Down Expand Up @@ -102,6 +103,7 @@ def __getattr__(name: str) -> Any:
"PowerBIToolkit",
"SlackToolkit",
"SteamToolkit",
"SQLCoderToolkit",
"SQLDatabaseToolkit",
"SparkSQLToolkit",
"UCSQLDatabaseToolkit",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""SQL Coder Agent Toolkit"""
93 changes: 93 additions & 0 deletions libs/langchain/langchain/agents/agent_toolkits/sqlcoder/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""SQL agent."""
from typing import Any, Dict, List, Optional

from langchain.agents.agent import AgentExecutor, BaseSingleActionAgent
from langchain.agents.agent_toolkits.sql.prompt import (
SQL_FUNCTIONS_SUFFIX,
SQL_PREFIX,
SQL_SUFFIX,
)
from langchain.agents.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain.agents.agent_types import AgentType
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain
from langchain.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder,
)
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.messages import AIMessage, SystemMessage


def create_sql_agent(
llm: BaseLanguageModel,
toolkit: SQLDatabaseToolkit,
agent_type: AgentType = AgentType.ZERO_SHOT_REACT_DESCRIPTION,
callback_manager: Optional[BaseCallbackManager] = None,
prefix: str = SQL_PREFIX,
suffix: Optional[str] = None,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
top_k: int = 10,
max_iterations: Optional[int] = 15,
max_execution_time: Optional[float] = None,
early_stopping_method: str = "force",
verbose: bool = False,
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Dict[str, Any],
) -> AgentExecutor:
"""Construct an SQL agent from an LLM and tools."""
tools = toolkit.get_tools()
prefix = prefix.format(dialect=toolkit.dialect, top_k=top_k)
agent: BaseSingleActionAgent

if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION:
prompt = ZeroShotAgent.create_prompt(
tools,
prefix=prefix,
suffix=suffix or SQL_SUFFIX,
format_instructions=format_instructions,
input_variables=input_variables,
)
llm_chain = LLMChain(
llm=llm,
prompt=prompt,
callback_manager=callback_manager,
)
tool_names = [tool.name for tool in tools]
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)

elif agent_type == AgentType.OPENAI_FUNCTIONS:
messages = [
SystemMessage(content=prefix),
HumanMessagePromptTemplate.from_template("{input}"),
AIMessage(content=suffix or SQL_FUNCTIONS_SUFFIX),
MessagesPlaceholder(variable_name="agent_scratchpad"),
]
input_variables = ["input", "agent_scratchpad"]
_prompt = ChatPromptTemplate(input_variables=input_variables, messages=messages)

agent = OpenAIFunctionsAgent(
llm=llm,
prompt=_prompt,
tools=tools,
callback_manager=callback_manager,
**kwargs,
)
else:
raise ValueError(f"Agent type {agent_type} not supported at the moment.")

return AgentExecutor.from_agent_and_tools(
agent=agent,
tools=tools,
callback_manager=callback_manager,
verbose=verbose,
max_iterations=max_iterations,
max_execution_time=max_execution_time,
early_stopping_method=early_stopping_method,
**(agent_executor_kwargs or {}),
)
23 changes: 23 additions & 0 deletions libs/langchain/langchain/agents/agent_toolkits/sqlcoder/prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# flake8: noqa

SQL_PREFIX = """You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.
You have access to tools for interacting with the database.
Only use the below tools. Only use the information returned by the below tools to construct your final answer.
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.

If the question does not seem related to the database, just return "I don't know" as the answer.
"""

SQL_SUFFIX = """Begin!

Question: {input}
Thought: I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables.
{agent_scratchpad}"""

SQL_FUNCTIONS_SUFFIX = """I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables."""
66 changes: 66 additions & 0 deletions libs/langchain/langchain/agents/agent_toolkits/sqlcoder/toolkit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""Toolkit for interacting with a SQL database."""
from typing import List

from langchain.agents.agent_toolkits.base import BaseToolkit
from langchain.base_language import BaseLanguageModel
from langchain.sql_database import SQLDatabase
from langchain.tools import BaseTool
from langchain.tools.sqlcoder.tool import (
QuerySparkSQLDataBaseTool,
SqlQueryCreatorTool,
)
from langchain.tools.sql_database.tool import QuerySQLCheckerTool
from langchain_core.pydantic_v1 import Field


class SQLCoderToolkit(BaseToolkit):
"""Toolkit for interacting with SQL databases."""

db: SQLDatabase = Field(exclude=True)
llm: BaseLanguageModel = Field(exclude=True)
db_token: str
db_host: str
db_catalog: str
db_schema: str
db_warehouse_id: str
allow_extra_fields = True
sqlcreatorllm : BaseLanguageModel = Field(exclude=True)
sql_query_creator_template : str

@property
def dialect(self) -> str:
"""Return string representation of dialect to use."""
return self.db.dialect

class Config:
"""Configuration for this pydantic object."""

arbitrary_types_allowed = True

def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
query_sql_database_tool_description = (
"Input to this tool is a detailed and correct SQL query, output is a "
"result from the database. If the query is not correct, an error message "
"will be returned. If an error is returned, rewrite the query, check the "
"query, and try again. If you encounter an issue with Unknown column "
"'xxxx' in 'field list', using schema_sql_db to query the correct table "
"fields."
)
return [
QuerySparkSQLDataBaseTool(
db=self.db, description=query_sql_database_tool_description
),
QuerySQLCheckerTool(db=self.db, llm=self.llm),
SqlQueryCreatorTool(
sqlcreatorllm=self.sqlcreatorllm ,
db=self.db,
db_token=self.db_token,
db_host=self.db_host,
db_catalog=self.db_catalog,
db_schema=self.db_schema,
db_warehouse_id=self.db_warehouse_id,
SQL_QUERY_CREATOR_TEMPLATE=self.sql_query_creator_template

)
]
1 change: 1 addition & 0 deletions libs/langchain/langchain/chains/combine_documents/stuff.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def _get_inputs(self, docs: List[Document], **kwargs: Any) -> dict:
if k in self.llm_chain.prompt.input_variables
}
inputs[self.document_variable_name] = self.document_separator.join(doc_strings)
inputs[self.document_variable_name] += "\n\nYou should return the sources only when it is relevant to the output. Else return Sources as None"
return inputs

def prompt_length(self, docs: List[Document], **kwargs: Any) -> Optional[int]:
Expand Down
Loading
Loading