Skip to content

Commit

Permalink
test: add test for DynamoMQClient return UnmarshalingAttributeError
Browse files Browse the repository at this point in the history
  • Loading branch information
vvatanabe committed Nov 28, 2023
1 parent 8223db8 commit 0c961ca
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 41 deletions.
50 changes: 31 additions & 19 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ const (
DefaultQueueingIndexName = "dynamo-mq-index-queue_type-queue_add_timestamp"
DefaultRetryMaxAttempts = 10
DefaultVisibilityTimeoutInMinutes = 1
DefaultMaxListMessages = 10
)

type Client[T any] interface {
Expand All @@ -42,10 +43,12 @@ type ClientOptions struct {
VisibilityTimeoutInMinutes int
MaximumReceives int
UseFIFO bool
BaseEndpoint string
RetryMaxAttempts int
Clock clock.Clock

BaseEndpoint string
RetryMaxAttempts int
MarshalMap func(in interface{}) (map[string]types.AttributeValue, error)
UnmarshalMap func(m map[string]types.AttributeValue, out interface{}) error
}

func WithAWSDynamoDBClient(client *dynamodb.Client) func(*ClientOptions) {
Expand Down Expand Up @@ -98,6 +101,8 @@ func NewFromConfig[T any](cfg aws.Config, optFns ...func(*ClientOptions)) (Clien
VisibilityTimeoutInMinutes: DefaultVisibilityTimeoutInMinutes,
UseFIFO: false,
Clock: &clock.RealClock{},
MarshalMap: attributevalue.MarshalMap,
UnmarshalMap: attributevalue.UnmarshalMap,
}
for _, opt := range optFns {
opt(o)
Expand All @@ -110,6 +115,8 @@ func NewFromConfig[T any](cfg aws.Config, optFns ...func(*ClientOptions)) (Clien
useFIFO: o.UseFIFO,
dynamoDB: o.DynamoDB,
clock: o.Clock,
marshalMap: o.MarshalMap,
unmarshalMap: o.UnmarshalMap,
}
if c.dynamoDB != nil {
return c, nil
Expand All @@ -131,6 +138,8 @@ type client[T any] struct {
maximumReceives int
useFIFO bool
clock clock.Clock
marshalMap func(in interface{}) (map[string]types.AttributeValue, error)
unmarshalMap func(m map[string]types.AttributeValue, out interface{}) error
}

type SendMessageInput[T any] struct {
Expand Down Expand Up @@ -270,8 +279,8 @@ func (c *client[T]) processQueryResult(queryResult *dynamodb.QueryOutput) (*Mess

for _, itemMap := range queryResult.Items {
item := Message[T]{}
if err := attributevalue.UnmarshalMap(itemMap, &item); err != nil {
return nil, &UnmarshalingAttributeError{Cause: err}
if err := c.unmarshalMap(itemMap, &item); err != nil {
return nil, UnmarshalingAttributeError{Cause: err}
}

if err := item.markAsProcessing(c.clock.Now(), visibilityTimeout); err == nil {
Expand Down Expand Up @@ -572,7 +581,7 @@ func (c *client[T]) queryAndCalculateQueueStats(ctx context.Context, expr expres
}
exclusiveStartKey = queryOutput.LastEvaluatedKey

err = processQueryItemsForQueueStats[T](queryOutput.Items, stats)
err = c.processQueryItemsForQueueStats(queryOutput.Items, stats)
if err != nil {
return nil, err
}
Expand All @@ -585,13 +594,13 @@ func (c *client[T]) queryAndCalculateQueueStats(ctx context.Context, expr expres
return stats, nil
}

func processQueryItemsForQueueStats[T any](items []map[string]types.AttributeValue, stats *GetQueueStatsOutput) error {
func (c *client[T]) processQueryItemsForQueueStats(items []map[string]types.AttributeValue, stats *GetQueueStatsOutput) error {
for _, itemMap := range items {
stats.TotalRecordsInQueue++
item := Message[T]{}
err := attributevalue.UnmarshalMap(itemMap, &item)
err := c.unmarshalMap(itemMap, &item)
if err != nil {
return &UnmarshalingAttributeError{Cause: err}
return UnmarshalingAttributeError{Cause: err}
}

updateQueueStatsFromItem[T](&item, stats)
Expand Down Expand Up @@ -662,7 +671,7 @@ func (c *client[T]) queryAndCalculateDLQStats(ctx context.Context, expr expressi
}
lastEvaluatedKey = queryOutput.LastEvaluatedKey

err = processQueryItemsForDLQStats[T](queryOutput.Items, stats)
err = c.processQueryItemsForDLQStats(queryOutput.Items, stats)
if err != nil {
return nil, err
}
Expand All @@ -674,14 +683,14 @@ func (c *client[T]) queryAndCalculateDLQStats(ctx context.Context, expr expressi
return stats, nil
}

func processQueryItemsForDLQStats[T any](items []map[string]types.AttributeValue, stats *GetDLQStatsOutput) error {
func (c *client[T]) processQueryItemsForDLQStats(items []map[string]types.AttributeValue, stats *GetDLQStatsOutput) error {
for _, itemMap := range items {
stats.TotalRecordsInDLQ++
if len(stats.First100IDsInQueue) < 100 {
item := Message[T]{}
err := attributevalue.UnmarshalMap(itemMap, &item)
err := c.unmarshalMap(itemMap, &item)
if err != nil {
return &UnmarshalingAttributeError{Cause: err}
return UnmarshalingAttributeError{Cause: err}
}
stats.First100IDsInQueue = append(stats.First100IDsInQueue, item.ID)
}
Expand Down Expand Up @@ -718,9 +727,9 @@ func (c *client[T]) GetMessage(ctx context.Context, params *GetMessageInput) (*G
return &GetMessageOutput[T]{}, nil
}
item := Message[T]{}
err = attributevalue.UnmarshalMap(resp.Item, &item)
err = c.unmarshalMap(resp.Item, &item)
if err != nil {
return &GetMessageOutput[T]{}, &UnmarshalingAttributeError{Cause: err}
return &GetMessageOutput[T]{}, UnmarshalingAttributeError{Cause: err}
}
return &GetMessageOutput[T]{
Message: &item,
Expand All @@ -739,6 +748,9 @@ func (c *client[T]) ListMessages(ctx context.Context, params *ListMessagesInput)
if params == nil {
params = &ListMessagesInput{}
}
if params.Size <= 0 {
params.Size = DefaultMaxListMessages
}
output, err := c.dynamoDB.Scan(ctx, &dynamodb.ScanInput{
TableName: &c.tableName,
Limit: aws.Int32(params.Size),
Expand All @@ -749,7 +761,7 @@ func (c *client[T]) ListMessages(ctx context.Context, params *ListMessagesInput)
var messages []*Message[T]
err = attributevalue.UnmarshalListOfMaps(output.Items, &messages)
if err != nil {
return &ListMessagesOutput[T]{}, &UnmarshalingAttributeError{Cause: err}
return &ListMessagesOutput[T]{}, UnmarshalingAttributeError{Cause: err}
}
sort.Slice(messages, func(i, j int) bool {
return messages[i].LastUpdatedTimestamp < messages[j].LastUpdatedTimestamp
Expand Down Expand Up @@ -791,9 +803,9 @@ func (c *client[T]) ReplaceMessage(ctx context.Context, params *ReplaceMessageIn
}

func (c *client[T]) put(ctx context.Context, message *Message[T]) error {
item, err := message.marshalMap()
item, err := c.marshalMap(message)
if err != nil {
return err
return MarshalingAttributeError{Cause: err}
}
_, err = c.dynamoDB.PutItem(ctx, &dynamodb.PutItemInput{
TableName: aws.String(c.tableName),
Expand Down Expand Up @@ -824,9 +836,9 @@ func (c *client[T]) updateDynamoDBItem(ctx context.Context,
return nil, handleDynamoDBError(err)
}
message := Message[T]{}
err = attributevalue.UnmarshalMap(outcome.Attributes, &message)
err = c.unmarshalMap(outcome.Attributes, &message)
if err != nil {
return nil, &UnmarshalingAttributeError{Cause: err}
return nil, UnmarshalingAttributeError{Cause: err}
}
return &message, nil
}
Expand Down
134 changes: 122 additions & 12 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ type ClientTestCase[Args any, Want any] struct {
func TestDynamoMQClientShouldReturnError(t *testing.T) {
t.Parallel()
client, cancel := prepareTestClient(t, context.Background(),
NewSetupFunc(newPutRequestWithReadyItem("A-101", clock.Now())), mock.Clock{}, false)
NewSetupFunc(newPutRequestWithReadyItem("A-101", clock.Now())), mock.Clock{}, false, nil)
defer cancel()
type testCase struct {
name string
Expand Down Expand Up @@ -293,7 +293,7 @@ func TestDynamoMQClientReceiveMessage(t *testing.T) {
}
runTestsParallel[any, *ReceiveMessageOutput[test.MessageData]](t, "ReceiveMessage()", tests,
func(client Client[test.MessageData], _ any) (*ReceiveMessageOutput[test.MessageData], error) {
return client.ReceiveMessage(context.Background(), &ReceiveMessageInput{})
return client.ReceiveMessage(context.Background(), nil)
})
}

Expand All @@ -306,7 +306,7 @@ func testDynamoMQClientReceiveMessageSequence(t *testing.T, useFIFO bool) {
newPutRequestWithReadyItem("A-303", test.DefaultTestDate.Add(1*time.Second))),
mock.Clock{
T: now,
}, useFIFO)
}, useFIFO, nil)
defer clean()

wants := []*ReceiveMessageOutput[test.MessageData]{
Expand Down Expand Up @@ -609,7 +609,7 @@ func TestDynamoMQClientGetQueueStats(t *testing.T) {
}
runTestsParallel[any, *GetQueueStatsOutput](t, "GetQueueStats()", tests,
func(client Client[test.MessageData], _ any) (*GetQueueStatsOutput, error) {
return client.GetQueueStats(context.Background(), &GetQueueStatsInput{})
return client.GetQueueStats(context.Background(), nil)
})
}

Expand Down Expand Up @@ -646,7 +646,7 @@ func TestDynamoMQClientGetDLQStats(t *testing.T) {
}
runTestsParallel[any, *GetDLQStatsOutput](t, "GetDLQStats()", tests,
func(client Client[test.MessageData], _ any) (*GetDLQStatsOutput, error) {
return client.GetDLQStats(context.Background(), &GetDLQStatsInput{})
return client.GetDLQStats(context.Background(), nil)
})
}

Expand Down Expand Up @@ -736,11 +736,11 @@ func TestDynamoMQClientListMessages(t *testing.T) {
type args struct {
size int32
}
tests := []ClientTestCase[args, []*Message[test.MessageData]]{
tests := []ClientTestCase[*args, []*Message[test.MessageData]]{
{
name: "should return empty list when no messages",
setup: NewSetupFunc(),
args: args{
args: &args{
size: 10,
},
want: []*Message[test.MessageData]{},
Expand All @@ -753,17 +753,35 @@ func TestDynamoMQClientListMessages(t *testing.T) {
puts := generatePutRequests(messages)
return SetupDynamoDB(t, puts...)
},
args: args{
args: &args{
size: 10,
},
want: generateExpectedMessages("A",
test.DefaultTestDate, 10),
wantErr: nil,
},
{
name: "should return list of messages when messages exist and args is nil",
setup: func(t *testing.T) (string, *dynamodb.Client, func()) {
messages := generateExpectedMessages("A", test.DefaultTestDate, 10)
puts := generatePutRequests(messages)
return SetupDynamoDB(t, puts...)
},
args: nil,
want: generateExpectedMessages("A",
test.DefaultTestDate, 10),
wantErr: nil,
},
}
runTestsParallel[args, []*Message[test.MessageData]](t, "ListMessages()", tests,
func(client Client[test.MessageData], args args) ([]*Message[test.MessageData], error) {
out, err := client.ListMessages(context.Background(), &ListMessagesInput{Size: args.size})
runTestsParallel[*args, []*Message[test.MessageData]](t, "ListMessages()", tests,
func(client Client[test.MessageData], args *args) ([]*Message[test.MessageData], error) {
var in *ListMessagesInput
if args != nil {
in = &ListMessagesInput{
Size: args.size,
}
}
out, err := client.ListMessages(context.Background(), in)
return out.Messages, err
})
}
Expand All @@ -774,7 +792,7 @@ func runTestsParallel[Args any, Want any](t *testing.T, prefix string,
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
client, clean := prepareTestClient(t, context.Background(), tt.setup, tt.sdkClock, false)
client, clean := prepareTestClient(t, context.Background(), tt.setup, tt.sdkClock, false, nil)
defer clean()
result, err := operation(client, tt.args)
if tt.wantErr != nil {
Expand All @@ -790,6 +808,7 @@ func prepareTestClient(t *testing.T, ctx context.Context,
setupTable func(*testing.T) (string, *dynamodb.Client, func()),
sdkClock clock.Clock,
useFIFO bool,
unmarshalMap func(m map[string]types.AttributeValue, out interface{}) error,
) (Client[test.MessageData], func()) {
t.Helper()
tableName, raw, clean := setupTable(t)
Expand All @@ -802,6 +821,7 @@ func prepareTestClient(t *testing.T, ctx context.Context,
WithUseFIFO(useFIFO),
WithAWSVisibilityTimeout(1),
WithAWSRetryMaxAttempts(DefaultRetryMaxAttempts),
WithUnmarshalMap(unmarshalMap),
}
cfg, err := config.LoadDefaultConfig(ctx)
if err != nil {
Expand Down Expand Up @@ -865,3 +885,93 @@ func marshalMap[T any](m *Message[T]) (map[string]types.AttributeValue, error) {
}
return item, nil
}

func TestNewFromConfig(t *testing.T) {
_, err := NewFromConfig[any](aws.Config{}, WithAWSBaseEndpoint("https://localhost:8000"))
if err != nil {
t.Errorf("failed to new client from config: %s\n", err)
}
}

func TestTestDynamoMQClientReturnUnmarshalingAttributeError(t *testing.T) {
t.Parallel()
setupFunc := NewSetupFunc(
newPutRequestWithReadyItem("A-101", clock.Now()),
newPutRequestWithDLQItem("B-101", clock.Now()),
)
client, cancel := prepareTestClient(t, context.Background(), setupFunc, mock.Clock{}, false,
func(m map[string]types.AttributeValue, out interface{}) error {
return test.ErrorTest
})
defer cancel()
type testCase struct {
name string
operation func() (any, error)
}
tests := []testCase{
{
name: "ReceiveMessage should return UnmarshalingAttributeError",
operation: func() (any, error) {
return client.ReceiveMessage(context.Background(), &ReceiveMessageInput{})
},
},
{
name: "GetQueueStats should return UnmarshalingAttributeError",
operation: func() (any, error) {
return client.GetQueueStats(context.Background(), &GetQueueStatsInput{})
},
},
{
name: "GetDLQStats should return UnmarshalingAttributeError",
operation: func() (any, error) {
return client.GetDLQStats(context.Background(), &GetDLQStatsInput{})
},
},
{
name: "GetMessage should return UnmarshalingAttributeError",
operation: func() (any, error) {
return client.GetMessage(context.Background(), &GetMessageInput{
ID: "A-101",
})
},
},
{
name: "UpdateMessageAsVisible should return UnmarshalingAttributeError",
operation: func() (any, error) {
return client.UpdateMessageAsVisible(context.Background(), &UpdateMessageAsVisibleInput{
ID: "A-101",
})
},
},
{
name: "MoveMessageToDLQ should return UnmarshalingAttributeError",
operation: func() (any, error) {
return client.MoveMessageToDLQ(context.Background(), &MoveMessageToDLQInput{
ID: "A-101",
})
},
},
{
name: "RedriveMessage should return UnmarshalingAttributeError",
operation: func() (any, error) {
return client.RedriveMessage(context.Background(), &RedriveMessageInput{
ID: "B-101",
})
},
},
}
for _, tt := range tests {
_, err := tt.operation()
test.AssertError(t, err, UnmarshalingAttributeError{
Cause: test.ErrorTest,
}, tt.name)
}
}

func WithUnmarshalMap(f func(m map[string]types.AttributeValue, out interface{}) error) func(s *ClientOptions) {
return func(s *ClientOptions) {
if f != nil {
s.UnmarshalMap = f
}
}
}
Loading

0 comments on commit 0c961ca

Please sign in to comment.