Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for observing Many-to-Many field changes in ModelObserver #209

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 9 additions & 22 deletions djangochannelsrestframework/observer/base_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,30 +175,17 @@ async def unsubscribe_to_comment_activity(self, request_id, **kwargs):
groups = list(self.group_names_for_consumer(*args, consumer=consumer, **kwargs))

for group_name in groups:
# remove group to request mappings
if (
group_name
in consumer._observer_group_to_request_id[self._stable_observer_id]
):
# unsubscribe all requests to this group
if group_name in consumer._observer_group_to_request_id[self._stable_observer_id]:
if request_id is None:
consumer._observer_group_to_request_id[
self._stable_observer_id
].pop(group_name)
consumer._observer_group_to_request_id[self._stable_observer_id].pop(group_name)
else:
consumer._observer_group_to_request_id[self._stable_observer_id][
group_name
].remove(request_id)

if (
len(
consumer._observer_group_to_request_id[self._stable_observer_id][
group_name
]
)
> 0
):
await consumer.remove_group(group_name)
consumer._observer_group_to_request_id[self._stable_observer_id][group_name].discard(request_id)

if not consumer._observer_group_to_request_id[self._stable_observer_id][group_name]:
consumer._observer_group_to_request_id[self._stable_observer_id].pop(group_name)

if group_name not in consumer._observer_group_to_request_id[self._stable_observer_id]:
await consumer.remove_group(group_name)

return groups

Expand Down
121 changes: 86 additions & 35 deletions djangochannelsrestframework/observer/model_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from channels.layers import get_channel_layer
from django.db import transaction
from django.db.models import Model
from django.db.models.signals import post_delete, post_save, post_init
from django.db.models.signals import post_delete, post_save, post_init, m2m_changed
from rest_framework.serializers import Serializer

from djangochannelsrestframework.observer.base_observer import BaseObserver
Expand Down Expand Up @@ -61,17 +61,37 @@ def _connect(self):

# this is used to capture the current state for the model
post_init.connect(
self.post_init_receiver, sender=self.model_cls, dispatch_uid=id(self)
self.post_init_receiver, sender=self.model_cls, dispatch_uid=str(id(self))
)

post_save.connect(
self.post_save_receiver, sender=self.model_cls, dispatch_uid=id(self)
self.post_save_receiver, sender=self.model_cls, dispatch_uid=str(id(self))
)
have_m2m = False
for field in self.model_cls._meta.many_to_many:
if hasattr(field.remote_field, 'through'):
m2m_changed.connect(
self.m2m_changed_receiver,
sender=field.remote_field.through,
dispatch_uid=f"{str(id(self))}-{self.model_cls.__name__}-{field.name}"
)
have_m2m = True

post_delete.connect(
self.post_delete_receiver, sender=self.model_cls, dispatch_uid=id(self)
self.post_delete_receiver, sender=self.model_cls, dispatch_uid=str(id(self))
)

if have_m2m:
warnings.warn(
"Model observation with many-to-many fields is partially supported. " +
"If you delete a related object, the signal will not be sent. " +
"This is a Django bug that is over 10 years old: https://code.djangoproject.com/ticket/17688. " +
"Also, when working with many-to-many fields, Django uses savepoints, " +
"working with which is non-deterministic and can lead to unexpected results, " +
"as we do not support them.",
UnsupportedWarning,
)

def post_init_receiver(self, instance: Model, **kwargs):

if instance.pk is None:
Expand Down Expand Up @@ -99,28 +119,64 @@ def post_save_receiver(self, instance: Model, created: bool, **kwargs):
else:
self.database_event(instance, Action.UPDATE)

def m2m_changed_receiver(self, sender, instance: Model, action: str, reverse: bool, model: Type[Model],
pk_set: Set[Any], **kwargs):
"""
Handle many-to-many changes.
"""
if action not in {"post_add", "post_remove", "post_clear"} and not reverse:
return

if action not in {"post_add", "post_remove", "pre_clear"} and reverse:
return

