diff --git a/flow/connectors/postgres/client.go b/flow/connectors/postgres/client.go index eb0b3e7619..b96e9a3030 100644 --- a/flow/connectors/postgres/client.go +++ b/flow/connectors/postgres/client.go @@ -408,18 +408,18 @@ func generateCreateTableSQLForNormalizedTable( ) string { createTableSQLArray := make([]string, 0, utils.TableSchemaColumns(sourceTableSchema)+2) utils.IterColumns(sourceTableSchema, func(columnName, genericColumnType string) { - createTableSQLArray = append(createTableSQLArray, fmt.Sprintf("\"%s\" %s,", columnName, - qValueKindToPostgresType(genericColumnType))) + createTableSQLArray = append(createTableSQLArray, + fmt.Sprintf("%s %s", QuoteIdentifier(columnName), qValueKindToPostgresType(genericColumnType))) }) if softDeleteColName != "" { createTableSQLArray = append(createTableSQLArray, - fmt.Sprintf(`%s BOOL DEFAULT FALSE,`, QuoteIdentifier(softDeleteColName))) + fmt.Sprintf(`%s BOOL DEFAULT FALSE`, QuoteIdentifier(softDeleteColName))) } if syncedAtColName != "" { createTableSQLArray = append(createTableSQLArray, - fmt.Sprintf(`%s TIMESTAMP DEFAULT CURRENT_TIMESTAMP,`, QuoteIdentifier(syncedAtColName))) + fmt.Sprintf(`%s TIMESTAMP DEFAULT CURRENT_TIMESTAMP`, QuoteIdentifier(syncedAtColName))) } // add composite primary key to the table @@ -428,12 +428,11 @@ func generateCreateTableSQLForNormalizedTable( for _, primaryKeyCol := range sourceTableSchema.PrimaryKeyColumns { primaryKeyColsQuoted = append(primaryKeyColsQuoted, QuoteIdentifier(primaryKeyCol)) } - createTableSQLArray = append(createTableSQLArray, fmt.Sprintf("PRIMARY KEY(%s),", - strings.TrimSuffix(strings.Join(primaryKeyColsQuoted, ","), ","))) + createTableSQLArray = append(createTableSQLArray, fmt.Sprintf("PRIMARY KEY(%s)", + strings.Join(primaryKeyColsQuoted, ","))) } - return fmt.Sprintf(createNormalizedTableSQL, sourceTableIdentifier, - strings.TrimSuffix(strings.Join(createTableSQLArray, ""), ",")) + return fmt.Sprintf(createNormalizedTableSQL, sourceTableIdentifier, strings.Join(createTableSQLArray, ",")) } func (c *PostgresConnector) GetLastSyncBatchID(jobName string) (int64, error) { diff --git a/flow/connectors/postgres/postgres.go b/flow/connectors/postgres/postgres.go index 7fdd6b032d..9fbca8e449 100644 --- a/flow/connectors/postgres/postgres.go +++ b/flow/connectors/postgres/postgres.go @@ -742,8 +742,8 @@ func (c *PostgresConnector) ReplayTableSchemaDeltas(flowJobName string, for _, addedColumn := range schemaDelta.AddedColumns { _, err = tableSchemaModifyTx.Exec(c.ctx, fmt.Sprintf( - "ALTER TABLE %s ADD COLUMN IF NOT EXISTS \"%s\" %s", - schemaDelta.DstTableName, addedColumn.ColumnName, + "ALTER TABLE %s ADD COLUMN IF NOT EXISTS %s %s", + schemaDelta.DstTableName, QuoteIdentifier(addedColumn.ColumnName), qValueKindToPostgresType(addedColumn.ColumnType))) if err != nil { return fmt.Errorf("failed to add column %s for table %s: %w", addedColumn.ColumnName, diff --git a/flow/connectors/postgres/qrep.go b/flow/connectors/postgres/qrep.go index 51cb6e30e9..fc706c920a 100644 --- a/flow/connectors/postgres/qrep.go +++ b/flow/connectors/postgres/qrep.go @@ -84,7 +84,7 @@ func (c *PostgresConnector) getNumRowsPartitions( ) ([]*protos.QRepPartition, error) { var err error numRowsPerPartition := int64(config.NumRowsPerPartition) - quotedWatermarkColumn := fmt.Sprintf("\"%s\"", config.WatermarkColumn) + quotedWatermarkColumn := QuoteIdentifier(config.WatermarkColumn) whereClause := "" if last != nil && last.Range != nil { @@ -198,7 +198,7 @@ func (c *PostgresConnector) getMinMaxValues( last *protos.QRepPartition, ) (interface{}, interface{}, error) { var minValue, maxValue interface{} - quotedWatermarkColumn := fmt.Sprintf("\"%s\"", config.WatermarkColumn) + quotedWatermarkColumn := QuoteIdentifier(config.WatermarkColumn) parsedWatermarkTable, err := utils.ParseSchemaTable(config.WatermarkTable) if err != nil { diff --git a/flow/connectors/snowflake/merge_stmt_generator.go b/flow/connectors/snowflake/merge_stmt_generator.go index a26d981b9b..291b3314d9 100644 --- a/flow/connectors/snowflake/merge_stmt_generator.go +++ b/flow/connectors/snowflake/merge_stmt_generator.go @@ -41,31 +41,31 @@ func (m *mergeStmtGenerator) generateMergeStmt() (string, error) { switch qvalue.QValueKind(genericColumnType) { case qvalue.QValueKindBytes, qvalue.QValueKindBit: flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("BASE64_DECODE_BINARY(%s:\"%s\") "+ - "AS %s,", toVariantColumnName, columnName, targetColumnName)) + "AS %s", toVariantColumnName, columnName, targetColumnName)) case qvalue.QValueKindGeography: flattenedCastsSQLArray = append(flattenedCastsSQLArray, - fmt.Sprintf("TO_GEOGRAPHY(CAST(%s:\"%s\" AS STRING),true) AS %s,", + fmt.Sprintf("TO_GEOGRAPHY(CAST(%s:\"%s\" AS STRING),true) AS %s", toVariantColumnName, columnName, targetColumnName)) case qvalue.QValueKindGeometry: flattenedCastsSQLArray = append(flattenedCastsSQLArray, - fmt.Sprintf("TO_GEOMETRY(CAST(%s:\"%s\" AS STRING),true) AS %s,", + fmt.Sprintf("TO_GEOMETRY(CAST(%s:\"%s\" AS STRING),true) AS %s", toVariantColumnName, columnName, targetColumnName)) case qvalue.QValueKindJSON, qvalue.QValueKindHStore: flattenedCastsSQLArray = append(flattenedCastsSQLArray, - fmt.Sprintf("PARSE_JSON(CAST(%s:\"%s\" AS STRING)) AS %s,", + fmt.Sprintf("PARSE_JSON(CAST(%s:\"%s\" AS STRING)) AS %s", toVariantColumnName, columnName, targetColumnName)) // TODO: https://github.com/PeerDB-io/peerdb/issues/189 - handle time types and interval types // case model.ColumnTypeTime: // flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("TIME_FROM_PARTS(0,0,0,%s:%s:"+ // "Microseconds*1000) "+ - // "AS %s,", toVariantColumnName, columnName, columnName)) + // "AS %s", toVariantColumnName, columnName, columnName)) default: if qvKind == qvalue.QValueKindNumeric { flattenedCastsSQLArray = append(flattenedCastsSQLArray, - fmt.Sprintf("TRY_CAST((%s:\"%s\")::text AS %s) AS %s,", + fmt.Sprintf("TRY_CAST((%s:\"%s\")::text AS %s) AS %s", toVariantColumnName, columnName, sfType, targetColumnName)) } else { - flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("CAST(%s:\"%s\" AS %s) AS %s,", + flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("CAST(%s:\"%s\" AS %s) AS %s", toVariantColumnName, columnName, sfType, targetColumnName)) } } @@ -74,7 +74,7 @@ func (m *mergeStmtGenerator) generateMergeStmt() (string, error) { if err != nil { return "", err } - flattenedCastsSQL := strings.TrimSuffix(strings.Join(flattenedCastsSQLArray, ""), ",") + flattenedCastsSQL := strings.Join(flattenedCastsSQLArray, ",") quotedUpperColNames := make([]string, 0, len(columnNames)) for _, columnName := range columnNames { @@ -85,7 +85,7 @@ func (m *mergeStmtGenerator) generateMergeStmt() (string, error) { fmt.Sprintf(`"%s"`, strings.ToUpper(m.peerdbCols.SyncedAtColName)), ) - insertColumnsSQL := strings.TrimSuffix(strings.Join(quotedUpperColNames, ","), ",") + insertColumnsSQL := strings.Join(quotedUpperColNames, ",") insertValuesSQLArray := make([]string, 0, len(columnNames)) for _, columnName := range columnNames { diff --git a/flow/connectors/snowflake/snowflake.go b/flow/connectors/snowflake/snowflake.go index b96d6d9360..658203382d 100644 --- a/flow/connectors/snowflake/snowflake.go +++ b/flow/connectors/snowflake/snowflake.go @@ -779,14 +779,14 @@ func generateCreateTableSQLForNormalizedTable( slog.Any("error", err)) return } - createTableSQLArray = append(createTableSQLArray, fmt.Sprintf(`%s %s,`, normalizedColName, sfColType)) + createTableSQLArray = append(createTableSQLArray, fmt.Sprintf(`%s %s`, normalizedColName, sfColType)) }) // add a _peerdb_is_deleted column to the normalized table // this is boolean default false, and is used to mark records as deleted if softDeleteColName != "" { createTableSQLArray = append(createTableSQLArray, - fmt.Sprintf(`%s BOOLEAN DEFAULT FALSE,`, softDeleteColName)) + fmt.Sprintf(`%s BOOLEAN DEFAULT FALSE`, softDeleteColName)) } // add a _peerdb_synced column to the normalized table @@ -794,7 +794,7 @@ func generateCreateTableSQLForNormalizedTable( // default value is the current timestamp (snowflake) if syncedAtColName != "" { createTableSQLArray = append(createTableSQLArray, - fmt.Sprintf(`%s TIMESTAMP DEFAULT CURRENT_TIMESTAMP,`, syncedAtColName)) + fmt.Sprintf(`%s TIMESTAMP DEFAULT CURRENT_TIMESTAMP`, syncedAtColName)) } // add composite primary key to the table @@ -804,12 +804,12 @@ func generateCreateTableSQLForNormalizedTable( normalizedPrimaryKeyCols = append(normalizedPrimaryKeyCols, SnowflakeIdentifierNormalize(primaryKeyCol)) } - createTableSQLArray = append(createTableSQLArray, fmt.Sprintf("PRIMARY KEY(%s),", - strings.TrimSuffix(strings.Join(normalizedPrimaryKeyCols, ","), ","))) + createTableSQLArray = append(createTableSQLArray, + fmt.Sprintf("PRIMARY KEY(%s)", strings.Join(normalizedPrimaryKeyCols, ","))) } return fmt.Sprintf(createNormalizedTableSQL, snowflakeSchemaTableNormalize(dstSchemaTable), - strings.TrimSuffix(strings.Join(createTableSQLArray, ""), ",")) + strings.Join(createTableSQLArray, ",")) } func getRawTableIdentifier(jobName string) string {