From be5e873fa5b6ff24985774aabcaac3618b865507 Mon Sep 17 00:00:00 2001 From: Marcos Prieto Date: Thu, 28 Nov 2024 14:24:44 +0100 Subject: [PATCH] Bugfix, clearing LMSUser.lti_v13_user_id on LTI1.1 launches bulk_upsert will update the table columns in `update_elements` with the value passed in updated. In the case of lti_v13_user_id that means that the value will be set on LTI1.3 and removed (set to None) on LTI1.1 launches. To fix this we change the value set from the raw value to coalesce(value, LMSUser.lti_v13_user_id) --- lms/services/upsert.py | 13 +++++++++++-- lms/services/user.py | 15 +++++++++++++-- tests/unit/lms/services/user_test.py | 15 +++++++++++++++ 3 files changed, 39 insertions(+), 4 deletions(-) diff --git a/lms/services/upsert.py b/lms/services/upsert.py index 8db15925e6..c02754e229 100644 --- a/lms/services/upsert.py +++ b/lms/services/upsert.py @@ -10,7 +10,7 @@ def bulk_upsert( model_class, values: list[dict], index_elements: list[str], - update_columns: list[str], + update_columns: list[str | tuple], ): """ Create or update the specified values in a table. @@ -50,7 +50,16 @@ def bulk_upsert( # The columns to use to find matching rows. index_elements=index_elements, # The columns to update. - set_={element: getattr(base.excluded, element) for element in update_columns}, + set_={ + # For tuples include the two elements as the key and value of the dict + # For strings use value: excluded.value by default + (element[0] if isinstance(element, tuple) else element): ( + element[1] + if isinstance(element, tuple) + else getattr(base.excluded, element) + ) + for element in update_columns + }, ).returning(*index_elements_columns) result = db.execute(stmt) diff --git a/lms/services/user.py b/lms/services/user.py index 53e57dd632..6a7dbca286 100644 --- a/lms/services/user.py +++ b/lms/services/user.py @@ -1,6 +1,6 @@ from functools import lru_cache -from sqlalchemy import select +from sqlalchemy import func, select, text from sqlalchemy.exc import NoResultFound from sqlalchemy.sql import Select @@ -89,7 +89,18 @@ def upsert_lms_user(self, user: User, lti_params: LTIParams) -> LMSUser: } ], index_elements=["h_userid"], - update_columns=["updated", "display_name", "email", "lti_v13_user_id"], + update_columns=[ + "updated", + "display_name", + "email", + ( + "lti_v13_user_id", + func.coalesce( + text('"excluded"."lti_v13_user_id"'), + text('"lms_user"."lti_v13_user_id"'), + ), + ), + ], ).one() bulk_upsert( self._db, diff --git a/tests/unit/lms/services/user_test.py b/tests/unit/lms/services/user_test.py index 990946cefe..3813ccd80c 100644 --- a/tests/unit/lms/services/user_test.py +++ b/tests/unit/lms/services/user_test.py @@ -76,6 +76,21 @@ def test_upsert_lms_user(self, service, lti_user, pyramid_request, db_session): assert lms_user.updated == user.updated assert lms_user.lti_v13_user_id == pyramid_request.lti_params.v13.get("sub") + def test_upsert_lms_user_doesnt_clear_lti_v13_user_id( + self, service, lti_user, pyramid_request, db_session + ): + lms_user = factories.LMSUser( + lti_v13_user_id="EXISTING", + h_userid=lti_user.h_user.userid(authority="authority.example.com"), + ) + db_session.commit() + pyramid_request.lti_params.v13["sub"] = None + + user = service.upsert_user(lti_user) + lms_user = service.upsert_lms_user(user, pyramid_request.lti_params) + + assert lms_user.lti_v13_user_id == "EXISTING" + def test_get(self, user, service): db_user = service.get(user.application_instance, user.user_id)