diff --git a/lms/services/lti_grading/_v13.py b/lms/services/lti_grading/_v13.py index a6cd2ab203..c98833373d 100644 --- a/lms/services/lti_grading/_v13.py +++ b/lms/services/lti_grading/_v13.py @@ -2,7 +2,7 @@ from datetime import datetime, timezone from urllib.parse import urlparse -from lms.models import ApplicationInstance, LTIRegistration +from lms.models import ApplicationInstance, LMSUser, LTIRegistration from lms.product.family import Family from lms.product.plugin.misc import MiscPlugin from lms.services.exceptions import ExternalRequestError, StudentNotInCourse @@ -97,7 +97,7 @@ def sync_grade( # noqa: PLR0913 application_instance: ApplicationInstance, lis_outcome_service_url: str, grade_timestamp: str, - user_grading_id: str, + lms_user: LMSUser, score: float, ): """ @@ -106,7 +106,12 @@ def sync_grade( # noqa: PLR0913 This is very similar to `record_result` but not scoped to the request context, taking all the necessary information as parameters. """ - payload = self._record_score_payload(score, user_grading_id, grade_timestamp) + assert ( + lms_user.lti_v13_user_id + ), "Trying to grade a student without lti_v13_user_id" + payload = LTI13GradingService._record_score_payload( + score, lms_user.lti_v13_user_id, grade_timestamp + ) if application_instance.lti_registration.product_family == Family.CANVAS: # By default Canvas calls to /score create a new submission diff --git a/lms/services/lti_grading/factory.py b/lms/services/lti_grading/factory.py index 4373751929..aac7c141af 100644 --- a/lms/services/lti_grading/factory.py +++ b/lms/services/lti_grading/factory.py @@ -1,27 +1,48 @@ from lms.services.lti_grading._v11 import LTI11GradingService from lms.services.lti_grading._v13 import LTI13GradingService +from lms.services.lti_grading.interface import LTIGradingService from lms.services.ltia_http import LTIAHTTPService -def service_factory(_context, request): - application_instance = request.lti_user.application_instance +def service_factory(_context, request, application_instance=None) -> LTIGradingService: + """Create a new LTIGradingService. - if application_instance.lti_version == "1.3.0": + When called via pyramid services (ie request.find_service) the LTI version is selected + depending on the current request's application_instance. + + For other uses cases (e.g. from a Celery task) the passed application_instance will be used instead. + """ + + if not application_instance: + application_instance = request.lti_user.application_instance + + lti_version = application_instance.lti_version + lis_outcome_service_url = _get_lis_outcome_service_url(request) + + if lti_version == "1.3.0": return LTI13GradingService( - # Pick the value from the right dictionary depending on the context we are running - # either an API call from the frontend (parsed_params) or inside an LTI launch (lti_params). - line_item_url=request.parsed_params.get("lis_outcome_service_url") - or request.lti_params.get("lis_outcome_service_url"), + line_item_url=lis_outcome_service_url, line_item_container_url=request.lti_params.get("lineitems"), ltia_service=request.find_service(LTIAHTTPService), product_family=request.product.family, misc_plugin=request.product.plugin.misc, - lti_registration=request.lti_user.application_instance.lti_registration, + lti_registration=application_instance.lti_registration, ) return LTI11GradingService( - line_item_url=request.parsed_params.get("lis_outcome_service_url"), + line_item_url=lis_outcome_service_url, http_service=request.find_service(name="http"), oauth1_service=request.find_service(name="oauth1"), application_instance=request.lti_user.application_instance, ) + + +def _get_lis_outcome_service_url(request) -> str | None: + # Pick the value from the right dictionary depending on the context we are running + # either an API call from the frontend (parsed_params) or inside an LTI launch (lti_params). + if hasattr(request, "parsed_params") and ( + lis_outcome_service_url := request.parsed_params.get("lis_outcome_service_url") + ): + return lis_outcome_service_url + + return request.lti_params.get("lis_outcome_service_url") diff --git a/lms/services/lti_grading/interface.py b/lms/services/lti_grading/interface.py index dae63d749e..c02ab3e0f2 100644 --- a/lms/services/lti_grading/interface.py +++ b/lms/services/lti_grading/interface.py @@ -1,5 +1,7 @@ from dataclasses import dataclass +from lms.models import ApplicationInstance, LMSUser + @dataclass class GradingResult: @@ -85,6 +87,22 @@ def record_result(self, grading_id, score=None, pre_record_hook=None, comment=No """ raise NotImplementedError() + def sync_grade( # noqa: PLR0913 + self, + application_instance: ApplicationInstance, + lis_outcome_service_url: str, + grade_timestamp: str, + lms_user: LMSUser, + score: float, + ): + """ + Send a grade to the LMS. + + This is very similar to `record_result` but not scoped to the request context, + taking all the necessary information as parameters. + """ + raise NotImplementedError() + def create_line_item(self, resource_link_id, label): """ Create a new line item associated to one resource_link_id. diff --git a/lms/tasks/grading.py b/lms/tasks/grading.py index 66787f786a..2b6ac509b3 100644 --- a/lms/tasks/grading.py +++ b/lms/tasks/grading.py @@ -3,8 +3,7 @@ from sqlalchemy import exists, select from lms.models import GradingSync, GradingSyncGrade -from lms.services import LTIAHTTPService -from lms.services.lti_grading.factory import LTI13GradingService +from lms.services.lti_grading.factory import service_factory from lms.tasks.celery import app @@ -47,15 +46,8 @@ def sync_grade(*, lis_outcome_service_url: str, grading_sync_grade_id: int): grading_sync_grade = request.db.get(GradingSyncGrade, grading_sync_grade_id) grading_sync = grading_sync_grade.grading_sync application_instance = grading_sync.assignment.course.application_instance + grading_service = service_factory(None, request, application_instance) - grading_service = LTI13GradingService( - ltia_service=request.find_service(LTIAHTTPService), - line_item_url=None, - line_item_container_url=None, - product_family=None, # type: ignore - misc_plugin=None, # type: ignore - lti_registration=None, # type: ignore - ) try: grading_service.sync_grade( application_instance, @@ -63,7 +55,7 @@ def sync_grade(*, lis_outcome_service_url: str, grading_sync_grade_id: int): # DB dates are not TZ aware but are always in UTC # Make them TZ aware so the LTI API calls have an explicit timezone grading_sync.created.replace(tzinfo=UTC).isoformat(), - grading_sync_grade.lms_user.lti_v13_user_id, + grading_sync_grade.lms_user, grading_sync_grade.grade, ) except Exception as err: diff --git a/tests/unit/lms/services/lti_grading/_v13_test.py b/tests/unit/lms/services/lti_grading/_v13_test.py index ab1be0ae18..e2b38621bd 100644 --- a/tests/unit/lms/services/lti_grading/_v13_test.py +++ b/tests/unit/lms/services/lti_grading/_v13_test.py @@ -9,6 +9,7 @@ from lms.product.family import Family from lms.services.exceptions import ExternalRequestError, StudentNotInCourse from lms.services.lti_grading._v13 import LTI13GradingService +from tests import factories class TestLTI13GradingService: @@ -153,6 +154,7 @@ def test_get_score_maximum_no_line_item(self, svc, ltia_http_service): def test_sync_grade( self, svc, ltia_http_service, lti_v13_application_instance, is_canvas ): + lms_user = factories.LMSUser(lti_v13_user_id=sentinel.user_id) if is_canvas: lti_v13_application_instance.lti_registration.issuer = ( "https://canvas.instructure.com" @@ -162,7 +164,7 @@ def test_sync_grade( lti_v13_application_instance, "LIS_OUTCOME_SERVICE_URL", datetime(2022, 4, 4).isoformat(), - sentinel.user_id, + lms_user, sentinel.grade, ) diff --git a/tests/unit/lms/services/lti_grading/factory_test.py b/tests/unit/lms/services/lti_grading/factory_test.py index 5690e82873..a63b4b9348 100644 --- a/tests/unit/lms/services/lti_grading/factory_test.py +++ b/tests/unit/lms/services/lti_grading/factory_test.py @@ -2,7 +2,11 @@ import pytest -from lms.services.lti_grading.factory import service_factory +from lms.services.lti_grading.factory import ( + LTI11GradingService, + LTI13GradingService, + service_factory, +) class TestFactory: @@ -56,6 +60,24 @@ def test_v13_line_item_url_from_lti_params( ) assert svc == LTI13GradingService.return_value + @pytest.mark.usefixtures("ltia_http_service", "misc_plugin") + def test_with_explicit_lti_v13_application_instance( + self, pyramid_request, lti_v13_application_instance + ): + svc = service_factory( + sentinel.context, pyramid_request, lti_v13_application_instance + ) + + assert isinstance(svc, LTI13GradingService) + + @pytest.mark.usefixtures("http_service", "oauth1_service") + def test_with_explicit_lti_v11_application_instance( + self, pyramid_request, application_instance + ): + svc = service_factory(sentinel.context, pyramid_request, application_instance) + + assert isinstance(svc, LTI11GradingService) + @pytest.fixture def pyramid_request(self, pyramid_request): pyramid_request.parsed_params = { diff --git a/tests/unit/lms/tasks/grading_test.py b/tests/unit/lms/tasks/grading_test.py index a8ddc7b511..7023e8064c 100644 --- a/tests/unit/lms/tasks/grading_test.py +++ b/tests/unit/lms/tasks/grading_test.py @@ -25,8 +25,8 @@ def test_sync_grade( self, grading_sync, lti_v13_application_instance, - LTI13GradingService, - ltia_http_service, + service_factory, + pyramid_request, sync_grades_complete, ): sync_grade( @@ -34,21 +34,16 @@ def test_sync_grade( grading_sync_grade_id=grading_sync.grades[0].id, ) - LTI13GradingService.assert_called_once_with( - ltia_service=ltia_http_service, - line_item_url=None, - line_item_container_url=None, - product_family=None, - misc_plugin=None, - lti_registration=None, + service_factory.assert_called_once_with( + None, pyramid_request, lti_v13_application_instance ) - grading_service = LTI13GradingService.return_value + grading_service = service_factory.return_value grading_service.sync_grade.assert_called_once_with( lti_v13_application_instance, "URL", grading_sync.created.replace(tzinfo=UTC).isoformat(), - grading_sync.grades[0].lms_user.lti_v13_user_id, + grading_sync.grades[0].lms_user, grading_sync.grades[0].grade, ) sync_grades_complete.apply_async.assert_called_once_with( @@ -58,9 +53,9 @@ def test_sync_grade( @pytest.mark.usefixtures("ltia_http_service") def test_sync_grade_raises( - self, grading_sync, LTI13GradingService, sync_grades_complete + self, grading_sync, service_factory, sync_grades_complete ): - grading_service = LTI13GradingService.return_value + grading_service = service_factory.return_value grading_service.sync_grade.side_effect = Exception sync_grade.max_retries = 2 @@ -75,9 +70,9 @@ def test_sync_grade_raises( @pytest.mark.usefixtures("ltia_http_service") def test_sync_grade_last_retry( - self, grading_sync, LTI13GradingService, sync_grades_complete + self, grading_sync, service_factory, sync_grades_complete ): - grading_service = LTI13GradingService.return_value + grading_service = service_factory.return_value grading_service.sync_grade.side_effect = Exception sync_grade.max_retries = 0 @@ -136,8 +131,8 @@ def sync_grades_complete(self, patch): return patch("lms.tasks.grading.sync_grades_complete") @pytest.fixture - def LTI13GradingService(self, patch): - return patch("lms.tasks.grading.LTI13GradingService") + def service_factory(self, patch): + return patch("lms.tasks.grading.service_factory") @pytest.fixture(autouse=True)