Skip to content

Commit

Permalink
Add m2m signal handling and tests for ModelObserver
Browse files Browse the repository at this point in the history
Integrated support for `m2m_changed` signals in `ModelObserver` to handle many-to-many relationship changes. Updated the `database_event` method accordingly and introduced a new test to validate m2m signal handling with actions such as add, remove, and clear operations. Adjusted test data to include `groups` field for verification.
  • Loading branch information
tumblingman committed Dec 13, 2024
1 parent e39ac3c commit 78cbf19
Show file tree
Hide file tree
Showing 2 changed files with 191 additions and 6 deletions.
26 changes: 25 additions & 1 deletion djangochannelsrestframework/observer/model_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from channels.layers import get_channel_layer
from django.db import transaction
from django.db.models import Model
from django.db.models.signals import post_delete, post_save, post_init
from django.db.models.signals import post_delete, post_save, post_init, m2m_changed
from rest_framework.serializers import Serializer

from djangochannelsrestframework.observer.base_observer import BaseObserver
Expand Down Expand Up @@ -68,6 +68,13 @@ def _connect(self):
self.post_save_receiver, sender=self.model_cls, dispatch_uid=id(self)
)

for field in self.model_cls._meta.many_to_many:
m2m_changed.connect(
self.m2m_changed_receiver,
sender=field.remote_field.through,
dispatch_uid=f"{id(self)}-{field.name}"
)

post_delete.connect(
self.post_delete_receiver, sender=self.model_cls, dispatch_uid=id(self)
)
Expand Down Expand Up @@ -99,6 +106,23 @@ def post_save_receiver(self, instance: Model, created: bool, **kwargs):
else:
self.database_event(instance, Action.UPDATE)

def m2m_changed_receiver(self, action: str, instance: Model, reverse: bool, model: Type[Model], pk_set: Set[Any], **kwargs):
"""
Handle many-to-many changes.
"""
if action not in {"post_add", "post_remove", "post_clear"}:
return

target_instances = []
if not reverse:
target_instances.append(instance)
else:
for pk in pk_set:
target_instances.append(model.objects.get(pk=pk))

for target_instance in target_instances:
self.database_event(target_instance, Action.UPDATE)

def post_delete_receiver(self, instance: Model, **kwargs):
self.database_event(instance, Action.DELETE)

Expand Down
171 changes: 166 additions & 5 deletions tests/test_model_observer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import asyncio
from contextlib import AsyncExitStack

import pytest
from django.contrib.auth.models import Group
from channels import DEFAULT_CHANNEL_LAYER
from channels.db import database_sync_to_async
from channels.layers import channel_layers
Expand All @@ -21,6 +23,7 @@ class Meta:
"id",
"username",
"email",
"groups",
)


Expand Down Expand Up @@ -104,7 +107,7 @@ async def update_username(self, pk=None, username=None, **kwargs):
"errors": [],
"response_status": 200,
"request_id": 1,
"data": {"email": "[email protected]", "id": u1.id, "username": "test1"},
"data": {"email": "[email protected]", "id": u1.id, "username": "test1", "groups": []},
}

# lookup up u1
Expand Down Expand Up @@ -153,7 +156,7 @@ async def update_username(self, pk=None, username=None, **kwargs):
"errors": [],
"response_status": 200,
"request_id": 4,
"data": {"email": "[email protected]", "id": u1.id, "username": "thenewname"},
"data": {"email": "[email protected]", "id": u1.id, "username": "thenewname", "groups": []},
}

u1_pk = u1.pk
Expand Down Expand Up @@ -348,7 +351,7 @@ async def update_username(self, pk=None, username=None, **kwargs):
"errors": [],
"response_status": 200,
"request_id": 4,
"data": {"email": "[email protected]", "id": u1.pk, "username": "thenewname"},
"data": {"email": "[email protected]", "id": u1.pk, "username": "thenewname", "groups": []},
} in [a, b]

# unsubscribe
Expand Down Expand Up @@ -490,7 +493,7 @@ async def update_username(self, pk=None, username=None, **kwargs):
"errors": [],
"response_status": 200,
"request_id": 4,
"data": {"email": "[email protected]", "id": u1.id, "username": "new name"},
"data": {"email": "[email protected]", "id": u1.id, "username": "new name", "groups": []},
}

