From 1d10189b221ff30251f8b2f6a5e214ecf4264312 Mon Sep 17 00:00:00 2001 From: Antoine Tavernier <6902440+Uranium2@users.noreply.github.com> Date: Tue, 1 Aug 2023 15:06:05 +0200 Subject: [PATCH 01/10] Update sql_database.py Raise a `PermissionError` when a provided SQL keyword is provided to `SQLDatabase`. --- libs/langchain/langchain/utilities/sql_database.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/libs/langchain/langchain/utilities/sql_database.py b/libs/langchain/langchain/utilities/sql_database.py index 110f081d3c0ee..1cc96394ac047 100644 --- a/libs/langchain/langchain/utilities/sql_database.py +++ b/libs/langchain/langchain/utilities/sql_database.py @@ -50,6 +50,7 @@ def __init__( custom_table_info: Optional[dict] = None, view_support: bool = False, max_string_length: int = 300, + harmful_keywords: Optional[dict] = None, ): """Create engine from database URI.""" self._engine = engine @@ -115,6 +116,10 @@ def __init__( schema=self._schema, ) + # Harmful keywords to not execute on database + self.harmful_keywords = harmful_keywords if harmful_keywords else [] + + @classmethod def from_uri( cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any @@ -386,7 +391,11 @@ def _execute(self, command: str, fetch: Optional[str] = "all") -> Sequence: pass else: # postgresql and compatible dialects connection.exec_driver_sql(f"SET search_path TO {self._schema}") - cursor = connection.execute(text(command)) + + if not self.detect_harmful_actions(command): + cursor = connection.execute(text(command)) + else: + raise PermissionError(f"Harmful actions in the SQL '{command}'\n Commands '{self.harmful_keywords}' are forbidden.") if cursor.returns_rows: if fetch == "all": result = cursor.fetchall() From 24059e6a4d56547a406ad19ebb7b7a1c5c4c6a3a Mon Sep 17 00:00:00 2001 From: Antoine Tavernier <6902440+Uranium2@users.noreply.github.com> Date: Tue, 1 Aug 2023 15:49:37 +0200 Subject: [PATCH 02/10] Update sql_database.py add detect_harmful_actions --- libs/langchain/langchain/utilities/sql_database.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/libs/langchain/langchain/utilities/sql_database.py b/libs/langchain/langchain/utilities/sql_database.py index 1cc96394ac047..8e605a02108bf 100644 --- a/libs/langchain/langchain/utilities/sql_database.py +++ b/libs/langchain/langchain/utilities/sql_database.py @@ -406,6 +406,20 @@ def _execute(self, command: str, fetch: Optional[str] = "all") -> Sequence: return result return [] + def detect_harmful_actions(self, input_string): + # List of harmful keywords + harmful_keywords = [keyword.lower() for keyword in self.harmful_keywords] + + # Convert the input string to lowercase for case-insensitive matching + input_lower = input_string.lower() + + # Check if any harmful keyword is present in the input string + for keyword in harmful_keywords: + if keyword in input_lower: + return True + + return False + def run(self, command: str, fetch: str = "all") -> str: """Execute a SQL command and return a string representing the results. From 6e98c61f15816b84bd2103f0eea1b5bbec42038c Mon Sep 17 00:00:00 2001 From: Antoine Tavernier <6902440+Uranium2@users.noreply.github.com> Date: Tue, 1 Aug 2023 16:21:39 +0200 Subject: [PATCH 03/10] Update sql_database.py make format --- libs/langchain/langchain/utilities/sql_database.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/libs/langchain/langchain/utilities/sql_database.py b/libs/langchain/langchain/utilities/sql_database.py index 8e605a02108bf..c5e4033637449 100644 --- a/libs/langchain/langchain/utilities/sql_database.py +++ b/libs/langchain/langchain/utilities/sql_database.py @@ -119,7 +119,6 @@ def __init__( # Harmful keywords to not execute on database self.harmful_keywords = harmful_keywords if harmful_keywords else [] - @classmethod def from_uri( cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any @@ -395,7 +394,9 @@ def _execute(self, command: str, fetch: Optional[str] = "all") -> Sequence: if not self.detect_harmful_actions(command): cursor = connection.execute(text(command)) else: - raise PermissionError(f"Harmful actions in the SQL '{command}'\n Commands '{self.harmful_keywords}' are forbidden.") + raise PermissionError( + f"Harmful actions in the SQL '{command}'\n Commands '{self.harmful_keywords}' are forbidden." + ) if cursor.returns_rows: if fetch == "all": result = cursor.fetchall() From 5037a4b8c3e0fb40d4bbbf8c3de0666f40b6096d Mon Sep 17 00:00:00 2001 From: Antoine Tavernier <6902440+Uranium2@users.noreply.github.com> Date: Tue, 1 Aug 2023 16:37:38 +0200 Subject: [PATCH 04/10] Update test_sql_database_schema.py Add test_sql_harmful_keywords --- .../unit_tests/test_sql_database_schema.py | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/libs/langchain/tests/unit_tests/test_sql_database_schema.py b/libs/langchain/tests/unit_tests/test_sql_database_schema.py index b2a6589463dd5..e144d66c3f0b5 100644 --- a/libs/langchain/tests/unit_tests/test_sql_database_schema.py +++ b/libs/langchain/tests/unit_tests/test_sql_database_schema.py @@ -91,3 +91,26 @@ def test_sql_database_run() -> None: output = db.run(command) expected_output = "[('Harrison',)]" assert output == expected_output + + +def test_sql_harmful_keywords(): + """Test that given keywords by the user will stop the execution of the SQL command and raise an error.""" + engine = create_engine("duckdb:///:memory:") + metadata_obj.create_all(engine) + + harmful_keywords = ["drop"] + db = SQLDatabase( + engine, + schema="schema_a", + metadata=metadata_obj, + harmful_keywords=harmful_keywords, + ) + + command = 'DROP DATABASE IF EXISTS "user"' + with pytest.raises(PermissionError) as records: + db.run(command) + + assert ( + records.value.args[0] + == f"""Harmful actions in the SQL '{command}'\n Commands '{harmful_keywords}' are forbidden.""" + ) From 61bfc47412251d184b20b5152d0a09abb0b59ec7 Mon Sep 17 00:00:00 2001 From: Antoine Tavernier <6902440+Uranium2@users.noreply.github.com> Date: Tue, 1 Aug 2023 16:41:52 +0200 Subject: [PATCH 05/10] change actions to keywords --- libs/langchain/langchain/utilities/sql_database.py | 6 +++--- libs/langchain/tests/unit_tests/test_sql_database_schema.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/libs/langchain/langchain/utilities/sql_database.py b/libs/langchain/langchain/utilities/sql_database.py index c5e4033637449..8fe38b7cea7b7 100644 --- a/libs/langchain/langchain/utilities/sql_database.py +++ b/libs/langchain/langchain/utilities/sql_database.py @@ -391,11 +391,11 @@ def _execute(self, command: str, fetch: Optional[str] = "all") -> Sequence: else: # postgresql and compatible dialects connection.exec_driver_sql(f"SET search_path TO {self._schema}") - if not self.detect_harmful_actions(command): + if not self.detect_harmful_keywords(command): cursor = connection.execute(text(command)) else: raise PermissionError( - f"Harmful actions in the SQL '{command}'\n Commands '{self.harmful_keywords}' are forbidden." + f"Harmful keywords in the SQL '{command}'\n Commands '{self.harmful_keywords}' are forbidden." ) if cursor.returns_rows: if fetch == "all": @@ -407,7 +407,7 @@ def _execute(self, command: str, fetch: Optional[str] = "all") -> Sequence: return result return [] - def detect_harmful_actions(self, input_string): + def detect_harmful_keywords(self, input_string): # List of harmful keywords harmful_keywords = [keyword.lower() for keyword in self.harmful_keywords] diff --git a/libs/langchain/tests/unit_tests/test_sql_database_schema.py b/libs/langchain/tests/unit_tests/test_sql_database_schema.py index e144d66c3f0b5..bf4a21a088bd1 100644 --- a/libs/langchain/tests/unit_tests/test_sql_database_schema.py +++ b/libs/langchain/tests/unit_tests/test_sql_database_schema.py @@ -112,5 +112,5 @@ def test_sql_harmful_keywords(): assert ( records.value.args[0] - == f"""Harmful actions in the SQL '{command}'\n Commands '{harmful_keywords}' are forbidden.""" + == f"""Harmful keywords in the SQL '{command}'\n Commands '{harmful_keywords}' are forbidden.""" ) From a4084094e7a2ccf7d53710841c4bb9ec67f87b53 Mon Sep 17 00:00:00 2001 From: Antoine Tavernier <6902440+Uranium2@users.noreply.github.com> Date: Tue, 1 Aug 2023 16:47:01 +0200 Subject: [PATCH 06/10] Update test_sql_database_schema.py user is a table not a database --- libs/langchain/tests/unit_tests/test_sql_database_schema.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/langchain/tests/unit_tests/test_sql_database_schema.py b/libs/langchain/tests/unit_tests/test_sql_database_schema.py index bf4a21a088bd1..9b73543ba9173 100644 --- a/libs/langchain/tests/unit_tests/test_sql_database_schema.py +++ b/libs/langchain/tests/unit_tests/test_sql_database_schema.py @@ -106,7 +106,7 @@ def test_sql_harmful_keywords(): harmful_keywords=harmful_keywords, ) - command = 'DROP DATABASE IF EXISTS "user"' + command = 'DROP TABLE IF EXISTS "user"' with pytest.raises(PermissionError) as records: db.run(command) From eeda7f29653ed00670076140c9087f76b3d115d8 Mon Sep 17 00:00:00 2001 From: Antoine Tavernier <6902440+Uranium2@users.noreply.github.com> Date: Wed, 2 Aug 2023 09:19:25 +0200 Subject: [PATCH 07/10] harmful -> restricted More representative of the real usage of the feature. --- .../langchain/utilities/sql_database.py | 20 +++++++++---------- .../unit_tests/test_sql_database_schema.py | 8 ++++---- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/libs/langchain/langchain/utilities/sql_database.py b/libs/langchain/langchain/utilities/sql_database.py index 8fe38b7cea7b7..1e8c80405cc4a 100644 --- a/libs/langchain/langchain/utilities/sql_database.py +++ b/libs/langchain/langchain/utilities/sql_database.py @@ -50,7 +50,7 @@ def __init__( custom_table_info: Optional[dict] = None, view_support: bool = False, max_string_length: int = 300, - harmful_keywords: Optional[dict] = None, + restricted_keywords: Optional[dict] = None, ): """Create engine from database URI.""" self._engine = engine @@ -116,8 +116,8 @@ def __init__( schema=self._schema, ) - # Harmful keywords to not execute on database - self.harmful_keywords = harmful_keywords if harmful_keywords else [] + # Restricted keywords to not execute on database + self.restricted_keywords = restricted_keywords if restricted_keywords else [] @classmethod def from_uri( @@ -391,11 +391,11 @@ def _execute(self, command: str, fetch: Optional[str] = "all") -> Sequence: else: # postgresql and compatible dialects connection.exec_driver_sql(f"SET search_path TO {self._schema}") - if not self.detect_harmful_keywords(command): + if not self.detect_restricted_keywords(command): cursor = connection.execute(text(command)) else: raise PermissionError( - f"Harmful keywords in the SQL '{command}'\n Commands '{self.harmful_keywords}' are forbidden." + f"Restricted keywords in the SQL '{command}'\n Commands '{self.restricted_keywords}' are forbidden." ) if cursor.returns_rows: if fetch == "all": @@ -407,15 +407,15 @@ def _execute(self, command: str, fetch: Optional[str] = "all") -> Sequence: return result return [] - def detect_harmful_keywords(self, input_string): - # List of harmful keywords - harmful_keywords = [keyword.lower() for keyword in self.harmful_keywords] + def detect_restricted_keywords(self, input_string): + # List of restricted keywords + restricted_keywords = [keyword.lower() for keyword in self.restricted_keywords] # Convert the input string to lowercase for case-insensitive matching input_lower = input_string.lower() - # Check if any harmful keyword is present in the input string - for keyword in harmful_keywords: + # Check if any restricted keyword is present in the input string + for keyword in restricted_keywords: if keyword in input_lower: return True diff --git a/libs/langchain/tests/unit_tests/test_sql_database_schema.py b/libs/langchain/tests/unit_tests/test_sql_database_schema.py index 9b73543ba9173..9539b304b02b0 100644 --- a/libs/langchain/tests/unit_tests/test_sql_database_schema.py +++ b/libs/langchain/tests/unit_tests/test_sql_database_schema.py @@ -93,17 +93,17 @@ def test_sql_database_run() -> None: assert output == expected_output -def test_sql_harmful_keywords(): +def test_sql_restricted_keywords(): """Test that given keywords by the user will stop the execution of the SQL command and raise an error.""" engine = create_engine("duckdb:///:memory:") metadata_obj.create_all(engine) - harmful_keywords = ["drop"] + restricted_keywords = ["drop"] db = SQLDatabase( engine, schema="schema_a", metadata=metadata_obj, - harmful_keywords=harmful_keywords, + restricted_keywords=restricted_keywords, ) command = 'DROP TABLE IF EXISTS "user"' @@ -112,5 +112,5 @@ def test_sql_harmful_keywords(): assert ( records.value.args[0] - == f"""Harmful keywords in the SQL '{command}'\n Commands '{harmful_keywords}' are forbidden.""" + == f"""Restricted keywords in the SQL '{command}'\n Commands '{restricted_keywords}' are forbidden.""" ) From 2eae72c5e43e44f4a47a45167169c26ef99f930d Mon Sep 17 00:00:00 2001 From: Antoine Tavernier <6902440+Uranium2@users.noreply.github.com> Date: Wed, 2 Aug 2023 09:30:21 +0200 Subject: [PATCH 08/10] Fix Linting ran make lint and pass my files for the feature, but other files are in error --- libs/langchain/langchain/utilities/sql_database.py | 4 ++-- libs/langchain/tests/unit_tests/test_sql_database_schema.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/libs/langchain/langchain/utilities/sql_database.py b/libs/langchain/langchain/utilities/sql_database.py index 1e8c80405cc4a..581f04205cbf1 100644 --- a/libs/langchain/langchain/utilities/sql_database.py +++ b/libs/langchain/langchain/utilities/sql_database.py @@ -50,7 +50,7 @@ def __init__( custom_table_info: Optional[dict] = None, view_support: bool = False, max_string_length: int = 300, - restricted_keywords: Optional[dict] = None, + restricted_keywords: Optional[List[str]] = None, ): """Create engine from database URI.""" self._engine = engine @@ -407,7 +407,7 @@ def _execute(self, command: str, fetch: Optional[str] = "all") -> Sequence: return result return [] - def detect_restricted_keywords(self, input_string): + def detect_restricted_keywords(self, input_string: str) -> bool: # List of restricted keywords restricted_keywords = [keyword.lower() for keyword in self.restricted_keywords] diff --git a/libs/langchain/tests/unit_tests/test_sql_database_schema.py b/libs/langchain/tests/unit_tests/test_sql_database_schema.py index 9539b304b02b0..8833364651927 100644 --- a/libs/langchain/tests/unit_tests/test_sql_database_schema.py +++ b/libs/langchain/tests/unit_tests/test_sql_database_schema.py @@ -93,7 +93,7 @@ def test_sql_database_run() -> None: assert output == expected_output -def test_sql_restricted_keywords(): +def test_sql_restricted_keywords() -> None: """Test that given keywords by the user will stop the execution of the SQL command and raise an error.""" engine = create_engine("duckdb:///:memory:") metadata_obj.create_all(engine) From 3b382a2718cd868eea8234a63d68ec773656c3b3 Mon Sep 17 00:00:00 2001 From: Antoine Tavernier <6902440+Uranium2@users.noreply.github.com> Date: Wed, 2 Aug 2023 11:26:13 +0200 Subject: [PATCH 09/10] poetry run ruff . Fix line too long --- libs/langchain/langchain/utilities/sql_database.py | 3 ++- libs/langchain/tests/unit_tests/test_sql_database_schema.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/libs/langchain/langchain/utilities/sql_database.py b/libs/langchain/langchain/utilities/sql_database.py index 581f04205cbf1..7cd03f7736b79 100644 --- a/libs/langchain/langchain/utilities/sql_database.py +++ b/libs/langchain/langchain/utilities/sql_database.py @@ -395,7 +395,8 @@ def _execute(self, command: str, fetch: Optional[str] = "all") -> Sequence: cursor = connection.execute(text(command)) else: raise PermissionError( - f"Restricted keywords in the SQL '{command}'\n Commands '{self.restricted_keywords}' are forbidden." + f"Restricted keywords in the SQL '{command}' " + f"Commands '{self.restricted_keywords}' are forbidden." ) if cursor.returns_rows: if fetch == "all": diff --git a/libs/langchain/tests/unit_tests/test_sql_database_schema.py b/libs/langchain/tests/unit_tests/test_sql_database_schema.py index 8833364651927..2f880f001d84b 100644 --- a/libs/langchain/tests/unit_tests/test_sql_database_schema.py +++ b/libs/langchain/tests/unit_tests/test_sql_database_schema.py @@ -111,6 +111,6 @@ def test_sql_restricted_keywords() -> None: db.run(command) assert ( - records.value.args[0] - == f"""Restricted keywords in the SQL '{command}'\n Commands '{restricted_keywords}' are forbidden.""" + records.value.args[0] == f"Restricted keywords in the SQL '{command}' " + f"Commands '{restricted_keywords}' are forbidden." ) From 6feba54c86027850a709436acb886d9709dcb584 Mon Sep 17 00:00:00 2001 From: Uranium2 Date: Sun, 18 Feb 2024 19:52:28 +0100 Subject: [PATCH 10/10] readd my stuff --- .../utilities/sql_database.py | 22 ++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/libs/community/langchain_community/utilities/sql_database.py b/libs/community/langchain_community/utilities/sql_database.py index fb542cfa88392..79aceb2b71f03 100644 --- a/libs/community/langchain_community/utilities/sql_database.py +++ b/libs/community/langchain_community/utilities/sql_database.py @@ -58,6 +58,7 @@ def __init__( custom_table_info: Optional[dict] = None, view_support: bool = False, max_string_length: int = 300, + restricted_keywords: Optional[List[str]] = None, ): """Create engine from database URI.""" self._engine = engine @@ -123,6 +124,10 @@ def __init__( schema=self._schema, ) + # Restricted keywords to not execute on database + self.restricted_keywords = restricted_keywords if restricted_keywords else [] + + @classmethod def from_uri( cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any @@ -447,11 +452,18 @@ def _execute( pass else: raise TypeError(f"Query expression has unknown type: {type(command)}") - cursor = connection.execute( - command, - parameters, - execution_options=execution_options, - ) + if not self.detect_restricted_keywords(command): + cursor = connection.execute( + command, + parameters, + execution_options=execution_options, + ) + else: + raise PermissionError( + f"Restricted keywords in the SQL '{command}' " + f"Commands '{self.restricted_keywords}' are forbidden." + ) + if cursor.returns_rows: if fetch == "all":