From 890ed775a38e4f1d0b9648fc9e50bc2bc07c573b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=20=E6=96=B9=E7=91=9E?= Date: Thu, 7 Sep 2023 08:08:12 +0800 Subject: [PATCH 1/2] Resolve: VectorSearch enabled SQLChain? (#10177) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Squashed from #7454 with updated features We have separated the `SQLDatabseChain` from `VectorSQLDatabseChain` and put everything into `experimental/`. Below is the original PR message from #7454. ------- We have been working on features to fill up the gap among SQL, vector search and LLM applications. Some inspiring works like self-query retrievers for VectorStores (for example [Weaviate](https://python.langchain.com/en/latest/modules/indexes/retrievers/examples/weaviate_self_query.html) and [others](https://python.langchain.com/en/latest/modules/indexes/retrievers/examples/self_query.html)) really turn those vector search databases into a powerful knowledge base! 🚀🚀 We are thinking if we can merge all in one, like SQL and vector search and LLMChains, making this SQL vector database memory as the only source of your data. Here are some benefits we can think of for now, maybe you have more 👀: With ALL data you have: since you store all your pasta in the database, you don't need to worry about the foreign keys or links between names from other data source. Flexible data structure: Even if you have changed your schema, for example added a table, the LLM will know how to JOIN those tables and use those as filters. SQL compatibility: We found that vector databases that supports SQL in the marketplace have similar interfaces, which means you can change your backend with no pain, just change the name of the distance function in your DB solution and you are ready to go! ### Issue resolved: - [Feature Proposal: VectorSearch enabled SQLChain?](https://github.com/hwchase17/langchain/issues/5122) ### Change made in this PR: - An improved schema handling that ignore `types.NullType` columns - A SQL output Parser interface in `SQLDatabaseChain` to enable Vector SQL capability and further more - A Retriever based on `SQLDatabaseChain` to retrieve data from the database for RetrievalQAChains and many others - Allow `SQLDatabaseChain` to retrieve data in python native format - Includes PR #6737 - Vector SQL Output Parser for `SQLDatabaseChain` and `SQLDatabaseChainRetriever` - Prompts that can implement text to VectorSQL - Corresponding unit-tests and notebook ### Twitter handle: - @MyScaleDB ### Tag Maintainer: Prompts / General: @hwchase17, @baskaryan DataLoaders / VectorStores / Retrievers: @rlancemartin, @eyurtsev ### Dependencies: No dependency added --- .../sql_database/myscale_vector_sql.ipynb | 200 +++++++++++++++ .../retrievers/vector_sql_database.py | 38 +++ .../langchain_experimental/sql/prompt.py | 85 +++++++ .../langchain_experimental/sql/vector_sql.py | 237 ++++++++++++++++++ libs/experimental/poetry.lock | 1 + .../langchain/utilities/sql_database.py | 6 + 6 files changed, 567 insertions(+) create mode 100644 docs/extras/modules/data_connection/retrievers/sql_database/myscale_vector_sql.ipynb create mode 100644 libs/experimental/langchain_experimental/retrievers/vector_sql_database.py create mode 100644 libs/experimental/langchain_experimental/sql/prompt.py create mode 100644 libs/experimental/langchain_experimental/sql/vector_sql.py diff --git a/docs/extras/modules/data_connection/retrievers/sql_database/myscale_vector_sql.ipynb b/docs/extras/modules/data_connection/retrievers/sql_database/myscale_vector_sql.ipynb new file mode 100644 index 0000000000000..65bd8323ed068 --- /dev/null +++ b/docs/extras/modules/data_connection/retrievers/sql_database/myscale_vector_sql.ipynb @@ -0,0 +1,200 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "245065c6", + "metadata": {}, + "source": [ + "# Vector SQL Retriever with MyScale\n", + "\n", + ">[MyScale](https://docs.myscale.com/en/) is an integrated vector database. You can access your database in SQL and also from here, LangChain. MyScale can make a use of [various data types and functions for filters](https://blog.myscale.com/2023/06/06/why-integrated-database-solution-can-boost-your-llm-apps/#filter-on-anything-without-constraints). It will boost up your LLM app no matter if you are scaling up your data or expand your system to broader application." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0246c5bf", + "metadata": {}, + "outputs": [], + "source": [ + "!pip3 install clickhouse-sqlalchemy InstructorEmbedding sentence_transformers openai langchain-experimental" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7585d2c3", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "from os import environ\n", + "import getpass\n", + "from typing import Dict, Any\n", + "from langchain import OpenAI, SQLDatabase, LLMChain\n", + "from langchain_experimental.sql.vector_sql import VectorSQLDatabaseChain\n", + "from sqlalchemy import create_engine, Column, MetaData\n", + "from langchain import PromptTemplate\n", + "\n", + "\n", + "from sqlalchemy import create_engine\n", + "\n", + "MYSCALE_HOST = \"msc-1decbcc9.us-east-1.aws.staging.myscale.cloud\"\n", + "MYSCALE_PORT = 443\n", + "MYSCALE_USER = \"chatdata\"\n", + "MYSCALE_PASSWORD = \"myscale_rocks\"\n", + "OPENAI_API_KEY = getpass.getpass(\"OpenAI API Key:\")\n", + "\n", + "engine = create_engine(\n", + " f\"clickhouse://{MYSCALE_USER}:{MYSCALE_PASSWORD}@{MYSCALE_HOST}:{MYSCALE_PORT}/default?protocol=https\"\n", + ")\n", + "metadata = MetaData(bind=engine)\n", + "environ[\"OPENAI_API_KEY\"] = OPENAI_API_KEY" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e08d9ddc", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.embeddings import HuggingFaceInstructEmbeddings\n", + "from langchain_experimental.sql.vector_sql import VectorSQLOutputParser\n", + "\n", + "output_parser = VectorSQLOutputParser.from_embeddings(\n", + " model=HuggingFaceInstructEmbeddings(\n", + " model_name=\"hkunlp/instructor-xl\", model_kwargs={\"device\": \"cpu\"}\n", + " )\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "84b705b2", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "from langchain.llms import OpenAI\n", + "from langchain.callbacks import StdOutCallbackHandler\n", + "\n", + "from langchain.utilities.sql_database import SQLDatabase\n", + "from langchain_experimental.sql.prompt import MYSCALE_PROMPT\n", + "from langchain_experimental.sql.vector_sql import VectorSQLDatabaseChain\n", + "\n", + "chain = VectorSQLDatabaseChain(\n", + " llm_chain=LLMChain(\n", + " llm=OpenAI(openai_api_key=OPENAI_API_KEY, temperature=0),\n", + " prompt=MYSCALE_PROMPT,\n", + " ),\n", + " top_k=10,\n", + " return_direct=True,\n", + " sql_cmd_parser=output_parser,\n", + " database=SQLDatabase(engine, None, metadata),\n", + ")\n", + "\n", + "import pandas as pd\n", + "\n", + "pd.DataFrame(\n", + " chain.run(\n", + " \"Please give me 10 papers to ask what is PageRank?\",\n", + " callbacks=[StdOutCallbackHandler()],\n", + " )\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "6c09cda0", + "metadata": {}, + "source": [ + "## SQL Database as Retriever" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "734d7ff5", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.chat_models import ChatOpenAI\n", + "from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain\n", + "\n", + "from langchain_experimental.sql.vector_sql import VectorSQLDatabaseChain\n", + "from langchain_experimental.retrievers.vector_sql_database \\\n", + " import VectorSQLDatabaseChainRetriever\n", + "from langchain_experimental.sql.prompt import MYSCALE_PROMPT\n", + "from langchain_experimental.sql.vector_sql import VectorSQLRetrieveAllOutputParser\n", + "\n", + "output_parser_retrieve_all = VectorSQLRetrieveAllOutputParser.from_embeddings(\n", + " output_parser.model\n", + ")\n", + "\n", + "chain = VectorSQLDatabaseChain.from_llm(\n", + " llm=OpenAI(openai_api_key=OPENAI_API_KEY, temperature=0),\n", + " prompt=MYSCALE_PROMPT,\n", + " top_k=10,\n", + " return_direct=True,\n", + " db=SQLDatabase(engine, None, metadata),\n", + " sql_cmd_parser=output_parser_retrieve_all,\n", + " native_format=True,\n", + ")\n", + "\n", + "# You need all those keys to get docs\n", + "retriever = VectorSQLDatabaseChainRetriever(sql_db_chain=chain, page_content_key=\"abstract\")\n", + "\n", + "document_with_metadata_prompt = PromptTemplate(\n", + " input_variables=[\"page_content\", \"id\", \"title\", \"authors\", \"pubdate\", \"categories\"],\n", + " template=\"Content:\\n\\tTitle: {title}\\n\\tAbstract: {page_content}\\n\\tAuthors: {authors}\\n\\tDate of Publication: {pubdate}\\n\\tCategories: {categories}\\nSOURCE: {id}\",\n", + ")\n", + "\n", + "chain = RetrievalQAWithSourcesChain.from_chain_type(\n", + " ChatOpenAI(\n", + " model_name=\"gpt-3.5-turbo-16k\", openai_api_key=OPENAI_API_KEY, temperature=0.6\n", + " ),\n", + " retriever=retriever,\n", + " chain_type=\"stuff\",\n", + " chain_type_kwargs={\n", + " \"document_prompt\": document_with_metadata_prompt,\n", + " },\n", + " return_source_documents=True,\n", + ")\n", + "ans = chain(\"Please give me 10 papers to ask what is PageRank?\",\n", + " callbacks=[StdOutCallbackHandler()])\n", + "print(ans[\"answer\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4948ff25", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/libs/experimental/langchain_experimental/retrievers/vector_sql_database.py b/libs/experimental/langchain_experimental/retrievers/vector_sql_database.py new file mode 100644 index 0000000000000..1ec088dbc515f --- /dev/null +++ b/libs/experimental/langchain_experimental/retrievers/vector_sql_database.py @@ -0,0 +1,38 @@ +"""Vector SQL Database Chain Retriever""" +from typing import Any, Dict, List + +from langchain.callbacks.manager import ( + AsyncCallbackManagerForRetrieverRun, + CallbackManagerForRetrieverRun, +) +from langchain.schema import BaseRetriever, Document + +from langchain_experimental.sql.vector_sql import VectorSQLDatabaseChain + + +class VectorSQLDatabaseChainRetriever(BaseRetriever): + """Retriever that uses SQLDatabase as Retriever""" + + sql_db_chain: VectorSQLDatabaseChain + """SQL Database Chain""" + page_content_key: str = "content" + """column name for page content of documents""" + + def _get_relevant_documents( + self, + query: str, + *, + run_manager: CallbackManagerForRetrieverRun, + **kwargs: Any, + ) -> List[Document]: + ret: List[Dict[str, Any]] = self.sql_db_chain( + query, callbacks=run_manager.get_child(), **kwargs + )["result"] + return [ + Document(page_content=r[self.page_content_key], metadata=r) for r in ret + ] + + async def _aget_relevant_documents( + self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun + ) -> List[Document]: + raise NotImplementedError diff --git a/libs/experimental/langchain_experimental/sql/prompt.py b/libs/experimental/langchain_experimental/sql/prompt.py new file mode 100644 index 0000000000000..5f4c9b8a4fd6f --- /dev/null +++ b/libs/experimental/langchain_experimental/sql/prompt.py @@ -0,0 +1,85 @@ +# flake8: noqa +from langchain.prompts.prompt import PromptTemplate + + +PROMPT_SUFFIX = """Only use the following tables: +{table_info} + +Question: {input}""" + +_VECTOR_SQL_DEFAULT_TEMPLATE = """You are a {dialect} expert. Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer to the input question. +{dialect} queries has a vector distance function called `DISTANCE(column, array)` to compute relevance to the user's question and sort the feature array column by the relevance. +When the query is asking for {top_k} closest row, you have to use this distance function to calculate distance to entity's array on vector column and order by the distance to retrieve relevant rows. + +*NOTICE*: `DISTANCE(column, array)` only accept an array column as its first argument and a `NeuralArray(entity)` as its second argument. You also need a user defined function called `NeuralArray(entity)` to retrieve the entity's array. + +Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per {dialect}. You should only order according to the distance function. +Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers. +Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table. +Pay attention to use today() function to get the current date, if the question involves "today". `ORDER BY` clause should always be after `WHERE` clause. DO NOT add semicolon to the end of SQL. Pay attention to the comment in table schema. + +Use the following format: + +Question: "Question here" +SQLQuery: "SQL Query to run" +SQLResult: "Result of the SQLQuery" +Answer: "Final answer here" +""" + +VECTOR_SQL_PROMPT = PromptTemplate( + input_variables=["input", "table_info", "dialect", "top_k"], + template=_VECTOR_SQL_DEFAULT_TEMPLATE + PROMPT_SUFFIX, +) + + +_myscale_prompt = """You are a MyScale expert. Given an input question, first create a syntactically correct MyScale query to run, then look at the results of the query and return the answer to the input question. +MyScale queries has a vector distance function called `DISTANCE(column, array)` to compute relevance to the user's question and sort the feature array column by the relevance. +When the query is asking for {top_k} closest row, you have to use this distance function to calculate distance to entity's array on vector column and order by the distance to retrieve relevant rows. + +*NOTICE*: `DISTANCE(column, array)` only accept an array column as its first argument and a `NeuralArray(entity)` as its second argument. You also need a user defined function called `NeuralArray(entity)` to retrieve the entity's array. + +Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per MyScale. You should only order according to the distance function. +Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers. +Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table. +Pay attention to use today() function to get the current date, if the question involves "today". `ORDER BY` clause should always be after `WHERE` clause. DO NOT add semicolon to the end of SQL. Pay attention to the comment in table schema. + +Use the following format: + +======== table info ======== + + +Question: "Question here" +SQLQuery: "SQL Query to run" + + +Here are some examples: + +======== table info ======== +CREATE TABLE "ChatPaper" ( + abstract String, + id String, + vector Array(Float32), +) ENGINE = ReplicatedReplacingMergeTree() + ORDER BY id + PRIMARY KEY id + +Question: What is Feartue Pyramid Network? +SQLQuery: SELECT ChatPaper.title, ChatPaper.id, ChatPaper.authors FROM ChatPaper ORDER BY DISTANCE(vector, NeuralArray(PaperRank contribution)) LIMIT {top_k} + + +Let's begin: +======== table info ======== +{table_info} + +Question: {input} +SQLQuery: """ + +MYSCALE_PROMPT = PromptTemplate( + input_variables=["input", "table_info", "top_k"], + template=_myscale_prompt + PROMPT_SUFFIX, +) + + +VECTOR_SQL_PROMPTS = { + "myscale": MYSCALE_PROMPT, +} diff --git a/libs/experimental/langchain_experimental/sql/vector_sql.py b/libs/experimental/langchain_experimental/sql/vector_sql.py new file mode 100644 index 0000000000000..98f3c2dee0c18 --- /dev/null +++ b/libs/experimental/langchain_experimental/sql/vector_sql.py @@ -0,0 +1,237 @@ +"""Vector SQL Database Chain Retriever""" +from __future__ import annotations + +from typing import Any, Dict, List, Optional, Union + +from langchain.callbacks.manager import CallbackManagerForChainRun +from langchain.chains.llm import LLMChain +from langchain.chains.sql_database.prompt import PROMPT, SQL_PROMPTS +from langchain.embeddings.base import Embeddings +from langchain.prompts.prompt import PromptTemplate +from langchain.schema import BaseOutputParser, BasePromptTemplate +from langchain.schema.language_model import BaseLanguageModel +from langchain.tools.sql_database.prompt import QUERY_CHECKER +from langchain.utilities.sql_database import SQLDatabase + +from langchain_experimental.sql.base import INTERMEDIATE_STEPS_KEY, SQLDatabaseChain + + +class VectorSQLOutputParser(BaseOutputParser[str]): + """Output Parser for Vector SQL + 1. finds for `NeuralArray()` and replace it with the embedding + 2. finds for `DISTANCE()` and replace it with the distance name in backend SQL + """ + + model: Embeddings + """Embedding model to extract embedding for entity""" + distance_func_name: str = "distance" + """Distance name for Vector SQL""" + + class Config: + arbitrary_types_allowed = 1 + + @property + def _type(self) -> str: + return "vector_sql_parser" + + @classmethod + def from_embeddings( + cls, model: Embeddings, distance_func_name: str = "distance", **kwargs: Any + ) -> BaseOutputParser: + return cls(model=model, distance_func_name=distance_func_name, **kwargs) + + def parse(self, text: str) -> str: + text = text.strip() + start = text.find("NeuralArray(") + _sql_str_compl = text + if start > 0: + _matched = text[text.find("NeuralArray(") + len("NeuralArray(") :] + end = _matched.find(")") + start + len("NeuralArray(") + 1 + entity = _matched[: _matched.find(")")] + vecs = self.model.embed_query(entity) + vecs_str = "[" + ",".join(map(str, vecs)) + "]" + _sql_str_compl = text.replace("DISTANCE", self.distance_func_name).replace( + text[start:end], vecs_str + ) + if _sql_str_compl[-1] == ";": + _sql_str_compl = _sql_str_compl[:-1] + return _sql_str_compl + + +class VectorSQLRetrieveAllOutputParser(VectorSQLOutputParser): + """Based on VectorSQLOutputParser + It also modify the SQL to get all columns + """ + + @property + def _type(self) -> str: + return "vector_sql_retrieve_all_parser" + + def parse(self, text: str) -> str: + text = text.strip() + start = text.upper().find("SELECT") + if start >= 0: + end = text.upper().find("FROM") + text = text.replace(text[start + len("SELECT") + 1 : end - 1], "*") + return super().parse(text) + + +def _try_eval(x: Any) -> Any: + try: + return eval(x) + except Exception: + return x + + +def get_result_from_sqldb( + db: SQLDatabase, cmd: str +) -> Union[str, List[Dict[str, Any]], Dict[str, Any]]: + result = db._execute(cmd, fetch="all") # type: ignore + if isinstance(result, list): + return [{k: _try_eval(v) for k, v in dict(d._asdict()).items()} for d in result] + else: + return { + k: _try_eval(v) for k, v in dict(result._asdict()).items() # type: ignore + } + + +class VectorSQLDatabaseChain(SQLDatabaseChain): + """Chain for interacting with Vector SQL Database. + + Example: + .. code-block:: python + + from langchain_experimental.sql import SQLDatabaseChain + from langchain import OpenAI, SQLDatabase, OpenAIEmbeddings + db = SQLDatabase(...) + db_chain = VectorSQLDatabaseChain.from_llm(OpenAI(), db, OpenAIEmbeddings()) + + *Security note*: Make sure that the database connection uses credentials + that are narrowly-scoped to only include the permissions this chain needs. + Failure to do so may result in data corruption or loss, since this chain may + attempt commands like `DROP TABLE` or `INSERT` if appropriately prompted. + The best way to guard against such negative outcomes is to (as appropriate) + limit the permissions granted to the credentials used with this chain. + This issue shows an example negative outcome if these steps are not taken: + https://github.com/langchain-ai/langchain/issues/5923 + """ + + sql_cmd_parser: VectorSQLOutputParser + """Parser for Vector SQL""" + native_format: bool = False + """If return_direct, controls whether to return in python native format""" + + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, Any]: + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() + input_text = f"{inputs[self.input_key]}\nSQLQuery:" + _run_manager.on_text(input_text, verbose=self.verbose) + # If not present, then defaults to None which is all tables. + table_names_to_use = inputs.get("table_names_to_use") + table_info = self.database.get_table_info(table_names=table_names_to_use) + llm_inputs = { + "input": input_text, + "top_k": str(self.top_k), + "dialect": self.database.dialect, + "table_info": table_info, + "stop": ["\nSQLResult:"], + } + intermediate_steps: List = [] + try: + intermediate_steps.append(llm_inputs) # input: sql generation + llm_out = self.llm_chain.predict( + callbacks=_run_manager.get_child(), + **llm_inputs, + ) + sql_cmd = self.sql_cmd_parser.parse(llm_out) + if self.return_sql: + return {self.output_key: sql_cmd} + if not self.use_query_checker: + _run_manager.on_text(llm_out, color="green", verbose=self.verbose) + intermediate_steps.append( + llm_out + ) # output: sql generation (no checker) + intermediate_steps.append({"sql_cmd": llm_out}) # input: sql exec + result = get_result_from_sqldb(self.database, sql_cmd) + intermediate_steps.append(str(result)) # output: sql exec + else: + query_checker_prompt = self.query_checker_prompt or PromptTemplate( + template=QUERY_CHECKER, input_variables=["query", "dialect"] + ) + query_checker_chain = LLMChain( + llm=self.llm_chain.llm, + prompt=query_checker_prompt, + output_parser=self.llm_chain.output_parser, + ) + query_checker_inputs = { + "query": llm_out, + "dialect": self.database.dialect, + } + checked_llm_out = query_checker_chain.predict( + callbacks=_run_manager.get_child(), **query_checker_inputs + ) + checked_sql_command = self.sql_cmd_parser.parse(checked_llm_out) + intermediate_steps.append( + checked_llm_out + ) # output: sql generation (checker) + _run_manager.on_text( + checked_llm_out, color="green", verbose=self.verbose + ) + intermediate_steps.append( + {"sql_cmd": checked_llm_out} + ) # input: sql exec + result = get_result_from_sqldb(self.database, checked_sql_command) + intermediate_steps.append(str(result)) # output: sql exec + llm_out = checked_llm_out + sql_cmd = checked_sql_command + + _run_manager.on_text("\nSQLResult: ", verbose=self.verbose) + _run_manager.on_text(str(result), color="yellow", verbose=self.verbose) + # If return direct, we just set the final result equal to + # the result of the sql query result, otherwise try to get a human readable + # final answer + if self.return_direct: + final_result = result + else: + _run_manager.on_text("\nAnswer:", verbose=self.verbose) + input_text += f"{llm_out}\nSQLResult: {result}\nAnswer:" + llm_inputs["input"] = input_text + intermediate_steps.append(llm_inputs) # input: final answer + final_result = self.llm_chain.predict( + callbacks=_run_manager.get_child(), + **llm_inputs, + ).strip() + intermediate_steps.append(final_result) # output: final answer + _run_manager.on_text(final_result, color="green", verbose=self.verbose) + chain_result: Dict[str, Any] = {self.output_key: final_result} + if self.return_intermediate_steps: + chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps + return chain_result + except Exception as exc: + # Append intermediate steps to exception, to aid in logging and later + # improvement of few shot prompt seeds + exc.intermediate_steps = intermediate_steps # type: ignore + raise exc + + @property + def _chain_type(self) -> str: + return "vector_sql_database_chain" + + @classmethod + def from_llm( + cls, + llm: BaseLanguageModel, + db: SQLDatabase, + prompt: Optional[BasePromptTemplate] = None, + sql_cmd_parser: Optional[VectorSQLOutputParser] = None, + **kwargs: Any, + ) -> VectorSQLDatabaseChain: + assert sql_cmd_parser, "`sql_cmd_parser` must be set in VectorSQLDatabaseChain." + prompt = prompt or SQL_PROMPTS.get(db.dialect, PROMPT) + llm_chain = LLMChain(llm=llm, prompt=prompt) + return cls( + llm_chain=llm_chain, database=db, sql_cmd_parser=sql_cmd_parser, **kwargs + ) diff --git a/libs/experimental/poetry.lock b/libs/experimental/poetry.lock index 620da0f99ae11..9e8cf9f1aff2d 100644 --- a/libs/experimental/poetry.lock +++ b/libs/experimental/poetry.lock @@ -1245,6 +1245,7 @@ optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*" files = [ {file = "jsonpointer-2.4-py2.py3-none-any.whl", hash = "sha256:15d51bba20eea3165644553647711d150376234112651b4f1811022aecad7d7a"}, + {file = "jsonpointer-2.4.tar.gz", hash = "sha256:585cee82b70211fa9e6043b7bb89db6e1aa49524340dde8ad6b63206ea689d88"}, ] [[package]] diff --git a/libs/langchain/langchain/utilities/sql_database.py b/libs/langchain/langchain/utilities/sql_database.py index e621ffd17bd26..13718c8c0c7f6 100644 --- a/libs/langchain/langchain/utilities/sql_database.py +++ b/libs/langchain/langchain/utilities/sql_database.py @@ -9,6 +9,7 @@ from sqlalchemy.engine import Engine from sqlalchemy.exc import ProgrammingError, SQLAlchemyError from sqlalchemy.schema import CreateTable +from sqlalchemy.types import NullType from langchain.utils import get_from_env @@ -314,6 +315,11 @@ def get_table_info(self, table_names: Optional[List[str]] = None) -> str: tables.append(self._custom_table_info[table.name]) continue + # Ignore JSON datatyped columns + for k, v in table.columns.items(): + if type(v.type) is NullType: + table._columns.remove(v) + # add create table command create_table = str(CreateTable(table).compile(self._engine)) table_info = f"{create_table.rstrip()}" From f4f9254dadd470766b2c21c2f8b8e37a299a57fb Mon Sep 17 00:00:00 2001 From: Bagatur Date: Wed, 6 Sep 2023 17:09:40 -0700 Subject: [PATCH 2/2] Move Myscale SQL vector retrieval nb --- .../qa_structured/integrations}/myscale_vector_sql.ipynb | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename docs/extras/{modules/data_connection/retrievers/sql_database => use_cases/qa_structured/integrations}/myscale_vector_sql.ipynb (100%) diff --git a/docs/extras/modules/data_connection/retrievers/sql_database/myscale_vector_sql.ipynb b/docs/extras/use_cases/qa_structured/integrations/myscale_vector_sql.ipynb similarity index 100% rename from docs/extras/modules/data_connection/retrievers/sql_database/myscale_vector_sql.ipynb rename to docs/extras/use_cases/qa_structured/integrations/myscale_vector_sql.ipynb