From f30259807551e25437e9f129aaaa5e5df466e3a6 Mon Sep 17 00:00:00 2001 From: Kaushik Iska Date: Mon, 27 Nov 2023 14:01:23 -0500 Subject: [PATCH] support mixed case table names pg->pg (#722) Co-authored-by: Kevin Biju --- flow/connectors/postgres/client.go | 43 +++----------------- flow/connectors/postgres/postgres.go | 7 ++-- flow/connectors/postgres/qrep_sync_method.go | 14 ++++--- 3 files changed, 18 insertions(+), 46 deletions(-) diff --git a/flow/connectors/postgres/client.go b/flow/connectors/postgres/client.go index 816eecfb25..9a66d2a51b 100644 --- a/flow/connectors/postgres/client.go +++ b/flow/connectors/postgres/client.go @@ -508,14 +508,15 @@ func (c *PostgresConnector) generateFallbackStatements(destinationTableIdentifie destinationTableIdentifier, columnName, columnCast)) } deleteWhereClauseSQL := strings.TrimSuffix(strings.Join(deleteWhereClauseArray, ""), "AND ") + parsedDstTable, _ := utils.ParseSchemaTable(destinationTableIdentifier) fallbackUpsertStatement := fmt.Sprintf(fallbackUpsertStatementSQL, strings.TrimSuffix(strings.Join(maps.Values(primaryKeyColumnCasts), ","), ","), c.metadataSchema, - rawTableIdentifier, destinationTableIdentifier, insertColumnsSQL, flattenedCastsSQL, + rawTableIdentifier, parsedDstTable.String(), insertColumnsSQL, flattenedCastsSQL, strings.Join(normalizedTableSchema.PrimaryKeyColumns, ","), updateColumnsSQL) fallbackDeleteStatement := fmt.Sprintf(fallbackDeleteStatementSQL, strings.Join(maps.Values(primaryKeyColumnCasts), ","), c.metadataSchema, - rawTableIdentifier, destinationTableIdentifier, deleteWhereClauseSQL) + rawTableIdentifier, parsedDstTable.String(), deleteWhereClauseSQL) return []string{fallbackUpsertStatement, fallbackDeleteStatement} } @@ -529,6 +530,8 @@ func (c *PostgresConnector) generateMergeStatement(destinationTableIdentifier st } flattenedCastsSQLArray := make([]string, 0, len(normalizedTableSchema.Columns)) + parsedDstTable, _ := utils.ParseSchemaTable(destinationTableIdentifier) + primaryKeyColumnCasts := make(map[string]string) primaryKeySelectSQLArray := make([]string, 0, len(normalizedTableSchema.PrimaryKeyColumns)) for columnName, genericColumnType := range normalizedTableSchema.Columns { @@ -558,7 +561,7 @@ func (c *PostgresConnector) generateMergeStatement(destinationTableIdentifier st updateStatements := c.generateUpdateStatement(columnNames, unchangedToastColumns) return fmt.Sprintf(mergeStatementSQL, strings.Join(maps.Values(primaryKeyColumnCasts), ","), - c.metadataSchema, rawTableIdentifier, destinationTableIdentifier, flattenedCastsSQL, + c.metadataSchema, rawTableIdentifier, parsedDstTable.String(), flattenedCastsSQL, strings.Join(primaryKeySelectSQLArray, " AND "), insertColumnsSQL, insertValuesSQL, updateStatements) } @@ -584,40 +587,6 @@ func (c *PostgresConnector) generateUpdateStatement(allCols []string, unchangedT return strings.Join(updateStmts, "\n") } -func (c *PostgresConnector) getApproxTableCounts(tables []string) (int64, error) { - countTablesBatch := &pgx.Batch{} - totalCount := int64(0) - for _, table := range tables { - parsedTable, err := utils.ParseSchemaTable(table) - if err != nil { - log.Errorf("error while parsing table %s: %v", table, err) - return 0, fmt.Errorf("error while parsing table %s: %w", table, err) - } - countTablesBatch.Queue( - fmt.Sprintf("SELECT reltuples::bigint AS estimate FROM pg_class WHERE oid = '%s'::regclass;", - parsedTable.String())). - QueryRow(func(row pgx.Row) error { - var count int64 - err := row.Scan(&count) - if err != nil { - log.WithFields(log.Fields{ - "table": table, - }).Errorf("error while scanning row: %v", err) - return fmt.Errorf("error while scanning row: %w", err) - } - totalCount += count - return nil - }) - } - countTablesResults := c.pool.SendBatch(c.ctx, countTablesBatch) - err := countTablesResults.Close() - if err != nil { - log.Errorf("error while closing statement batch: %v", err) - return 0, fmt.Errorf("error while closing statement batch: %w", err) - } - return totalCount, nil -} - func (c *PostgresConnector) getCurrentLSN() (pglogrepl.LSN, error) { row := c.pool.QueryRow(c.ctx, "SELECT pg_current_wal_lsn();") var result string diff --git a/flow/connectors/postgres/postgres.go b/flow/connectors/postgres/postgres.go index ccabcfdb89..e46c383d4c 100644 --- a/flow/connectors/postgres/postgres.go +++ b/flow/connectors/postgres/postgres.go @@ -635,11 +635,11 @@ func (c *PostgresConnector) SetupNormalizedTables(req *protos.SetupNormalizedTab }() for tableIdentifier, tableSchema := range req.TableNameSchemaMapping { - normalizedTableNameComponents, err := utils.ParseSchemaTable(tableIdentifier) + parsedNormalizedTable, err := utils.ParseSchemaTable(tableIdentifier) if err != nil { return nil, fmt.Errorf("error while parsing table schema and name: %w", err) } - tableAlreadyExists, err := c.tableExists(normalizedTableNameComponents) + tableAlreadyExists, err := c.tableExists(parsedNormalizedTable) if err != nil { return nil, fmt.Errorf("error occurred while checking if normalized table exists: %w", err) } @@ -649,7 +649,8 @@ func (c *PostgresConnector) SetupNormalizedTables(req *protos.SetupNormalizedTab } // convert the column names and types to Postgres types - normalizedTableCreateSQL := generateCreateTableSQLForNormalizedTable(tableIdentifier, tableSchema) + normalizedTableCreateSQL := generateCreateTableSQLForNormalizedTable( + parsedNormalizedTable.String(), tableSchema) _, err = createNormalizedTablesTx.Exec(c.ctx, normalizedTableCreateSQL) if err != nil { return nil, fmt.Errorf("error while creating normalized table: %w", err) diff --git a/flow/connectors/postgres/qrep_sync_method.go b/flow/connectors/postgres/qrep_sync_method.go index c91d734764..46c7dcc63c 100644 --- a/flow/connectors/postgres/qrep_sync_method.go +++ b/flow/connectors/postgres/qrep_sync_method.go @@ -115,21 +115,23 @@ func (s *QRepStagingTableSync) SyncQRepRecords( // construct the SET clause for the upsert operation upsertMatchColsList := writeMode.UpsertKeyColumns - upsertMatchCols := make(map[string]bool) + upsertMatchCols := make(map[string]struct{}) for _, col := range upsertMatchColsList { - upsertMatchCols[col] = true + upsertMatchCols[col] = struct{}{} } - setClause := "" + setClauseArray := make([]string, 0) + selectStrArray := make([]string, 0) for _, col := range schema.GetColumnNames() { _, ok := upsertMatchCols[col] if !ok { - setClause += fmt.Sprintf("%s = EXCLUDED.%s,", col, col) + setClauseArray = append(setClauseArray, fmt.Sprintf(`"%s" = EXCLUDED."%s"`, col, col)) } + selectStrArray = append(selectStrArray, fmt.Sprintf(`"%s"`, col)) } - setClause = strings.TrimSuffix(setClause, ",") - selectStr := strings.Join(schema.GetColumnNames(), ", ") + setClause := strings.Join(setClauseArray, ",") + selectStr := strings.Join(selectStrArray, ",") // Step 2.3: Perform the upsert operation, ON CONFLICT UPDATE upsertStmt := fmt.Sprintf(