diff --git a/docs/src/piccolo/query_types/django_comparison.rst b/docs/src/piccolo/query_types/django_comparison.rst index a0447bbd9..18ef42db9 100644 --- a/docs/src/piccolo/query_types/django_comparison.rst +++ b/docs/src/piccolo/query_types/django_comparison.rst @@ -218,7 +218,7 @@ Or alternatively, using ``get_related``: fan_club = await band.get_related(Band.id.join_on(FanClub.band)) - # Similarly: + # Alternatively, by reversing the foreign key: fan_club = await band.get_related(FanClub.band.reverse()) If doing a select query, and you want data from the related table: @@ -231,7 +231,7 @@ If doing a select query, and you want data from the related table: ... ) [{'name': 'Pythonistas', 'address': '1 Flying Circus, UK'}, ...] -Similarly, in where clauses: +And filtering by related tables in the ``where`` clause: .. code-block:: python diff --git a/piccolo/columns/column_types.py b/piccolo/columns/column_types.py index bab051d28..67c9c1a9d 100644 --- a/piccolo/columns/column_types.py +++ b/piccolo/columns/column_types.py @@ -2031,6 +2031,10 @@ class FanClub(Table): band = ForeignKey(Band, unique=True) address = Text() + class Treasurer(Table): + fan_club = ForeignKey(FanClub, unique=True) + name = Varchar() + It's helpful with ``get_related``, for example: .. code-block:: python @@ -2039,17 +2043,34 @@ class FanClub(Table): >>> await band.get_related(FanClub.band.reverse()) + It works multiple levels deep: + + .. code-block:: python + + >>> await band.get_related(Treasurer.fan_club._.band.reverse()) + + """ if not self._meta.unique or any( not i._meta.unique for i in self._meta.call_chain ): raise ValueError("Only reverse unique foreign keys.") - target_column = self._foreign_key_meta.resolved_target_column - foreign_key = target_column.join_on(self) - foreign_key._meta.call_chain = [ - i.reverse() for i in reversed(self._meta.call_chain) - ] + foreign_keys = [*self._meta.call_chain, self] + + root_foreign_key = foreign_keys[0] + target_column = ( + root_foreign_key._foreign_key_meta.resolved_target_column + ) + foreign_key = target_column.join_on(root_foreign_key) + + call_chain = [] + for fk in reversed(foreign_keys[1:]): + target_column = fk._foreign_key_meta.resolved_target_column + call_chain.append(target_column.join_on(fk)) + + foreign_key._meta.call_chain = call_chain + return foreign_key def all_related( diff --git a/tests/base.py b/tests/base.py index b05f85622..f9f964c70 100644 --- a/tests/base.py +++ b/tests/base.py @@ -13,7 +13,12 @@ from piccolo.engine.finder import engine_finder from piccolo.engine.postgres import PostgresEngine from piccolo.engine.sqlite import SQLiteEngine -from piccolo.table import Table, create_table_class +from piccolo.table import ( + Table, + create_db_tables_sync, + create_table_class, + drop_db_tables_sync, +) from piccolo.utils.sync import run_sync ENGINE = engine_finder() @@ -454,3 +459,17 @@ def setUp(self): def tearDown(self): self.drop_tables() + + +class TableTest(TestCase): + """ + Used for tests where we need to create Piccolo tables. + """ + + tables: t.List[t.Type[Table]] + + def setUp(self) -> None: + create_db_tables_sync(*self.tables) + + def tearDown(self) -> None: + drop_db_tables_sync(*self.tables) diff --git a/tests/columns/foreign_key/test_reverse.py b/tests/columns/foreign_key/test_reverse.py new file mode 100644 index 000000000..2a90ac5ba --- /dev/null +++ b/tests/columns/foreign_key/test_reverse.py @@ -0,0 +1,56 @@ +from piccolo.columns import ForeignKey, Text, Varchar +from piccolo.table import Table +from tests.base import TableTest + + +class Band(Table): + name = Varchar() + + +class FanClub(Table): + address = Text() + band = ForeignKey(Band, unique=True) + + +class Treasurer(Table): + name = Varchar() + fan_club = ForeignKey(FanClub, unique=True) + + +class TestReverse(TableTest): + tables = [Band, FanClub, Treasurer] + + def setUp(self): + super().setUp() + + band = Band({Band.name: "Pythonistas"}) + band.save().run_sync() + + fan_club = FanClub( + {FanClub.band: band, FanClub.address: "1 Flying Circus, UK"} + ) + fan_club.save().run_sync() + + treasurer = Treasurer( + {Treasurer.fan_club: fan_club, Treasurer.name: "Bob"} + ) + treasurer.save().run_sync() + + def test_reverse(self): + response = Band.select( + Band.name, + FanClub.band.reverse().address.as_alias("address"), + Treasurer.fan_club._.band.reverse().name.as_alias( + "treasurer_name" + ), + ).run_sync() + self.assertListEqual( + response, + [ + { + "name": "Pythonistas", + "address": "1 Flying Circus, UK", + "treasurer_name": "Bob", + } + ], + ) diff --git a/tests/query/functions/base.py b/tests/query/functions/base.py index 168f5528b..623bc1a5a 100644 --- a/tests/query/functions/base.py +++ b/tests/query/functions/base.py @@ -1,21 +1,8 @@ -import typing as t -from unittest import TestCase - -from piccolo.table import Table, create_db_tables_sync, drop_db_tables_sync +from tests.base import TableTest from tests.example_apps.music.tables import Band, Manager -class FunctionTest(TestCase): - tables: t.List[t.Type[Table]] - - def setUp(self) -> None: - create_db_tables_sync(*self.tables) - - def tearDown(self) -> None: - drop_db_tables_sync(*self.tables) - - -class BandTest(FunctionTest): +class BandTest(TableTest): tables = [Band, Manager] def setUp(self) -> None: diff --git a/tests/query/functions/test_datetime.py b/tests/query/functions/test_datetime.py index 3e2d33d0b..360833dc4 100644 --- a/tests/query/functions/test_datetime.py +++ b/tests/query/functions/test_datetime.py @@ -12,16 +12,14 @@ Year, ) from piccolo.table import Table -from tests.base import engines_only, sqlite_only - -from .base import FunctionTest +from tests.base import TableTest, engines_only, sqlite_only class Concert(Table): starts = Timestamp() -class DatetimeTest(FunctionTest): +class DatetimeTest(TableTest): tables = [Concert] def setUp(self) -> None: diff --git a/tests/query/functions/test_math.py b/tests/query/functions/test_math.py index 1c82f9426..7029e7857 100644 --- a/tests/query/functions/test_math.py +++ b/tests/query/functions/test_math.py @@ -3,15 +3,14 @@ from piccolo.columns import Numeric from piccolo.query.functions.math import Abs, Ceil, Floor, Round from piccolo.table import Table - -from .base import FunctionTest +from tests.base import TableTest class Ticket(Table): price = Numeric(digits=(5, 2)) -class TestMath(FunctionTest): +class TestMath(TableTest): tables = [Ticket]