From 8c7a4af6949eb4c4e5dad70a27093978b8a4547b Mon Sep 17 00:00:00 2001 From: Marcos Prieto Date: Mon, 11 Nov 2024 15:59:37 +0100 Subject: [PATCH] Code changes to maintain the lms_segment table up to date --- lms/services/__init__.py | 4 + lms/services/grouping/factory.py | 2 + lms/services/grouping/service.py | 42 +++++++-- lms/services/segment.py | 86 +++++++++++++++++++ tests/factories/__init__.py | 1 + tests/factories/lms_group_set.py | 9 +- tests/factories/lms_segment.py | 14 +++ .../lms/services/grouping/factory_test.py | 5 +- .../lms/services/grouping/service_test.py | 39 +++++++-- tests/unit/lms/services/segment_test.py | 74 ++++++++++++++++ tests/unit/services.py | 7 ++ 11 files changed, 265 insertions(+), 18 deletions(-) create mode 100644 lms/services/segment.py create mode 100644 tests/factories/lms_segment.py create mode 100644 tests/unit/lms/services/segment_test.py diff --git a/lms/services/__init__.py b/lms/services/__init__.py index 9e4ec0ce1c..98e9b9a47c 100644 --- a/lms/services/__init__.py +++ b/lms/services/__init__.py @@ -39,6 +39,7 @@ from lms.services.organization_usage_report import OrganizationUsageReportService from lms.services.roster import RosterService from lms.services.rsa_key import RSAKeyService +from lms.services.segment import SegmentService from lms.services.user import UserService from lms.services.user_preferences import UserPreferencesService from lms.services.vitalsource import VitalSourceService @@ -156,6 +157,9 @@ def includeme(config): # noqa: PLR0915 config.register_service_factory( "lms.services.group_set.factory", iface=GroupSetService ) + config.register_service_factory( + "lms.services.segment.factory", iface=SegmentService + ) config.register_service_factory( "lms.services.auto_grading.factory", iface=AutoGradingService ) diff --git a/lms/services/grouping/factory.py b/lms/services/grouping/factory.py index 31046466ed..2d7d3dba40 100644 --- a/lms/services/grouping/factory.py +++ b/lms/services/grouping/factory.py @@ -1,4 +1,5 @@ from lms.services.grouping.service import GroupingService +from lms.services.segment import SegmentService def service_factory(_context, request): @@ -8,4 +9,5 @@ def service_factory(_context, request): request.lti_user.application_instance if request.lti_user else None ), plugin=request.product.plugin.grouping, + segment_service=request.find_service(SegmentService), ) diff --git a/lms/services/grouping/service.py b/lms/services/grouping/service.py index 582068653a..c2196dae42 100644 --- a/lms/services/grouping/service.py +++ b/lms/services/grouping/service.py @@ -1,17 +1,25 @@ from sqlalchemy import func from sqlalchemy.orm import aliased -from lms.models import Course, Grouping, GroupingMembership, LTIUser, User +from lms.models import Course, Grouping, GroupingMembership, LTIRole, LTIUser, User from lms.models._hashed_id import hashed_id from lms.product.plugin.grouping import GroupingPlugin +from lms.services.segment import SegmentService from lms.services.upsert import bulk_upsert class GroupingService: - def __init__(self, db, application_instance, plugin: GroupingPlugin): + def __init__( + self, + db, + application_instance, + plugin: GroupingPlugin, + segment_service: SegmentService, + ): self._db = db self.application_instance = application_instance self.plugin = plugin + self.segment_service = segment_service def get_authority_provided_id( self, lms_id, type_: Grouping.Type, parent: Grouping | None = None @@ -170,7 +178,13 @@ def get_sections( else: groupings = self.plugin.get_sections_for_instructor(self, course) - return self._to_groupings(user, groupings, course, self.plugin.sections_type) + return self._to_groupings( + user, + groupings, + course, + self.plugin.sections_type, + lti_roles=lti_user.lti_roles, + ) def get_groups( # noqa: PLR0913 self, @@ -202,7 +216,13 @@ def get_groups( # noqa: PLR0913 self, course, group_set_id ) - return self._to_groupings(user, groupings, course, self.plugin.group_type) + return self._to_groupings( + user, + groupings, + course, + self.plugin.group_type, + lti_roles=lti_user.lti_roles, + ) def get_launch_grouping_type(self, request, course, assignment) -> Grouping.Type: """ @@ -224,7 +244,9 @@ def get_launch_grouping_type(self, request, course, assignment) -> Grouping.Type return Grouping.Type.COURSE - def _to_groupings(self, user, groupings, course, type_): + def _to_groupings( # noqa: PLR0913 + self, user, groupings, course, type_, lti_roles: list[LTIRole] + ): if groupings and not isinstance(groupings[0], Grouping): groupings = [ { @@ -241,7 +263,15 @@ def _to_groupings(self, user, groupings, course, type_): for grouping in groupings ] groupings = self.upsert_groupings(groupings, parent=course, type_=type_) + segments = self.segment_service.upsert_segments( + course=course.lms_course, + type_=groupings[0].type, + groupings=groupings, + lms_group_set_id=groupings[0].extra["group_set_id"], + ) + self.segment_service.upsert_segment_memberships( + lms_user=user.lms_user, segments=segments, lti_roles=lti_roles + ) self.upsert_grouping_memberships(user, groupings) - return groupings diff --git a/lms/services/segment.py b/lms/services/segment.py new file mode 100644 index 0000000000..c25a1cf854 --- /dev/null +++ b/lms/services/segment.py @@ -0,0 +1,86 @@ +from typing import TypedDict + +from sqlalchemy import func + +from lms.models import ( + Grouping, + LMSCourse, + LMSSegment, + LMSSegmentMembership, + LMSUser, + LTIRole, +) +from lms.services.group_set import GroupSetService +from lms.services.upsert import bulk_upsert + + +class SegmentService: + def __init__(self, db, group_set_service: GroupSetService): + self._db = db + self._group_set_service = group_set_service + + def upsert_segments( + self, + course: LMSCourse, + type_: Grouping.Type, + groupings: list[Grouping], + lms_group_set_id: str | None = None, + ) -> list[LMSSegment]: + group_set = None + if lms_group_set_id: + group_set = self._group_set_service.find_group_set( + course.course.application_instance, + lms_id=lms_group_set_id, + context_id=course.lti_context_id, + ) + + return bulk_upsert( + self._db, + LMSSegment, + [ + { + "type": type_, + "lms_id": segment.lms_id, + "name": segment.lms_name, + "h_authority_provided_id": segment.authority_provided_id, + "lms_course_id": course.id, + "lms_group_set_id": group_set.id if group_set else None, + } + for segment in groupings + ], + index_elements=["h_authority_provided_id"], + update_columns=["name", "updated"], + ).all() + + def upsert_segment_memberships( + self, + lms_user: LMSUser, + lti_roles: list[LTIRole], + segments: list[LMSSegment], + ) -> list[LMSSegmentMembership]: + if not lms_user.id or any(s.id is None for s in segments): + # Ensure all ORM objects have their PK populated + self._db.flush() + + return bulk_upsert( + self._db, + LMSSegmentMembership, + [ + { + "lms_segment_id": s.id, + "lms_user_id": lms_user.id, + "lti_role_id": lti_role.id, + "updated": func.now(), + } + for s in segments + for lti_role in lti_roles + ], + index_elements=["lms_segment_id", "lms_user_id", "lti_role_id"], + update_columns=["updated"], + ) + + +def factory(_context, request): + return SegmentService( + db=request.db, group_set_service=request.find_service(GroupSetService) + ) diff --git a/tests/factories/__init__.py b/tests/factories/__init__.py index 2462c8017d..64cfb38cbf 100644 --- a/tests/factories/__init__.py +++ b/tests/factories/__init__.py @@ -38,6 +38,7 @@ LMSCourseMembership, ) from tests.factories.lms_group_set import LMSGroupSet +from tests.factories.lms_segment import LMSSegment from tests.factories.lms_user import LMSUser from tests.factories.lti_registration import LTIRegistration from tests.factories.lti_role import LTIRole, LTIRoleOverride diff --git a/tests/factories/lms_group_set.py b/tests/factories/lms_group_set.py index 2c9a3ee875..8dc29450df 100644 --- a/tests/factories/lms_group_set.py +++ b/tests/factories/lms_group_set.py @@ -1,6 +1,11 @@ -from factory import make_factory +from factory import Faker, make_factory from factory.alchemy import SQLAlchemyModelFactory from lms import models -LMSGroupSet = make_factory(models.LMSGroupSet, FACTORY_CLASS=SQLAlchemyModelFactory) +LMSGroupSet = make_factory( + models.LMSGroupSet, + FACTORY_CLASS=SQLAlchemyModelFactory, + lms_id=Faker("hexify", text="^" * 40), + name=Faker("word"), +) diff --git a/tests/factories/lms_segment.py b/tests/factories/lms_segment.py new file mode 100644 index 0000000000..586852c910 --- /dev/null +++ b/tests/factories/lms_segment.py @@ -0,0 +1,14 @@ +import factory +from factory import Faker +from factory.alchemy import SQLAlchemyModelFactory + +from lms import models + +LMSSegment = factory.make_factory( + models.LMSSegment, + FACTORY_CLASS=SQLAlchemyModelFactory, + type=Faker("random_element", elements=models.Grouping.Type), + lms_id=Faker("hexify", text="^" * 40), + name=Faker("word"), + h_authority_provided_id=Faker("hexify", text="^" * 40), +) diff --git a/tests/unit/lms/services/grouping/factory_test.py b/tests/unit/lms/services/grouping/factory_test.py index 11ec0b944b..754726be4c 100644 --- a/tests/unit/lms/services/grouping/factory_test.py +++ b/tests/unit/lms/services/grouping/factory_test.py @@ -7,13 +7,16 @@ @pytest.mark.usefixtures("application_instance_service", "with_plugins") class TestFactory: - def test_it(self, pyramid_request, application_instance, GroupingService): + def test_it( + self, pyramid_request, application_instance, GroupingService, segment_service + ): svc = service_factory(sentinel.context, pyramid_request) GroupingService.assert_called_once_with( db=pyramid_request.db, application_instance=application_instance, plugin=pyramid_request.product.plugin.grouping, + segment_service=segment_service, ) assert svc == GroupingService.return_value diff --git a/tests/unit/lms/services/grouping/service_test.py b/tests/unit/lms/services/grouping/service_test.py index 5007a83a8c..da9d95dcdb 100644 --- a/tests/unit/lms/services/grouping/service_test.py +++ b/tests/unit/lms/services/grouping/service_test.py @@ -401,8 +401,15 @@ def test_get_groups_with_instructor(self, svc, lti_user, assert_groups_returned) "group_set_key", ("groupSetId", "group_category_id", "group_set_id") ) def test_to_groupings_with_dicts( - self, svc, upsert_groupings, upsert_grouping_memberships, group_set_key + self, + svc, + upsert_groupings, + upsert_grouping_memberships, + group_set_key, + lti_user, ): + user = factories.User() + course = factories.Course() grouping_dicts = [ { "id": sentinel.id, @@ -413,7 +420,11 @@ def test_to_groupings_with_dicts( ] groupings = svc._to_groupings( # noqa: SLF001 - sentinel.user, grouping_dicts, sentinel.course, sentinel.grouping_type + user, + grouping_dicts, + course, + sentinel.grouping_type, + lti_roles=lti_user.lti_roles, ) upsert_groupings.assert_called_once_with( @@ -425,21 +436,25 @@ def test_to_groupings_with_dicts( "settings": sentinel.settings, } ], - parent=sentinel.course, + parent=course, type_=sentinel.grouping_type, ) upsert_grouping_memberships.assert_called_once_with( - sentinel.user, upsert_groupings.return_value + user, upsert_groupings.return_value ) assert groupings == upsert_groupings.return_value def test_to_groupings_when_already_groupings( - self, svc, upsert_groupings, upsert_grouping_memberships + self, svc, upsert_groupings, upsert_grouping_memberships, lti_user ): groupings = factories.CanvasSection.create_batch(5) svc._to_groupings( # noqa: SLF001 - sentinel.user, groupings, sentinel.course, sentinel.grouping_type + sentinel.user, + groupings, + sentinel.course, + sentinel.grouping_type, + lti_roles=lti_user.lti_roles, ) upsert_groupings.assert_not_called() @@ -480,13 +495,14 @@ def assert_sections_returned(self, svc, assert_groupings_returned): ) @pytest.fixture - def assert_groupings_returned(self, _to_groupings): + def assert_groupings_returned(self, _to_groupings, lti_user): def assert_groupings_returned(groupings, plugin_method, grouping_type): _to_groupings.assert_called_once_with( sentinel.user, plugin_method.return_value, sentinel.course, grouping_type, + lti_roles=lti_user.lti_roles, ) assert groupings == _to_groupings.return_value @@ -520,5 +536,10 @@ def user(): @pytest.fixture -def svc(db_session, application_instance, grouping_plugin): - return GroupingService(db_session, application_instance, plugin=grouping_plugin) +def svc(db_session, application_instance, grouping_plugin, segment_service): + return GroupingService( + db_session, + application_instance, + plugin=grouping_plugin, + segment_service=segment_service, + ) diff --git a/tests/unit/lms/services/segment_test.py b/tests/unit/lms/services/segment_test.py new file mode 100644 index 0000000000..b6b3d03a20 --- /dev/null +++ b/tests/unit/lms/services/segment_test.py @@ -0,0 +1,74 @@ +from unittest.mock import sentinel + +import pytest + +from lms.services.segment import SegmentService, factory +from tests import factories + + +class TestSegmentService: + def test_upsert_segments_with_group_set(self, svc, group_set_service, db_session): + course = factories.LMSCourse(course=factories.Course()) + groups = factories.CanvasGroup.create_batch(5, parent=course.course) + group_set = factories.LMSGroupSet(lms_course=course) + group_set_service.find_group_set.return_value = group_set + db_session.flush() + + segments = svc.upsert_segments(course, groups[0].type, groups, group_set.lms_id) + + group_set_service.find_group_set.assert_called_once_with( + course.course.application_instance, + lms_id=group_set.lms_id, + context_id=course.lti_context_id, + ) + assert { + (s.type, s.lms_course.h_authority_provided_id, s.name, s.lms_id) + for s in segments + } == { + (g.type, g.parent.authority_provided_id, g.name, g.lms_id) for g in groups + } + + def test_upsert_segments(self, svc, db_session): + course = factories.LMSCourse(course=factories.Course()) + groups = factories.CanvasGroup.create_batch(5, parent=course.course) + db_session.flush() + + segments = svc.upsert_segments(course, groups[0].type, groups) + + assert { + (s.type, s.lms_course.h_authority_provided_id, s.name, s.lms_id) + for s in segments + } == { + (g.type, g.parent.authority_provided_id, g.name, g.lms_id) for g in groups + } + + @pytest.mark.parametrize("with_flush", [True, False]) + def test_upsert_segment_memberships(self, svc, db_session, with_flush): + segments = factories.LMSSegment.create_batch(5) + lms_user = factories.LMSUser() + roles = factories.LTIRole.create_batch(3) + + if with_flush: + db_session.flush() + + svc.upsert_segment_memberships( + lms_user=lms_user, segments=segments, lti_roles=roles + ) + + @pytest.fixture + def svc(self, db_session, group_set_service): + return SegmentService(db=db_session, group_set_service=group_set_service) + + +class TestFactory: + def test_it(self, pyramid_request, SegmentService, db_session, group_set_service): + service = factory(sentinel.context, pyramid_request) + + SegmentService.assert_called_once_with( + db=db_session, group_set_service=group_set_service + ) + assert service == SegmentService.return_value + + @pytest.fixture + def SegmentService(self, patch): + return patch("lms.services.segment.SegmentService") diff --git a/tests/unit/services.py b/tests/unit/services.py index dd7a3969fa..7053586117 100644 --- a/tests/unit/services.py +++ b/tests/unit/services.py @@ -51,6 +51,7 @@ from lms.services.organization_usage_report import OrganizationUsageReportService from lms.services.roster import RosterService from lms.services.rsa_key import RSAKeyService +from lms.services.segment import SegmentService from lms.services.user import UserService from lms.services.user_preferences import UserPreferencesService from lms.services.vitalsource import VitalSourceService @@ -108,6 +109,7 @@ "user_preferences_service", "vitalsource_service", "email_preferences_service", + "segment_service", "youtube_service", # Product plugins "grouping_plugin", @@ -198,6 +200,11 @@ def roster_service(mock_service): return mock_service(RosterService) +@pytest.fixture +def segment_service(mock_service): + return mock_service(SegmentService) + + @pytest.fixture def group_set_service(mock_service): return mock_service(GroupSetService)