Skip to content

Commit a93d5f0

Browse files
author
Graham Jenson
authored
Merge pull request coinbase#55 from coinbase/ian-lai/dynamodb-locking
Migrate from S3 to DynamoDB for locking
2 parents 2631af5 + cbb46a2 commit a93d5f0

12 files changed

+468
-25
lines changed

aws/aws.go

+8
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ import (
66
"github.com/aws/aws-sdk-go/aws"
77
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
88
"github.com/aws/aws-sdk-go/aws/session"
9+
"github.com/aws/aws-sdk-go/service/dynamodb"
10+
"github.com/aws/aws-sdk-go/service/dynamodb/dynamodbiface"
911
"github.com/aws/aws-sdk-go/service/lambda"
1012
"github.com/aws/aws-sdk-go/service/lambda/lambdaiface"
1113
"github.com/aws/aws-sdk-go/service/s3"
@@ -21,11 +23,13 @@ import (
2123
type S3API s3iface.S3API
2224
type LambdaAPI lambdaiface.LambdaAPI
2325
type SFNAPI sfniface.SFNAPI
26+
type DynamoDBAPI dynamodbiface.DynamoDBAPI
2427

2528
type AwsClients interface {
2629
S3Client(region *string, account_id *string, role *string) S3API
2730
LambdaClient(region *string, account_id *string, role *string) LambdaAPI
2831
SFNClient(region *string, account_id *string, role *string) SFNAPI
32+
DynamoDBClient(region *string, account_id *string, role *string) DynamoDBAPI
2933
}
3034

3135
////////////
@@ -106,3 +110,7 @@ func (c *Clients) LambdaClient(region *string, account_id *string, role *string)
106110
func (c *Clients) SFNClient(region *string, account_id *string, role *string) SFNAPI {
107111
return sfn.New(c.Session(), c.Config(region, account_id, role))
108112
}
113+
114+
func (c *Clients) DynamoDBClient(region *string, account_id *string, role *string) DynamoDBAPI {
115+
return dynamodb.New(c.Session(), c.Config(region, account_id, role))
116+
}

aws/dynamodb/lock.go

+99
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
package dynamodb
2+
3+
import (
4+
"fmt"
5+
"time"
6+
7+
awssdk "github.com/aws/aws-sdk-go/aws"
8+
"github.com/aws/aws-sdk-go/aws/awserr"
9+
"github.com/aws/aws-sdk-go/service/dynamodb"
10+
"github.com/aws/aws-sdk-go/service/dynamodb/expression"
11+
12+
stepaws "github.com/coinbase/step/aws"
13+
)
14+
15+
var (
16+
columnKey = "key"
17+
columnId = "id"
18+
columnTime = "time"
19+
)
20+
21+
type DynamoDBLocker struct {
22+
client stepaws.DynamoDBAPI
23+
}
24+
25+
func NewDynamoDBLocker(client stepaws.DynamoDBAPI) *DynamoDBLocker {
26+
return &DynamoDBLocker{client}
27+
}
28+
29+
func (l *DynamoDBLocker) GrabLock(namespace string, lockPath string, uuid string, reason string) (bool, error) {
30+
// Construct a conditional expression such that we only allow a new lock
31+
// to be created if there is not already one for the same key.
32+
condExp := expression.Name(columnKey).AttributeNotExists()
33+
condExp = condExp.Or(expression.Name(columnId).Equal(expression.Value(uuid)))
34+
35+
expr, err := expression.NewBuilder().WithCondition(condExp).Build()
36+
if err != nil {
37+
return false, err
38+
}
39+
40+
// Attempt to create a lock
41+
_, err = l.client.PutItem(&dynamodb.PutItemInput{
42+
TableName: awssdk.String(namespace),
43+
ConditionExpression: expr.Condition(),
44+
ExpressionAttributeNames: expr.Names(),
45+
ExpressionAttributeValues: expr.Values(),
46+
Item: map[string]*dynamodb.AttributeValue{
47+
columnKey: {
48+
S: awssdk.String(lockPath),
49+
},
50+
columnId: {
51+
S: awssdk.String(uuid),
52+
},
53+
columnTime: {
54+
S: awssdk.String(time.Now().Format(time.RFC3339)),
55+
},
56+
},
57+
})
58+
59+
if err != nil {
60+
awsErr, ok := err.(awserr.Error)
61+
// A lock already exists for the same key.
62+
if ok && awsErr.Code() == dynamodb.ErrCodeConditionalCheckFailedException {
63+
return false, nil
64+
}
65+
66+
return false, err
67+
}
68+
69+
return true, nil
70+
}
71+
72+
func (l *DynamoDBLocker) ReleaseLock(namespace string, lockPath string, uuid string) error {
73+
// Construct a condition expression such that we only allow a lock
74+
// to be deleted if the key, and the UUID aligns.
75+
condExp := expression.Name(columnId).Equal(expression.Value(uuid))
76+
expr, err := expression.NewBuilder().WithCondition(condExp).Build()
77+
if err != nil {
78+
return err
79+
}
80+
81+
// Attempt to delete lock
82+
_, err = l.client.DeleteItem(&dynamodb.DeleteItemInput{
83+
TableName: awssdk.String(namespace),
84+
ConditionExpression: expr.Condition(),
85+
ExpressionAttributeNames: expr.Names(),
86+
ExpressionAttributeValues: expr.Values(),
87+
Key: map[string]*dynamodb.AttributeValue{
88+
columnKey: {
89+
S: awssdk.String(lockPath),
90+
},
91+
},
92+
})
93+
94+
if awsErr, ok := err.(awserr.Error); ok && awsErr.Code() == dynamodb.ErrCodeConditionalCheckFailedException {
95+
// A lock already exists, but with a different UUID.
96+
return fmt.Errorf("Lock was stolen for release with UUID(%v)", uuid)
97+
}
98+
return err
99+
}

aws/dynamodb/lock_test.go

+94
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
package dynamodb
2+
3+
import (
4+
"errors"
5+
"testing"
6+
7+
"github.com/aws/aws-sdk-go/aws/awserr"
8+
"github.com/aws/aws-sdk-go/service/dynamodb"
9+
"github.com/aws/aws-sdk-go/service/dynamodb/dynamodbiface"
10+
"github.com/stretchr/testify/assert"
11+
)
12+
13+
type MockDynamoDBClient struct {
14+
dynamodbiface.DynamoDBAPI
15+
putItemCallback func(input *dynamodb.PutItemInput) (*dynamodb.PutItemOutput, error)
16+
deleteItemCallback func(input *dynamodb.DeleteItemInput) (*dynamodb.DeleteItemOutput, error)
17+
}
18+
19+
func (c *MockDynamoDBClient) PutItem(input *dynamodb.PutItemInput) (*dynamodb.PutItemOutput, error) {
20+
return c.putItemCallback(input)
21+
}
22+
23+
func (c *MockDynamoDBClient) DeleteItem(input *dynamodb.DeleteItemInput) (*dynamodb.DeleteItemOutput, error) {
24+
return c.deleteItemCallback(input)
25+
}
26+
27+
func TestLock(t *testing.T) {
28+
t.Run("lock failure", func(t *testing.T) {
29+
client := &MockDynamoDBClient{}
30+
client.putItemCallback = func(input *dynamodb.PutItemInput) (*dynamodb.PutItemOutput, error) {
31+
return nil, awserr.New(dynamodb.ErrCodeConditionalCheckFailedException, "The conditional request failed.", errors.New("fake error"))
32+
}
33+
34+
locker := &DynamoDBLocker{client}
35+
36+
grabbed, err := locker.GrabLock("tableName", "lockPath", "uuid", "testing")
37+
assert.NoError(t, err)
38+
assert.False(t, grabbed)
39+
})
40+
41+
t.Run("lock acquired successfully", func(t *testing.T) {
42+
client := &MockDynamoDBClient{}
43+
client.putItemCallback = func(input *dynamodb.PutItemInput) (*dynamodb.PutItemOutput, error) {
44+
assert.Equal(t, "tableName", *input.TableName)
45+
assert.Equal(t, "lockPath", *input.Item[columnKey].S)
46+
assert.Equal(t, "uuid", *input.Item[columnId].S)
47+
assert.Equal(t, "(attribute_not_exists (#0)) OR (#1 = :0)", *input.ConditionExpression)
48+
49+
assert.Equal(t, "key", *input.ExpressionAttributeNames["#0"])
50+
assert.Equal(t, "id", *input.ExpressionAttributeNames["#1"])
51+
assert.Equal(t, "uuid", *input.ExpressionAttributeValues[":0"].S)
52+
53+
return &dynamodb.PutItemOutput{}, nil
54+
}
55+
56+
locker := &DynamoDBLocker{client}
57+
58+
grabbed, err := locker.GrabLock("tableName", "lockPath", "uuid", "testing")
59+
assert.NoError(t, err)
60+
assert.True(t, grabbed)
61+
})
62+
}
63+
64+
func TestUnlock(t *testing.T) {
65+
t.Run("unlock failure", func(t *testing.T) {
66+
client := &MockDynamoDBClient{}
67+
client.deleteItemCallback = func(input *dynamodb.DeleteItemInput) (*dynamodb.DeleteItemOutput, error) {
68+
return nil, awserr.New(dynamodb.ErrCodeConditionalCheckFailedException, "The conditional request failed.", errors.New("fake error"))
69+
}
70+
71+
locker := &DynamoDBLocker{client}
72+
73+
err := locker.ReleaseLock("tableName", "lockPath", "uuid")
74+
assert.Error(t, err)
75+
})
76+
77+
t.Run("unlock released", func(t *testing.T) {
78+
client := &MockDynamoDBClient{}
79+
client.deleteItemCallback = func(input *dynamodb.DeleteItemInput) (*dynamodb.DeleteItemOutput, error) {
80+
assert.Equal(t, "tableName", *input.TableName)
81+
82+
assert.Equal(t, "id", *input.ExpressionAttributeNames["#0"])
83+
assert.Equal(t, "uuid", *input.ExpressionAttributeValues[":0"].S)
84+
assert.Equal(t, "lockPath", *input.Key[columnKey].S)
85+
86+
return &dynamodb.DeleteItemOutput{}, nil
87+
}
88+
89+
locker := &DynamoDBLocker{client}
90+
91+
err := locker.ReleaseLock("tableName", "lockPath", "uuid")
92+
assert.NoError(t, err)
93+
})
94+
}

aws/mocks/mock_dynamodb.go

+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
package mocks
2+
3+
import (
4+
"github.com/aws/aws-sdk-go/service/dynamodb"
5+
"github.com/aws/aws-sdk-go/service/dynamodb/dynamodbiface"
6+
)
7+
8+
type MockDynamoDBClient struct {
9+
dynamodbiface.DynamoDBAPI
10+
11+
PutItemInputs []*dynamodb.PutItemInput
12+
DeleteItemInputs []*dynamodb.DeleteItemInput
13+
}
14+
15+
func (m *MockDynamoDBClient) init() {
16+
if m.PutItemInputs == nil {
17+
m.PutItemInputs = []*dynamodb.PutItemInput{}
18+
}
19+
20+
if m.DeleteItemInputs == nil {
21+
m.DeleteItemInputs = []*dynamodb.DeleteItemInput{}
22+
}
23+
}
24+
25+
func (m *MockDynamoDBClient) PutItem(input *dynamodb.PutItemInput) (*dynamodb.PutItemOutput, error) {
26+
m.PutItemInputs = append(m.PutItemInputs, input)
27+
return &dynamodb.PutItemOutput{}, nil
28+
}
29+
30+
func (m *MockDynamoDBClient) DeleteItem(input *dynamodb.DeleteItemInput) (*dynamodb.DeleteItemOutput, error) {
31+
m.DeleteItemInputs = append(m.DeleteItemInputs, input)
32+
return &dynamodb.DeleteItemOutput{}, nil
33+
}

aws/mocks/mocks.go

+9-3
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@ package mocks
33
import "github.com/coinbase/step/aws"
44

55
type MockClients struct {
6-
S3 *MockS3Client
7-
Lambda *MockLambdaClient
8-
SFN *MockSFNClient
6+
S3 *MockS3Client
7+
Lambda *MockLambdaClient
8+
SFN *MockSFNClient
9+
DynamoDB *MockDynamoDBClient
910
}
1011

1112
func (awsc *MockClients) S3Client(*string, *string, *string) aws.S3API {
@@ -20,10 +21,15 @@ func (awsc *MockClients) SFNClient(*string, *string, *string) aws.SFNAPI {
2021
return awsc.SFN
2122
}
2223

24+
func (awsc *MockClients) DynamoDBClient(*string, *string, *string) aws.DynamoDBAPI {
25+
return awsc.DynamoDB
26+
}
27+
2328
func MockAwsClients() *MockClients {
2429
return &MockClients{
2530
&MockS3Client{},
2631
&MockLambdaClient{},
2732
&MockSFNClient{},
33+
&MockDynamoDBClient{},
2834
}
2935
}

aws/s3/s3.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ func PutWithTypeAndCacheControl(s3c aws.S3API, bucket *string, path *string, con
157157
Key: path,
158158
Body: bytes.NewReader(*content),
159159
ACL: to.Strp("private"),
160-
ContentType: contentType,
160+
ContentType: contentType,
161161
CacheControl: cacheControl,
162162
})
163163
}

bifrost/inmemory_locker.go

+88
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
package bifrost
2+
3+
import (
4+
"fmt"
5+
"sync"
6+
)
7+
8+
type Lock struct {
9+
lockPath string
10+
uuid string
11+
reason string
12+
}
13+
14+
type InMemoryLocker struct {
15+
mu sync.RWMutex
16+
locks map[string][]*Lock
17+
}
18+
19+
func NewInMemoryLocker() *InMemoryLocker {
20+
return &InMemoryLocker{
21+
locks: make(map[string][]*Lock),
22+
}
23+
}
24+
25+
func (l *InMemoryLocker) GrabLock(namespace string, lockPath string, uuid string, reason string) (bool, error) {
26+
existingLock := l.GetLockByPath(namespace, lockPath)
27+
if existingLock != nil {
28+
return existingLock.uuid == uuid, nil
29+
}
30+
31+
l.mu.Lock()
32+
defer l.mu.Unlock()
33+
34+
l.locks[namespace] = append(l.locks[namespace], &Lock{
35+
lockPath: lockPath,
36+
uuid: uuid,
37+
reason: reason,
38+
})
39+
40+
return true, nil
41+
}
42+
43+
func (l *InMemoryLocker) ReleaseLock(namespace string, lockPath string, uuid string) error {
44+
existingLock := l.GetLockByPath(namespace, lockPath)
45+
if existingLock != nil && existingLock.uuid != uuid {
46+
return fmt.Errorf("failed to release lock: %s is currently held by UUID(%v)", lockPath, existingLock.uuid)
47+
}
48+
49+
l.mu.Lock()
50+
defer l.mu.Unlock()
51+
52+
var updatedLocks []*Lock
53+
for _, lock := range l.locks[namespace] {
54+
if lock.uuid == uuid {
55+
continue
56+
}
57+
updatedLocks = append(updatedLocks, lock)
58+
}
59+
60+
l.locks[namespace] = updatedLocks
61+
62+
return nil
63+
}
64+
65+
func (l *InMemoryLocker) GetLockByNamespace(namespace string) []*Lock {
66+
l.mu.RLock()
67+
defer l.mu.RUnlock()
68+
69+
locks, found := l.locks[namespace]
70+
if !found {
71+
return []*Lock{}
72+
}
73+
74+
return locks
75+
}
76+
77+
func (l *InMemoryLocker) GetLockByPath(namespace string, lockPath string) *Lock {
78+
l.mu.RLock()
79+
defer l.mu.RUnlock()
80+
81+
for _, lock := range l.GetLockByNamespace(namespace) {
82+
if lock.lockPath == lockPath {
83+
return lock
84+
}
85+
}
86+
87+
return nil
88+
}

0 commit comments

Comments
 (0)