Skip to content

Commit

Permalink
adapt to new aws code architecture
Browse files Browse the repository at this point in the history
  • Loading branch information
Amogh-Bharadwaj committed Apr 24, 2024
1 parent 2a378be commit 1b50375
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 14 deletions.
2 changes: 1 addition & 1 deletion flow/connectors/clickhouse/clickhouse.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ func NewClickhouseConnector(
AccessKeyID: config.AccessKeyId,
SecretAccessKey: config.SecretAccessKey,
},
EndpointUrl: nil,
EndpointUrl: config.Endpoint,
Region: config.Region,
})
if err != nil {
Expand Down
7 changes: 5 additions & 2 deletions flow/connectors/clickhouse/qrep_avro_sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ func (s *ClickhouseAvroSyncMethod) CopyStageToDestination(ctx context.Context, a

avroFileUrl := fmt.Sprintf("https://%s.s3.%s.amazonaws.com/%s", s3o.Bucket,
s.connector.credsProvider.Provider.GetRegion(), avroFile.FilePath)
if strings.Contains(s.connector.creds.Endpoint, "storage.googleapis.com") {

endpoint := s.connector.credsProvider.Provider.GetEndpointURL()
if strings.Contains(endpoint, "storage.googleapis.com") {
avroFileUrl = fmt.Sprintf("https://storage.googleapis.com/%s/%s", s3o.Bucket, avroFile.FilePath)
}

Expand Down Expand Up @@ -126,7 +128,8 @@ func (s *ClickhouseAvroSyncMethod) SyncQRepRecords(

avroFileUrl := fmt.Sprintf("https://%s.s3.%s.amazonaws.com/%s", s3o.Bucket,
s.connector.credsProvider.Provider.GetRegion(), avroFile.FilePath)
if strings.Contains(s.connector.creds.Endpoint, "storage.googleapis.com") {
endpoint := s.connector.credsProvider.Provider.GetEndpointURL()
if strings.Contains(endpoint, "storage.googleapis.com") {
avroFileUrl = fmt.Sprintf("https://storage.googleapis.com/%s/%s", s3o.Bucket, avroFile.FilePath)
}

Expand Down
38 changes: 27 additions & 11 deletions flow/connectors/utils/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ type AWSCredentialsProvider interface {
Retrieve(ctx context.Context) (AWSCredentials, error)
GetUnderlyingProvider() aws.CredentialsProvider
GetRegion() string
GetEndpointURL() string
}

type ConfigBasedAWSCredentialsProvider struct {
Expand All @@ -71,6 +72,15 @@ func (r *ConfigBasedAWSCredentialsProvider) GetRegion() string {
return r.config.Region
}

func (r *ConfigBasedAWSCredentialsProvider) GetEndpointURL() string {
endpoint := ""
if r.config.BaseEndpoint != nil {
endpoint = *r.config.BaseEndpoint
}

return endpoint
}

// Retrieve should be called as late as possible in order to have credentials with latest expiry
func (r *ConfigBasedAWSCredentialsProvider) Retrieve(ctx context.Context) (AWSCredentials, error) {
retrieved, err := r.config.Credentials.Retrieve(ctx)
Expand Down Expand Up @@ -105,6 +115,15 @@ func (s *StaticAWSCredentialsProvider) Retrieve(ctx context.Context) (AWSCredent
return s.credentials, nil
}

func (s *StaticAWSCredentialsProvider) GetEndpointURL() string {
endpoint := ""
if s.credentials.EndpointUrl != nil {
endpoint = *s.credentials.EndpointUrl
}

return endpoint
}

func NewStaticAWSCredentialsProvider(credentials AWSCredentials, region string) AWSCredentialsProvider {
return &StaticAWSCredentialsProvider{
credentials: credentials,
Expand Down Expand Up @@ -206,17 +225,14 @@ func CreateS3Client(ctx context.Context, credsProvider AWSCredentialsProvider) (
options.Region = credsProvider.GetRegion()
options.Credentials = credsProvider.GetUnderlyingProvider()
if awsCredentials.EndpointUrl != nil {
options.BaseEndpoint = awsCredentials.EndpointUrl
if strings.Contains(*awsCredentials.EndpointUrl, "storage.googleapis.com") {
// Assign custom client with our own transport
options.HTTPClient = &http.Client{
Transport: &RecalculateV4Signature{
next: http.DefaultTransport,
signer: v4.NewSigner(),
credentials: credsProvider.GetUnderlyingProvider(),
region: credsProvider.GetRegion(),
},
}
// Assign custom client with our own transport
options.HTTPClient = &http.Client{
Transport: &RecalculateV4Signature{
next: http.DefaultTransport,
signer: v4.NewSigner(),
credentials: credsProvider.GetUnderlyingProvider(),
region: credsProvider.GetRegion(),
},
}
}
})
Expand Down

0 comments on commit 1b50375

Please sign in to comment.