From 22517e97cf53a8630698ca6989f27b206a6ed477 Mon Sep 17 00:00:00 2001 From: Demur Rumed Date: Tue, 5 Dec 2023 12:11:48 +0000 Subject: [PATCH] Cleanup: two optimizations 1. Prefer strings.Cut to strings.Split 2. Prefer map[T]struct{} to map[T]bool when empty & false are equivalent --- flow/cmd/handler.go | 7 +++---- .../bigquery/merge_statement_generator.go | 4 ++-- flow/connectors/bigquery/qrep_avro_sync.go | 4 ++-- flow/connectors/postgres/client.go | 9 +++++---- flow/connectors/snowflake/snowflake.go | 13 +++++-------- flow/connectors/utils/array.go | 7 ++++--- flow/connectors/utils/avro/avro_writer.go | 2 +- flow/connectors/utils/aws.go | 11 ++--------- flow/model/conversion_avro.go | 14 +++++++------- flow/model/model.go | 6 +++--- 10 files changed, 34 insertions(+), 43 deletions(-) diff --git a/flow/cmd/handler.go b/flow/cmd/handler.go index 349501c306..489a503913 100644 --- a/flow/cmd/handler.go +++ b/flow/cmd/handler.go @@ -47,11 +47,10 @@ func (h *FlowRequestHandler) getPeerID(ctx context.Context, peerName string) (in } func schemaForTableIdentifier(tableIdentifier string, peerDBType int32) string { - tableIdentifierParts := strings.Split(tableIdentifier, ".") - if len(tableIdentifierParts) == 1 && peerDBType != int32(protos.DBType_BIGQUERY) { - tableIdentifierParts = append([]string{"public"}, tableIdentifierParts...) + if peerDBType != int32(protos.DBType_BIGQUERY) && !strings.ContainsRune(tableIdentifier, '.') { + return "public." + tableIdentifier } - return strings.Join(tableIdentifierParts, ".") + return tableIdentifier } func (h *FlowRequestHandler) createCdcJobEntry(ctx context.Context, diff --git a/flow/connectors/bigquery/merge_statement_generator.go b/flow/connectors/bigquery/merge_statement_generator.go index 1bef877a67..427fcd7884 100644 --- a/flow/connectors/bigquery/merge_statement_generator.go +++ b/flow/connectors/bigquery/merge_statement_generator.go @@ -172,12 +172,12 @@ and updating the other columns (not the unchanged toast columns) 7. Return the list of generated update statements. */ func (m *mergeStmtGenerator) generateUpdateStatements(allCols []string, unchangedToastCols []string) []string { - updateStmts := make([]string, 0) + updateStmts := make([]string, 0, len(unchangedToastCols)) for _, cols := range unchangedToastCols { unchangedColsArray := strings.Split(cols, ", ") otherCols := utils.ArrayMinus(allCols, unchangedColsArray) - tmpArray := make([]string, 0) + tmpArray := make([]string, 0, len(otherCols)) for _, colName := range otherCols { tmpArray = append(tmpArray, fmt.Sprintf("`%s` = _peerdb_deduped.%s", colName, colName)) } diff --git a/flow/connectors/bigquery/qrep_avro_sync.go b/flow/connectors/bigquery/qrep_avro_sync.go index 97c043ac04..c92e17fa14 100644 --- a/flow/connectors/bigquery/qrep_avro_sync.go +++ b/flow/connectors/bigquery/qrep_avro_sync.go @@ -186,7 +186,7 @@ type AvroSchema struct { func DefineAvroSchema(dstTableName string, dstTableMetadata *bigquery.TableMetadata) (*model.QRecordAvroSchemaDefinition, error) { avroFields := []AvroField{} - nullableFields := map[string]bool{} + nullableFields := make(map[string]struct{}) for _, bqField := range dstTableMetadata.Schema { avroType, err := GetAvroType(bqField) @@ -197,7 +197,7 @@ func DefineAvroSchema(dstTableName string, // If a field is nullable, its Avro type should be ["null", actualType] if !bqField.Required { avroType = []interface{}{"null", avroType} - nullableFields[bqField.Name] = true + nullableFields[bqField.Name] = struct{}{} } avroFields = append(avroFields, AvroField{ diff --git a/flow/connectors/postgres/client.go b/flow/connectors/postgres/client.go index 4c5fb6a3c1..da3096bcb9 100644 --- a/flow/connectors/postgres/client.go +++ b/flow/connectors/postgres/client.go @@ -572,15 +572,16 @@ func (c *PostgresConnector) generateMergeStatement(destinationTableIdentifier st } func (c *PostgresConnector) generateUpdateStatement(allCols []string, unchangedToastColsLists []string) string { - updateStmts := make([]string, 0) + updateStmts := make([]string, 0, len(unchangedToastColsLists)) for _, cols := range unchangedToastColsLists { - unchangedColsArray := make([]string, 0) - for _, unchangedToastCol := range strings.Split(cols, ",") { + unquotedUnchangedColsArray := strings.Split(cols, ",") + unchangedColsArray := make([]string, 0, len(unquotedUnchangedColsArray)) + for _, unchangedToastCol := range unquotedUnchangedColsArray { unchangedColsArray = append(unchangedColsArray, fmt.Sprintf(`"%s"`, unchangedToastCol)) } otherCols := utils.ArrayMinus(allCols, unchangedColsArray) - tmpArray := make([]string, 0) + tmpArray := make([]string, 0, len(otherCols)) for _, colName := range otherCols { tmpArray = append(tmpArray, fmt.Sprintf("%s=src.%s", colName, colName)) } diff --git a/flow/connectors/snowflake/snowflake.go b/flow/connectors/snowflake/snowflake.go index c624234923..8fe26ce50f 100644 --- a/flow/connectors/snowflake/snowflake.go +++ b/flow/connectors/snowflake/snowflake.go @@ -926,15 +926,12 @@ func (c *SnowflakeConnector) generateAndExecuteMergeStatement( // parseTableName parses a table name into schema and table name. func parseTableName(tableName string) (*tableNameComponents, error) { - parts := strings.Split(tableName, ".") - if len(parts) != 2 { + schemaIdentifier, tableIdentifier, hasDot := strings.Cut(tableName, ".") + if !hasDot || strings.ContainsRune(tableIdentifier, '.') { return nil, fmt.Errorf("invalid table name: %s", tableName) } - return &tableNameComponents{ - schemaIdentifier: parts[0], - tableIdentifier: parts[1], - }, nil + return &tableNameComponents{schemaIdentifier, tableIdentifier}, nil } func (c *SnowflakeConnector) jobMetadataExists(jobName string) (bool, error) { @@ -1032,12 +1029,12 @@ and updating the other columns. func (c *SnowflakeConnector) generateUpdateStatements( syncedAtCol string, softDeleteCol string, softDelete bool, allCols []string, unchangedToastCols []string) []string { - updateStmts := make([]string, 0) + updateStmts := make([]string, 0, len(unchangedToastCols)) for _, cols := range unchangedToastCols { unchangedColsArray := strings.Split(cols, ",") otherCols := utils.ArrayMinus(allCols, unchangedColsArray) - tmpArray := make([]string, 0) + tmpArray := make([]string, 0, len(otherCols)+2) for _, colName := range otherCols { quotedUpperColName := fmt.Sprintf(`"%s"`, strings.ToUpper(colName)) tmpArray = append(tmpArray, fmt.Sprintf("%s = SOURCE.%s", quotedUpperColName, quotedUpperColName)) diff --git a/flow/connectors/utils/array.go b/flow/connectors/utils/array.go index 37856657b0..ec83d9a6a9 100644 --- a/flow/connectors/utils/array.go +++ b/flow/connectors/utils/array.go @@ -1,15 +1,16 @@ package utils func ArrayMinus(first []string, second []string) []string { - lookup := make(map[string]bool) + lookup := make(map[string]struct{}) // Add elements from arrayB to the lookup map for _, element := range second { - lookup[element] = true + lookup[element] = struct{}{} } // Iterate over arrayA and check if the element is present in the lookup map var result []string for _, element := range first { - if !lookup[element] { + _, exists := lookup[element] + if !exists { result = append(result, element) } } diff --git a/flow/connectors/utils/avro/avro_writer.go b/flow/connectors/utils/avro/avro_writer.go index 36c8858aa4..9c5219f142 100644 --- a/flow/connectors/utils/avro/avro_writer.go +++ b/flow/connectors/utils/avro/avro_writer.go @@ -127,7 +127,7 @@ func (p *PeerDBOCFWriter) writeRecordsToOCFWriter(ocfWriter *goavro.OCFWriter) ( avroConverter := model.NewQRecordAvroConverter( qRecord, p.targetDWH, - &p.avroSchema.NullableFields, + p.avroSchema.NullableFields, colNames, ) diff --git a/flow/connectors/utils/aws.go b/flow/connectors/utils/aws.go index 5ba14d1616..473efd0ce5 100644 --- a/flow/connectors/utils/aws.go +++ b/flow/connectors/utils/aws.go @@ -81,18 +81,11 @@ func NewS3BucketAndPrefix(s3Path string) (*S3BucketAndPrefix, error) { stagingPath := strings.TrimPrefix(s3Path, "s3://") // Split into bucket and prefix - splitPath := strings.SplitN(stagingPath, "/", 2) - - bucket := splitPath[0] - prefix := "" - if len(splitPath) > 1 { - // Remove leading and trailing slashes from prefix - prefix = strings.Trim(splitPath[1], "/") - } + bucket, prefix, _ := strings.Cut(stagingPath, "/") return &S3BucketAndPrefix{ Bucket: bucket, - Prefix: prefix, + Prefix: strings.Trim(prefix, "/"), }, nil } diff --git a/flow/model/conversion_avro.go b/flow/model/conversion_avro.go index 3c4ba07076..258e9c00e8 100644 --- a/flow/model/conversion_avro.go +++ b/flow/model/conversion_avro.go @@ -10,14 +10,14 @@ import ( type QRecordAvroConverter struct { QRecord *QRecord TargetDWH qvalue.QDWHType - NullableFields *map[string]bool + NullableFields map[string]struct{} ColNames []string } func NewQRecordAvroConverter( q *QRecord, targetDWH qvalue.QDWHType, - nullableFields *map[string]bool, + nullableFields map[string]struct{}, colNames []string, ) *QRecordAvroConverter { return &QRecordAvroConverter{ @@ -33,12 +33,12 @@ func (qac *QRecordAvroConverter) Convert() (map[string]interface{}, error) { for idx := range qac.QRecord.Entries { key := qac.ColNames[idx] - nullable, ok := (*qac.NullableFields)[key] + _, nullable := qac.NullableFields[key] avroConverter := qvalue.NewQValueAvroConverter( &qac.QRecord.Entries[idx], qac.TargetDWH, - nullable && ok, + nullable, ) avroVal, err := avroConverter.ToAvroValue() if err != nil { @@ -64,7 +64,7 @@ type QRecordAvroSchema struct { type QRecordAvroSchemaDefinition struct { Schema string - NullableFields map[string]bool + NullableFields map[string]struct{} } func GetAvroSchemaDefinition( @@ -72,7 +72,7 @@ func GetAvroSchemaDefinition( qRecordSchema *QRecordSchema, ) (*QRecordAvroSchemaDefinition, error) { avroFields := []QRecordAvroField{} - nullableFields := map[string]bool{} + nullableFields := make(map[string]struct{}) for _, qField := range qRecordSchema.Fields { avroType, err := qvalue.GetAvroSchemaFromQValueKind(qField.Type, qField.Nullable) @@ -84,7 +84,7 @@ func GetAvroSchemaDefinition( if qField.Nullable { consolidatedType = []interface{}{"null", consolidatedType} - nullableFields[qField.Name] = true + nullableFields[qField.Name] = struct{}{} } avroFields = append(avroFields, QRecordAvroField{ diff --git a/flow/model/model.go b/flow/model/model.go index 297313b0f2..e60db328fd 100644 --- a/flow/model/model.go +++ b/flow/model/model.go @@ -174,13 +174,13 @@ func (r *RecordItems) toMap() (map[string]interface{}, error) { } type ToJSONOptions struct { - UnnestColumns map[string]bool + UnnestColumns map[string]struct{} } func NewToJSONOptions(unnestCols []string) *ToJSONOptions { - unnestColumns := make(map[string]bool) + unnestColumns := make(map[string]struct{}) for _, col := range unnestCols { - unnestColumns[col] = true + unnestColumns[col] = struct{}{} } return &ToJSONOptions{ UnnestColumns: unnestColumns,