Skip to content

Commit

Permalink
mixed case table and column name support for BigQuery (#585)
Browse files Browse the repository at this point in the history
  • Loading branch information
heavycrystal authored Oct 27, 2023
1 parent 00ea6f0 commit 461f6f1
Show file tree
Hide file tree
Showing 11 changed files with 88 additions and 64 deletions.
1 change: 0 additions & 1 deletion flow/connectors/postgres/cdc.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,6 @@ func (p *PostgresCDCSource) consumeStream(
}

numRowsProcessedMessage := fmt.Sprintf("processed %d rows", len(records.Records))
utils.RecordHeartbeatWithRecover(p.ctx, numRowsProcessedMessage)

if time.Since(standByLastLogged) > 10*time.Second {
log.Infof("Sent Standby status message. %s", numRowsProcessedMessage)
Expand Down
25 changes: 14 additions & 11 deletions flow/connectors/postgres/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,12 @@ const (
)

// getRelIDForTable returns the relation ID for a table.
func (c *PostgresConnector) getRelIDForTable(schemaTable *SchemaTable) (uint32, error) {
func (c *PostgresConnector) getRelIDForTable(schemaTable *utils.SchemaTable) (uint32, error) {
var relID uint32
err := c.pool.QueryRow(c.ctx,
`SELECT c.oid FROM pg_class c JOIN pg_namespace n
ON n.oid = c.relnamespace WHERE n.nspname = $1 AND c.relname = $2`,
strings.ToLower(schemaTable.Schema), strings.ToLower(schemaTable.Table)).Scan(&relID)
ON n.oid = c.relnamespace WHERE n.nspname=$1 AND c.relname=$2`,
schemaTable.Schema, schemaTable.Table).Scan(&relID)
if err != nil {
return 0, fmt.Errorf("error getting relation ID for table %s: %w", schemaTable, err)
}
Expand All @@ -90,7 +90,7 @@ func (c *PostgresConnector) getRelIDForTable(schemaTable *SchemaTable) (uint32,
}

// getReplicaIdentity returns the replica identity for a table.
func (c *PostgresConnector) isTableFullReplica(schemaTable *SchemaTable) (bool, error) {
func (c *PostgresConnector) isTableFullReplica(schemaTable *utils.SchemaTable) (bool, error) {
relID, relIDErr := c.getRelIDForTable(schemaTable)
if relIDErr != nil {
return false, fmt.Errorf("failed to get relation id for table %s: %w", schemaTable, relIDErr)
Expand All @@ -108,7 +108,7 @@ func (c *PostgresConnector) isTableFullReplica(schemaTable *SchemaTable) (bool,

// getPrimaryKeyColumns for table returns the primary key column for a given table
// errors if there is no primary key column or if there is more than one primary key column.
func (c *PostgresConnector) getPrimaryKeyColumns(schemaTable *SchemaTable) ([]string, error) {
func (c *PostgresConnector) getPrimaryKeyColumns(schemaTable *utils.SchemaTable) ([]string, error) {
relID, err := c.getRelIDForTable(schemaTable)
if err != nil {
return nil, fmt.Errorf("failed to get relation id for table %s: %w", schemaTable, err)
Expand Down Expand Up @@ -144,7 +144,7 @@ func (c *PostgresConnector) getPrimaryKeyColumns(schemaTable *SchemaTable) ([]st
return pkCols, nil
}

func (c *PostgresConnector) tableExists(schemaTable *SchemaTable) (bool, error) {
func (c *PostgresConnector) tableExists(schemaTable *utils.SchemaTable) (bool, error) {
var exists bool
err := c.pool.QueryRow(c.ctx,
`SELECT EXISTS (
Expand Down Expand Up @@ -216,10 +216,11 @@ func (c *PostgresConnector) createSlotAndPublication(
*/
srcTableNames := make([]string, 0, len(tableNameMapping))
for srcTableName := range tableNameMapping {
if len(strings.Split(srcTableName, ".")) != 2 {
return fmt.Errorf("source tables identifier is invalid: %v", srcTableName)
parsedSrcTableName, err := utils.ParseSchemaTable(srcTableName)
if err != nil {
return fmt.Errorf("source table identifier %s is invalid", srcTableName)
}
srcTableNames = append(srcTableNames, srcTableName)
srcTableNames = append(srcTableNames, parsedSrcTableName.String())
}
tableNameString := strings.Join(srcTableNames, ", ")

Expand All @@ -229,6 +230,7 @@ func (c *PostgresConnector) createSlotAndPublication(
_, err := c.pool.Exec(c.ctx, stmt)
if err != nil {
log.Warnf("Error creating publication '%s': %v", publication, err)
return fmt.Errorf("error creating publication '%s' : %w", publication, err)
}
}

Expand Down Expand Up @@ -588,13 +590,14 @@ func (c *PostgresConnector) getApproxTableCounts(tables []string) (int64, error)
countTablesBatch := &pgx.Batch{}
totalCount := int64(0)
for _, table := range tables {
_, err := parseSchemaTable(table)
parsedTable, err := utils.ParseSchemaTable(table)
if err != nil {
log.Errorf("error while parsing table %s: %v", table, err)
return 0, fmt.Errorf("error while parsing table %s: %w", table, err)
}
countTablesBatch.Queue(
fmt.Sprintf("SELECT reltuples::bigint AS estimate FROM pg_class WHERE oid = '%s'::regclass;", table)).
fmt.Sprintf("SELECT reltuples::bigint AS estimate FROM pg_class WHERE oid = '%s'::regclass;",
parsedTable.String())).
QueryRow(func(row pgx.Row) error {
var count int64
err := row.Scan(&count)
Expand Down
37 changes: 6 additions & 31 deletions flow/connectors/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"database/sql"
"fmt"
"regexp"
"strings"
"time"

"github.com/PeerDB-io/peer-flow/connectors/utils"
Expand Down Expand Up @@ -34,18 +33,6 @@ type PostgresConnector struct {
customTypesMapping map[uint32]string
}

// SchemaTable is a table in a schema.
type SchemaTable struct {
Schema string
Table string
}

func (t *SchemaTable) String() string {
quotedSchema := fmt.Sprintf(`"%s"`, t.Schema)
quotedTable := fmt.Sprintf(`"%s"`, t.Table)
return fmt.Sprintf("%s.%s", quotedSchema, quotedTable)
}

// NewPostgresConnector creates a new instance of PostgresConnector.
func NewPostgresConnector(ctx context.Context, pgConfig *protos.PostgresConfig) (*PostgresConnector, error) {
connectionString := utils.GetPGConnectionString(pgConfig)
Expand Down Expand Up @@ -120,7 +107,7 @@ func (c *PostgresConnector) ConnectionActive() bool {

// NeedsSetupMetadataTables returns true if the metadata tables need to be set up.
func (c *PostgresConnector) NeedsSetupMetadataTables() bool {
result, err := c.tableExists(&SchemaTable{
result, err := c.tableExists(&utils.SchemaTable{
Schema: internalSchema,
Table: mirrorJobsTableIdentifier,
})
Expand Down Expand Up @@ -582,7 +569,7 @@ func (c *PostgresConnector) GetTableSchema(
func (c *PostgresConnector) getTableSchemaForTable(
tableName string,
) (*protos.TableSchema, error) {
schemaTable, err := parseSchemaTable(tableName)
schemaTable, err := utils.ParseSchemaTable(tableName)
if err != nil {
return nil, err
}
Expand All @@ -594,7 +581,8 @@ func (c *PostgresConnector) getTableSchemaForTable(

// Get the column names and types
rows, err := c.pool.Query(c.ctx,
fmt.Sprintf(`SELECT * FROM %s LIMIT 0`, tableName), pgx.QueryExecModeSimpleProtocol)
fmt.Sprintf(`SELECT * FROM %s LIMIT 0`, schemaTable.String()),
pgx.QueryExecModeSimpleProtocol)
if err != nil {
return nil, fmt.Errorf("error getting table schema for table %s: %w", schemaTable, err)
}
Expand Down Expand Up @@ -655,7 +643,7 @@ func (c *PostgresConnector) SetupNormalizedTables(req *protos.SetupNormalizedTab
}()

for tableIdentifier, tableSchema := range req.TableNameSchemaMapping {
normalizedTableNameComponents, err := parseSchemaTable(tableIdentifier)
normalizedTableNameComponents, err := utils.ParseSchemaTable(tableIdentifier)
if err != nil {
return nil, fmt.Errorf("error while parsing table schema and name: %w", err)
}
Expand Down Expand Up @@ -752,7 +740,7 @@ func (c *PostgresConnector) EnsurePullability(req *protos.EnsurePullabilityBatch

tableIdentifierMapping := make(map[string]*protos.TableIdentifier)
for _, tableName := range req.SourceTableIdentifiers {
schemaTable, err := parseSchemaTable(tableName)
schemaTable, err := utils.ParseSchemaTable(tableName)
if err != nil {
return nil, fmt.Errorf("error parsing schema and table: %w", err)
}
Expand Down Expand Up @@ -896,16 +884,3 @@ func (c *PostgresConnector) SendWALHeartbeat() error {

return nil
}

// parseSchemaTable parses a table name into schema and table name.
func parseSchemaTable(tableName string) (*SchemaTable, error) {
parts := strings.Split(tableName, ".")
if len(parts) != 2 {
return nil, fmt.Errorf("invalid table name: %s", tableName)
}

return &SchemaTable{
Schema: parts[0],
Table: parts[1],
}, nil
}
9 changes: 5 additions & 4 deletions flow/connectors/postgres/postgres_cdc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"testing"
"time"

"github.com/PeerDB-io/peer-flow/connectors/utils"
"github.com/PeerDB-io/peer-flow/generated/protos"
"github.com/PeerDB-io/peer-flow/model"
"github.com/PeerDB-io/peer-flow/model/qvalue"
Expand Down Expand Up @@ -345,19 +346,19 @@ func (suite *PostgresCDCTestSuite) TearDownSuite() {
}

func (suite *PostgresCDCTestSuite) TestParseSchemaTable() {
schemaTest1, err := parseSchemaTable("schema")
schemaTest1, err := utils.ParseSchemaTable("schema")
suite.Nil(schemaTest1)
suite.NotNil(err)

schemaTest2, err := parseSchemaTable("schema.table")
suite.Equal(&SchemaTable{
schemaTest2, err := utils.ParseSchemaTable("schema.table")
suite.Equal(&utils.SchemaTable{
Schema: "schema",
Table: "table",
}, schemaTest2)
suite.Equal("\"schema\".\"table\"", schemaTest2.String())
suite.Nil(err)

schemaTest3, err := parseSchemaTable("database.schema.table")
schemaTest3, err := utils.ParseSchemaTable("database.schema.table")
suite.Nil(schemaTest3)
suite.NotNil(err)
}
Expand Down
29 changes: 20 additions & 9 deletions flow/connectors/postgres/qrep.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@ import (
"text/template"
"time"

"github.com/PeerDB-io/peer-flow/connectors/utils"
"github.com/PeerDB-io/peer-flow/connectors/utils/metrics"
utils "github.com/PeerDB-io/peer-flow/connectors/utils/partition"
partition_utils "github.com/PeerDB-io/peer-flow/connectors/utils/partition"
"github.com/PeerDB-io/peer-flow/generated/protos"
"github.com/PeerDB-io/peer-flow/model"
"github.com/google/uuid"
Expand Down Expand Up @@ -135,8 +136,13 @@ func (c *PostgresConnector) getNumRowsPartitions(
whereClause = fmt.Sprintf(`WHERE %s > $1`, quotedWatermarkColumn)
}

parsedWatermarkTable, err := utils.ParseSchemaTable(config.WatermarkTable)
if err != nil {
return nil, fmt.Errorf("unable to parse watermark table: %w", err)
}

// Query to get the total number of rows in the table
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM %s %s", config.WatermarkTable, whereClause)
countQuery := fmt.Sprintf(`SELECT COUNT(*) FROM %s %s`, parsedWatermarkTable.String(), whereClause)
var row pgx.Row
var minVal interface{} = nil
if last != nil && last.Range != nil {
Expand Down Expand Up @@ -184,7 +190,7 @@ func (c *PostgresConnector) getNumRowsPartitions(
`,
numPartitions,
quotedWatermarkColumn,
config.WatermarkTable,
parsedWatermarkTable.String(),
)
log.Infof("[row_based_next] partitions query: %s", partitionsQuery)
rows, err = tx.Query(c.ctx, partitionsQuery, minVal)
Expand All @@ -199,7 +205,7 @@ func (c *PostgresConnector) getNumRowsPartitions(
`,
numPartitions,
quotedWatermarkColumn,
config.WatermarkTable,
parsedWatermarkTable.String(),
)
log.Infof("[row_based] partitions query: %s", partitionsQuery)
rows, err = tx.Query(c.ctx, partitionsQuery)
Expand All @@ -211,7 +217,7 @@ func (c *PostgresConnector) getNumRowsPartitions(
return nil, fmt.Errorf("failed to query for partitions: %w", err)
}

partitionHelper := utils.NewPartitionHelper()
partitionHelper := partition_utils.NewPartitionHelper()
for rows.Next() {
var bucket int64
var start, end interface{}
Expand Down Expand Up @@ -244,8 +250,13 @@ func (c *PostgresConnector) getMinMaxValues(
quotedWatermarkColumn = fmt.Sprintf("%s::text::bigint", quotedWatermarkColumn)
}

parsedWatermarkTable, err := utils.ParseSchemaTable(config.WatermarkTable)
if err != nil {
return nil, nil, fmt.Errorf("unable to parse watermark table: %w", err)
}

// Get the maximum value from the database
maxQuery := fmt.Sprintf("SELECT MAX(%[1]s) FROM %[2]s", quotedWatermarkColumn, config.WatermarkTable)
maxQuery := fmt.Sprintf("SELECT MAX(%[1]s) FROM %[2]s", quotedWatermarkColumn, parsedWatermarkTable.String())
row := tx.QueryRow(c.ctx, maxQuery)
if err := row.Scan(&maxValue); err != nil {
return nil, nil, fmt.Errorf("failed to query for max value: %w", err)
Expand Down Expand Up @@ -273,7 +284,7 @@ func (c *PostgresConnector) getMinMaxValues(
}
} else {
// Otherwise get the minimum value from the database
minQuery := fmt.Sprintf("SELECT MIN(%[1]s) FROM %[2]s", quotedWatermarkColumn, config.WatermarkTable)
minQuery := fmt.Sprintf("SELECT MIN(%[1]s) FROM %[2]s", quotedWatermarkColumn, parsedWatermarkTable.String())
row := tx.QueryRow(c.ctx, minQuery)
if err := row.Scan(&minValue); err != nil {
log.WithFields(log.Fields{
Expand Down Expand Up @@ -301,7 +312,7 @@ func (c *PostgresConnector) getMinMaxValues(
}
}

err := tx.Commit(c.ctx)
err = tx.Commit(c.ctx)
if err != nil {
return nil, nil, fmt.Errorf("failed to commit transaction: %w", err)
}
Expand Down Expand Up @@ -508,7 +519,7 @@ func (c *PostgresConnector) SyncQRepRecords(
partition *protos.QRepPartition,
stream *model.QRecordStream,
) (int, error) {
dstTable, err := parseSchemaTable(config.DestinationTableIdentifier)
dstTable, err := utils.ParseSchemaTable(config.DestinationTableIdentifier)
if err != nil {
return 0, fmt.Errorf("failed to parse destination table identifier: %w", err)
}
Expand Down
3 changes: 2 additions & 1 deletion flow/connectors/postgres/qrep_sync_method.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"strings"
"time"

"github.com/PeerDB-io/peer-flow/connectors/utils"
"github.com/PeerDB-io/peer-flow/connectors/utils/metrics"
"github.com/PeerDB-io/peer-flow/generated/protos"
"github.com/PeerDB-io/peer-flow/model"
Expand All @@ -30,7 +31,7 @@ type QRepStagingTableSync struct {

func (s *QRepStagingTableSync) SyncQRepRecords(
flowJobName string,
dstTableName *SchemaTable,
dstTableName *utils.SchemaTable,
partition *protos.QRepPartition,
stream *model.QRecordStream,
writeMode *protos.QRepWriteMode,
Expand Down
2 changes: 1 addition & 1 deletion flow/connectors/utils/partition/partition.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package utils
package partition_utils

import (
"fmt"
Expand Down
24 changes: 24 additions & 0 deletions flow/connectors/utils/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"net/url"
"strings"

"github.com/PeerDB-io/peer-flow/generated/protos"
"github.com/jackc/pgx/v5/pgxpool"
Expand Down Expand Up @@ -47,3 +48,26 @@ func GetCustomDataTypes(ctx context.Context, pool *pgxpool.Pool) (map[uint32]str
}
return customTypeMap, nil
}

// SchemaTable is a table in a schema.
type SchemaTable struct {
Schema string
Table string
}

func (t *SchemaTable) String() string {
return fmt.Sprintf(`"%s"."%s"`, t.Schema, t.Table)
}

// ParseSchemaTable parses a table name into schema and table name.
func ParseSchemaTable(tableName string) (*SchemaTable, error) {
parts := strings.Split(tableName, ".")
if len(parts) != 2 {
return nil, fmt.Errorf("invalid table name: %s", tableName)
}

return &SchemaTable{
Schema: parts[0],
Table: parts[1],
}, nil
}
12 changes: 11 additions & 1 deletion flow/workflows/snapshot_flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"time"

"github.com/PeerDB-io/peer-flow/concurrency"
"github.com/PeerDB-io/peer-flow/connectors/utils"
"github.com/PeerDB-io/peer-flow/generated/protos"
"github.com/PeerDB-io/peer-flow/shared"
"github.com/google/uuid"
Expand Down Expand Up @@ -130,7 +131,16 @@ func (s *SnapshotFlowExecution) cloneTable(
partitionCol = mapping.PartitionKey
}

query := fmt.Sprintf("SELECT * FROM %s WHERE %s BETWEEN {{.start}} AND {{.end}}", srcName, partitionCol)
parsedSrcTable, err := utils.ParseSchemaTable(srcName)
if err != nil {
logrus.WithFields(logrus.Fields{
"flowName": flowName,
"snapshotName": snapshotName,
}).Errorf("unable to parse source table")
return fmt.Errorf("unable to parse source table: %w", err)
}
query := fmt.Sprintf("SELECT * FROM %s WHERE %s BETWEEN {{.start}} AND {{.end}}",
parsedSrcTable.String(), partitionCol)

numWorkers := uint32(8)
if s.config.SnapshotMaxParallelWorkers > 0 {
Expand Down
Loading

0 comments on commit 461f6f1

Please sign in to comment.