Skip to content

Commit

Permalink
Merge branch 'main' into soft-delete-post-update
Browse files Browse the repository at this point in the history
  • Loading branch information
iskakaushik authored Nov 24, 2023
2 parents f4e9e10 + 13c1c57 commit 9538ca4
Show file tree
Hide file tree
Showing 8 changed files with 192 additions and 49 deletions.
96 changes: 79 additions & 17 deletions flow/activities/flowable.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package activities

import (
"context"
"database/sql"
"errors"
"fmt"
"sync"
Expand All @@ -12,14 +13,17 @@ import (
connpostgres "github.com/PeerDB-io/peer-flow/connectors/postgres"
connsnowflake "github.com/PeerDB-io/peer-flow/connectors/snowflake"
"github.com/PeerDB-io/peer-flow/connectors/utils"
catalog "github.com/PeerDB-io/peer-flow/connectors/utils/catalog"
"github.com/PeerDB-io/peer-flow/connectors/utils/monitoring"
"github.com/PeerDB-io/peer-flow/generated/protos"
"github.com/PeerDB-io/peer-flow/model"
"github.com/PeerDB-io/peer-flow/shared"
"github.com/jackc/pglogrepl"
"github.com/jackc/pgx/v5"
log "github.com/sirupsen/logrus"
"go.temporal.io/sdk/activity"
"golang.org/x/sync/errgroup"
"google.golang.org/protobuf/proto"
)

// CheckConnectionResult is the result of a CheckConnection call.
Expand Down Expand Up @@ -659,31 +663,89 @@ func (a *FlowableActivity) DropFlow(ctx context.Context, config *protos.Shutdown
return nil
}

func (a *FlowableActivity) SendWALHeartbeat(ctx context.Context, config *protos.FlowConnectionConfigs) error {
srcConn, err := connectors.GetCDCPullConnector(ctx, config.Source)
func getPostgresPeerConfigs(ctx context.Context) ([]*protos.Peer, error) {
var peerOptions sql.RawBytes
catalogPool, catalogErr := catalog.GetCatalogConnectionPoolFromEnv()
if catalogErr != nil {
return nil, fmt.Errorf("error getting catalog connection pool: %w", catalogErr)
}
defer catalogPool.Close()

optionRows, err := catalogPool.Query(ctx,
"SELECT name, options FROM peers WHERE type=$1", protos.DBType_POSTGRES)
if err != nil {
return fmt.Errorf("failed to get destination connector: %w", err)
return nil, err
}
defer connectors.CloseConnector(srcConn)
log.WithFields(log.Fields{"flowName": config.FlowJobName}).Info("sending walheartbeat every 10 minutes")
ticker := time.NewTicker(10 * time.Minute)
defer optionRows.Close()
var peerName string
var postgresPeers []*protos.Peer
for optionRows.Next() {
err := optionRows.Scan(&peerName, &peerOptions)
if err != nil {
return nil, err
}
var pgPeerConfig protos.PostgresConfig
unmarshalErr := proto.Unmarshal(peerOptions, &pgPeerConfig)
if unmarshalErr != nil {
return nil, unmarshalErr
}
postgresPeers = append(postgresPeers, &protos.Peer{
Name: peerName,
Type: protos.DBType_POSTGRES,
Config: &protos.Peer_PostgresConfig{PostgresConfig: &pgPeerConfig},
})
}
return postgresPeers, nil
}

func (a *FlowableActivity) SendWALHeartbeat(ctx context.Context) error {
sendTimeout := 10 * time.Minute
ticker := time.NewTicker(sendTimeout)
defer ticker.Stop()

pgPeers, err := getPostgresPeerConfigs(ctx)
if err != nil {
return fmt.Errorf("error getting postgres peer configs: %w", err)
}

activity.RecordHeartbeat(ctx, "sending walheartbeat every 10 minutes")
for {
select {
case <-ctx.Done():
log.WithFields(
log.Fields{
"flowName": config.FlowJobName,
}).Info("context is done, exiting wal heartbeat send loop")
log.Info("context is done, exiting wal heartbeat send loop")
return nil
case <-ticker.C:
err = srcConn.SendWALHeartbeat()
if err != nil {
return fmt.Errorf("failed to send WAL heartbeat: %w", err)
command := `
BEGIN;
DROP aggregate IF EXISTS PEERDB_EPHEMERAL_HEARTBEAT(float4);
CREATE AGGREGATE PEERDB_EPHEMERAL_HEARTBEAT(float4) (SFUNC = float4pl, STYPE = float4);
DROP aggregate PEERDB_EPHEMERAL_HEARTBEAT(float4);
END;
`
// run above command for each Postgres peer
for _, pgPeer := range pgPeers {
pgConfig := pgPeer.GetPostgresConfig()
peerConn, peerErr := pgx.Connect(ctx, utils.GetPGConnectionString(pgConfig))
if peerErr != nil {
return fmt.Errorf("error creating pool for postgres peer %v with host %v: %w",
pgPeer.Name, pgConfig.Host, peerErr)
}

_, err := peerConn.Exec(ctx, command)
if err != nil {
log.Warnf("warning: could not send walheartbeat to peer %v: %v", pgPeer.Name, err)
}

closeErr := peerConn.Close(ctx)
if closeErr != nil {
return fmt.Errorf("error closing postgres connection for peer %v with host %v: %w",
pgPeer.Name, pgConfig.Host, closeErr)
}
log.Infof("sent walheartbeat to peer %v", pgPeer.Name)
}
log.WithFields(
log.Fields{
"flowName": config.FlowJobName,
}).Info("sent wal heartbeat")
ticker.Stop()
ticker = time.NewTicker(sendTimeout)

}
}
}
Expand Down
45 changes: 45 additions & 0 deletions flow/cmd/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,16 @@ import (

utils "github.com/PeerDB-io/peer-flow/connectors/utils/catalog"
"github.com/PeerDB-io/peer-flow/generated/protos"
"github.com/PeerDB-io/peer-flow/shared"
peerflow "github.com/PeerDB-io/peer-flow/workflows"
"github.com/google/uuid"
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/reflection"

"go.temporal.io/api/workflowservice/v1"
"go.temporal.io/sdk/client"
"google.golang.org/grpc/health"
"google.golang.org/grpc/health/grpc_health_v1"
Expand Down Expand Up @@ -58,6 +62,27 @@ func setupGRPCGatewayServer(args *APIServerParams) (*http.Server, error) {
return server, nil
}

func killExistingHeartbeatFlows(ctx context.Context, tc client.Client, namespace string) error {
listRes, err := tc.ListWorkflow(ctx,
&workflowservice.ListWorkflowExecutionsRequest{
Namespace: namespace,
Query: "WorkflowType = 'HeartbeatFlowWorkflow'",
})
if err != nil {
return fmt.Errorf("unable to list workflows: %w", err)
}
log.Info("Requesting cancellation of pre-existing heartbeat flows")
for _, workflow := range listRes.Executions {
log.Info("Cancelling workflow: ", workflow.Execution.WorkflowId)
err := tc.CancelWorkflow(ctx,
workflow.Execution.WorkflowId, workflow.Execution.RunId)
if err != nil && err.Error() != "workflow execution already completed" {
return fmt.Errorf("unable to terminate workflow: %w", err)
}
}
return nil
}

func APIMain(args *APIServerParams) error {
ctx := args.ctx
clientOptions := client.Options{
Expand Down Expand Up @@ -91,6 +116,26 @@ func APIMain(args *APIServerParams) error {
flowHandler := NewFlowRequestHandler(tc, catalogConn)
defer flowHandler.Close()

err = killExistingHeartbeatFlows(ctx, tc, args.TemporalNamespace)
if err != nil {
return fmt.Errorf("unable to kill existing heartbeat flows: %w", err)
}

workflowID := fmt.Sprintf("heartbeatflow-%s", uuid.New())
workflowOptions := client.StartWorkflowOptions{
ID: workflowID,
TaskQueue: shared.PeerFlowTaskQueue,
}

_, err = flowHandler.temporalClient.ExecuteWorkflow(
ctx, // context
workflowOptions, // workflow start options
peerflow.HeartbeatFlowWorkflow, // workflow function
)
if err != nil {
return fmt.Errorf("unable to start heartbeat workflow: %w", err)
}

protos.RegisterFlowServiceServer(grpcServer, flowHandler)
grpc_health_v1.RegisterHealthServer(grpcServer, health.NewServer())
reflection.Register(grpcServer)
Expand Down
1 change: 1 addition & 0 deletions flow/cmd/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ func WorkerMain(opts *WorkerOptions) error {
w.RegisterWorkflow(peerflow.QRepFlowWorkflow)
w.RegisterWorkflow(peerflow.QRepPartitionWorkflow)
w.RegisterWorkflow(peerflow.DropFlowWorkflow)
w.RegisterWorkflow(peerflow.HeartbeatFlowWorkflow)
w.RegisterActivity(&activities.FlowableActivity{
CatalogMirrorMonitor: catalogMirrorMonitor,
})
Expand Down
24 changes: 11 additions & 13 deletions flow/connectors/postgres/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,19 +126,15 @@ func (c *PostgresConnector) getPrimaryKeyColumns(schemaTable *utils.SchemaTable)
return nil, fmt.Errorf("error getting primary key column for table %s: %w", schemaTable, err)
}
defer rows.Close()
// 0 rows returned, table has no primary keys
if !rows.Next() {
return nil, fmt.Errorf("table %s has no primary keys", schemaTable)
}
for {
if !rows.Next() {
break
}
err = rows.Scan(&pkCol)
if err != nil {
return nil, fmt.Errorf("error scanning primary key column for table %s: %w", schemaTable, err)
}
pkCols = append(pkCols, pkCol)
if !rows.Next() {
break
}
}

return pkCols, nil
Expand Down Expand Up @@ -314,13 +310,15 @@ func generateCreateTableSQLForNormalizedTable(sourceTableIdentifier string,
}

// add composite primary key to the table
primaryKeyColsQuoted := make([]string, 0)
for _, primaryKeyCol := range sourceTableSchema.PrimaryKeyColumns {
primaryKeyColsQuoted = append(primaryKeyColsQuoted,
fmt.Sprintf(`"%s"`, primaryKeyCol))
if len(sourceTableSchema.PrimaryKeyColumns) > 0 {
primaryKeyColsQuoted := make([]string, 0, len(sourceTableSchema.PrimaryKeyColumns))
for _, primaryKeyCol := range sourceTableSchema.PrimaryKeyColumns {
primaryKeyColsQuoted = append(primaryKeyColsQuoted,
fmt.Sprintf(`"%s"`, primaryKeyCol))
}
createTableSQLArray = append(createTableSQLArray, fmt.Sprintf("PRIMARY KEY(%s),",
strings.TrimSuffix(strings.Join(primaryKeyColsQuoted, ","), ",")))
}
createTableSQLArray = append(createTableSQLArray, fmt.Sprintf("PRIMARY KEY(%s),",
strings.TrimSuffix(strings.Join(primaryKeyColsQuoted, ","), ",")))

return fmt.Sprintf(createNormalizedTableSQL, sourceTableIdentifier,
strings.TrimSuffix(strings.Join(createTableSQLArray, ""), ","))
Expand Down
27 changes: 20 additions & 7 deletions flow/connectors/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,11 @@ func (c *PostgresConnector) getTableSchemaForTable(
return nil, fmt.Errorf("error getting replica identity for table %s: %w", schemaTable, replErr)
}

pKeyCols, err := c.getPrimaryKeyColumns(schemaTable)
if err != nil {
return nil, fmt.Errorf("error getting primary key column for table %s: %w", schemaTable, err)
}

// Get the column names and types
rows, err := c.pool.Query(c.ctx,
fmt.Sprintf(`SELECT * FROM %s LIMIT 0`, schemaTable.String()),
Expand All @@ -582,13 +587,6 @@ func (c *PostgresConnector) getTableSchemaForTable(
}
defer rows.Close()

pKeyCols, err := c.getPrimaryKeyColumns(schemaTable)
if err != nil {
if !isFullReplica {
return nil, fmt.Errorf("error getting primary key column for table %s: %w", schemaTable, err)
}
}

res := &protos.TableSchema{
TableIdentifier: tableName,
Columns: make(map[string]string),
Expand Down Expand Up @@ -744,6 +742,21 @@ func (c *PostgresConnector) EnsurePullability(req *protos.EnsurePullabilityBatch
return nil, err
}

isFullReplica, replErr := c.isTableFullReplica(schemaTable)
if replErr != nil {
return nil, fmt.Errorf("error getting replica identity for table %s: %w", schemaTable, replErr)
}

pKeyCols, err := c.getPrimaryKeyColumns(schemaTable)
if err != nil {
return nil, fmt.Errorf("error getting primary key column for table %s: %w", schemaTable, err)
}

// we only allow no primary key if the table has REPLICA IDENTITY FULL
if len(pKeyCols) == 0 && !isFullReplica {
return nil, fmt.Errorf("table %s has no primary keys and does not have REPLICA IDENTITY FULL", schemaTable)
}

tableIdentifierMapping[tableName] = &protos.TableIdentifier{
TableIdentifier: &protos.TableIdentifier_PostgresTableIdentifier{
PostgresTableIdentifier: &protos.PostgresTableIdentifier{
Expand Down
14 changes: 8 additions & 6 deletions flow/connectors/snowflake/snowflake.go
Original file line number Diff line number Diff line change
Expand Up @@ -783,13 +783,15 @@ func generateCreateTableSQLForNormalizedTable(
}

// add composite primary key to the table
primaryKeyColsUpperQuoted := make([]string, 0)
for _, primaryKeyCol := range sourceTableSchema.PrimaryKeyColumns {
primaryKeyColsUpperQuoted = append(primaryKeyColsUpperQuoted,
fmt.Sprintf(`"%s"`, strings.ToUpper(primaryKeyCol)))
if len(sourceTableSchema.PrimaryKeyColumns) > 0 {
primaryKeyColsUpperQuoted := make([]string, 0, len(sourceTableSchema.PrimaryKeyColumns))
for _, primaryKeyCol := range sourceTableSchema.PrimaryKeyColumns {
primaryKeyColsUpperQuoted = append(primaryKeyColsUpperQuoted,
fmt.Sprintf(`"%s"`, strings.ToUpper(primaryKeyCol)))
}
createTableSQLArray = append(createTableSQLArray, fmt.Sprintf("PRIMARY KEY(%s),",
strings.TrimSuffix(strings.Join(primaryKeyColsUpperQuoted, ","), ",")))
}
createTableSQLArray = append(createTableSQLArray, fmt.Sprintf("PRIMARY KEY(%s),",
strings.TrimSuffix(strings.Join(primaryKeyColsUpperQuoted, ","), ",")))

return fmt.Sprintf(createNormalizedTableSQL, sourceTableIdentifier,
strings.TrimSuffix(strings.Join(createTableSQLArray, ""), ","))
Expand Down
22 changes: 22 additions & 0 deletions flow/workflows/heartbeat_flow.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package peerflow

import (
"time"

"go.temporal.io/sdk/workflow"
)

// HeartbeatFlowWorkflow is the workflow that sets up heartbeat sending.
func HeartbeatFlowWorkflow(ctx workflow.Context) error {

ctx = workflow.WithActivityOptions(ctx, workflow.ActivityOptions{
StartToCloseTimeout: 7 * 24 * time.Hour,
})

heartbeatFuture := workflow.ExecuteActivity(ctx, flowable.SendWALHeartbeat)
if err := heartbeatFuture.Get(ctx, nil); err != nil {
return err
}

return nil
}
12 changes: 6 additions & 6 deletions ui/app/mirrors/create/schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,24 +58,24 @@ export const cdcSchema = z.object({
.optional(),
snapshotNumRowsPerPartition: z
.number({
invalid_type_error: 'Snapshow rows per partition must be a number',
invalid_type_error: 'Snapshot rows per partition must be a number',
})
.int()
.min(1, 'Snapshow rows per partition must be a positive integer')
.min(1, 'Snapshot rows per partition must be a positive integer')
.optional(),
snapshotMaxParallelWorkers: z
.number({
invalid_type_error: 'Snapshow max workers must be a number',
invalid_type_error: 'Snapshot max workers must be a number',
})
.int()
.min(1, 'Snapshow max workers must be a positive integer')
.min(1, 'Snapshot max workers must be a positive integer')
.optional(),
snapshotNumTablesInParallel: z
.number({
invalid_type_error: 'Snapshow parallel tables must be a number',
invalid_type_error: 'Snapshot parallel tables must be a number',
})
.int()
.min(1, 'Snapshow parallel tables must be a positive integer')
.min(1, 'Snapshot parallel tables must be a positive integer')
.optional(),
snapshotStagingPath: z
.string({
Expand Down

0 comments on commit 9538ca4

Please sign in to comment.