Skip to content

Commit

Permalink
stashing before context switch
Browse files Browse the repository at this point in the history
  • Loading branch information
heavycrystal committed Jan 15, 2024
1 parent 1b906e7 commit 2782cbb
Show file tree
Hide file tree
Showing 17 changed files with 410 additions and 93 deletions.
21 changes: 21 additions & 0 deletions flow/activities/flowable.go
Original file line number Diff line number Diff line change
Expand Up @@ -988,3 +988,24 @@ func (a *FlowableActivity) ReplicateXminPartition(ctx context.Context,

return currentSnapshotXmin, nil
}

func (a *FlowableActivity) AddTablesToPublication(ctx context.Context, cfg *protos.FlowConnectionConfigs,
additionalTableMappings []*protos.TableMapping,
) error {
ctx = context.WithValue(ctx, shared.FlowNameKey, cfg.FlowJobName)
srcConn, err := connectors.GetCDCPullConnector(ctx, cfg.Source)
if err != nil {
return fmt.Errorf("failed to get source connector: %w", err)
}
defer connectors.CloseConnector(srcConn)

err = srcConn.AddTablesToPublication(&protos.AddTablesToPublicationInput{
FlowJobName: cfg.FlowJobName,
PublicationName: cfg.PublicationName,
AdditionalTables: additionalTableMappings,
})
if err != nil {
a.Alerter.LogFlowError(ctx, cfg.FlowJobName, err)
}
return err
}
26 changes: 19 additions & 7 deletions flow/cmd/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ func (h *FlowRequestHandler) ShutdownFlow(
ctx,
req.WorkflowId,
"",
shared.CDCFlowSignalName,
shared.FlowSignalName,
shared.ShutdownSignal,
)
if err != nil {
Expand Down Expand Up @@ -444,8 +444,13 @@ func (h *FlowRequestHandler) FlowStateChange(
if err != nil {
return nil, err
}
isCDCFlow, err := h.isCDCFlow(ctx, req.FlowJobName)
if err != nil {
return nil, err
}

if req.RequestedFlowState == protos.FlowStatus_STATUS_PAUSED &&
*currState == protos.FlowStatus_STATUS_RUNNING {
currState == protos.FlowStatus_STATUS_RUNNING {
err = h.updateWorkflowStatus(ctx, workflowID, protos.FlowStatus_STATUS_PAUSING)
if err != nil {
return nil, err
Expand All @@ -454,20 +459,27 @@ func (h *FlowRequestHandler) FlowStateChange(
ctx,
workflowID,
"",
shared.CDCFlowSignalName,
shared.FlowSignalName,
shared.PauseSignal,
)
} else if req.RequestedFlowState == protos.FlowStatus_STATUS_RUNNING &&
*currState == protos.FlowStatus_STATUS_PAUSED {
currState == protos.FlowStatus_STATUS_PAUSED {
if isCDCFlow && req.FlowConfigUpdate.GetCdcFlowConfigUpdate() != nil {
err = h.attemptCDCFlowConfigUpdate(ctx, workflowID,
req.FlowConfigUpdate.GetCdcFlowConfigUpdate())
if err != nil {
return nil, err
}
}
err = h.temporalClient.SignalWorkflow(
ctx,
workflowID,
"",
shared.CDCFlowSignalName,
shared.FlowSignalName,
shared.NoopSignal,
)
} else if req.RequestedFlowState == protos.FlowStatus_STATUS_TERMINATED &&
(*currState == protos.FlowStatus_STATUS_RUNNING || *currState == protos.FlowStatus_STATUS_PAUSED) {
(currState == protos.FlowStatus_STATUS_RUNNING || currState == protos.FlowStatus_STATUS_PAUSED) {
err = h.updateWorkflowStatus(ctx, workflowID, protos.FlowStatus_STATUS_TERMINATING)
if err != nil {
return nil, err
Expand All @@ -484,7 +496,7 @@ func (h *FlowRequestHandler) FlowStateChange(
req.RequestedFlowState, currState)
}
if err != nil {
return nil, fmt.Errorf("unable to signal CDCFlow workflow: %w", err)
return nil, fmt.Errorf("unable to signal workflow: %w", err)
}

return &protos.FlowStateChangeResponse{
Expand Down
47 changes: 41 additions & 6 deletions flow/cmd/mirror_status.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func (h *FlowRequestHandler) MirrorStatus(
Status: &protos.MirrorStatusResponse_CdcStatus{
CdcStatus: cdcStatus,
},
CurrentFlowState: *currState,
CurrentFlowState: currState,
}, nil
} else {
qrepStatus, err := h.QRepFlowStatus(ctx, req)
Expand All @@ -66,7 +66,7 @@ func (h *FlowRequestHandler) MirrorStatus(
Status: &protos.MirrorStatusResponse_QrepStatus{
QrepStatus: qrepStatus,
},
CurrentFlowState: *currState,
CurrentFlowState: currState,
}, nil
}
}
Expand Down Expand Up @@ -334,17 +334,41 @@ func (h *FlowRequestHandler) isCDCFlow(ctx context.Context, flowJobName string)
return false, nil
}

func (h *FlowRequestHandler) getWorkflowStatus(ctx context.Context, workflowID string) (*protos.FlowStatus, error) {
func (h *FlowRequestHandler) getCloneTableFlowNames(ctx context.Context, flowJobName string) ([]string, error) {
q := "SELECT flow_name FROM peerdb_stats.qrep_runs WHERE flow_name ILIKE $1"
rows, err := h.pool.Query(ctx, q, "clone_"+flowJobName+"_%")
if err != nil {
return nil, fmt.Errorf("unable to getCloneTableFlowNames: %w", err)
}
defer rows.Close()

flowNames := []string{}
for rows.Next() {
var name pgtype.Text
if err := rows.Scan(&name); err != nil {
return nil, fmt.Errorf("unable to scan flow row: %w", err)
}
if name.Valid {
flowNames = append(flowNames, name.String)
}
}

return flowNames, nil
}

func (h *FlowRequestHandler) getWorkflowStatus(ctx context.Context, workflowID string) (protos.FlowStatus, error) {
res, err := h.temporalClient.QueryWorkflow(ctx, workflowID, "", shared.FlowStatusQuery)
if err != nil {
slog.Error(fmt.Sprintf("failed to get state in workflow with ID %s: %s", workflowID, err.Error()))
return nil, fmt.Errorf("failed to get state in workflow with ID %s: %w", workflowID, err)
return protos.FlowStatus_STATUS_UNKNOWN,
fmt.Errorf("failed to get state in workflow with ID %s: %w", workflowID, err)
}
var state *protos.FlowStatus
var state protos.FlowStatus
err = res.Get(&state)
if err != nil {
slog.Error(fmt.Sprintf("failed to get state in workflow with ID %s: %s", workflowID, err.Error()))
return nil, fmt.Errorf("failed to get state in workflow with ID %s: %w", workflowID, err)
return protos.FlowStatus_STATUS_UNKNOWN,
fmt.Errorf("failed to get state in workflow with ID %s: %w", workflowID, err)
}
return state, nil
}
Expand All @@ -361,3 +385,14 @@ func (h *FlowRequestHandler) updateWorkflowStatus(
}
return nil
}

func (h *FlowRequestHandler) attemptCDCFlowConfigUpdate(ctx context.Context,
workflowID string, cdcFlowConfigUpdate *protos.CDCFlowConfigUpdate,
) error {
_, err := h.temporalClient.UpdateWorkflow(ctx, workflowID, "",
shared.CDCFlowConfigUpdate, cdcFlowConfigUpdate)
if err != nil {
return fmt.Errorf("failed to update config in CDC workflow with ID %s: %w", workflowID, err)
}
return nil
}
3 changes: 3 additions & 0 deletions flow/connectors/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ type CDCPullConnector interface {

// GetOpenConnectionsForUser returns the number of open connections for the user configured in the peer.
GetOpenConnectionsForUser() (*protos.GetOpenConnectionsForUserResult, error)

// AddTablesToPublication adds additional tables added to a mirror to the publication also
AddTablesToPublication(req *protos.AddTablesToPublicationInput) error
}

type CDCSyncConnector interface {
Expand Down
2 changes: 1 addition & 1 deletion flow/connectors/postgres/cdc.go
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,7 @@ func (p *PostgresCDCSource) processMessage(batch *model.CDCRecordStream, xld pgl
// TODO (kaushik): consider persistent state for a mirror job
// to be stored somewhere in temporal state. We might need to persist
// the state of the relation message somewhere
p.logger.Debug(fmt.Sprintf("RelationMessage => RelationID: %d, Namespace: %s, RelationName: %s, Columns: %v",
p.logger.Info(fmt.Sprintf("RelationMessage => RelationID: %d, Namespace: %s, RelationName: %s, Columns: %v",
msg.RelationID, msg.Namespace, msg.RelationName, msg.Columns))
if p.relationMessageMapping[msg.RelationID] == nil {
p.relationMessageMapping[msg.RelationID] = convertRelationMessageToProto(msg)
Expand Down
4 changes: 4 additions & 0 deletions flow/connectors/postgres/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -606,3 +606,7 @@ func (c *PostgresConnector) getCurrentLSN() (pglogrepl.LSN, error) {
}
return pglogrepl.ParseLSN(result.String)
}

func (c *PostgresConnector) getDefaultPublicationName(jobName string) string {
return fmt.Sprintf("peerflow_pub_%s", jobName)
}
30 changes: 24 additions & 6 deletions flow/connectors/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"log/slog"
"regexp"
"strings"
"time"

"github.com/PeerDB-io/peer-flow/connectors/utils"
Expand Down Expand Up @@ -206,8 +207,7 @@ func (c *PostgresConnector) PullRecords(catalogPool *pgxpool.Pool, req *model.Pu
slotName = req.OverrideReplicationSlotName
}

// Publication name would be the job name prefixed with "peerflow_pub_"
publicationName := fmt.Sprintf("peerflow_pub_%s", req.FlowJobName)
publicationName := c.getDefaultPublicationName(req.FlowJobName)
if req.OverridePublicationName != "" {
publicationName = req.OverridePublicationName
}
Expand Down Expand Up @@ -820,8 +820,7 @@ func (c *PostgresConnector) SetupReplication(signal SlotSignal, req *protos.Setu
slotName = req.ExistingReplicationSlotName
}

// Publication name would be the job name prefixed with "peerflow_pub_"
publicationName := fmt.Sprintf("peerflow_pub_%s", req.FlowJobName)
publicationName := c.getDefaultPublicationName(req.FlowJobName)
if req.ExistingPublicationName != "" {
publicationName = req.ExistingPublicationName
}
Expand Down Expand Up @@ -853,8 +852,7 @@ func (c *PostgresConnector) PullFlowCleanup(jobName string) error {
// Slotname would be the job name prefixed with "peerflow_slot_"
slotName := fmt.Sprintf("peerflow_slot_%s", jobName)

// Publication name would be the job name prefixed with "peerflow_pub_"
publicationName := fmt.Sprintf("peerflow_pub_%s", jobName)
publicationName := c.getDefaultPublicationName(jobName)

pullFlowCleanupTx, err := c.pool.Begin(c.ctx)
if err != nil {
Expand Down Expand Up @@ -932,3 +930,23 @@ func (c *PostgresConnector) GetOpenConnectionsForUser() (*protos.GetOpenConnecti
CurrentOpenConnections: result.Int64,
}, nil
}

func (c *PostgresConnector) AddTablesToPublication(req *protos.AddTablesToPublicationInput) error {
// don't modify custom publications
if req == nil || req.PublicationName != "" || len(req.AdditionalTables) == 0 {
return nil
}

additionalSrcTables := make([]string, 0, len(req.AdditionalTables))
for _, additionalTableMapping := range req.AdditionalTables {
additionalSrcTables = append(additionalSrcTables, additionalTableMapping.SourceTableIdentifier)
}
additionalSrcTablesString := strings.Join(additionalSrcTables, ",")

_, err := c.pool.Exec(c.ctx, fmt.Sprintf("ALTER PUBLICATION %s ADD TABLE %s",
c.getDefaultPublicationName(req.FlowJobName), additionalSrcTablesString))
if err != nil {
return fmt.Errorf("failed to alter publication: %w", err)
}
return nil
}
22 changes: 19 additions & 3 deletions flow/connectors/utils/array.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
package utils

func ArrayMinus(first []string, second []string) []string {
lookup := make(map[string]struct{}, len(second))
func ArrayMinus[T comparable](first, second []T) []T {
lookup := make(map[T]struct{}, len(second))
// Add elements from arrayB to the lookup map
for _, element := range second {
lookup[element] = struct{}{}
}
// Iterate over arrayA and check if the element is present in the lookup map
var result []string
var result []T
for _, element := range first {
_, exists := lookup[element]
if !exists {
Expand All @@ -29,3 +29,19 @@ func ArrayChunks[T any](slice []T, size int) [][]T {

return partitions
}

func ArraysHaveOverlap[T comparable](first, second []T) bool {
lookup := make(map[T]struct{})

for _, element := range first {
lookup[element] = struct{}{}
}

for _, element := range second {
if _, exists := lookup[element]; exists {
return true
}
}

return false
}
52 changes: 22 additions & 30 deletions flow/e2e/test_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,38 +135,30 @@ func EnvWaitForEqualTables(
table string,
cols string,
) {
suite.T().Helper()
EnvWaitForEqualTablesWithNames(env, suite, reason, table, table, cols)
}

func EnvWaitForEqualTablesWithNames(
env *testsuite.TestWorkflowEnvironment,
suite e2eshared.RowSource,
reason string,
srcTable string,
dstTable string,
cols string,
) {
t := suite.T()
t.Helper()

EnvWaitFor(t, env, 3*time.Minute, reason, func() bool {
t.Helper()

suffix := suite.Suffix()
pool := suite.Pool()
pgRows, err := GetPgRows(pool, suffix, srcTable, cols)
if err != nil {
return false
}
// wait for PeerFlowStatusQuery to finish setup
// sleep for 5 second to allow the workflow to start
time.Sleep(5 * time.Second)
for {
response, err := env.QueryWorkflow(
shared.CDCFlowStateQuery,
connectionGen.FlowJobName,
)
if err == nil {
var state peerflow.CDCFlowWorkflowState
err = response.Get(&state)
if err != nil {
slog.Error(err.Error())
}

rows, err := suite.GetRows(dstTable, cols)
if err != nil {
return false
if *state.CurrentFlowState == protos.FlowStatus_STATUS_RUNNING {
break
}
} else {
// log the error for informational purposes
slog.Error(err.Error())
}

return e2eshared.CheckEqualRecordBatches(t, pgRows, rows)
})
time.Sleep(1 * time.Second)
}
}

func SetupCDCFlowStatusQuery(t *testing.T, env *testsuite.TestWorkflowEnvironment,
Expand Down
5 changes: 3 additions & 2 deletions flow/shared/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ const (
snapshotFlowTaskQueue = "snapshot-flow-task-queue"

// Signals
CDCFlowSignalName = "peer-flow-signal"
FlowSignalName = "peer-flow-signal"
CDCDynamicPropertiesSignalName = "cdc-dynamic-properties"

// Queries
Expand All @@ -21,7 +21,8 @@ const (
FlowStatusQuery = "q-flow-status"

// Updates
FlowStatusUpdate = "u-flow-status"
FlowStatusUpdate = "u-flow-status"
CDCFlowConfigUpdate = "u-cdc-flow-config-update"
)

const MirrorNameSearchAttribute = "MirrorName"
Expand Down
Loading

0 comments on commit 2782cbb

Please sign in to comment.