From 4d3b49da6580b53319f10272d9030a0980ff985d Mon Sep 17 00:00:00 2001 From: sushant Date: Thu, 29 Feb 2024 13:01:52 +0530 Subject: [PATCH 01/33] querycreatortool added --- libs/langchain/langchain/agents/agent.py | 3 + .../agent_toolkits/unitycatalog/toolkit.py | 10 + libs/langchain/langchain/tools/__init__.py | 7 + .../tools/spark_unitycatalog/prompt.py | 26 +++ .../tools/spark_unitycatalog/tool.py | 201 +++++++++++++++++- libs/langchain/pyproject.toml | 2 +- 6 files changed, 247 insertions(+), 2 deletions(-) diff --git a/libs/langchain/langchain/agents/agent.py b/libs/langchain/langchain/agents/agent.py index 79f13f0988068..dc366620576a6 100644 --- a/libs/langchain/langchain/agents/agent.py +++ b/libs/langchain/langchain/agents/agent.py @@ -1590,6 +1590,8 @@ def _take_next_step( self.state.append( {str(intermediate_steps[-1][0]): str(intermediate_steps[-1][1])} ) + elif self.state is not None and "few_shot_examples" in self.state[0].keys(): + pass else: self.state = [] # metadata_dict = {} @@ -1619,3 +1621,4 @@ def _take_next_step( ) result.append((agent_action, observation)) return result + diff --git a/libs/langchain/langchain/agents/agent_toolkits/unitycatalog/toolkit.py b/libs/langchain/langchain/agents/agent_toolkits/unitycatalog/toolkit.py index 385aa93738b6d..caed7191d405f 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/unitycatalog/toolkit.py +++ b/libs/langchain/langchain/agents/agent_toolkits/unitycatalog/toolkit.py @@ -10,6 +10,7 @@ ListUnityCatalogTablesTool, QueryUCSQLDataBaseTool, SqlQueryValidatorTool, + SqlQueryCreatorTool, ) from langchain.tools.sql_database.tool import QuerySQLCheckerTool from langchain_core.pydantic_v1 import Field @@ -26,6 +27,7 @@ class UCSQLDatabaseToolkit(BaseToolkit): db_schema: str db_warehouse_id: str allow_extra_fields = True + sqlcreatorllm : BaseLanguageModel = Field(exclude=True) @property def dialect(self) -> str: @@ -76,4 +78,12 @@ def get_tools(self) -> List[BaseTool]: ), QuerySQLCheckerTool(db=self.db, llm=self.llm), SqlQueryValidatorTool(llm=self.llm), + SqlQueryCreatorTool( + sqlcreatorllm=self.sqlcreatorllm , + db=self.db, + db_token=self.db_token, + db_host=self.db_host, + db_catalog=self.db_catalog, + db_schema=self.db_schema, + db_warehouse_id=self.db_warehouse_id ) ] diff --git a/libs/langchain/langchain/tools/__init__.py b/libs/langchain/langchain/tools/__init__.py index ed9c8f849225b..e19e0fe6cbf2d 100644 --- a/libs/langchain/langchain/tools/__init__.py +++ b/libs/langchain/langchain/tools/__init__.py @@ -773,6 +773,10 @@ def _import_spark_unitycatalog_tool_SqlQueryValidatorTool() -> Any: return SqlQueryValidatorTool +def _import_spark_unitycatalog_tool_SqlQueryCreatorTool() -> Any: + from langchain.tools.spark_unitycatalog.tool import SqlQueryCreatorTool + + return SqlQueryCreatorTool def _import_stackexchange_tool() -> Any: from langchain.tools.stackexchange.tool import StackExchangeTool @@ -1075,6 +1079,8 @@ def __getattr__(name: str) -> Any: return _import_spark_unitycatalog_tool_QueryUCSQLDataBaseTool() elif name == "SqlQueryValidatorTool": return _import_spark_unitycatalog_tool_SqlQueryValidatorTool() + elif name == "SqlQueryCreatorTool": + return _import_spark_unitycatalog_tool_SqlQueryCreatorTool() elif name == "InfoSparkSQLTool": return _import_spark_sql_tool_InfoSparkSQLTool() elif name == "ListSparkSQLTool": @@ -1248,6 +1254,7 @@ def __getattr__(name: str) -> Any: "SlackSendMessage", "SleepTool", "SqlQueryValidatorTool", + "SqlQueryCreatorTool", "StdInInquireTool", "StackExchangeTool", "StackExchangeTool", diff --git a/libs/langchain/langchain/tools/spark_unitycatalog/prompt.py b/libs/langchain/langchain/tools/spark_unitycatalog/prompt.py index 9fed682b121b3..dc1874611bd4a 100644 --- a/libs/langchain/langchain/tools/spark_unitycatalog/prompt.py +++ b/libs/langchain/langchain/tools/spark_unitycatalog/prompt.py @@ -32,3 +32,29 @@ Begin SQL Query Validation. """ + + +SQL_QUERY_CREATOR = """### Instructions: +Your task is convert a question into a SQL query, given a schema which is databricks sql compatible. +Adhere to these rules: +- **Deliberately go through the question and database schema word by word** to appropriately answer the question +- **Use Table Aliases** to prevent ambiguity. For example, `SELECT table1.col1, table2.col1 FROM table1 JOIN table2 ON table1.id = table2.id`. +- When creating a ratio, always cast the numerator as float +You are an AI research assistant in the senior living industry. +You have access to a database that contains the information about different communities, their amenities, residents, expenses, budget, revenue and other finances, facilities, beds, events +When querying the database, given an input question, create a syntactically correct query to run, then look at the results of the query and return the answer. +Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 10 results. +You can order the results by a relevant column to return the most interesting examples in the database. +Never query for all the columns from a specific table, only ask for the relevant columns given the question. +DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. + +### Input: +Generate a SQL query that answers the question `{user_input}`. +This query will run on a database whose schema is represented in this string: +'{db_schema}' +Use the following examples to generate the sql query: +'{few_shot_examples}' +### Response: +Based on your instructions, here is the SQL query I have generated to answer '{user_input}' +```sql""" + diff --git a/libs/langchain/langchain/tools/spark_unitycatalog/tool.py b/libs/langchain/langchain/tools/spark_unitycatalog/tool.py index 14e95f6d9d5e8..b7e8f9253eac2 100644 --- a/libs/langchain/langchain/tools/spark_unitycatalog/tool.py +++ b/libs/langchain/langchain/tools/spark_unitycatalog/tool.py @@ -14,7 +14,7 @@ from langchain.chains.llm import LLMChain from langchain.prompts import PromptTemplate from langchain.sql_database import SQLDatabase -from langchain.tools.spark_unitycatalog.prompt import SQL_QUERY_VALIDATOR +from langchain.tools.spark_unitycatalog.prompt import SQL_QUERY_VALIDATOR, SQL_QUERY_CREATOR from langchain_core.pydantic_v1 import BaseModel, Extra, Field from langchain_core.tools import StateTool from requests.adapters import HTTPAdapter @@ -374,3 +374,202 @@ def _extract_sql_query(self): sql_query = match.group(1) return sql_query return None + + + +class SqlQueryCreatorTool(StateTool): + """Tool for creating SQL query.Use this to create sql query.""" + + name = "sql_db_query_creator" + description = """ + This is a tool used to create sql query for user input based on the schema of the table and few_shot_examples. + Input to this tool is input prompt and table schema and few_shot_examples + Output is a sql query +""" + sqlcreatorllm: BaseLanguageModel = Field(exclude=True) + # user_input : str + + + class Config(StateTool.Config): + """Configuration for this pydantic object.""" + + arbitrary_types_allowed = True + extra = Extra.allow + + def __init__(__pydantic_self__, **data: Any) -> None: + """Initialize the tool.""" + super().__init__(**data) + + def _run( + self, + user_input: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: + """Get the schema for tables in a comma-separated list.""" + if hasattr(self, "state"): + return self._create_sql_query(user_input) + + else: + return "This tool is not meant to be run directly. Start with ListUnityCatalogTablesTool" + + async def _arun( + self, + table_name: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: + raise NotImplementedError("SqlQueryCreatorTool does not support async") + + + + def _parse_few_shot_examples(self): + few_shot_examples = "" + for value in self.state: + for key, input_string in value.items(): + if "few_shot_examples" in key: + few_shot_examples = input_string + + return few_shot_examples + + + + def _create_sql_query(self,user_input): + for value in self.state: + for key in value.items(): + if "sql_db_list_tables" not in key: + table_names=self.get_table_list_from_unitycatalog() + self.state.append({"sql_db_list_tables":table_names}) + if "sql_db_schema" not in key: + db_schema = self.get_table_details_from_unity_catalog(table_names=table_names) + self.state.append({"sql_db_schema":db_schema}) + + + + few_shot_examples = self._parse_few_shot_examples() + + prompt_input = PromptTemplate( + input_variables=["db_schema", "user_input", "few_shot_examples"], + template=SQL_QUERY_CREATOR, + ) + chain = LLMChain(llm=self.sqlcreatorllm, prompt=prompt_input) + + return chain.run( + ( + { + "db_schema": db_schema, + "user_input": user_input, + "few_shot_examples": few_shot_examples, + } + ) + ) + + + def get_table_list_from_unitycatalog( + self, + ) -> str: + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.db_token}", + } + retries = Retry(total=5, backoff_factor=0.3) + adapter = HTTPAdapter(max_retries=retries) + session = requests.Session() + session.mount("http://", adapter) + session.mount("https://", adapter) + + url = f"https://{self.db_host}/api/2.1/unity-catalog/tables" + params = {"catalog_name": self.db_catalog, "schema_name": self.db_schema} + response = session.get(url, headers=headers, params=params) + if response.status_code != 200: + raise Exception(f"Error fetching list of tables : {response.text}") + json_data = json.loads(response.text) + tables = json_data["tables"] + return [table["name"] for table in tables] + + def get_table_details_from_unity_catalog( + self, + table_names: list, + ): + final_string: str = "" + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.db_token}", + } + retries = Retry(total=5, backoff_factor=0.3) + adapter = HTTPAdapter(max_retries=retries) + session = requests.Session() + session.mount("http://", adapter) + session.mount("https://", adapter) + # TODO: Improve performance by using asyncio or threading to make concurrent requests + for table_name in table_names: + table_name = table_name.strip() + url = f"https://{self.db_host}/api/2.1/unity-catalog/tables/{self.db_catalog}.{self.db_schema}.{table_name}" + response = session.get(url, headers=headers) + if response.status_code != 200: + raise Exception(f"Error fetching table {table_name}: {response.text}") + json_data = json.loads(response.text) + column_data = json_data["columns"] + table_comment = ( + json_data["comment"] if "comment" in json_data.keys() else None + ) + string_data = self._generate_create_table_query( + table_data=column_data, + table_name=table_name, + table_comment=table_comment, + ) + final_string = f"{final_string}\n{string_data}" + return final_string + + def _generate_create_table_query( + self, table_data: List[Dict], table_name: str, table_comment: str + ): + sample_rows_in_table_info: int = 3 + if table_comment: + query = f"CREATE TABLE {table_name} COMMENT '{table_comment}' (\n" + else: + query = f"CREATE TABLE {table_name} (\n" + for column_info in table_data: + column_name = column_info["name"] + column_type = column_info["type_text"].upper() + if column_comment := column_info.get("comment", None): + query += f"\t{column_name} {column_type} COMMENT '{column_comment}' " + else: + query += f"\t{column_name} {column_type} " + + # Add a comma if it's not the last column + if column_info != table_data[-1]: + query += "," + + query += "\n" + + query += ") USING DELTA" + + column_names = [item["name"] for item in table_data] + columns_str = "\t".join(column_names) + + top_3_rows = self._get_sample_rows(table=table_name) + + return ( + f"{query}\n" + ) + + def _get_sample_rows(self, table: str): + sample_rows_in_table_info: int = 3 + command = "Select * from {table} limit {sample_rows_in_table_info}".format( + table=table, sample_rows_in_table_info=sample_rows_in_table_info + ) + + try: + with self.db._engine.connect() as connection: + sample_rows_result = connection.execute(command) + sample_rows = list( + map(lambda ls: [str(i)[:100] for i in ls], sample_rows_result) + ) + sample_rows_str = "\n".join(["\t".join(row) for row in sample_rows]) + + except ProgrammingError: + sample_rows_str = "" + + return sample_rows_str + + + \ No newline at end of file diff --git a/libs/langchain/pyproject.toml b/libs/langchain/pyproject.toml index f8472f9b029e8..32d5fb9ad2ddd 100644 --- a/libs/langchain/pyproject.toml +++ b/libs/langchain/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langchain" -version = "0.1.21dev1" +version = "0.1.22dev1" description = "Building applications with LLMs through composability" authors = [] license = "MIT" From d49d0ede470ce681d91b78d67d6a8e46eedf2281 Mon Sep 17 00:00:00 2001 From: sushant Date: Fri, 1 Mar 2024 10:12:28 +0530 Subject: [PATCH 02/33] sqlquerycreatortool added --- .../tools/spark_unitycatalog/tool.py | 25 +++++++++++-------- libs/langchain/pyproject.toml | 2 +- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/libs/langchain/langchain/tools/spark_unitycatalog/tool.py b/libs/langchain/langchain/tools/spark_unitycatalog/tool.py index b7e8f9253eac2..6a5c1438f9eb1 100644 --- a/libs/langchain/langchain/tools/spark_unitycatalog/tool.py +++ b/libs/langchain/langchain/tools/spark_unitycatalog/tool.py @@ -4,6 +4,7 @@ import os import re from typing import Any, Dict, List, Optional +import logging import requests from langchain.base_language import BaseLanguageModel @@ -314,9 +315,7 @@ def _validate_sql_query(self, query): ) chain = LLMChain(llm=self.llm, prompt=prompt_input) - query_validation = chain.run(({"db_schema": db_schema, "query": query})) - - return query_validation + return chain.run(({"db_schema": db_schema, "query": query})) class QueryUCSQLDataBaseTool(StateTool): @@ -433,14 +432,18 @@ def _parse_few_shot_examples(self): def _create_sql_query(self,user_input): - for value in self.state: - for key in value.items(): - if "sql_db_list_tables" not in key: - table_names=self.get_table_list_from_unitycatalog() - self.state.append({"sql_db_list_tables":table_names}) - if "sql_db_schema" not in key: - db_schema = self.get_table_details_from_unity_catalog(table_names=table_names) - self.state.append({"sql_db_schema":db_schema}) + + keys = [key for value in self.state for key in value.keys()] + + if "sql_db_list_tables" not in keys: + logging.info('Getting table names') + table_names=self.get_table_list_from_unitycatalog() + self.state.append({"sql_db_list_tables":table_names}) + + if "sql_db_schema" not in keys: + logging.info('Getting table schema') + db_schema = self.get_table_details_from_unity_catalog(table_names=table_names) + self.state.append({"sql_db_schema":db_schema}) diff --git a/libs/langchain/pyproject.toml b/libs/langchain/pyproject.toml index 32d5fb9ad2ddd..0c281828756dd 100644 --- a/libs/langchain/pyproject.toml +++ b/libs/langchain/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langchain" -version = "0.1.22dev1" +version = "0.1.24dev1" description = "Building applications with LLMs through composability" authors = [] license = "MIT" From d63efe3506d6f5c39959ff0ae2cc5a1b4924512c Mon Sep 17 00:00:00 2001 From: arunraja1 Date: Mon, 4 Mar 2024 14:29:21 +0530 Subject: [PATCH 03/33] Update version and add SQLCoderToolkit --- .../agents/agent_toolkits/__init__.py | 2 + .../agent_toolkits/sqlcoder/__init__.py | 1 + .../agents/agent_toolkits/sqlcoder/base.py | 93 ++++++++ .../agents/agent_toolkits/sqlcoder/prompt.py | 23 ++ .../agents/agent_toolkits/sqlcoder/toolkit.py | 63 ++++++ .../agent_toolkits/unitycatalog/toolkit.py | 9 - libs/langchain/langchain/tools/__init__.py | 13 +- .../tools/spark_unitycatalog/prompt.py | 26 --- .../tools/spark_unitycatalog/tool.py | 205 +----------------- .../langchain/tools/sqlcoder/__init__.py | 1 + .../langchain/tools/sqlcoder/prompt.py | 30 +++ .../langchain/tools/sqlcoder/tool.py | 189 ++++++++++++++++ libs/langchain/pyproject.toml | 2 +- 13 files changed, 414 insertions(+), 243 deletions(-) create mode 100644 libs/langchain/langchain/agents/agent_toolkits/sqlcoder/__init__.py create mode 100644 libs/langchain/langchain/agents/agent_toolkits/sqlcoder/base.py create mode 100644 libs/langchain/langchain/agents/agent_toolkits/sqlcoder/prompt.py create mode 100644 libs/langchain/langchain/agents/agent_toolkits/sqlcoder/toolkit.py create mode 100644 libs/langchain/langchain/tools/sqlcoder/__init__.py create mode 100644 libs/langchain/langchain/tools/sqlcoder/prompt.py create mode 100644 libs/langchain/langchain/tools/sqlcoder/tool.py diff --git a/libs/langchain/langchain/agents/agent_toolkits/__init__.py b/libs/langchain/langchain/agents/agent_toolkits/__init__.py index e7424b69f5e82..b152682cdcaa0 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/__init__.py +++ b/libs/langchain/langchain/agents/agent_toolkits/__init__.py @@ -39,6 +39,7 @@ from langchain.agents.agent_toolkits.office365.toolkit import O365Toolkit from langchain.agents.agent_toolkits.openapi.base import create_openapi_agent from langchain.agents.agent_toolkits.openapi.toolkit import OpenAPIToolkit +from langchain.agents.agent_toolkits.sqlcoder.toolkit import SQLCoderToolkit from langchain.agents.agent_toolkits.playwright.toolkit import PlayWrightBrowserToolkit from langchain.agents.agent_toolkits.powerbi.base import create_pbi_agent from langchain.agents.agent_toolkits.powerbi.chat_base import create_pbi_chat_agent @@ -102,6 +103,7 @@ def __getattr__(name: str) -> Any: "PowerBIToolkit", "SlackToolkit", "SteamToolkit", + "SQLCoderToolkit", "SQLDatabaseToolkit", "SparkSQLToolkit", "UCSQLDatabaseToolkit", diff --git a/libs/langchain/langchain/agents/agent_toolkits/sqlcoder/__init__.py b/libs/langchain/langchain/agents/agent_toolkits/sqlcoder/__init__.py new file mode 100644 index 0000000000000..448ffc8cec4ac --- /dev/null +++ b/libs/langchain/langchain/agents/agent_toolkits/sqlcoder/__init__.py @@ -0,0 +1 @@ +"""SQL Coder Agent Toolkit""" diff --git a/libs/langchain/langchain/agents/agent_toolkits/sqlcoder/base.py b/libs/langchain/langchain/agents/agent_toolkits/sqlcoder/base.py new file mode 100644 index 0000000000000..0d516622bde33 --- /dev/null +++ b/libs/langchain/langchain/agents/agent_toolkits/sqlcoder/base.py @@ -0,0 +1,93 @@ +"""SQL agent.""" +from typing import Any, Dict, List, Optional + +from langchain.agents.agent import AgentExecutor, BaseSingleActionAgent +from langchain.agents.agent_toolkits.sql.prompt import ( + SQL_FUNCTIONS_SUFFIX, + SQL_PREFIX, + SQL_SUFFIX, +) +from langchain.agents.agent_toolkits.sql.toolkit import SQLDatabaseToolkit +from langchain.agents.agent_types import AgentType +from langchain.agents.mrkl.base import ZeroShotAgent +from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS +from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent +from langchain.callbacks.base import BaseCallbackManager +from langchain.chains.llm import LLMChain +from langchain.prompts.chat import ( + ChatPromptTemplate, + HumanMessagePromptTemplate, + MessagesPlaceholder, +) +from langchain.schema.language_model import BaseLanguageModel +from langchain.schema.messages import AIMessage, SystemMessage + + +def create_sql_agent( + llm: BaseLanguageModel, + toolkit: SQLDatabaseToolkit, + agent_type: AgentType = AgentType.ZERO_SHOT_REACT_DESCRIPTION, + callback_manager: Optional[BaseCallbackManager] = None, + prefix: str = SQL_PREFIX, + suffix: Optional[str] = None, + format_instructions: str = FORMAT_INSTRUCTIONS, + input_variables: Optional[List[str]] = None, + top_k: int = 10, + max_iterations: Optional[int] = 15, + max_execution_time: Optional[float] = None, + early_stopping_method: str = "force", + verbose: bool = False, + agent_executor_kwargs: Optional[Dict[str, Any]] = None, + **kwargs: Dict[str, Any], +) -> AgentExecutor: + """Construct an SQL agent from an LLM and tools.""" + tools = toolkit.get_tools() + prefix = prefix.format(dialect=toolkit.dialect, top_k=top_k) + agent: BaseSingleActionAgent + + if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION: + prompt = ZeroShotAgent.create_prompt( + tools, + prefix=prefix, + suffix=suffix or SQL_SUFFIX, + format_instructions=format_instructions, + input_variables=input_variables, + ) + llm_chain = LLMChain( + llm=llm, + prompt=prompt, + callback_manager=callback_manager, + ) + tool_names = [tool.name for tool in tools] + agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs) + + elif agent_type == AgentType.OPENAI_FUNCTIONS: + messages = [ + SystemMessage(content=prefix), + HumanMessagePromptTemplate.from_template("{input}"), + AIMessage(content=suffix or SQL_FUNCTIONS_SUFFIX), + MessagesPlaceholder(variable_name="agent_scratchpad"), + ] + input_variables = ["input", "agent_scratchpad"] + _prompt = ChatPromptTemplate(input_variables=input_variables, messages=messages) + + agent = OpenAIFunctionsAgent( + llm=llm, + prompt=_prompt, + tools=tools, + callback_manager=callback_manager, + **kwargs, + ) + else: + raise ValueError(f"Agent type {agent_type} not supported at the moment.") + + return AgentExecutor.from_agent_and_tools( + agent=agent, + tools=tools, + callback_manager=callback_manager, + verbose=verbose, + max_iterations=max_iterations, + max_execution_time=max_execution_time, + early_stopping_method=early_stopping_method, + **(agent_executor_kwargs or {}), + ) diff --git a/libs/langchain/langchain/agents/agent_toolkits/sqlcoder/prompt.py b/libs/langchain/langchain/agents/agent_toolkits/sqlcoder/prompt.py new file mode 100644 index 0000000000000..92464da4b9b9f --- /dev/null +++ b/libs/langchain/langchain/agents/agent_toolkits/sqlcoder/prompt.py @@ -0,0 +1,23 @@ +# flake8: noqa + +SQL_PREFIX = """You are an agent designed to interact with a SQL database. +Given an input question, create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. +Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results. +You can order the results by a relevant column to return the most interesting examples in the database. +Never query for all the columns from a specific table, only ask for the relevant columns given the question. +You have access to tools for interacting with the database. +Only use the below tools. Only use the information returned by the below tools to construct your final answer. +You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again. + +DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. + +If the question does not seem related to the database, just return "I don't know" as the answer. +""" + +SQL_SUFFIX = """Begin! + +Question: {input} +Thought: I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables. +{agent_scratchpad}""" + +SQL_FUNCTIONS_SUFFIX = """I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables.""" diff --git a/libs/langchain/langchain/agents/agent_toolkits/sqlcoder/toolkit.py b/libs/langchain/langchain/agents/agent_toolkits/sqlcoder/toolkit.py new file mode 100644 index 0000000000000..5dffff38c565e --- /dev/null +++ b/libs/langchain/langchain/agents/agent_toolkits/sqlcoder/toolkit.py @@ -0,0 +1,63 @@ +"""Toolkit for interacting with a SQL database.""" +from typing import List + +from langchain.agents.agent_toolkits.base import BaseToolkit +from langchain.base_language import BaseLanguageModel +from langchain.sql_database import SQLDatabase +from langchain.tools import BaseTool +from langchain.tools.sqlcoder.tool import ( + QuerySparkSQLDataBaseTool, + SqlQueryCreatorTool, +) +from langchain.tools.sql_database.tool import QuerySQLCheckerTool +from langchain_core.pydantic_v1 import Field + + +class SQLCoderToolkit(BaseToolkit): + """Toolkit for interacting with SQL databases.""" + + db: SQLDatabase = Field(exclude=True) + llm: BaseLanguageModel = Field(exclude=True) + db_token: str + db_host: str + db_catalog: str + db_schema: str + db_warehouse_id: str + allow_extra_fields = True + sqlcreatorllm : BaseLanguageModel = Field(exclude=True) + + @property + def dialect(self) -> str: + """Return string representation of dialect to use.""" + return self.db.dialect + + class Config: + """Configuration for this pydantic object.""" + + arbitrary_types_allowed = True + + def get_tools(self) -> List[BaseTool]: + """Get the tools in the toolkit.""" + query_sql_database_tool_description = ( + "Input to this tool is a detailed and correct SQL query, output is a " + "result from the database. If the query is not correct, an error message " + "will be returned. If an error is returned, rewrite the query, check the " + "query, and try again. If you encounter an issue with Unknown column " + "'xxxx' in 'field list', using schema_sql_db to query the correct table " + "fields." + ) + return [ + QuerySparkSQLDataBaseTool( + db=self.db, description=query_sql_database_tool_description + ), + QuerySQLCheckerTool(db=self.db, llm=self.llm), + SqlQueryCreatorTool( + sqlcreatorllm=self.sqlcreatorllm , + db=self.db, + db_token=self.db_token, + db_host=self.db_host, + db_catalog=self.db_catalog, + db_schema=self.db_schema, + db_warehouse_id=self.db_warehouse_id + ) + ] diff --git a/libs/langchain/langchain/agents/agent_toolkits/unitycatalog/toolkit.py b/libs/langchain/langchain/agents/agent_toolkits/unitycatalog/toolkit.py index caed7191d405f..b028a805ca8c6 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/unitycatalog/toolkit.py +++ b/libs/langchain/langchain/agents/agent_toolkits/unitycatalog/toolkit.py @@ -10,7 +10,6 @@ ListUnityCatalogTablesTool, QueryUCSQLDataBaseTool, SqlQueryValidatorTool, - SqlQueryCreatorTool, ) from langchain.tools.sql_database.tool import QuerySQLCheckerTool from langchain_core.pydantic_v1 import Field @@ -78,12 +77,4 @@ def get_tools(self) -> List[BaseTool]: ), QuerySQLCheckerTool(db=self.db, llm=self.llm), SqlQueryValidatorTool(llm=self.llm), - SqlQueryCreatorTool( - sqlcreatorllm=self.sqlcreatorllm , - db=self.db, - db_token=self.db_token, - db_host=self.db_host, - db_catalog=self.db_catalog, - db_schema=self.db_schema, - db_warehouse_id=self.db_warehouse_id ) ] diff --git a/libs/langchain/langchain/tools/__init__.py b/libs/langchain/langchain/tools/__init__.py index e19e0fe6cbf2d..c80a22e88b242 100644 --- a/libs/langchain/langchain/tools/__init__.py +++ b/libs/langchain/langchain/tools/__init__.py @@ -773,11 +773,16 @@ def _import_spark_unitycatalog_tool_SqlQueryValidatorTool() -> Any: return SqlQueryValidatorTool -def _import_spark_unitycatalog_tool_SqlQueryCreatorTool() -> Any: - from langchain.tools.spark_unitycatalog.tool import SqlQueryCreatorTool +def _import_sqlcoder_tool_SqlQueryCreatorTool() -> Any: + from langchain.tools.sqlcoder.tool import SqlQueryCreatorTool return SqlQueryCreatorTool +def _import_sqlcoder_tool_QuerySparkSQLDatabaseTool() -> Any: + from langchain.tools.sqlcoder.tool import QuerySparkSQLDataBaseTool + + return QuerySparkSQLDataBaseTool + def _import_stackexchange_tool() -> Any: from langchain.tools.stackexchange.tool import StackExchangeTool @@ -1080,7 +1085,9 @@ def __getattr__(name: str) -> Any: elif name == "SqlQueryValidatorTool": return _import_spark_unitycatalog_tool_SqlQueryValidatorTool() elif name == "SqlQueryCreatorTool": - return _import_spark_unitycatalog_tool_SqlQueryCreatorTool() + return _import_sqlcoder_tool_SqlQueryCreatorTool() + elif name == "QuerySparkSQLDataBaseTool": + return _import_sqlcoder_tool_QuerySparkSQLDatabaseTool() elif name == "InfoSparkSQLTool": return _import_spark_sql_tool_InfoSparkSQLTool() elif name == "ListSparkSQLTool": diff --git a/libs/langchain/langchain/tools/spark_unitycatalog/prompt.py b/libs/langchain/langchain/tools/spark_unitycatalog/prompt.py index dc1874611bd4a..9fed682b121b3 100644 --- a/libs/langchain/langchain/tools/spark_unitycatalog/prompt.py +++ b/libs/langchain/langchain/tools/spark_unitycatalog/prompt.py @@ -32,29 +32,3 @@ Begin SQL Query Validation. """ - - -SQL_QUERY_CREATOR = """### Instructions: -Your task is convert a question into a SQL query, given a schema which is databricks sql compatible. -Adhere to these rules: -- **Deliberately go through the question and database schema word by word** to appropriately answer the question -- **Use Table Aliases** to prevent ambiguity. For example, `SELECT table1.col1, table2.col1 FROM table1 JOIN table2 ON table1.id = table2.id`. -- When creating a ratio, always cast the numerator as float -You are an AI research assistant in the senior living industry. -You have access to a database that contains the information about different communities, their amenities, residents, expenses, budget, revenue and other finances, facilities, beds, events -When querying the database, given an input question, create a syntactically correct query to run, then look at the results of the query and return the answer. -Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 10 results. -You can order the results by a relevant column to return the most interesting examples in the database. -Never query for all the columns from a specific table, only ask for the relevant columns given the question. -DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. - -### Input: -Generate a SQL query that answers the question `{user_input}`. -This query will run on a database whose schema is represented in this string: -'{db_schema}' -Use the following examples to generate the sql query: -'{few_shot_examples}' -### Response: -Based on your instructions, here is the SQL query I have generated to answer '{user_input}' -```sql""" - diff --git a/libs/langchain/langchain/tools/spark_unitycatalog/tool.py b/libs/langchain/langchain/tools/spark_unitycatalog/tool.py index 6a5c1438f9eb1..9d36779e38214 100644 --- a/libs/langchain/langchain/tools/spark_unitycatalog/tool.py +++ b/libs/langchain/langchain/tools/spark_unitycatalog/tool.py @@ -15,7 +15,7 @@ from langchain.chains.llm import LLMChain from langchain.prompts import PromptTemplate from langchain.sql_database import SQLDatabase -from langchain.tools.spark_unitycatalog.prompt import SQL_QUERY_VALIDATOR, SQL_QUERY_CREATOR +from langchain.tools.spark_unitycatalog.prompt import SQL_QUERY_VALIDATOR from langchain_core.pydantic_v1 import BaseModel, Extra, Field from langchain_core.tools import StateTool from requests.adapters import HTTPAdapter @@ -373,206 +373,3 @@ def _extract_sql_query(self): sql_query = match.group(1) return sql_query return None - - - -class SqlQueryCreatorTool(StateTool): - """Tool for creating SQL query.Use this to create sql query.""" - - name = "sql_db_query_creator" - description = """ - This is a tool used to create sql query for user input based on the schema of the table and few_shot_examples. - Input to this tool is input prompt and table schema and few_shot_examples - Output is a sql query -""" - sqlcreatorllm: BaseLanguageModel = Field(exclude=True) - # user_input : str - - - class Config(StateTool.Config): - """Configuration for this pydantic object.""" - - arbitrary_types_allowed = True - extra = Extra.allow - - def __init__(__pydantic_self__, **data: Any) -> None: - """Initialize the tool.""" - super().__init__(**data) - - def _run( - self, - user_input: str, - run_manager: Optional[CallbackManagerForToolRun] = None, - ) -> str: - """Get the schema for tables in a comma-separated list.""" - if hasattr(self, "state"): - return self._create_sql_query(user_input) - - else: - return "This tool is not meant to be run directly. Start with ListUnityCatalogTablesTool" - - async def _arun( - self, - table_name: str, - run_manager: Optional[AsyncCallbackManagerForToolRun] = None, - ) -> str: - raise NotImplementedError("SqlQueryCreatorTool does not support async") - - - - def _parse_few_shot_examples(self): - few_shot_examples = "" - for value in self.state: - for key, input_string in value.items(): - if "few_shot_examples" in key: - few_shot_examples = input_string - - return few_shot_examples - - - - def _create_sql_query(self,user_input): - - keys = [key for value in self.state for key in value.keys()] - - if "sql_db_list_tables" not in keys: - logging.info('Getting table names') - table_names=self.get_table_list_from_unitycatalog() - self.state.append({"sql_db_list_tables":table_names}) - - if "sql_db_schema" not in keys: - logging.info('Getting table schema') - db_schema = self.get_table_details_from_unity_catalog(table_names=table_names) - self.state.append({"sql_db_schema":db_schema}) - - - - few_shot_examples = self._parse_few_shot_examples() - - prompt_input = PromptTemplate( - input_variables=["db_schema", "user_input", "few_shot_examples"], - template=SQL_QUERY_CREATOR, - ) - chain = LLMChain(llm=self.sqlcreatorllm, prompt=prompt_input) - - return chain.run( - ( - { - "db_schema": db_schema, - "user_input": user_input, - "few_shot_examples": few_shot_examples, - } - ) - ) - - - def get_table_list_from_unitycatalog( - self, - ) -> str: - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {self.db_token}", - } - retries = Retry(total=5, backoff_factor=0.3) - adapter = HTTPAdapter(max_retries=retries) - session = requests.Session() - session.mount("http://", adapter) - session.mount("https://", adapter) - - url = f"https://{self.db_host}/api/2.1/unity-catalog/tables" - params = {"catalog_name": self.db_catalog, "schema_name": self.db_schema} - response = session.get(url, headers=headers, params=params) - if response.status_code != 200: - raise Exception(f"Error fetching list of tables : {response.text}") - json_data = json.loads(response.text) - tables = json_data["tables"] - return [table["name"] for table in tables] - - def get_table_details_from_unity_catalog( - self, - table_names: list, - ): - final_string: str = "" - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {self.db_token}", - } - retries = Retry(total=5, backoff_factor=0.3) - adapter = HTTPAdapter(max_retries=retries) - session = requests.Session() - session.mount("http://", adapter) - session.mount("https://", adapter) - # TODO: Improve performance by using asyncio or threading to make concurrent requests - for table_name in table_names: - table_name = table_name.strip() - url = f"https://{self.db_host}/api/2.1/unity-catalog/tables/{self.db_catalog}.{self.db_schema}.{table_name}" - response = session.get(url, headers=headers) - if response.status_code != 200: - raise Exception(f"Error fetching table {table_name}: {response.text}") - json_data = json.loads(response.text) - column_data = json_data["columns"] - table_comment = ( - json_data["comment"] if "comment" in json_data.keys() else None - ) - string_data = self._generate_create_table_query( - table_data=column_data, - table_name=table_name, - table_comment=table_comment, - ) - final_string = f"{final_string}\n{string_data}" - return final_string - - def _generate_create_table_query( - self, table_data: List[Dict], table_name: str, table_comment: str - ): - sample_rows_in_table_info: int = 3 - if table_comment: - query = f"CREATE TABLE {table_name} COMMENT '{table_comment}' (\n" - else: - query = f"CREATE TABLE {table_name} (\n" - for column_info in table_data: - column_name = column_info["name"] - column_type = column_info["type_text"].upper() - if column_comment := column_info.get("comment", None): - query += f"\t{column_name} {column_type} COMMENT '{column_comment}' " - else: - query += f"\t{column_name} {column_type} " - - # Add a comma if it's not the last column - if column_info != table_data[-1]: - query += "," - - query += "\n" - - query += ") USING DELTA" - - column_names = [item["name"] for item in table_data] - columns_str = "\t".join(column_names) - - top_3_rows = self._get_sample_rows(table=table_name) - - return ( - f"{query}\n" - ) - - def _get_sample_rows(self, table: str): - sample_rows_in_table_info: int = 3 - command = "Select * from {table} limit {sample_rows_in_table_info}".format( - table=table, sample_rows_in_table_info=sample_rows_in_table_info - ) - - try: - with self.db._engine.connect() as connection: - sample_rows_result = connection.execute(command) - sample_rows = list( - map(lambda ls: [str(i)[:100] for i in ls], sample_rows_result) - ) - sample_rows_str = "\n".join(["\t".join(row) for row in sample_rows]) - - except ProgrammingError: - sample_rows_str = "" - - return sample_rows_str - - - \ No newline at end of file diff --git a/libs/langchain/langchain/tools/sqlcoder/__init__.py b/libs/langchain/langchain/tools/sqlcoder/__init__.py new file mode 100644 index 0000000000000..22d864cfb5bbf --- /dev/null +++ b/libs/langchain/langchain/tools/sqlcoder/__init__.py @@ -0,0 +1 @@ +"""Tools for interacting with a SQL database using SQLCoder LLM""" diff --git a/libs/langchain/langchain/tools/sqlcoder/prompt.py b/libs/langchain/langchain/tools/sqlcoder/prompt.py new file mode 100644 index 0000000000000..d7d7fa3184cb9 --- /dev/null +++ b/libs/langchain/langchain/tools/sqlcoder/prompt.py @@ -0,0 +1,30 @@ +SQL_QUERY_CREATOR = """### Instructions: +Your task is convert a question into a SQL query, given a schema which is databricks sql compatible. +Adhere to these rules: +- **Deliberately go through the question and database schema word by word** to appropriately answer the question +- **Use Table Aliases** to prevent ambiguity. For example, `SELECT table1.col1, table2.col1 FROM table1 JOIN table2 ON table1.id = table2.id`. +- When creating a ratio, always cast the numerator as float +You are an AI research assistant in the senior living industry. +You have access to a database that contains the information about different communities, their amenities, residents, expenses, budget, revenue and other finances, facilities, beds, events +When querying the database, given an input question, create a syntactically correct query to run, then look at the results of the query and return the answer. +Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 10 results. +You can order the results by a relevant column to return the most interesting examples in the database. +Never query for all the columns from a specific table, only ask for the relevant columns given the question. +DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. + +### Input: +Generate a SQL query that answers the question `{user_input}`. +This query will run on a database whose schema is represented in this string: +'{db_schema}' +Use the following examples to generate the sql query: +'{few_shot_examples}' +### Response: +Based on your instructions, here is the SQL query I have generated to answer '{user_input}' +```sql""" + + +SQL_QUERY_CREATOR_RETRY = """ +You have failed in the first attempt to generate correct sql query. Please try again to generate correct sql query. +The previously generated queries are {sql_query}. +Make sure you create right query by using the schema of the table and few_shot_examples. +""" \ No newline at end of file diff --git a/libs/langchain/langchain/tools/sqlcoder/tool.py b/libs/langchain/langchain/tools/sqlcoder/tool.py new file mode 100644 index 0000000000000..a5e4acda74f28 --- /dev/null +++ b/libs/langchain/langchain/tools/sqlcoder/tool.py @@ -0,0 +1,189 @@ +# flake8: noqa +"""Tools for interacting with a SQL database.""" +from typing import Any, Dict, List, Optional +from langchain.base_language import BaseLanguageModel +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) +from langchain.chains.llm import LLMChain +from langchain.prompts import PromptTemplate +from langchain.sql_database import SQLDatabase +from langchain.tools.sqlcoder.prompt import SQL_QUERY_CREATOR_RETRY, SQL_QUERY_CREATOR +from langchain_core.pydantic_v1 import BaseModel, Extra, Field +from langchain_core.tools import StateTool + +class BaseSQLDatabaseTool(BaseModel): + """Base tool for interacting with a SQL database.""" + + db: SQLDatabase = Field(exclude=True) + + # Override BaseTool.Config to appease mypy + # See https://github.com/pydantic/pydantic/issues/4173 + class Config(StateTool.Config): + + """Configuration for this pydantic object.""" + + arbitrary_types_allowed = True + extra = Extra.allow + + +class QuerySparkSQLDataBaseTool(StateTool): + """Tool for querying a SQL database.""" + + class Config(StateTool.Config): + """Configuration for this pydantic object.""" + + arbitrary_types_allowed = True + extra = Extra.allow + + db: SQLDatabase = Field(exclude=True) + name: str = "sql_db_query" + description: str = """ + Input to this tool is a detailed and correct SQL query, output is a result from the database. + If the query is not correct, an error message will be returned. + If an error is returned, rewrite the query, check the query, and try again. + """ + + def __init__(__pydantic_self__, **data: Any) -> None: + """Initialize the tool.""" + super().__init__(**data) + + def _run( + self, + query: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: + """Execute the query, return the results or an error message.""" + if not hasattr(self, "state"): + return "This tool is not meant to be run directly. Start with a SQLQueryCreatorTool" + executable_query = ( + extracted_sql_query.strip() + if (extracted_sql_query := self._extract_sql_query()) + else query.strip() + ) + executable_query = executable_query.strip('"') + return self.db.run_no_throw(executable_query) + + async def _arun( + self, + query: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: + raise NotImplementedError("QuerySparkSQLDataBaseTool does not support async") + + def _extract_sql_query(self): + for value in self.state: + for key, input_string in value.items(): + if "sql_db_query_creator" in key: + return input_string + return None + + + +class SqlQueryCreatorTool(StateTool): + """Tool for creating SQL query.Use this to create sql query.""" + + name = "sql_db_query_creator" + description = """ + This is a tool used to create sql query for user input based on the schema of the table and few_shot_examples. + Input to this tool is input prompt and table schema and few_shot_examples + Output is a sql query + """ + sqlcreatorllm: BaseLanguageModel = Field(exclude=True) + + class Config(StateTool.Config): + """Configuration for this pydantic object.""" + + arbitrary_types_allowed = True + extra = Extra.allow + + def __init__(__pydantic_self__, **data: Any) -> None: + """Initialize the tool.""" + super().__init__(**data) + + def _run( + self, + user_input: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: + """Get the SQL query for the user input.""" + return self._create_sql_query(user_input) + + async def _arun( + self, + table_name: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: + raise NotImplementedError("SqlQueryCreatorTool does not support async") + + def _parse_few_shot_examples(self): + few_shot_examples = "" + for value in self.state: + for key, input_string in value.items(): + if "few_shot_examples" in key: + few_shot_examples = input_string + + return few_shot_examples + + def _parse_db_schema(self): + db_schema = {} + for value in self.state: + for key, input_string in value.items(): + if "sql_db_schema" in key: + db_schema = input_string + return db_schema + + def _create_sql_query(self,user_input): + + few_shot_examples = self._parse_few_shot_examples() + db_schema = self._parse_db_schema() + sql_query = self._extract_sql_query() + + if len(sql_query) == 0: + prompt_input = PromptTemplate( + input_variables=["db_schema", "user_input", "few_shot_examples"], + template=SQL_QUERY_CREATOR, + ) + query_creator_chain = LLMChain(llm=self.sqlcreatorllm, prompt=prompt_input) + + sql_query = query_creator_chain.run( + ( + { + "db_schema": db_schema, + "user_input": user_input, + "few_shot_examples": few_shot_examples, + } + ) + ) + + else: + prompt_input = PromptTemplate( + input_variables=["db_schema", "user_input", "few_shot_examples","sql_query"], + template=SQL_QUERY_CREATOR + SQL_QUERY_CREATOR_RETRY, + ) + query_creator_chain = LLMChain(llm=self.sqlcreatorllm, prompt=prompt_input) + + sql_query = query_creator_chain.run( + ( + { + "db_schema": db_schema, + "user_input": user_input, + "few_shot_examples": few_shot_examples, + "sql_query": sql_query + } + ) + ) + if hasattr(self, "state"): + self.state.append({"sql_db_query_creator": sql_query}) + return sql_query + + def _extract_sql_query(self): + sql_queries = [] + for value in self.state: + sql_queries.extend( + input_string + for key, input_string in value.items() + if "sql_db_query_creator" in key + ) + return sql_queries \ No newline at end of file diff --git a/libs/langchain/pyproject.toml b/libs/langchain/pyproject.toml index 0c281828756dd..e738398e2bdd9 100644 --- a/libs/langchain/pyproject.toml +++ b/libs/langchain/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langchain" -version = "0.1.24dev1" +version = "0.1.25dev1" description = "Building applications with LLMs through composability" authors = [] license = "MIT" From d3955895aaf04ab28320b9ac06c1ea4967026e19 Mon Sep 17 00:00:00 2001 From: sushant Date: Mon, 4 Mar 2024 16:47:45 +0530 Subject: [PATCH 04/33] Updated tool.py for sqlcoder --- .../langchain/tools/sqlcoder/tool.py | 95 ++++++++++--------- libs/langchain/pyproject.toml | 2 +- 2 files changed, 49 insertions(+), 48 deletions(-) diff --git a/libs/langchain/langchain/tools/sqlcoder/tool.py b/libs/langchain/langchain/tools/sqlcoder/tool.py index a5e4acda74f28..18512892a18b3 100644 --- a/libs/langchain/langchain/tools/sqlcoder/tool.py +++ b/libs/langchain/langchain/tools/sqlcoder/tool.py @@ -12,7 +12,7 @@ from langchain.tools.sqlcoder.prompt import SQL_QUERY_CREATOR_RETRY, SQL_QUERY_CREATOR from langchain_core.pydantic_v1 import BaseModel, Extra, Field from langchain_core.tools import StateTool - +import re class BaseSQLDatabaseTool(BaseModel): """Base tool for interacting with a SQL database.""" @@ -62,7 +62,8 @@ def _run( if (extracted_sql_query := self._extract_sql_query()) else query.strip() ) - executable_query = executable_query.strip('"') + executable_query = executable_query.strip('\"') + executable_query = re.sub('\\n```', '',executable_query) return self.db.run_no_throw(executable_query) async def _arun( @@ -138,52 +139,52 @@ def _create_sql_query(self,user_input): few_shot_examples = self._parse_few_shot_examples() db_schema = self._parse_db_schema() - sql_query = self._extract_sql_query() - - if len(sql_query) == 0: - prompt_input = PromptTemplate( - input_variables=["db_schema", "user_input", "few_shot_examples"], - template=SQL_QUERY_CREATOR, - ) - query_creator_chain = LLMChain(llm=self.sqlcreatorllm, prompt=prompt_input) - - sql_query = query_creator_chain.run( - ( - { - "db_schema": db_schema, - "user_input": user_input, - "few_shot_examples": few_shot_examples, - } - ) + #sql_query = self._extract_sql_query() + + #if len(sql_query) == 0: + prompt_input = PromptTemplate( + input_variables=["db_schema", "user_input", "few_shot_examples"], + template=SQL_QUERY_CREATOR, + ) + query_creator_chain = LLMChain(llm=self.sqlcreatorllm, prompt=prompt_input) + + sql_query = query_creator_chain.run( + ( + { + "db_schema": db_schema, + "user_input": user_input, + "few_shot_examples": few_shot_examples, + } ) + ) - else: - prompt_input = PromptTemplate( - input_variables=["db_schema", "user_input", "few_shot_examples","sql_query"], - template=SQL_QUERY_CREATOR + SQL_QUERY_CREATOR_RETRY, - ) - query_creator_chain = LLMChain(llm=self.sqlcreatorllm, prompt=prompt_input) - - sql_query = query_creator_chain.run( - ( - { - "db_schema": db_schema, - "user_input": user_input, - "few_shot_examples": few_shot_examples, - "sql_query": sql_query - } - ) - ) - if hasattr(self, "state"): - self.state.append({"sql_db_query_creator": sql_query}) + # else: + # prompt_input = PromptTemplate( + # input_variables=["db_schema", "user_input", "few_shot_examples","sql_query"], + # template=SQL_QUERY_CREATOR + SQL_QUERY_CREATOR_RETRY, + # ) + # query_creator_chain = LLMChain(llm=self.sqlcreatorllm, prompt=prompt_input) + + # sql_query = query_creator_chain.run( + # ( + # { + # "db_schema": db_schema, + # "user_input": user_input, + # "few_shot_examples": few_shot_examples, + # "sql_query": sql_query + # } + # ) + # ) + # if hasattr(self, "state"): + # self.state.append({"sql_db_query_creator": sql_query}) return sql_query - def _extract_sql_query(self): - sql_queries = [] - for value in self.state: - sql_queries.extend( - input_string - for key, input_string in value.items() - if "sql_db_query_creator" in key - ) - return sql_queries \ No newline at end of file + # def _extract_sql_query(self): + # sql_queries = [] + # for value in self.state: + # sql_queries.extend( + # input_string + # for key, input_string in value.items() + # if "sql_db_query_creator" in key + # ) + # return sql_queries \ No newline at end of file diff --git a/libs/langchain/pyproject.toml b/libs/langchain/pyproject.toml index e738398e2bdd9..01a8e99dfe20a 100644 --- a/libs/langchain/pyproject.toml +++ b/libs/langchain/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langchain" -version = "0.1.25dev1" +version = "0.1.26dev1" description = "Building applications with LLMs through composability" authors = [] license = "MIT" From 5a220ce6c79a637b6926f1149d7b7777d2d961de Mon Sep 17 00:00:00 2001 From: sushant Date: Mon, 4 Mar 2024 17:50:53 +0530 Subject: [PATCH 05/33] retry logic added in sqlquerycreator tool --- .../langchain/tools/sqlcoder/tool.py | 88 +++++++++---------- libs/langchain/pyproject.toml | 2 +- 2 files changed, 44 insertions(+), 46 deletions(-) diff --git a/libs/langchain/langchain/tools/sqlcoder/tool.py b/libs/langchain/langchain/tools/sqlcoder/tool.py index 18512892a18b3..494d34ae89cfe 100644 --- a/libs/langchain/langchain/tools/sqlcoder/tool.py +++ b/libs/langchain/langchain/tools/sqlcoder/tool.py @@ -139,52 +139,50 @@ def _create_sql_query(self,user_input): few_shot_examples = self._parse_few_shot_examples() db_schema = self._parse_db_schema() - #sql_query = self._extract_sql_query() - - #if len(sql_query) == 0: - prompt_input = PromptTemplate( - input_variables=["db_schema", "user_input", "few_shot_examples"], - template=SQL_QUERY_CREATOR, - ) - query_creator_chain = LLMChain(llm=self.sqlcreatorllm, prompt=prompt_input) - - sql_query = query_creator_chain.run( - ( - { - "db_schema": db_schema, - "user_input": user_input, - "few_shot_examples": few_shot_examples, - } + sql_query = self._extract_sql_query() + + if sql_query is None: + prompt_input = PromptTemplate( + input_variables=["db_schema", "user_input", "few_shot_examples"], + template=SQL_QUERY_CREATOR, + ) + query_creator_chain = LLMChain(llm=self.sqlcreatorllm, prompt=prompt_input) + + sql_query = query_creator_chain.run( + ( + { + "db_schema": db_schema, + "user_input": user_input, + "few_shot_examples": few_shot_examples, + } + ) ) - ) - # else: - # prompt_input = PromptTemplate( - # input_variables=["db_schema", "user_input", "few_shot_examples","sql_query"], - # template=SQL_QUERY_CREATOR + SQL_QUERY_CREATOR_RETRY, - # ) - # query_creator_chain = LLMChain(llm=self.sqlcreatorllm, prompt=prompt_input) - - # sql_query = query_creator_chain.run( - # ( - # { - # "db_schema": db_schema, - # "user_input": user_input, - # "few_shot_examples": few_shot_examples, - # "sql_query": sql_query - # } - # ) - # ) - # if hasattr(self, "state"): - # self.state.append({"sql_db_query_creator": sql_query}) + else: + prompt_input = PromptTemplate( + input_variables=["db_schema", "user_input", "few_shot_examples","sql_query"], + template=SQL_QUERY_CREATOR + SQL_QUERY_CREATOR_RETRY, + ) + query_creator_chain = LLMChain(llm=self.sqlcreatorllm, prompt=prompt_input) + + sql_query = query_creator_chain.run( + ( + { + "db_schema": db_schema, + "user_input": user_input, + "few_shot_examples": few_shot_examples, + "sql_query": sql_query + } + ) + ) + if hasattr(self, "state"): + self.state.append({"sql_db_query_creator": sql_query}) + return sql_query - # def _extract_sql_query(self): - # sql_queries = [] - # for value in self.state: - # sql_queries.extend( - # input_string - # for key, input_string in value.items() - # if "sql_db_query_creator" in key - # ) - # return sql_queries \ No newline at end of file + def _extract_sql_query(self): + for value in self.state: + for key, input_string in value.items(): + if "sql_db_query_creator" in key: + return input_string + return None \ No newline at end of file diff --git a/libs/langchain/pyproject.toml b/libs/langchain/pyproject.toml index 01a8e99dfe20a..696e52a5e0580 100644 --- a/libs/langchain/pyproject.toml +++ b/libs/langchain/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langchain" -version = "0.1.26dev1" +version = "0.1.28dev1" description = "Building applications with LLMs through composability" authors = [] license = "MIT" From a651d322480957687a8032919a83a76e38f57065 Mon Sep 17 00:00:00 2001 From: sushant Date: Tue, 12 Mar 2024 19:28:27 +0530 Subject: [PATCH 06/33] sql query capture updated --- libs/langchain/langchain/tools/sqlcoder/tool.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/libs/langchain/langchain/tools/sqlcoder/tool.py b/libs/langchain/langchain/tools/sqlcoder/tool.py index 494d34ae89cfe..1c9ce04265223 100644 --- a/libs/langchain/langchain/tools/sqlcoder/tool.py +++ b/libs/langchain/langchain/tools/sqlcoder/tool.py @@ -157,7 +157,7 @@ def _create_sql_query(self,user_input): } ) ) - + sql_query = sql_query.replace("```","") else: prompt_input = PromptTemplate( input_variables=["db_schema", "user_input", "few_shot_examples","sql_query"], @@ -175,6 +175,7 @@ def _create_sql_query(self,user_input): } ) ) + sql_query = sql_query.replace("```","") if hasattr(self, "state"): self.state.append({"sql_db_query_creator": sql_query}) From 2da388a78355709baf6836a590bc9d918d6e9ed3 Mon Sep 17 00:00:00 2001 From: sushant Date: Tue, 12 Mar 2024 19:48:03 +0530 Subject: [PATCH 07/33] version update --- libs/langchain/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/langchain/pyproject.toml b/libs/langchain/pyproject.toml index 696e52a5e0580..69807673d20bf 100644 --- a/libs/langchain/pyproject.toml +++ b/libs/langchain/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langchain" -version = "0.1.28dev1" +version = "0.1.29dev1" description = "Building applications with LLMs through composability" authors = [] license = "MIT" From 6129071a21d7135b586adee7e2a14dd4af4b5537 Mon Sep 17 00:00:00 2001 From: sushant Date: Mon, 18 Mar 2024 09:30:13 +0530 Subject: [PATCH 08/33] langchain version update --- libs/langchain/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/langchain/pyproject.toml b/libs/langchain/pyproject.toml index 5ef084dd2f97b..373e433619e8a 100644 --- a/libs/langchain/pyproject.toml +++ b/libs/langchain/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langchain" -version = "0.1.32dev1" +version = "0.1.33dev1" description = "Building applications with LLMs through composability" authors = [] license = "MIT" From 3dd193e4a7f9bdd30430caf5547a85f36c8a0875 Mon Sep 17 00:00:00 2001 From: sushant Date: Mon, 18 Mar 2024 10:03:55 +0530 Subject: [PATCH 09/33] merge fix --- libs/langchain/langchain/agents/agent.py | 2 ++ libs/langchain/pyproject.toml | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/libs/langchain/langchain/agents/agent.py b/libs/langchain/langchain/agents/agent.py index ace9be8ccbcdb..6a8808e2ec231 100644 --- a/libs/langchain/langchain/agents/agent.py +++ b/libs/langchain/langchain/agents/agent.py @@ -1534,6 +1534,8 @@ def _take_next_step( self.state.append( {str(intermediate_steps[-1][0]): str(intermediate_steps[-1][1])} ) + elif self.state is not None and "few_shot_examples" in self.state[0].keys(): + pass else: self.state = [] diff --git a/libs/langchain/pyproject.toml b/libs/langchain/pyproject.toml index 373e433619e8a..5795197e7811e 100644 --- a/libs/langchain/pyproject.toml +++ b/libs/langchain/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langchain" -version = "0.1.33dev1" +version = "0.1.34dev1" description = "Building applications with LLMs through composability" authors = [] license = "MIT" From 225fbf4ace9cadbbe9229f57409f62471cf8a1a3 Mon Sep 17 00:00:00 2001 From: sushant Date: Tue, 19 Mar 2024 11:17:32 +0530 Subject: [PATCH 10/33] version change to accomodate 36dev1(uncommited code to fix unstructered query for genesis by arunRaja) --- libs/langchain/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/langchain/pyproject.toml b/libs/langchain/pyproject.toml index 5795197e7811e..13816d774878d 100644 --- a/libs/langchain/pyproject.toml +++ b/libs/langchain/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langchain" -version = "0.1.34dev1" +version = "0.1.37dev1" description = "Building applications with LLMs through composability" authors = [] license = "MIT" From 13cdeef10b845f495f48d589699b53da5e949806 Mon Sep 17 00:00:00 2001 From: sushant Date: Mon, 1 Apr 2024 19:14:11 +0530 Subject: [PATCH 11/33] add changes to sqlcoder prompt --- libs/langchain/langchain/tools/sqlcoder/prompt.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/libs/langchain/langchain/tools/sqlcoder/prompt.py b/libs/langchain/langchain/tools/sqlcoder/prompt.py index d7d7fa3184cb9..ddba23438ea22 100644 --- a/libs/langchain/langchain/tools/sqlcoder/prompt.py +++ b/libs/langchain/langchain/tools/sqlcoder/prompt.py @@ -7,7 +7,7 @@ You are an AI research assistant in the senior living industry. You have access to a database that contains the information about different communities, their amenities, residents, expenses, budget, revenue and other finances, facilities, beds, events When querying the database, given an input question, create a syntactically correct query to run, then look at the results of the query and return the answer. -Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 10 results. +Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 30 results. You can order the results by a relevant column to return the most interesting examples in the database. Never query for all the columns from a specific table, only ask for the relevant columns given the question. DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. @@ -18,6 +18,7 @@ '{db_schema}' Use the following examples to generate the sql query: '{few_shot_examples}' +Unless specified in the user input, always limit your query to 30 results ### Response: Based on your instructions, here is the SQL query I have generated to answer '{user_input}' ```sql""" @@ -25,6 +26,5 @@ SQL_QUERY_CREATOR_RETRY = """ You have failed in the first attempt to generate correct sql query. Please try again to generate correct sql query. -The previously generated queries are {sql_query}. -Make sure you create right query by using the schema of the table and few_shot_examples. +Make sure you create right query by using the {db_schema}, {few_shot_examples} and do not repeat any query from the previously generated queries of {sql_query}. """ \ No newline at end of file From 39ec7f77782062959719bf8eee66b99ded9b6bdc Mon Sep 17 00:00:00 2001 From: sushant Date: Wed, 3 Apr 2024 13:48:10 +0530 Subject: [PATCH 12/33] version update --- libs/langchain/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/langchain/pyproject.toml b/libs/langchain/pyproject.toml index 92b2da9c718d3..ea64f153976e8 100644 --- a/libs/langchain/pyproject.toml +++ b/libs/langchain/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langchain" -version = "0.1.41dev1" +version = "0.1.42dev1" description = "Building applications with LLMs through composability" authors = [] license = "MIT" From 099962abfcdfb8a3408746596ceb8684879b4107 Mon Sep 17 00:00:00 2001 From: sushant Date: Thu, 4 Apr 2024 09:54:24 +0530 Subject: [PATCH 13/33] =?UTF-8?q?added=20structured=20decomposition?= =?UTF-8?q?=F0=9F=98=81=20to=20sherloq?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../langchain/tools/sqlcoder/prompt.py | 45 ++++++++++++++++--- .../langchain/tools/sqlcoder/tool.py | 25 +++++++---- libs/langchain/pyproject.toml | 2 +- 3 files changed, 58 insertions(+), 14 deletions(-) diff --git a/libs/langchain/langchain/tools/sqlcoder/prompt.py b/libs/langchain/langchain/tools/sqlcoder/prompt.py index ddba23438ea22..865c636892794 100644 --- a/libs/langchain/langchain/tools/sqlcoder/prompt.py +++ b/libs/langchain/langchain/tools/sqlcoder/prompt.py @@ -2,13 +2,18 @@ Your task is convert a question into a SQL query, given a schema which is databricks sql compatible. Adhere to these rules: - **Deliberately go through the question and database schema word by word** to appropriately answer the question +- **Deliberately go through the question and database schema word by word** to ensure that correct column names and metric names are used to answer the question - **Use Table Aliases** to prevent ambiguity. For example, `SELECT table1.col1, table2.col1 FROM table1 JOIN table2 ON table1.id = table2.id`. - When creating a ratio, always cast the numerator as float You are an AI research assistant in the senior living industry. You have access to a database that contains the information about different communities, their amenities, residents, expenses, budget, revenue and other finances, facilities, beds, events -When querying the database, given an input question, create a syntactically correct query to run, then look at the results of the query and return the answer. +When querying the database, given an input question, create a syntactically correct query to run. Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 30 results. +Don't query living options if amenities are asked for. +Most sql queries would reduce into the Unified query below, you are free to change or remove the predicates of this query based on the question. When denominator is null, use the total amount to compute the aggregate. Remember to change Metric_Code and other predicates with relevant metric asked in the prompt. Use this unified query below whenever you are not sure about what query to form or you are facing errors in forming sql. +Unified Query: SELECT Business_Unit , Entity_Name , Entity_Type , LOB_01 , LOB_02 , LOB_03 , LOB_04 , LOB_05 , LOB_06 , LOB_07 , LOB_08 , LOB_09 , Date , SUM(Total_Amount) , SUM(Total_Amount_Numerator) AS SumNumerator , SUM(Total_Amount_Denominator) AS SumDenominator , (SumNumerator/NULLIF(SumDenominator, 0)) AS Average_Metric , Metric_Code, Metric_Name , Metric_Description , Metric_Frequency , Calculation_Description , Week_Num , Month , Month_Name , Month_Number , Year_Month , Quarter_Number , Year , Facility_Name , Health_Type FROM genesishealthcare_sandbox_main.genesishealthcare_sandbox.skypoint_metric_fact_denormalized_vw WHERE Metric_Code = 'DSM_M' AND Facility_Name LIKE '%' AND Date BETWEEN '2019-01-01 00:00:00' AND '2025-12-31 23:59:59' AND Year BETWEEN 2019 AND 2025 AND Month LIKE '%' AND Month_Name LIKE '%' AND Month_Number BETWEEN 1 AND 12 AND Facility_Name LIKE '%' AND Business_Unit LIKE '%' AND Entity_Name LIKE '%' AND Entity_Type LIKE '%' AND LOB_01 LIKE '%' AND LOB_02 LIKE '%' AND LOB_03 LIKE '%' AND LOB_04 LIKE '%' AND LOB_05 LIKE '%' AND LOB_06 LIKE '%' AND LOB_07 LIKE '%' AND LOB_08 LIKE '%' AND LOB_09 LIKE '%' AND Health_Type LIKE '%' AND ISNOTNULL(Total_Amount_Denominator) GROUP BY Business_Unit , Entity_Name , Entity_Type , Facility_Name , Health_Type , LOB_01 , LOB_02 , LOB_03 , LOB_04 , LOB_05 , LOB_06 , LOB_07 , LOB_08 , LOB_09 , Year , Quarter_Number , Year_Month , Month , Month_Name , Month_Number , Week_Num , Date , Metric_Code, Metric_Name , Metric_Description , Metric_Frequency , Calculation_Description ORDER BY Date, Entity_Name LIMIT 20 You can order the results by a relevant column to return the most interesting examples in the database. +'{data_model_context}' Never query for all the columns from a specific table, only ask for the relevant columns given the question. DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. @@ -24,7 +29,37 @@ ```sql""" -SQL_QUERY_CREATOR_RETRY = """ -You have failed in the first attempt to generate correct sql query. Please try again to generate correct sql query. -Make sure you create right query by using the {db_schema}, {few_shot_examples} and do not repeat any query from the previously generated queries of {sql_query}. -""" \ No newline at end of file +SQL_QUERY_CREATOR_RETRY = """ +### Instructions: +You have failed in the first attempt to generate correct sql query. Please try again to rewrite correct sql query. +Your task is convert a question into a SQL query, given a schema which is databricks sql compatible. +Adhere to these rules: +- **Deliberately go through the question and database schema word by word** to appropriately answer the question +- **Deliberately go through the '{sql_query}' and database schema word by word** to ensure that you get the correct column names and metric names to answer the question +- **Use Table Aliases** to prevent ambiguity. For example, `SELECT table1.col1, table2.col1 FROM table1 JOIN table2 ON table1.id = table2.id`. +- When creating a ratio, always cast the numerator as float +You are an AI research assistant in the senior living industry. +You have access to a database that contains the information about different communities, their amenities, residents, expenses, budget, revenue and other finances, facilities, beds, events +When querying the database, given an input question, create a syntactically correct query to run. +Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 30 results. +Don't query living options if amenities are asked for. +Most sql queries would reduce into the Unified query below, you are free to change or remove the predicates of this query based on the question. When denominator is null, use the total amount to compute the aggregate. Remember to change Metric_Code and other predicates with relevant metric asked in the prompt. Use this unified query below whenever you are not sure about what query to form or you are facing errors in forming sql. +Unified Query: SELECT Business_Unit , Entity_Name , Entity_Type , LOB_01 , LOB_02 , LOB_03 , LOB_04 , LOB_05 , LOB_06 , LOB_07 , LOB_08 , LOB_09 , Date , SUM(Total_Amount) , SUM(Total_Amount_Numerator) AS SumNumerator , SUM(Total_Amount_Denominator) AS SumDenominator , (SumNumerator/NULLIF(SumDenominator, 0)) AS Average_Metric , Metric_Code, Metric_Name , Metric_Description , Metric_Frequency , Calculation_Description , Week_Num , Month , Month_Name , Month_Number , Year_Month , Quarter_Number , Year , Facility_Name , Health_Type FROM genesishealthcare_sandbox_main.genesishealthcare_sandbox.skypoint_metric_fact_denormalized_vw WHERE Metric_Code = 'DSM_M' AND Facility_Name LIKE '%' AND Date BETWEEN '2019-01-01 00:00:00' AND '2025-12-31 23:59:59' AND Year BETWEEN 2019 AND 2025 AND Month LIKE '%' AND Month_Name LIKE '%' AND Month_Number BETWEEN 1 AND 12 AND Facility_Name LIKE '%' AND Business_Unit LIKE '%' AND Entity_Name LIKE '%' AND Entity_Type LIKE '%' AND LOB_01 LIKE '%' AND LOB_02 LIKE '%' AND LOB_03 LIKE '%' AND LOB_04 LIKE '%' AND LOB_05 LIKE '%' AND LOB_06 LIKE '%' AND LOB_07 LIKE '%' AND LOB_08 LIKE '%' AND LOB_09 LIKE '%' AND Health_Type LIKE '%' AND ISNOTNULL(Total_Amount_Denominator) GROUP BY Business_Unit , Entity_Name , Entity_Type , Facility_Name , Health_Type , LOB_01 , LOB_02 , LOB_03 , LOB_04 , LOB_05 , LOB_06 , LOB_07 , LOB_08 , LOB_09 , Year , Quarter_Number , Year_Month , Month , Month_Name , Month_Number , Week_Num , Date , Metric_Code, Metric_Name , Metric_Description , Metric_Frequency , Calculation_Description ORDER BY Date, Entity_Name LIMIT 20 +Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 30 results. +You can order the results by a relevant column to return the most interesting examples in the database. +'{data_model_context}' +Never query for all the columns from a specific table, only ask for the relevant columns given the question. +DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. + +Try different queries with different variations of the strings like synonyms, abbreviations, singular/plural forms of words in the string, etc. +Donot create sql queries on your own . Always use few_shotted_examples to create sql queries. + +### Input: +Generate a SQL query that answers the question `{user_input}`. +This query will run on a database whose schema is represented in this string: +'{db_schema}' +Use the following examples to generate the sql query: +'{few_shot_examples}' +### Response: +Based on your instructions, here is the SQL query I have generated to answer '{user_input}' +```sql""" \ No newline at end of file diff --git a/libs/langchain/langchain/tools/sqlcoder/tool.py b/libs/langchain/langchain/tools/sqlcoder/tool.py index 1c9ce04265223..6ebea4bda1976 100644 --- a/libs/langchain/langchain/tools/sqlcoder/tool.py +++ b/libs/langchain/langchain/tools/sqlcoder/tool.py @@ -135,15 +135,22 @@ def _parse_db_schema(self): db_schema = input_string return db_schema + def _parse_data_model_context(self): + data_model_context = "" + for value in self.state: + for key, input_string in value.items(): + if "data_model_context" in key: + data_model_context = input_string + return data_model_context def _create_sql_query(self,user_input): few_shot_examples = self._parse_few_shot_examples() db_schema = self._parse_db_schema() sql_query = self._extract_sql_query() - + data_model_context = self._parse_data_model_context() if sql_query is None: prompt_input = PromptTemplate( - input_variables=["db_schema", "user_input", "few_shot_examples"], + input_variables=["db_schema", "user_input", "few_shot_examples","data_model_context"], template=SQL_QUERY_CREATOR, ) query_creator_chain = LLMChain(llm=self.sqlcreatorllm, prompt=prompt_input) @@ -154,14 +161,14 @@ def _create_sql_query(self,user_input): "db_schema": db_schema, "user_input": user_input, "few_shot_examples": few_shot_examples, + "data_model_context": data_model_context } ) ) - sql_query = sql_query.replace("```","") else: prompt_input = PromptTemplate( - input_variables=["db_schema", "user_input", "few_shot_examples","sql_query"], - template=SQL_QUERY_CREATOR + SQL_QUERY_CREATOR_RETRY, + input_variables=["db_schema", "user_input", "few_shot_examples","sql_query","data_model_context"], + template=SQL_QUERY_CREATOR_RETRY, ) query_creator_chain = LLMChain(llm=self.sqlcreatorllm, prompt=prompt_input) @@ -171,11 +178,13 @@ def _create_sql_query(self,user_input): "db_schema": db_schema, "user_input": user_input, "few_shot_examples": few_shot_examples, - "sql_query": sql_query + "sql_query": sql_query, + "data_model_context": data_model_context + } ) ) - sql_query = sql_query.replace("```","") + sql_query = sql_query.replace("```","") if hasattr(self, "state"): self.state.append({"sql_db_query_creator": sql_query}) @@ -184,6 +193,6 @@ def _create_sql_query(self,user_input): def _extract_sql_query(self): for value in self.state: for key, input_string in value.items(): - if "sql_db_query_creator" in key: + if key == "sql_db_query_creator": return input_string return None \ No newline at end of file diff --git a/libs/langchain/pyproject.toml b/libs/langchain/pyproject.toml index ea64f153976e8..1aad54489eecc 100644 --- a/libs/langchain/pyproject.toml +++ b/libs/langchain/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langchain" -version = "0.1.42dev1" +version = "0.1.43dev1" description = "Building applications with LLMs through composability" authors = [] license = "MIT" From 20d3f78c04f768bd642f3be3ad3f8b24046c28a8 Mon Sep 17 00:00:00 2001 From: sushant Date: Thu, 4 Apr 2024 18:48:22 +0530 Subject: [PATCH 14/33] sqlcoder prompt updated --- libs/langchain/langchain/tools/sqlcoder/prompt.py | 8 ++++---- libs/langchain/pyproject.toml | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/libs/langchain/langchain/tools/sqlcoder/prompt.py b/libs/langchain/langchain/tools/sqlcoder/prompt.py index 865c636892794..81bc7fa18dc5e 100644 --- a/libs/langchain/langchain/tools/sqlcoder/prompt.py +++ b/libs/langchain/langchain/tools/sqlcoder/prompt.py @@ -2,7 +2,6 @@ Your task is convert a question into a SQL query, given a schema which is databricks sql compatible. Adhere to these rules: - **Deliberately go through the question and database schema word by word** to appropriately answer the question -- **Deliberately go through the question and database schema word by word** to ensure that correct column names and metric names are used to answer the question - **Use Table Aliases** to prevent ambiguity. For example, `SELECT table1.col1, table2.col1 FROM table1 JOIN table2 ON table1.id = table2.id`. - When creating a ratio, always cast the numerator as float You are an AI research assistant in the senior living industry. @@ -13,7 +12,6 @@ Most sql queries would reduce into the Unified query below, you are free to change or remove the predicates of this query based on the question. When denominator is null, use the total amount to compute the aggregate. Remember to change Metric_Code and other predicates with relevant metric asked in the prompt. Use this unified query below whenever you are not sure about what query to form or you are facing errors in forming sql. Unified Query: SELECT Business_Unit , Entity_Name , Entity_Type , LOB_01 , LOB_02 , LOB_03 , LOB_04 , LOB_05 , LOB_06 , LOB_07 , LOB_08 , LOB_09 , Date , SUM(Total_Amount) , SUM(Total_Amount_Numerator) AS SumNumerator , SUM(Total_Amount_Denominator) AS SumDenominator , (SumNumerator/NULLIF(SumDenominator, 0)) AS Average_Metric , Metric_Code, Metric_Name , Metric_Description , Metric_Frequency , Calculation_Description , Week_Num , Month , Month_Name , Month_Number , Year_Month , Quarter_Number , Year , Facility_Name , Health_Type FROM genesishealthcare_sandbox_main.genesishealthcare_sandbox.skypoint_metric_fact_denormalized_vw WHERE Metric_Code = 'DSM_M' AND Facility_Name LIKE '%' AND Date BETWEEN '2019-01-01 00:00:00' AND '2025-12-31 23:59:59' AND Year BETWEEN 2019 AND 2025 AND Month LIKE '%' AND Month_Name LIKE '%' AND Month_Number BETWEEN 1 AND 12 AND Facility_Name LIKE '%' AND Business_Unit LIKE '%' AND Entity_Name LIKE '%' AND Entity_Type LIKE '%' AND LOB_01 LIKE '%' AND LOB_02 LIKE '%' AND LOB_03 LIKE '%' AND LOB_04 LIKE '%' AND LOB_05 LIKE '%' AND LOB_06 LIKE '%' AND LOB_07 LIKE '%' AND LOB_08 LIKE '%' AND LOB_09 LIKE '%' AND Health_Type LIKE '%' AND ISNOTNULL(Total_Amount_Denominator) GROUP BY Business_Unit , Entity_Name , Entity_Type , Facility_Name , Health_Type , LOB_01 , LOB_02 , LOB_03 , LOB_04 , LOB_05 , LOB_06 , LOB_07 , LOB_08 , LOB_09 , Year , Quarter_Number , Year_Month , Month , Month_Name , Month_Number , Week_Num , Date , Metric_Code, Metric_Name , Metric_Description , Metric_Frequency , Calculation_Description ORDER BY Date, Entity_Name LIMIT 20 You can order the results by a relevant column to return the most interesting examples in the database. -'{data_model_context}' Never query for all the columns from a specific table, only ask for the relevant columns given the question. DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. @@ -21,6 +19,7 @@ Generate a SQL query that answers the question `{user_input}`. This query will run on a database whose schema is represented in this string: '{db_schema}' +'{data_model_context}' Use the following examples to generate the sql query: '{few_shot_examples}' Unless specified in the user input, always limit your query to 30 results @@ -47,17 +46,18 @@ Unified Query: SELECT Business_Unit , Entity_Name , Entity_Type , LOB_01 , LOB_02 , LOB_03 , LOB_04 , LOB_05 , LOB_06 , LOB_07 , LOB_08 , LOB_09 , Date , SUM(Total_Amount) , SUM(Total_Amount_Numerator) AS SumNumerator , SUM(Total_Amount_Denominator) AS SumDenominator , (SumNumerator/NULLIF(SumDenominator, 0)) AS Average_Metric , Metric_Code, Metric_Name , Metric_Description , Metric_Frequency , Calculation_Description , Week_Num , Month , Month_Name , Month_Number , Year_Month , Quarter_Number , Year , Facility_Name , Health_Type FROM genesishealthcare_sandbox_main.genesishealthcare_sandbox.skypoint_metric_fact_denormalized_vw WHERE Metric_Code = 'DSM_M' AND Facility_Name LIKE '%' AND Date BETWEEN '2019-01-01 00:00:00' AND '2025-12-31 23:59:59' AND Year BETWEEN 2019 AND 2025 AND Month LIKE '%' AND Month_Name LIKE '%' AND Month_Number BETWEEN 1 AND 12 AND Facility_Name LIKE '%' AND Business_Unit LIKE '%' AND Entity_Name LIKE '%' AND Entity_Type LIKE '%' AND LOB_01 LIKE '%' AND LOB_02 LIKE '%' AND LOB_03 LIKE '%' AND LOB_04 LIKE '%' AND LOB_05 LIKE '%' AND LOB_06 LIKE '%' AND LOB_07 LIKE '%' AND LOB_08 LIKE '%' AND LOB_09 LIKE '%' AND Health_Type LIKE '%' AND ISNOTNULL(Total_Amount_Denominator) GROUP BY Business_Unit , Entity_Name , Entity_Type , Facility_Name , Health_Type , LOB_01 , LOB_02 , LOB_03 , LOB_04 , LOB_05 , LOB_06 , LOB_07 , LOB_08 , LOB_09 , Year , Quarter_Number , Year_Month , Month , Month_Name , Month_Number , Week_Num , Date , Metric_Code, Metric_Name , Metric_Description , Metric_Frequency , Calculation_Description ORDER BY Date, Entity_Name LIMIT 20 Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 30 results. You can order the results by a relevant column to return the most interesting examples in the database. -'{data_model_context}' Never query for all the columns from a specific table, only ask for the relevant columns given the question. DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. Try different queries with different variations of the strings like synonyms, abbreviations, singular/plural forms of words in the string, etc. -Donot create sql queries on your own . Always use few_shotted_examples to create sql queries. + ### Input: Generate a SQL query that answers the question `{user_input}`. This query will run on a database whose schema is represented in this string: '{db_schema}' +'{data_model_context}' +Replace Entity_Name with 'Facility_Name' in SQL query. Use the following examples to generate the sql query: '{few_shot_examples}' ### Response: diff --git a/libs/langchain/pyproject.toml b/libs/langchain/pyproject.toml index 1aad54489eecc..e4e3ce7377c24 100644 --- a/libs/langchain/pyproject.toml +++ b/libs/langchain/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langchain" -version = "0.1.43dev1" +version = "0.1.44dev1" description = "Building applications with LLMs through composability" authors = [] license = "MIT" From f145added6f13631bdd843747ca34977ffbea273 Mon Sep 17 00:00:00 2001 From: sushant Date: Mon, 15 Apr 2024 16:08:39 +0530 Subject: [PATCH 15/33] prompt updated sqlcoder --- .../langchain/tools/sqlcoder/prompt.py | 37 +++++++++++++++++-- libs/langchain/pyproject.toml | 2 +- 2 files changed, 35 insertions(+), 4 deletions(-) diff --git a/libs/langchain/langchain/tools/sqlcoder/prompt.py b/libs/langchain/langchain/tools/sqlcoder/prompt.py index 81bc7fa18dc5e..d1e1c20d80c63 100644 --- a/libs/langchain/langchain/tools/sqlcoder/prompt.py +++ b/libs/langchain/langchain/tools/sqlcoder/prompt.py @@ -10,10 +10,11 @@ Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 30 results. Don't query living options if amenities are asked for. Most sql queries would reduce into the Unified query below, you are free to change or remove the predicates of this query based on the question. When denominator is null, use the total amount to compute the aggregate. Remember to change Metric_Code and other predicates with relevant metric asked in the prompt. Use this unified query below whenever you are not sure about what query to form or you are facing errors in forming sql. -Unified Query: SELECT Business_Unit , Entity_Name , Entity_Type , LOB_01 , LOB_02 , LOB_03 , LOB_04 , LOB_05 , LOB_06 , LOB_07 , LOB_08 , LOB_09 , Date , SUM(Total_Amount) , SUM(Total_Amount_Numerator) AS SumNumerator , SUM(Total_Amount_Denominator) AS SumDenominator , (SumNumerator/NULLIF(SumDenominator, 0)) AS Average_Metric , Metric_Code, Metric_Name , Metric_Description , Metric_Frequency , Calculation_Description , Week_Num , Month , Month_Name , Month_Number , Year_Month , Quarter_Number , Year , Facility_Name , Health_Type FROM genesishealthcare_sandbox_main.genesishealthcare_sandbox.skypoint_metric_fact_denormalized_vw WHERE Metric_Code = 'DSM_M' AND Facility_Name LIKE '%' AND Date BETWEEN '2019-01-01 00:00:00' AND '2025-12-31 23:59:59' AND Year BETWEEN 2019 AND 2025 AND Month LIKE '%' AND Month_Name LIKE '%' AND Month_Number BETWEEN 1 AND 12 AND Facility_Name LIKE '%' AND Business_Unit LIKE '%' AND Entity_Name LIKE '%' AND Entity_Type LIKE '%' AND LOB_01 LIKE '%' AND LOB_02 LIKE '%' AND LOB_03 LIKE '%' AND LOB_04 LIKE '%' AND LOB_05 LIKE '%' AND LOB_06 LIKE '%' AND LOB_07 LIKE '%' AND LOB_08 LIKE '%' AND LOB_09 LIKE '%' AND Health_Type LIKE '%' AND ISNOTNULL(Total_Amount_Denominator) GROUP BY Business_Unit , Entity_Name , Entity_Type , Facility_Name , Health_Type , LOB_01 , LOB_02 , LOB_03 , LOB_04 , LOB_05 , LOB_06 , LOB_07 , LOB_08 , LOB_09 , Year , Quarter_Number , Year_Month , Month , Month_Name , Month_Number , Week_Num , Date , Metric_Code, Metric_Name , Metric_Description , Metric_Frequency , Calculation_Description ORDER BY Date, Entity_Name LIMIT 20 +Unified Query: SELECT Business_Unit , Entity_Name , Entity_Type , LOB_01 , LOB_02 , LOB_03 , LOB_04 , LOB_05 , LOB_06 , LOB_07 , LOB_08 , LOB_09 , Date , SUM(Total_Amount) , SUM(Total_Amount_Numerator) AS SumNumerator , SUM(Total_Amount_Denominator) AS SumDenominator , (SumNumerator/NULLIF(SumDenominator, 0)) AS Average_Metric , Metric_Code, Metric_Name , Metric_Description , Metric_Frequency , Calculation_Description , Week_Num , Month , Month_Name , Month_Number , Year_Month , Quarter_Number , Year , Facility_Name , Health_Type FROM skypoint_metric_fact_denormalized_vw WHERE Metric_Code = 'DSM_M' AND Facility_Name LIKE '%' AND Date BETWEEN '2019-01-01 00:00:00' AND '2025-12-31 23:59:59' AND Year BETWEEN 2019 AND 2025 AND Month LIKE '%' AND Month_Name LIKE '%' AND Month_Number BETWEEN 1 AND 12 AND Facility_Name LIKE '%' AND Business_Unit LIKE '%' AND Entity_Name LIKE '%' AND Entity_Type LIKE '%' AND LOB_01 LIKE '%' AND LOB_02 LIKE '%' AND LOB_03 LIKE '%' AND LOB_04 LIKE '%' AND LOB_05 LIKE '%' AND LOB_06 LIKE '%' AND LOB_07 LIKE '%' AND LOB_08 LIKE '%' AND LOB_09 LIKE '%' AND Health_Type LIKE '%' AND ISNOTNULL(Total_Amount_Denominator) GROUP BY Business_Unit , Entity_Name , Entity_Type , Facility_Name , Health_Type , LOB_01 , LOB_02 , LOB_03 , LOB_04 , LOB_05 , LOB_06 , LOB_07 , LOB_08 , LOB_09 , Year , Quarter_Number , Year_Month , Month , Month_Name , Month_Number , Week_Num , Date , Metric_Code, Metric_Name , Metric_Description , Metric_Frequency , Calculation_Description, Entity_Name LIMIT 20 You can order the results by a relevant column to return the most interesting examples in the database. Never query for all the columns from a specific table, only ask for the relevant columns given the question. DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. +Create the query for the specified year/month asked in the '{user_input}' ### Input: Generate a SQL query that answers the question `{user_input}`. @@ -43,7 +44,7 @@ Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 30 results. Don't query living options if amenities are asked for. Most sql queries would reduce into the Unified query below, you are free to change or remove the predicates of this query based on the question. When denominator is null, use the total amount to compute the aggregate. Remember to change Metric_Code and other predicates with relevant metric asked in the prompt. Use this unified query below whenever you are not sure about what query to form or you are facing errors in forming sql. -Unified Query: SELECT Business_Unit , Entity_Name , Entity_Type , LOB_01 , LOB_02 , LOB_03 , LOB_04 , LOB_05 , LOB_06 , LOB_07 , LOB_08 , LOB_09 , Date , SUM(Total_Amount) , SUM(Total_Amount_Numerator) AS SumNumerator , SUM(Total_Amount_Denominator) AS SumDenominator , (SumNumerator/NULLIF(SumDenominator, 0)) AS Average_Metric , Metric_Code, Metric_Name , Metric_Description , Metric_Frequency , Calculation_Description , Week_Num , Month , Month_Name , Month_Number , Year_Month , Quarter_Number , Year , Facility_Name , Health_Type FROM genesishealthcare_sandbox_main.genesishealthcare_sandbox.skypoint_metric_fact_denormalized_vw WHERE Metric_Code = 'DSM_M' AND Facility_Name LIKE '%' AND Date BETWEEN '2019-01-01 00:00:00' AND '2025-12-31 23:59:59' AND Year BETWEEN 2019 AND 2025 AND Month LIKE '%' AND Month_Name LIKE '%' AND Month_Number BETWEEN 1 AND 12 AND Facility_Name LIKE '%' AND Business_Unit LIKE '%' AND Entity_Name LIKE '%' AND Entity_Type LIKE '%' AND LOB_01 LIKE '%' AND LOB_02 LIKE '%' AND LOB_03 LIKE '%' AND LOB_04 LIKE '%' AND LOB_05 LIKE '%' AND LOB_06 LIKE '%' AND LOB_07 LIKE '%' AND LOB_08 LIKE '%' AND LOB_09 LIKE '%' AND Health_Type LIKE '%' AND ISNOTNULL(Total_Amount_Denominator) GROUP BY Business_Unit , Entity_Name , Entity_Type , Facility_Name , Health_Type , LOB_01 , LOB_02 , LOB_03 , LOB_04 , LOB_05 , LOB_06 , LOB_07 , LOB_08 , LOB_09 , Year , Quarter_Number , Year_Month , Month , Month_Name , Month_Number , Week_Num , Date , Metric_Code, Metric_Name , Metric_Description , Metric_Frequency , Calculation_Description ORDER BY Date, Entity_Name LIMIT 20 +Unified Query: SELECT Business_Unit , Entity_Name , Entity_Type , LOB_01 , LOB_02 , LOB_03 , LOB_04 , LOB_05 , LOB_06 , LOB_07 , LOB_08 , LOB_09 , Date , SUM(Total_Amount) , SUM(Total_Amount_Numerator) AS SumNumerator , SUM(Total_Amount_Denominator) AS SumDenominator , (SumNumerator/NULLIF(SumDenominator, 0)) AS Average_Metric , Metric_Code, Metric_Name , Metric_Description , Metric_Frequency , Calculation_Description , Week_Num , Month , Month_Name , Month_Number , Year_Month , Quarter_Number , Year , Facility_Name , Health_Type FROM skypoint_metric_fact_denormalized_vw WHERE Metric_Code = 'DSM_M' AND Facility_Name LIKE '%' AND Date BETWEEN '2019-01-01 00:00:00' AND '2025-12-31 23:59:59' AND Year BETWEEN 2019 AND 2025 AND Month LIKE '%' AND Month_Name LIKE '%' AND Month_Number BETWEEN 1 AND 12 AND Facility_Name LIKE '%' AND Business_Unit LIKE '%' AND Entity_Name LIKE '%' AND Entity_Type LIKE '%' AND LOB_01 LIKE '%' AND LOB_02 LIKE '%' AND LOB_03 LIKE '%' AND LOB_04 LIKE '%' AND LOB_05 LIKE '%' AND LOB_06 LIKE '%' AND LOB_07 LIKE '%' AND LOB_08 LIKE '%' AND LOB_09 LIKE '%' AND Health_Type LIKE '%' AND ISNOTNULL(Total_Amount_Denominator) GROUP BY Business_Unit , Entity_Name , Entity_Type , Facility_Name , Health_Type , LOB_01 , LOB_02 , LOB_03 , LOB_04 , LOB_05 , LOB_06 , LOB_07 , LOB_08 , LOB_09 , Year , Quarter_Number , Year_Month , Month , Month_Name , Month_Number , Week_Num , Date , Metric_Code, Metric_Name , Metric_Description , Metric_Frequency , Calculation_Description , Entity_Name LIMIT 20 Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 30 results. You can order the results by a relevant column to return the most interesting examples in the database. Never query for all the columns from a specific table, only ask for the relevant columns given the question. @@ -62,4 +63,34 @@ '{few_shot_examples}' ### Response: Based on your instructions, here is the SQL query I have generated to answer '{user_input}' -```sql""" \ No newline at end of file +```sql""" + +SQL_QUERY_CREATOR_7b = """ +### Instructions: +Your task is convert a question into a SQL query, given a schema which is databricks sql compatible. +Adhere to these rules: +- **Deliberately go through the question and database schema word by word** to appropriately answer the question +- **Use Table Aliases** to prevent ambiguity. For example, `SELECT table1.col1, table2.col1 FROM table1 JOIN table2 ON table1.id = table2.id`. +- When creating a ratio, always cast the numerator as float +You are an AI research assistant in the senior living industry. +You have access to a database that contains the information about different communities, their amenities, residents, expenses, budget, revenue and other finances, facilities, beds, events +When querying the database, given an input question, create a syntactically correct query to run. +Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 30 results. +Don't query living options if amenities are asked for. +Most sql queries would reduce into the Unified query below, you are free to change or remove the predicates of this query based on the question. When denominator is null, use the total amount to compute the aggregate. Remember to change Metric_Code and other predicates with relevant metric asked in the prompt. Use this unified query below whenever you are not sure about what query to form or you are facing errors in forming sql. +Unified Query: SELECT Business_Unit , Entity_Name , Entity_Type , LOB_01 , LOB_02 , LOB_03 , LOB_04 , LOB_05 , LOB_06 , LOB_07 , LOB_08 , LOB_09 , Date , SUM(Total_Amount) , SUM(Total_Amount_Numerator) AS SumNumerator , SUM(Total_Amount_Denominator) AS SumDenominator , (SumNumerator/NULLIF(SumDenominator, 0)) AS Average_Metric , Metric_Code, Metric_Name , Metric_Description , Metric_Frequency , Calculation_Description , Week_Num , Month , Month_Name , Month_Number , Year_Month , Quarter_Number , Year , Facility_Name , Health_Type FROM genesishealthcare_sandbox_main.genesishealthcare_sandbox.skypoint_metric_fact_denormalized_vw WHERE Metric_Code = 'DSM_M' AND Facility_Name LIKE '%' AND Date BETWEEN '2019-01-01 00:00:00' AND '2025-12-31 23:59:59' AND Year BETWEEN 2019 AND 2025 AND Month LIKE '%' AND Month_Name LIKE '%' AND Month_Number BETWEEN 1 AND 12 AND Facility_Name LIKE '%' AND Business_Unit LIKE '%' AND Entity_Name LIKE '%' AND Entity_Type LIKE '%' AND LOB_01 LIKE '%' AND LOB_02 LIKE '%' AND LOB_03 LIKE '%' AND LOB_04 LIKE '%' AND LOB_05 LIKE '%' AND LOB_06 LIKE '%' AND LOB_07 LIKE '%' AND LOB_08 LIKE '%' AND LOB_09 LIKE '%' AND Health_Type LIKE '%' AND ISNOTNULL(Total_Amount_Denominator) GROUP BY Business_Unit , Entity_Name , Entity_Type , Facility_Name , Health_Type , LOB_01 , LOB_02 , LOB_03 , LOB_04 , LOB_05 , LOB_06 , LOB_07 , LOB_08 , LOB_09 , Year , Quarter_Number , Year_Month , Month , Month_Name , Month_Number , Week_Num , Date , Metric_Code, Metric_Name , Metric_Description , Metric_Frequency , Calculation_Description, Entity_Name LIMIT 20 +You can order the results by a relevant column to return the most interesting examples in the database. +Never query for all the columns from a specific table, only ask for the relevant columns given the question. +DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. + +### Task: +Generate a SQL query that answers the question [QUESTION]`{user_input}`[/QUESTION]. +This query will run on a database whose schema is represented in this string: +'{db_schema}' +'{data_model_context}' +Use the following examples to generate the sql query: +'{few_shot_examples}' +Unless specified in the user input, always limit your query to 30 results +### Response: +Based on your instructions, here is the SQL query I have generated to answer [QUESTION]`{user_input}`[/QUESTION] +[SQL]""" \ No newline at end of file diff --git a/libs/langchain/pyproject.toml b/libs/langchain/pyproject.toml index e4e3ce7377c24..4330f8d122823 100644 --- a/libs/langchain/pyproject.toml +++ b/libs/langchain/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langchain" -version = "0.1.44dev1" +version = "0.1.45dev1" description = "Building applications with LLMs through composability" authors = [] license = "MIT" From dc8908563fe54766cb1a656059c173b5b2898b57 Mon Sep 17 00:00:00 2001 From: arunraja1 Date: Sun, 5 May 2024 09:31:42 +0530 Subject: [PATCH 16/33] Add CustomPlanandSolveExecutor class to agent.py --- libs/langchain/langchain/agents/agent.py | 162 +++++++++++++++++++++++ 1 file changed, 162 insertions(+) diff --git a/libs/langchain/langchain/agents/agent.py b/libs/langchain/langchain/agents/agent.py index 6a8808e2ec231..5d679750d4bbf 100644 --- a/libs/langchain/langchain/agents/agent.py +++ b/libs/langchain/langchain/agents/agent.py @@ -1622,3 +1622,165 @@ def _take_next_step( result.append((agent_action, observation)) return result + +class CustomPlanandSolveExecutor(AgentExecutor): + """Agent that is using tools.""" + + agent: Union[BaseSingleActionAgent, BaseMultiActionAgent] + """The agent to run for creating a plan and determining actions + to take at each step of the execution loop.""" + tools: Sequence[BaseTool] + """The valid tools the agent can call.""" + return_intermediate_steps: bool = False + """Whether to return the agent's trajectory of intermediate steps + at the end in addition to the final output.""" + max_iterations: Optional[int] = 15 + """The maximum number of steps to take before ending the execution + loop. Setting to 'None' could lead to an infinite loop.""" + + max_execution_time: Optional[float] = None + """The maximum amount of wall clock time to spend in the execution + loop. + """ + early_stopping_method: str = "force" + """The method to use for early stopping if the agent never + returns `AgentFinish`. Either 'force' or 'generate'. + `"force"` returns a string saying that it stopped because it met a + time or iteration limit. + `"generate"` calls the agent's LLM Chain one final time to generate + a final answer based on the previous steps. + """ + handle_parsing_errors: Union[ + bool, str, Callable[[OutputParserException], str] + ] = False + """How to handle errors raised by the agent's output parser. + Defaults to `False`, which raises the error. +s + If `true`, the error will be sent back to the LLM as an observation. + If a string, the string itself will be sent to the LLM as an observation. + If a callable function, the function will be called with the exception + as an argument, and the result of that function will be passed to the agent + as an observation. + """ + state: List[Dict[str, Any]] = None + trim_intermediate_steps: Union[ + int, Callable[[List[Tuple[AgentAction, str]]], List[Tuple[AgentAction, str]]] + ] = -1 + + def _take_next_step( + self, + name_to_tool_map: Dict[str, StateTool], + color_mapping: Dict[str, str], + inputs: Dict[str, str], + intermediate_steps: List[Tuple[AgentAction, str]], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Union[AgentFinish, List[Tuple[AgentAction, str]]]: + """Take a single step in the thought-action-observation loop. + Override this to take control of how the agent makes and acts on choices. + """ + try: + intermediate_steps = self._prepare_intermediate_steps(intermediate_steps) + # Check the intermediate_steps and add the last one to the state + if intermediate_steps: + try: + tool_name = intermediate_steps[-1][0].dict().get("tool",None) + tool_input = intermediate_steps[-1][0].dict().get("tool_input",None) + sql_query = json.loads(intermediate_steps[-1][1][0]).get("sql_query",None) + answer = json.loads(intermediate_steps[-1][1][0]).get("answer",None) + + self.state.append( + { + "tool": tool_name, + "tool_input": tool_input, + "sql_query": sql_query, + "answer": answer + + } + ) + except: + self.state = [] + else: + self.state = [] + + # Call the LLM to see what to do. + output = self.agent.plan( + intermediate_steps, + callbacks=run_manager.get_child() if run_manager else None, + **inputs, + ) + except OutputParserException as e: + if isinstance(self.handle_parsing_errors, bool): + raise_error = not self.handle_parsing_errors + else: + raise_error = False + if raise_error: + raise e + text = str(e) + if isinstance(self.handle_parsing_errors, bool): + if e.send_to_llm: + observation = str(e.observation) + text = str(e.llm_output) + else: + observation = "Invalid or incomplete response" + elif isinstance(self.handle_parsing_errors, str): + observation = self.handle_parsing_errors + elif callable(self.handle_parsing_errors): + observation = self.handle_parsing_errors(e) + else: + raise ValueError("Got unexpected type of `handle_parsing_errors`") + output = AgentAction("_Exception", observation, text) + if run_manager: + run_manager.on_agent_action(output, color="green") + tool_run_kwargs = self.agent.tool_run_logging_kwargs() + observation = ExceptionTool().run( + output.tool_input, + verbose=self.verbose, + color=None, + callbacks=run_manager.get_child() if run_manager else None, + **tool_run_kwargs, + ) + return [(output, observation)] + # If the tool chosen is the finishing tool, then we end and return. + if isinstance(output, AgentFinish): + return output + actions: List[AgentAction] + if isinstance(output, AgentAction): + actions = [output] + else: + actions = output + result = [] + for agent_action in actions: + if run_manager: + run_manager.on_agent_action(agent_action, color="green") + # Otherwise we lookup the tool + if agent_action.tool in name_to_tool_map: + tool = name_to_tool_map[agent_action.tool] + return_direct = tool.return_direct + color = color_mapping[agent_action.tool] + tool_run_kwargs = self.agent.tool_run_logging_kwargs() + if return_direct: + tool_run_kwargs["llm_prefix"] = "" + # We then call the tool on the tool input to get an observation + observation = tool.run( + agent_action.tool_input, + verbose=self.verbose, + color=color, + state=self.state, + callbacks=run_manager.get_child() + if run_manager + else None**tool_run_kwargs, + ) + else: + tool_run_kwargs = self.agent.tool_run_logging_kwargs() + observation = InvalidTool().run( + { + "requested_tool_name": agent_action.tool, + "available_tool_names": list(name_to_tool_map.keys()), + }, + verbose=self.verbose, + color=None, + callbacks=run_manager.get_child() if run_manager else None, + **tool_run_kwargs, + ) + result.append((agent_action, observation)) + return result \ No newline at end of file From 14a50ad723e375602bbb0decafcf008d81568429 Mon Sep 17 00:00:00 2001 From: arunraja1 Date: Sun, 5 May 2024 10:11:23 +0530 Subject: [PATCH 17/33] Update langchain version to 0.1.46dev1 --- libs/langchain/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/langchain/pyproject.toml b/libs/langchain/pyproject.toml index 4330f8d122823..fb7db334d0d52 100644 --- a/libs/langchain/pyproject.toml +++ b/libs/langchain/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langchain" -version = "0.1.45dev1" +version = "0.1.46dev1" description = "Building applications with LLMs through composability" authors = [] license = "MIT" From 868474cf90a7c8c1fdce6ab7a5fd36f91b281a93 Mon Sep 17 00:00:00 2001 From: Sushant Burnawal Date: Mon, 6 May 2024 18:14:00 +0530 Subject: [PATCH 18/33] query mixing fixed --- .../langchain/langchain/tools/sqlcoder/tool.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/libs/langchain/langchain/tools/sqlcoder/tool.py b/libs/langchain/langchain/tools/sqlcoder/tool.py index 6ebea4bda1976..dcc1477a5777f 100644 --- a/libs/langchain/langchain/tools/sqlcoder/tool.py +++ b/libs/langchain/langchain/tools/sqlcoder/tool.py @@ -42,7 +42,7 @@ class Config(StateTool.Config): description: str = """ Input to this tool is a detailed and correct SQL query, output is a result from the database. If the query is not correct, an error message will be returned. - If an error is returned, rewrite the query, check the query, and try again. + If an error is returned, re-run the sql_db_query_creator tool to get the correct query. """ def __init__(__pydantic_self__, **data: Any) -> None: @@ -64,7 +64,8 @@ def _run( ) executable_query = executable_query.strip('\"') executable_query = re.sub('\\n```', '',executable_query) - return self.db.run_no_throw(executable_query) + query_response = self.db.run_no_throw(executable_query) + return query_response async def _arun( self, @@ -145,8 +146,8 @@ def _parse_data_model_context(self): def _create_sql_query(self,user_input): few_shot_examples = self._parse_few_shot_examples() - db_schema = self._parse_db_schema() sql_query = self._extract_sql_query() + db_schema = self._parse_db_schema() data_model_context = self._parse_data_model_context() if sql_query is None: prompt_input = PromptTemplate( @@ -167,8 +168,8 @@ def _create_sql_query(self,user_input): ) else: prompt_input = PromptTemplate( - input_variables=["db_schema", "user_input", "few_shot_examples","sql_query","data_model_context"], - template=SQL_QUERY_CREATOR_RETRY, + input_variables=["db_schema", "user_input", "few_shot_examples","data_model_context"], + template=SQL_QUERY_CREATOR_RETRY ) query_creator_chain = LLMChain(llm=self.sqlcreatorllm, prompt=prompt_input) @@ -178,21 +179,18 @@ def _create_sql_query(self,user_input): "db_schema": db_schema, "user_input": user_input, "few_shot_examples": few_shot_examples, - "sql_query": sql_query, "data_model_context": data_model_context - } ) ) sql_query = sql_query.replace("```","") - if hasattr(self, "state"): - self.state.append({"sql_db_query_creator": sql_query}) + return sql_query def _extract_sql_query(self): for value in self.state: for key, input_string in value.items(): - if key == "sql_db_query_creator": + if "sql_db_query_creator" in key: return input_string return None \ No newline at end of file From 7e3da7ff435b0d2c0414a2f506d286020d4da7e6 Mon Sep 17 00:00:00 2001 From: Sushant Burnawal Date: Mon, 6 May 2024 18:17:22 +0530 Subject: [PATCH 19/33] version updated --- libs/langchain/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/langchain/pyproject.toml b/libs/langchain/pyproject.toml index fb7db334d0d52..da602749a25db 100644 --- a/libs/langchain/pyproject.toml +++ b/libs/langchain/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langchain" -version = "0.1.46dev1" +version = "0.1.47dev1" description = "Building applications with LLMs through composability" authors = [] license = "MIT" From 22120ff3cf2c98036dcfdf7c7bf175edb71b406d Mon Sep 17 00:00:00 2001 From: arunraja1 Date: Mon, 6 May 2024 19:05:55 +0530 Subject: [PATCH 20/33] chore: Add sources to agent state --- libs/langchain/langchain/agents/agent.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/libs/langchain/langchain/agents/agent.py b/libs/langchain/langchain/agents/agent.py index 5d679750d4bbf..939a863261f13 100644 --- a/libs/langchain/langchain/agents/agent.py +++ b/libs/langchain/langchain/agents/agent.py @@ -1687,13 +1687,15 @@ def _take_next_step( tool_input = intermediate_steps[-1][0].dict().get("tool_input",None) sql_query = json.loads(intermediate_steps[-1][1][0]).get("sql_query",None) answer = json.loads(intermediate_steps[-1][1][0]).get("answer",None) + sources = json.loads(intermediate_steps[-1][1][0]).get("sources",None) self.state.append( { "tool": tool_name, "tool_input": tool_input, "sql_query": sql_query, - "answer": answer + "answer": answer, + "sources": sources } ) @@ -1702,6 +1704,10 @@ def _take_next_step( else: self.state = [] + if len(self.state) != 0: + # return self.state['answer'] + return AgentFinish({"output": self.state[0]["answer"]}, intermediate_steps[0][0].log) + # Call the LLM to see what to do. output = self.agent.plan( intermediate_steps, From 1c22b3d813d0a758af95893285fdc2501243cf95 Mon Sep 17 00:00:00 2001 From: Sushant Burnawal Date: Tue, 7 May 2024 21:24:01 +0530 Subject: [PATCH 21/33] sqlcoder retry prompt updated --- libs/langchain/langchain/tools/sqlcoder/prompt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/langchain/langchain/tools/sqlcoder/prompt.py b/libs/langchain/langchain/tools/sqlcoder/prompt.py index d1e1c20d80c63..7dcfb6e9ab8a1 100644 --- a/libs/langchain/langchain/tools/sqlcoder/prompt.py +++ b/libs/langchain/langchain/tools/sqlcoder/prompt.py @@ -35,7 +35,7 @@ Your task is convert a question into a SQL query, given a schema which is databricks sql compatible. Adhere to these rules: - **Deliberately go through the question and database schema word by word** to appropriately answer the question -- **Deliberately go through the '{sql_query}' and database schema word by word** to ensure that you get the correct column names and metric names to answer the question +- **Deliberately go through the database schema word by word** to ensure that you get the correct column names and metric names to answer the question - **Use Table Aliases** to prevent ambiguity. For example, `SELECT table1.col1, table2.col1 FROM table1 JOIN table2 ON table1.id = table2.id`. - When creating a ratio, always cast the numerator as float You are an AI research assistant in the senior living industry. From 78f3ed51ad0348fc247589165f5936a690ffd272 Mon Sep 17 00:00:00 2001 From: Sushant Burnawal Date: Tue, 7 May 2024 21:25:01 +0530 Subject: [PATCH 22/33] version updated --- libs/langchain/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/langchain/pyproject.toml b/libs/langchain/pyproject.toml index da602749a25db..97d93b9a6fd67 100644 --- a/libs/langchain/pyproject.toml +++ b/libs/langchain/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langchain" -version = "0.1.47dev1" +version = "0.1.48dev1" description = "Building applications with LLMs through composability" authors = [] license = "MIT" From a8027442a7cc02af55738d83734ef11ab6553dfa Mon Sep 17 00:00:00 2001 From: Sushant Burnawal Date: Wed, 8 May 2024 14:59:07 +0530 Subject: [PATCH 23/33] changes added --- libs/langchain/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/langchain/pyproject.toml b/libs/langchain/pyproject.toml index 97d93b9a6fd67..7b012f35fc54e 100644 --- a/libs/langchain/pyproject.toml +++ b/libs/langchain/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langchain" -version = "0.1.48dev1" +version = "0.1.49dev1" description = "Building applications with LLMs through composability" authors = [] license = "MIT" From 9db1f3500026af8b6608f086b0f079ec4534ece0 Mon Sep 17 00:00:00 2001 From: Sushant Burnawal Date: Wed, 22 May 2024 11:48:31 +0530 Subject: [PATCH 24/33] retry mechanism bug fix --- .../langchain/tools/sqlcoder/prompt.py | 26 +++++++++++++------ .../langchain/tools/sqlcoder/tool.py | 2 +- libs/langchain/pyproject.toml | 2 +- 3 files changed, 20 insertions(+), 10 deletions(-) diff --git a/libs/langchain/langchain/tools/sqlcoder/prompt.py b/libs/langchain/langchain/tools/sqlcoder/prompt.py index 7dcfb6e9ab8a1..e0f6cf38b2f5b 100644 --- a/libs/langchain/langchain/tools/sqlcoder/prompt.py +++ b/libs/langchain/langchain/tools/sqlcoder/prompt.py @@ -10,9 +10,16 @@ Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 30 results. Don't query living options if amenities are asked for. Most sql queries would reduce into the Unified query below, you are free to change or remove the predicates of this query based on the question. When denominator is null, use the total amount to compute the aggregate. Remember to change Metric_Code and other predicates with relevant metric asked in the prompt. Use this unified query below whenever you are not sure about what query to form or you are facing errors in forming sql. -Unified Query: SELECT Business_Unit , Entity_Name , Entity_Type , LOB_01 , LOB_02 , LOB_03 , LOB_04 , LOB_05 , LOB_06 , LOB_07 , LOB_08 , LOB_09 , Date , SUM(Total_Amount) , SUM(Total_Amount_Numerator) AS SumNumerator , SUM(Total_Amount_Denominator) AS SumDenominator , (SumNumerator/NULLIF(SumDenominator, 0)) AS Average_Metric , Metric_Code, Metric_Name , Metric_Description , Metric_Frequency , Calculation_Description , Week_Num , Month , Month_Name , Month_Number , Year_Month , Quarter_Number , Year , Facility_Name , Health_Type FROM skypoint_metric_fact_denormalized_vw WHERE Metric_Code = 'DSM_M' AND Facility_Name LIKE '%' AND Date BETWEEN '2019-01-01 00:00:00' AND '2025-12-31 23:59:59' AND Year BETWEEN 2019 AND 2025 AND Month LIKE '%' AND Month_Name LIKE '%' AND Month_Number BETWEEN 1 AND 12 AND Facility_Name LIKE '%' AND Business_Unit LIKE '%' AND Entity_Name LIKE '%' AND Entity_Type LIKE '%' AND LOB_01 LIKE '%' AND LOB_02 LIKE '%' AND LOB_03 LIKE '%' AND LOB_04 LIKE '%' AND LOB_05 LIKE '%' AND LOB_06 LIKE '%' AND LOB_07 LIKE '%' AND LOB_08 LIKE '%' AND LOB_09 LIKE '%' AND Health_Type LIKE '%' AND ISNOTNULL(Total_Amount_Denominator) GROUP BY Business_Unit , Entity_Name , Entity_Type , Facility_Name , Health_Type , LOB_01 , LOB_02 , LOB_03 , LOB_04 , LOB_05 , LOB_06 , LOB_07 , LOB_08 , LOB_09 , Year , Quarter_Number , Year_Month , Month , Month_Name , Month_Number , Week_Num , Date , Metric_Code, Metric_Name , Metric_Description , Metric_Frequency , Calculation_Description, Entity_Name LIMIT 20 +Unified Query: SELECT Business_Unit , Entity_Name , Entity_Type , LOB_01 , LOB_02 , LOB_03 , LOB_04 , LOB_05 , LOB_06 , LOB_07 , LOB_08 , LOB_09 , Date , Facility_Name, COUNT(DISTICT Facility_Name), SUM(Total_Amount_Numerator) /COALESCE(SUM(NULLIF(Total_Amount_Denominator,0)),1) AS Average_Metric , Metric_Code, Metric_Name , Metric_Description , Metric_Frequency , Calculation_Description , Week_Num , Month_Name , Month_Number , Year_Month , Quarter_Number , Year , Facility_Name , Health_Type FROM genesishealthcare_sandbox_main.genesishealthcare_sandbox.skypoint_metric_fact_denormalized_vw WHERE Metric_Code = 'DSM_M' AND Facility_Name LIKE '%' AND Date BETWEEN '2019-01-01 00:00:00' AND '2025-12-31 23:59:59' AND Year BETWEEN 2019 AND 2025 AND Month_Name LIKE '%' AND Month_Number BETWEEN 1 AND 12 AND Facility_Name LIKE '%' AND Business_Unit LIKE '%' AND Entity_Name LIKE '%' AND Entity_Type LIKE '%' AND LOB_01 LIKE '%' AND LOB_02 LIKE '%' AND LOB_03 LIKE '%' AND LOB_04 LIKE '%' AND LOB_05 LIKE '%' AND LOB_06 LIKE '%' AND LOB_07 LIKE '%' AND LOB_08 LIKE '%' AND LOB_09 LIKE '%' AND Health_Type LIKE '%' AND ISNOTNULL(Total_Amount_Denominator) GROUP BY Facility_Name, Business_Unit , Entity_Name , Entity_Type , Facility_Name , Health_Type , LOB_01 , LOB_02 , LOB_03 , LOB_04 , LOB_05 , LOB_06 , LOB_07 , LOB_08 , LOB_09 , Year , Quarter_Number , Year_Month , Month_Name , Month_Number , Week_Num , Date , Metric_Code, Metric_Name , Metric_Description , Metric_Frequency , Calculation_Description ORDER BY Date, Entity_Name LIMIT 30 You can order the results by a relevant column to return the most interesting examples in the database. Never query for all the columns from a specific table, only ask for the relevant columns given the question. +Always cast rating to decimal with 2 points after decimal. +Always try to aggregate data while querying by using GROUP BY. +Always try to use DISTINCT with COUNT wherever required. +Don't use state if not mentioned in the question. +When querying string columns use rlike operator with regular expression and different combinations of the string. +When comparing strings use java regular expressions. +Use exists with rlike and regular expressions when matching string array elements DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. Create the query for the specified year/month asked in the '{user_input}' @@ -35,7 +42,6 @@ Your task is convert a question into a SQL query, given a schema which is databricks sql compatible. Adhere to these rules: - **Deliberately go through the question and database schema word by word** to appropriately answer the question -- **Deliberately go through the database schema word by word** to ensure that you get the correct column names and metric names to answer the question - **Use Table Aliases** to prevent ambiguity. For example, `SELECT table1.col1, table2.col1 FROM table1 JOIN table2 ON table1.id = table2.id`. - When creating a ratio, always cast the numerator as float You are an AI research assistant in the senior living industry. @@ -44,23 +50,27 @@ Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 30 results. Don't query living options if amenities are asked for. Most sql queries would reduce into the Unified query below, you are free to change or remove the predicates of this query based on the question. When denominator is null, use the total amount to compute the aggregate. Remember to change Metric_Code and other predicates with relevant metric asked in the prompt. Use this unified query below whenever you are not sure about what query to form or you are facing errors in forming sql. -Unified Query: SELECT Business_Unit , Entity_Name , Entity_Type , LOB_01 , LOB_02 , LOB_03 , LOB_04 , LOB_05 , LOB_06 , LOB_07 , LOB_08 , LOB_09 , Date , SUM(Total_Amount) , SUM(Total_Amount_Numerator) AS SumNumerator , SUM(Total_Amount_Denominator) AS SumDenominator , (SumNumerator/NULLIF(SumDenominator, 0)) AS Average_Metric , Metric_Code, Metric_Name , Metric_Description , Metric_Frequency , Calculation_Description , Week_Num , Month , Month_Name , Month_Number , Year_Month , Quarter_Number , Year , Facility_Name , Health_Type FROM skypoint_metric_fact_denormalized_vw WHERE Metric_Code = 'DSM_M' AND Facility_Name LIKE '%' AND Date BETWEEN '2019-01-01 00:00:00' AND '2025-12-31 23:59:59' AND Year BETWEEN 2019 AND 2025 AND Month LIKE '%' AND Month_Name LIKE '%' AND Month_Number BETWEEN 1 AND 12 AND Facility_Name LIKE '%' AND Business_Unit LIKE '%' AND Entity_Name LIKE '%' AND Entity_Type LIKE '%' AND LOB_01 LIKE '%' AND LOB_02 LIKE '%' AND LOB_03 LIKE '%' AND LOB_04 LIKE '%' AND LOB_05 LIKE '%' AND LOB_06 LIKE '%' AND LOB_07 LIKE '%' AND LOB_08 LIKE '%' AND LOB_09 LIKE '%' AND Health_Type LIKE '%' AND ISNOTNULL(Total_Amount_Denominator) GROUP BY Business_Unit , Entity_Name , Entity_Type , Facility_Name , Health_Type , LOB_01 , LOB_02 , LOB_03 , LOB_04 , LOB_05 , LOB_06 , LOB_07 , LOB_08 , LOB_09 , Year , Quarter_Number , Year_Month , Month , Month_Name , Month_Number , Week_Num , Date , Metric_Code, Metric_Name , Metric_Description , Metric_Frequency , Calculation_Description , Entity_Name LIMIT 20 -Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 30 results. +Unified Query: SELECT Business_Unit , Entity_Name , Entity_Type , LOB_01 , LOB_02 , LOB_03 , LOB_04 , LOB_05 , LOB_06 , LOB_07 , LOB_08 , LOB_09 , Date , Facility_Name, COUNT(DISTICT Facility_Name), SUM(Total_Amount_Numerator) /COALESCE(SUM(NULLIF(Total_Amount_Denominator,0)),1) AS Average_Metric , Metric_Code, Metric_Name , Metric_Description , Metric_Frequency , Calculation_Description , Week_Num , Month_Name , Month_Number , Year_Month , Quarter_Number , Year , Facility_Name , Health_Type FROM genesishealthcare_sandbox_main.genesishealthcare_sandbox.skypoint_metric_fact_denormalized_vw WHERE Metric_Code = 'DSM_M' AND Facility_Name LIKE '%' AND Date BETWEEN '2019-01-01 00:00:00' AND '2025-12-31 23:59:59' AND Year BETWEEN 2019 AND 2025 AND Month_Name LIKE '%' AND Month_Number BETWEEN 1 AND 12 AND Facility_Name LIKE '%' AND Business_Unit LIKE '%' AND Entity_Name LIKE '%' AND Entity_Type LIKE '%' AND LOB_01 LIKE '%' AND LOB_02 LIKE '%' AND LOB_03 LIKE '%' AND LOB_04 LIKE '%' AND LOB_05 LIKE '%' AND LOB_06 LIKE '%' AND LOB_07 LIKE '%' AND LOB_08 LIKE '%' AND LOB_09 LIKE '%' AND Health_Type LIKE '%' AND ISNOTNULL(Total_Amount_Denominator) GROUP BY Facility_Name, Business_Unit , Entity_Name , Entity_Type , Facility_Name , Health_Type , LOB_01 , LOB_02 , LOB_03 , LOB_04 , LOB_05 , LOB_06 , LOB_07 , LOB_08 , LOB_09 , Year , Quarter_Number , Year_Month , Month_Name , Month_Number , Week_Num , Date , Metric_Code, Metric_Name , Metric_Description , Metric_Frequency , Calculation_Description ORDER BY Date, Entity_Name LIMIT 30 You can order the results by a relevant column to return the most interesting examples in the database. Never query for all the columns from a specific table, only ask for the relevant columns given the question. +Always cast rating to decimal with 2 points after decimal. +Always try to aggregate data while querying by using GROUP BY. +Always try to use DISTINCT with COUNT wherever required. +Don't use state if not mentioned in the question. +When querying string columns use rlike operator with regular expression and different combinations of the string. +When comparing strings use java regular expressions. +Use exists with rlike and regular expressions when matching string array elements DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. - -Try different queries with different variations of the strings like synonyms, abbreviations, singular/plural forms of words in the string, etc. - +Create the query for the specified year/month asked in the '{user_input}' ### Input: Generate a SQL query that answers the question `{user_input}`. This query will run on a database whose schema is represented in this string: '{db_schema}' '{data_model_context}' -Replace Entity_Name with 'Facility_Name' in SQL query. Use the following examples to generate the sql query: '{few_shot_examples}' +Unless specified in the user input, always limit your query to 30 results ### Response: Based on your instructions, here is the SQL query I have generated to answer '{user_input}' ```sql""" diff --git a/libs/langchain/langchain/tools/sqlcoder/tool.py b/libs/langchain/langchain/tools/sqlcoder/tool.py index dcc1477a5777f..6d408b524e0ea 100644 --- a/libs/langchain/langchain/tools/sqlcoder/tool.py +++ b/libs/langchain/langchain/tools/sqlcoder/tool.py @@ -75,7 +75,7 @@ async def _arun( raise NotImplementedError("QuerySparkSQLDataBaseTool does not support async") def _extract_sql_query(self): - for value in self.state: + for value in reversed(self.state): for key, input_string in value.items(): if "sql_db_query_creator" in key: return input_string diff --git a/libs/langchain/pyproject.toml b/libs/langchain/pyproject.toml index 7b012f35fc54e..5ab78b257456e 100644 --- a/libs/langchain/pyproject.toml +++ b/libs/langchain/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langchain" -version = "0.1.49dev1" +version = "0.1.50dev1" description = "Building applications with LLMs through composability" authors = [] license = "MIT" From e2b0246e7cb6049ac4f3e8a043b5d595a31dfe78 Mon Sep 17 00:00:00 2001 From: Sushant Burnawal Date: Fri, 24 May 2024 14:39:30 +0530 Subject: [PATCH 25/33] SQL_QUERY_CREATOR_TEMPLATE is removed from langchain --- .../langchain/tools/sqlcoder/prompt.py | 73 +------------------ .../langchain/tools/sqlcoder/tool.py | 7 +- 2 files changed, 5 insertions(+), 75 deletions(-) diff --git a/libs/langchain/langchain/tools/sqlcoder/prompt.py b/libs/langchain/langchain/tools/sqlcoder/prompt.py index e0f6cf38b2f5b..dffbf4f2e136a 100644 --- a/libs/langchain/langchain/tools/sqlcoder/prompt.py +++ b/libs/langchain/langchain/tools/sqlcoder/prompt.py @@ -1,79 +1,8 @@ -SQL_QUERY_CREATOR = """### Instructions: -Your task is convert a question into a SQL query, given a schema which is databricks sql compatible. -Adhere to these rules: -- **Deliberately go through the question and database schema word by word** to appropriately answer the question -- **Use Table Aliases** to prevent ambiguity. For example, `SELECT table1.col1, table2.col1 FROM table1 JOIN table2 ON table1.id = table2.id`. -- When creating a ratio, always cast the numerator as float -You are an AI research assistant in the senior living industry. -You have access to a database that contains the information about different communities, their amenities, residents, expenses, budget, revenue and other finances, facilities, beds, events -When querying the database, given an input question, create a syntactically correct query to run. -Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 30 results. -Don't query living options if amenities are asked for. -Most sql queries would reduce into the Unified query below, you are free to change or remove the predicates of this query based on the question. When denominator is null, use the total amount to compute the aggregate. Remember to change Metric_Code and other predicates with relevant metric asked in the prompt. Use this unified query below whenever you are not sure about what query to form or you are facing errors in forming sql. -Unified Query: SELECT Business_Unit , Entity_Name , Entity_Type , LOB_01 , LOB_02 , LOB_03 , LOB_04 , LOB_05 , LOB_06 , LOB_07 , LOB_08 , LOB_09 , Date , Facility_Name, COUNT(DISTICT Facility_Name), SUM(Total_Amount_Numerator) /COALESCE(SUM(NULLIF(Total_Amount_Denominator,0)),1) AS Average_Metric , Metric_Code, Metric_Name , Metric_Description , Metric_Frequency , Calculation_Description , Week_Num , Month_Name , Month_Number , Year_Month , Quarter_Number , Year , Facility_Name , Health_Type FROM genesishealthcare_sandbox_main.genesishealthcare_sandbox.skypoint_metric_fact_denormalized_vw WHERE Metric_Code = 'DSM_M' AND Facility_Name LIKE '%' AND Date BETWEEN '2019-01-01 00:00:00' AND '2025-12-31 23:59:59' AND Year BETWEEN 2019 AND 2025 AND Month_Name LIKE '%' AND Month_Number BETWEEN 1 AND 12 AND Facility_Name LIKE '%' AND Business_Unit LIKE '%' AND Entity_Name LIKE '%' AND Entity_Type LIKE '%' AND LOB_01 LIKE '%' AND LOB_02 LIKE '%' AND LOB_03 LIKE '%' AND LOB_04 LIKE '%' AND LOB_05 LIKE '%' AND LOB_06 LIKE '%' AND LOB_07 LIKE '%' AND LOB_08 LIKE '%' AND LOB_09 LIKE '%' AND Health_Type LIKE '%' AND ISNOTNULL(Total_Amount_Denominator) GROUP BY Facility_Name, Business_Unit , Entity_Name , Entity_Type , Facility_Name , Health_Type , LOB_01 , LOB_02 , LOB_03 , LOB_04 , LOB_05 , LOB_06 , LOB_07 , LOB_08 , LOB_09 , Year , Quarter_Number , Year_Month , Month_Name , Month_Number , Week_Num , Date , Metric_Code, Metric_Name , Metric_Description , Metric_Frequency , Calculation_Description ORDER BY Date, Entity_Name LIMIT 30 -You can order the results by a relevant column to return the most interesting examples in the database. -Never query for all the columns from a specific table, only ask for the relevant columns given the question. -Always cast rating to decimal with 2 points after decimal. -Always try to aggregate data while querying by using GROUP BY. -Always try to use DISTINCT with COUNT wherever required. -Don't use state if not mentioned in the question. -When querying string columns use rlike operator with regular expression and different combinations of the string. -When comparing strings use java regular expressions. -Use exists with rlike and regular expressions when matching string array elements -DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. -Create the query for the specified year/month asked in the '{user_input}' - -### Input: -Generate a SQL query that answers the question `{user_input}`. -This query will run on a database whose schema is represented in this string: -'{db_schema}' -'{data_model_context}' -Use the following examples to generate the sql query: -'{few_shot_examples}' -Unless specified in the user input, always limit your query to 30 results -### Response: -Based on your instructions, here is the SQL query I have generated to answer '{user_input}' -```sql""" SQL_QUERY_CREATOR_RETRY = """ -### Instructions: You have failed in the first attempt to generate correct sql query. Please try again to rewrite correct sql query. -Your task is convert a question into a SQL query, given a schema which is databricks sql compatible. -Adhere to these rules: -- **Deliberately go through the question and database schema word by word** to appropriately answer the question -- **Use Table Aliases** to prevent ambiguity. For example, `SELECT table1.col1, table2.col1 FROM table1 JOIN table2 ON table1.id = table2.id`. -- When creating a ratio, always cast the numerator as float -You are an AI research assistant in the senior living industry. -You have access to a database that contains the information about different communities, their amenities, residents, expenses, budget, revenue and other finances, facilities, beds, events -When querying the database, given an input question, create a syntactically correct query to run. -Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 30 results. -Don't query living options if amenities are asked for. -Most sql queries would reduce into the Unified query below, you are free to change or remove the predicates of this query based on the question. When denominator is null, use the total amount to compute the aggregate. Remember to change Metric_Code and other predicates with relevant metric asked in the prompt. Use this unified query below whenever you are not sure about what query to form or you are facing errors in forming sql. -Unified Query: SELECT Business_Unit , Entity_Name , Entity_Type , LOB_01 , LOB_02 , LOB_03 , LOB_04 , LOB_05 , LOB_06 , LOB_07 , LOB_08 , LOB_09 , Date , Facility_Name, COUNT(DISTICT Facility_Name), SUM(Total_Amount_Numerator) /COALESCE(SUM(NULLIF(Total_Amount_Denominator,0)),1) AS Average_Metric , Metric_Code, Metric_Name , Metric_Description , Metric_Frequency , Calculation_Description , Week_Num , Month_Name , Month_Number , Year_Month , Quarter_Number , Year , Facility_Name , Health_Type FROM genesishealthcare_sandbox_main.genesishealthcare_sandbox.skypoint_metric_fact_denormalized_vw WHERE Metric_Code = 'DSM_M' AND Facility_Name LIKE '%' AND Date BETWEEN '2019-01-01 00:00:00' AND '2025-12-31 23:59:59' AND Year BETWEEN 2019 AND 2025 AND Month_Name LIKE '%' AND Month_Number BETWEEN 1 AND 12 AND Facility_Name LIKE '%' AND Business_Unit LIKE '%' AND Entity_Name LIKE '%' AND Entity_Type LIKE '%' AND LOB_01 LIKE '%' AND LOB_02 LIKE '%' AND LOB_03 LIKE '%' AND LOB_04 LIKE '%' AND LOB_05 LIKE '%' AND LOB_06 LIKE '%' AND LOB_07 LIKE '%' AND LOB_08 LIKE '%' AND LOB_09 LIKE '%' AND Health_Type LIKE '%' AND ISNOTNULL(Total_Amount_Denominator) GROUP BY Facility_Name, Business_Unit , Entity_Name , Entity_Type , Facility_Name , Health_Type , LOB_01 , LOB_02 , LOB_03 , LOB_04 , LOB_05 , LOB_06 , LOB_07 , LOB_08 , LOB_09 , Year , Quarter_Number , Year_Month , Month_Name , Month_Number , Week_Num , Date , Metric_Code, Metric_Name , Metric_Description , Metric_Frequency , Calculation_Description ORDER BY Date, Entity_Name LIMIT 30 -You can order the results by a relevant column to return the most interesting examples in the database. -Never query for all the columns from a specific table, only ask for the relevant columns given the question. -Always cast rating to decimal with 2 points after decimal. -Always try to aggregate data while querying by using GROUP BY. -Always try to use DISTINCT with COUNT wherever required. -Don't use state if not mentioned in the question. -When querying string columns use rlike operator with regular expression and different combinations of the string. -When comparing strings use java regular expressions. -Use exists with rlike and regular expressions when matching string array elements -DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. -Create the query for the specified year/month asked in the '{user_input}' - -### Input: -Generate a SQL query that answers the question `{user_input}`. -This query will run on a database whose schema is represented in this string: -'{db_schema}' -'{data_model_context}' -Use the following examples to generate the sql query: -'{few_shot_examples}' -Unless specified in the user input, always limit your query to 30 results -### Response: -Based on your instructions, here is the SQL query I have generated to answer '{user_input}' -```sql""" +""" SQL_QUERY_CREATOR_7b = """ ### Instructions: diff --git a/libs/langchain/langchain/tools/sqlcoder/tool.py b/libs/langchain/langchain/tools/sqlcoder/tool.py index 6d408b524e0ea..2681a6ca56c15 100644 --- a/libs/langchain/langchain/tools/sqlcoder/tool.py +++ b/libs/langchain/langchain/tools/sqlcoder/tool.py @@ -9,7 +9,7 @@ from langchain.chains.llm import LLMChain from langchain.prompts import PromptTemplate from langchain.sql_database import SQLDatabase -from langchain.tools.sqlcoder.prompt import SQL_QUERY_CREATOR_RETRY, SQL_QUERY_CREATOR +from langchain.tools.sqlcoder.prompt import SQL_QUERY_CREATOR_RETRY from langchain_core.pydantic_v1 import BaseModel, Extra, Field from langchain_core.tools import StateTool import re @@ -75,7 +75,7 @@ async def _arun( raise NotImplementedError("QuerySparkSQLDataBaseTool does not support async") def _extract_sql_query(self): - for value in reversed(self.state): + for value in self.state: for key, input_string in value.items(): if "sql_db_query_creator" in key: return input_string @@ -93,6 +93,7 @@ class SqlQueryCreatorTool(StateTool): Output is a sql query """ sqlcreatorllm: BaseLanguageModel = Field(exclude=True) + SQL_QUERY_CREATOR_TEMPLATE: str class Config(StateTool.Config): """Configuration for this pydantic object.""" @@ -152,7 +153,7 @@ def _create_sql_query(self,user_input): if sql_query is None: prompt_input = PromptTemplate( input_variables=["db_schema", "user_input", "few_shot_examples","data_model_context"], - template=SQL_QUERY_CREATOR, + template=self.SQL_QUERY_CREATOR_TEMPLATE, ) query_creator_chain = LLMChain(llm=self.sqlcreatorllm, prompt=prompt_input) From 4bcd25ec968313e2dd39317ef0bfb32b71a823d5 Mon Sep 17 00:00:00 2001 From: Sushant Burnawal Date: Fri, 24 May 2024 16:25:38 +0530 Subject: [PATCH 26/33] changes to sqlcodertoolkit --- libs/langchain/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/langchain/pyproject.toml b/libs/langchain/pyproject.toml index 5ab78b257456e..93599f9be7da3 100644 --- a/libs/langchain/pyproject.toml +++ b/libs/langchain/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langchain" -version = "0.1.50dev1" +version = "0.1.52dev1" description = "Building applications with LLMs through composability" authors = [] license = "MIT" From 093b6607a987d742d731281d46a0f665d7ada1ac Mon Sep 17 00:00:00 2001 From: Sushant Burnawal Date: Fri, 24 May 2024 16:26:12 +0530 Subject: [PATCH 27/33] changes to sqlcodertoolkit --- .../langchain/agents/agent_toolkits/sqlcoder/toolkit.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/libs/langchain/langchain/agents/agent_toolkits/sqlcoder/toolkit.py b/libs/langchain/langchain/agents/agent_toolkits/sqlcoder/toolkit.py index 5dffff38c565e..b1be9ab0e5cf7 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/sqlcoder/toolkit.py +++ b/libs/langchain/langchain/agents/agent_toolkits/sqlcoder/toolkit.py @@ -25,6 +25,7 @@ class SQLCoderToolkit(BaseToolkit): db_warehouse_id: str allow_extra_fields = True sqlcreatorllm : BaseLanguageModel = Field(exclude=True) + sql_query_creator_template : str @property def dialect(self) -> str: @@ -58,6 +59,8 @@ def get_tools(self) -> List[BaseTool]: db_host=self.db_host, db_catalog=self.db_catalog, db_schema=self.db_schema, - db_warehouse_id=self.db_warehouse_id + db_warehouse_id=self.db_warehouse_id, + SQL_QUERY_CREATOR_TEMPLATE=self.sql_query_creator_template + ) ] From 4f611ba7e3a5509b5feeae1116b830a83891dbe9 Mon Sep 17 00:00:00 2001 From: Sushant Burnawal Date: Fri, 24 May 2024 16:26:42 +0530 Subject: [PATCH 28/33] changes to sqlcodertoolkit --- libs/langchain/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/langchain/pyproject.toml b/libs/langchain/pyproject.toml index 93599f9be7da3..11698c4ef729b 100644 --- a/libs/langchain/pyproject.toml +++ b/libs/langchain/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langchain" -version = "0.1.52dev1" +version = "0.1.53dev1" description = "Building applications with LLMs through composability" authors = [] license = "MIT" From cf50a6c2dc50558c3f9aaa22a04136cec3e59998 Mon Sep 17 00:00:00 2001 From: arunraja1 Date: Thu, 30 May 2024 08:34:02 +0530 Subject: [PATCH 29/33] Add prompt instructions to include sources in the document variable --- libs/langchain/langchain/chains/combine_documents/stuff.py | 1 + libs/langchain/pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/libs/langchain/langchain/chains/combine_documents/stuff.py b/libs/langchain/langchain/chains/combine_documents/stuff.py index a30d4a0e90b1b..fc8b06f6775fa 100644 --- a/libs/langchain/langchain/chains/combine_documents/stuff.py +++ b/libs/langchain/langchain/chains/combine_documents/stuff.py @@ -131,6 +131,7 @@ def _get_inputs(self, docs: List[Document], **kwargs: Any) -> dict: if k in self.llm_chain.prompt.input_variables } inputs[self.document_variable_name] = self.document_separator.join(doc_strings) + inputs[self.document_variable_name] += "\n\nYou should return the sources only when it matches with the output. Else do not return sources" return inputs def prompt_length(self, docs: List[Document], **kwargs: Any) -> Optional[int]: diff --git a/libs/langchain/pyproject.toml b/libs/langchain/pyproject.toml index 11698c4ef729b..2dd37a9d70208 100644 --- a/libs/langchain/pyproject.toml +++ b/libs/langchain/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langchain" -version = "0.1.53dev1" +version = "0.1.54dev1" description = "Building applications with LLMs through composability" authors = [] license = "MIT" From b910473632c3e83be664811aab6dc4d65ff78ebe Mon Sep 17 00:00:00 2001 From: arunraja1 Date: Thu, 30 May 2024 08:39:26 +0530 Subject: [PATCH 30/33] Update document variable to include relevant sources --- libs/langchain/langchain/chains/combine_documents/stuff.py | 2 +- libs/langchain/pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/libs/langchain/langchain/chains/combine_documents/stuff.py b/libs/langchain/langchain/chains/combine_documents/stuff.py index fc8b06f6775fa..d4e6829fa48a7 100644 --- a/libs/langchain/langchain/chains/combine_documents/stuff.py +++ b/libs/langchain/langchain/chains/combine_documents/stuff.py @@ -131,7 +131,7 @@ def _get_inputs(self, docs: List[Document], **kwargs: Any) -> dict: if k in self.llm_chain.prompt.input_variables } inputs[self.document_variable_name] = self.document_separator.join(doc_strings) - inputs[self.document_variable_name] += "\n\nYou should return the sources only when it matches with the output. Else do not return sources" + inputs[self.document_variable_name] += "\n\nYou should return the sources only when it is relevant to the output. Else do not return Sources as None" return inputs def prompt_length(self, docs: List[Document], **kwargs: Any) -> Optional[int]: diff --git a/libs/langchain/pyproject.toml b/libs/langchain/pyproject.toml index 2dd37a9d70208..15744f1eb7e10 100644 --- a/libs/langchain/pyproject.toml +++ b/libs/langchain/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langchain" -version = "0.1.54dev1" +version = "0.1.55dev1" description = "Building applications with LLMs through composability" authors = [] license = "MIT" From 16acf007e2411cb4c93c0ccb17668846fc8a75df Mon Sep 17 00:00:00 2001 From: arunraja1 Date: Thu, 30 May 2024 08:40:55 +0530 Subject: [PATCH 31/33] Update document variable to include relevant sources --- libs/langchain/langchain/chains/combine_documents/stuff.py | 2 +- libs/langchain/pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/libs/langchain/langchain/chains/combine_documents/stuff.py b/libs/langchain/langchain/chains/combine_documents/stuff.py index d4e6829fa48a7..43eab1e5ddbc9 100644 --- a/libs/langchain/langchain/chains/combine_documents/stuff.py +++ b/libs/langchain/langchain/chains/combine_documents/stuff.py @@ -131,7 +131,7 @@ def _get_inputs(self, docs: List[Document], **kwargs: Any) -> dict: if k in self.llm_chain.prompt.input_variables } inputs[self.document_variable_name] = self.document_separator.join(doc_strings) - inputs[self.document_variable_name] += "\n\nYou should return the sources only when it is relevant to the output. Else do not return Sources as None" + inputs[self.document_variable_name] += "\n\nYou should return the sources only when it is relevant to the output. Else return Sources as None" return inputs def prompt_length(self, docs: List[Document], **kwargs: Any) -> Optional[int]: diff --git a/libs/langchain/pyproject.toml b/libs/langchain/pyproject.toml index 15744f1eb7e10..8d5152af0a5a2 100644 --- a/libs/langchain/pyproject.toml +++ b/libs/langchain/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langchain" -version = "0.1.55dev1" +version = "0.1.56dev1" description = "Building applications with LLMs through composability" authors = [] license = "MIT" From e2737c3344ef3bf04072732e9abc6e0c9daddb29 Mon Sep 17 00:00:00 2001 From: arunraja1 Date: Thu, 30 May 2024 09:16:11 +0530 Subject: [PATCH 32/33] Update sources handling in BaseQAWithSourcesChain --- libs/langchain/langchain/chains/qa_with_sources/base.py | 3 +++ libs/langchain/pyproject.toml | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/libs/langchain/langchain/chains/qa_with_sources/base.py b/libs/langchain/langchain/chains/qa_with_sources/base.py index 59fe0cfb4af7d..874a95b000650 100644 --- a/libs/langchain/langchain/chains/qa_with_sources/base.py +++ b/libs/langchain/langchain/chains/qa_with_sources/base.py @@ -133,6 +133,9 @@ def _split_sources(self, raw_answer: str) -> Tuple[str, str]: r"SOURCES?:|QUESTION:\s", raw_answer, flags=re.IGNORECASE )[:2] sources = re.split(r"\n", raw_sources)[0].strip() + if "/" in sources: + sources = sources.split("/")[-1].strip() + if sources == "": regex = r"- \s*(.+\.pdf)" sources_list = re.findall(regex, raw_sources) diff --git a/libs/langchain/pyproject.toml b/libs/langchain/pyproject.toml index 8d5152af0a5a2..3b441e1ac6827 100644 --- a/libs/langchain/pyproject.toml +++ b/libs/langchain/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langchain" -version = "0.1.56dev1" +version = "0.1.57dev1" description = "Building applications with LLMs through composability" authors = [] license = "MIT" From 09d5c0cf7f863e4a75ccca7432ae37b4bc70c88a Mon Sep 17 00:00:00 2001 From: arunraja1 Date: Fri, 21 Jun 2024 19:20:02 +0530 Subject: [PATCH 33/33] Update SQL query creator tool to remove unnecessary code and improve readability --- libs/langchain/langchain/tools/sqlcoder/tool.py | 2 +- libs/langchain/pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/libs/langchain/langchain/tools/sqlcoder/tool.py b/libs/langchain/langchain/tools/sqlcoder/tool.py index 2681a6ca56c15..5a778513fa459 100644 --- a/libs/langchain/langchain/tools/sqlcoder/tool.py +++ b/libs/langchain/langchain/tools/sqlcoder/tool.py @@ -185,7 +185,7 @@ def _create_sql_query(self,user_input): ) ) sql_query = sql_query.replace("```","") - + sql_query = sql_query.replace("sql","") return sql_query diff --git a/libs/langchain/pyproject.toml b/libs/langchain/pyproject.toml index 3b441e1ac6827..80fa9b4740100 100644 --- a/libs/langchain/pyproject.toml +++ b/libs/langchain/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langchain" -version = "0.1.57dev1" +version = "0.1.58dev1" description = "Building applications with LLMs through composability" authors = [] license = "MIT"