Skip to content

Commit

Permalink
Merge branch 'main' into fix-parallel-lsn
Browse files Browse the repository at this point in the history
  • Loading branch information
serprex authored May 10, 2024
2 parents 317e35d + ffd494f commit 7113344
Show file tree
Hide file tree
Showing 28 changed files with 415 additions and 190 deletions.
23 changes: 20 additions & 3 deletions flow/activities/flowable_core.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/yuin/gopher-lua"
"go.temporal.io/sdk/activity"
"go.temporal.io/sdk/log"
"go.temporal.io/sdk/temporal"
Expand All @@ -23,6 +24,7 @@ import (
"github.com/PeerDB-io/peer-flow/generated/protos"
"github.com/PeerDB-io/peer-flow/model"
"github.com/PeerDB-io/peer-flow/peerdbenv"
"github.com/PeerDB-io/peer-flow/pua"
"github.com/PeerDB-io/peer-flow/shared"
)

Expand Down Expand Up @@ -343,10 +345,25 @@ func (a *FlowableActivity) replicateQRepPartition(ctx context.Context,
})
defer shutdown()

var rowsSynced int
bufferSize := shared.FetchAndChannelSize
errGroup, errCtx := errgroup.WithContext(ctx)
stream := model.NewQRecordStream(bufferSize)
outstream := stream
if config.Script != "" {
ls, err := utils.LoadScript(ctx, config.Script, utils.LuaPrintFn(func(s string) {
a.Alerter.LogFlowInfo(ctx, config.FlowJobName, s)
}))
if err != nil {
a.Alerter.LogFlowError(ctx, config.FlowJobName, err)
return err
}
lfn := ls.Env.RawGetString("transformRow")
if fn, ok := lfn.(*lua.LFunction); ok {
outstream = pua.AttachToStream(ls, fn, stream)
}
}