target_instances = []
if not reverse:
target_instances.append(instance)
else:
if pk_set:
for pk in pk_set:
target_instances.append(model.objects.get(pk=pk))
else: # pre_clear case
related_field = next(
(field for field in instance._meta.get_fields()
if field.many_to_many and hasattr(field, 'through') and field.through == sender),
None
)
if related_field:
related_manager = getattr(instance, related_field.related_name or f"{related_field.name}_set", None)
if related_manager:
target_instances.extend(related_manager.all())

target_instances = list(set(target_instances)) # remove duplicates if any
for target_instance in target_instances:
self.database_event(target_instance, Action.UPDATE)

def post_delete_receiver(self, instance: Model, **kwargs):
self.database_event(instance, Action.DELETE)

def database_event(self, instance: Model, action: Action):
"""
Handles database events and prepares messages for sending on commit.
"""
messages = list(self.prepare_messages(instance, action))

connection = transaction.get_connection()

if connection.in_atomic_block:
if len(connection.savepoint_ids) > 0:
warnings.warn(
"Model observation with save points is unsupported and will"
" result in unexpected beauvoir.",
" result in unexpected behavior.",
UnsupportedWarning,
)

connection.on_commit(partial(self.post_change_receiver, instance, action))
connection.on_commit(partial(self.send_prepared_messages, messages))

def post_change_receiver(self, instance: Model, action: Action, **kwargs):
def prepare_messages(self, instance: Model, action: Action, **kwargs):
"""
Triggers the old_binding to possibly send to its group.
Prepares messages for sending based on the given action and instance.
"""

if action == Action.CREATE:
old_group_names = set()
else:
Expand All @@ -133,37 +189,32 @@ def post_change_receiver(self, instance: Model, action: Action, **kwargs):

self.get_observer_state(instance).current_groups = new_group_names

# if post delete, new_group_names should be []
yield from self.generate_messages(instance, old_group_names, new_group_names, action, **kwargs)

# Django DDP had used the ordering of DELETE, UPDATE then CREATE for good reasons.
self.send_messages(
instance, old_group_names - new_group_names, Action.DELETE, **kwargs
)
# the object has been updated so that its groups are not the same.
self.send_messages(
instance, old_group_names & new_group_names, Action.UPDATE, **kwargs
)
def generate_messages(self, instance: Model, old_group_names: Set[str], new_group_names: Set[str], action: Action,
**kwargs):
"""
Generates messages for the given group names and action.
"""
for group_name in old_group_names - new_group_names:
yield {**self.serialize(instance, Action.DELETE, **kwargs), "group": group_name}

#
self.send_messages(
instance, new_group_names - old_group_names, Action.CREATE, **kwargs
)
for group_name in old_group_names & new_group_names:
yield {**self.serialize(instance, Action.UPDATE, **kwargs), "group": group_name}

def send_messages(
self, instance: Model, group_names: Set[str], action: Action, **kwargs
):
if not group_names:
return
message = self.serialize(instance, action, **kwargs)
channel_layer = get_channel_layer()

for group_name in group_names:
message_to_send = deepcopy(message)
for group_name in new_group_names - old_group_names:
yield {**self.serialize(instance, Action.CREATE, **kwargs), "group": group_name}

# Include the group name in the message being sent
message_to_send["group"] = group_name
def send_prepared_messages(self, messages):
"""
Sends prepared messages to the channel layer.
"""
if not messages:
return

async_to_sync(channel_layer.group_send)(group_name, message_to_send)
channel_layer = get_channel_layer()
for message in messages:
async_to_sync(channel_layer.group_send)(message["group"], deepcopy(message))

