Skip to content

Commit

Permalink
SCHEMA MAPPING support, pending tests
Browse files Browse the repository at this point in the history
  • Loading branch information
heavycrystal committed Sep 25, 2023
1 parent bb69450 commit 4a2cf37
Show file tree
Hide file tree
Showing 22 changed files with 3,629 additions and 1,930 deletions.
130 changes: 118 additions & 12 deletions flow/activities/flowable.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

"github.com/PeerDB-io/peer-flow/connectors"
connpostgres "github.com/PeerDB-io/peer-flow/connectors/postgres"
connsnowflake "github.com/PeerDB-io/peer-flow/connectors/snowflake"
"github.com/PeerDB-io/peer-flow/connectors/utils"
"github.com/PeerDB-io/peer-flow/connectors/utils/monitoring"
"github.com/PeerDB-io/peer-flow/generated/protos"
Expand All @@ -17,12 +18,14 @@ import (
"github.com/jackc/pglogrepl"
log "github.com/sirupsen/logrus"
"go.temporal.io/sdk/activity"
"golang.org/x/exp/maps"
)

// CheckConnectionResult is the result of a CheckConnection call.
type CheckConnectionResult struct {
// True of metadata tables need to be set up.
// True if metadata tables need to be set up.
NeedsSetupMetadataTables bool
SupportsSchemaMapping bool
}

