diff --git a/flow/connectors/postgres/client.go b/flow/connectors/postgres/client.go index 0a99bce668..547e1cd8b3 100644 --- a/flow/connectors/postgres/client.go +++ b/flow/connectors/postgres/client.go @@ -248,7 +248,9 @@ func (c *PostgresConnector) checkSlotAndPublication(slot string, publication str func (c *PostgresConnector) GetSlotInfo(slotName string) ([]*protos.SlotInfo, error) { specificSlotClause := "" if slotName != "" { - specificSlotClause = fmt.Sprintf(" WHERE slot_name = '%s'", slotName) + specificSlotClause = fmt.Sprintf(" WHERE slot_name = %s", QuoteLiteral(slotName)) + } else { + specificSlotClause = fmt.Sprintf(" WHERE database = %s", QuoteLiteral(c.config.Database)) } rows, err := c.pool.Query(c.ctx, "SELECT slot_name, redo_lsn::Text,restart_lsn::text,wal_status,"+ "confirmed_flush_lsn::text,active,"+ @@ -403,20 +405,19 @@ func generateCreateTableSQLForNormalizedTable( if softDeleteColName != "" { createTableSQLArray = append(createTableSQLArray, - fmt.Sprintf(`"%s" BOOL DEFAULT FALSE,`, softDeleteColName)) + fmt.Sprintf(`%s BOOL DEFAULT FALSE,`, QuoteIdentifier(softDeleteColName))) } if syncedAtColName != "" { createTableSQLArray = append(createTableSQLArray, - fmt.Sprintf(`"%s" TIMESTAMP DEFAULT CURRENT_TIMESTAMP,`, syncedAtColName)) + fmt.Sprintf(`%s TIMESTAMP DEFAULT CURRENT_TIMESTAMP,`, QuoteIdentifier(syncedAtColName))) } // add composite primary key to the table if len(sourceTableSchema.PrimaryKeyColumns) > 0 { primaryKeyColsQuoted := make([]string, 0, len(sourceTableSchema.PrimaryKeyColumns)) for _, primaryKeyCol := range sourceTableSchema.PrimaryKeyColumns { - primaryKeyColsQuoted = append(primaryKeyColsQuoted, - fmt.Sprintf(`"%s"`, primaryKeyCol)) + primaryKeyColsQuoted = append(primaryKeyColsQuoted, QuoteIdentifier(primaryKeyCol)) } createTableSQLArray = append(createTableSQLArray, fmt.Sprintf("PRIMARY KEY(%s),", strings.TrimSuffix(strings.Join(primaryKeyColsQuoted, ","), ","))) diff --git a/flow/connectors/postgres/escape.go b/flow/connectors/postgres/escape.go new file mode 100644 index 0000000000..280d108338 --- /dev/null +++ b/flow/connectors/postgres/escape.go @@ -0,0 +1,57 @@ +// from https://github.com/lib/pq/blob/v1.10.9/conn.go#L1656 + +package connpostgres + +import "strings" + +// QuoteLiteral quotes a 'literal' (e.g. a parameter, often used to pass literal +// to DDL and other statements that do not accept parameters) to be used as part +// of an SQL statement. For example: +// +// exp_date := pq.QuoteLiteral("2023-01-05 15:00:00Z") +// err := db.Exec(fmt.Sprintf("CREATE ROLE my_user VALID UNTIL %s", exp_date)) +// +// Any single quotes in name will be escaped. Any backslashes (i.e. "\") will be +// replaced by two backslashes (i.e. "\\") and the C-style escape identifier +// that PostgreSQL provides ('E') will be prepended to the string. +func QuoteLiteral(literal string) string { + // This follows the PostgreSQL internal algorithm for handling quoted literals + // from libpq, which can be found in the "PQEscapeStringInternal" function, + // which is found in the libpq/fe-exec.c source file: + // https://git.postgresql.org/gitweb/?p=postgresql.git;a=blob;f=src/interfaces/libpq/fe-exec.c + // + // substitute any single-quotes (') with two single-quotes ('') + literal = strings.Replace(literal, `'`, `''`, -1) + // determine if the string has any backslashes (\) in it. + // if it does, replace any backslashes (\) with two backslashes (\\) + // then, we need to wrap the entire string with a PostgreSQL + // C-style escape. Per how "PQEscapeStringInternal" handles this case, we + // also add a space before the "E" + if strings.Contains(literal, `\`) { + literal = strings.Replace(literal, `\`, `\\`, -1) + literal = ` E'` + literal + `'` + } else { + // otherwise, we can just wrap the literal with a pair of single quotes + literal = `'` + literal + `'` + } + return literal +} + +// QuoteIdentifier quotes an "identifier" (e.g. a table or a column name) to be +// used as part of an SQL statement. For example: +// +// tblname := "my_table" +// data := "my_data" +// quoted := pq.QuoteIdentifier(tblname) +// err := db.Exec(fmt.Sprintf("INSERT INTO %s VALUES ($1)", quoted), data) +// +// Any double quotes in name will be escaped. The quoted identifier will be +// case sensitive when used in a query. If the input string contains a zero +// byte, the result will be truncated immediately before it. +func QuoteIdentifier(name string) string { + end := strings.IndexRune(name, 0) + if end > -1 { + name = name[:end] + } + return `"` + strings.Replace(name, `"`, `""`, -1) + `"` +} diff --git a/flow/connectors/postgres/normalize_stmt_generator.go b/flow/connectors/postgres/normalize_stmt_generator.go index b541543fe2..ef305f31e5 100644 --- a/flow/connectors/postgres/normalize_stmt_generator.go +++ b/flow/connectors/postgres/normalize_stmt_generator.go @@ -48,18 +48,20 @@ func (n *normalizeStmtGenerator) generateFallbackStatements() []string { flattenedCastsSQLArray := make([]string, 0, columnCount) primaryKeyColumnCasts := make(map[string]string) utils.IterColumns(n.normalizedTableSchema, func(columnName, genericColumnType string) { - columnNames = append(columnNames, fmt.Sprintf("\"%s\"", columnName)) + quotedCol := QuoteIdentifier(columnName) + stringCol := QuoteLiteral(columnName) + columnNames = append(columnNames, quotedCol) pgType := qValueKindToPostgresType(genericColumnType) if qvalue.QValueKind(genericColumnType).IsArray() { flattenedCastsSQLArray = append(flattenedCastsSQLArray, - fmt.Sprintf("ARRAY(SELECT * FROM JSON_ARRAY_ELEMENTS_TEXT((_peerdb_data->>'%s')::JSON))::%s AS \"%s\"", - strings.Trim(columnName, "\""), pgType, columnName)) + fmt.Sprintf("ARRAY(SELECT * FROM JSON_ARRAY_ELEMENTS_TEXT((_peerdb_data->>%s)::JSON))::%s AS %s", + stringCol, pgType, quotedCol)) } else { - flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("(_peerdb_data->>'%s')::%s AS \"%s\"", - strings.Trim(columnName, "\""), pgType, columnName)) + flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("(_peerdb_data->>%s)::%s AS %s", + stringCol, pgType, quotedCol)) } if slices.Contains(n.normalizedTableSchema.PrimaryKeyColumns, columnName) { - primaryKeyColumnCasts[columnName] = fmt.Sprintf("(_peerdb_data->>'%s')::%s", columnName, pgType) + primaryKeyColumnCasts[columnName] = fmt.Sprintf("(_peerdb_data->>%s)::%s", stringCol, pgType) } }) flattenedCastsSQL := strings.TrimSuffix(strings.Join(flattenedCastsSQLArray, ","), ",") @@ -68,13 +70,14 @@ func (n *normalizeStmtGenerator) generateFallbackStatements() []string { insertColumnsSQL := strings.TrimSuffix(strings.Join(columnNames, ","), ",") updateColumnsSQLArray := make([]string, 0, utils.TableSchemaColumns(n.normalizedTableSchema)) utils.IterColumns(n.normalizedTableSchema, func(columnName, _ string) { - updateColumnsSQLArray = append(updateColumnsSQLArray, fmt.Sprintf(`"%s"=EXCLUDED."%s"`, columnName, columnName)) + quotedCol := QuoteIdentifier(columnName) + updateColumnsSQLArray = append(updateColumnsSQLArray, fmt.Sprintf(`%s=EXCLUDED.%s`, quotedCol, quotedCol)) }) updateColumnsSQL := strings.TrimSuffix(strings.Join(updateColumnsSQLArray, ","), ",") deleteWhereClauseArray := make([]string, 0, len(n.normalizedTableSchema.PrimaryKeyColumns)) for columnName, columnCast := range primaryKeyColumnCasts { - deleteWhereClauseArray = append(deleteWhereClauseArray, fmt.Sprintf(`%s."%s"=%s AND `, - parsedDstTable.String(), columnName, columnCast)) + deleteWhereClauseArray = append(deleteWhereClauseArray, fmt.Sprintf(`%s.%s=%s AND `, + parsedDstTable.String(), QuoteIdentifier(columnName), columnCast)) } deleteWhereClauseSQL := strings.TrimSuffix(strings.Join(deleteWhereClauseArray, ""), "AND ") deletePart := fmt.Sprintf( @@ -82,11 +85,11 @@ func (n *normalizeStmtGenerator) generateFallbackStatements() []string { parsedDstTable.String()) if n.peerdbCols.SoftDelete { - deletePart = fmt.Sprintf(`UPDATE %s SET "%s"=TRUE`, - parsedDstTable.String(), n.peerdbCols.SoftDeleteColName) + deletePart = fmt.Sprintf(`UPDATE %s SET %s=TRUE`, + parsedDstTable.String(), QuoteIdentifier(n.peerdbCols.SoftDeleteColName)) if n.peerdbCols.SyncedAtColName != "" { - deletePart = fmt.Sprintf(`%s,"%s"=CURRENT_TIMESTAMP`, - deletePart, n.peerdbCols.SyncedAtColName) + deletePart = fmt.Sprintf(`%s,%s=CURRENT_TIMESTAMP`, + deletePart, QuoteIdentifier(n.peerdbCols.SyncedAtColName)) } deletePart += " FROM" } @@ -104,7 +107,7 @@ func (n *normalizeStmtGenerator) generateFallbackStatements() []string { func (n *normalizeStmtGenerator) generateMergeStatement() string { columnNames := utils.TableSchemaColumnNames(n.normalizedTableSchema) for i, columnName := range columnNames { - columnNames[i] = fmt.Sprintf("\"%s\"", columnName) + columnNames[i] = QuoteIdentifier(columnName) } flattenedCastsSQLArray := make([]string, 0, utils.TableSchemaColumns(n.normalizedTableSchema)) @@ -113,30 +116,31 @@ func (n *normalizeStmtGenerator) generateMergeStatement() string { primaryKeyColumnCasts := make(map[string]string) primaryKeySelectSQLArray := make([]string, 0, len(n.normalizedTableSchema.PrimaryKeyColumns)) utils.IterColumns(n.normalizedTableSchema, func(columnName, genericColumnType string) { + quotedCol := QuoteIdentifier(columnName) pgType := qValueKindToPostgresType(genericColumnType) if qvalue.QValueKind(genericColumnType).IsArray() { flattenedCastsSQLArray = append(flattenedCastsSQLArray, - fmt.Sprintf("ARRAY(SELECT * FROM JSON_ARRAY_ELEMENTS_TEXT((_peerdb_data->>'%s')::JSON))::%s AS \"%s\"", - strings.Trim(columnName, "\""), pgType, columnName)) + fmt.Sprintf("ARRAY(SELECT * FROM JSON_ARRAY_ELEMENTS_TEXT((_peerdb_data->>%s)::JSON))::%s AS %s", + QuoteLiteral(columnName), pgType, quotedCol)) } else { - flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("(_peerdb_data->>'%s')::%s AS \"%s\"", - strings.Trim(columnName, "\""), pgType, columnName)) + flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("(_peerdb_data->>%s)::%s AS %s", + QuoteLiteral(columnName), pgType, quotedCol)) } if slices.Contains(n.normalizedTableSchema.PrimaryKeyColumns, columnName) { - primaryKeyColumnCasts[columnName] = fmt.Sprintf("(_peerdb_data->>'%s')::%s", columnName, pgType) + primaryKeyColumnCasts[columnName] = fmt.Sprintf("(_peerdb_data->>%s)::%s", columnName, pgType) primaryKeySelectSQLArray = append(primaryKeySelectSQLArray, fmt.Sprintf("src.%s=dst.%s", - columnName, columnName)) + quotedCol, quotedCol)) } }) flattenedCastsSQL := strings.TrimSuffix(strings.Join(flattenedCastsSQLArray, ","), ",") insertValuesSQLArray := make([]string, 0, len(columnNames)) for _, columnName := range columnNames { - insertValuesSQLArray = append(insertValuesSQLArray, fmt.Sprintf("src.%s", columnName)) + insertValuesSQLArray = append(insertValuesSQLArray, fmt.Sprintf("src.%s", QuoteIdentifier(columnName))) } updateStatementsforToastCols := n.generateUpdateStatements(columnNames) // append synced_at column - columnNames = append(columnNames, fmt.Sprintf(`"%s"`, n.peerdbCols.SyncedAtColName)) + columnNames = append(columnNames, QuoteIdentifier(n.peerdbCols.SyncedAtColName)) insertColumnsSQL := strings.Join(columnNames, ",") // fill in synced_at column insertValuesSQLArray = append(insertValuesSQLArray, "CURRENT_TIMESTAMP") @@ -144,7 +148,7 @@ func (n *normalizeStmtGenerator) generateMergeStatement() string { if n.peerdbCols.SoftDelete { softDeleteInsertColumnsSQL := strings.TrimSuffix(strings.Join(append(columnNames, - fmt.Sprintf(`"%s"`, n.peerdbCols.SoftDeleteColName)), ","), ",") + QuoteIdentifier(n.peerdbCols.SoftDeleteColName)), ","), ",") softDeleteInsertValuesSQL := strings.Join(append(insertValuesSQLArray, "TRUE"), ",") updateStatementsforToastCols = append(updateStatementsforToastCols, @@ -156,10 +160,10 @@ func (n *normalizeStmtGenerator) generateMergeStatement() string { deletePart := "DELETE" if n.peerdbCols.SoftDelete { colName := n.peerdbCols.SoftDeleteColName - deletePart = fmt.Sprintf(`UPDATE SET "%s"=TRUE`, colName) + deletePart = fmt.Sprintf(`UPDATE SET %s=TRUE`, QuoteIdentifier(colName)) if n.peerdbCols.SyncedAtColName != "" { - deletePart = fmt.Sprintf(`%s,"%s"=CURRENT_TIMESTAMP`, - deletePart, n.peerdbCols.SyncedAtColName) + deletePart = fmt.Sprintf(`%s,%s=CURRENT_TIMESTAMP`, + deletePart, QuoteIdentifier(n.peerdbCols.SyncedAtColName)) } } @@ -194,7 +198,7 @@ func (n *normalizeStmtGenerator) generateUpdateStatements(allCols []string) []st unquotedUnchangedColsArray := strings.Split(cols, ",") unchangedColsArray := make([]string, 0, len(unquotedUnchangedColsArray)) for _, unchangedToastCol := range unquotedUnchangedColsArray { - unchangedColsArray = append(unchangedColsArray, fmt.Sprintf(`"%s"`, unchangedToastCol)) + unchangedColsArray = append(unchangedColsArray, QuoteIdentifier(unchangedToastCol)) } otherCols := utils.ArrayMinus(allCols, unchangedColsArray) tmpArray := make([]string, 0, len(otherCols)) @@ -203,13 +207,13 @@ func (n *normalizeStmtGenerator) generateUpdateStatements(allCols []string) []st } // set the synced at column to the current timestamp if n.peerdbCols.SyncedAtColName != "" { - tmpArray = append(tmpArray, fmt.Sprintf(`"%s"=CURRENT_TIMESTAMP`, - n.peerdbCols.SyncedAtColName)) + tmpArray = append(tmpArray, fmt.Sprintf(`%s=CURRENT_TIMESTAMP`, + QuoteIdentifier(n.peerdbCols.SyncedAtColName))) } // set soft-deleted to false, tackles insert after soft-delete if handleSoftDelete { - tmpArray = append(tmpArray, fmt.Sprintf(`"%s"=FALSE`, - n.peerdbCols.SoftDeleteColName)) + tmpArray = append(tmpArray, fmt.Sprintf(`%s=FALSE`, + QuoteIdentifier(n.peerdbCols.SoftDeleteColName))) } ssep := strings.Join(tmpArray, ",") @@ -223,7 +227,7 @@ func (n *normalizeStmtGenerator) generateUpdateStatements(allCols []string) []st // and then set soft-delete to true. if handleSoftDelete { tmpArray = append(tmpArray[:len(tmpArray)-1], - fmt.Sprintf(`"%s"=TRUE`, n.peerdbCols.SoftDeleteColName)) + fmt.Sprintf(`%s=TRUE`, QuoteIdentifier(n.peerdbCols.SoftDeleteColName))) ssep := strings.Join(tmpArray, ", ") updateStmt := fmt.Sprintf(`WHEN MATCHED AND src._peerdb_record_type=2 AND _peerdb_unchanged_toast_columns='%s' diff --git a/flow/connectors/postgres/qrep_sync_method.go b/flow/connectors/postgres/qrep_sync_method.go index 6725032411..6fdda9ab54 100644 --- a/flow/connectors/postgres/qrep_sync_method.go +++ b/flow/connectors/postgres/qrep_sync_method.go @@ -85,10 +85,10 @@ func (s *QRepStagingTableSync) SyncQRepRecords( if syncedAtCol != "" { updateSyncedAtStmt := fmt.Sprintf( - `UPDATE %s SET "%s" = CURRENT_TIMESTAMP WHERE "%s" IS NULL;`, + `UPDATE %s SET %s = CURRENT_TIMESTAMP WHERE %s IS NULL;`, pgx.Identifier{dstTableName.Schema, dstTableName.Table}.Sanitize(), - syncedAtCol, - syncedAtCol, + QuoteIdentifier(syncedAtCol), + QuoteIdentifier(syncedAtCol), ) _, err = tx.Exec(context.Background(), updateSyncedAtStmt) if err != nil { @@ -137,22 +137,23 @@ func (s *QRepStagingTableSync) SyncQRepRecords( selectStrArray := make([]string, 0) for _, col := range schema.GetColumnNames() { _, ok := upsertMatchCols[col] + quotedCol := QuoteIdentifier(col) if !ok { - setClauseArray = append(setClauseArray, fmt.Sprintf(`"%s" = EXCLUDED."%s"`, col, col)) + setClauseArray = append(setClauseArray, fmt.Sprintf(`%s = EXCLUDED.%s`, quotedCol, quotedCol)) } - selectStrArray = append(selectStrArray, fmt.Sprintf(`"%s"`, col)) + selectStrArray = append(selectStrArray, quotedCol) } setClauseArray = append(setClauseArray, - fmt.Sprintf(`"%s" = CURRENT_TIMESTAMP`, syncedAtCol)) + fmt.Sprintf(`%s = CURRENT_TIMESTAMP`, QuoteIdentifier(syncedAtCol))) setClause := strings.Join(setClauseArray, ",") selectSQL := strings.Join(selectStrArray, ",") // Step 2.3: Perform the upsert operation, ON CONFLICT UPDATE upsertStmt := fmt.Sprintf( - `INSERT INTO %s (%s, "%s") SELECT %s, CURRENT_TIMESTAMP FROM %s ON CONFLICT (%s) DO UPDATE SET %s;`, + `INSERT INTO %s (%s, %s) SELECT %s, CURRENT_TIMESTAMP FROM %s ON CONFLICT (%s) DO UPDATE SET %s;`, dstTableIdentifier.Sanitize(), selectSQL, - syncedAtCol, + QuoteIdentifier(syncedAtCol), selectSQL, stagingTableIdentifier.Sanitize(), strings.Join(writeMode.UpsertKeyColumns, ", "),