Skip to content

Commit

Permalink
Remove use of strings.TrimSuffix(strings.Join(..., ""), ",")
Browse files Browse the repository at this point in the history
Remove trailing comma so only `strings.Join(..., ",")` is needed

Also fix some missed QuoteIdentifier spots where `\"%s\"` was being used
  • Loading branch information
serprex committed Jan 23, 2024
1 parent cd6086c commit b9ffebc
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 27 deletions.
15 changes: 7 additions & 8 deletions flow/connectors/postgres/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -408,18 +408,18 @@ func generateCreateTableSQLForNormalizedTable(
) string {
createTableSQLArray := make([]string, 0, utils.TableSchemaColumns(sourceTableSchema)+2)
utils.IterColumns(sourceTableSchema, func(columnName, genericColumnType string) {
createTableSQLArray = append(createTableSQLArray, fmt.Sprintf("\"%s\" %s,", columnName,
qValueKindToPostgresType(genericColumnType)))
createTableSQLArray = append(createTableSQLArray,
fmt.Sprintf("%s %s", QuoteIdentifier(columnName), qValueKindToPostgresType(genericColumnType)))
})

if softDeleteColName != "" {
createTableSQLArray = append(createTableSQLArray,
fmt.Sprintf(`%s BOOL DEFAULT FALSE,`, QuoteIdentifier(softDeleteColName)))
fmt.Sprintf(`%s BOOL DEFAULT FALSE`, QuoteIdentifier(softDeleteColName)))
}

if syncedAtColName != "" {
createTableSQLArray = append(createTableSQLArray,
fmt.Sprintf(`%s TIMESTAMP DEFAULT CURRENT_TIMESTAMP,`, QuoteIdentifier(syncedAtColName)))
fmt.Sprintf(`%s TIMESTAMP DEFAULT CURRENT_TIMESTAMP`, QuoteIdentifier(syncedAtColName)))
}

// add composite primary key to the table
Expand All @@ -428,12 +428,11 @@ func generateCreateTableSQLForNormalizedTable(
for _, primaryKeyCol := range sourceTableSchema.PrimaryKeyColumns {
primaryKeyColsQuoted = append(primaryKeyColsQuoted, QuoteIdentifier(primaryKeyCol))
}
createTableSQLArray = append(createTableSQLArray, fmt.Sprintf("PRIMARY KEY(%s),",
strings.TrimSuffix(strings.Join(primaryKeyColsQuoted, ","), ",")))
createTableSQLArray = append(createTableSQLArray, fmt.Sprintf("PRIMARY KEY(%s)",
strings.Join(primaryKeyColsQuoted, ",")))
}

return fmt.Sprintf(createNormalizedTableSQL, sourceTableIdentifier,
strings.TrimSuffix(strings.Join(createTableSQLArray, ""), ","))
return fmt.Sprintf(createNormalizedTableSQL, sourceTableIdentifier, strings.Join(createTableSQLArray, ","))
}

