diff --git a/libs/community/langchain_community/utilities/sql_database.py b/libs/community/langchain_community/utilities/sql_database.py index fb542cfa88392..79aceb2b71f03 100644 --- a/libs/community/langchain_community/utilities/sql_database.py +++ b/libs/community/langchain_community/utilities/sql_database.py @@ -58,6 +58,7 @@ def __init__( custom_table_info: Optional[dict] = None, view_support: bool = False, max_string_length: int = 300, + restricted_keywords: Optional[List[str]] = None, ): """Create engine from database URI.""" self._engine = engine @@ -123,6 +124,10 @@ def __init__( schema=self._schema, ) + # Restricted keywords to not execute on database + self.restricted_keywords = restricted_keywords if restricted_keywords else [] + + @classmethod def from_uri( cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any @@ -447,11 +452,18 @@ def _execute( pass else: raise TypeError(f"Query expression has unknown type: {type(command)}") - cursor = connection.execute( - command, - parameters, - execution_options=execution_options, - ) + if not self.detect_restricted_keywords(command): + cursor = connection.execute( + command, + parameters, + execution_options=execution_options, + ) + else: + raise PermissionError( + f"Restricted keywords in the SQL '{command}' " + f"Commands '{self.restricted_keywords}' are forbidden." + ) + if cursor.returns_rows: if fetch == "all": diff --git a/libs/community/tests/unit_tests/test_sql_database_schema.py b/libs/community/tests/unit_tests/test_sql_database_schema.py index 22f12ab582a07..cead860e123bb 100644 --- a/libs/community/tests/unit_tests/test_sql_database_schema.py +++ b/libs/community/tests/unit_tests/test_sql_database_schema.py @@ -91,3 +91,26 @@ def test_sql_database_run() -> None: output = db.run(command) expected_output = "[('Harrison',)]" assert output == expected_output + + +def test_sql_restricted_keywords() -> None: + """Test that given keywords by the user will stop the execution of the SQL command and raise an error.""" + engine = create_engine("duckdb:///:memory:") + metadata_obj.create_all(engine) + + restricted_keywords = ["drop"] + db = SQLDatabase( + engine, + schema="schema_a", + metadata=metadata_obj, + restricted_keywords=restricted_keywords, + ) + + command = 'DROP TABLE IF EXISTS "user"' + with pytest.raises(PermissionError) as records: + db.run(command) + + assert ( + records.value.args[0] == f"Restricted keywords in the SQL '{command}' " + f"Commands '{restricted_keywords}' are forbidden." + )