def group_names(self, *args, **kwargs):
# one channel for all updates.
Expand Down
14 changes: 7 additions & 7 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
django==4.2.4
djangorestframework>=3.14.0
channels>=4.0.0
pytest
pytest-asyncio
pytest-django
django>=4.2.16
djangorestframework>=3.15.2
channels>=4.1.0
pytest>=8.3.3
pytest-django>=4.9.0
pytest-asyncio>=0.24.0
sphinx
black==20.8b1
black==24.3.0
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
[tool:pytest]
asyncio_default_fixture_loop_scope = session
11 changes: 6 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
license="MIT",
packages=find_packages(exclude=["tests"]),
include_package_data=True,
install_requires=["Django>=3.2", "channels>=4.0.0", "djangorestframework>=3.14.0"],
install_requires=["Django>=4.2.16", "channels>=4.1.0", "djangorestframework>=3.15.2"],
extras_require={
"tests": [
"channels[daphne]>=4.0.0",
"pytest>=7.0.1",
"pytest-django>=4.5.2",
"pytest-asyncio>=0.18.1",
"channels[daphne]>=4.1.0",
"pytest>=8.3.3",
"pytest-django>=4.9.0",
"pytest-asyncio>=0.24.0",
"coverage>=6.3.1",
],
},
Expand All @@ -33,6 +33,7 @@
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Framework :: Django",
"Topic :: Internet :: WWW/HTTP",
],
Expand Down
85 changes: 85 additions & 0 deletions tests/communicator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import asyncio
from contextlib import asynccontextmanager
from typing import Awaitable

from asgiref.timeout import timeout as async_timeout

from channels.testing import WebsocketCommunicator


class Communicator(WebsocketCommunicator):
"""
Custom communicator class for WebSocket communication in tests.

This override resolves an issue where the default ApplicationCommunicator
cancels the application future unnecessarily when a timeout occurs
while waiting for output. This behavior disrupts subsequent attempts
to receive output or reconnect, causing the application to raise
CancelledError.

The `receive_output` method is modified to:
- Avoid cancelling the application future when a timeout occurs.
- Ensure that the method can be called repeatedly without causing
failures in subsequent operations.

This makes the communicator more flexible for tests where output
availability is uncertain, allowing patterns like:

outputs = []
while True:
try:
outputs.append(await communicator.receive_output())
except asyncio.TimeoutError:
break
"""
_connected = False

@property
def connected(self):
return self._connected

async def receive_output(self, timeout=1):
if self.future.done():
self.future.result() # Ensure exceptions are re-raised if future is complete
try:
async with async_timeout(timeout): # Wait for output with a timeout
return await self.output_queue.get()
except asyncio.TimeoutError as e:
if self.future.done(): # Re-check the state of the future after the timeout
self.future.result()
raise e # Propagate the timeout exception

async def connect(self, timeout=1):
self._connected, subprotocol = await super().connect(timeout)
return self._connected, subprotocol

async def disconnect(self, code=1000, timeout=1):
await super().disconnect(code, timeout)
self._connected = False


@asynccontextmanager
async def connected_communicator(consumer, path: str = "/testws/") -> Awaitable[Communicator]:
"""
Asynchronous context manager for managing WebSocket communicator lifecycle.

This utility simplifies tests involving WebSocket communication by:
- Initializing and connecting a Communicator instance for the given consumer and path.
- Ensuring the connection is properly established, raising an assertion error if not.
- Guaranteeing cleanup by disconnecting the communicator upon exiting the context.

Example usage:

async with connected_communicator(TestConsumer) as communicator:
await communicator.send_json_to({"key": "value"})
response = await communicator.receive_json_from()
assert response == {"key": "value"}
"""
communicator = Communicator(consumer, path)
connected, _ = await communicator.connect()
try:
assert connected, "Failed to connect to WebSocket"
yield communicator
finally:
if communicator.connected:
await communicator.disconnect()
3 changes: 3 additions & 0 deletions tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@

class TestModel(models.Model):
"""Simple model to test with."""
__test__ = False # Prevent pytest from collecting this as a test class

name = models.CharField(max_length=255)


class TestModelWithCustomPK(models.Model):
"""Simple model with custom primary key to test with."""
__test__ = False # Prevent pytest from collecting this as a test class

name = models.CharField(max_length=255, primary_key=True)
description = models.CharField(max_length=255, null=True, blank=True)
Loading