From 1ee61d811ef690e37237f583dbf791f3992f3e6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20Dub=C3=A9?= Date: Wed, 17 Jan 2024 04:20:53 +0000 Subject: [PATCH] GetSlotInfo: only list replication slots in peer's database Also introduce escape.go to help with properly escaping strings/identifiers Update all cases of "%s" in connectors/postgres, along with a few strings such as in GetSlotInfo --- flow/connectors/postgres/client.go | 24 ++-- flow/connectors/postgres/escape.go | 57 ++++++++++ .../postgres/normalize_stmt_generator.go | 105 +++++++++--------- flow/connectors/postgres/qrep.go | 2 +- .../postgres/qrep_query_executor.go | 2 +- flow/connectors/postgres/qrep_sync_method.go | 17 +-- flow/e2e/postgres/qrep_flow_pg_test.go | 2 +- 7 files changed, 133 insertions(+), 76 deletions(-) create mode 100644 flow/connectors/postgres/escape.go diff --git a/flow/connectors/postgres/client.go b/flow/connectors/postgres/client.go index 0a99bce668..8c8113911b 100644 --- a/flow/connectors/postgres/client.go +++ b/flow/connectors/postgres/client.go @@ -45,7 +45,6 @@ const ( getTableNameToUnchangedToastColsSQL = `SELECT _peerdb_destination_table_name, ARRAY_AGG(DISTINCT _peerdb_unchanged_toast_columns) FROM %s.%s WHERE _peerdb_batch_id>$1 AND _peerdb_batch_id<=$2 AND _peerdb_record_type!=2 GROUP BY _peerdb_destination_table_name` - srcTableName = "src" mergeStatementSQL = `WITH src_rank AS ( SELECT _peerdb_data,_peerdb_record_type,_peerdb_unchanged_toast_columns, RANK() OVER (PARTITION BY %s ORDER BY _peerdb_timestamp DESC) AS _peerdb_rank @@ -55,10 +54,8 @@ const ( USING (SELECT %s,_peerdb_record_type,_peerdb_unchanged_toast_columns FROM src_rank WHERE _peerdb_rank=1) src ON %s WHEN NOT MATCHED AND src._peerdb_record_type!=2 THEN - INSERT (%s) VALUES (%s) - %s - WHEN MATCHED AND src._peerdb_record_type=2 THEN - %s` + INSERT (%s) VALUES (%s) %s + WHEN MATCHED AND src._peerdb_record_type=2 THEN %s` fallbackUpsertStatementSQL = `WITH src_rank AS ( SELECT _peerdb_data,_peerdb_record_type,_peerdb_unchanged_toast_columns, RANK() OVER (PARTITION BY %s ORDER BY _peerdb_timestamp DESC) AS _peerdb_rank @@ -71,7 +68,7 @@ const ( RANK() OVER (PARTITION BY %s ORDER BY _peerdb_timestamp DESC) AS _peerdb_rank FROM %s.%s WHERE _peerdb_batch_id>$1 AND _peerdb_batch_id<=$2 AND _peerdb_destination_table_name=$3 ) - %s src_rank WHERE %s AND src_rank._peerdb_rank=1 AND src_rank._peerdb_record_type=2` + DELETE FROM %s USING %s FROM src_rank WHERE %s AND src_rank._peerdb_rank=1 AND src_rank._peerdb_record_type=2` dropTableIfExistsSQL = "DROP TABLE IF EXISTS %s.%s" deleteJobMetadataSQL = "DELETE FROM %s.%s WHERE mirror_job_name=$1" @@ -246,15 +243,17 @@ func (c *PostgresConnector) checkSlotAndPublication(slot string, publication str // If slotName input is empty, all slot info rows are returned - this is for UI. // Else, only the row pertaining to that slotName will be returned. func (c *PostgresConnector) GetSlotInfo(slotName string) ([]*protos.SlotInfo, error) { - specificSlotClause := "" + whereClause := "" if slotName != "" { - specificSlotClause = fmt.Sprintf(" WHERE slot_name = '%s'", slotName) + whereClause = fmt.Sprintf(" WHERE slot_name = %s", QuoteLiteral(slotName)) + } else { + whereClause = fmt.Sprintf(" WHERE database = %s", QuoteLiteral(c.config.Database)) } rows, err := c.pool.Query(c.ctx, "SELECT slot_name, redo_lsn::Text,restart_lsn::text,wal_status,"+ "confirmed_flush_lsn::text,active,"+ "round((CASE WHEN pg_is_in_recovery() THEN pg_last_wal_receive_lsn() ELSE pg_current_wal_lsn() END"+ " - confirmed_flush_lsn) / 1024 / 1024) AS MB_Behind"+ - " FROM pg_control_checkpoint(), pg_replication_slots"+specificSlotClause+";") + " FROM pg_control_checkpoint(), pg_replication_slots"+whereClause) if err != nil { return nil, err } @@ -403,20 +402,19 @@ func generateCreateTableSQLForNormalizedTable( if softDeleteColName != "" { createTableSQLArray = append(createTableSQLArray, - fmt.Sprintf(`"%s" BOOL DEFAULT FALSE,`, softDeleteColName)) + fmt.Sprintf(`%s BOOL DEFAULT FALSE,`, QuoteIdentifier(softDeleteColName))) } if syncedAtColName != "" { createTableSQLArray = append(createTableSQLArray, - fmt.Sprintf(`"%s" TIMESTAMP DEFAULT CURRENT_TIMESTAMP,`, syncedAtColName)) + fmt.Sprintf(`%s TIMESTAMP DEFAULT CURRENT_TIMESTAMP,`, QuoteIdentifier(syncedAtColName))) } // add composite primary key to the table if len(sourceTableSchema.PrimaryKeyColumns) > 0 { primaryKeyColsQuoted := make([]string, 0, len(sourceTableSchema.PrimaryKeyColumns)) for _, primaryKeyCol := range sourceTableSchema.PrimaryKeyColumns { - primaryKeyColsQuoted = append(primaryKeyColsQuoted, - fmt.Sprintf(`"%s"`, primaryKeyCol)) + primaryKeyColsQuoted = append(primaryKeyColsQuoted, QuoteIdentifier(primaryKeyCol)) } createTableSQLArray = append(createTableSQLArray, fmt.Sprintf("PRIMARY KEY(%s),", strings.TrimSuffix(strings.Join(primaryKeyColsQuoted, ","), ","))) diff --git a/flow/connectors/postgres/escape.go b/flow/connectors/postgres/escape.go new file mode 100644 index 0000000000..280d108338 --- /dev/null +++ b/flow/connectors/postgres/escape.go @@ -0,0 +1,57 @@ +// from https://github.com/lib/pq/blob/v1.10.9/conn.go#L1656 + +package connpostgres + +import "strings" + +// QuoteLiteral quotes a 'literal' (e.g. a parameter, often used to pass literal +// to DDL and other statements that do not accept parameters) to be used as part +// of an SQL statement. For example: +// +// exp_date := pq.QuoteLiteral("2023-01-05 15:00:00Z") +// err := db.Exec(fmt.Sprintf("CREATE ROLE my_user VALID UNTIL %s", exp_date)) +// +// Any single quotes in name will be escaped. Any backslashes (i.e. "\") will be +// replaced by two backslashes (i.e. "\\") and the C-style escape identifier +// that PostgreSQL provides ('E') will be prepended to the string. +func QuoteLiteral(literal string) string { + // This follows the PostgreSQL internal algorithm for handling quoted literals + // from libpq, which can be found in the "PQEscapeStringInternal" function, + // which is found in the libpq/fe-exec.c source file: + // https://git.postgresql.org/gitweb/?p=postgresql.git;a=blob;f=src/interfaces/libpq/fe-exec.c + // + // substitute any single-quotes (') with two single-quotes ('') + literal = strings.Replace(literal, `'`, `''`, -1) + // determine if the string has any backslashes (\) in it. + // if it does, replace any backslashes (\) with two backslashes (\\) + // then, we need to wrap the entire string with a PostgreSQL + // C-style escape. Per how "PQEscapeStringInternal" handles this case, we + // also add a space before the "E" + if strings.Contains(literal, `\`) { + literal = strings.Replace(literal, `\`, `\\`, -1) + literal = ` E'` + literal + `'` + } else { + // otherwise, we can just wrap the literal with a pair of single quotes + literal = `'` + literal + `'` + } + return literal +} + +// QuoteIdentifier quotes an "identifier" (e.g. a table or a column name) to be +// used as part of an SQL statement. For example: +// +// tblname := "my_table" +// data := "my_data" +// quoted := pq.QuoteIdentifier(tblname) +// err := db.Exec(fmt.Sprintf("INSERT INTO %s VALUES ($1)", quoted), data) +// +// Any double quotes in name will be escaped. The quoted identifier will be +// case sensitive when used in a query. If the input string contains a zero +// byte, the result will be truncated immediately before it. +func QuoteIdentifier(name string) string { + end := strings.IndexRune(name, 0) + if end > -1 { + name = name[:end] + } + return `"` + strings.Replace(name, `"`, `""`, -1) + `"` +} diff --git a/flow/connectors/postgres/normalize_stmt_generator.go b/flow/connectors/postgres/normalize_stmt_generator.go index b541543fe2..9c292ef252 100644 --- a/flow/connectors/postgres/normalize_stmt_generator.go +++ b/flow/connectors/postgres/normalize_stmt_generator.go @@ -48,18 +48,20 @@ func (n *normalizeStmtGenerator) generateFallbackStatements() []string { flattenedCastsSQLArray := make([]string, 0, columnCount) primaryKeyColumnCasts := make(map[string]string) utils.IterColumns(n.normalizedTableSchema, func(columnName, genericColumnType string) { - columnNames = append(columnNames, fmt.Sprintf("\"%s\"", columnName)) + quotedCol := QuoteIdentifier(columnName) + stringCol := QuoteLiteral(columnName) + columnNames = append(columnNames, quotedCol) pgType := qValueKindToPostgresType(genericColumnType) if qvalue.QValueKind(genericColumnType).IsArray() { flattenedCastsSQLArray = append(flattenedCastsSQLArray, - fmt.Sprintf("ARRAY(SELECT * FROM JSON_ARRAY_ELEMENTS_TEXT((_peerdb_data->>'%s')::JSON))::%s AS \"%s\"", - strings.Trim(columnName, "\""), pgType, columnName)) + fmt.Sprintf("ARRAY(SELECT * FROM JSON_ARRAY_ELEMENTS_TEXT((_peerdb_data->>%s)::JSON))::%s AS %s", + stringCol, pgType, quotedCol)) } else { - flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("(_peerdb_data->>'%s')::%s AS \"%s\"", - strings.Trim(columnName, "\""), pgType, columnName)) + flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("(_peerdb_data->>%s)::%s AS %s", + stringCol, pgType, quotedCol)) } if slices.Contains(n.normalizedTableSchema.PrimaryKeyColumns, columnName) { - primaryKeyColumnCasts[columnName] = fmt.Sprintf("(_peerdb_data->>'%s')::%s", columnName, pgType) + primaryKeyColumnCasts[columnName] = fmt.Sprintf("(_peerdb_data->>%s)::%s", stringCol, pgType) } }) flattenedCastsSQL := strings.TrimSuffix(strings.Join(flattenedCastsSQLArray, ","), ",") @@ -68,27 +70,24 @@ func (n *normalizeStmtGenerator) generateFallbackStatements() []string { insertColumnsSQL := strings.TrimSuffix(strings.Join(columnNames, ","), ",") updateColumnsSQLArray := make([]string, 0, utils.TableSchemaColumns(n.normalizedTableSchema)) utils.IterColumns(n.normalizedTableSchema, func(columnName, _ string) { - updateColumnsSQLArray = append(updateColumnsSQLArray, fmt.Sprintf(`"%s"=EXCLUDED."%s"`, columnName, columnName)) + quotedCol := QuoteIdentifier(columnName) + updateColumnsSQLArray = append(updateColumnsSQLArray, fmt.Sprintf(`%s=EXCLUDED.%s`, quotedCol, quotedCol)) }) updateColumnsSQL := strings.TrimSuffix(strings.Join(updateColumnsSQLArray, ","), ",") deleteWhereClauseArray := make([]string, 0, len(n.normalizedTableSchema.PrimaryKeyColumns)) for columnName, columnCast := range primaryKeyColumnCasts { - deleteWhereClauseArray = append(deleteWhereClauseArray, fmt.Sprintf(`%s."%s"=%s AND `, - parsedDstTable.String(), columnName, columnCast)) + deleteWhereClauseArray = append(deleteWhereClauseArray, fmt.Sprintf(`%s.%s=%s AND `, + parsedDstTable.String(), QuoteIdentifier(columnName), columnCast)) } deleteWhereClauseSQL := strings.TrimSuffix(strings.Join(deleteWhereClauseArray, ""), "AND ") - deletePart := fmt.Sprintf( - "DELETE FROM %s USING", - parsedDstTable.String()) + deleteUpdate := "" if n.peerdbCols.SoftDelete { - deletePart = fmt.Sprintf(`UPDATE %s SET "%s"=TRUE`, - parsedDstTable.String(), n.peerdbCols.SoftDeleteColName) + deleteUpdate = fmt.Sprintf(`UPDATE %s SET %s=TRUE`, + parsedDstTable.String(), QuoteIdentifier(n.peerdbCols.SoftDeleteColName)) if n.peerdbCols.SyncedAtColName != "" { - deletePart = fmt.Sprintf(`%s,"%s"=CURRENT_TIMESTAMP`, - deletePart, n.peerdbCols.SyncedAtColName) + deleteUpdate += fmt.Sprintf(`,%s=CURRENT_TIMESTAMP`, QuoteIdentifier(n.peerdbCols.SyncedAtColName)) } - deletePart += " FROM" } fallbackUpsertStatement := fmt.Sprintf(fallbackUpsertStatementSQL, strings.TrimSuffix(strings.Join(maps.Values(primaryKeyColumnCasts), ","), ","), n.metadataSchema, @@ -96,15 +95,15 @@ func (n *normalizeStmtGenerator) generateFallbackStatements() []string { strings.Join(n.normalizedTableSchema.PrimaryKeyColumns, ","), updateColumnsSQL) fallbackDeleteStatement := fmt.Sprintf(fallbackDeleteStatementSQL, strings.Join(maps.Values(primaryKeyColumnCasts), ","), n.metadataSchema, - n.rawTableName, deletePart, deleteWhereClauseSQL) + n.rawTableName, parsedDstTable.String(), deleteUpdate, deleteWhereClauseSQL) return []string{fallbackUpsertStatement, fallbackDeleteStatement} } func (n *normalizeStmtGenerator) generateMergeStatement() string { - columnNames := utils.TableSchemaColumnNames(n.normalizedTableSchema) - for i, columnName := range columnNames { - columnNames[i] = fmt.Sprintf("\"%s\"", columnName) + quotedColumnNames := utils.TableSchemaColumnNames(n.normalizedTableSchema) + for i, columnName := range quotedColumnNames { + quotedColumnNames[i] = QuoteIdentifier(columnName) } flattenedCastsSQLArray := make([]string, 0, utils.TableSchemaColumns(n.normalizedTableSchema)) @@ -113,38 +112,40 @@ func (n *normalizeStmtGenerator) generateMergeStatement() string { primaryKeyColumnCasts := make(map[string]string) primaryKeySelectSQLArray := make([]string, 0, len(n.normalizedTableSchema.PrimaryKeyColumns)) utils.IterColumns(n.normalizedTableSchema, func(columnName, genericColumnType string) { + quotedCol := QuoteIdentifier(columnName) pgType := qValueKindToPostgresType(genericColumnType) if qvalue.QValueKind(genericColumnType).IsArray() { flattenedCastsSQLArray = append(flattenedCastsSQLArray, - fmt.Sprintf("ARRAY(SELECT * FROM JSON_ARRAY_ELEMENTS_TEXT((_peerdb_data->>'%s')::JSON))::%s AS \"%s\"", - strings.Trim(columnName, "\""), pgType, columnName)) + fmt.Sprintf("ARRAY(SELECT * FROM JSON_ARRAY_ELEMENTS_TEXT((_peerdb_data->>%s)::JSON))::%s AS %s", + QuoteLiteral(columnName), pgType, quotedCol)) } else { - flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("(_peerdb_data->>'%s')::%s AS \"%s\"", - strings.Trim(columnName, "\""), pgType, columnName)) + flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("(_peerdb_data->>%s)::%s AS %s", + QuoteLiteral(columnName), pgType, quotedCol)) } if slices.Contains(n.normalizedTableSchema.PrimaryKeyColumns, columnName) { - primaryKeyColumnCasts[columnName] = fmt.Sprintf("(_peerdb_data->>'%s')::%s", columnName, pgType) + primaryKeyColumnCasts[columnName] = fmt.Sprintf("(_peerdb_data->>%s)::%s", columnName, pgType) primaryKeySelectSQLArray = append(primaryKeySelectSQLArray, fmt.Sprintf("src.%s=dst.%s", - columnName, columnName)) + quotedCol, quotedCol)) } }) flattenedCastsSQL := strings.TrimSuffix(strings.Join(flattenedCastsSQLArray, ","), ",") - insertValuesSQLArray := make([]string, 0, len(columnNames)) - for _, columnName := range columnNames { + insertValuesSQLArray := make([]string, 0, len(quotedColumnNames)+2) + for _, columnName := range quotedColumnNames { insertValuesSQLArray = append(insertValuesSQLArray, fmt.Sprintf("src.%s", columnName)) } - updateStatementsforToastCols := n.generateUpdateStatements(columnNames) + updateStatementsforToastCols := n.generateUpdateStatements(quotedColumnNames) // append synced_at column - columnNames = append(columnNames, fmt.Sprintf(`"%s"`, n.peerdbCols.SyncedAtColName)) - insertColumnsSQL := strings.Join(columnNames, ",") - // fill in synced_at column - insertValuesSQLArray = append(insertValuesSQLArray, "CURRENT_TIMESTAMP") - insertValuesSQL := strings.TrimSuffix(strings.Join(insertValuesSQLArray, ","), ",") + if n.peerdbCols.SyncedAtColName != "" { + quotedColumnNames = append(quotedColumnNames, QuoteIdentifier(n.peerdbCols.SyncedAtColName)) + insertValuesSQLArray = append(insertValuesSQLArray, "CURRENT_TIMESTAMP") + } + insertColumnsSQL := strings.Join(quotedColumnNames, ",") + insertValuesSQL := strings.Join(insertValuesSQLArray, ",") if n.peerdbCols.SoftDelete { - softDeleteInsertColumnsSQL := strings.TrimSuffix(strings.Join(append(columnNames, - fmt.Sprintf(`"%s"`, n.peerdbCols.SoftDeleteColName)), ","), ",") + softDeleteInsertColumnsSQL := strings.Join( + append(quotedColumnNames, QuoteIdentifier(n.peerdbCols.SoftDeleteColName)), ",") softDeleteInsertValuesSQL := strings.Join(append(insertValuesSQLArray, "TRUE"), ",") updateStatementsforToastCols = append(updateStatementsforToastCols, @@ -156,10 +157,10 @@ func (n *normalizeStmtGenerator) generateMergeStatement() string { deletePart := "DELETE" if n.peerdbCols.SoftDelete { colName := n.peerdbCols.SoftDeleteColName - deletePart = fmt.Sprintf(`UPDATE SET "%s"=TRUE`, colName) + deletePart = fmt.Sprintf(`UPDATE SET %s=TRUE`, QuoteIdentifier(colName)) if n.peerdbCols.SyncedAtColName != "" { - deletePart = fmt.Sprintf(`%s,"%s"=CURRENT_TIMESTAMP`, - deletePart, n.peerdbCols.SyncedAtColName) + deletePart = fmt.Sprintf(`%s,%s=CURRENT_TIMESTAMP`, + deletePart, QuoteIdentifier(n.peerdbCols.SyncedAtColName)) } } @@ -180,7 +181,7 @@ func (n *normalizeStmtGenerator) generateMergeStatement() string { return mergeStmt } -func (n *normalizeStmtGenerator) generateUpdateStatements(allCols []string) []string { +func (n *normalizeStmtGenerator) generateUpdateStatements(quotedCols []string) []string { handleSoftDelete := n.peerdbCols.SoftDelete && (n.peerdbCols.SoftDeleteColName != "") // weird way of doing it but avoids prealloc lint updateStmts := make([]string, 0, func() int { @@ -194,28 +195,28 @@ func (n *normalizeStmtGenerator) generateUpdateStatements(allCols []string) []st unquotedUnchangedColsArray := strings.Split(cols, ",") unchangedColsArray := make([]string, 0, len(unquotedUnchangedColsArray)) for _, unchangedToastCol := range unquotedUnchangedColsArray { - unchangedColsArray = append(unchangedColsArray, fmt.Sprintf(`"%s"`, unchangedToastCol)) + unchangedColsArray = append(unchangedColsArray, QuoteIdentifier(unchangedToastCol)) } - otherCols := utils.ArrayMinus(allCols, unchangedColsArray) + otherCols := utils.ArrayMinus(quotedCols, unchangedColsArray) tmpArray := make([]string, 0, len(otherCols)) for _, colName := range otherCols { tmpArray = append(tmpArray, fmt.Sprintf("%s=src.%s", colName, colName)) } // set the synced at column to the current timestamp if n.peerdbCols.SyncedAtColName != "" { - tmpArray = append(tmpArray, fmt.Sprintf(`"%s"=CURRENT_TIMESTAMP`, - n.peerdbCols.SyncedAtColName)) + tmpArray = append(tmpArray, fmt.Sprintf(`%s=CURRENT_TIMESTAMP`, + QuoteIdentifier(n.peerdbCols.SyncedAtColName))) } // set soft-deleted to false, tackles insert after soft-delete if handleSoftDelete { - tmpArray = append(tmpArray, fmt.Sprintf(`"%s"=FALSE`, - n.peerdbCols.SoftDeleteColName)) + tmpArray = append(tmpArray, fmt.Sprintf(`%s=FALSE`, + QuoteIdentifier(n.peerdbCols.SoftDeleteColName))) } ssep := strings.Join(tmpArray, ",") updateStmt := fmt.Sprintf(`WHEN MATCHED AND - src._peerdb_record_type!=2 AND _peerdb_unchanged_toast_columns='%s' - THEN UPDATE SET %s`, cols, ssep) + src._peerdb_record_type!=2 AND _peerdb_unchanged_toast_columns=%s + THEN UPDATE SET %s`, QuoteLiteral(cols), ssep) updateStmts = append(updateStmts, updateStmt) // generates update statements for the case where updates and deletes happen in the same branch @@ -223,11 +224,11 @@ func (n *normalizeStmtGenerator) generateUpdateStatements(allCols []string) []st // and then set soft-delete to true. if handleSoftDelete { tmpArray = append(tmpArray[:len(tmpArray)-1], - fmt.Sprintf(`"%s"=TRUE`, n.peerdbCols.SoftDeleteColName)) + fmt.Sprintf(`%s=TRUE`, QuoteIdentifier(n.peerdbCols.SoftDeleteColName))) ssep := strings.Join(tmpArray, ", ") updateStmt := fmt.Sprintf(`WHEN MATCHED AND - src._peerdb_record_type=2 AND _peerdb_unchanged_toast_columns='%s' - THEN UPDATE SET %s `, cols, ssep) + src._peerdb_record_type=2 AND _peerdb_unchanged_toast_columns=%s + THEN UPDATE SET %s`, QuoteLiteral(cols), ssep) updateStmts = append(updateStmts, updateStmt) } } diff --git a/flow/connectors/postgres/qrep.go b/flow/connectors/postgres/qrep.go index 2e45647280..f34e0a13bd 100644 --- a/flow/connectors/postgres/qrep.go +++ b/flow/connectors/postgres/qrep.go @@ -68,7 +68,7 @@ func (c *PostgresConnector) GetQRepPartitions( func (c *PostgresConnector) setTransactionSnapshot(tx pgx.Tx) error { snapshot := c.config.TransactionSnapshot if snapshot != "" { - if _, err := tx.Exec(c.ctx, fmt.Sprintf("SET TRANSACTION SNAPSHOT '%s'", snapshot)); err != nil { + if _, err := tx.Exec(c.ctx, fmt.Sprintf("SET TRANSACTION SNAPSHOT %s", QuoteLiteral(snapshot))); err != nil { return fmt.Errorf("failed to set transaction snapshot: %w", err) } } diff --git a/flow/connectors/postgres/qrep_query_executor.go b/flow/connectors/postgres/qrep_query_executor.go index 627c2e2fcf..44551a124a 100644 --- a/flow/connectors/postgres/qrep_query_executor.go +++ b/flow/connectors/postgres/qrep_query_executor.go @@ -365,7 +365,7 @@ func (qe *QRepQueryExecutor) ExecuteAndProcessQueryStreamWithTx( }() if qe.snapshot != "" { - _, err = tx.Exec(qe.ctx, fmt.Sprintf("SET TRANSACTION SNAPSHOT '%s'", qe.snapshot)) + _, err = tx.Exec(qe.ctx, fmt.Sprintf("SET TRANSACTION SNAPSHOT %s", QuoteLiteral(qe.snapshot))) if err != nil { stream.Records <- model.QRecordOrError{ Err: fmt.Errorf("failed to set snapshot: %w", err), diff --git a/flow/connectors/postgres/qrep_sync_method.go b/flow/connectors/postgres/qrep_sync_method.go index 6725032411..6fdda9ab54 100644 --- a/flow/connectors/postgres/qrep_sync_method.go +++ b/flow/connectors/postgres/qrep_sync_method.go @@ -85,10 +85,10 @@ func (s *QRepStagingTableSync) SyncQRepRecords( if syncedAtCol != "" { updateSyncedAtStmt := fmt.Sprintf( - `UPDATE %s SET "%s" = CURRENT_TIMESTAMP WHERE "%s" IS NULL;`, + `UPDATE %s SET %s = CURRENT_TIMESTAMP WHERE %s IS NULL;`, pgx.Identifier{dstTableName.Schema, dstTableName.Table}.Sanitize(), - syncedAtCol, - syncedAtCol, + QuoteIdentifier(syncedAtCol), + QuoteIdentifier(syncedAtCol), ) _, err = tx.Exec(context.Background(), updateSyncedAtStmt) if err != nil { @@ -137,22 +137,23 @@ func (s *QRepStagingTableSync) SyncQRepRecords( selectStrArray := make([]string, 0) for _, col := range schema.GetColumnNames() { _, ok := upsertMatchCols[col] + quotedCol := QuoteIdentifier(col) if !ok { - setClauseArray = append(setClauseArray, fmt.Sprintf(`"%s" = EXCLUDED."%s"`, col, col)) + setClauseArray = append(setClauseArray, fmt.Sprintf(`%s = EXCLUDED.%s`, quotedCol, quotedCol)) } - selectStrArray = append(selectStrArray, fmt.Sprintf(`"%s"`, col)) + selectStrArray = append(selectStrArray, quotedCol) } setClauseArray = append(setClauseArray, - fmt.Sprintf(`"%s" = CURRENT_TIMESTAMP`, syncedAtCol)) + fmt.Sprintf(`%s = CURRENT_TIMESTAMP`, QuoteIdentifier(syncedAtCol))) setClause := strings.Join(setClauseArray, ",") selectSQL := strings.Join(selectStrArray, ",") // Step 2.3: Perform the upsert operation, ON CONFLICT UPDATE upsertStmt := fmt.Sprintf( - `INSERT INTO %s (%s, "%s") SELECT %s, CURRENT_TIMESTAMP FROM %s ON CONFLICT (%s) DO UPDATE SET %s;`, + `INSERT INTO %s (%s, %s) SELECT %s, CURRENT_TIMESTAMP FROM %s ON CONFLICT (%s) DO UPDATE SET %s;`, dstTableIdentifier.Sanitize(), selectSQL, - syncedAtCol, + QuoteIdentifier(syncedAtCol), selectSQL, stagingTableIdentifier.Sanitize(), strings.Join(writeMode.UpsertKeyColumns, ", "), diff --git a/flow/e2e/postgres/qrep_flow_pg_test.go b/flow/e2e/postgres/qrep_flow_pg_test.go index fa9ad948ab..2943e201b5 100644 --- a/flow/e2e/postgres/qrep_flow_pg_test.go +++ b/flow/e2e/postgres/qrep_flow_pg_test.go @@ -116,7 +116,7 @@ func (s PeerFlowE2ETestSuitePG) checkEnums(srcSchemaQualified, dstSchemaQualifie } if exists.Bool { - return fmt.Errorf("enum comparison failed: rows are not equal\n") + return fmt.Errorf("enum comparison failed: rows are not equal") } return nil }