diff --git a/lms/services/assignment.py b/lms/services/assignment.py index 4b56dfc997..c281c86ffd 100644 --- a/lms/services/assignment.py +++ b/lms/services/assignment.py @@ -11,6 +11,9 @@ AutoGradingConfig, Course, Grouping, + LMSUser, + LMSUserAssignmentMembership, + LTIParams, LTIRole, User, ) @@ -182,7 +185,11 @@ def get_assignment_for_launch(self, request, course: Course) -> Assignment | Non ) def upsert_assignment_membership( - self, assignment: Assignment, user: User, lti_roles: list[LTIRole] + self, + lti_params: LTIParams, + assignment: Assignment, + user: User, + lti_roles: list[LTIRole], ) -> list[AssignmentMembership]: """Store details of the roles a user plays in an assignment.""" @@ -198,6 +205,10 @@ def upsert_assignment_membership( for lti_role in lti_roles ] + self._upsert_lms_user_assignment_memberships( + lti_params, user.lms_user, assignment, lti_roles + ) + return list( bulk_upsert( self._db, @@ -208,6 +219,35 @@ def upsert_assignment_membership( ) ) + def _upsert_lms_user_assignment_memberships( + self, + lti_params, + lms_user: LMSUser, + assignment: Assignment, + lti_roles: list[LTIRole], + ) -> list[LMSUserAssignmentMembership]: + values = [ + { + "lms_user_id": lms_user.id, + "assignment_id": assignment.id, + "lti_role_id": lti_role.id, + "lti_v11_lis_result_sourcedid": None + if lti_params.v13 + else lti_params.get("lis_result_sourcedid"), + } + for lti_role in lti_roles + ] + + return list( + bulk_upsert( + self._db, + model_class=LMSUserAssignmentMembership, + values=values, + index_elements=["lms_user_id", "assignment_id", "lti_role_id"], + update_columns=["updated", "lti_v11_lis_result_sourcedid"], + ) + ) + def upsert_assignment_groupings( self, assignment: Assignment, groupings: list[Grouping] ) -> list[AssignmentGrouping]: diff --git a/lms/views/lti/basic_launch.py b/lms/views/lti/basic_launch.py index e79c6c396c..9a4f149d23 100644 --- a/lms/views/lti/basic_launch.py +++ b/lms/views/lti/basic_launch.py @@ -172,6 +172,7 @@ def _show_document(self, assignment): # Store the relationship between the assignment and the user self.assignment_service.upsert_assignment_membership( + lti_params=self.request.lti_params, assignment=assignment, user=self.request.user, lti_roles=self.request.lti_user.lti_roles, diff --git a/tests/unit/lms/services/assignment_test.py b/tests/unit/lms/services/assignment_test.py index 24dc17b7fc..60860896ac 100644 --- a/tests/unit/lms/services/assignment_test.py +++ b/tests/unit/lms/services/assignment_test.py @@ -3,11 +3,13 @@ import pytest from h_matchers import Any +from sqlalchemy import select from lms.models import ( AssignmentGrouping, AssignmentMembership, AutoGradingConfig, + LMSUserAssignmentMembership, RoleScope, RoleType, ) @@ -269,16 +271,25 @@ def test_get_assignment_created_assignments_point_to_copy( assert assignment.copied_from == sentinel.original_assignment assert assignment.document_url == sentinel.document_url - def test_upsert_assignment_membership(self, svc, assignment): - user = factories.User() + @pytest.mark.parametrize("with_lti11_grading_id", [True, False]) + def test_upsert_assignment_membership( + self, svc, assignment, pyramid_request, db_session, with_lti11_grading_id + ): + lms_user = factories.LMSUser() + user = factories.User(lms_user=lms_user) lti_roles = factories.LTIRole.create_batch(3) # One existing row factories.AssignmentMembership.create( assignment=assignment, user=user, lti_role=lti_roles[0] ) + if with_lti11_grading_id: + pyramid_request.lti_params['lis_result_sourcedid'] = "SOURCEDID" membership = svc.upsert_assignment_membership( - assignment=assignment, user=user, lti_roles=lti_roles + pyramid_request.lti_params, + assignment=assignment, + user=user, + lti_roles=lti_roles, ) assert ( membership @@ -291,6 +302,22 @@ def test_upsert_assignment_membership(self, svc, assignment): ] ).only() ) + assert ( + db_session.scalars(select(LMSUserAssignmentMembership)).all() + == Any.list.containing( + [ + Any.instance_of(LMSUserAssignmentMembership).with_attrs( + { + "lms_user": lms_user, + "assignment": assignment, + "lti_role": lti_role, + "lti_v11_lis_result_sourcedid": "SOURCEDID" if with_lti11_grading_id else None + } + ) + for lti_role in lti_roles + ] + ).only() + ) def test_upsert_assignment_grouping(self, svc, assignment, db_session): groupings = factories.CanvasGroup.create_batch(3) diff --git a/tests/unit/lms/views/lti/basic_launch_test.py b/tests/unit/lms/views/lti/basic_launch_test.py index 40a140dd06..616941cc88 100644 --- a/tests/unit/lms/views/lti/basic_launch_test.py +++ b/tests/unit/lms/views/lti/basic_launch_test.py @@ -258,6 +258,7 @@ def test__show_document( ) assignment_service.upsert_assignment_membership.assert_called_once_with( + pyramid_request.lti_params, assignment=assignment, user=pyramid_request.user, lti_roles=lti_user.lti_roles,