Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
fix: ensure updates to product attributes are persisted during updates
Browse files Browse the repository at this point in the history
  • Loading branch information
adamstankiewicz authored and macdiesel committed Oct 17, 2022
1 parent 6348631 commit 05404eb
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 4 deletions.
10 changes: 6 additions & 4 deletions ecommerce/entitlements/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,22 @@ def test_course_entitlement_creation(self):

def test_course_entitlement_update(self):
""" Test course entitlement product update """
original_variant_id = '00000000-0000-0000-0000-000000000000'
product = create_or_update_course_entitlement(
'verified', 100, self.partner, 'foo-bar', 'Foo Bar Entitlement')
'verified', 100, self.partner, 'foo-bar', 'Foo Bar Entitlement', variant_id=original_variant_id)
assert product.attr.variant_id == original_variant_id
stock_record = StockRecord.objects.get(product=product, partner=self.partner)

self.assertEqual(stock_record.price_excl_tax, 100)
self.assertEqual(product.title, 'Course Foo Bar Entitlement')

variant_id = '00000000-0000-0000-0000-000000000000'
new_variant_id = '11111111-1111-1111-1111-11111111'
product = create_or_update_course_entitlement(
'verified', 200, self.partner, 'foo-bar', 'Foo Bar Entitlement', variant_id=variant_id)
'verified', 200, self.partner, 'foo-bar', 'Foo Bar Entitlement', variant_id=new_variant_id)

stock_record = StockRecord.objects.get(product=product, partner=self.partner)
self.assertEqual(stock_record.price_excl_tax, 200)
self.assertEqual(stock_record.price_excl_tax, 200)

product.refresh_from_db()
assert product.attr.variant_id == '00000000-0000-0000-0000-000000000000'
assert product.attr.variant_id == new_variant_id
7 changes: 7 additions & 0 deletions ecommerce/entitlements/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,12 @@ def create_or_update_course_entitlement(
""" Create or Update Course Entitlement Products """
certificate_type = certificate_type.lower()
UUID = str(UUID)
has_existing_course_entitlement = False

try:
parent_entitlement, __ = create_parent_course_entitlement(title, UUID)
course_entitlement = get_entitlement(UUID, certificate_type)
has_existing_course_entitlement = True
except Product.DoesNotExist:
course_entitlement = Product()

Expand All @@ -79,6 +81,11 @@ def create_or_update_course_entitlement(
course_entitlement.parent = parent_entitlement
if variant_id:
course_entitlement.attr.variant_id = variant_id
if has_existing_course_entitlement:
# Calling `save` on the attributes is necessary for any updates to persist. This is not necessary
# for new attributes, only for existing attributes. This `save` method must be called before saving
# the associated course entitlement below.
course_entitlement.attr.save()
course_entitlement.save()

__, created = StockRecord.objects.update_or_create(
Expand Down

0 comments on commit 05404eb

Please sign in to comment.