Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

community: allow SQLDatabase to run queries with parameters #15453

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions libs/community/langchain_community/utilities/sql_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from __future__ import annotations

import warnings
from typing import Any, Dict, Iterable, List, Literal, Optional, Sequence
from typing import Any, Dict, Iterable, List, Literal, Optional, Sequence, Union

import sqlalchemy
from langchain_core.utils import get_from_env
Expand Down Expand Up @@ -376,14 +376,15 @@ def _get_sample_rows(self, table: Table) -> str:
def _execute(
self,
command: str,
fetch: Literal["all", "one"] = "all",
fetch: Union[Literal["all"], Literal["one"]] = "all",
query_params: Optional[Dict[str, Any]] = None,
) -> Sequence[Dict[str, Any]]:
"""
Executes SQL command through underlying engine.

If the statement returns no rows, an empty list is returned.
"""
with self._engine.begin() as connection: # type: Connection
with self._engine.begin() as connection:
if self._schema is not None:
if self.dialect == "snowflake":
connection.exec_driver_sql(
Expand Down Expand Up @@ -411,7 +412,7 @@ def _execute(
pass
else: # postgresql and other compatible dialects
connection.exec_driver_sql("SET search_path TO %s", (self._schema,))
cursor = connection.execute(text(command))
cursor = connection.execute(text(command), parameters=query_params)
if cursor.returns_rows:
if fetch == "all":
result = [x._asdict() for x in cursor.fetchall()]
Expand All @@ -426,15 +427,17 @@ def _execute(
def run(
self,
command: str,
fetch: Literal["all", "one"] = "all",
fetch: Union[Literal["all"], Literal["one"]] = "all",
include_columns: bool = False,
query_params: Optional[Dict[str, Any]] = None,
) -> str:
"""Execute a SQL command and return a string representing the results.
Optionally, query_params may be provided to parameterize the query.

If the statement returns rows, a string of the results is returned.
If the statement returns no rows, an empty string is returned.
"""
result = self._execute(command, fetch)
result = self._execute(command, fetch, query_params)

res = [
{
Expand Down Expand Up @@ -471,18 +474,20 @@ def get_table_info_no_throw(self, table_names: Optional[List[str]] = None) -> st
def run_no_throw(
self,
command: str,
fetch: Literal["all", "one"] = "all",
fetch: Union[Literal["all"], Literal["one"]] = "all",
include_columns: bool = False,
query_params: Optional[Dict[str, Any]] = None,
) -> str:
"""Execute a SQL command and return a string representing the results.
Optionally, query_params may be provided to parameterize the query.

If the statement returns rows, a string of the results is returned.
If the statement returns no rows, an empty string is returned.

If the statement throws an error, the error message is returned.
"""
try:
return self.run(command, fetch, include_columns)
return self.run(command, fetch, include_columns, query_params)
except SQLAlchemyError as e:
"""Format the error message"""
return f"Error: {e}"
Loading