func (c *PostgresConnector) GetLastSyncBatchID(jobName string) (int64, error) {
Expand Down
4 changes: 2 additions & 2 deletions flow/connectors/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -742,8 +742,8 @@ func (c *PostgresConnector) ReplayTableSchemaDeltas(flowJobName string,

for _, addedColumn := range schemaDelta.AddedColumns {
_, err = tableSchemaModifyTx.Exec(c.ctx, fmt.Sprintf(
"ALTER TABLE %s ADD COLUMN IF NOT EXISTS \"%s\" %s",
schemaDelta.DstTableName, addedColumn.ColumnName,
"ALTER TABLE %s ADD COLUMN IF NOT EXISTS %s %s",
schemaDelta.DstTableName, QuoteIdentifier(addedColumn.ColumnName),
qValueKindToPostgresType(addedColumn.ColumnType)))
if err != nil {
return fmt.Errorf("failed to add column %s for table %s: %w", addedColumn.ColumnName,
Expand Down
4 changes: 2 additions & 2 deletions flow/connectors/postgres/qrep.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func (c *PostgresConnector) getNumRowsPartitions(
) ([]*protos.QRepPartition, error) {
var err error
numRowsPerPartition := int64(config.NumRowsPerPartition)
quotedWatermarkColumn := fmt.Sprintf("\"%s\"", config.WatermarkColumn)
quotedWatermarkColumn := QuoteIdentifier(config.WatermarkColumn)

whereClause := ""
if last != nil && last.Range != nil {
Expand Down Expand Up @@ -198,7 +198,7 @@ func (c *PostgresConnector) getMinMaxValues(
last *protos.QRepPartition,
) (interface{}, interface{}, error) {
var minValue, maxValue interface{}
quotedWatermarkColumn := fmt.Sprintf("\"%s\"", config.WatermarkColumn)
quotedWatermarkColumn := QuoteIdentifier(config.WatermarkColumn)

parsedWatermarkTable, err := utils.ParseSchemaTable(config.WatermarkTable)
if err != nil {
Expand Down
18 changes: 9 additions & 9 deletions flow/connectors/snowflake/merge_stmt_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,31 +41,31 @@ func (m *mergeStmtGenerator) generateMergeStmt() (string, error) {
switch qvalue.QValueKind(genericColumnType) {
case qvalue.QValueKindBytes, qvalue.QValueKindBit:
flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("BASE64_DECODE_BINARY(%s:\"%s\") "+
"AS %s,", toVariantColumnName, columnName, targetColumnName))
"AS %s", toVariantColumnName, columnName, targetColumnName))
case qvalue.QValueKindGeography:
flattenedCastsSQLArray = append(flattenedCastsSQLArray,
fmt.Sprintf("TO_GEOGRAPHY(CAST(%s:\"%s\" AS STRING),true) AS %s,",
fmt.Sprintf("TO_GEOGRAPHY(CAST(%s:\"%s\" AS STRING),true) AS %s",
toVariantColumnName, columnName, targetColumnName))
case qvalue.QValueKindGeometry:
flattenedCastsSQLArray = append(flattenedCastsSQLArray,
fmt.Sprintf("TO_GEOMETRY(CAST(%s:\"%s\" AS STRING),true) AS %s,",
fmt.Sprintf("TO_GEOMETRY(CAST(%s:\"%s\" AS STRING),true) AS %s",
toVariantColumnName, columnName, targetColumnName))
case qvalue.QValueKindJSON, qvalue.QValueKindHStore:
flattenedCastsSQLArray = append(flattenedCastsSQLArray,
fmt.Sprintf("PARSE_JSON(CAST(%s:\"%s\" AS STRING)) AS %s,",
fmt.Sprintf("PARSE_JSON(CAST(%s:\"%s\" AS STRING)) AS %s",
toVariantColumnName, columnName, targetColumnName))
// TODO: https://github.com/PeerDB-io/peerdb/issues/189 - handle time types and interval types
// case model.ColumnTypeTime:
// flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("TIME_FROM_PARTS(0,0,0,%s:%s:"+
// "Microseconds*1000) "+
// "AS %s,", toVariantColumnName, columnName, columnName))
// "AS %s", toVariantColumnName, columnName, columnName))
default:
if qvKind == qvalue.QValueKindNumeric {
flattenedCastsSQLArray = append(flattenedCastsSQLArray,
fmt.Sprintf("TRY_CAST((%s:\"%s\")::text AS %s) AS %s,",
fmt.Sprintf("TRY_CAST((%s:\"%s\")::text AS %s) AS %s",
toVariantColumnName, columnName, sfType, targetColumnName))
} else {
flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("CAST(%s:\"%s\" AS %s) AS %s,",
flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("CAST(%s:\"%s\" AS %s) AS %s",
toVariantColumnName, columnName, sfType, targetColumnName))
}
}
Expand All @@ -74,7 +74,7 @@ func (m *mergeStmtGenerator) generateMergeStmt() (string, error) {
if err != nil {
return "", err
}
flattenedCastsSQL := strings.TrimSuffix(strings.Join(flattenedCastsSQLArray, ""), ",")
flattenedCastsSQL := strings.Join(flattenedCastsSQLArray, ",")

quotedUpperColNames := make([]string, 0, len(columnNames))
for _, columnName := range columnNames {
Expand All @@ -85,7 +85,7 @@ func (m *mergeStmtGenerator) generateMergeStmt() (string, error) {
fmt.Sprintf(`"%s"`, strings.ToUpper(m.peerdbCols.SyncedAtColName)),
)

insertColumnsSQL := strings.TrimSuffix(strings.Join(quotedUpperColNames, ","), ",")
insertColumnsSQL := strings.Join(quotedUpperColNames, ",")

insertValuesSQLArray := make([]string, 0, len(columnNames))
for _, columnName := range columnNames {
Expand Down
12 changes: 6 additions & 6 deletions flow/connectors/snowflake/snowflake.go
Original file line number Diff line number Diff line change
Expand Up @@ -779,22 +779,22 @@ func generateCreateTableSQLForNormalizedTable(
slog.Any("error", err))
return
}
createTableSQLArray = append(createTableSQLArray, fmt.Sprintf(`%s %s,`, normalizedColName, sfColType))
createTableSQLArray = append(createTableSQLArray, fmt.Sprintf(`%s %s`, normalizedColName, sfColType))
})

// add a _peerdb_is_deleted column to the normalized table
// this is boolean default false, and is used to mark records as deleted
if softDeleteColName != "" {
createTableSQLArray = append(createTableSQLArray,
fmt.Sprintf(`%s BOOLEAN DEFAULT FALSE,`, softDeleteColName))
fmt.Sprintf(`%s BOOLEAN DEFAULT FALSE`, softDeleteColName))
}

// add a _peerdb_synced column to the normalized table
// this is a timestamp column that is used to mark records as synced
// default value is the current timestamp (snowflake)
if syncedAtColName != "" {
createTableSQLArray = append(createTableSQLArray,
fmt.Sprintf(`%s TIMESTAMP DEFAULT CURRENT_TIMESTAMP,`, syncedAtColName))
fmt.Sprintf(`%s TIMESTAMP DEFAULT CURRENT_TIMESTAMP`, syncedAtColName))
}

// add composite primary key to the table
Expand All @@ -804,12 +804,12 @@ func generateCreateTableSQLForNormalizedTable(
normalizedPrimaryKeyCols = append(normalizedPrimaryKeyCols,
SnowflakeIdentifierNormalize(primaryKeyCol))
}
createTableSQLArray = append(createTableSQLArray, fmt.Sprintf("PRIMARY KEY(%s),",
strings.TrimSuffix(strings.Join(normalizedPrimaryKeyCols, ","), ",")))
createTableSQLArray = append(createTableSQLArray,
fmt.Sprintf("PRIMARY KEY(%s)", strings.Join(normalizedPrimaryKeyCols, ",")))
}

return fmt.Sprintf(createNormalizedTableSQL, snowflakeSchemaTableNormalize(dstSchemaTable),
strings.TrimSuffix(strings.Join(createTableSQLArray, ""), ","))
strings.Join(createTableSQLArray, ","))
}

func getRawTableIdentifier(jobName string) string {
Expand Down

0 comments on commit b9ffebc

Please sign in to comment.