diff --git a/client.go b/client.go index b5f98fa..2d2a6cb 100644 --- a/client.go +++ b/client.go @@ -577,7 +577,7 @@ func (c *client[T]) queryAndCalculateQueueStats(ctx context.Context, expr expres ExclusiveStartKey: exclusiveStartKey, }) if err != nil { - return nil, err + return nil, handleDynamoDBError(err) } exclusiveStartKey = queryOutput.LastEvaluatedKey @@ -848,7 +848,7 @@ func handleDynamoDBError(err error) error { if errors.As(err, &cause) { return &ConditionalCheckFailedError{Cause: cause} } - return &DynamoDBAPIError{Cause: err} + return DynamoDBAPIError{Cause: err} } type Status string diff --git a/client_test.go b/client_test.go index d4039e2..b1439d9 100644 --- a/client_test.go +++ b/client_test.go @@ -2,7 +2,9 @@ package dynamomq_test import ( "context" + "errors" "fmt" + "reflect" "testing" "time" @@ -1011,6 +1013,76 @@ func TestTestDynamoMQClientReturnMarshalingAttributeError(t *testing.T) { } } +func TestTestDynamoMQClientReturnDynamoDBAPIError(t *testing.T) { + t.Parallel() + client, err := NewFromConfig[test.MessageData](aws.Config{}) + if err != nil { + t.Fatalf("failed to create DynamoMQ client: %s\n", err) + } + type testCase struct { + name string + operation func() (any, error) + } + tests := []testCase{ + { + name: "ReceiveMessage should return DynamoDBAPIError", + operation: func() (any, error) { + return client.ReceiveMessage(context.Background(), &ReceiveMessageInput{}) + }, + }, + { + name: "DeleteMessage should return DynamoDBAPIError", + operation: func() (any, error) { + return client.DeleteMessage(context.Background(), &DeleteMessageInput{ + ID: "A-101", + }) + }, + }, + { + name: "GetQueueStats should return DynamoDBAPIError", + operation: func() (any, error) { + return client.GetQueueStats(context.Background(), &GetQueueStatsInput{}) + }, + }, + { + name: "GetDLQStats should return DynamoDBAPIError", + operation: func() (any, error) { + return client.GetDLQStats(context.Background(), &GetDLQStatsInput{}) + }, + }, + { + name: "GetMessage should return DynamoDBAPIError", + operation: func() (any, error) { + return client.GetMessage(context.Background(), &GetMessageInput{ + ID: "A-101", + }) + }, + }, + { + name: "ListMessages should return DynamoDBAPIError", + operation: func() (any, error) { + return client.ListMessages(context.Background(), &ListMessagesInput{ + Size: DefaultMaxListMessages, + }) + }, + }, + } + for _, tt := range tests { + _, err := tt.operation() + if _, ok := assertErrorType[DynamoDBAPIError](err); !ok { + t.Errorf("error = %v, want %v", "DynamoDBAPIError", reflect.TypeOf(err)) + } + } +} + +func assertErrorType[T error](err error) (T, bool) { + var wantErr T + if errors.As(err, &wantErr) { + return wantErr, true + } + return wantErr, false +} + func WithUnmarshalMap(f func(m map[string]types.AttributeValue, out interface{}) error) func(s *ClientOptions) { return func(s *ClientOptions) { if f != nil {