var rowsSynced int
errGroup, errCtx := errgroup.WithContext(ctx)
errGroup.Go(func() error {
tmp, err := srcConn.PullQRepRecords(errCtx, config, partition, stream)
if err != nil {
Expand All @@ -363,7 +380,7 @@ func (a *FlowableActivity) replicateQRepPartition(ctx context.Context,
})

errGroup.Go(func() error {
rowsSynced, err = dstConn.SyncQRepRecords(errCtx, config, partition, stream)
rowsSynced, err = dstConn.SyncQRepRecords(errCtx, config, partition, outstream)
if err != nil {
a.Alerter.LogFlowError(ctx, config.FlowJobName, err)
return fmt.Errorf("failed to sync records: %w", err)
Expand Down
22 changes: 6 additions & 16 deletions flow/cmd/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,27 +142,19 @@ func (h *FlowRequestHandler) CreateCDCFlow(

if req.ConnectionConfigs.SoftDeleteColName == "" {
req.ConnectionConfigs.SoftDeleteColName = "_PEERDB_IS_DELETED"
} else {
// make them all uppercase
req.ConnectionConfigs.SoftDeleteColName = strings.ToUpper(req.ConnectionConfigs.SoftDeleteColName)
}

if req.ConnectionConfigs.SyncedAtColName == "" {
req.ConnectionConfigs.SyncedAtColName = "_PEERDB_SYNCED_AT"
} else {
// make them all uppercase
req.ConnectionConfigs.SyncedAtColName = strings.ToUpper(req.ConnectionConfigs.SyncedAtColName)
}

if req.CreateCatalogEntry {
err := h.createCdcJobEntry(ctx, req, workflowID)
if err != nil {
slog.Error("unable to create flow job entry", slog.Any("error", err))
return nil, fmt.Errorf("unable to create flow job entry: %w", err)
}
err := h.createCdcJobEntry(ctx, req, workflowID)
if err != nil {
slog.Error("unable to create flow job entry", slog.Any("error", err))
return nil, fmt.Errorf("unable to create flow job entry: %w", err)
}

err := h.updateFlowConfigInCatalog(ctx, cfg)
err = h.updateFlowConfigInCatalog(ctx, cfg)
if err != nil {
slog.Error("unable to update flow config in catalog", slog.Any("error", err))
return nil, fmt.Errorf("unable to update flow config in catalog: %w", err)
Expand Down Expand Up @@ -258,10 +250,8 @@ func (h *FlowRequestHandler) CreateQRepFlow(

if req.QrepConfig.SyncedAtColName == "" {
cfg.SyncedAtColName = "_PEERDB_SYNCED_AT"
} else {
// make them all uppercase
cfg.SyncedAtColName = strings.ToUpper(req.QrepConfig.SyncedAtColName)
}

_, err := h.temporalClient.ExecuteWorkflow(ctx, workflowOptions, workflowFn, cfg, state)
if err != nil {
slog.Error("unable to start QRepFlow workflow",
Expand Down
28 changes: 15 additions & 13 deletions flow/cmd/validate_mirror.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import (
func (h *FlowRequestHandler) ValidateCDCMirror(
ctx context.Context, req *protos.CreateCDCFlowRequest,
) (*protos.ValidateCDCMirrorResponse, error) {
if req.CreateCatalogEntry && !req.ConnectionConfigs.Resync {
if !req.ConnectionConfigs.Resync {
mirrorExists, existCheckErr := h.CheckIfMirrorNameExists(ctx, req.ConnectionConfigs.FlowJobName)
if existCheckErr != nil {
slog.Error("/validatecdc failed to check if mirror name exists", slog.Any("error", existCheckErr))
Expand Down Expand Up @@ -46,7 +46,9 @@ func (h *FlowRequestHandler) ValidateCDCMirror(
sourcePeerConfig := req.ConnectionConfigs.Source.GetPostgresConfig()
if sourcePeerConfig == nil {
slog.Error("/validatecdc source peer config is nil", slog.Any("peer", req.ConnectionConfigs.Source))
return nil, errors.New("source peer config is nil")
return &protos.ValidateCDCMirrorResponse{
Ok: false,
}, errors.New("source peer config is nil")
}

pgPeer, err := connpostgres.NewPostgresConnector(ctx, sourcePeerConfig)
Expand Down Expand Up @@ -103,17 +105,17 @@ func (h *FlowRequestHandler) ValidateCDCMirror(
}

pubName := req.ConnectionConfigs.PublicationName
if pubName != "" {
err = pgPeer.CheckSourceTables(ctx, sourceTables, pubName)
if err != nil {
displayErr := fmt.Errorf("provided source tables invalidated: %v", err)
h.alerter.LogNonFlowWarning(ctx, telemetry.CreateMirror, req.ConnectionConfigs.FlowJobName,
fmt.Sprint(displayErr),
)
return &protos.ValidateCDCMirrorResponse{
Ok: false,
}, displayErr
}

err = pgPeer.CheckSourceTables(ctx, sourceTables, pubName)
if err != nil {
displayErr := fmt.Errorf("provided source tables invalidated: %v", err)
slog.Error(displayErr.Error())
h.alerter.LogNonFlowWarning(ctx, telemetry.CreateMirror, req.ConnectionConfigs.FlowJobName,
fmt.Sprint(displayErr),
)
return &protos.ValidateCDCMirrorResponse{
Ok: false,
}, displayErr
}

return &protos.ValidateCDCMirrorResponse{
Expand Down
18 changes: 14 additions & 4 deletions flow/connectors/clickhouse/qrep_avro_sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,15 @@ func (s *ClickhouseAvroSyncMethod) CopyStageToDestination(ctx context.Context, a
if err != nil {
return err
}

sessionTokenPart := ""
if creds.AWS.SessionToken != "" {
sessionTokenPart = fmt.Sprintf(", '%s'", creds.AWS.SessionToken)
}
//nolint:gosec
query := fmt.Sprintf("INSERT INTO %s SELECT * FROM s3('%s','%s','%s', '%s', 'Avro')",
query := fmt.Sprintf("INSERT INTO %s SELECT * FROM s3('%s','%s','%s'%s, 'Avro')",
s.config.DestinationTableIdentifier, avroFileUrl,
creds.AWS.AccessKeyID, creds.AWS.SecretAccessKey, creds.AWS.SessionToken)
creds.AWS.AccessKeyID, creds.AWS.SecretAccessKey, sessionTokenPart)

_, err = s.connector.database.ExecContext(ctx, query)

Expand Down Expand Up @@ -137,10 +142,15 @@ func (s *ClickhouseAvroSyncMethod) SyncQRepRecords(
selector = append(selector, "`"+colName+"`")
}
selectorStr := strings.Join(selector, ",")

sessionTokenPart := ""
if creds.AWS.SessionToken != "" {
sessionTokenPart = fmt.Sprintf(", '%s'", creds.AWS.SessionToken)
}
//nolint:gosec
query := fmt.Sprintf("INSERT INTO %s(%s) SELECT %s FROM s3('%s','%s','%s', '%s', 'Avro')",
query := fmt.Sprintf("INSERT INTO %s(%s) SELECT %s FROM s3('%s','%s','%s'%s, 'Avro')",
config.DestinationTableIdentifier, selectorStr, selectorStr, avroFileUrl,
creds.AWS.AccessKeyID, creds.AWS.SecretAccessKey, creds.AWS.SessionToken)
creds.AWS.AccessKeyID, creds.AWS.SecretAccessKey, sessionTokenPart)

_, err = s.connector.database.ExecContext(ctx, query)
if err != nil {
Expand Down
13 changes: 3 additions & 10 deletions flow/connectors/eventhub/eventhub.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"fmt"
"log/slog"
"strings"
"sync/atomic"
"time"

Expand Down Expand Up @@ -196,15 +195,9 @@ func (c *EventHubConnector) processBatch(
var fn *lua.LFunction
if req.Script != "" {
var err error
ls, err = utils.LoadScript(ctx, req.Script, func(ls *lua.LState) int {
top := ls.GetTop()
ss := make([]string, top)
for i := range top {
ss[i] = ls.ToStringMeta(ls.Get(i + 1)).String()
}
_ = c.LogFlowInfo(ctx, req.FlowJobName, strings.Join(ss, "\t"))
return 0
})
ls, err = utils.LoadScript(ctx, req.Script, utils.LuaPrintFn(func(s string) {
_ = c.LogFlowInfo(ctx, req.FlowJobName, s)
}))
if err != nil {
return 0, err
}
Expand Down
13 changes: 3 additions & 10 deletions flow/connectors/kafka/kafka.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"crypto/tls"
"fmt"
"log/slog"
"strings"
"sync/atomic"
"time"

Expand Down Expand Up @@ -178,15 +177,9 @@ func (c *KafkaConnector) createPool(
queueErr func(error),
) (*utils.LPool[poolResult], error) {
return utils.LuaPool(func() (*lua.LState, error) {
ls, err := utils.LoadScript(ctx, script, func(ls *lua.LState) int {
top := ls.GetTop()
ss := make([]string, top)
for i := range top {
ss[i] = ls.ToStringMeta(ls.Get(i + 1)).String()
}
_ = c.LogFlowInfo(ctx, flowJobName, strings.Join(ss, "\t"))
return 0
})
ls, err := utils.LoadScript(ctx, script, utils.LuaPrintFn(func(s string) {
_ = c.LogFlowInfo(ctx, flowJobName, s)
}))
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion flow/connectors/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,11 @@ func NewPostgresConnector(ctx context.Context, pgConfig *protos.PostgresConfig)
// create a separate connection pool for non-replication queries as replication connections cannot
// be used for extended query protocol, i.e. prepared statements
connConfig, err := pgx.ParseConfig(connectionString)
replConfig := connConfig.Copy()
if err != nil {
return nil, fmt.Errorf("failed to parse connection string: %w", err)
}

replConfig := connConfig.Copy()
runtimeParams := connConfig.Config.RuntimeParams
runtimeParams["idle_in_transaction_session_timeout"] = "0"
runtimeParams["statement_timeout"] = "0"
Expand Down
33 changes: 18 additions & 15 deletions flow/connectors/postgres/validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,28 +34,31 @@ func (c *PostgresConnector) CheckSourceTables(ctx context.Context,
}

tableStr := strings.Join(tableArr, ",")
// Check if publication exists
err := c.conn.QueryRow(ctx, "SELECT pubname FROM pg_publication WHERE pubname=$1", pubName).Scan(nil)
if err != nil {
if err == pgx.ErrNoRows {
return fmt.Errorf("publication does not exist: %s", pubName)

if pubName != "" {
// Check if publication exists
err := c.conn.QueryRow(ctx, "SELECT pubname FROM pg_publication WHERE pubname=$1", pubName).Scan(nil)
if err != nil {
if err == pgx.ErrNoRows {
return fmt.Errorf("publication does not exist: %s", pubName)
}
return fmt.Errorf("error while checking for publication existence: %w", err)
}
return fmt.Errorf("error while checking for publication existence: %w", err)
}

// Check if tables belong to publication
var pubTableCount int
err = c.conn.QueryRow(ctx, fmt.Sprintf(`
// Check if tables belong to publication
var pubTableCount int
err = c.conn.QueryRow(ctx, fmt.Sprintf(`
with source_table_components (sname, tname) as (values %s)
select COUNT(DISTINCT(schemaname,tablename)) from pg_publication_tables
INNER JOIN source_table_components stc
ON schemaname=stc.sname and tablename=stc.tname where pubname=$1;`, tableStr), pubName).Scan(&pubTableCount)
if err != nil {
return err
}
if err != nil {
return err
}

if pubTableCount != len(tableNames) {
return errors.New("not all tables belong to publication")
if pubTableCount != len(tableNames) {
return errors.New("not all tables belong to publication")
}
}

return nil
Expand Down
13 changes: 3 additions & 10 deletions flow/connectors/pubsub/pubsub.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"fmt"
"log/slog"
"strings"
"sync"
"sync/atomic"
"time"
Expand Down Expand Up @@ -140,15 +139,9 @@ func (c *PubSubConnector) createPool(
queueErr func(error),
) (*utils.LPool[poolResult], error) {
return utils.LuaPool(func() (*lua.LState, error) {
ls, err := utils.LoadScript(ctx, script, func(ls *lua.LState) int {
top := ls.GetTop()
ss := make([]string, top)
for i := range top {
ss[i] = ls.ToStringMeta(ls.Get(i + 1)).String()
}
_ = c.LogFlowInfo(ctx, flowJobName, strings.Join(ss, "\t"))
return 0
})
ls, err := utils.LoadScript(ctx, script, utils.LuaPrintFn(func(s string) {
_ = c.LogFlowInfo(ctx, flowJobName, s)
}))
if err != nil {
return nil, err
}
Expand Down
1 change: 1 addition & 0 deletions flow/connectors/snowflake/snowflake.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ const (
SELECT _PEERDB_UID,_PEERDB_TIMESTAMP,TO_VARIANT(PARSE_JSON(_PEERDB_DATA)) %s,_PEERDB_RECORD_TYPE,
_PEERDB_MATCH_DATA,_PEERDB_BATCH_ID,_PEERDB_UNCHANGED_TOAST_COLUMNS
FROM _PEERDB_INTERNAL.%s WHERE _PEERDB_BATCH_ID = %d AND
_PEERDB_DATA != '' AND
_PEERDB_DESTINATION_TABLE_NAME = ? ), FLATTENED AS
(SELECT _PEERDB_UID,_PEERDB_TIMESTAMP,_PEERDB_RECORD_TYPE,_PEERDB_MATCH_DATA,_PEERDB_BATCH_ID,
_PEERDB_UNCHANGED_TOAST_COLUMNS,%s
Expand Down
13 changes: 13 additions & 0 deletions flow/connectors/utils/lua.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package utils
import (
"context"
"fmt"
"strings"

"github.com/yuin/gopher-lua"

Expand Down Expand Up @@ -35,6 +36,18 @@ func LVAsStringOrNil(ls *lua.LState, v lua.LValue) (string, error) {
}
}

func LuaPrintFn(fn func(string)) lua.LGFunction {
return func(ls *lua.LState) int {
top := ls.GetTop()
ss := make([]string, top)
for i := range top {
ss[i] = ls.ToStringMeta(ls.Get(i + 1)).String()
}
fn(strings.Join(ss, "\t"))
return 0
}
}

func LoadScript(ctx context.Context, script string, printfn lua.LGFunction) (*lua.LState, error) {
ls := lua.NewState(lua.Options{SkipOpenLibs: true})
ls.SetContext(ctx)
Expand Down
6 changes: 3 additions & 3 deletions flow/e2e/bigquery/peer_flow_bq_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1110,8 +1110,8 @@ func (s PeerFlowE2ETestSuiteBQ) Test_Soft_Delete_IUD_Same_Batch() {
Source: e2e.GeneratePostgresPeer(),
CdcStagingPath: connectionGen.CdcStagingPath,
SoftDelete: true,
SoftDeleteColName: "_PEERDB_IS_DELETED",
SyncedAtColName: "_PEERDB_SYNCED_AT",
SoftDeleteColName: "_custom_deleted",
SyncedAtColName: "_custom_synced",
MaxBatchSize: 100,
}

Expand Down Expand Up @@ -1141,7 +1141,7 @@ func (s PeerFlowE2ETestSuiteBQ) Test_Soft_Delete_IUD_Same_Batch() {
e2e.EnvWaitForEqualTables(env, s, "normalizing tx", "test_softdel_iud", "id,c1,c2,t")
e2e.EnvWaitFor(s.t, env, 3*time.Minute, "checking soft delete", func() bool {
newerSyncedAtQuery := fmt.Sprintf(
"SELECT COUNT(*) FROM `%s.%s` WHERE _PEERDB_IS_DELETED",
"SELECT COUNT(*) FROM `%s.%s` WHERE _custom_deleted",
s.bqHelper.Config.DatasetId, dstTableName)
numNewRows, err := s.bqHelper.RunInt64Query(newerSyncedAtQuery)
e2e.EnvNoError(s.t, env, err)
Expand Down
Loading

0 comments on commit 7113344

Please sign in to comment.