diff --git a/lms/services/course.py b/lms/services/course.py index 698e239e41..026a07fd10 100644 --- a/lms/services/course.py +++ b/lms/services/course.py @@ -13,7 +13,10 @@ CourseGroupsExportedFromH, Grouping, GroupingMembership, + LTIRole, Organization, + RoleScope, + RoleType, User, ) from lms.product import Product @@ -329,6 +332,24 @@ def get_assignments( return self._db.scalars(assignments_query).all() + def get_members( + self, course: Course, role_type: RoleType, role_scope: RoleScope + ) -> list[User]: + return self._db.scalars( + select(User) + .join(AssignmentMembership, User.id == AssignmentMembership.user_id) + .join(LTIRole) + .join( + AssignmentGrouping, + AssignmentMembership.assignment_id == AssignmentGrouping.assignment_id, + ) + .where( + AssignmentGrouping.grouping_id == course.id, + LTIRole.scope == role_scope, + LTIRole.type == role_type, + ) + ).all() + def _get_authority_provided_id(self, context_id): return self._grouping_service.get_authority_provided_id( lms_id=context_id, type_=Grouping.Type.COURSE diff --git a/tests/unit/lms/services/course_test.py b/tests/unit/lms/services/course_test.py index e9ed60eb19..1f5150499a 100644 --- a/tests/unit/lms/services/course_test.py +++ b/tests/unit/lms/services/course_test.py @@ -5,7 +5,12 @@ from h_matchers import Any from sqlalchemy.exc import NoResultFound -from lms.models import ApplicationSettings, CourseGroupsExportedFromH, Grouping +from lms.models import ( + ApplicationSettings, + CourseGroupsExportedFromH, + Grouping, + RoleScope, +) from lms.product.product import Product from lms.services.course import CourseService, course_service_factory from tests import factories @@ -338,6 +343,29 @@ def test_is_member(self, svc, db_session): assert svc.is_member(course, user.h_userid) assert not svc.is_member(course, other_user.h_userid) + def test_get_members(self, svc, db_session): + factories.User() # User not in assignment + assignment = factories.Assignment() + user = factories.User() + course = factories.Course() + lti_role = factories.LTIRole(scope=RoleScope.COURSE) + factories.AssignmentMembership.create( + assignment=assignment, user=user, lti_role=lti_role + ) + # User in assignment with other role + factories.AssignmentMembership.create( + assignment=assignment, + user=factories.User(), + lti_role=factories.LTIRole(scope=RoleScope.SYSTEM), + ) + factories.AssignmentGrouping(grouping=course, assignment=assignment) + + db_session.flush() + + assert svc.get_members( + course, role_scope=lti_role.scope, role_type=lti_role.type + ) == [user] + def test_get_organization_courses_deduplicates(self, db_session, svc): org = factories.Organization()