diff --git a/djangochannelsrestframework/observer/base_observer.py b/djangochannelsrestframework/observer/base_observer.py index 5fc6d02..03ed5bf 100644 --- a/djangochannelsrestframework/observer/base_observer.py +++ b/djangochannelsrestframework/observer/base_observer.py @@ -175,30 +175,17 @@ async def unsubscribe_to_comment_activity(self, request_id, **kwargs): groups = list(self.group_names_for_consumer(*args, consumer=consumer, **kwargs)) for group_name in groups: - # remove group to request mappings - if ( - group_name - in consumer._observer_group_to_request_id[self._stable_observer_id] - ): - # unsubscribe all requests to this group + if group_name in consumer._observer_group_to_request_id[self._stable_observer_id]: if request_id is None: - consumer._observer_group_to_request_id[ - self._stable_observer_id - ].pop(group_name) + consumer._observer_group_to_request_id[self._stable_observer_id].pop(group_name) else: - consumer._observer_group_to_request_id[self._stable_observer_id][ - group_name - ].remove(request_id) - - if ( - len( - consumer._observer_group_to_request_id[self._stable_observer_id][ - group_name - ] - ) - > 0 - ): - await consumer.remove_group(group_name) + consumer._observer_group_to_request_id[self._stable_observer_id][group_name].discard(request_id) + + if not consumer._observer_group_to_request_id[self._stable_observer_id][group_name]: + consumer._observer_group_to_request_id[self._stable_observer_id].pop(group_name) + + if group_name not in consumer._observer_group_to_request_id[self._stable_observer_id]: + await consumer.remove_group(group_name) return groups diff --git a/tests/test_observer.py b/tests/test_observer.py index be9787c..95972af 100644 --- a/tests/test_observer.py +++ b/tests/test_observer.py @@ -582,3 +582,132 @@ def user_change_custom_groups(self, username=None, **kwargs): "type": "user.change.custom.groups", "subscribing_request_ids": [5], } == response + + +@pytest.mark.django_db(transaction=True) +@pytest.mark.asyncio +async def test_observer_unsubscribe_behavior_with_custom_groups(settings): + settings.CHANNEL_LAYERS = { + "default": { + "BACKEND": "channels.layers.InMemoryChannelLayer", + "TEST_CONFIG": { + "expiry": 100500, + }, + }, + } + + layer = channel_layers.make_test_backend(DEFAULT_CHANNEL_LAYER) + + class TestConsumerObserverCustomGroups(AsyncAPIConsumer): + @action() + async def subscribe(self, username, request_id, **kwargs): + await self.user_change_custom_groups.subscribe( + username=username, request_id=request_id + ) + + @action() + async def unsubscribe(self, username, request_id, **kwargs): + await self.user_change_custom_groups.unsubscribe( + username=username, request_id=request_id + ) + + @model_observer(get_user_model()) + async def user_change_custom_groups( + self, + message, + action, + message_type, + observer=None, + subscribing_request_ids=None, + **kwargs + ): + await self.send_json( + dict( + body=message, + action=action, + type=message_type, + subscribing_request_ids=subscribing_request_ids, + ) + ) + + @user_change_custom_groups.groups_for_signal + def user_change_custom_groups(self, instance=None, **kwargs): + yield "-instance-username-{}".format(instance.username) + + @user_change_custom_groups.groups_for_consumer + def user_change_custom_groups(self, username=None, **kwargs): + yield "-instance-username-{}".format(slugify(username)) + + async with connected_communicator(TestConsumerObserverCustomGroups()) as communicator: + + user = await database_sync_to_async(get_user_model().objects.create)( + username="thenewname", email="test@example.com" + ) + + assert await communicator.receive_nothing(timeout=0.5) + + await database_sync_to_async(user.delete)() + + assert await communicator.receive_nothing(timeout=0.5) + + await communicator.send_json_to( + { + "action": "subscribe", + "username": "thenewname", + "request_id": 5, + } + ) + + await asyncio.sleep(0.5) + + user = await database_sync_to_async(get_user_model().objects.create)( + username="thenewname", email="test@example.com" + ) + + response = await communicator.receive_json_from() + + assert { + "action": "create", + "body": {"pk": user.pk}, + "type": "user.change.custom.groups", + "subscribing_request_ids": [5], + } == response + + await communicator.send_json_to( + { + "action": "unsubscribe", + "username": "thenewname", + "request_id": 5, + } + ) + + await communicator.send_json_to( + { + "action": "subscribe", + "username": "thenewname2", + "request_id": 6, + } + ) + + await asyncio.sleep(0.5) + + await database_sync_to_async(user.delete)() + + user = await database_sync_to_async(get_user_model().objects.create)( + username="thenewname", email="test@example.com" + ) + + assert await communicator.receive_nothing(timeout=0.5) + + user = await database_sync_to_async(get_user_model().objects.create)( + username="thenewname2", email="test2@example.com" + ) + + response = await communicator.receive_json_from() + + assert { + "action": "create", + "body": {"pk": user.pk}, + "type": "user.change.custom.groups", + "subscribing_request_ids": [6], + } == response