Skip to content

Commit

Permalink
GetSlotInfo: only list replication slots in peer's database (#1090)
Browse files Browse the repository at this point in the history
Also introduce escape.go to help with properly escaping strings/identifiers

Update all cases of `"%s"` / `'%s'` in queries in connectors/postgres
  • Loading branch information
serprex authored Jan 17, 2024
1 parent 8c5bb09 commit c92825e
Show file tree
Hide file tree
Showing 7 changed files with 146 additions and 90 deletions.
24 changes: 11 additions & 13 deletions flow/connectors/postgres/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ const (
getTableNameToUnchangedToastColsSQL = `SELECT _peerdb_destination_table_name,
ARRAY_AGG(DISTINCT _peerdb_unchanged_toast_columns) FROM %s.%s WHERE
_peerdb_batch_id>$1 AND _peerdb_batch_id<=$2 AND _peerdb_record_type!=2 GROUP BY _peerdb_destination_table_name`
srcTableName = "src"
mergeStatementSQL = `WITH src_rank AS (
SELECT _peerdb_data,_peerdb_record_type,_peerdb_unchanged_toast_columns,
RANK() OVER (PARTITION BY %s ORDER BY _peerdb_timestamp DESC) AS _peerdb_rank
Expand All @@ -55,10 +54,8 @@ const (
USING (SELECT %s,_peerdb_record_type,_peerdb_unchanged_toast_columns FROM src_rank WHERE _peerdb_rank=1) src
ON %s
WHEN NOT MATCHED AND src._peerdb_record_type!=2 THEN
INSERT (%s) VALUES (%s)
%s
WHEN MATCHED AND src._peerdb_record_type=2 THEN
%s`
INSERT (%s) VALUES (%s) %s
WHEN MATCHED AND src._peerdb_record_type=2 THEN %s`
fallbackUpsertStatementSQL = `WITH src_rank AS (
SELECT _peerdb_data,_peerdb_record_type,_peerdb_unchanged_toast_columns,
RANK() OVER (PARTITION BY %s ORDER BY _peerdb_timestamp DESC) AS _peerdb_rank
Expand All @@ -71,7 +68,7 @@ const (
RANK() OVER (PARTITION BY %s ORDER BY _peerdb_timestamp DESC) AS _peerdb_rank
FROM %s.%s WHERE _peerdb_batch_id>$1 AND _peerdb_batch_id<=$2 AND _peerdb_destination_table_name=$3
)
%s src_rank WHERE %s AND src_rank._peerdb_rank=1 AND src_rank._peerdb_record_type=2`
DELETE FROM %s USING %s FROM src_rank WHERE %s AND src_rank._peerdb_rank=1 AND src_rank._peerdb_record_type=2`

dropTableIfExistsSQL = "DROP TABLE IF EXISTS %s.%s"
deleteJobMetadataSQL = "DELETE FROM %s.%s WHERE mirror_job_name=$1"
Expand Down Expand Up @@ -246,15 +243,17 @@ func (c *PostgresConnector) checkSlotAndPublication(slot string, publication str
// If slotName input is empty, all slot info rows are returned - this is for UI.
// Else, only the row pertaining to that slotName will be returned.
func (c *PostgresConnector) GetSlotInfo(slotName string) ([]*protos.SlotInfo, error) {
specificSlotClause := ""
whereClause := ""
if slotName != "" {
specificSlotClause = fmt.Sprintf(" WHERE slot_name = '%s'", slotName)
whereClause = fmt.Sprintf(" WHERE slot_name = %s", QuoteLiteral(slotName))
} else {
whereClause = fmt.Sprintf(" WHERE database = %s", QuoteLiteral(c.config.Database))
}
rows, err := c.pool.Query(c.ctx, "SELECT slot_name, redo_lsn::Text,restart_lsn::text,wal_status,"+
"confirmed_flush_lsn::text,active,"+
"round((CASE WHEN pg_is_in_recovery() THEN pg_last_wal_receive_lsn() ELSE pg_current_wal_lsn() END"+
" - confirmed_flush_lsn) / 1024 / 1024) AS MB_Behind"+
" FROM pg_control_checkpoint(), pg_replication_slots"+specificSlotClause+";")
" FROM pg_control_checkpoint(), pg_replication_slots"+whereClause)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -403,20 +402,19 @@ func generateCreateTableSQLForNormalizedTable(

if softDeleteColName != "" {
createTableSQLArray = append(createTableSQLArray,
fmt.Sprintf(`"%s" BOOL DEFAULT FALSE,`, softDeleteColName))
fmt.Sprintf(`%s BOOL DEFAULT FALSE,`, QuoteIdentifier(softDeleteColName)))
}

if syncedAtColName != "" {
createTableSQLArray = append(createTableSQLArray,
fmt.Sprintf(`"%s" TIMESTAMP DEFAULT CURRENT_TIMESTAMP,`, syncedAtColName))
fmt.Sprintf(`%s TIMESTAMP DEFAULT CURRENT_TIMESTAMP,`, QuoteIdentifier(syncedAtColName)))
}

// add composite primary key to the table
if len(sourceTableSchema.PrimaryKeyColumns) > 0 {
primaryKeyColsQuoted := make([]string, 0, len(sourceTableSchema.PrimaryKeyColumns))
for _, primaryKeyCol := range sourceTableSchema.PrimaryKeyColumns {
primaryKeyColsQuoted = append(primaryKeyColsQuoted,
fmt.Sprintf(`"%s"`, primaryKeyCol))
primaryKeyColsQuoted = append(primaryKeyColsQuoted, QuoteIdentifier(primaryKeyCol))
}
createTableSQLArray = append(createTableSQLArray, fmt.Sprintf("PRIMARY KEY(%s),",
strings.TrimSuffix(strings.Join(primaryKeyColsQuoted, ","), ",")))
Expand Down
57 changes: 57 additions & 0 deletions flow/connectors/postgres/escape.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// from https://github.com/lib/pq/blob/v1.10.9/conn.go#L1656

package connpostgres

import "strings"

// QuoteLiteral quotes a 'literal' (e.g. a parameter, often used to pass literal
// to DDL and other statements that do not accept parameters) to be used as part
// of an SQL statement. For example:
//
// exp_date := pq.QuoteLiteral("2023-01-05 15:00:00Z")
// err := db.Exec(fmt.Sprintf("CREATE ROLE my_user VALID UNTIL %s", exp_date))
//
// Any single quotes in name will be escaped. Any backslashes (i.e. "\") will be
// replaced by two backslashes (i.e. "\\") and the C-style escape identifier
// that PostgreSQL provides ('E') will be prepended to the string.
func QuoteLiteral(literal string) string {
// This follows the PostgreSQL internal algorithm for handling quoted literals
// from libpq, which can be found in the "PQEscapeStringInternal" function,
// which is found in the libpq/fe-exec.c source file:
// https://git.postgresql.org/gitweb/?p=postgresql.git;a=blob;f=src/interfaces/libpq/fe-exec.c
//
// substitute any single-quotes (') with two single-quotes ('')
literal = strings.Replace(literal, `'`, `''`, -1)
// determine if the string has any backslashes (\) in it.
// if it does, replace any backslashes (\) with two backslashes (\\)
// then, we need to wrap the entire string with a PostgreSQL
// C-style escape. Per how "PQEscapeStringInternal" handles this case, we
// also add a space before the "E"
if strings.Contains(literal, `\`) {
literal = strings.Replace(literal, `\`, `\\`, -1)
literal = ` E'` + literal + `'`
} else {
// otherwise, we can just wrap the literal with a pair of single quotes
literal = `'` + literal + `'`
}
return literal
}

// QuoteIdentifier quotes an "identifier" (e.g. a table or a column name) to be
// used as part of an SQL statement. For example:
//
// tblname := "my_table"
// data := "my_data"
// quoted := pq.QuoteIdentifier(tblname)
// err := db.Exec(fmt.Sprintf("INSERT INTO %s VALUES ($1)", quoted), data)
//
// Any double quotes in name will be escaped. The quoted identifier will be
// case sensitive when used in a query. If the input string contains a zero
// byte, the result will be truncated immediately before it.
func QuoteIdentifier(name string) string {
end := strings.IndexRune(name, 0)
if end > -1 {
name = name[:end]
}
return `"` + strings.Replace(name, `"`, `""`, -1) + `"`
}
132 changes: 66 additions & 66 deletions flow/connectors/postgres/normalize_stmt_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,65 +46,64 @@ func (n *normalizeStmtGenerator) generateFallbackStatements() []string {
columnCount := utils.TableSchemaColumns(n.normalizedTableSchema)
columnNames := make([]string, 0, columnCount)
flattenedCastsSQLArray := make([]string, 0, columnCount)
primaryKeyColumnCasts := make(map[string]string)
primaryKeyColumnCasts := make(map[string]string, len(n.normalizedTableSchema.PrimaryKeyColumns))
utils.IterColumns(n.normalizedTableSchema, func(columnName, genericColumnType string) {
columnNames = append(columnNames, fmt.Sprintf("\"%s\"", columnName))
quotedCol := QuoteIdentifier(columnName)
stringCol := QuoteLiteral(columnName)
columnNames = append(columnNames, quotedCol)
pgType := qValueKindToPostgresType(genericColumnType)
if qvalue.QValueKind(genericColumnType).IsArray() {
flattenedCastsSQLArray = append(flattenedCastsSQLArray,
fmt.Sprintf("ARRAY(SELECT * FROM JSON_ARRAY_ELEMENTS_TEXT((_peerdb_data->>'%s')::JSON))::%s AS \"%s\"",
strings.Trim(columnName, "\""), pgType, columnName))
fmt.Sprintf("ARRAY(SELECT * FROM JSON_ARRAY_ELEMENTS_TEXT((_peerdb_data->>%s)::JSON))::%s AS %s",
stringCol, pgType, quotedCol))
} else {
flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("(_peerdb_data->>'%s')::%s AS \"%s\"",
strings.Trim(columnName, "\""), pgType, columnName))
flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("(_peerdb_data->>%s)::%s AS %s",
stringCol, pgType, quotedCol))
}
if slices.Contains(n.normalizedTableSchema.PrimaryKeyColumns, columnName) {
primaryKeyColumnCasts[columnName] = fmt.Sprintf("(_peerdb_data->>'%s')::%s", columnName, pgType)
primaryKeyColumnCasts[columnName] = fmt.Sprintf("(_peerdb_data->>%s)::%s", stringCol, pgType)
}
})
flattenedCastsSQL := strings.TrimSuffix(strings.Join(flattenedCastsSQLArray, ","), ",")
flattenedCastsSQL := strings.Join(flattenedCastsSQLArray, ",")
parsedDstTable, _ := utils.ParseSchemaTable(n.dstTableName)

insertColumnsSQL := strings.TrimSuffix(strings.Join(columnNames, ","), ",")
insertColumnsSQL := strings.Join(columnNames, ",")
updateColumnsSQLArray := make([]string, 0, utils.TableSchemaColumns(n.normalizedTableSchema))
utils.IterColumns(n.normalizedTableSchema, func(columnName, _ string) {
updateColumnsSQLArray = append(updateColumnsSQLArray, fmt.Sprintf(`"%s"=EXCLUDED."%s"`, columnName, columnName))
quotedCol := QuoteIdentifier(columnName)
updateColumnsSQLArray = append(updateColumnsSQLArray, fmt.Sprintf(`%s=EXCLUDED.%s`, quotedCol, quotedCol))
})
updateColumnsSQL := strings.TrimSuffix(strings.Join(updateColumnsSQLArray, ","), ",")
updateColumnsSQL := strings.Join(updateColumnsSQLArray, ",")
deleteWhereClauseArray := make([]string, 0, len(n.normalizedTableSchema.PrimaryKeyColumns))
for columnName, columnCast := range primaryKeyColumnCasts {
deleteWhereClauseArray = append(deleteWhereClauseArray, fmt.Sprintf(`%s."%s"=%s AND `,
parsedDstTable.String(), columnName, columnCast))
deleteWhereClauseArray = append(deleteWhereClauseArray, fmt.Sprintf(`%s.%s=%s`,
parsedDstTable.String(), QuoteIdentifier(columnName), columnCast))
}
deleteWhereClauseSQL := strings.TrimSuffix(strings.Join(deleteWhereClauseArray, ""), "AND ")
deletePart := fmt.Sprintf(
"DELETE FROM %s USING",
parsedDstTable.String())
deleteWhereClauseSQL := strings.Join(deleteWhereClauseArray, " AND ")

deleteUpdate := ""
if n.peerdbCols.SoftDelete {
deletePart = fmt.Sprintf(`UPDATE %s SET "%s"=TRUE`,
parsedDstTable.String(), n.peerdbCols.SoftDeleteColName)
deleteUpdate = fmt.Sprintf(`UPDATE %s SET %s=TRUE`,
parsedDstTable.String(), QuoteIdentifier(n.peerdbCols.SoftDeleteColName))
if n.peerdbCols.SyncedAtColName != "" {
deletePart = fmt.Sprintf(`%s,"%s"=CURRENT_TIMESTAMP`,
deletePart, n.peerdbCols.SyncedAtColName)
deleteUpdate += fmt.Sprintf(`,%s=CURRENT_TIMESTAMP`, QuoteIdentifier(n.peerdbCols.SyncedAtColName))
}
deletePart += " FROM"
}
fallbackUpsertStatement := fmt.Sprintf(fallbackUpsertStatementSQL,
strings.TrimSuffix(strings.Join(maps.Values(primaryKeyColumnCasts), ","), ","), n.metadataSchema,
strings.Join(maps.Values(primaryKeyColumnCasts), ","), n.metadataSchema,
n.rawTableName, parsedDstTable.String(), insertColumnsSQL, flattenedCastsSQL,
strings.Join(n.normalizedTableSchema.PrimaryKeyColumns, ","), updateColumnsSQL)
fallbackDeleteStatement := fmt.Sprintf(fallbackDeleteStatementSQL,
strings.Join(maps.Values(primaryKeyColumnCasts), ","), n.metadataSchema,
n.rawTableName, deletePart, deleteWhereClauseSQL)
n.rawTableName, parsedDstTable.String(), deleteUpdate, deleteWhereClauseSQL)

return []string{fallbackUpsertStatement, fallbackDeleteStatement}
}

