Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix aiopg + sqlalchemy >= 1.4 compatibility issue #167

Merged
merged 2 commits into from
Sep 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,17 @@ jobs:
- name: Run unit tests
run: tox run -- --cov-report=term

test-db:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Install Lets
uses: lets-cli/[email protected]
with:
version: latest
- name: Test database integration
run: timeout 600 lets test-pg

federation-test:
runs-on: ubuntu-latest
steps:
Expand Down
42 changes: 42 additions & 0 deletions hiku/sources/aiopg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
Iterable,
Any,
List,
Iterator,
Tuple,
)

import sqlalchemy
from sqlalchemy.sql import Select
from sqlalchemy import any_
from sqlalchemy.sql.elements import BinaryExpression

Expand All @@ -17,12 +20,51 @@
FETCH_SIZE = 100


def _uniq_fields(fields: List[Field]) -> Iterator[Field]:
visited = set()
for f in fields:
if f.name not in visited:
visited.add(f.name)
yield f


class FieldsQuery(_sa.FieldsQuery):
def in_impl(
self, column: sqlalchemy.Column, values: Iterable
) -> BinaryExpression:
return column == any_(values)

def select_expr(
self, fields_: List[Field], ids: Iterable
) -> Tuple[Select, Callable]:
result_columns = [self.from_clause.c[f.name] for f in fields_]
# aiopg requires unique columns to be passed to select,
# otherwise it will raise an error
query_columns = [
column
for f in _uniq_fields(fields_)
if (column := self.from_clause.c[f.name]) != self.primary_key
]

expr = (
sqlalchemy.select(
*_sa._process_select_params([self.primary_key] + query_columns)
)
.select_from(self.from_clause)
.where(self.in_impl(self.primary_key, ids))
)

def result_proc(rows: List[_sa.Row]) -> List:
rows_map = {
row[self.primary_key]: [row[c] for c in result_columns]
for row in map(_sa._process_result_row, rows)
}

nulls = [None for _ in fields_]
return [rows_map.get(id_, nulls) for id_ in ids]

return expr, result_proc

async def __call__(
self, ctx: Context, fields_: List[Field], ids: Iterable
) -> List:
Expand Down
2 changes: 1 addition & 1 deletion lets.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ commands:
test-pg:
description: Run tests with pg
depends: [_build-tests]
cmd: [docker-compose, run, --rm, test-pg]
cmd: [docker compose, run, --rm, test-pg]

test-tox:
description: Run tests using tox
Expand Down
Loading