diff --git a/docs/src/piccolo/query_types/objects.rst b/docs/src/piccolo/query_types/objects.rst index e4a2c126e..0d6f93d26 100644 --- a/docs/src/piccolo/query_types/objects.rst +++ b/docs/src/piccolo/query_types/objects.rst @@ -176,6 +176,13 @@ using ``get_related``. >>> manager.name 'Guido' +It works multiple levels deep - for example: + +.. code-block:: python + + concert = await Concert.objects().first() + manager = await concert.get_related(Concert.band_1.manager) + Prefetching related objects ~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/piccolo/query/methods/objects.py b/piccolo/query/methods/objects.py index 5e4dda50b..eff040201 100644 --- a/piccolo/query/methods/objects.py +++ b/piccolo/query/methods/objects.py @@ -2,7 +2,7 @@ import typing as t -from piccolo.columns.column_types import ForeignKey +from piccolo.columns.column_types import ForeignKey, ReferencedTable from piccolo.columns.combination import And, Where from piccolo.custom_types import Combinable, TableInstance from piccolo.engine.base import BaseBatch @@ -231,6 +231,54 @@ def run_sync(self, *args, **kwargs) -> None: return run_sync(self.run(*args, **kwargs)) +class GetRelated(t.Generic[ReferencedTable]): + + def __init__(self, row: Table, foreign_key: ForeignKey[ReferencedTable]): + self.row = row + self.foreign_key = foreign_key + + async def run( + self, + node: t.Optional[str] = None, + in_pool: bool = True, + ) -> t.Optional[ReferencedTable]: + references = t.cast( + t.Type[ReferencedTable], + self.foreign_key._foreign_key_meta.resolved_references, + ) + + data = ( + await self.row.__class__.select( + *[ + i.as_alias(i._meta.name) + for i in self.foreign_key.all_columns() + ] + ) + .first() + .run(node=node, in_pool=in_pool) + ) + + # Make sure that some values were returned: + if data is None or not any(data.values()): + return None + + referenced_object = references(**data) + referenced_object._exists_in_db = True + return referenced_object + + def __await__( + self, + ) -> t.Generator[None, None, t.Optional[ReferencedTable]]: + """ + If the user doesn't explicity call .run(), proxy to it as a + convenience. + """ + return self.run().__await__() + + def run_sync(self, *args, **kwargs) -> t.Optional[ReferencedTable]: + return run_sync(self.run(*args, **kwargs)) + + ############################################################################### diff --git a/piccolo/table.py b/piccolo/table.py index b50855f95..bae9b8a47 100644 --- a/piccolo/table.py +++ b/piccolo/table.py @@ -46,7 +46,7 @@ ) from piccolo.query.methods.create_index import CreateIndex from piccolo.query.methods.indexes import Indexes -from piccolo.query.methods.objects import First, UpdateSelf +from piccolo.query.methods.objects import GetRelated, UpdateSelf from piccolo.query.methods.refresh import Refresh from piccolo.querystring import QueryString from piccolo.utils import _camel_to_snake @@ -612,14 +612,14 @@ def refresh( @t.overload def get_related( self, foreign_key: ForeignKey[ReferencedTable] - ) -> First[ReferencedTable]: ... + ) -> GetRelated[ReferencedTable]: ... @t.overload - def get_related(self, foreign_key: str) -> First[Table]: ... + def get_related(self, foreign_key: str) -> GetRelated[Table]: ... def get_related( self, foreign_key: t.Union[str, ForeignKey[ReferencedTable]] - ) -> t.Union[First[Table], First[ReferencedTable]]: + ) -> GetRelated[ReferencedTable]: """ Used to fetch a ``Table`` instance, for the target of a foreign key. @@ -630,8 +630,8 @@ def get_related( >>> print(manager.name) 'Guido' - It can only follow foreign keys one level currently. - i.e. ``Band.manager``, but not ``Band.manager.x.y.z``. + It can only follow foreign keys multiple levels deep. For example, + ``Concert.band_1.manager``. """ if isinstance(foreign_key, str): @@ -645,18 +645,7 @@ def get_related( "ForeignKey column." ) - column_name = foreign_key._meta.name - - references = foreign_key._foreign_key_meta.resolved_references - - return ( - references.objects() - .where( - foreign_key._foreign_key_meta.resolved_target_column - == getattr(self, column_name) - ) - .first() - ) + return GetRelated(foreign_key=foreign_key, row=self) def get_m2m(self, m2m: M2M) -> M2MGetRelated: """ diff --git a/tests/table/instance/test_get_related.py b/tests/table/instance/test_get_related.py index 28c572314..6cae3b9fc 100644 --- a/tests/table/instance/test_get_related.py +++ b/tests/table/instance/test_get_related.py @@ -1,42 +1,62 @@ import typing as t -from unittest import TestCase -from tests.example_apps.music.tables import Band, Manager +from piccolo.testing.test_case import AsyncTableTest +from tests.example_apps.music.tables import Band, Concert, Manager, Venue -TABLES = [Manager, Band] +class TestGetRelated(AsyncTableTest): + tables = [Manager, Band, Concert, Venue] -class TestGetRelated(TestCase): - def setUp(self): - for table in TABLES: - table.create_table().run_sync() + async def asyncSetUp(self): + await super().asyncSetUp() - def tearDown(self): - for table in reversed(TABLES): - table.alter().drop_table().run_sync() + self.manager = Manager(name="Guido") + await self.manager.save() - def test_get_related(self) -> None: + self.band = Band( + name="Pythonistas", manager=self.manager.id, popularity=100 + ) + await self.band.save() + + async def test_foreign_key(self) -> None: """ Make sure you can get a related object from another object instance. """ - manager = Manager(name="Guido") - manager.save().run_sync() + manager = await self.band.get_related(Band.manager) + assert manager is not None + self.assertTrue(manager.name == "Guido") - band = Band(name="Pythonistas", manager=manager.id, popularity=100) - band.save().run_sync() + async def test_non_foreign_key(self): + """ + Make sure that non-ForeignKey raise an exception. + """ + with self.assertRaises(ValueError): + self.band.get_related(Band.name) # type: ignore - _manager = band.get_related(Band.manager).run_sync() - assert _manager is not None - self.assertTrue(_manager.name == "Guido") + async def test_string(self): + """ + Make sure it also works using a string representation of a foreign key. + """ + manager = t.cast(Manager, await self.band.get_related("manager")) + self.assertTrue(manager.name == "Guido") - # Test non-ForeignKey + async def test_invalid_string(self): + """ + Make sure an exception is raised if the foreign key string is invalid. + """ with self.assertRaises(ValueError): - band.get_related(Band.name) # type: ignore + self.band.get_related("abc123") + + async def test_multiple_levels(self): + """ + Make sure ``get_related`` works multiple levels deep. + """ + concert = Concert(band_1=self.band) + await concert.save() - # Make sure it also works using a string - _manager_2 = t.cast(Manager, band.get_related("manager").run_sync()) - self.assertTrue(_manager_2.name == "Guido") + manager = await concert.get_related(Concert.band_1._.manager) + assert manager is not None + self.assertTrue(manager.name == "Guido") - # Test an invalid string - with self.assertRaises(ValueError): - band.get_related("abc123") + band_2_manager = await concert.get_related(Concert.band_2._.manager) + assert band_2_manager is None diff --git a/tests/type_checking.py b/tests/type_checking.py index d1e9d96ca..7288768e1 100644 --- a/tests/type_checking.py +++ b/tests/type_checking.py @@ -49,6 +49,12 @@ async def get_related() -> None: manager = await band.get_related(Band.manager) assert_type(manager, t.Optional[Manager]) + async def get_related_multiple_levels() -> None: + concert = await Concert.objects().first() + assert concert is not None + manager = await concert.get_related(Concert.band_1._.manager) + assert_type(manager, t.Optional[Manager]) + async def get_or_create() -> None: query = Band.objects().get_or_create(Band.name == "Pythonistas") assert_type(await query, Band)