Skip to content

Commit

Permalink
Allow tuple-valued params in read_sql[_query] (pandas-dev#997)
Browse files Browse the repository at this point in the history
  • Loading branch information
sh-at-cs authored Sep 12, 2024
1 parent 336718a commit 708c8aa
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 4 deletions.
36 changes: 32 additions & 4 deletions pandas-stubs/io/sql.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,14 @@ def read_sql_query(
con: _SQLConnection,
index_col: str | list[str] | None = ...,
coerce_float: bool = ...,
params: list[Scalar] | tuple[Scalar, ...] | Mapping[str, Scalar] | None = ...,
params: (
list[Scalar]
| tuple[Scalar, ...]
| tuple[tuple[Scalar, ...], ...]
| Mapping[str, Scalar]
| Mapping[str, tuple[Scalar, ...]]
| None
) = ...,
parse_dates: list[str] | dict[str, str] | dict[str, dict[str, Any]] | None = ...,
*,
chunksize: int,
Expand All @@ -79,7 +86,14 @@ def read_sql_query(
con: _SQLConnection,
index_col: str | list[str] | None = ...,
coerce_float: bool = ...,
params: list[Scalar] | tuple[Scalar, ...] | Mapping[str, Scalar] | None = ...,
params: (
list[Scalar]
| tuple[Scalar, ...]
| tuple[tuple[Scalar, ...], ...]
| Mapping[str, Scalar]
| Mapping[str, tuple[Scalar, ...]]
| None
) = ...,
parse_dates: list[str] | dict[str, str] | dict[str, dict[str, Any]] | None = ...,
chunksize: None = ...,
dtype: DtypeArg | None = ...,
Expand All @@ -91,7 +105,14 @@ def read_sql(
con: _SQLConnection,
index_col: str | list[str] | None = ...,
coerce_float: bool = ...,
params: list[Scalar] | tuple[Scalar, ...] | Mapping[str, Scalar] | None = ...,
params: (
list[Scalar]
| tuple[Scalar, ...]
| tuple[tuple[Scalar, ...], ...]
| Mapping[str, Scalar]
| Mapping[str, tuple[Scalar, ...]]
| None
) = ...,
parse_dates: list[str] | dict[str, str] | dict[str, dict[str, Any]] | None = ...,
columns: list[str] = ...,
*,
Expand All @@ -105,7 +126,14 @@ def read_sql(
con: _SQLConnection,
index_col: str | list[str] | None = ...,
coerce_float: bool = ...,
params: list[Scalar] | tuple[Scalar, ...] | Mapping[str, Scalar] | None = ...,
params: (
list[Scalar]
| tuple[Scalar, ...]
| tuple[tuple[Scalar, ...], ...]
| Mapping[str, Scalar]
| Mapping[str, tuple[Scalar, ...]]
| None
) = ...,
parse_dates: list[str] | dict[str, str] | dict[str, dict[str, Any]] | None = ...,
columns: list[str] = ...,
chunksize: None = ...,
Expand Down
33 changes: 33 additions & 0 deletions tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -1238,6 +1238,39 @@ def test_read_sql_query_via_sqlalchemy_engine_with_params():
engine.dispose()


@pytest.mark.skip(
reason="Only works in Postgres (and MySQL, but with different query syntax)"
)
def test_read_sql_query_via_sqlalchemy_engine_with_tuple_valued_params():
with ensure_clean() as path:
db_uri = "postgresql+psycopg2://postgres@localhost:5432/postgres"
engine = sqlalchemy.create_engine(db_uri)

check(
assert_type(
read_sql_query(
"select * from test where a in %(a)s",
con=engine,
params={"a": (1, 2)},
),
DataFrame,
),
DataFrame,
)
check(
assert_type(
read_sql_query(
"select * from test where a in %s",
con=engine,
params=((1, 2),),
),
DataFrame,
),
DataFrame,
)
engine.dispose()


def test_read_html():
check(assert_type(DF.to_html(), str), str)
with ensure_clean() as path:
Expand Down

0 comments on commit 708c8aa

Please sign in to comment.