func (n *normalizeStmtGenerator) generateMergeStatement() string {
columnNames := utils.TableSchemaColumnNames(n.normalizedTableSchema)
for i, columnName := range columnNames {
columnNames[i] = fmt.Sprintf("\"%s\"", columnName)
quotedColumnNames := utils.TableSchemaColumnNames(n.normalizedTableSchema)
for i, columnName := range quotedColumnNames {
quotedColumnNames[i] = QuoteIdentifier(columnName)
}

flattenedCastsSQLArray := make([]string, 0, utils.TableSchemaColumns(n.normalizedTableSchema))
Expand All @@ -113,38 +112,41 @@ func (n *normalizeStmtGenerator) generateMergeStatement() string {
primaryKeyColumnCasts := make(map[string]string)
primaryKeySelectSQLArray := make([]string, 0, len(n.normalizedTableSchema.PrimaryKeyColumns))
utils.IterColumns(n.normalizedTableSchema, func(columnName, genericColumnType string) {
quotedCol := QuoteIdentifier(columnName)
stringCol := QuoteLiteral(columnName)
pgType := qValueKindToPostgresType(genericColumnType)
if qvalue.QValueKind(genericColumnType).IsArray() {
flattenedCastsSQLArray = append(flattenedCastsSQLArray,
fmt.Sprintf("ARRAY(SELECT * FROM JSON_ARRAY_ELEMENTS_TEXT((_peerdb_data->>'%s')::JSON))::%s AS \"%s\"",
strings.Trim(columnName, "\""), pgType, columnName))
fmt.Sprintf("ARRAY(SELECT * FROM JSON_ARRAY_ELEMENTS_TEXT((_peerdb_data->>%s)::JSON))::%s AS %s",
stringCol, pgType, quotedCol))
} else {
flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("(_peerdb_data->>'%s')::%s AS \"%s\"",
strings.Trim(columnName, "\""), pgType, columnName))
flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("(_peerdb_data->>%s)::%s AS %s",
stringCol, pgType, quotedCol))
}
if slices.Contains(n.normalizedTableSchema.PrimaryKeyColumns, columnName) {
primaryKeyColumnCasts[columnName] = fmt.Sprintf("(_peerdb_data->>'%s')::%s", columnName, pgType)
primaryKeyColumnCasts[columnName] = fmt.Sprintf("(_peerdb_data->>%s)::%s", stringCol, pgType)
primaryKeySelectSQLArray = append(primaryKeySelectSQLArray, fmt.Sprintf("src.%s=dst.%s",
columnName, columnName))
quotedCol, quotedCol))
}
})
flattenedCastsSQL := strings.TrimSuffix(strings.Join(flattenedCastsSQLArray, ","), ",")
insertValuesSQLArray := make([]string, 0, len(columnNames))
for _, columnName := range columnNames {
insertValuesSQLArray = append(insertValuesSQLArray, fmt.Sprintf("src.%s", columnName))
flattenedCastsSQL := strings.Join(flattenedCastsSQLArray, ",")
insertValuesSQLArray := make([]string, 0, len(quotedColumnNames)+2)
for _, quotedCol := range quotedColumnNames {
insertValuesSQLArray = append(insertValuesSQLArray, fmt.Sprintf("src.%s", quotedCol))
}

updateStatementsforToastCols := n.generateUpdateStatements(columnNames)
updateStatementsforToastCols := n.generateUpdateStatements(quotedColumnNames)
// append synced_at column
columnNames = append(columnNames, fmt.Sprintf(`"%s"`, n.peerdbCols.SyncedAtColName))
insertColumnsSQL := strings.Join(columnNames, ",")
// fill in synced_at column
insertValuesSQLArray = append(insertValuesSQLArray, "CURRENT_TIMESTAMP")
insertValuesSQL := strings.TrimSuffix(strings.Join(insertValuesSQLArray, ","), ",")
if n.peerdbCols.SyncedAtColName != "" {
quotedColumnNames = append(quotedColumnNames, QuoteIdentifier(n.peerdbCols.SyncedAtColName))
insertValuesSQLArray = append(insertValuesSQLArray, "CURRENT_TIMESTAMP")
}
insertColumnsSQL := strings.Join(quotedColumnNames, ",")
insertValuesSQL := strings.Join(insertValuesSQLArray, ",")

if n.peerdbCols.SoftDelete {
softDeleteInsertColumnsSQL := strings.TrimSuffix(strings.Join(append(columnNames,
fmt.Sprintf(`"%s"`, n.peerdbCols.SoftDeleteColName)), ","), ",")
softDeleteInsertColumnsSQL := strings.Join(
append(quotedColumnNames, QuoteIdentifier(n.peerdbCols.SoftDeleteColName)), ",")
softDeleteInsertValuesSQL := strings.Join(append(insertValuesSQLArray, "TRUE"), ",")

updateStatementsforToastCols = append(updateStatementsforToastCols,
Expand All @@ -153,13 +155,12 @@ func (n *normalizeStmtGenerator) generateMergeStatement() string {
}
updateStringToastCols := strings.Join(updateStatementsforToastCols, "\n")

deletePart := "DELETE"
conflictPart := "DELETE"
if n.peerdbCols.SoftDelete {
colName := n.peerdbCols.SoftDeleteColName
deletePart = fmt.Sprintf(`UPDATE SET "%s"=TRUE`, colName)
conflictPart = fmt.Sprintf(`UPDATE SET %s=TRUE`, QuoteIdentifier(colName))
if n.peerdbCols.SyncedAtColName != "" {
deletePart = fmt.Sprintf(`%s,"%s"=CURRENT_TIMESTAMP`,
deletePart, n.peerdbCols.SyncedAtColName)
conflictPart += fmt.Sprintf(`,%s=CURRENT_TIMESTAMP`, QuoteIdentifier(n.peerdbCols.SyncedAtColName))
}
}

Expand All @@ -174,13 +175,13 @@ func (n *normalizeStmtGenerator) generateMergeStatement() string {
insertColumnsSQL,
insertValuesSQL,
updateStringToastCols,
deletePart,
conflictPart,
)

return mergeStmt
}

func (n *normalizeStmtGenerator) generateUpdateStatements(allCols []string) []string {
func (n *normalizeStmtGenerator) generateUpdateStatements(quotedCols []string) []string {
handleSoftDelete := n.peerdbCols.SoftDelete && (n.peerdbCols.SoftDeleteColName != "")
// weird way of doing it but avoids prealloc lint
updateStmts := make([]string, 0, func() int {
Expand All @@ -191,43 +192,42 @@ func (n *normalizeStmtGenerator) generateUpdateStatements(allCols []string) []st
}())

for _, cols := range n.unchangedToastColumns {
unquotedUnchangedColsArray := strings.Split(cols, ",")
unchangedColsArray := make([]string, 0, len(unquotedUnchangedColsArray))
for _, unchangedToastCol := range unquotedUnchangedColsArray {
unchangedColsArray = append(unchangedColsArray, fmt.Sprintf(`"%s"`, unchangedToastCol))
unchangedColsArray := strings.Split(cols, ",")
for i, unchangedToastCol := range unchangedColsArray {
unchangedColsArray[i] = QuoteIdentifier(unchangedToastCol)
}
otherCols := utils.ArrayMinus(allCols, unchangedColsArray)
otherCols := utils.ArrayMinus(quotedCols, unchangedColsArray)
tmpArray := make([]string, 0, len(otherCols))
for _, colName := range otherCols {
tmpArray = append(tmpArray, fmt.Sprintf("%s=src.%s", colName, colName))
}
// set the synced at column to the current timestamp
if n.peerdbCols.SyncedAtColName != "" {
tmpArray = append(tmpArray, fmt.Sprintf(`"%s"=CURRENT_TIMESTAMP`,
n.peerdbCols.SyncedAtColName))
tmpArray = append(tmpArray, fmt.Sprintf(`%s=CURRENT_TIMESTAMP`,
QuoteIdentifier(n.peerdbCols.SyncedAtColName)))
}
// set soft-deleted to false, tackles insert after soft-delete
if handleSoftDelete {
tmpArray = append(tmpArray, fmt.Sprintf(`"%s"=FALSE`,
n.peerdbCols.SoftDeleteColName))
tmpArray = append(tmpArray, fmt.Sprintf(`%s=FALSE`,
QuoteIdentifier(n.peerdbCols.SoftDeleteColName)))
}

quotedCols := QuoteLiteral(cols)
ssep := strings.Join(tmpArray, ",")
updateStmt := fmt.Sprintf(`WHEN MATCHED AND
src._peerdb_record_type!=2 AND _peerdb_unchanged_toast_columns='%s'
THEN UPDATE SET %s`, cols, ssep)
src._peerdb_record_type!=2 AND _peerdb_unchanged_toast_columns=%s
THEN UPDATE SET %s`, quotedCols, ssep)
updateStmts = append(updateStmts, updateStmt)

// generates update statements for the case where updates and deletes happen in the same branch
// the backfill has happened from the pull side already, so treat the DeleteRecord as an update
// and then set soft-delete to true.
if handleSoftDelete {
tmpArray = append(tmpArray[:len(tmpArray)-1],
fmt.Sprintf(`"%s"=TRUE`, n.peerdbCols.SoftDeleteColName))
tmpArray[len(tmpArray)-1] = fmt.Sprintf(`%s=TRUE`, QuoteIdentifier(n.peerdbCols.SoftDeleteColName))
ssep := strings.Join(tmpArray, ", ")
updateStmt := fmt.Sprintf(`WHEN MATCHED AND
src._peerdb_record_type=2 AND _peerdb_unchanged_toast_columns='%s'
THEN UPDATE SET %s `, cols, ssep)
src._peerdb_record_type=2 AND _peerdb_unchanged_toast_columns=%s
THEN UPDATE SET %s`, quotedCols, ssep)
updateStmts = append(updateStmts, updateStmt)
}
}
Expand Down
2 changes: 1 addition & 1 deletion flow/connectors/postgres/qrep.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func (c *PostgresConnector) GetQRepPartitions(
func (c *PostgresConnector) setTransactionSnapshot(tx pgx.Tx) error {
snapshot := c.config.TransactionSnapshot
if snapshot != "" {
if _, err := tx.Exec(c.ctx, fmt.Sprintf("SET TRANSACTION SNAPSHOT '%s'", snapshot)); err != nil {
if _, err := tx.Exec(c.ctx, fmt.Sprintf("SET TRANSACTION SNAPSHOT %s", QuoteLiteral(snapshot))); err != nil {
return fmt.Errorf("failed to set transaction snapshot: %w", err)
}
}
Expand Down
2 changes: 1 addition & 1 deletion flow/connectors/postgres/qrep_query_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ func (qe *QRepQueryExecutor) ExecuteAndProcessQueryStreamWithTx(
}()

if qe.snapshot != "" {
_, err = tx.Exec(qe.ctx, fmt.Sprintf("SET TRANSACTION SNAPSHOT '%s'", qe.snapshot))
_, err = tx.Exec(qe.ctx, fmt.Sprintf("SET TRANSACTION SNAPSHOT %s", QuoteLiteral(qe.snapshot)))
if err != nil {
stream.Records <- model.QRecordOrError{
Err: fmt.Errorf("failed to set snapshot: %w", err),
Expand Down
Loading

0 comments on commit c92825e

Please sign in to comment.