assert await communicator.receive_nothing()
Expand Down Expand Up @@ -522,5 +525,163 @@ async def update_username(self, pk=None, username=None, **kwargs):
"errors": [],
"response_status": 200,
"request_id": 5,
"data": {"email": "[email protected]", "id": u2.id, "username": "the new name 2"},
"data": {"email": "[email protected]", "id": u2.id, "username": "the new name 2", "groups": []},
}


@pytest.mark.django_db(transaction=True)
@pytest.mark.asyncio
async def test_m2m_observer(settings):
"""
This tests
"""

settings.CHANNEL_LAYERS = {
"default": {
"BACKEND": "channels.layers.InMemoryChannelLayer",
"TEST_CONFIG": {"expiry": 100500},
},
}

layer = channel_layers.make_test_backend(DEFAULT_CHANNEL_LAYER)

class TestConsumerMultipleSubs(ObserverModelInstanceMixin, GenericAsyncAPIConsumer):

queryset = get_user_model().objects.all()
serializer_class = UserSerializer

async def accept(self, subprotocol=None):
await super().accept()

assert not await database_sync_to_async(get_user_model().objects.all().exists)()

# Test a normal connection
async with connected_communicator(TestConsumerMultipleSubs()) as communicator:

u1 = await database_sync_to_async(get_user_model().objects.create)(
username="test1", email="[email protected]"
)

u2 = await database_sync_to_async(get_user_model().objects.create)(
username="test2", email="[email protected]"
)

# Subscribe to instance user 1
await communicator.send_json_to(
{"action": "subscribe_instance", "pk": u1.id, "request_id": 4}
)

response = await communicator.receive_json_from()

assert response == {
"action": "subscribe_instance",
"errors": [],
"response_status": 201,
"request_id": 4,
"data": None,
}

g1 = await database_sync_to_async(Group.objects.create)(name="group1")
g2 = await database_sync_to_async(Group.objects.create)(name="group2")
g3 = await database_sync_to_async(Group.objects.create)(name="group3")
g4 = await database_sync_to_async(Group.objects.create)(name="group4")

await database_sync_to_async(u1.groups.add)(g1, g2)

response = await communicator.receive_json_from()

assert response == {
"action": "update",
"errors": [],
"response_status": 200,
"request_id": 4,
"data": {
"email": "[email protected]",
"id": u1.id,
"username": "test1",
"groups": [g1.id, g2.id]
},
}

await database_sync_to_async(u2.groups.add)(g4)

await communicator.receive_nothing()

await database_sync_to_async(g1.user_set.add)(u2)

await communicator.receive_nothing()

await database_sync_to_async(g3.user_set.add)(u1, u2)

response = await communicator.receive_json_from()

assert response == {
"action": "update",
"errors": [],
"response_status": 200,
"request_id": 4,
"data": {
"email": "[email protected]",
"id": u1.id,
"username": "test1",
"groups": [g1.id, g2.id, g3.id]
},
}

await database_sync_to_async(g1.user_set.remove)(u1)

response = await communicator.receive_json_from()

assert response == {
"action": "update",
"errors": [],
"response_status": 200,
"request_id": 4,
"data": {
"email": "[email protected]",
"id": u1.id,
"username": "test1",
"groups": [g2.id, g3.id]
},
}

await database_sync_to_async(u1.groups.clear)()

response = await communicator.receive_json_from()

assert response == {
"action": "update",
"errors": [],
"response_status": 200,
"request_id": 4,
"data": {
"email": "[email protected]",
"id": u1.id,
"username": "test1",
"groups": []
},
}

await database_sync_to_async(u2.groups.clear)()

await communicator.receive_nothing()

await database_sync_to_async(u1.groups.set)([g1, g4])

response = await communicator.receive_json_from()
assert response == {
"action": "update",
"errors": [],
"response_status": 200,
"request_id": 4,
"data": {
"email": "[email protected]",
"id": u1.id,
"username": "test1",
"groups": [g1.id, g4.id]
},
}

await database_sync_to_async(u2.groups.set)([g1, g4])

await communicator.receive_nothing()

0 comments on commit 78cbf19

Please sign in to comment.