Skip to content

Commit

Permalink
Include the real value of roster.active for assignments
Browse files Browse the repository at this point in the history
  • Loading branch information
marcospri committed Dec 11, 2024
1 parent 23c6f94 commit a4aaec4
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 36 deletions.
5 changes: 3 additions & 2 deletions lms/services/dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def get_request_admin_organizations(self, request) -> list[Organization]:

def get_assignment_roster(
self, assignment: Assignment, h_userids: list[str] | None = None
) -> Select[tuple[LMSUser]]:
) -> Select[tuple[LMSUser, bool]]:
rosters_enabled = (
assignment.course
and assignment.course.application_instance.settings.get(
Expand All @@ -203,7 +203,8 @@ def get_assignment_roster(
role_type=RoleType.LEARNER,
assignment_id=assignment.id,
h_userids=h_userids,
)
# For launch data we always add the "active" column as true for compatibility with the roster query.
).add_columns(True)

# Always return the results, no matter the source, sorted
return query.order_by(LMSUser.display_name, LMSUser.id)
Expand Down
20 changes: 8 additions & 12 deletions lms/services/roster.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,29 +80,25 @@ def get_assignment_roster(
role_scope: RoleScope | None = None,
role_type: RoleType | None = None,
h_userids: list[str] | None = None,
) -> Select[tuple[LMSUser]]:
) -> Select[tuple[LMSUser, bool]]:
"""Get the roster information for a course from our DB."""
roster_query = (
select(AssignmentRoster.lms_user_id)
.join(LTIRole)
.where(
AssignmentRoster.assignment_id == assignment.id,
AssignmentRoster.active.is_(True),
)
)
select(LMSUser, AssignmentRoster.active)
.join(LMSUser, AssignmentRoster.lms_user_id == LMSUser.id)
.join(LTIRole, AssignmentRoster.lti_role_id == LTIRole.id)
.where(AssignmentRoster.assignment_id == assignment.id)
).distinct()

if role_scope:
roster_query = roster_query.where(LTIRole.scope == role_scope)

if role_type:
roster_query = roster_query.where(LTIRole.type == role_type)

query = select(LMSUser).where(LMSUser.id.in_(roster_query))

if h_userids:
query = query.where(LMSUser.h_userid.in_(h_userids))
roster_query = roster_query.where(LMSUser.h_userid.in_(h_userids))

return query
return roster_query

def fetch_course_roster(self, lms_course: LMSCourse) -> None:
"""Fetch the roster information for a course from the LMS."""
Expand Down
7 changes: 4 additions & 3 deletions lms/views/dashboard/api/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,12 @@ def students_metrics(self) -> APIRoster:
h_userids=request_h_userids,
)
# Iterate over all the students we have in the DB
for user in self.request.db.scalars(users_query).all():
for roster_data in self.request.db.execute(users_query).all():
user, active = roster_data
if s := stats_by_user.get(user.h_userid):
# We seen this student in H, get all the data from there
api_student = RosterEntry(
active=True,
active=active,
h_userid=user.h_userid,
lms_id=user.user_id,
display_name=s["display_name"],
Expand All @@ -165,7 +166,7 @@ def students_metrics(self) -> APIRoster:
# We haven't seen this user H,
# use LMS DB's data and set 0s for all annotation related fields.
api_student = RosterEntry(
active=True,
active=active,
h_userid=user.h_userid,
lms_id=user.user_id,
display_name=user.display_name,
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/lms/services/dashboard_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def test_get_assignment_roster_with_roster_disabled(
)
assert (
roster
== user_service.get_users_for_assignment.return_value.order_by.return_value
== user_service.get_users_for_assignment.return_value.add_columns.return_value.order_by.return_value
)

def test_get_assignment_roster_with(
Expand Down
35 changes: 32 additions & 3 deletions tests/unit/lms/services/roster_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,43 @@ def test_get_assignment_roster(
)
db_session.flush()

assert db_session.scalars(
result = db_session.execute(
svc.get_assignment_roster(
assignment,
role_scope=lti_role.scope if with_role_scope else None,
role_type=lti_role.type if with_role_type else None,
h_userids=[lms_user.h_userid] if with_h_userids else None,
h_userids=[lms_user.h_userid, inactive_lms_user.h_userid]
if with_h_userids
else None,
)
).all() == [lms_user]
).all()

assert [(lms_user, True), (inactive_lms_user, False)] == result

def test_get_assignment_roster_doesnt_return_duplicates(
self, assignment, db_session, svc
):
lms_user = factories.LMSUser()
lti_role_1 = factories.LTIRole()
lti_role_2 = factories.LTIRole()

factories.AssignmentRoster(
lms_user=lms_user,
assignment=assignment,
lti_role=lti_role_1,
active=True,
)
factories.AssignmentRoster(
lms_user=lms_user,
assignment=assignment,
lti_role=lti_role_2,
active=True,
)
db_session.flush()

result = db_session.execute(svc.get_assignment_roster(assignment)).all()

assert [(lms_user, True)] == result

def test_fetch_course_roster(
self,
Expand Down
53 changes: 38 additions & 15 deletions tests/unit/lms/views/dashboard/api/user_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,23 +92,39 @@ def test_students_metrics(
pyramid_request.parsed_params["segment_authority_provided_ids"] = [
g.authority_provided_id for g in segments
]
user_service.get_users.return_value = select(User).where(
User.id.in_(
[
u.id
for u in [student, student_no_annos, student_no_annos_no_name]
]
user_service.get_users.return_value = (
select(User)
.where(
User.id.in_(
[
u.id
for u in [
student,
student_no_annos,
student_no_annos_no_name,
]
]
)
)
.add_columns(True)
)
else:
db_session.flush()
dashboard_service.get_assignment_roster.return_value = select(User).where(
User.id.in_(
[
u.id
for u in [student, student_no_annos, student_no_annos_no_name]
]
dashboard_service.get_assignment_roster.return_value = (
select(User)
.where(
User.id.in_(
[
u.id
for u in [
student,
student_no_annos,
student_no_annos_no_name,
]
]
)
)
.add_columns(True)
)

db_session.flush()
Expand Down Expand Up @@ -200,10 +216,17 @@ def test_students_metrics_with_auto_grading(
auto_grading_service.get_last_grades.return_value = {}

db_session.flush()
dashboard_service.get_assignment_roster.return_value = select(User).where(
User.id.in_(
[u.id for u in [student, student_no_annos, student_no_annos_no_name]]
dashboard_service.get_assignment_roster.return_value = (
select(User)
.where(
User.id.in_(
[
u.id
for u in [student, student_no_annos, student_no_annos_no_name]
]
)
)
.add_columns(True)
)
dashboard_service.get_request_assignment.return_value = assignment
h_api.get_annotation_counts.return_value = annotation_counts_response
Expand Down

0 comments on commit a4aaec4

Please sign in to comment.