Skip to content

Commit

Permalink
Keep the LMSUserAssignmentMembership table in sync with AssignmentMem…
Browse files Browse the repository at this point in the history
…bership
  • Loading branch information
marcospri committed Oct 8, 2024
1 parent 6d33459 commit c71f539
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 4 deletions.
42 changes: 41 additions & 1 deletion lms/services/assignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
AutoGradingConfig,
Course,
Grouping,
LMSUser,
LMSUserAssignmentMembership,
LTIParams,
LTIRole,
User,
)
Expand Down Expand Up @@ -182,7 +185,11 @@ def get_assignment_for_launch(self, request, course: Course) -> Assignment | Non
)

def upsert_assignment_membership(
self, assignment: Assignment, user: User, lti_roles: list[LTIRole]
self,
lti_params: LTIParams,
assignment: Assignment,
user: User,
lti_roles: list[LTIRole],
) -> list[AssignmentMembership]:
"""Store details of the roles a user plays in an assignment."""

Expand All @@ -198,6 +205,10 @@ def upsert_assignment_membership(
for lti_role in lti_roles
]

self._upsert_lms_user_assignment_memberships(
lti_params, user.lms_user, assignment, lti_roles
)

return list(
bulk_upsert(
self._db,
Expand All @@ -208,6 +219,35 @@ def upsert_assignment_membership(
)
)

def _upsert_lms_user_assignment_memberships(
self,
lti_params,
lms_user: LMSUser,
assignment: Assignment,
lti_roles: list[LTIRole],
) -> list[LMSUserAssignmentMembership]:
values = [
{
"lms_user_id": lms_user.id,
"assignment_id": assignment.id,
"lti_role_id": lti_role.id,
"lti_v11_lis_result_sourcedid": None
if lti_params.v13
else lti_params.get("lis_result_sourcedid"),
}
for lti_role in lti_roles
]

return list(
bulk_upsert(
self._db,
model_class=LMSUserAssignmentMembership,
values=values,
index_elements=["lms_user_id", "assignment_id", "lti_role_id"],
update_columns=["updated", "lti_v11_lis_result_sourcedid"],
)
)

def upsert_assignment_groupings(
self, assignment: Assignment, groupings: list[Grouping]
) -> list[AssignmentGrouping]:
Expand Down
1 change: 1 addition & 0 deletions lms/views/lti/basic_launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def _show_document(self, assignment):

# Store the relationship between the assignment and the user
self.assignment_service.upsert_assignment_membership(
lti_params=self.request.lti_params,
assignment=assignment,
user=self.request.user,
lti_roles=self.request.lti_user.lti_roles,
Expand Down
29 changes: 26 additions & 3 deletions tests/unit/lms/services/assignment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@

import pytest
from h_matchers import Any
from sqlalchemy import select

from lms.models import (
AssignmentGrouping,
AssignmentMembership,
AutoGradingConfig,
LMSUserAssignmentMembership,
RoleScope,
RoleType,
)
Expand Down Expand Up @@ -269,16 +271,22 @@ def test_get_assignment_created_assignments_point_to_copy(
assert assignment.copied_from == sentinel.original_assignment
assert assignment.document_url == sentinel.document_url

def test_upsert_assignment_membership(self, svc, assignment):
user = factories.User()
def test_upsert_assignment_membership(
self, svc, assignment, pyramid_request, db_session
):
lms_user = factories.LMSUser()
user = factories.User(lms_user=lms_user)
lti_roles = factories.LTIRole.create_batch(3)
# One existing row
factories.AssignmentMembership.create(
assignment=assignment, user=user, lti_role=lti_roles[0]
)

membership = svc.upsert_assignment_membership(
assignment=assignment, user=user, lti_roles=lti_roles
pyramid_request.lti_params,
assignment=assignment,
user=user,
lti_roles=lti_roles,
)
assert (
membership
Expand All @@ -291,6 +299,21 @@ def test_upsert_assignment_membership(self, svc, assignment):
]
).only()
)
assert (
db_session.scalars(select(LMSUserAssignmentMembership)).all()
== Any.list.containing(
[
Any.instance_of(LMSUserAssignmentMembership).with_attrs(
{
"lms_user": lms_user,
"assignment": assignment,
"lti_role": lti_role,
}
)
for lti_role in lti_roles
]
).only()
)

def test_upsert_assignment_grouping(self, svc, assignment, db_session):
groupings = factories.CanvasGroup.create_batch(3)
Expand Down

0 comments on commit c71f539

Please sign in to comment.