Skip to content

Commit

Permalink
QRep scripting (#1625)
Browse files Browse the repository at this point in the history
Scripts for qrep may define a function `transformRow(row)` which can reassign fields' values (without changing types)
  • Loading branch information
serprex authored May 10, 2024
1 parent b80d434 commit ffd494f
Show file tree
Hide file tree
Showing 10 changed files with 318 additions and 35 deletions.
23 changes: 20 additions & 3 deletions flow/activities/flowable_core.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/yuin/gopher-lua"
"go.temporal.io/sdk/activity"
"go.temporal.io/sdk/log"
"go.temporal.io/sdk/temporal"
Expand All @@ -23,6 +24,7 @@ import (
"github.com/PeerDB-io/peer-flow/generated/protos"
"github.com/PeerDB-io/peer-flow/model"
"github.com/PeerDB-io/peer-flow/peerdbenv"
"github.com/PeerDB-io/peer-flow/pua"
"github.com/PeerDB-io/peer-flow/shared"
)

Expand Down Expand Up @@ -343,10 +345,25 @@ func (a *FlowableActivity) replicateQRepPartition(ctx context.Context,
})
defer shutdown()

var rowsSynced int
bufferSize := shared.FetchAndChannelSize
errGroup, errCtx := errgroup.WithContext(ctx)
stream := model.NewQRecordStream(bufferSize)
outstream := stream
if config.Script != "" {
ls, err := utils.LoadScript(ctx, config.Script, utils.LuaPrintFn(func(s string) {
a.Alerter.LogFlowInfo(ctx, config.FlowJobName, s)
}))
if err != nil {
a.Alerter.LogFlowError(ctx, config.FlowJobName, err)
return err
}
lfn := ls.Env.RawGetString("transformRow")
if fn, ok := lfn.(*lua.LFunction); ok {
outstream = pua.AttachToStream(ls, fn, stream)
}
}

