Skip to content

Commit

Permalink
make reverse work multiple levels deep
Browse files Browse the repository at this point in the history
  • Loading branch information
dantownsend committed Jun 19, 2024
1 parent cf4ffc1 commit 216907f
Show file tree
Hide file tree
Showing 7 changed files with 110 additions and 30 deletions.
4 changes: 2 additions & 2 deletions docs/src/piccolo/query_types/django_comparison.rst
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ Or alternatively, using ``get_related``:
fan_club = await band.get_related(Band.id.join_on(FanClub.band))
# Similarly:
# Alternatively, by reversing the foreign key:
fan_club = await band.get_related(FanClub.band.reverse())
If doing a select query, and you want data from the related table:
Expand All @@ -231,7 +231,7 @@ If doing a select query, and you want data from the related table:
... )
[{'name': 'Pythonistas', 'address': '1 Flying Circus, UK'}, ...]
Similarly, in where clauses:
And filtering by related tables in the ``where`` clause:

.. code-block:: python
Expand Down
31 changes: 26 additions & 5 deletions piccolo/columns/column_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2031,6 +2031,10 @@ class FanClub(Table):
band = ForeignKey(Band, unique=True)
address = Text()
class Treasurer(Table):
fan_club = ForeignKey(FanClub, unique=True)
name = Varchar()
It's helpful with ``get_related``, for example:
.. code-block:: python
Expand All @@ -2039,17 +2043,34 @@ class FanClub(Table):
>>> await band.get_related(FanClub.band.reverse())
<Fan Club: 1>
It works multiple levels deep:
.. code-block:: python
>>> await band.get_related(Treasurer.fan_club._.band.reverse())
<Treasurer: 1>
"""
if not self._meta.unique or any(
not i._meta.unique for i in self._meta.call_chain
):
raise ValueError("Only reverse unique foreign keys.")

target_column = self._foreign_key_meta.resolved_target_column
foreign_key = target_column.join_on(self)
foreign_key._meta.call_chain = [
i.reverse() for i in reversed(self._meta.call_chain)
]
foreign_keys = [*self._meta.call_chain, self]

root_foreign_key = foreign_keys[0]
target_column = (
root_foreign_key._foreign_key_meta.resolved_target_column
)
foreign_key = target_column.join_on(root_foreign_key)

call_chain = []
for fk in reversed(foreign_keys[1:]):
target_column = fk._foreign_key_meta.resolved_target_column
call_chain.append(target_column.join_on(fk))

foreign_key._meta.call_chain = call_chain

return foreign_key

def all_related(
Expand Down
21 changes: 20 additions & 1 deletion tests/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@
from piccolo.engine.finder import engine_finder
from piccolo.engine.postgres import PostgresEngine
from piccolo.engine.sqlite import SQLiteEngine
from piccolo.table import Table, create_table_class
from piccolo.table import (
Table,
create_db_tables_sync,
create_table_class,
drop_db_tables_sync,
)
from piccolo.utils.sync import run_sync

ENGINE = engine_finder()
Expand Down Expand Up @@ -454,3 +459,17 @@ def setUp(self):

def tearDown(self):
self.drop_tables()


class TableTest(TestCase):
"""
Used for tests where we need to create Piccolo tables.
"""

tables: t.List[t.Type[Table]]

def setUp(self) -> None:
create_db_tables_sync(*self.tables)

def tearDown(self) -> None:
drop_db_tables_sync(*self.tables)
56 changes: 56 additions & 0 deletions tests/columns/foreign_key/test_reverse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from piccolo.columns import ForeignKey, Text, Varchar
from piccolo.table import Table
from tests.base import TableTest


class Band(Table):
name = Varchar()


class FanClub(Table):
address = Text()
band = ForeignKey(Band, unique=True)


class Treasurer(Table):
name = Varchar()
fan_club = ForeignKey(FanClub, unique=True)


class TestReverse(TableTest):
tables = [Band, FanClub, Treasurer]

def setUp(self):
super().setUp()

band = Band({Band.name: "Pythonistas"})
band.save().run_sync()

fan_club = FanClub(
{FanClub.band: band, FanClub.address: "1 Flying Circus, UK"}
)
fan_club.save().run_sync()

treasurer = Treasurer(
{Treasurer.fan_club: fan_club, Treasurer.name: "Bob"}
)
treasurer.save().run_sync()

def test_reverse(self):
response = Band.select(
Band.name,
FanClub.band.reverse().address.as_alias("address"),
Treasurer.fan_club._.band.reverse().name.as_alias(
"treasurer_name"
),
).run_sync()
self.assertListEqual(
response,
[
{
"name": "Pythonistas",
"address": "1 Flying Circus, UK",
"treasurer_name": "Bob",
}
],
)
17 changes: 2 additions & 15 deletions tests/query/functions/base.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,8 @@
import typing as t
from unittest import TestCase

from piccolo.table import Table, create_db_tables_sync, drop_db_tables_sync
from tests.base import TableTest
from tests.example_apps.music.tables import Band, Manager


class FunctionTest(TestCase):
tables: t.List[t.Type[Table]]

def setUp(self) -> None:
create_db_tables_sync(*self.tables)

def tearDown(self) -> None:
drop_db_tables_sync(*self.tables)


class BandTest(FunctionTest):
class BandTest(TableTest):
tables = [Band, Manager]

def setUp(self) -> None:
Expand Down
6 changes: 2 additions & 4 deletions tests/query/functions/test_datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,14 @@
Year,
)
from piccolo.table import Table
from tests.base import engines_only, sqlite_only

from .base import FunctionTest
from tests.base import TableTest, engines_only, sqlite_only


class Concert(Table):
starts = Timestamp()


class DatetimeTest(FunctionTest):
class DatetimeTest(TableTest):
tables = [Concert]

def setUp(self) -> None:
Expand Down
5 changes: 2 additions & 3 deletions tests/query/functions/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,14 @@
from piccolo.columns import Numeric
from piccolo.query.functions.math import Abs, Ceil, Floor, Round
from piccolo.table import Table

from .base import FunctionTest
from tests.base import TableTest


class Ticket(Table):
price = Numeric(digits=(5, 2))


class TestMath(FunctionTest):
class TestMath(TableTest):

tables = [Ticket]

Expand Down

0 comments on commit 216907f

Please sign in to comment.