From 79b08976c56de41fc84018b0526d9787bd528424 Mon Sep 17 00:00:00 2001 From: Tumblingman <73558578+tumblingman@users.noreply.github.com> Date: Tue, 24 Dec 2024 23:47:41 +0400 Subject: [PATCH] Fix and enhance many-to-many change handling in ModelObserver Added checks for `through` attribute and improved handling of `pre_clear` and `reverse` cases in many-to-many field changes. Enhanced the logic to avoid duplicates and ensure correct updates to related instances. Updated tests to validate these changes, ensuring robust many-to-many relationship observation. --- .../observer/model_observer.py | 40 +++++++++----- tests/test_model_observer.py | 52 ++++++++++++++++++- 2 files changed, 79 insertions(+), 13 deletions(-) diff --git a/djangochannelsrestframework/observer/model_observer.py b/djangochannelsrestframework/observer/model_observer.py index bd0ff46..6439a41 100644 --- a/djangochannelsrestframework/observer/model_observer.py +++ b/djangochannelsrestframework/observer/model_observer.py @@ -69,12 +69,13 @@ def _connect(self): ) have_m2m = False for field in self.model_cls._meta.many_to_many: - m2m_changed.connect( - self.m2m_changed_receiver, - sender=field.remote_field.through, - dispatch_uid=f"{id(self)}-{field.name}" - ) - have_m2m = True + if hasattr(field.remote_field, 'through'): + m2m_changed.connect( + self.m2m_changed_receiver, + sender=field.remote_field.through, + dispatch_uid=f"{str(id(self))}-{self.model_cls.__name__}-{field.name}" + ) + have_m2m = True post_delete.connect( self.post_delete_receiver, sender=self.model_cls, dispatch_uid=str(id(self)) @@ -118,21 +119,36 @@ def post_save_receiver(self, instance: Model, created: bool, **kwargs): else: self.database_event(instance, Action.UPDATE) - def m2m_changed_receiver(self, action: str, instance: Model, reverse: bool, model: Type[Model], pk_set: Set[Any], - **kwargs): + def m2m_changed_receiver(self, sender, instance: Model, action: str, reverse: bool, model: Type[Model], + pk_set: Set[Any], **kwargs): """ Handle many-to-many changes. """ - if action not in {"post_add", "post_remove", "post_clear"}: + if action not in {"post_add", "post_remove", "post_clear"} and not reverse: + return + + if action not in {"post_add", "post_remove", "pre_clear"} and reverse: return target_instances = [] if not reverse: target_instances.append(instance) else: - for pk in pk_set: - target_instances.append(model.objects.get(pk=pk)) - + if pk_set: + for pk in pk_set: + target_instances.append(model.objects.get(pk=pk)) + else: # pre_clear case + related_field = next( + (field for field in instance._meta.get_fields() + if field.many_to_many and hasattr(field, 'through') and field.through == sender), + None + ) + if related_field: + related_manager = getattr(instance, related_field.related_name or f"{related_field.name}_set", None) + if related_manager: + target_instances.extend(related_manager.all()) + + target_instances = list(set(target_instances)) # remove duplicates if any for target_instance in target_instances: self.database_event(target_instance, Action.UPDATE) diff --git a/tests/test_model_observer.py b/tests/test_model_observer.py index b663222..41875b3 100644 --- a/tests/test_model_observer.py +++ b/tests/test_model_observer.py @@ -1,4 +1,3 @@ -import asyncio from contextlib import AsyncExitStack import pytest @@ -685,3 +684,54 @@ async def accept(self, subprotocol=None): await database_sync_to_async(u2.groups.set)([g1, g4]) await communicator.receive_nothing() + + await database_sync_to_async(u1.groups.set)([g1, g2, g3, g4]) + + response = await communicator.receive_json_from() + + assert response == { + "action": "update", + "errors": [], + "response_status": 200, + "request_id": 4, + "data": { + "email": "42@example.com", + "id": u1.id, + "username": "test1", + "groups": [g1.id, g2.id, g3.id, g4.id] + }, + } + + await database_sync_to_async(g4.user_set.clear)() + + response = await communicator.receive_json_from() + + assert response == { + "action": "update", + "errors": [], + "response_status": 200, + "request_id": 4, + "data": { + "email": "42@example.com", + "id": u1.id, + "username": "test1", + "groups": [g1.id, g2.id, g3.id] + }, + } + + await database_sync_to_async(g3.user_set.remove)(u1) + + response = await communicator.receive_json_from() + + assert response == { + "action": "update", + "errors": [], + "response_status": 200, + "request_id": 4, + "data": { + "email": "42@example.com", + "id": u1.id, + "username": "test1", + "groups": [g1.id, g2.id] + }, + }