diff --git a/lms/services/dashboard.py b/lms/services/dashboard.py index a206937d71..4e2783598a 100644 --- a/lms/services/dashboard.py +++ b/lms/services/dashboard.py @@ -178,7 +178,7 @@ def get_request_admin_organizations(self, request) -> list[Organization]: def get_assignment_roster( self, assignment: Assignment, h_userids: list[str] | None = None - ) -> Select[tuple[LMSUser]]: + ) -> Select[tuple[LMSUser, bool]]: rosters_enabled = ( assignment.course and assignment.course.application_instance.settings.get( @@ -203,7 +203,8 @@ def get_assignment_roster( role_type=RoleType.LEARNER, assignment_id=assignment.id, h_userids=h_userids, - ) + # For launch data we always add the "active" column as true for compatibility with the roster query. + ).add_columns(True) # Always return the results, no matter the source, sorted return query.order_by(LMSUser.display_name, LMSUser.id) diff --git a/lms/services/roster.py b/lms/services/roster.py index e57ef2aa41..be8503298d 100644 --- a/lms/services/roster.py +++ b/lms/services/roster.py @@ -80,16 +80,14 @@ def get_assignment_roster( role_scope: RoleScope | None = None, role_type: RoleType | None = None, h_userids: list[str] | None = None, - ) -> Select[tuple[LMSUser]]: + ) -> Select[tuple[LMSUser, bool]]: """Get the roster information for a course from our DB.""" roster_query = ( - select(AssignmentRoster.lms_user_id) - .join(LTIRole) - .where( - AssignmentRoster.assignment_id == assignment.id, - AssignmentRoster.active.is_(True), - ) - ) + select(LMSUser, AssignmentRoster.active) + .join(LMSUser, AssignmentRoster.lms_user_id == LMSUser.id) + .join(LTIRole, AssignmentRoster.lti_role_id == LTIRole.id) + .where(AssignmentRoster.assignment_id == assignment.id) + ).distinct() if role_scope: roster_query = roster_query.where(LTIRole.scope == role_scope) @@ -97,12 +95,10 @@ def get_assignment_roster( if role_type: roster_query = roster_query.where(LTIRole.type == role_type) - query = select(LMSUser).where(LMSUser.id.in_(roster_query)) - if h_userids: - query = query.where(LMSUser.h_userid.in_(h_userids)) + roster_query = roster_query.where(LMSUser.h_userid.in_(h_userids)) - return query + return roster_query def fetch_course_roster(self, lms_course: LMSCourse) -> None: """Fetch the roster information for a course from the LMS.""" diff --git a/lms/views/dashboard/api/user.py b/lms/views/dashboard/api/user.py index 10806c8ed3..ebea699f4a 100644 --- a/lms/views/dashboard/api/user.py +++ b/lms/views/dashboard/api/user.py @@ -147,11 +147,12 @@ def students_metrics(self) -> APIRoster: h_userids=request_h_userids, ) # Iterate over all the students we have in the DB - for user in self.request.db.scalars(users_query).all(): + for roster_data in self.request.db.execute(users_query).all(): + user, active = roster_data if s := stats_by_user.get(user.h_userid): # We seen this student in H, get all the data from there api_student = RosterEntry( - active=True, + active=active, h_userid=user.h_userid, lms_id=user.user_id, display_name=s["display_name"], @@ -165,7 +166,7 @@ def students_metrics(self) -> APIRoster: # We haven't seen this user H, # use LMS DB's data and set 0s for all annotation related fields. api_student = RosterEntry( - active=True, + active=active, h_userid=user.h_userid, lms_id=user.user_id, display_name=user.display_name, diff --git a/tests/unit/lms/services/dashboard_test.py b/tests/unit/lms/services/dashboard_test.py index 1cfdab2300..102cb934c5 100644 --- a/tests/unit/lms/services/dashboard_test.py +++ b/tests/unit/lms/services/dashboard_test.py @@ -234,7 +234,7 @@ def test_get_assignment_roster_with_roster_disabled( ) assert ( roster - == user_service.get_users_for_assignment.return_value.order_by.return_value + == user_service.get_users_for_assignment.return_value.add_columns.return_value.order_by.return_value ) def test_get_assignment_roster_with( diff --git a/tests/unit/lms/services/roster_test.py b/tests/unit/lms/services/roster_test.py index 1aa1fe20a0..508e9f83be 100644 --- a/tests/unit/lms/services/roster_test.py +++ b/tests/unit/lms/services/roster_test.py @@ -93,14 +93,43 @@ def test_get_assignment_roster( ) db_session.flush() - assert db_session.scalars( + result = db_session.execute( svc.get_assignment_roster( assignment, role_scope=lti_role.scope if with_role_scope else None, role_type=lti_role.type if with_role_type else None, - h_userids=[lms_user.h_userid] if with_h_userids else None, + h_userids=[lms_user.h_userid, inactive_lms_user.h_userid] + if with_h_userids + else None, ) - ).all() == [lms_user] + ).all() + + assert [(lms_user, True), (inactive_lms_user, False)] == result + + def test_get_assignment_roster_doesnt_return_duplicates( + self, assignment, db_session, svc + ): + lms_user = factories.LMSUser() + lti_role_1 = factories.LTIRole() + lti_role_2 = factories.LTIRole() + + factories.AssignmentRoster( + lms_user=lms_user, + assignment=assignment, + lti_role=lti_role_1, + active=True, + ) + factories.AssignmentRoster( + lms_user=lms_user, + assignment=assignment, + lti_role=lti_role_2, + active=True, + ) + db_session.flush() + + result = db_session.execute(svc.get_assignment_roster(assignment)).all() + + assert [(lms_user, True)] == result def test_fetch_course_roster( self, diff --git a/tests/unit/lms/views/dashboard/api/user_test.py b/tests/unit/lms/views/dashboard/api/user_test.py index 0e4a2cf069..db19747d7c 100644 --- a/tests/unit/lms/views/dashboard/api/user_test.py +++ b/tests/unit/lms/views/dashboard/api/user_test.py @@ -92,23 +92,39 @@ def test_students_metrics( pyramid_request.parsed_params["segment_authority_provided_ids"] = [ g.authority_provided_id for g in segments ] - user_service.get_users.return_value = select(User).where( - User.id.in_( - [ - u.id - for u in [student, student_no_annos, student_no_annos_no_name] - ] + user_service.get_users.return_value = ( + select(User) + .where( + User.id.in_( + [ + u.id + for u in [ + student, + student_no_annos, + student_no_annos_no_name, + ] + ] + ) ) + .add_columns(True) ) else: db_session.flush() - dashboard_service.get_assignment_roster.return_value = select(User).where( - User.id.in_( - [ - u.id - for u in [student, student_no_annos, student_no_annos_no_name] - ] + dashboard_service.get_assignment_roster.return_value = ( + select(User) + .where( + User.id.in_( + [ + u.id + for u in [ + student, + student_no_annos, + student_no_annos_no_name, + ] + ] + ) ) + .add_columns(True) ) db_session.flush() @@ -200,10 +216,17 @@ def test_students_metrics_with_auto_grading( auto_grading_service.get_last_grades.return_value = {} db_session.flush() - dashboard_service.get_assignment_roster.return_value = select(User).where( - User.id.in_( - [u.id for u in [student, student_no_annos, student_no_annos_no_name]] + dashboard_service.get_assignment_roster.return_value = ( + select(User) + .where( + User.id.in_( + [ + u.id + for u in [student, student_no_annos, student_no_annos_no_name] + ] + ) ) + .add_columns(True) ) dashboard_service.get_request_assignment.return_value = assignment h_api.get_annotation_counts.return_value = annotation_counts_response