Skip to content

Commit

Permalink
ETag based retry logic for user updates (#2226)
Browse files Browse the repository at this point in the history
  • Loading branch information
zkokelj authored Dec 24, 2024
1 parent 0fca35e commit 2d20c8b
Showing 1 changed file with 64 additions and 71 deletions.
135 changes: 64 additions & 71 deletions tools/walletextension/storage/database/cosmosdb/cosmosdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ type userWithETag struct {
etag azcore.ETag
}

const MAX_RETRIES = 3

func NewCosmosDB(connectionString string, encryptionKey []byte) (*CosmosDB, error) {
// Create encryptor
encryptor, err := encryption.NewEncryptor(encryptionKey)
Expand Down Expand Up @@ -127,63 +129,52 @@ func (c *CosmosDB) DeleteUser(userID []byte) error {
return nil
}

// Adds or updates a session key for the user, with retries on ETag mismatch
func (c *CosmosDB) AddSessionKey(userID []byte, key common.GWSessionKey) error {
ctx := context.Background()

user, err := c.getUserDB(userID)
if err != nil {
return fmt.Errorf("failed to get user: %w", err)
}
user.user.SessionKey = &dbcommon.GWSessionKeyDB{
PrivateKey: crypto.FromECDSA(key.PrivateKey.ExportECDSA()),
Account: dbcommon.GWAccountDB{
AccountAddress: key.Account.Address.Bytes(),
Signature: key.Account.Signature,
SignatureType: int(key.Account.SignatureType),
},
}
return c.updateUser(ctx, user.user)
return c.updateUserWithRetries(ctx, userID, func(u *dbcommon.GWUserDB) error {
u.SessionKey = &dbcommon.GWSessionKeyDB{
PrivateKey: crypto.FromECDSA(key.PrivateKey.ExportECDSA()),
Account: dbcommon.GWAccountDB{
AccountAddress: key.Account.Address.Bytes(),
Signature: key.Account.Signature,
SignatureType: int(key.Account.SignatureType),
},
}
return nil
})
}

// Sets the ActiveSK flag for the user, with retries on ETag mismatch
func (c *CosmosDB) ActivateSessionKey(userID []byte, active bool) error {
ctx := context.Background()

user, err := c.getUserDB(userID)
if err != nil {
return fmt.Errorf("failed to get user: %w", err)
}
user.user.ActiveSK = active
return c.updateUser(ctx, user.user)
return c.updateUserWithRetries(ctx, userID, func(u *dbcommon.GWUserDB) error {
u.ActiveSK = active
return nil
})
}

// Removes the session key for the user, with retries on ETag mismatch
func (c *CosmosDB) RemoveSessionKey(userID []byte) error {
ctx := context.Background()

user, err := c.getUserDB(userID)
if err != nil {
return fmt.Errorf("failed to get user: %w", err)
}
user.user.SessionKey = nil
return c.updateUser(ctx, user.user)
return c.updateUserWithRetries(ctx, userID, func(u *dbcommon.GWUserDB) error {
u.SessionKey = nil
return nil
})
}

// Adds a new account for the user, with retries on ETag mismatch
func (c *CosmosDB) AddAccount(userID []byte, accountAddress []byte, signature []byte, signatureType viewingkey.SignatureType) error {
ctx := context.Background()

user, err := c.getUserDB(userID)
if err != nil {
return fmt.Errorf("failed to get user: %w", err)
}

// Add the new account
newAccount := dbcommon.GWAccountDB{
AccountAddress: accountAddress,
Signature: signature,
SignatureType: int(signatureType),
}
user.user.Accounts = append(user.user.Accounts, newAccount)

return c.updateUser(ctx, user.user)
return c.updateUserWithRetries(ctx, userID, func(u *dbcommon.GWUserDB) error {
newAccount := dbcommon.GWAccountDB{
AccountAddress: accountAddress,
Signature: signature,
SignatureType: int(signatureType),
}
u.Accounts = append(u.Accounts, newAccount)
return nil
})
}

func (c *CosmosDB) GetUser(userID []byte) (*common.GWUser, error) {
Expand Down Expand Up @@ -223,34 +214,6 @@ func (c *CosmosDB) getUserDB(userID []byte) (userWithETag, error) {
return userWithETag{user: user, etag: itemResponse.ETag}, nil
}

func (c *CosmosDB) updateUser(ctx context.Context, user dbcommon.GWUserDB) error {
// Attempt to update without retries
currentUser, err := c.getUserDB(user.UserId)
if err != nil {
return fmt.Errorf("failed to get current user state: %w", err)
}

keyString, partitionKey := c.dbKey(user.UserId)
encryptedDoc, err := c.createEncryptedDoc(user, keyString)
if err != nil {
return fmt.Errorf("failed to marshal updated document: %w", err)
}

options := &azcosmos.ItemOptions{
IfMatchEtag: &currentUser.etag,
}

_, err = c.usersContainer.ReplaceItem(ctx, partitionKey, keyString, encryptedDoc, options)
if err != nil {
if strings.Contains(err.Error(), "Precondition Failed") {
return fmt.Errorf("ETag mismatch: the user document was modified by another process")
}
return fmt.Errorf("failed to update user: %w", err)
}

return nil
}

func (c *CosmosDB) createEncryptedDoc(user dbcommon.GWUserDB, keyString string) ([]byte, error) {
userJSON, err := json.Marshal(user)
if err != nil {
Expand Down Expand Up @@ -286,3 +249,33 @@ func (c *CosmosDB) dbKey(userID []byte) (string, azcosmos.PartitionKey) {
func (c *CosmosDB) GetEncryptionKey() []byte {
return c.encryptor.GetKey()
}

// Retries read–mutate–write if ETag mismatch occurs
func (c *CosmosDB) updateUserWithRetries(ctx context.Context, userID []byte, mutate func(*dbcommon.GWUserDB) error) error {
for i := 0; i < MAX_RETRIES; i++ {
current, err := c.getUserDB(userID)
if err != nil {
return fmt.Errorf("failed to get user: %w", err)
}
if err := mutate(&current.user); err != nil {
return fmt.Errorf("failed to mutate user: %w", err)
}
keyString, partitionKey := c.dbKey(current.user.UserId)
encryptedDoc, err := c.createEncryptedDoc(current.user, keyString)
if err != nil {
return fmt.Errorf("failed to marshal updated document: %w", err)
}
options := &azcosmos.ItemOptions{
IfMatchEtag: &current.etag,
}
_, err = c.usersContainer.ReplaceItem(ctx, partitionKey, keyString, encryptedDoc, options)
if err != nil {
if strings.Contains(err.Error(), "Precondition Failed") {
continue
}
return fmt.Errorf("failed to update user: %w", err)
}
return nil
}
return fmt.Errorf("exceeded max retries, user update failed")
}

0 comments on commit 2d20c8b

Please sign in to comment.