var rowsSynced int
errGroup, errCtx := errgroup.WithContext(ctx)
errGroup.Go(func() error {
tmp, err := srcConn.PullQRepRecords(errCtx, config, partition, stream)
if err != nil {
Expand All @@ -363,7 +380,7 @@ func (a *FlowableActivity) replicateQRepPartition(ctx context.Context,
})

errGroup.Go(func() error {
rowsSynced, err = dstConn.SyncQRepRecords(errCtx, config, partition, stream)
rowsSynced, err = dstConn.SyncQRepRecords(errCtx, config, partition, outstream)
if err != nil {
a.Alerter.LogFlowError(ctx, config.FlowJobName, err)
return fmt.Errorf("failed to sync records: %w", err)
Expand Down
13 changes: 3 additions & 10 deletions flow/connectors/eventhub/eventhub.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"fmt"
"log/slog"
"strings"
"sync/atomic"
"time"

Expand Down Expand Up @@ -196,15 +195,9 @@ func (c *EventHubConnector) processBatch(
var fn *lua.LFunction
if req.Script != "" {
var err error
ls, err = utils.LoadScript(ctx, req.Script, func(ls *lua.LState) int {
top := ls.GetTop()
ss := make([]string, top)
for i := range top {
ss[i] = ls.ToStringMeta(ls.Get(i + 1)).String()
}
_ = c.LogFlowInfo(ctx, req.FlowJobName, strings.Join(ss, "\t"))
return 0
})
ls, err = utils.LoadScript(ctx, req.Script, utils.LuaPrintFn(func(s string) {
_ = c.LogFlowInfo(ctx, req.FlowJobName, s)
}))
if err != nil {
return 0, err
}
Expand Down
13 changes: 3 additions & 10 deletions flow/connectors/kafka/kafka.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"crypto/tls"
"fmt"
"log/slog"
"strings"
"sync/atomic"
"time"

Expand Down Expand Up @@ -178,15 +177,9 @@ func (c *KafkaConnector) createPool(
}

return utils.LuaPool(func() (*lua.LState, error) {
ls, err := utils.LoadScript(ctx, script, func(ls *lua.LState) int {
top := ls.GetTop()
ss := make([]string, top)
for i := range top {
ss[i] = ls.ToStringMeta(ls.Get(i + 1)).String()
}
_ = c.LogFlowInfo(ctx, flowJobName, strings.Join(ss, "\t"))
return 0
})
ls, err := utils.LoadScript(ctx, script, utils.LuaPrintFn(func(s string) {
_ = c.LogFlowInfo(ctx, flowJobName, s)
}))
if err != nil {
return nil, err
}
Expand Down
13 changes: 3 additions & 10 deletions flow/connectors/pubsub/pubsub.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"fmt"
"log/slog"
"strings"
"sync"
"sync/atomic"
"time"
Expand Down Expand Up @@ -130,15 +129,9 @@ func (c *PubSubConnector) createPool(
queueErr func(error),
) (*utils.LPool[[]PubSubMessage], error) {
return utils.LuaPool(func() (*lua.LState, error) {
ls, err := utils.LoadScript(ctx, script, func(ls *lua.LState) int {
top := ls.GetTop()
ss := make([]string, top)
for i := range top {
ss[i] = ls.ToStringMeta(ls.Get(i + 1)).String()
}
_ = c.LogFlowInfo(ctx, flowJobName, strings.Join(ss, "\t"))
return 0
})
ls, err := utils.LoadScript(ctx, script, utils.LuaPrintFn(func(s string) {
_ = c.LogFlowInfo(ctx, flowJobName, s)
}))
if err != nil {
return nil, err
}
Expand Down
13 changes: 13 additions & 0 deletions flow/connectors/utils/lua.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package utils
import (
"context"
"fmt"
"strings"

"github.com/yuin/gopher-lua"

Expand Down Expand Up @@ -35,6 +36,18 @@ func LVAsStringOrNil(ls *lua.LState, v lua.LValue) (string, error) {
}
}

func LuaPrintFn(fn func(string)) lua.LGFunction {
return func(ls *lua.LState) int {
top := ls.GetTop()
ss := make([]string, top)
for i := range top {
ss[i] = ls.ToStringMeta(ls.Get(i + 1)).String()
}
fn(strings.Join(ss, "\t"))
return 0
}
}

func LoadScript(ctx context.Context, script string, printfn lua.LGFunction) (*lua.LState, error) {
ls := lua.NewState(lua.Options{SkipOpenLibs: true})
ls.SetContext(ctx)
Expand Down
52 changes: 52 additions & 0 deletions flow/e2e/postgres/qrep_flow_pg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -406,3 +406,55 @@ func (s PeerFlowE2ETestSuitePG) Test_Pause() {
env.Cancel()
e2e.RequireEnvCanceled(s.t, env)
}

func (s PeerFlowE2ETestSuitePG) TestTransform() {
numRows := 10

srcTable := "test_transform"
s.setupSourceTable(srcTable, numRows)

dstTable := "test_transformdst"

srcSchemaQualified := fmt.Sprintf("%s_%s.%s", "e2e_test", s.suffix, srcTable)
dstSchemaQualified := fmt.Sprintf("%s_%s.%s", "e2e_test", s.suffix, dstTable)

query := fmt.Sprintf("SELECT * FROM %s WHERE updated_at BETWEEN {{.start}} AND {{.end}}", srcSchemaQualified)

postgresPeer := e2e.GeneratePostgresPeer()

_, err := s.Conn().Exec(context.Background(), `insert into public.scripts (name, lang, source) values
('pgtransform', 'lua', 'function transformRow(row) row.myreal = 1729 end') on conflict do nothing`)
require.NoError(s.t, err)

qrepConfig, err := e2e.CreateQRepWorkflowConfig(
"test_transform",
srcSchemaQualified,
dstSchemaQualified,
query,
postgresPeer,
"",
true,
"_PEERDB_SYNCED_AT",
"",
)
require.NoError(s.t, err)
qrepConfig.WriteMode = &protos.QRepWriteMode{
WriteType: protos.QRepWriteType_QREP_WRITE_MODE_OVERWRITE,
}
qrepConfig.InitialCopyOnly = false
qrepConfig.Script = "pgtransform"

tc := e2e.NewTemporalClient(s.t)
env := e2e.RunQRepFlowWorkflow(tc, qrepConfig)
e2e.EnvWaitFor(s.t, env, 3*time.Minute, "waiting for first sync to complete", func() bool {
err = s.compareCounts(dstSchemaQualified, int64(numRows))
return err == nil
})
require.NoError(s.t, env.Error())

var exists bool
err = s.Conn().QueryRow(context.Background(),
fmt.Sprintf("select exists(select * from %s where myreal <> 1729)", dstSchemaQualified)).Scan(&exists)
require.NoError(s.t, err)
require.False(s.t, exists)
}
Loading

0 comments on commit ffd494f

Please sign in to comment.