Skip to content

Commit

Permalink
GetSlotInfo: only list replication slots in peer's database
Browse files Browse the repository at this point in the history
Also introduce escape.go to help with properly escaping strings/identifiers

Update all cases of "%s" in connectors/postgres,
along with a few strings such as in GetSlotInfo
  • Loading branch information
serprex committed Jan 17, 2024
1 parent 8c5bb09 commit 457d203
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 45 deletions.
11 changes: 6 additions & 5 deletions flow/connectors/postgres/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,9 @@ func (c *PostgresConnector) checkSlotAndPublication(slot string, publication str
func (c *PostgresConnector) GetSlotInfo(slotName string) ([]*protos.SlotInfo, error) {
specificSlotClause := ""
if slotName != "" {
specificSlotClause = fmt.Sprintf(" WHERE slot_name = '%s'", slotName)
specificSlotClause = fmt.Sprintf(" WHERE slot_name = %s", QuoteLiteral(slotName))
} else {
specificSlotClause = fmt.Sprintf(" WHERE database = %s", QuoteLiteral(c.config.Database))
}
rows, err := c.pool.Query(c.ctx, "SELECT slot_name, redo_lsn::Text,restart_lsn::text,wal_status,"+
"confirmed_flush_lsn::text,active,"+
Expand Down Expand Up @@ -403,20 +405,19 @@ func generateCreateTableSQLForNormalizedTable(

if softDeleteColName != "" {
createTableSQLArray = append(createTableSQLArray,
fmt.Sprintf(`"%s" BOOL DEFAULT FALSE,`, softDeleteColName))
fmt.Sprintf(`%s BOOL DEFAULT FALSE,`, QuoteIdentifier(softDeleteColName)))
}

if syncedAtColName != "" {
createTableSQLArray = append(createTableSQLArray,
fmt.Sprintf(`"%s" TIMESTAMP DEFAULT CURRENT_TIMESTAMP,`, syncedAtColName))
fmt.Sprintf(`%s TIMESTAMP DEFAULT CURRENT_TIMESTAMP,`, QuoteIdentifier(syncedAtColName)))
}

// add composite primary key to the table
if len(sourceTableSchema.PrimaryKeyColumns) > 0 {
primaryKeyColsQuoted := make([]string, 0, len(sourceTableSchema.PrimaryKeyColumns))
for _, primaryKeyCol := range sourceTableSchema.PrimaryKeyColumns {
primaryKeyColsQuoted = append(primaryKeyColsQuoted,
fmt.Sprintf(`"%s"`, primaryKeyCol))
primaryKeyColsQuoted = append(primaryKeyColsQuoted, QuoteIdentifier(primaryKeyCol))
}
createTableSQLArray = append(createTableSQLArray, fmt.Sprintf("PRIMARY KEY(%s),",
strings.TrimSuffix(strings.Join(primaryKeyColsQuoted, ","), ",")))
Expand Down
57 changes: 57 additions & 0 deletions flow/connectors/postgres/escape.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// from https://github.com/lib/pq/blob/v1.10.9/conn.go#L1656

package connpostgres

import "strings"

// QuoteLiteral quotes a 'literal' (e.g. a parameter, often used to pass literal
// to DDL and other statements that do not accept parameters) to be used as part
// of an SQL statement. For example:
//
// exp_date := pq.QuoteLiteral("2023-01-05 15:00:00Z")
// err := db.Exec(fmt.Sprintf("CREATE ROLE my_user VALID UNTIL %s", exp_date))
//
// Any single quotes in name will be escaped. Any backslashes (i.e. "\") will be
// replaced by two backslashes (i.e. "\\") and the C-style escape identifier
// that PostgreSQL provides ('E') will be prepended to the string.
func QuoteLiteral(literal string) string {
// This follows the PostgreSQL internal algorithm for handling quoted literals
// from libpq, which can be found in the "PQEscapeStringInternal" function,
// which is found in the libpq/fe-exec.c source file:
// https://git.postgresql.org/gitweb/?p=postgresql.git;a=blob;f=src/interfaces/libpq/fe-exec.c
//
// substitute any single-quotes (') with two single-quotes ('')
literal = strings.Replace(literal, `'`, `''`, -1)
// determine if the string has any backslashes (\) in it.
// if it does, replace any backslashes (\) with two backslashes (\\)
// then, we need to wrap the entire string with a PostgreSQL
// C-style escape. Per how "PQEscapeStringInternal" handles this case, we
// also add a space before the "E"
if strings.Contains(literal, `\`) {
literal = strings.Replace(literal, `\`, `\\`, -1)
literal = ` E'` + literal + `'`
} else {
// otherwise, we can just wrap the literal with a pair of single quotes
literal = `'` + literal + `'`
}
return literal
}

// QuoteIdentifier quotes an "identifier" (e.g. a table or a column name) to be
// used as part of an SQL statement. For example:
//
// tblname := "my_table"
// data := "my_data"
// quoted := pq.QuoteIdentifier(tblname)
// err := db.Exec(fmt.Sprintf("INSERT INTO %s VALUES ($1)", quoted), data)
//
// Any double quotes in name will be escaped. The quoted identifier will be
// case sensitive when used in a query. If the input string contains a zero
// byte, the result will be truncated immediately before it.
func QuoteIdentifier(name string) string {
end := strings.IndexRune(name, 0)
if end > -1 {
name = name[:end]
}
return `"` + strings.Replace(name, `"`, `""`, -1) + `"`
}
68 changes: 36 additions & 32 deletions flow/connectors/postgres/normalize_stmt_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,18 +48,20 @@ func (n *normalizeStmtGenerator) generateFallbackStatements() []string {
flattenedCastsSQLArray := make([]string, 0, columnCount)
primaryKeyColumnCasts := make(map[string]string)
utils.IterColumns(n.normalizedTableSchema, func(columnName, genericColumnType string) {
columnNames = append(columnNames, fmt.Sprintf("\"%s\"", columnName))
quotedCol := QuoteIdentifier(columnName)
stringCol := QuoteLiteral(columnName)
columnNames = append(columnNames, quotedCol)
pgType := qValueKindToPostgresType(genericColumnType)
if qvalue.QValueKind(genericColumnType).IsArray() {
flattenedCastsSQLArray = append(flattenedCastsSQLArray,
fmt.Sprintf("ARRAY(SELECT * FROM JSON_ARRAY_ELEMENTS_TEXT((_peerdb_data->>'%s')::JSON))::%s AS \"%s\"",
strings.Trim(columnName, "\""), pgType, columnName))
fmt.Sprintf("ARRAY(SELECT * FROM JSON_ARRAY_ELEMENTS_TEXT((_peerdb_data->>%s)::JSON))::%s AS %s",
stringCol, pgType, quotedCol))
} else {
flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("(_peerdb_data->>'%s')::%s AS \"%s\"",
strings.Trim(columnName, "\""), pgType, columnName))
flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("(_peerdb_data->>%s)::%s AS %s",
stringCol, pgType, quotedCol))
}
if slices.Contains(n.normalizedTableSchema.PrimaryKeyColumns, columnName) {
primaryKeyColumnCasts[columnName] = fmt.Sprintf("(_peerdb_data->>'%s')::%s", columnName, pgType)
primaryKeyColumnCasts[columnName] = fmt.Sprintf("(_peerdb_data->>%s)::%s", stringCol, pgType)
}
})
flattenedCastsSQL := strings.TrimSuffix(strings.Join(flattenedCastsSQLArray, ","), ",")
Expand All @@ -68,25 +70,26 @@ func (n *normalizeStmtGenerator) generateFallbackStatements() []string {
insertColumnsSQL := strings.TrimSuffix(strings.Join(columnNames, ","), ",")
updateColumnsSQLArray := make([]string, 0, utils.TableSchemaColumns(n.normalizedTableSchema))
utils.IterColumns(n.normalizedTableSchema, func(columnName, _ string) {
updateColumnsSQLArray = append(updateColumnsSQLArray, fmt.Sprintf(`"%s"=EXCLUDED."%s"`, columnName, columnName))
quotedCol := QuoteIdentifier(columnName)
updateColumnsSQLArray = append(updateColumnsSQLArray, fmt.Sprintf(`%s=EXCLUDED.%s`, quotedCol, quotedCol))
})
updateColumnsSQL := strings.TrimSuffix(strings.Join(updateColumnsSQLArray, ","), ",")
deleteWhereClauseArray := make([]string, 0, len(n.normalizedTableSchema.PrimaryKeyColumns))
for columnName, columnCast := range primaryKeyColumnCasts {
deleteWhereClauseArray = append(deleteWhereClauseArray, fmt.Sprintf(`%s."%s"=%s AND `,
parsedDstTable.String(), columnName, columnCast))
deleteWhereClauseArray = append(deleteWhereClauseArray, fmt.Sprintf(`%s.%s=%s AND `,
parsedDstTable.String(), QuoteIdentifier(columnName), columnCast))
}
deleteWhereClauseSQL := strings.TrimSuffix(strings.Join(deleteWhereClauseArray, ""), "AND ")
deletePart := fmt.Sprintf(
"DELETE FROM %s USING",
parsedDstTable.String())

if n.peerdbCols.SoftDelete {
deletePart = fmt.Sprintf(`UPDATE %s SET "%s"=TRUE`,
parsedDstTable.String(), n.peerdbCols.SoftDeleteColName)
deletePart = fmt.Sprintf(`UPDATE %s SET %s=TRUE`,
parsedDstTable.String(), QuoteIdentifier(n.peerdbCols.SoftDeleteColName))
if n.peerdbCols.SyncedAtColName != "" {
deletePart = fmt.Sprintf(`%s,"%s"=CURRENT_TIMESTAMP`,
deletePart, n.peerdbCols.SyncedAtColName)
deletePart = fmt.Sprintf(`%s,%s=CURRENT_TIMESTAMP`,
deletePart, QuoteIdentifier(n.peerdbCols.SyncedAtColName))
}
deletePart += " FROM"
}
Expand All @@ -104,7 +107,7 @@ func (n *normalizeStmtGenerator) generateFallbackStatements() []string {
func (n *normalizeStmtGenerator) generateMergeStatement() string {
columnNames := utils.TableSchemaColumnNames(n.normalizedTableSchema)
for i, columnName := range columnNames {
columnNames[i] = fmt.Sprintf("\"%s\"", columnName)
columnNames[i] = QuoteIdentifier(columnName)
}

flattenedCastsSQLArray := make([]string, 0, utils.TableSchemaColumns(n.normalizedTableSchema))
Expand All @@ -113,38 +116,39 @@ func (n *normalizeStmtGenerator) generateMergeStatement() string {
primaryKeyColumnCasts := make(map[string]string)
primaryKeySelectSQLArray := make([]string, 0, len(n.normalizedTableSchema.PrimaryKeyColumns))
utils.IterColumns(n.normalizedTableSchema, func(columnName, genericColumnType string) {
quotedCol := QuoteIdentifier(columnName)
pgType := qValueKindToPostgresType(genericColumnType)
if qvalue.QValueKind(genericColumnType).IsArray() {
flattenedCastsSQLArray = append(flattenedCastsSQLArray,
fmt.Sprintf("ARRAY(SELECT * FROM JSON_ARRAY_ELEMENTS_TEXT((_peerdb_data->>'%s')::JSON))::%s AS \"%s\"",
strings.Trim(columnName, "\""), pgType, columnName))
fmt.Sprintf("ARRAY(SELECT * FROM JSON_ARRAY_ELEMENTS_TEXT((_peerdb_data->>%s)::JSON))::%s AS %s",
QuoteLiteral(columnName), pgType, quotedCol))
} else {
flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("(_peerdb_data->>'%s')::%s AS \"%s\"",
strings.Trim(columnName, "\""), pgType, columnName))
flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("(_peerdb_data->>%s)::%s AS %s",
QuoteLiteral(columnName), pgType, quotedCol))
}
if slices.Contains(n.normalizedTableSchema.PrimaryKeyColumns, columnName) {
primaryKeyColumnCasts[columnName] = fmt.Sprintf("(_peerdb_data->>'%s')::%s", columnName, pgType)
primaryKeyColumnCasts[columnName] = fmt.Sprintf("(_peerdb_data->>%s)::%s", columnName, pgType)
primaryKeySelectSQLArray = append(primaryKeySelectSQLArray, fmt.Sprintf("src.%s=dst.%s",
columnName, columnName))
quotedCol, quotedCol))
}
})
flattenedCastsSQL := strings.TrimSuffix(strings.Join(flattenedCastsSQLArray, ","), ",")
insertValuesSQLArray := make([]string, 0, len(columnNames))
for _, columnName := range columnNames {
insertValuesSQLArray = append(insertValuesSQLArray, fmt.Sprintf("src.%s", columnName))
insertValuesSQLArray = append(insertValuesSQLArray, fmt.Sprintf("src.%s", QuoteIdentifier(columnName)))
}

updateStatementsforToastCols := n.generateUpdateStatements(columnNames)
// append synced_at column
columnNames = append(columnNames, fmt.Sprintf(`"%s"`, n.peerdbCols.SyncedAtColName))
columnNames = append(columnNames, QuoteIdentifier(n.peerdbCols.SyncedAtColName))
insertColumnsSQL := strings.Join(columnNames, ",")
// fill in synced_at column
insertValuesSQLArray = append(insertValuesSQLArray, "CURRENT_TIMESTAMP")
insertValuesSQL := strings.TrimSuffix(strings.Join(insertValuesSQLArray, ","), ",")

if n.peerdbCols.SoftDelete {
softDeleteInsertColumnsSQL := strings.TrimSuffix(strings.Join(append(columnNames,
fmt.Sprintf(`"%s"`, n.peerdbCols.SoftDeleteColName)), ","), ",")
QuoteIdentifier(n.peerdbCols.SoftDeleteColName)), ","), ",")
softDeleteInsertValuesSQL := strings.Join(append(insertValuesSQLArray, "TRUE"), ",")

