Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor aws dynamodb streams scaler #6089

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 18 additions & 31 deletions pkg/scalers/aws_dynamodb_streams_scaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ type awsDynamoDBStreamsScaler struct {
}

type awsDynamoDBStreamsMetadata struct {
targetShardCount int64
activationTargetShardCount int64
tableName string
awsRegion string
awsEndpoint string
awsAuthorization awsutils.AuthorizationMetadata
triggerIndex int
targetShardCount int64
activationTargetShardCount int64
Comment on lines +36 to +37
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wozniakjan I'm leaving these out for now because the error logic handling is abit different in thise case. Instead of letting it failed on invalid input, this code just throws error and fallback to the default value

if val, ok := config.TriggerMetadata["shardCount"]; ok && val != "" {
shardCount, err := strconv.ParseInt(val, 10, 64)
if err != nil {
meta.targetShardCount = defaultTargetDBStreamsShardCount
logger.Error(err, "error parsing dyanmodb stream metadata shardCount, using default %n", defaultTargetDBStreamsShardCount)
} else {
meta.targetShardCount = shardCount
}
}
if val, ok := config.TriggerMetadata["activationShardCount"]; ok && val != "" {
shardCount, err := strconv.ParseInt(val, 10, 64)
if err != nil {
meta.activationTargetShardCount = defaultActivationTargetDBStreamsShardCount
logger.Error(err, "error parsing dyanmodb stream metadata activationTargetShardCount, using default %n", defaultActivationTargetDBStreamsShardCount)
} else {
meta.activationTargetShardCount = shardCount
}
}

Copy link
Member

@wozniakjan wozniakjan Oct 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for pointing this out, I'm ok with this. But if you and other @kedacore/keda-core-contributors are willing to make an exception with a breaking change here, I'd like to advocate for a throwing error here. Imho it leads to overall better UX after the initial surprise that something that used to work (incorrectly) no longer does.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I have no idea why did we allow this in a first place :) I don't like this kind of unexpected defaults. I am inclined in changing the behavior to error here.

WDYT @kedacore/keda-contributors ?

Copy link
Member

@JorTurFer JorTurFer Nov 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is something undocumented a breaking change? IMHO it depends on the time that has been there, if it's something recent we can change it, but this has been there for 2 years, so we should keep it IMO

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm how about we add a deprecation message logic into this current line. After the next two release, we then deprecate it by this PR? Wdyt?

TableName string `keda:"name=tableName, order=triggerMetadata"`
AwsRegion string `keda:"name=awsRegion, order=triggerMetadata"`
AwsEndpoint string `keda:"name=awsEndpoint, order=triggerMetadata, optional"`
}

