diff --git a/lms/product/canvas/_plugin/grouping.py b/lms/product/canvas/_plugin/grouping.py index 5884fbb777..9680384e51 100644 --- a/lms/product/canvas/_plugin/grouping.py +++ b/lms/product/canvas/_plugin/grouping.py @@ -88,8 +88,8 @@ def get_groups_for_instructor(self, _svc, course, group_set_id): ) except CanvasAPIError as canvas_api_error: group_set_name = None - if group_set := self._request.find_service(name="course").find_group_set( - group_set_id=group_set_id + if group_set := self._group_set_service.find_group_set( + course.application_instance, group_set_id=group_set_id ): group_set_name = group_set["name"] diff --git a/lms/product/plugin/course_copy.py b/lms/product/plugin/course_copy.py index 832d995e3e..bd4090516a 100644 --- a/lms/product/plugin/course_copy.py +++ b/lms/product/plugin/course_copy.py @@ -4,6 +4,7 @@ from lms.models import File from lms.services.exceptions import ExternalRequestError, OAuth2TokenError from lms.services.file import FileService +from lms.services.group_set import GroupSetService class CourseCopyFilesHelper: @@ -66,8 +67,8 @@ def factory(cls, _context, request): class CourseCopyGroupsHelper: - def __init__(self, course_service, grouping_plugin): - self._course_service = course_service + def __init__(self, group_set_service: GroupSetService, grouping_plugin): + self._group_set_service = group_set_service self._grouping_plugin = grouping_plugin def find_matching_group_set_in_course(self, course, group_set_id): @@ -91,7 +92,9 @@ def find_matching_group_set_in_course(self, course, group_set_id): pass # Get the original group set from the DB - group_set = self._course_service.find_group_set(group_set_id=group_set_id) + group_set = self._group_set_service.find_group_set( + application_instance=course.application_instance, group_set_id=group_set_id + ) if not group_set: # If we haven't found it could that either: # - The group set doesn't belong to this course @@ -102,8 +105,10 @@ def find_matching_group_set_in_course(self, course, group_set_id): # Try to find a matching group set in the new course. # We might have a record of this because we just called `grouping_plugin.get_group_sets` as the current user # or another user might have done it before for us. - if new_group_set := self._course_service.find_group_set( - name=group_set["name"], context_id=course.lms_id + if new_group_set := self._group_set_service.find_group_set( + application_instance=course.application_instance, + name=group_set["name"], + context_id=course.lms_id, ): # We found a match, store it to save the search for next time course.set_mapped_group_set_id(group_set_id, new_group_set["id"]) @@ -114,7 +119,9 @@ def find_matching_group_set_in_course(self, course, group_set_id): @classmethod def factory(cls, _context, request): - return cls(request.find_service(name="course"), request.product.plugin.grouping) + return cls( + request.find_service(GroupSetService), request.product.plugin.grouping + ) class CourseCopyPlugin: # pragma: nocover diff --git a/lms/services/course.py b/lms/services/course.py index 1f62562553..2830b260c1 100644 --- a/lms/services/course.py +++ b/lms/services/course.py @@ -362,44 +362,6 @@ def upsert_lms_course_membership( ) ) - def find_group_set(self, group_set_id=None, name=None, context_id=None): - """ - Find the first matching group set in this course. - - Group sets are stored as part of Course.extra, this method allows to query and filter them. - - :param context_id: Match only group sets of courses with this ID - :param name: Filter courses by name - :param group_set_id: Filter courses by ID - """ - group_set = ( - func.jsonb_to_recordset(Course.extra["group_sets"]) - .table_valued( - column("id", Text), column("name", Text), joins_implicitly=True - ) - .render_derived(with_types=True) - ) - - query = self._db.query(Grouping.id, group_set.c.id, group_set.c.name).filter( - Grouping.application_instance == self._application_instance - ) - - if context_id: - query = query.filter(Grouping.lms_id == context_id) - - if group_set_id: - query = query.filter(group_set.c.id == group_set_id) - - if name: - query = query.filter( - func.lower(func.trim(group_set.c.name)) == func.lower(func.trim(name)) - ) - - if group_set := query.first(): - return {"id": group_set.id, "name": group_set.name} - - return None - def get_by_id(self, id_: int) -> Course | None: return self._search_query(id_=id_).one_or_none() diff --git a/lms/services/group_set.py b/lms/services/group_set.py index 5d791661b4..c42db94052 100644 --- a/lms/services/group_set.py +++ b/lms/services/group_set.py @@ -1,6 +1,8 @@ from typing import TypedDict -from lms.models.group_set import LMSGroupSet +from sqlalchemy import Text, column, func, select, union + +from lms.models import Course, Grouping, LMSGroupSet from lms.services.upsert import bulk_upsert @@ -46,6 +48,46 @@ def store_group_sets(self, course, group_sets: list[dict]): update_columns=["name", "updated"], ) + def find_group_set( + self, application_instance, group_set_id=None, name=None, context_id=None + ) -> GroupSetDict | None: + """ + Find the first matching group set in this course. + + Group sets are stored as part of Course.extra, this method allows to query and filter them. + + :param context_id: Match only group sets of courses with this ID + :param name: Filter courses by name + :param group_set_id: Filter courses by ID + """ + group_set = ( + func.jsonb_to_recordset(Course.extra["group_sets"]) + .table_valued( + column("id", Text), column("name", Text), joins_implicitly=True + ) + .render_derived(with_types=True) + ) + + query = self._db.query(Grouping.id, group_set.c.id, group_set.c.name).filter( + Grouping.application_instance == application_instance + ) + + if context_id: + query = query.filter(Grouping.lms_id == context_id) + + if group_set_id: + query = query.filter(group_set.c.id == group_set_id) + + if name: + query = query.filter( + func.lower(func.trim(group_set.c.name)) == func.lower(func.trim(name)) + ) + + if group_set := query.first(): + return {"id": group_set.id, "name": group_set.name} + + return None + def factory(_context, request): return GroupSetService(db=request.db) diff --git a/tests/unit/lms/services/course_test.py b/tests/unit/lms/services/course_test.py index f17ca13bd2..30615ccd86 100644 --- a/tests/unit/lms/services/course_test.py +++ b/tests/unit/lms/services/course_test.py @@ -280,39 +280,6 @@ def test_upsert_lms_course_membership( update_columns=["updated"], ) - @pytest.mark.usefixtures("course_with_group_sets") - @pytest.mark.parametrize( - "params", - ( - {"context_id": "context_id", "group_set_id": "ID", "name": "NAME"}, - {"context_id": "context_id", "name": "NAME"}, - {"context_id": "context_id", "name": "name"}, - {"context_id": "context_id", "name": "NAME "}, - {"context_id": "context_id", "group_set_id": "ID"}, - ), - ) - def test_find_group_set(self, svc, params): - group_set = svc.find_group_set(**params) - - assert group_set["id"] == "ID" - assert group_set["name"] == "NAME" - - @pytest.mark.usefixtures("course_with_group_sets") - @pytest.mark.parametrize( - "params", - ( - {"context_id": "context_id", "group_set_id": "NOID", "name": "NAME"}, - {"context_id": "context_id", "group_set_id": "ID", "name": "NONAME"}, - {"context_id": "no_context_id", "group_set_id": "ID", "name": "NAME"}, - ), - ) - def test_find_group_set_no_matches(self, svc, params): - assert not svc.find_group_set(**params) - - @pytest.mark.usefixtures("course_with_group_sets") - def test_find_group_set_returns_first_result(self, svc): - assert svc.find_group_set() - @pytest.mark.parametrize( "param,field", ( @@ -475,22 +442,6 @@ def course(self, application_instance, grouping_service): lms_id="context_id", ) - @pytest.fixture - def course_with_group_sets(self, course): - course.extra = { - "group_sets": [ - { - "id": "ID", - "name": "NAME", - }, - { - "id": "NOT MATCHING ID NOISE", - "name": "NOT MATCHING NAME NOISE", - }, - ] - } - return course - @pytest.fixture def application_instance(self, application_instance): application_instance.tool_consumer_instance_guid = "tool_consumer_instance_guid" diff --git a/tests/unit/lms/services/group_set_test.py b/tests/unit/lms/services/group_set_test.py index 9b7448d05d..2daacf97a1 100644 --- a/tests/unit/lms/services/group_set_test.py +++ b/tests/unit/lms/services/group_set_test.py @@ -35,10 +35,69 @@ def test_set_group_sets(self, group_set, expected, svc, db_session): == group_set["name"] ) + @pytest.mark.usefixtures("course_with_group_sets") + @pytest.mark.parametrize( + "params", + ( + {"context_id": "context_id", "group_set_id": "ID", "name": "NAME"}, + {"context_id": "context_id", "name": "NAME"}, + {"context_id": "context_id", "name": "name"}, + {"context_id": "context_id", "name": "NAME "}, + {"context_id": "context_id", "group_set_id": "ID"}, + ), + ) + def test_find_group_set(self, svc, params, application_instance): + group_set = svc.find_group_set( + application_instance=application_instance, **params + ) + + assert group_set["id"] == "ID" + assert group_set["name"] == "NAME" + + @pytest.mark.usefixtures("course_with_group_sets") + @pytest.mark.parametrize( + "params", + ( + {"context_id": "context_id", "group_set_id": "NOID", "name": "NAME"}, + {"context_id": "context_id", "group_set_id": "ID", "name": "NONAME"}, + {"context_id": "no_context_id", "group_set_id": "ID", "name": "NAME"}, + ), + ) + def test_find_group_set_no_matches(self, svc, params, application_instance): + assert not svc.find_group_set( + application_instance=application_instance, **params + ) + + @pytest.mark.usefixtures("course_with_group_sets") + def test_find_group_set_returns_first_result(self, svc, application_instance): + assert svc.find_group_set(application_instance) + @pytest.fixture def svc(self, db_session): return GroupSetService(db=db_session) + @pytest.fixture + def course(self, application_instance): + return factories.Course( + application_instance=application_instance, lms_id="context_id" + ) + + @pytest.fixture + def course_with_group_sets(self, course): + course.extra = { + "group_sets": [ + { + "id": "ID", + "name": "NAME", + }, + { + "id": "NOT MATCHING ID NOISE", + "name": "NOT MATCHING NAME NOISE", + }, + ] + } + return course + class TestFactory: def test_it(self, pyramid_request, GroupSetService, db_session):