updateStatementsforToastCols = append(updateStatementsforToastCols,
Expand All @@ -156,10 +160,10 @@ func (n *normalizeStmtGenerator) generateMergeStatement() string {
deletePart := "DELETE"
if n.peerdbCols.SoftDelete {
colName := n.peerdbCols.SoftDeleteColName
deletePart = fmt.Sprintf(`UPDATE SET "%s"=TRUE`, colName)
deletePart = fmt.Sprintf(`UPDATE SET %s=TRUE`, QuoteIdentifier(colName))
if n.peerdbCols.SyncedAtColName != "" {
deletePart = fmt.Sprintf(`%s,"%s"=CURRENT_TIMESTAMP`,
deletePart, n.peerdbCols.SyncedAtColName)
deletePart = fmt.Sprintf(`%s,%s=CURRENT_TIMESTAMP`,
deletePart, QuoteIdentifier(n.peerdbCols.SyncedAtColName))
}
}

Expand Down Expand Up @@ -194,7 +198,7 @@ func (n *normalizeStmtGenerator) generateUpdateStatements(allCols []string) []st
unquotedUnchangedColsArray := strings.Split(cols, ",")
unchangedColsArray := make([]string, 0, len(unquotedUnchangedColsArray))
for _, unchangedToastCol := range unquotedUnchangedColsArray {
unchangedColsArray = append(unchangedColsArray, fmt.Sprintf(`"%s"`, unchangedToastCol))
unchangedColsArray = append(unchangedColsArray, QuoteIdentifier(unchangedToastCol))
}
otherCols := utils.ArrayMinus(allCols, unchangedColsArray)
tmpArray := make([]string, 0, len(otherCols))
Expand All @@ -203,13 +207,13 @@ func (n *normalizeStmtGenerator) generateUpdateStatements(allCols []string) []st
}
// set the synced at column to the current timestamp
if n.peerdbCols.SyncedAtColName != "" {
tmpArray = append(tmpArray, fmt.Sprintf(`"%s"=CURRENT_TIMESTAMP`,
n.peerdbCols.SyncedAtColName))
tmpArray = append(tmpArray, fmt.Sprintf(`%s=CURRENT_TIMESTAMP`,
QuoteIdentifier(n.peerdbCols.SyncedAtColName)))
}
// set soft-deleted to false, tackles insert after soft-delete
if handleSoftDelete {
tmpArray = append(tmpArray, fmt.Sprintf(`"%s"=FALSE`,
n.peerdbCols.SoftDeleteColName))
tmpArray = append(tmpArray, fmt.Sprintf(`%s=FALSE`,
QuoteIdentifier(n.peerdbCols.SoftDeleteColName)))
}

ssep := strings.Join(tmpArray, ",")
Expand All @@ -223,7 +227,7 @@ func (n *normalizeStmtGenerator) generateUpdateStatements(allCols []string) []st
// and then set soft-delete to true.
if handleSoftDelete {
tmpArray = append(tmpArray[:len(tmpArray)-1],
fmt.Sprintf(`"%s"=TRUE`, n.peerdbCols.SoftDeleteColName))
fmt.Sprintf(`%s=TRUE`, QuoteIdentifier(n.peerdbCols.SoftDeleteColName)))
ssep := strings.Join(tmpArray, ", ")
updateStmt := fmt.Sprintf(`WHEN MATCHED AND
src._peerdb_record_type=2 AND _peerdb_unchanged_toast_columns='%s'
Expand Down
17 changes: 9 additions & 8 deletions flow/connectors/postgres/qrep_sync_method.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,10 @@ func (s *QRepStagingTableSync) SyncQRepRecords(

if syncedAtCol != "" {
updateSyncedAtStmt := fmt.Sprintf(
`UPDATE %s SET "%s" = CURRENT_TIMESTAMP WHERE "%s" IS NULL;`,
`UPDATE %s SET %s = CURRENT_TIMESTAMP WHERE %s IS NULL;`,
pgx.Identifier{dstTableName.Schema, dstTableName.Table}.Sanitize(),
syncedAtCol,
syncedAtCol,
QuoteIdentifier(syncedAtCol),
QuoteIdentifier(syncedAtCol),
)
_, err = tx.Exec(context.Background(), updateSyncedAtStmt)
if err != nil {
Expand Down Expand Up @@ -137,22 +137,23 @@ func (s *QRepStagingTableSync) SyncQRepRecords(
selectStrArray := make([]string, 0)
for _, col := range schema.GetColumnNames() {
_, ok := upsertMatchCols[col]
quotedCol := QuoteIdentifier(col)
if !ok {
setClauseArray = append(setClauseArray, fmt.Sprintf(`"%s" = EXCLUDED."%s"`, col, col))
setClauseArray = append(setClauseArray, fmt.Sprintf(`%s = EXCLUDED.%s`, quotedCol, quotedCol))
}
selectStrArray = append(selectStrArray, fmt.Sprintf(`"%s"`, col))
selectStrArray = append(selectStrArray, quotedCol)
}
setClauseArray = append(setClauseArray,
fmt.Sprintf(`"%s" = CURRENT_TIMESTAMP`, syncedAtCol))
fmt.Sprintf(`%s = CURRENT_TIMESTAMP`, QuoteIdentifier(syncedAtCol)))
setClause := strings.Join(setClauseArray, ",")
selectSQL := strings.Join(selectStrArray, ",")

// Step 2.3: Perform the upsert operation, ON CONFLICT UPDATE
upsertStmt := fmt.Sprintf(
`INSERT INTO %s (%s, "%s") SELECT %s, CURRENT_TIMESTAMP FROM %s ON CONFLICT (%s) DO UPDATE SET %s;`,
`INSERT INTO %s (%s, %s) SELECT %s, CURRENT_TIMESTAMP FROM %s ON CONFLICT (%s) DO UPDATE SET %s;`,
dstTableIdentifier.Sanitize(),
selectSQL,
syncedAtCol,
QuoteIdentifier(syncedAtCol),
selectSQL,
stagingTableIdentifier.Sanitize(),
strings.Join(writeMode.UpsertKeyColumns, ", "),
Expand Down

0 comments on commit 457d203

Please sign in to comment.