// NewAwsDynamoDBStreamsScaler creates a new awsDynamoDBStreamsScaler
Expand All @@ -58,7 +58,7 @@ func NewAwsDynamoDBStreamsScaler(ctx context.Context, config *scalersconfig.Scal
if err != nil {
return nil, fmt.Errorf("error when creating dynamodbstream client: %w", err)
}
streamArn, err := getDynamoDBStreamsArn(ctx, dbClient, &meta.tableName)
streamArn, err := getDynamoDBStreamsArn(ctx, dbClient, &meta.TableName)
if err != nil {
return nil, fmt.Errorf("error dynamodb stream arn: %w", err)
}
Expand All @@ -75,24 +75,11 @@ func NewAwsDynamoDBStreamsScaler(ctx context.Context, config *scalersconfig.Scal
}

func parseAwsDynamoDBStreamsMetadata(config *scalersconfig.ScalerConfig, logger logr.Logger) (*awsDynamoDBStreamsMetadata, error) {
meta := awsDynamoDBStreamsMetadata{}
meta.targetShardCount = defaultTargetDBStreamsShardCount

if val, ok := config.TriggerMetadata["awsRegion"]; ok && val != "" {
meta.awsRegion = val
} else {
return nil, fmt.Errorf("no awsRegion given")
}

if val, ok := config.TriggerMetadata["awsEndpoint"]; ok {
meta.awsEndpoint = val
}

if val, ok := config.TriggerMetadata["tableName"]; ok && val != "" {
meta.tableName = val
} else {
return nil, fmt.Errorf("no tableName given")
meta := &awsDynamoDBStreamsMetadata{}
if err := config.TypedConfig(meta); err != nil {
return nil, fmt.Errorf("error parsing prometheus metadata: %w", err)
}
meta.targetShardCount = defaultTargetDBStreamsShardCount

if val, ok := config.TriggerMetadata["shardCount"]; ok && val != "" {
shardCount, err := strconv.ParseInt(val, 10, 64)
Expand Down Expand Up @@ -121,22 +108,22 @@ func parseAwsDynamoDBStreamsMetadata(config *scalersconfig.ScalerConfig, logger
meta.awsAuthorization = auth
meta.triggerIndex = config.TriggerIndex

return &meta, nil
return meta, nil
}

func createClientsForDynamoDBStreamsScaler(ctx context.Context, metadata *awsDynamoDBStreamsMetadata) (*dynamodb.Client, *dynamodbstreams.Client, error) {
cfg, err := awsutils.GetAwsConfig(ctx, metadata.awsRegion, metadata.awsAuthorization)
cfg, err := awsutils.GetAwsConfig(ctx, metadata.AwsRegion, metadata.awsAuthorization)
if err != nil {
return nil, nil, err
}
dbClient := dynamodb.NewFromConfig(*cfg, func(options *dynamodb.Options) {
if metadata.awsEndpoint != "" {
options.BaseEndpoint = aws.String(metadata.awsEndpoint)
if metadata.AwsEndpoint != "" {
options.BaseEndpoint = aws.String(metadata.AwsEndpoint)
}
})
dbStreamClient := dynamodbstreams.NewFromConfig(*cfg, func(options *dynamodbstreams.Options) {
if metadata.awsEndpoint != "" {
options.BaseEndpoint = aws.String(metadata.awsEndpoint)
if metadata.AwsEndpoint != "" {
options.BaseEndpoint = aws.String(metadata.AwsEndpoint)
}
})

Expand Down Expand Up @@ -176,7 +163,7 @@ func (s *awsDynamoDBStreamsScaler) Close(_ context.Context) error {
func (s *awsDynamoDBStreamsScaler) GetMetricSpecForScaling(_ context.Context) []v2.MetricSpec {
externalMetric := &v2.ExternalMetricSource{
Metric: v2.MetricIdentifier{
Name: GenerateMetricNameWithIndex(s.metadata.triggerIndex, kedautil.NormalizeString(fmt.Sprintf("aws-dynamodb-streams-%s", s.metadata.tableName))),
Name: GenerateMetricNameWithIndex(s.metadata.triggerIndex, kedautil.NormalizeString(fmt.Sprintf("aws-dynamodb-streams-%s", s.metadata.TableName))),
},
Target: GetMetricTarget(s.metricType, s.metadata.targetShardCount),
}
Expand Down Expand Up @@ -208,7 +195,7 @@ func (s *awsDynamoDBStreamsScaler) getDynamoDBStreamShardCount(ctx context.Conte
}
for {
if lastShardID != nil {
// The upper limit of shard num to retrun is 100.
// The upper limit of shard num to return is 100.
// ExclusiveStartShardId is the shard ID of the first item that the operation will evaluate.
input = dynamodbstreams.DescribeStreamInput{
StreamArn: s.streamArn,
Expand Down
48 changes: 24 additions & 24 deletions pkg/scalers/aws_dynamodb_streams_scaler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,8 @@ var testAwsDynamoDBStreamMetadata = []parseAwsDynamoDBStreamsMetadataTestData{
expected: &awsDynamoDBStreamsMetadata{
targetShardCount: 2,
activationTargetShardCount: 1,
tableName: testAWSDynamoDBSmallTable,
awsRegion: testAWSDynamoDBStreamsRegion,
TableName: testAWSDynamoDBSmallTable,
AwsRegion: testAWSDynamoDBStreamsRegion,
awsAuthorization: awsutils.AuthorizationMetadata{
AwsAccessKeyID: testAWSDynamoDBStreamsAccessKeyID,
AwsSecretAccessKey: testAWSDynamoDBStreamsSecretAccessKey,
Expand All @@ -161,9 +161,9 @@ var testAwsDynamoDBStreamMetadata = []parseAwsDynamoDBStreamsMetadataTestData{
expected: &awsDynamoDBStreamsMetadata{
targetShardCount: 2,
activationTargetShardCount: 1,
tableName: testAWSDynamoDBSmallTable,
awsRegion: testAWSDynamoDBStreamsRegion,
awsEndpoint: testAWSDynamoDBStreamsEndpoint,
TableName: testAWSDynamoDBSmallTable,
AwsRegion: testAWSDynamoDBStreamsRegion,
AwsEndpoint: testAWSDynamoDBStreamsEndpoint,
awsAuthorization: awsutils.AuthorizationMetadata{
AwsAccessKeyID: testAWSDynamoDBStreamsAccessKeyID,
AwsSecretAccessKey: testAWSDynamoDBStreamsSecretAccessKey,
Expand Down Expand Up @@ -206,8 +206,8 @@ var testAwsDynamoDBStreamMetadata = []parseAwsDynamoDBStreamsMetadataTestData{
expected: &awsDynamoDBStreamsMetadata{
targetShardCount: defaultTargetDBStreamsShardCount,
activationTargetShardCount: defaultActivationTargetDBStreamsShardCount,
tableName: testAWSDynamoDBSmallTable,
awsRegion: testAWSDynamoDBStreamsRegion,
TableName: testAWSDynamoDBSmallTable,
AwsRegion: testAWSDynamoDBStreamsRegion,
awsAuthorization: awsutils.AuthorizationMetadata{
AwsAccessKeyID: testAWSDynamoDBStreamsAccessKeyID,
AwsSecretAccessKey: testAWSDynamoDBStreamsSecretAccessKey,
Expand All @@ -227,8 +227,8 @@ var testAwsDynamoDBStreamMetadata = []parseAwsDynamoDBStreamsMetadataTestData{
authParams: testAWSKinesisAuthentication,
expected: &awsDynamoDBStreamsMetadata{
targetShardCount: defaultTargetDBStreamsShardCount,
tableName: testAWSDynamoDBSmallTable,
awsRegion: testAWSDynamoDBStreamsRegion,
TableName: testAWSDynamoDBSmallTable,
AwsRegion: testAWSDynamoDBStreamsRegion,
awsAuthorization: awsutils.AuthorizationMetadata{
AwsAccessKeyID: testAWSDynamoDBStreamsAccessKeyID,
AwsSecretAccessKey: testAWSDynamoDBStreamsSecretAccessKey,
Expand Down Expand Up @@ -279,8 +279,8 @@ var testAwsDynamoDBStreamMetadata = []parseAwsDynamoDBStreamsMetadataTestData{
},
expected: &awsDynamoDBStreamsMetadata{
targetShardCount: 2,
tableName: testAWSDynamoDBSmallTable,
awsRegion: testAWSDynamoDBStreamsRegion,
TableName: testAWSDynamoDBSmallTable,
AwsRegion: testAWSDynamoDBStreamsRegion,
awsAuthorization: awsutils.AuthorizationMetadata{
AwsAccessKeyID: testAWSDynamoDBStreamsAccessKeyID,
AwsSecretAccessKey: testAWSDynamoDBStreamsSecretAccessKey,
Expand Down Expand Up @@ -331,8 +331,8 @@ var testAwsDynamoDBStreamMetadata = []parseAwsDynamoDBStreamsMetadataTestData{
},
expected: &awsDynamoDBStreamsMetadata{
targetShardCount: 2,
tableName: testAWSDynamoDBSmallTable,
awsRegion: testAWSDynamoDBStreamsRegion,
TableName: testAWSDynamoDBSmallTable,
AwsRegion: testAWSDynamoDBStreamsRegion,
awsAuthorization: awsutils.AuthorizationMetadata{
AwsRoleArn: testAWSDynamoDBStreamsRoleArn,
PodIdentityOwner: true,
Expand All @@ -351,8 +351,8 @@ var testAwsDynamoDBStreamMetadata = []parseAwsDynamoDBStreamsMetadataTestData{
authParams: map[string]string{},
expected: &awsDynamoDBStreamsMetadata{
targetShardCount: 2,
tableName: testAWSDynamoDBSmallTable,
awsRegion: testAWSDynamoDBStreamsRegion,
TableName: testAWSDynamoDBSmallTable,
AwsRegion: testAWSDynamoDBStreamsRegion,
awsAuthorization: awsutils.AuthorizationMetadata{
PodIdentityOwner: false,
},
Expand All @@ -370,10 +370,10 @@ var awsDynamoDBStreamMetricIdentifiers = []awsDynamoDBStreamsMetricIdentifier{
}

var awsDynamoDBStreamsGetMetricTestData = []*awsDynamoDBStreamsMetadata{
{tableName: testAWSDynamoDBBigTable},
{tableName: testAWSDynamoDBSmallTable},
{tableName: testAWSDynamoDBErrorTable},
{tableName: testAWSDynamoDBInvalidTable},
{TableName: testAWSDynamoDBBigTable},
{TableName: testAWSDynamoDBSmallTable},
{TableName: testAWSDynamoDBErrorTable},
{TableName: testAWSDynamoDBInvalidTable},
}

func TestParseAwsDynamoDBStreamsMetadata(t *testing.T) {
Expand All @@ -399,7 +399,7 @@ func TestAwsDynamoDBStreamsGetMetricSpecForScaling(t *testing.T) {
if err != nil {
t.Fatal("Could not parse metadata:", err)
}
streamArn, err := getDynamoDBStreamsArn(ctx, &mockAwsDynamoDB{}, &meta.tableName)
streamArn, err := getDynamoDBStreamsArn(ctx, &mockAwsDynamoDB{}, &meta.TableName)
if err != nil {
t.Fatal("Could not get dynamodb stream arn:", err)
}
Expand All @@ -418,12 +418,12 @@ func TestAwsDynamoDBStreamsScalerGetMetrics(t *testing.T) {
var err error
var streamArn *string
ctx := context.Background()
streamArn, err = getDynamoDBStreamsArn(ctx, &mockAwsDynamoDB{}, &meta.tableName)
streamArn, err = getDynamoDBStreamsArn(ctx, &mockAwsDynamoDB{}, &meta.TableName)
if err == nil {
scaler := awsDynamoDBStreamsScaler{"", meta, streamArn, &mockAwsDynamoDBStreams{}, logr.Discard()}
value, _, err = scaler.GetMetricsAndActivity(context.Background(), "MetricName")
}
switch meta.tableName {
switch meta.TableName {
case testAWSDynamoDBErrorTable:
assert.Error(t, err, "expect error because of dynamodb stream api error")
case testAWSDynamoDBInvalidTable:
Expand All @@ -442,12 +442,12 @@ func TestAwsDynamoDBStreamsScalerIsActive(t *testing.T) {
var err error
var streamArn *string
ctx := context.Background()
streamArn, err = getDynamoDBStreamsArn(ctx, &mockAwsDynamoDB{}, &meta.tableName)
streamArn, err = getDynamoDBStreamsArn(ctx, &mockAwsDynamoDB{}, &meta.TableName)
if err == nil {
scaler := awsDynamoDBStreamsScaler{"", meta, streamArn, &mockAwsDynamoDBStreams{}, logr.Discard()}
_, value, err = scaler.GetMetricsAndActivity(context.Background(), "MetricName")
}
switch meta.tableName {
switch meta.TableName {
case testAWSDynamoDBErrorTable:
assert.Error(t, err, "expect error because of dynamodb stream api error")
case testAWSDynamoDBInvalidTable:
Expand Down
Loading