diff --git a/ecommerce/entitlements/tests/test_utils.py b/ecommerce/entitlements/tests/test_utils.py index 2c6cc7f39b2..2e3f6568f50 100644 --- a/ecommerce/entitlements/tests/test_utils.py +++ b/ecommerce/entitlements/tests/test_utils.py @@ -23,20 +23,22 @@ def test_course_entitlement_creation(self): def test_course_entitlement_update(self): """ Test course entitlement product update """ + original_variant_id = '00000000-0000-0000-0000-000000000000' product = create_or_update_course_entitlement( - 'verified', 100, self.partner, 'foo-bar', 'Foo Bar Entitlement') + 'verified', 100, self.partner, 'foo-bar', 'Foo Bar Entitlement', variant_id=original_variant_id) + assert product.attr.variant_id == original_variant_id stock_record = StockRecord.objects.get(product=product, partner=self.partner) self.assertEqual(stock_record.price_excl_tax, 100) self.assertEqual(product.title, 'Course Foo Bar Entitlement') - variant_id = '00000000-0000-0000-0000-000000000000' + new_variant_id = '11111111-1111-1111-1111-11111111' product = create_or_update_course_entitlement( - 'verified', 200, self.partner, 'foo-bar', 'Foo Bar Entitlement', variant_id=variant_id) + 'verified', 200, self.partner, 'foo-bar', 'Foo Bar Entitlement', variant_id=new_variant_id) stock_record = StockRecord.objects.get(product=product, partner=self.partner) self.assertEqual(stock_record.price_excl_tax, 200) self.assertEqual(stock_record.price_excl_tax, 200) product.refresh_from_db() - assert product.attr.variant_id == '00000000-0000-0000-0000-000000000000' + assert product.attr.variant_id == new_variant_id diff --git a/ecommerce/entitlements/utils.py b/ecommerce/entitlements/utils.py index fd50a65dcc4..bd7e3feda6f 100644 --- a/ecommerce/entitlements/utils.py +++ b/ecommerce/entitlements/utils.py @@ -62,10 +62,12 @@ def create_or_update_course_entitlement( """ Create or Update Course Entitlement Products """ certificate_type = certificate_type.lower() UUID = str(UUID) + has_existing_course_entitlement = False try: parent_entitlement, __ = create_parent_course_entitlement(title, UUID) course_entitlement = get_entitlement(UUID, certificate_type) + has_existing_course_entitlement = True except Product.DoesNotExist: course_entitlement = Product() @@ -79,6 +81,11 @@ def create_or_update_course_entitlement( course_entitlement.parent = parent_entitlement if variant_id: course_entitlement.attr.variant_id = variant_id + if has_existing_course_entitlement: + # Calling `save` on the attributes is necessary for any updates to persist. This is not necessary + # for new attributes, only for existing attributes. This `save` method must be called before saving + # the associated course entitlement below. + course_entitlement.attr.save() course_entitlement.save() __, created = StockRecord.objects.update_or_create(