diff --git a/piccolo/table.py b/piccolo/table.py index bae9b8a47..6c86ef77c 100644 --- a/piccolo/table.py +++ b/piccolo/table.py @@ -851,6 +851,11 @@ def __repr__(self) -> str: ) return f"<{self.__class__.__name__}: {pk}>" + def __eq__(self, other) -> bool: + return isinstance(other, self.__class__) and getattr( + self, self._meta.primary_key._meta.name, None + ) == getattr(other, other._meta.primary_key._meta.name, None) + ########################################################################### # Classmethods diff --git a/tests/table/instance/test_instance_equality.py b/tests/table/instance/test_instance_equality.py new file mode 100644 index 000000000..b95334534 --- /dev/null +++ b/tests/table/instance/test_instance_equality.py @@ -0,0 +1,31 @@ +from piccolo.testing.test_case import AsyncTableTest +from tests.example_apps.music.tables import Band, Manager + + +class TestInstanceEquality(AsyncTableTest): + tables = [Manager, Band] + + async def asyncSetUp(self): + await super().asyncSetUp() + + self.manager = Manager(name="Guido") + await self.manager.save() + + self.band = Band( + name="Pythonistas", manager=self.manager.id, popularity=100 + ) + await self.band.save() + + async def test_instance_equality(self) -> None: + """ + Make sure for instance equailty. + """ + band_pk = await self.band.objects().first() + band = await self.band.objects(Band.manager).get( + (Band._meta.primary_key == band_pk.id) + ) + manager_pk = await self.manager.objects().first() + manager = await self.manager.objects().get( + Manager._meta.primary_key == manager_pk.id + ) + self.assertTrue(band.manager == manager)