Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] CDC mirrors with SCHEMA MAPPING support #420

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 125 additions & 15 deletions flow/activities/flowable.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

"github.com/PeerDB-io/peer-flow/connectors"
connpostgres "github.com/PeerDB-io/peer-flow/connectors/postgres"
connsnowflake "github.com/PeerDB-io/peer-flow/connectors/snowflake"
"github.com/PeerDB-io/peer-flow/connectors/utils"
"github.com/PeerDB-io/peer-flow/connectors/utils/metrics"
"github.com/PeerDB-io/peer-flow/connectors/utils/monitoring"
Expand All @@ -18,12 +19,14 @@ import (
"github.com/jackc/pglogrepl"
log "github.com/sirupsen/logrus"
"go.temporal.io/sdk/activity"
"golang.org/x/exp/maps"
)

// CheckConnectionResult is the result of a CheckConnection call.
type CheckConnectionResult struct {
// True of metadata tables need to be set up.
// True if metadata tables need to be set up.
NeedsSetupMetadataTables bool
SupportsSchemaMapping bool
}

type SlotSnapshotSignal struct {
Expand All @@ -37,8 +40,17 @@ type FlowableActivity struct {
CatalogMirrorMonitor *monitoring.CatalogMirrorMonitor
}

// CheckConnection implements CheckConnection.
func (a *FlowableActivity) CheckConnection(
func (a *FlowableActivity) CheckPullConnection(ctx context.Context, config *protos.Peer) (bool, error) {
srcConn, err := connectors.GetCDCPullConnector(ctx, config)
if err != nil {
return false, fmt.Errorf("failed to get connector: %w", err)
}
defer connectors.CloseConnector(srcConn)

return srcConn.ConnectionActive(), nil
}

func (a *FlowableActivity) CheckSyncConnection(
ctx context.Context,
config *protos.Peer,
) (*CheckConnectionResult, error) {
Expand All @@ -50,8 +62,17 @@ func (a *FlowableActivity) CheckConnection(

needsSetup := dstConn.NeedsSetupMetadataTables()

supportsSchemaMapping := false
switch dstConn.(type) {
case *connpostgres.PostgresConnector:
supportsSchemaMapping = true
case *connsnowflake.SnowflakeConnector:
supportsSchemaMapping = true
}

return &CheckConnectionResult{
NeedsSetupMetadataTables: needsSetup,
SupportsSchemaMapping: supportsSchemaMapping,
}, nil
}

Expand Down Expand Up @@ -146,18 +167,18 @@ func (a *FlowableActivity) CreateNormalizedTable(
ctx context.Context,
config *protos.SetupNormalizedTableBatchInput,
) (*protos.SetupNormalizedTableBatchOutput, error) {
conn, err := connectors.GetCDCSyncConnector(ctx, config.PeerConnectionConfig)
dstConn, err := connectors.GetCDCSyncConnector(ctx, config.PeerConnectionConfig)
if err != nil {
return nil, fmt.Errorf("failed to get connector: %w", err)
}
defer connectors.CloseConnector(conn)
defer connectors.CloseConnector(dstConn)

return conn.SetupNormalizedTables(config)
return dstConn.SetupNormalizedTables(config)
}

// StartFlow implements StartFlow.
func (a *FlowableActivity) StartFlow(ctx context.Context,
input *protos.StartFlowInput) (*model.SyncResponse, error) {
input *protos.StartFlowInput) (*protos.SyncResponse, error) {
activity.RecordHeartbeat(ctx, "starting flow...")
conn := input.FlowConnectionConfigs

Expand Down Expand Up @@ -188,8 +209,12 @@ func (a *FlowableActivity) StartFlow(ctx context.Context,
"flowName": input.FlowConnectionConfigs.FlowJobName,
}).Info("pulling records...")

var schemas []string
if input.FlowConnectionConfigs.MappingType == protos.MappingType_SCHEMA {
schemas = maps.Keys(input.FlowConnectionConfigs.SchemaMapping)
}
startTime := time.Now()
recordsWithTableSchemaDelta, err := srcConn.PullRecords(&model.PullRecordsRequest{
recordsWithDeltaInfo, err := srcConn.PullRecords(&model.PullRecordsRequest{
FlowJobName: input.FlowConnectionConfigs.FlowJobName,
SrcTableIDNameMapping: input.FlowConnectionConfigs.SrcTableIdNameMapping,
TableNameMapping: input.FlowConnectionConfigs.TableNameMapping,
Expand All @@ -200,11 +225,13 @@ func (a *FlowableActivity) StartFlow(ctx context.Context,
OverridePublicationName: input.FlowConnectionConfigs.PublicationName,
OverrideReplicationSlotName: input.FlowConnectionConfigs.ReplicationSlotName,
RelationMessageMapping: input.RelationMessageMapping,
Schemas: schemas,
AllowTableAdditions: input.FlowConnectionConfigs.AllowTableAdditions,
})
if err != nil {
return nil, fmt.Errorf("failed to pull records: %w", err)
}
recordBatch := recordsWithTableSchemaDelta.RecordBatch
recordBatch := recordsWithDeltaInfo.RecordBatch

pullRecordWithCount := fmt.Sprintf("pulled %d records", len(recordBatch.Records))
activity.RecordHeartbeat(ctx, pullRecordWithCount)
Expand Down Expand Up @@ -242,9 +269,9 @@ func (a *FlowableActivity) StartFlow(ctx context.Context,
metrics.LogSyncMetrics(ctx, input.FlowConnectionConfigs.FlowJobName, 0, 1)
metrics.LogNormalizeMetrics(ctx, input.FlowConnectionConfigs.FlowJobName, 0, 1, 0)
metrics.LogCDCRawThroughputMetrics(ctx, input.FlowConnectionConfigs.FlowJobName, 0)
return &model.SyncResponse{
RelationMessageMapping: recordsWithTableSchemaDelta.RelationMessageMapping,
TableSchemaDelta: recordsWithTableSchemaDelta.TableSchemaDelta,
return &protos.SyncResponse{
RelationMessageMapping: recordsWithDeltaInfo.RelationMessageMapping,
MirrorDelta: recordsWithDeltaInfo.MirrorDelta,
}, nil
}

Expand Down Expand Up @@ -284,16 +311,16 @@ func (a *FlowableActivity) StartFlow(ctx context.Context,
}
if res.TableNameRowsMapping != nil {
err = a.CatalogMirrorMonitor.AddCDCBatchTablesForFlow(ctx, input.FlowConnectionConfigs.FlowJobName,
res.CurrentSyncBatchID, res.TableNameRowsMapping)
res.CurrentSyncBatchId, res.TableNameRowsMapping)
if err != nil {
return nil, err
}
}
if err != nil {
return nil, err
}
res.TableSchemaDelta = recordsWithTableSchemaDelta.TableSchemaDelta
res.RelationMessageMapping = recordsWithTableSchemaDelta.RelationMessageMapping
res.MirrorDelta = recordsWithDeltaInfo.MirrorDelta
res.RelationMessageMapping = recordsWithDeltaInfo.RelationMessageMapping

pushedRecordsWithCount := fmt.Sprintf("pushed %d records", numRecords)
activity.RecordHeartbeat(ctx, pushedRecordsWithCount)
Expand Down Expand Up @@ -627,3 +654,86 @@ func (a *FlowableActivity) DropFlow(ctx context.Context, config *protos.Shutdown
}
return nil
}

// PopulateTableMappingFromSchemas sets up the TableNameMapping from SchemaMapping for MappingType SCHEMA
func (a *FlowableActivity) PopulateTableMappingFromSchemas(
ctx context.Context,
config *protos.ListTablesInSchemasInput,
) (*protos.ListTablesInSchemasOutput, error) {
srcConn, err := connectors.GetCDCPullConnector(ctx, config.PeerConnectionConfig)
if err != nil {
return nil, fmt.Errorf("failed to get source connector: %w", err)
}
defer connectors.CloseConnector(srcConn)

srcSchemas := maps.Keys(config.SchemaMapping)

schemaTablesMapping, err := srcConn.ListTablesInSchemas(srcSchemas)
if err != nil {
return nil, fmt.Errorf("failed to get schemaTablesMapping: %w", err)
}

schemaTablesMappingProto := make(map[string]*protos.TablesList)
for schema, tables := range schemaTablesMapping {
schemaTablesMappingProto[schema] = &protos.TablesList{
Tables: tables,
}
}

return &protos.ListTablesInSchemasOutput{
SchemaToTables: schemaTablesMappingProto,
}, nil
}

func (a *FlowableActivity) CreateAdditionalTable(
ctx context.Context,
input *protos.CreateAdditionalTableInput) (*protos.AdditionalTableDelta, error) {
srcConn, err := connectors.GetCDCPullConnector(ctx, input.FlowConnectionConfigs.Source)
if err != nil {
return nil, fmt.Errorf("failed to get source connector: %w", err)
}
defer connectors.CloseConnector(srcConn)

dstConn, err := connectors.GetCDCSyncConnector(ctx, input.FlowConnectionConfigs.Destination)
if err != nil {
return nil, fmt.Errorf("failed to get destination connector: %w", err)
}
defer connectors.CloseConnector(dstConn)

srcTableIdentifier := fmt.Sprintf("%s.%s", input.AdditionalTableInfo.SrcSchema,
input.AdditionalTableInfo.TableName)
dstTableIdentifier := fmt.Sprintf("%s.%s", input.AdditionalTableInfo.DstSchema,
input.AdditionalTableInfo.TableName)

tableRelIDMapping, err := srcConn.EnsurePullability(&protos.EnsurePullabilityBatchInput{
SourceTableIdentifiers: []string{srcTableIdentifier},
})
if err != nil {
return nil, fmt.Errorf("failed to ensure pullability for additional table: %w", err)
}

tableNameSchemaMapping, err := srcConn.GetTableSchema(&protos.GetTableSchemaBatchInput{
TableIdentifiers: []string{srcTableIdentifier},
})
if err != nil {
return nil, fmt.Errorf("failed to get schema for additional table: %w", err)
}

_, err = dstConn.SetupNormalizedTables(&protos.SetupNormalizedTableBatchInput{
TableNameSchemaMapping: map[string]*protos.TableSchema{
dstTableIdentifier: tableNameSchemaMapping.TableNameSchemaMapping[srcTableIdentifier]},
})
if err != nil {
return nil, fmt.Errorf("failed to create additional table at destination: %w", err)
}

input.AdditionalTableInfo.RelId = tableRelIDMapping.
TableIdentifierMapping[srcTableIdentifier].GetPostgresTableIdentifier().RelId
input.AdditionalTableInfo.TableSchema = tableNameSchemaMapping.TableNameSchemaMapping[srcTableIdentifier]

log.WithFields(log.Fields{
"flowName": input.FlowConnectionConfigs.FlowJobName,
}).Infof("finished creating additional table %s\n", dstTableIdentifier)

return input.AdditionalTableInfo, nil
}
36 changes: 18 additions & 18 deletions flow/connectors/bigquery/bigquery.go
Original file line number Diff line number Diff line change
Expand Up @@ -438,11 +438,11 @@ func (r StagingBQRecord) Save() (map[string]bigquery.Value, string, error) {
// SyncRecords pushes records to the destination.
// currently only supports inserts,updates and deletes
// more record types will be added in the future.
func (c *BigQueryConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.SyncResponse, error) {
func (c *BigQueryConnector) SyncRecords(req *model.SyncRecordsRequest) (*protos.SyncResponse, error) {
if len(req.Records.Records) == 0 {
return &model.SyncResponse{
FirstSyncedCheckPointID: 0,
LastSyncedCheckPointID: 0,
return &protos.SyncResponse{
FirstSyncedCheckpointId: 0,
LastSyncedCheckpointId: 0,
NumRecordsSynced: 0,
}, nil
}
Expand All @@ -460,7 +460,7 @@ func (c *BigQueryConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.S
}
syncBatchID = syncBatchID + 1

var res *model.SyncResponse
var res *protos.SyncResponse
if req.SyncMode == protos.QRepSyncMode_QREP_SYNC_MODE_STORAGE_AVRO {
res, err = c.syncRecordsViaAvro(req, rawTableName, syncBatchID)
if err != nil {
Expand All @@ -478,7 +478,7 @@ func (c *BigQueryConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.S
}

func (c *BigQueryConnector) syncRecordsViaSQL(req *model.SyncRecordsRequest,
rawTableName string, syncBatchID int64) (*model.SyncResponse, error) {
rawTableName string, syncBatchID int64) (*protos.SyncResponse, error) {
stagingTableName := c.getStagingTableName(req.FlowJobName)
stagingTable := c.client.Dataset(c.datasetID).Table(stagingTableName)
err := c.truncateTable(stagingTableName)
Expand Down Expand Up @@ -594,9 +594,9 @@ func (c *BigQueryConnector) syncRecordsViaSQL(req *model.SyncRecordsRequest,

numRecords := len(records)
if numRecords == 0 {
return &model.SyncResponse{
FirstSyncedCheckPointID: 0,
LastSyncedCheckPointID: 0,
return &protos.SyncResponse{
FirstSyncedCheckpointId: 0,
LastSyncedCheckpointId: 0,
NumRecordsSynced: 0,
}, nil
}
Expand Down Expand Up @@ -645,17 +645,17 @@ func (c *BigQueryConnector) syncRecordsViaSQL(req *model.SyncRecordsRequest,
metrics.LogSyncMetrics(c.ctx, req.FlowJobName, int64(numRecords), time.Since(startTime))
log.Printf("pushed %d records to %s.%s", numRecords, c.datasetID, rawTableName)

return &model.SyncResponse{
FirstSyncedCheckPointID: firstCP,
LastSyncedCheckPointID: lastCP,
return &protos.SyncResponse{
FirstSyncedCheckpointId: firstCP,
LastSyncedCheckpointId: lastCP,
NumRecordsSynced: int64(numRecords),
CurrentSyncBatchID: syncBatchID,
CurrentSyncBatchId: syncBatchID,
TableNameRowsMapping: tableNameRowsMapping,
}, nil
}

func (c *BigQueryConnector) syncRecordsViaAvro(req *model.SyncRecordsRequest,
rawTableName string, syncBatchID int64) (*model.SyncResponse, error) {
rawTableName string, syncBatchID int64) (*protos.SyncResponse, error) {
tableNameRowsMapping := make(map[string]uint32)
first := true
var firstCP int64 = 0
Expand Down Expand Up @@ -867,11 +867,11 @@ func (c *BigQueryConnector) syncRecordsViaAvro(req *model.SyncRecordsRequest,
metrics.LogSyncMetrics(c.ctx, req.FlowJobName, int64(numRecords), time.Since(startTime))
log.Printf("pushed %d records to %s.%s", numRecords, c.datasetID, rawTableName)

return &model.SyncResponse{
FirstSyncedCheckPointID: firstCP,
LastSyncedCheckPointID: lastCP,
return &protos.SyncResponse{
FirstSyncedCheckpointId: firstCP,
LastSyncedCheckpointId: lastCP,
NumRecordsSynced: int64(numRecords),
CurrentSyncBatchID: syncBatchID,
CurrentSyncBatchId: syncBatchID,
TableNameRowsMapping: tableNameRowsMapping,
}, nil
}
Expand Down
9 changes: 5 additions & 4 deletions flow/connectors/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,15 @@ type CDCPullConnector interface {
EnsurePullability(req *protos.EnsurePullabilityBatchInput) (
*protos.EnsurePullabilityBatchOutput, error)

// Methods related to retrieving and pusing records for this connector as a source and destination.

// PullRecords pulls records from the source, and returns a RecordBatch.
// This method should be idempotent, and should be able to be called multiple times with the same request.
PullRecords(req *model.PullRecordsRequest) (*model.RecordsWithTableSchemaDelta, error)
PullRecords(req *model.PullRecordsRequest) (*model.RecordsWithDeltaInfo, error)

// PullFlowCleanup drops both the Postgres publication and replication slot, as a part of DROP MIRROR
PullFlowCleanup(jobName string) error

// ListTablesInSchemas... gets all the tables in multiple schemas
ListTablesInSchemas(schemas []string) (map[string][]string, error)
}

type CDCSyncConnector interface {
Expand Down Expand Up @@ -69,7 +70,7 @@ type CDCSyncConnector interface {

// SyncRecords pushes records to the destination peer and stores it in PeerDB specific tables.
// This method should be idempotent, and should be able to be called multiple times with the same request.
SyncRecords(req *model.SyncRecordsRequest) (*model.SyncResponse, error)
SyncRecords(req *model.SyncRecordsRequest) (*protos.SyncResponse, error)

// SyncFlowCleanup drops metadata tables on the destination, as a part of DROP MIRROR.
SyncFlowCleanup(jobName string) error
Expand Down
8 changes: 4 additions & 4 deletions flow/connectors/eventhub/eventhub.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func (c *EventHubConnector) InitializeTableSchema(req map[string]*protos.TableSc
return nil
}

func (c *EventHubConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.SyncResponse, error) {
func (c *EventHubConnector) SyncRecords(req *model.SyncRecordsRequest) (*protos.SyncResponse, error) {
shutdown := utils.HeartbeatRoutine(c.ctx, 10*time.Second, func() string {
return fmt.Sprintf("syncing records to eventhub with"+
" push parallelism %d and push batch size %d",
Expand Down Expand Up @@ -219,9 +219,9 @@ func (c *EventHubConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.S
metrics.LogSyncMetrics(c.ctx, req.FlowJobName, int64(rowsSynced), time.Since(startTime))
metrics.LogNormalizeMetrics(c.ctx, req.FlowJobName, int64(rowsSynced),
time.Since(startTime), int64(rowsSynced))
return &model.SyncResponse{
FirstSyncedCheckPointID: batch.FirstCheckPointID,
LastSyncedCheckPointID: batch.LastCheckPointID,
return &protos.SyncResponse{
FirstSyncedCheckpointId: batch.FirstCheckPointID,
LastSyncedCheckpointId: batch.LastCheckPointID,
NumRecordsSynced: int64(len(batch.Records)),
TableNameRowsMapping: tableNameRowsMapping.Items(),
}, nil
Expand Down
Loading
Loading