From 461f6f103a6cf10f9cd15aa039a1153b6b198a57 Mon Sep 17 00:00:00 2001 From: Kevin K Biju <52661649+heavycrystal@users.noreply.github.com> Date: Fri, 27 Oct 2023 20:04:38 +0000 Subject: [PATCH] mixed case table and column name support for BigQuery (#585) --- flow/connectors/postgres/cdc.go | 1 - flow/connectors/postgres/client.go | 25 +++++++------ flow/connectors/postgres/postgres.go | 37 +++---------------- flow/connectors/postgres/postgres_cdc_test.go | 9 +++-- flow/connectors/postgres/qrep.go | 29 ++++++++++----- flow/connectors/postgres/qrep_sync_method.go | 3 +- flow/connectors/utils/partition/partition.go | 2 +- flow/connectors/utils/postgres.go | 24 ++++++++++++ flow/workflows/snapshot_flow.go | 12 +++++- nexus/analyzer/src/lib.rs | 6 +-- nexus/postgres-connection/src/lib.rs | 4 +- 11 files changed, 88 insertions(+), 64 deletions(-) diff --git a/flow/connectors/postgres/cdc.go b/flow/connectors/postgres/cdc.go index bd3f9fc40e..b087e9d0f3 100644 --- a/flow/connectors/postgres/cdc.go +++ b/flow/connectors/postgres/cdc.go @@ -222,7 +222,6 @@ func (p *PostgresCDCSource) consumeStream( } numRowsProcessedMessage := fmt.Sprintf("processed %d rows", len(records.Records)) - utils.RecordHeartbeatWithRecover(p.ctx, numRowsProcessedMessage) if time.Since(standByLastLogged) > 10*time.Second { log.Infof("Sent Standby status message. %s", numRowsProcessedMessage) diff --git a/flow/connectors/postgres/client.go b/flow/connectors/postgres/client.go index 1c89edf6cf..c9bed66d2f 100644 --- a/flow/connectors/postgres/client.go +++ b/flow/connectors/postgres/client.go @@ -76,12 +76,12 @@ const ( ) // getRelIDForTable returns the relation ID for a table. -func (c *PostgresConnector) getRelIDForTable(schemaTable *SchemaTable) (uint32, error) { +func (c *PostgresConnector) getRelIDForTable(schemaTable *utils.SchemaTable) (uint32, error) { var relID uint32 err := c.pool.QueryRow(c.ctx, `SELECT c.oid FROM pg_class c JOIN pg_namespace n - ON n.oid = c.relnamespace WHERE n.nspname = $1 AND c.relname = $2`, - strings.ToLower(schemaTable.Schema), strings.ToLower(schemaTable.Table)).Scan(&relID) + ON n.oid = c.relnamespace WHERE n.nspname=$1 AND c.relname=$2`, + schemaTable.Schema, schemaTable.Table).Scan(&relID) if err != nil { return 0, fmt.Errorf("error getting relation ID for table %s: %w", schemaTable, err) } @@ -90,7 +90,7 @@ func (c *PostgresConnector) getRelIDForTable(schemaTable *SchemaTable) (uint32, } // getReplicaIdentity returns the replica identity for a table. -func (c *PostgresConnector) isTableFullReplica(schemaTable *SchemaTable) (bool, error) { +func (c *PostgresConnector) isTableFullReplica(schemaTable *utils.SchemaTable) (bool, error) { relID, relIDErr := c.getRelIDForTable(schemaTable) if relIDErr != nil { return false, fmt.Errorf("failed to get relation id for table %s: %w", schemaTable, relIDErr) @@ -108,7 +108,7 @@ func (c *PostgresConnector) isTableFullReplica(schemaTable *SchemaTable) (bool, // getPrimaryKeyColumns for table returns the primary key column for a given table // errors if there is no primary key column or if there is more than one primary key column. -func (c *PostgresConnector) getPrimaryKeyColumns(schemaTable *SchemaTable) ([]string, error) { +func (c *PostgresConnector) getPrimaryKeyColumns(schemaTable *utils.SchemaTable) ([]string, error) { relID, err := c.getRelIDForTable(schemaTable) if err != nil { return nil, fmt.Errorf("failed to get relation id for table %s: %w", schemaTable, err) @@ -144,7 +144,7 @@ func (c *PostgresConnector) getPrimaryKeyColumns(schemaTable *SchemaTable) ([]st return pkCols, nil } -func (c *PostgresConnector) tableExists(schemaTable *SchemaTable) (bool, error) { +func (c *PostgresConnector) tableExists(schemaTable *utils.SchemaTable) (bool, error) { var exists bool err := c.pool.QueryRow(c.ctx, `SELECT EXISTS ( @@ -216,10 +216,11 @@ func (c *PostgresConnector) createSlotAndPublication( */ srcTableNames := make([]string, 0, len(tableNameMapping)) for srcTableName := range tableNameMapping { - if len(strings.Split(srcTableName, ".")) != 2 { - return fmt.Errorf("source tables identifier is invalid: %v", srcTableName) + parsedSrcTableName, err := utils.ParseSchemaTable(srcTableName) + if err != nil { + return fmt.Errorf("source table identifier %s is invalid", srcTableName) } - srcTableNames = append(srcTableNames, srcTableName) + srcTableNames = append(srcTableNames, parsedSrcTableName.String()) } tableNameString := strings.Join(srcTableNames, ", ") @@ -229,6 +230,7 @@ func (c *PostgresConnector) createSlotAndPublication( _, err := c.pool.Exec(c.ctx, stmt) if err != nil { log.Warnf("Error creating publication '%s': %v", publication, err) + return fmt.Errorf("error creating publication '%s' : %w", publication, err) } } @@ -588,13 +590,14 @@ func (c *PostgresConnector) getApproxTableCounts(tables []string) (int64, error) countTablesBatch := &pgx.Batch{} totalCount := int64(0) for _, table := range tables { - _, err := parseSchemaTable(table) + parsedTable, err := utils.ParseSchemaTable(table) if err != nil { log.Errorf("error while parsing table %s: %v", table, err) return 0, fmt.Errorf("error while parsing table %s: %w", table, err) } countTablesBatch.Queue( - fmt.Sprintf("SELECT reltuples::bigint AS estimate FROM pg_class WHERE oid = '%s'::regclass;", table)). + fmt.Sprintf("SELECT reltuples::bigint AS estimate FROM pg_class WHERE oid = '%s'::regclass;", + parsedTable.String())). QueryRow(func(row pgx.Row) error { var count int64 err := row.Scan(&count) diff --git a/flow/connectors/postgres/postgres.go b/flow/connectors/postgres/postgres.go index c550ae1445..b312115536 100644 --- a/flow/connectors/postgres/postgres.go +++ b/flow/connectors/postgres/postgres.go @@ -5,7 +5,6 @@ import ( "database/sql" "fmt" "regexp" - "strings" "time" "github.com/PeerDB-io/peer-flow/connectors/utils" @@ -34,18 +33,6 @@ type PostgresConnector struct { customTypesMapping map[uint32]string } -// SchemaTable is a table in a schema. -type SchemaTable struct { - Schema string - Table string -} - -func (t *SchemaTable) String() string { - quotedSchema := fmt.Sprintf(`"%s"`, t.Schema) - quotedTable := fmt.Sprintf(`"%s"`, t.Table) - return fmt.Sprintf("%s.%s", quotedSchema, quotedTable) -} - // NewPostgresConnector creates a new instance of PostgresConnector. func NewPostgresConnector(ctx context.Context, pgConfig *protos.PostgresConfig) (*PostgresConnector, error) { connectionString := utils.GetPGConnectionString(pgConfig) @@ -120,7 +107,7 @@ func (c *PostgresConnector) ConnectionActive() bool { // NeedsSetupMetadataTables returns true if the metadata tables need to be set up. func (c *PostgresConnector) NeedsSetupMetadataTables() bool { - result, err := c.tableExists(&SchemaTable{ + result, err := c.tableExists(&utils.SchemaTable{ Schema: internalSchema, Table: mirrorJobsTableIdentifier, }) @@ -582,7 +569,7 @@ func (c *PostgresConnector) GetTableSchema( func (c *PostgresConnector) getTableSchemaForTable( tableName string, ) (*protos.TableSchema, error) { - schemaTable, err := parseSchemaTable(tableName) + schemaTable, err := utils.ParseSchemaTable(tableName) if err != nil { return nil, err } @@ -594,7 +581,8 @@ func (c *PostgresConnector) getTableSchemaForTable( // Get the column names and types rows, err := c.pool.Query(c.ctx, - fmt.Sprintf(`SELECT * FROM %s LIMIT 0`, tableName), pgx.QueryExecModeSimpleProtocol) + fmt.Sprintf(`SELECT * FROM %s LIMIT 0`, schemaTable.String()), + pgx.QueryExecModeSimpleProtocol) if err != nil { return nil, fmt.Errorf("error getting table schema for table %s: %w", schemaTable, err) } @@ -655,7 +643,7 @@ func (c *PostgresConnector) SetupNormalizedTables(req *protos.SetupNormalizedTab }() for tableIdentifier, tableSchema := range req.TableNameSchemaMapping { - normalizedTableNameComponents, err := parseSchemaTable(tableIdentifier) + normalizedTableNameComponents, err := utils.ParseSchemaTable(tableIdentifier) if err != nil { return nil, fmt.Errorf("error while parsing table schema and name: %w", err) } @@ -752,7 +740,7 @@ func (c *PostgresConnector) EnsurePullability(req *protos.EnsurePullabilityBatch tableIdentifierMapping := make(map[string]*protos.TableIdentifier) for _, tableName := range req.SourceTableIdentifiers { - schemaTable, err := parseSchemaTable(tableName) + schemaTable, err := utils.ParseSchemaTable(tableName) if err != nil { return nil, fmt.Errorf("error parsing schema and table: %w", err) } @@ -896,16 +884,3 @@ func (c *PostgresConnector) SendWALHeartbeat() error { return nil } - -// parseSchemaTable parses a table name into schema and table name. -func parseSchemaTable(tableName string) (*SchemaTable, error) { - parts := strings.Split(tableName, ".") - if len(parts) != 2 { - return nil, fmt.Errorf("invalid table name: %s", tableName) - } - - return &SchemaTable{ - Schema: parts[0], - Table: parts[1], - }, nil -} diff --git a/flow/connectors/postgres/postgres_cdc_test.go b/flow/connectors/postgres/postgres_cdc_test.go index 2ef5609359..439cb8d2c6 100644 --- a/flow/connectors/postgres/postgres_cdc_test.go +++ b/flow/connectors/postgres/postgres_cdc_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + "github.com/PeerDB-io/peer-flow/connectors/utils" "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/model" "github.com/PeerDB-io/peer-flow/model/qvalue" @@ -345,19 +346,19 @@ func (suite *PostgresCDCTestSuite) TearDownSuite() { } func (suite *PostgresCDCTestSuite) TestParseSchemaTable() { - schemaTest1, err := parseSchemaTable("schema") + schemaTest1, err := utils.ParseSchemaTable("schema") suite.Nil(schemaTest1) suite.NotNil(err) - schemaTest2, err := parseSchemaTable("schema.table") - suite.Equal(&SchemaTable{ + schemaTest2, err := utils.ParseSchemaTable("schema.table") + suite.Equal(&utils.SchemaTable{ Schema: "schema", Table: "table", }, schemaTest2) suite.Equal("\"schema\".\"table\"", schemaTest2.String()) suite.Nil(err) - schemaTest3, err := parseSchemaTable("database.schema.table") + schemaTest3, err := utils.ParseSchemaTable("database.schema.table") suite.Nil(schemaTest3) suite.NotNil(err) } diff --git a/flow/connectors/postgres/qrep.go b/flow/connectors/postgres/qrep.go index 69f2957a4b..e1dd47c314 100644 --- a/flow/connectors/postgres/qrep.go +++ b/flow/connectors/postgres/qrep.go @@ -6,8 +6,9 @@ import ( "text/template" "time" + "github.com/PeerDB-io/peer-flow/connectors/utils" "github.com/PeerDB-io/peer-flow/connectors/utils/metrics" - utils "github.com/PeerDB-io/peer-flow/connectors/utils/partition" + partition_utils "github.com/PeerDB-io/peer-flow/connectors/utils/partition" "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/model" "github.com/google/uuid" @@ -135,8 +136,13 @@ func (c *PostgresConnector) getNumRowsPartitions( whereClause = fmt.Sprintf(`WHERE %s > $1`, quotedWatermarkColumn) } + parsedWatermarkTable, err := utils.ParseSchemaTable(config.WatermarkTable) + if err != nil { + return nil, fmt.Errorf("unable to parse watermark table: %w", err) + } + // Query to get the total number of rows in the table - countQuery := fmt.Sprintf("SELECT COUNT(*) FROM %s %s", config.WatermarkTable, whereClause) + countQuery := fmt.Sprintf(`SELECT COUNT(*) FROM %s %s`, parsedWatermarkTable.String(), whereClause) var row pgx.Row var minVal interface{} = nil if last != nil && last.Range != nil { @@ -184,7 +190,7 @@ func (c *PostgresConnector) getNumRowsPartitions( `, numPartitions, quotedWatermarkColumn, - config.WatermarkTable, + parsedWatermarkTable.String(), ) log.Infof("[row_based_next] partitions query: %s", partitionsQuery) rows, err = tx.Query(c.ctx, partitionsQuery, minVal) @@ -199,7 +205,7 @@ func (c *PostgresConnector) getNumRowsPartitions( `, numPartitions, quotedWatermarkColumn, - config.WatermarkTable, + parsedWatermarkTable.String(), ) log.Infof("[row_based] partitions query: %s", partitionsQuery) rows, err = tx.Query(c.ctx, partitionsQuery) @@ -211,7 +217,7 @@ func (c *PostgresConnector) getNumRowsPartitions( return nil, fmt.Errorf("failed to query for partitions: %w", err) } - partitionHelper := utils.NewPartitionHelper() + partitionHelper := partition_utils.NewPartitionHelper() for rows.Next() { var bucket int64 var start, end interface{} @@ -244,8 +250,13 @@ func (c *PostgresConnector) getMinMaxValues( quotedWatermarkColumn = fmt.Sprintf("%s::text::bigint", quotedWatermarkColumn) } + parsedWatermarkTable, err := utils.ParseSchemaTable(config.WatermarkTable) + if err != nil { + return nil, nil, fmt.Errorf("unable to parse watermark table: %w", err) + } + // Get the maximum value from the database - maxQuery := fmt.Sprintf("SELECT MAX(%[1]s) FROM %[2]s", quotedWatermarkColumn, config.WatermarkTable) + maxQuery := fmt.Sprintf("SELECT MAX(%[1]s) FROM %[2]s", quotedWatermarkColumn, parsedWatermarkTable.String()) row := tx.QueryRow(c.ctx, maxQuery) if err := row.Scan(&maxValue); err != nil { return nil, nil, fmt.Errorf("failed to query for max value: %w", err) @@ -273,7 +284,7 @@ func (c *PostgresConnector) getMinMaxValues( } } else { // Otherwise get the minimum value from the database - minQuery := fmt.Sprintf("SELECT MIN(%[1]s) FROM %[2]s", quotedWatermarkColumn, config.WatermarkTable) + minQuery := fmt.Sprintf("SELECT MIN(%[1]s) FROM %[2]s", quotedWatermarkColumn, parsedWatermarkTable.String()) row := tx.QueryRow(c.ctx, minQuery) if err := row.Scan(&minValue); err != nil { log.WithFields(log.Fields{ @@ -301,7 +312,7 @@ func (c *PostgresConnector) getMinMaxValues( } } - err := tx.Commit(c.ctx) + err = tx.Commit(c.ctx) if err != nil { return nil, nil, fmt.Errorf("failed to commit transaction: %w", err) } @@ -508,7 +519,7 @@ func (c *PostgresConnector) SyncQRepRecords( partition *protos.QRepPartition, stream *model.QRecordStream, ) (int, error) { - dstTable, err := parseSchemaTable(config.DestinationTableIdentifier) + dstTable, err := utils.ParseSchemaTable(config.DestinationTableIdentifier) if err != nil { return 0, fmt.Errorf("failed to parse destination table identifier: %w", err) } diff --git a/flow/connectors/postgres/qrep_sync_method.go b/flow/connectors/postgres/qrep_sync_method.go index 09264cbaf3..84b7bb8949 100644 --- a/flow/connectors/postgres/qrep_sync_method.go +++ b/flow/connectors/postgres/qrep_sync_method.go @@ -6,6 +6,7 @@ import ( "strings" "time" + "github.com/PeerDB-io/peer-flow/connectors/utils" "github.com/PeerDB-io/peer-flow/connectors/utils/metrics" "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/model" @@ -30,7 +31,7 @@ type QRepStagingTableSync struct { func (s *QRepStagingTableSync) SyncQRepRecords( flowJobName string, - dstTableName *SchemaTable, + dstTableName *utils.SchemaTable, partition *protos.QRepPartition, stream *model.QRecordStream, writeMode *protos.QRepWriteMode, diff --git a/flow/connectors/utils/partition/partition.go b/flow/connectors/utils/partition/partition.go index ad8bb6067a..7fcc40a69b 100644 --- a/flow/connectors/utils/partition/partition.go +++ b/flow/connectors/utils/partition/partition.go @@ -1,4 +1,4 @@ -package utils +package partition_utils import ( "fmt" diff --git a/flow/connectors/utils/postgres.go b/flow/connectors/utils/postgres.go index 2080d2f3dd..cd6d9983ac 100644 --- a/flow/connectors/utils/postgres.go +++ b/flow/connectors/utils/postgres.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net/url" + "strings" "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/jackc/pgx/v5/pgxpool" @@ -47,3 +48,26 @@ func GetCustomDataTypes(ctx context.Context, pool *pgxpool.Pool) (map[uint32]str } return customTypeMap, nil } + +// SchemaTable is a table in a schema. +type SchemaTable struct { + Schema string + Table string +} + +func (t *SchemaTable) String() string { + return fmt.Sprintf(`"%s"."%s"`, t.Schema, t.Table) +} + +// ParseSchemaTable parses a table name into schema and table name. +func ParseSchemaTable(tableName string) (*SchemaTable, error) { + parts := strings.Split(tableName, ".") + if len(parts) != 2 { + return nil, fmt.Errorf("invalid table name: %s", tableName) + } + + return &SchemaTable{ + Schema: parts[0], + Table: parts[1], + }, nil +} diff --git a/flow/workflows/snapshot_flow.go b/flow/workflows/snapshot_flow.go index d8ff8f232b..fab06d3ce1 100644 --- a/flow/workflows/snapshot_flow.go +++ b/flow/workflows/snapshot_flow.go @@ -6,6 +6,7 @@ import ( "time" "github.com/PeerDB-io/peer-flow/concurrency" + "github.com/PeerDB-io/peer-flow/connectors/utils" "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/shared" "github.com/google/uuid" @@ -130,7 +131,16 @@ func (s *SnapshotFlowExecution) cloneTable( partitionCol = mapping.PartitionKey } - query := fmt.Sprintf("SELECT * FROM %s WHERE %s BETWEEN {{.start}} AND {{.end}}", srcName, partitionCol) + parsedSrcTable, err := utils.ParseSchemaTable(srcName) + if err != nil { + logrus.WithFields(logrus.Fields{ + "flowName": flowName, + "snapshotName": snapshotName, + }).Errorf("unable to parse source table") + return fmt.Errorf("unable to parse source table: %w", err) + } + query := fmt.Sprintf("SELECT * FROM %s WHERE %s BETWEEN {{.start}} AND {{.end}}", + parsedSrcTable.String(), partitionCol) numWorkers := uint32(8) if s.config.SnapshotMaxParallelWorkers > 0 { diff --git a/nexus/analyzer/src/lib.rs b/nexus/analyzer/src/lib.rs index c361bb5865..0525767d66 100644 --- a/nexus/analyzer/src/lib.rs +++ b/nexus/analyzer/src/lib.rs @@ -156,12 +156,10 @@ impl<'a> StatementAnalyzer for PeerDDLAnalyzer<'a> { flow_job_table_mappings.push(FlowJobTableMapping { source_table_identifier: table_mapping .source - .to_string() - .to_lowercase(), + .to_string(), destination_table_identifier: table_mapping .destination - .to_string() - .to_lowercase(), + .to_string(), partition_key: table_mapping .partition_key .clone() diff --git a/nexus/postgres-connection/src/lib.rs b/nexus/postgres-connection/src/lib.rs index 21ba7d0825..58e9ecc793 100644 --- a/nexus/postgres-connection/src/lib.rs +++ b/nexus/postgres-connection/src/lib.rs @@ -17,7 +17,7 @@ pub fn get_pg_connection_string(config: &PostgresConfig) -> String { connection_string.push('/'); connection_string.push_str(&config.database); - // Add the timeout as a query parameter + // Add the timeout as a query parameter, sslmode changes here appear to be useless connection_string.push_str("?connect_timeout=15"); connection_string @@ -27,6 +27,8 @@ pub async fn connect_postgres(config: &PostgresConfig) -> anyhow::Result