diff --git a/openedx_learning/apps/authoring/collections/api.py b/openedx_learning/apps/authoring/collections/api.py index cd9cb3dd..464fd467 100644 --- a/openedx_learning/apps/authoring/collections/api.py +++ b/openedx_learning/apps/authoring/collections/api.py @@ -21,6 +21,7 @@ "get_collection", "get_collections", "get_learning_package_collections", + "get_object_collections", "remove_from_collections", "update_collection", ] @@ -160,6 +161,16 @@ def remove_from_collections( return total_deleted +def get_object_collections(object_id: int) -> QuerySet[Collection]: + """ + Get all collections associated with a given PublishableEntity. + + Only enabled collections are returned. + """ + entity = PublishableEntity.objects.get(pk=object_id) + return entity.collections.filter(enabled=True).order_by("pk") + + def get_learning_package_collections(learning_package_id: int) -> QuerySet[Collection]: """ Get all collections for a given learning package diff --git a/tests/openedx_learning/apps/authoring/collections/test_api.py b/tests/openedx_learning/apps/authoring/collections/test_api.py index 2d1df7ed..16b8bc33 100644 --- a/tests/openedx_learning/apps/authoring/collections/test_api.py +++ b/tests/openedx_learning/apps/authoring/collections/test_api.py @@ -199,6 +199,7 @@ class CollectionContentsTestCase(CollectionTestCase): collection0: Collection collection1: Collection collection2: Collection + disabled_collection: Collection @classmethod def setUpTestData(cls) -> None: @@ -268,6 +269,17 @@ def setUpTestData(cls) -> None: cls.draft_entity.id, ]), ) + cls.disabled_collection = collection_api.create_collection( + cls.learning_package.id, + title="Disabled Collection", + created_by=None, + description="This disabled collection contains 1 object", + contents_qset=PublishableEntity.objects.filter(id__in=[ + cls.published_entity.id, + ]), + ) + cls.disabled_collection.enabled = False + cls.disabled_collection.save() def test_create_collection_contents(self): """ @@ -371,6 +383,18 @@ def test_remove_from_collections(self): self.collection2.refresh_from_db() assert self.collection2.modified == modified_time + def test_get_object_collections(self): + """ + Tests fetching the enabled collections which contain a given object. + """ + collections = collection_api.get_object_collections( + self.published_entity.id, + ) + assert list(collections) == [ + self.collection1, + self.collection2, + ] + class UpdateCollectionTestCase(CollectionTestCase): """