diff --git a/flow/connectors/postgres/client.go b/flow/connectors/postgres/client.go index 0a99bce668..8c8113911b 100644 --- a/flow/connectors/postgres/client.go +++ b/flow/connectors/postgres/client.go @@ -45,7 +45,6 @@ const ( getTableNameToUnchangedToastColsSQL = `SELECT _peerdb_destination_table_name, ARRAY_AGG(DISTINCT _peerdb_unchanged_toast_columns) FROM %s.%s WHERE _peerdb_batch_id>$1 AND _peerdb_batch_id<=$2 AND _peerdb_record_type!=2 GROUP BY _peerdb_destination_table_name` - srcTableName = "src" mergeStatementSQL = `WITH src_rank AS ( SELECT _peerdb_data,_peerdb_record_type,_peerdb_unchanged_toast_columns, RANK() OVER (PARTITION BY %s ORDER BY _peerdb_timestamp DESC) AS _peerdb_rank @@ -55,10 +54,8 @@ const ( USING (SELECT %s,_peerdb_record_type,_peerdb_unchanged_toast_columns FROM src_rank WHERE _peerdb_rank=1) src ON %s WHEN NOT MATCHED AND src._peerdb_record_type!=2 THEN - INSERT (%s) VALUES (%s) - %s - WHEN MATCHED AND src._peerdb_record_type=2 THEN - %s` + INSERT (%s) VALUES (%s) %s + WHEN MATCHED AND src._peerdb_record_type=2 THEN %s` fallbackUpsertStatementSQL = `WITH src_rank AS ( SELECT _peerdb_data,_peerdb_record_type,_peerdb_unchanged_toast_columns, RANK() OVER (PARTITION BY %s ORDER BY _peerdb_timestamp DESC) AS _peerdb_rank @@ -71,7 +68,7 @@ const ( RANK() OVER (PARTITION BY %s ORDER BY _peerdb_timestamp DESC) AS _peerdb_rank FROM %s.%s WHERE _peerdb_batch_id>$1 AND _peerdb_batch_id<=$2 AND _peerdb_destination_table_name=$3 ) - %s src_rank WHERE %s AND src_rank._peerdb_rank=1 AND src_rank._peerdb_record_type=2` + DELETE FROM %s USING %s FROM src_rank WHERE %s AND src_rank._peerdb_rank=1 AND src_rank._peerdb_record_type=2` dropTableIfExistsSQL = "DROP TABLE IF EXISTS %s.%s" deleteJobMetadataSQL = "DELETE FROM %s.%s WHERE mirror_job_name=$1" @@ -246,15 +243,17 @@ func (c *PostgresConnector) checkSlotAndPublication(slot string, publication str // If slotName input is empty, all slot info rows are returned - this is for UI. // Else, only the row pertaining to that slotName will be returned. func (c *PostgresConnector) GetSlotInfo(slotName string) ([]*protos.SlotInfo, error) { - specificSlotClause := "" + whereClause := "" if slotName != "" { - specificSlotClause = fmt.Sprintf(" WHERE slot_name = '%s'", slotName) + whereClause = fmt.Sprintf(" WHERE slot_name = %s", QuoteLiteral(slotName)) + } else { + whereClause = fmt.Sprintf(" WHERE database = %s", QuoteLiteral(c.config.Database)) } rows, err := c.pool.Query(c.ctx, "SELECT slot_name, redo_lsn::Text,restart_lsn::text,wal_status,"+ "confirmed_flush_lsn::text,active,"+ "round((CASE WHEN pg_is_in_recovery() THEN pg_last_wal_receive_lsn() ELSE pg_current_wal_lsn() END"+ " - confirmed_flush_lsn) / 1024 / 1024) AS MB_Behind"+ - " FROM pg_control_checkpoint(), pg_replication_slots"+specificSlotClause+";") + " FROM pg_control_checkpoint(), pg_replication_slots"+whereClause) if err != nil { return nil, err } @@ -403,20 +402,19 @@ func generateCreateTableSQLForNormalizedTable( if softDeleteColName != "" { createTableSQLArray = append(createTableSQLArray, - fmt.Sprintf(`"%s" BOOL DEFAULT FALSE,`, softDeleteColName)) + fmt.Sprintf(`%s BOOL DEFAULT FALSE,`, QuoteIdentifier(softDeleteColName))) } if syncedAtColName != "" { createTableSQLArray = append(createTableSQLArray, - fmt.Sprintf(`"%s" TIMESTAMP DEFAULT CURRENT_TIMESTAMP,`, syncedAtColName)) + fmt.Sprintf(`%s TIMESTAMP DEFAULT CURRENT_TIMESTAMP,`, QuoteIdentifier(syncedAtColName))) } // add composite primary key to the table if len(sourceTableSchema.PrimaryKeyColumns) > 0 { primaryKeyColsQuoted := make([]string, 0, len(sourceTableSchema.PrimaryKeyColumns)) for _, primaryKeyCol := range sourceTableSchema.PrimaryKeyColumns { - primaryKeyColsQuoted = append(primaryKeyColsQuoted, - fmt.Sprintf(`"%s"`, primaryKeyCol)) + primaryKeyColsQuoted = append(primaryKeyColsQuoted, QuoteIdentifier(primaryKeyCol)) } createTableSQLArray = append(createTableSQLArray, fmt.Sprintf("PRIMARY KEY(%s),", strings.TrimSuffix(strings.Join(primaryKeyColsQuoted, ","), ","))) diff --git a/flow/connectors/postgres/escape.go b/flow/connectors/postgres/escape.go new file mode 100644 index 0000000000..280d108338 --- /dev/null +++ b/flow/connectors/postgres/escape.go @@ -0,0 +1,57 @@ +// from https://github.com/lib/pq/blob/v1.10.9/conn.go#L1656 + +package connpostgres + +import "strings" + +// QuoteLiteral quotes a 'literal' (e.g. a parameter, often used to pass literal +// to DDL and other statements that do not accept parameters) to be used as part +// of an SQL statement. For example: +// +// exp_date := pq.QuoteLiteral("2023-01-05 15:00:00Z") +// err := db.Exec(fmt.Sprintf("CREATE ROLE my_user VALID UNTIL %s", exp_date)) +// +// Any single quotes in name will be escaped. Any backslashes (i.e. "\") will be +// replaced by two backslashes (i.e. "\\") and the C-style escape identifier +// that PostgreSQL provides ('E') will be prepended to the string. +func QuoteLiteral(literal string) string { + // This follows the PostgreSQL internal algorithm for handling quoted literals + // from libpq, which can be found in the "PQEscapeStringInternal" function, + // which is found in the libpq/fe-exec.c source file: + // https://git.postgresql.org/gitweb/?p=postgresql.git;a=blob;f=src/interfaces/libpq/fe-exec.c + // + // substitute any single-quotes (') with two single-quotes ('') + literal = strings.Replace(literal, `'`, `''`, -1) + // determine if the string has any backslashes (\) in it. + // if it does, replace any backslashes (\) with two backslashes (\\) + // then, we need to wrap the entire string with a PostgreSQL + // C-style escape. Per how "PQEscapeStringInternal" handles this case, we + // also add a space before the "E" + if strings.Contains(literal, `\`) { + literal = strings.Replace(literal, `\`, `\\`, -1) + literal = ` E'` + literal + `'` + } else { + // otherwise, we can just wrap the literal with a pair of single quotes + literal = `'` + literal + `'` + } + return literal +} + +// QuoteIdentifier quotes an "identifier" (e.g. a table or a column name) to be +// used as part of an SQL statement. For example: +// +// tblname := "my_table" +// data := "my_data" +// quoted := pq.QuoteIdentifier(tblname) +// err := db.Exec(fmt.Sprintf("INSERT INTO %s VALUES ($1)", quoted), data) +// +// Any double quotes in name will be escaped. The quoted identifier will be +// case sensitive when used in a query. If the input string contains a zero +// byte, the result will be truncated immediately before it. +func QuoteIdentifier(name string) string { + end := strings.IndexRune(name, 0) + if end > -1 { + name = name[:end] + } + return `"` + strings.Replace(name, `"`, `""`, -1) + `"` +} diff --git a/flow/connectors/postgres/normalize_stmt_generator.go b/flow/connectors/postgres/normalize_stmt_generator.go index b541543fe2..083021926a 100644 --- a/flow/connectors/postgres/normalize_stmt_generator.go +++ b/flow/connectors/postgres/normalize_stmt_generator.go @@ -46,65 +46,64 @@ func (n *normalizeStmtGenerator) generateFallbackStatements() []string { columnCount := utils.TableSchemaColumns(n.normalizedTableSchema) columnNames := make([]string, 0, columnCount) flattenedCastsSQLArray := make([]string, 0, columnCount) - primaryKeyColumnCasts := make(map[string]string) + primaryKeyColumnCasts := make(map[string]string, len(n.normalizedTableSchema.PrimaryKeyColumns)) utils.IterColumns(n.normalizedTableSchema, func(columnName, genericColumnType string) { - columnNames = append(columnNames, fmt.Sprintf("\"%s\"", columnName)) + quotedCol := QuoteIdentifier(columnName) + stringCol := QuoteLiteral(columnName) + columnNames = append(columnNames, quotedCol) pgType := qValueKindToPostgresType(genericColumnType) if qvalue.QValueKind(genericColumnType).IsArray() { flattenedCastsSQLArray = append(flattenedCastsSQLArray, - fmt.Sprintf("ARRAY(SELECT * FROM JSON_ARRAY_ELEMENTS_TEXT((_peerdb_data->>'%s')::JSON))::%s AS \"%s\"", - strings.Trim(columnName, "\""), pgType, columnName)) + fmt.Sprintf("ARRAY(SELECT * FROM JSON_ARRAY_ELEMENTS_TEXT((_peerdb_data->>%s)::JSON))::%s AS %s", + stringCol, pgType, quotedCol)) } else { - flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("(_peerdb_data->>'%s')::%s AS \"%s\"", - strings.Trim(columnName, "\""), pgType, columnName)) + flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("(_peerdb_data->>%s)::%s AS %s", + stringCol, pgType, quotedCol)) } if slices.Contains(n.normalizedTableSchema.PrimaryKeyColumns, columnName) { - primaryKeyColumnCasts[columnName] = fmt.Sprintf("(_peerdb_data->>'%s')::%s", columnName, pgType) + primaryKeyColumnCasts[columnName] = fmt.Sprintf("(_peerdb_data->>%s)::%s", stringCol, pgType) } }) - flattenedCastsSQL := strings.TrimSuffix(strings.Join(flattenedCastsSQLArray, ","), ",") + flattenedCastsSQL := strings.Join(flattenedCastsSQLArray, ",") parsedDstTable, _ := utils.ParseSchemaTable(n.dstTableName) - insertColumnsSQL := strings.TrimSuffix(strings.Join(columnNames, ","), ",") + insertColumnsSQL := strings.Join(columnNames, ",") updateColumnsSQLArray := make([]string, 0, utils.TableSchemaColumns(n.normalizedTableSchema)) utils.IterColumns(n.normalizedTableSchema, func(columnName, _ string) { - updateColumnsSQLArray = append(updateColumnsSQLArray, fmt.Sprintf(`"%s"=EXCLUDED."%s"`, columnName, columnName)) + quotedCol := QuoteIdentifier(columnName) + updateColumnsSQLArray = append(updateColumnsSQLArray, fmt.Sprintf(`%s=EXCLUDED.%s`, quotedCol, quotedCol)) }) - updateColumnsSQL := strings.TrimSuffix(strings.Join(updateColumnsSQLArray, ","), ",") + updateColumnsSQL := strings.Join(updateColumnsSQLArray, ",") deleteWhereClauseArray := make([]string, 0, len(n.normalizedTableSchema.PrimaryKeyColumns)) for columnName, columnCast := range primaryKeyColumnCasts { - deleteWhereClauseArray = append(deleteWhereClauseArray, fmt.Sprintf(`%s."%s"=%s AND `, - parsedDstTable.String(), columnName, columnCast)) + deleteWhereClauseArray = append(deleteWhereClauseArray, fmt.Sprintf(`%s.%s=%s`, + parsedDstTable.String(), QuoteIdentifier(columnName), columnCast)) } - deleteWhereClauseSQL := strings.TrimSuffix(strings.Join(deleteWhereClauseArray, ""), "AND ") - deletePart := fmt.Sprintf( - "DELETE FROM %s USING", - parsedDstTable.String()) + deleteWhereClauseSQL := strings.Join(deleteWhereClauseArray, " AND ") + deleteUpdate := "" if n.peerdbCols.SoftDelete { - deletePart = fmt.Sprintf(`UPDATE %s SET "%s"=TRUE`, - parsedDstTable.String(), n.peerdbCols.SoftDeleteColName) + deleteUpdate = fmt.Sprintf(`UPDATE %s SET %s=TRUE`, + parsedDstTable.String(), QuoteIdentifier(n.peerdbCols.SoftDeleteColName)) if n.peerdbCols.SyncedAtColName != "" { - deletePart = fmt.Sprintf(`%s,"%s"=CURRENT_TIMESTAMP`, - deletePart, n.peerdbCols.SyncedAtColName) + deleteUpdate += fmt.Sprintf(`,%s=CURRENT_TIMESTAMP`, QuoteIdentifier(n.peerdbCols.SyncedAtColName)) } - deletePart += " FROM" } fallbackUpsertStatement := fmt.Sprintf(fallbackUpsertStatementSQL, - strings.TrimSuffix(strings.Join(maps.Values(primaryKeyColumnCasts), ","), ","), n.metadataSchema, + strings.Join(maps.Values(primaryKeyColumnCasts), ","), n.metadataSchema, n.rawTableName, parsedDstTable.String(), insertColumnsSQL, flattenedCastsSQL, strings.Join(n.normalizedTableSchema.PrimaryKeyColumns, ","), updateColumnsSQL) fallbackDeleteStatement := fmt.Sprintf(fallbackDeleteStatementSQL, strings.Join(maps.Values(primaryKeyColumnCasts), ","), n.metadataSchema, - n.rawTableName, deletePart, deleteWhereClauseSQL) + n.rawTableName, parsedDstTable.String(), deleteUpdate, deleteWhereClauseSQL) return []string{fallbackUpsertStatement, fallbackDeleteStatement} } func (n *normalizeStmtGenerator) generateMergeStatement() string { - columnNames := utils.TableSchemaColumnNames(n.normalizedTableSchema) - for i, columnName := range columnNames { - columnNames[i] = fmt.Sprintf("\"%s\"", columnName) + quotedColumnNames := utils.TableSchemaColumnNames(n.normalizedTableSchema) + for i, columnName := range quotedColumnNames { + quotedColumnNames[i] = QuoteIdentifier(columnName) } flattenedCastsSQLArray := make([]string, 0, utils.TableSchemaColumns(n.normalizedTableSchema)) @@ -113,38 +112,41 @@ func (n *normalizeStmtGenerator) generateMergeStatement() string { primaryKeyColumnCasts := make(map[string]string) primaryKeySelectSQLArray := make([]string, 0, len(n.normalizedTableSchema.PrimaryKeyColumns)) utils.IterColumns(n.normalizedTableSchema, func(columnName, genericColumnType string) { + quotedCol := QuoteIdentifier(columnName) + stringCol := QuoteLiteral(columnName) pgType := qValueKindToPostgresType(genericColumnType) if qvalue.QValueKind(genericColumnType).IsArray() { flattenedCastsSQLArray = append(flattenedCastsSQLArray, - fmt.Sprintf("ARRAY(SELECT * FROM JSON_ARRAY_ELEMENTS_TEXT((_peerdb_data->>'%s')::JSON))::%s AS \"%s\"", - strings.Trim(columnName, "\""), pgType, columnName)) + fmt.Sprintf("ARRAY(SELECT * FROM JSON_ARRAY_ELEMENTS_TEXT((_peerdb_data->>%s)::JSON))::%s AS %s", + stringCol, pgType, quotedCol)) } else { - flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("(_peerdb_data->>'%s')::%s AS \"%s\"", - strings.Trim(columnName, "\""), pgType, columnName)) + flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("(_peerdb_data->>%s)::%s AS %s", + stringCol, pgType, quotedCol)) } if slices.Contains(n.normalizedTableSchema.PrimaryKeyColumns, columnName) { - primaryKeyColumnCasts[columnName] = fmt.Sprintf("(_peerdb_data->>'%s')::%s", columnName, pgType) + primaryKeyColumnCasts[columnName] = fmt.Sprintf("(_peerdb_data->>%s)::%s", stringCol, pgType) primaryKeySelectSQLArray = append(primaryKeySelectSQLArray, fmt.Sprintf("src.%s=dst.%s", - columnName, columnName)) + quotedCol, quotedCol)) } }) - flattenedCastsSQL := strings.TrimSuffix(strings.Join(flattenedCastsSQLArray, ","), ",") - insertValuesSQLArray := make([]string, 0, len(columnNames)) - for _, columnName := range columnNames { - insertValuesSQLArray = append(insertValuesSQLArray, fmt.Sprintf("src.%s", columnName)) + flattenedCastsSQL := strings.Join(flattenedCastsSQLArray, ",") + insertValuesSQLArray := make([]string, 0, len(quotedColumnNames)+2) + for _, quotedCol := range quotedColumnNames { + insertValuesSQLArray = append(insertValuesSQLArray, fmt.Sprintf("src.%s", quotedCol)) } - updateStatementsforToastCols := n.generateUpdateStatements(columnNames) + updateStatementsforToastCols := n.generateUpdateStatements(quotedColumnNames) // append synced_at column - columnNames = append(columnNames, fmt.Sprintf(`"%s"`, n.peerdbCols.SyncedAtColName)) - insertColumnsSQL := strings.Join(columnNames, ",") - // fill in synced_at column - insertValuesSQLArray = append(insertValuesSQLArray, "CURRENT_TIMESTAMP") - insertValuesSQL := strings.TrimSuffix(strings.Join(insertValuesSQLArray, ","), ",") + if n.peerdbCols.SyncedAtColName != "" { + quotedColumnNames = append(quotedColumnNames, QuoteIdentifier(n.peerdbCols.SyncedAtColName)) + insertValuesSQLArray = append(insertValuesSQLArray, "CURRENT_TIMESTAMP") + } + insertColumnsSQL := strings.Join(quotedColumnNames, ",") + insertValuesSQL := strings.Join(insertValuesSQLArray, ",") if n.peerdbCols.SoftDelete { - softDeleteInsertColumnsSQL := strings.TrimSuffix(strings.Join(append(columnNames, - fmt.Sprintf(`"%s"`, n.peerdbCols.SoftDeleteColName)), ","), ",") + softDeleteInsertColumnsSQL := strings.Join( + append(quotedColumnNames, QuoteIdentifier(n.peerdbCols.SoftDeleteColName)), ",") softDeleteInsertValuesSQL := strings.Join(append(insertValuesSQLArray, "TRUE"), ",") updateStatementsforToastCols = append(updateStatementsforToastCols, @@ -153,13 +155,12 @@ func (n *normalizeStmtGenerator) generateMergeStatement() string { } updateStringToastCols := strings.Join(updateStatementsforToastCols, "\n") - deletePart := "DELETE" + conflictPart := "DELETE" if n.peerdbCols.SoftDelete { colName := n.peerdbCols.SoftDeleteColName - deletePart = fmt.Sprintf(`UPDATE SET "%s"=TRUE`, colName) + conflictPart = fmt.Sprintf(`UPDATE SET %s=TRUE`, QuoteIdentifier(colName)) if n.peerdbCols.SyncedAtColName != "" { - deletePart = fmt.Sprintf(`%s,"%s"=CURRENT_TIMESTAMP`, - deletePart, n.peerdbCols.SyncedAtColName) + conflictPart += fmt.Sprintf(`,%s=CURRENT_TIMESTAMP`, QuoteIdentifier(n.peerdbCols.SyncedAtColName)) } } @@ -174,13 +175,13 @@ func (n *normalizeStmtGenerator) generateMergeStatement() string { insertColumnsSQL, insertValuesSQL, updateStringToastCols, - deletePart, + conflictPart, ) return mergeStmt } -func (n *normalizeStmtGenerator) generateUpdateStatements(allCols []string) []string { +func (n *normalizeStmtGenerator) generateUpdateStatements(quotedCols []string) []string { handleSoftDelete := n.peerdbCols.SoftDelete && (n.peerdbCols.SoftDeleteColName != "") // weird way of doing it but avoids prealloc lint updateStmts := make([]string, 0, func() int { @@ -191,43 +192,42 @@ func (n *normalizeStmtGenerator) generateUpdateStatements(allCols []string) []st }()) for _, cols := range n.unchangedToastColumns { - unquotedUnchangedColsArray := strings.Split(cols, ",") - unchangedColsArray := make([]string, 0, len(unquotedUnchangedColsArray)) - for _, unchangedToastCol := range unquotedUnchangedColsArray { - unchangedColsArray = append(unchangedColsArray, fmt.Sprintf(`"%s"`, unchangedToastCol)) + unchangedColsArray := strings.Split(cols, ",") + for i, unchangedToastCol := range unchangedColsArray { + unchangedColsArray[i] = QuoteIdentifier(unchangedToastCol) } - otherCols := utils.ArrayMinus(allCols, unchangedColsArray) + otherCols := utils.ArrayMinus(quotedCols, unchangedColsArray) tmpArray := make([]string, 0, len(otherCols)) for _, colName := range otherCols { tmpArray = append(tmpArray, fmt.Sprintf("%s=src.%s", colName, colName)) } // set the synced at column to the current timestamp if n.peerdbCols.SyncedAtColName != "" { - tmpArray = append(tmpArray, fmt.Sprintf(`"%s"=CURRENT_TIMESTAMP`, - n.peerdbCols.SyncedAtColName)) + tmpArray = append(tmpArray, fmt.Sprintf(`%s=CURRENT_TIMESTAMP`, + QuoteIdentifier(n.peerdbCols.SyncedAtColName))) } // set soft-deleted to false, tackles insert after soft-delete if handleSoftDelete { - tmpArray = append(tmpArray, fmt.Sprintf(`"%s"=FALSE`, - n.peerdbCols.SoftDeleteColName)) + tmpArray = append(tmpArray, fmt.Sprintf(`%s=FALSE`, + QuoteIdentifier(n.peerdbCols.SoftDeleteColName))) } + quotedCols := QuoteLiteral(cols) ssep := strings.Join(tmpArray, ",") updateStmt := fmt.Sprintf(`WHEN MATCHED AND - src._peerdb_record_type!=2 AND _peerdb_unchanged_toast_columns='%s' - THEN UPDATE SET %s`, cols, ssep) + src._peerdb_record_type!=2 AND _peerdb_unchanged_toast_columns=%s + THEN UPDATE SET %s`, quotedCols, ssep) updateStmts = append(updateStmts, updateStmt) // generates update statements for the case where updates and deletes happen in the same branch // the backfill has happened from the pull side already, so treat the DeleteRecord as an update // and then set soft-delete to true. if handleSoftDelete { - tmpArray = append(tmpArray[:len(tmpArray)-1], - fmt.Sprintf(`"%s"=TRUE`, n.peerdbCols.SoftDeleteColName)) + tmpArray[len(tmpArray)-1] = fmt.Sprintf(`%s=TRUE`, QuoteIdentifier(n.peerdbCols.SoftDeleteColName)) ssep := strings.Join(tmpArray, ", ") updateStmt := fmt.Sprintf(`WHEN MATCHED AND - src._peerdb_record_type=2 AND _peerdb_unchanged_toast_columns='%s' - THEN UPDATE SET %s `, cols, ssep) + src._peerdb_record_type=2 AND _peerdb_unchanged_toast_columns=%s + THEN UPDATE SET %s`, quotedCols, ssep) updateStmts = append(updateStmts, updateStmt) } } diff --git a/flow/connectors/postgres/qrep.go b/flow/connectors/postgres/qrep.go index 2e45647280..f34e0a13bd 100644 --- a/flow/connectors/postgres/qrep.go +++ b/flow/connectors/postgres/qrep.go @@ -68,7 +68,7 @@ func (c *PostgresConnector) GetQRepPartitions( func (c *PostgresConnector) setTransactionSnapshot(tx pgx.Tx) error { snapshot := c.config.TransactionSnapshot if snapshot != "" { - if _, err := tx.Exec(c.ctx, fmt.Sprintf("SET TRANSACTION SNAPSHOT '%s'", snapshot)); err != nil { + if _, err := tx.Exec(c.ctx, fmt.Sprintf("SET TRANSACTION SNAPSHOT %s", QuoteLiteral(snapshot))); err != nil { return fmt.Errorf("failed to set transaction snapshot: %w", err) } } diff --git a/flow/connectors/postgres/qrep_query_executor.go b/flow/connectors/postgres/qrep_query_executor.go index 627c2e2fcf..44551a124a 100644 --- a/flow/connectors/postgres/qrep_query_executor.go +++ b/flow/connectors/postgres/qrep_query_executor.go @@ -365,7 +365,7 @@ func (qe *QRepQueryExecutor) ExecuteAndProcessQueryStreamWithTx( }() if qe.snapshot != "" { - _, err = tx.Exec(qe.ctx, fmt.Sprintf("SET TRANSACTION SNAPSHOT '%s'", qe.snapshot)) + _, err = tx.Exec(qe.ctx, fmt.Sprintf("SET TRANSACTION SNAPSHOT %s", QuoteLiteral(qe.snapshot))) if err != nil { stream.Records <- model.QRecordOrError{ Err: fmt.Errorf("failed to set snapshot: %w", err), diff --git a/flow/connectors/postgres/qrep_sync_method.go b/flow/connectors/postgres/qrep_sync_method.go index 6725032411..6fdda9ab54 100644 --- a/flow/connectors/postgres/qrep_sync_method.go +++ b/flow/connectors/postgres/qrep_sync_method.go @@ -85,10 +85,10 @@ func (s *QRepStagingTableSync) SyncQRepRecords( if syncedAtCol != "" { updateSyncedAtStmt := fmt.Sprintf( - `UPDATE %s SET "%s" = CURRENT_TIMESTAMP WHERE "%s" IS NULL;`, + `UPDATE %s SET %s = CURRENT_TIMESTAMP WHERE %s IS NULL;`, pgx.Identifier{dstTableName.Schema, dstTableName.Table}.Sanitize(), - syncedAtCol, - syncedAtCol, + QuoteIdentifier(syncedAtCol), + QuoteIdentifier(syncedAtCol), ) _, err = tx.Exec(context.Background(), updateSyncedAtStmt) if err != nil { @@ -137,22 +137,23 @@ func (s *QRepStagingTableSync) SyncQRepRecords( selectStrArray := make([]string, 0) for _, col := range schema.GetColumnNames() { _, ok := upsertMatchCols[col] + quotedCol := QuoteIdentifier(col) if !ok { - setClauseArray = append(setClauseArray, fmt.Sprintf(`"%s" = EXCLUDED."%s"`, col, col)) + setClauseArray = append(setClauseArray, fmt.Sprintf(`%s = EXCLUDED.%s`, quotedCol, quotedCol)) } - selectStrArray = append(selectStrArray, fmt.Sprintf(`"%s"`, col)) + selectStrArray = append(selectStrArray, quotedCol) } setClauseArray = append(setClauseArray, - fmt.Sprintf(`"%s" = CURRENT_TIMESTAMP`, syncedAtCol)) + fmt.Sprintf(`%s = CURRENT_TIMESTAMP`, QuoteIdentifier(syncedAtCol))) setClause := strings.Join(setClauseArray, ",") selectSQL := strings.Join(selectStrArray, ",") // Step 2.3: Perform the upsert operation, ON CONFLICT UPDATE upsertStmt := fmt.Sprintf( - `INSERT INTO %s (%s, "%s") SELECT %s, CURRENT_TIMESTAMP FROM %s ON CONFLICT (%s) DO UPDATE SET %s;`, + `INSERT INTO %s (%s, %s) SELECT %s, CURRENT_TIMESTAMP FROM %s ON CONFLICT (%s) DO UPDATE SET %s;`, dstTableIdentifier.Sanitize(), selectSQL, - syncedAtCol, + QuoteIdentifier(syncedAtCol), selectSQL, stagingTableIdentifier.Sanitize(), strings.Join(writeMode.UpsertKeyColumns, ", "), diff --git a/flow/e2e/postgres/qrep_flow_pg_test.go b/flow/e2e/postgres/qrep_flow_pg_test.go index fa9ad948ab..2943e201b5 100644 --- a/flow/e2e/postgres/qrep_flow_pg_test.go +++ b/flow/e2e/postgres/qrep_flow_pg_test.go @@ -116,7 +116,7 @@ func (s PeerFlowE2ETestSuitePG) checkEnums(srcSchemaQualified, dstSchemaQualifie } if exists.Bool { - return fmt.Errorf("enum comparison failed: rows are not equal\n") + return fmt.Errorf("enum comparison failed: rows are not equal") } return nil }