type SlotSnapshotSignal struct {
Expand All @@ -36,8 +39,17 @@ type FlowableActivity struct {
CatalogMirrorMonitor *monitoring.CatalogMirrorMonitor
}

// CheckConnection implements CheckConnection.
func (a *FlowableActivity) CheckConnection(
func (a *FlowableActivity) CheckPullConnection(ctx context.Context, config *protos.Peer) (bool, error) {
srcConn, err := connectors.GetCDCPullConnector(ctx, config)
if err != nil {
return false, fmt.Errorf("failed to get connector: %w", err)
}
defer connectors.CloseConnector(srcConn)

return srcConn.ConnectionActive(), nil
}

func (a *FlowableActivity) CheckSyncConnection(
ctx context.Context,
config *protos.Peer,
) (*CheckConnectionResult, error) {
Expand All @@ -49,8 +61,17 @@ func (a *FlowableActivity) CheckConnection(

needsSetup := dstConn.NeedsSetupMetadataTables()

supportsSchemaMapping := false
switch dstConn.(type) {
case *connpostgres.PostgresConnector:
supportsSchemaMapping = true
case *connsnowflake.SnowflakeConnector:
supportsSchemaMapping = true
}

return &CheckConnectionResult{
NeedsSetupMetadataTables: needsSetup,
SupportsSchemaMapping: supportsSchemaMapping,
}, nil
}

Expand Down Expand Up @@ -145,13 +166,13 @@ func (a *FlowableActivity) CreateNormalizedTable(
ctx context.Context,
config *protos.SetupNormalizedTableBatchInput,
) (*protos.SetupNormalizedTableBatchOutput, error) {
conn, err := connectors.GetCDCSyncConnector(ctx, config.PeerConnectionConfig)
dstConn, err := connectors.GetCDCSyncConnector(ctx, config.PeerConnectionConfig)
if err != nil {
return nil, fmt.Errorf("failed to get connector: %w", err)
}
defer connectors.CloseConnector(conn)
defer connectors.CloseConnector(dstConn)

return conn.SetupNormalizedTables(config)
return dstConn.SetupNormalizedTables(config)
}

// StartFlow implements StartFlow.
Expand Down Expand Up @@ -187,8 +208,12 @@ func (a *FlowableActivity) StartFlow(ctx context.Context,
"flowName": input.FlowConnectionConfigs.FlowJobName,
}).Info("pulling records...")

var schemas []string
if input.FlowConnectionConfigs.MappingType == protos.MappingType_SCHEMA {
schemas = maps.Keys(input.FlowConnectionConfigs.SchemaMapping)
}
startTime := time.Now()
recordsWithTableSchemaDelta, err := srcConn.PullRecords(&model.PullRecordsRequest{
recordsWithDeltaInfo, err := srcConn.PullRecords(&model.PullRecordsRequest{
FlowJobName: input.FlowConnectionConfigs.FlowJobName,
SrcTableIDNameMapping: input.FlowConnectionConfigs.SrcTableIdNameMapping,
TableNameMapping: input.FlowConnectionConfigs.TableNameMapping,
Expand All @@ -199,11 +224,12 @@ func (a *FlowableActivity) StartFlow(ctx context.Context,
OverridePublicationName: input.FlowConnectionConfigs.PublicationName,
OverrideReplicationSlotName: input.FlowConnectionConfigs.ReplicationSlotName,
RelationMessageMapping: input.RelationMessageMapping,
Schemas: schemas,
})
if err != nil {
return nil, fmt.Errorf("failed to pull records: %w", err)
}
recordBatch := recordsWithTableSchemaDelta.RecordBatch
recordBatch := recordsWithDeltaInfo.RecordBatch

pullRecordWithCount := fmt.Sprintf("pulled %d records", len(recordBatch.Records))
activity.RecordHeartbeat(ctx, pullRecordWithCount)
Expand Down Expand Up @@ -239,8 +265,8 @@ func (a *FlowableActivity) StartFlow(ctx context.Context,
"flowName": input.FlowConnectionConfigs.FlowJobName,
}).Info("no records to push")
return &model.SyncResponse{
RelationMessageMapping: recordsWithTableSchemaDelta.RelationMessageMapping,
TableSchemaDelta: recordsWithTableSchemaDelta.TableSchemaDelta,
RelationMessageMapping: recordsWithDeltaInfo.RelationMessageMapping,
TableSchemaDelta: recordsWithDeltaInfo.TableSchemaDelta,
}, nil
}

Expand Down Expand Up @@ -288,10 +314,11 @@ func (a *FlowableActivity) StartFlow(ctx context.Context,
if err != nil {
return nil, err
}
res.TableSchemaDelta = recordsWithTableSchemaDelta.TableSchemaDelta
res.RelationMessageMapping = recordsWithTableSchemaDelta.RelationMessageMapping
res.TableSchemaDelta = recordsWithDeltaInfo.TableSchemaDelta
res.RelationMessageMapping = recordsWithDeltaInfo.RelationMessageMapping

pushedRecordsWithCount := fmt.Sprintf("pushed %d records", numRecords)
res.AdditionalTableInfo = recordsWithDeltaInfo.AdditionalTableInfo
activity.RecordHeartbeat(ctx, pushedRecordsWithCount)

return res, nil
Expand Down Expand Up @@ -617,3 +644,82 @@ func (a *FlowableActivity) DropFlow(ctx context.Context, config *protos.Shutdown
}
return nil
}

// GetTableSchema returns the schema of a table.
func (a *FlowableActivity) PopulateTableMappingFromSchemas(
ctx context.Context,
config *protos.PopulateTableMappingFromSchemasInput,
) (*protos.PopulateTableMappingFromSchemasOutput, error) {
srcConn, err := connectors.GetCDCPullConnector(ctx, config.PeerConnectionConfig)
if err != nil {
return nil, fmt.Errorf("failed to get source connector: %w", err)
}
defer connectors.CloseConnector(srcConn)

srcSchemas := maps.Keys(config.SchemaMapping)

schemaTablesMapping, err := srcConn.GetAllTablesInSchemas(srcSchemas)
if err != nil {
return nil, fmt.Errorf("failed to get schemaTablesMapping: %w", err)
}

schemaTablesMappingProto := make(map[string]*protos.TablesList)
for schema, tables := range schemaTablesMapping {
schemaTablesMappingProto[schema] = &protos.TablesList{
Tables: tables,
}
}

return &protos.PopulateTableMappingFromSchemasOutput{
TableMapping: schemaTablesMappingProto,
}, nil
}

func (a *FlowableActivity) CreateAdditionalTable(
ctx context.Context,
input *protos.CreateAdditionalTableInput) (*protos.AdditionalTableInfo, error) {
srcConn, err := connectors.GetCDCPullConnector(ctx, input.FlowConnectionConfigs.Source)
if err != nil {
return nil, fmt.Errorf("failed to get source connector: %w", err)
}
defer connectors.CloseConnector(srcConn)

dstConn, err := connectors.GetCDCSyncConnector(ctx, input.FlowConnectionConfigs.Destination)
if err != nil {
return nil, fmt.Errorf("failed to get destination connector: %w", err)
}
defer connectors.CloseConnector(dstConn)

srcTableIdentifier := fmt.Sprintf("%s.%s", input.AdditionalTableInfo.SrcSchema,
input.AdditionalTableInfo.TableName)
dstTableIdentifier := fmt.Sprintf("%s.%s", input.AdditionalTableInfo.DstSchema,
input.AdditionalTableInfo.TableName)

tableRelIDMapping, err := srcConn.EnsurePullability(&protos.EnsurePullabilityBatchInput{
SourceTableIdentifiers: []string{srcTableIdentifier},
})
if err != nil {
return nil, fmt.Errorf("failed to ensure pullability for additional table: %w", err)
}

tableNameSchemaMapping, err := srcConn.GetTableSchema(&protos.GetTableSchemaBatchInput{
TableIdentifiers: []string{srcTableIdentifier},
})
if err != nil {
return nil, fmt.Errorf("failed to get schema for additional table: %w", err)
}

_, err = dstConn.SetupNormalizedTables(&protos.SetupNormalizedTableBatchInput{
TableNameSchemaMapping: map[string]*protos.TableSchema{
dstTableIdentifier: tableNameSchemaMapping.TableNameSchemaMapping[srcTableIdentifier]},
})
if err != nil {
return nil, fmt.Errorf("failed to create additional table at destination: %w", err)
}

input.AdditionalTableInfo.RelId = tableRelIDMapping.
TableIdentifierMapping[srcTableIdentifier].GetPostgresTableIdentifier().RelId
input.AdditionalTableInfo.TableSchema = tableNameSchemaMapping.TableNameSchemaMapping[srcTableIdentifier]

return input.AdditionalTableInfo, nil
}
7 changes: 4 additions & 3 deletions flow/connectors/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,15 @@ type CDCPullConnector interface {
EnsurePullability(req *protos.EnsurePullabilityBatchInput) (
*protos.EnsurePullabilityBatchOutput, error)

// Methods related to retrieving and pusing records for this connector as a source and destination.

// PullRecords pulls records from the source, and returns a RecordBatch.
// This method should be idempotent, and should be able to be called multiple times with the same request.
PullRecords(req *model.PullRecordsRequest) (*model.RecordsWithTableSchemaDelta, error)
PullRecords(req *model.PullRecordsRequest) (*model.RecordsWithDeltaInfo, error)

// PullFlowCleanup drops both the Postgres publication and replication slot, as a part of DROP MIRROR
PullFlowCleanup(jobName string) error

// GetAllTablesInSchemas... gets all the tables in multiple schemas
GetAllTablesInSchemas(schemas []string) (map[string][]string, error)
}

type CDCSyncConnector interface {
Expand Down
Loading

0 comments on commit 4a2cf37

Please sign in to comment.