diff --git a/flow/activities/flowable_core.go b/flow/activities/flowable_core.go index 319da34c49..ad3dca378a 100644 --- a/flow/activities/flowable_core.go +++ b/flow/activities/flowable_core.go @@ -11,6 +11,7 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" + "github.com/yuin/gopher-lua" "go.temporal.io/sdk/activity" "go.temporal.io/sdk/log" "go.temporal.io/sdk/temporal" @@ -23,6 +24,7 @@ import ( "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/model" "github.com/PeerDB-io/peer-flow/peerdbenv" + "github.com/PeerDB-io/peer-flow/pua" "github.com/PeerDB-io/peer-flow/shared" ) @@ -343,10 +345,25 @@ func (a *FlowableActivity) replicateQRepPartition(ctx context.Context, }) defer shutdown() - var rowsSynced int bufferSize := shared.FetchAndChannelSize - errGroup, errCtx := errgroup.WithContext(ctx) stream := model.NewQRecordStream(bufferSize) + outstream := stream + if config.Script != "" { + ls, err := utils.LoadScript(ctx, config.Script, utils.LuaPrintFn(func(s string) { + a.Alerter.LogFlowInfo(ctx, config.FlowJobName, s) + })) + if err != nil { + a.Alerter.LogFlowError(ctx, config.FlowJobName, err) + return err + } + lfn := ls.Env.RawGetString("transformRow") + if fn, ok := lfn.(*lua.LFunction); ok { + outstream = pua.AttachToStream(ls, fn, stream) + } + } + + var rowsSynced int + errGroup, errCtx := errgroup.WithContext(ctx) errGroup.Go(func() error { tmp, err := srcConn.PullQRepRecords(errCtx, config, partition, stream) if err != nil { @@ -363,7 +380,7 @@ func (a *FlowableActivity) replicateQRepPartition(ctx context.Context, }) errGroup.Go(func() error { - rowsSynced, err = dstConn.SyncQRepRecords(errCtx, config, partition, stream) + rowsSynced, err = dstConn.SyncQRepRecords(errCtx, config, partition, outstream) if err != nil { a.Alerter.LogFlowError(ctx, config.FlowJobName, err) return fmt.Errorf("failed to sync records: %w", err) diff --git a/flow/connectors/eventhub/eventhub.go b/flow/connectors/eventhub/eventhub.go index 4f95d1007e..1182f4d413 100644 --- a/flow/connectors/eventhub/eventhub.go +++ b/flow/connectors/eventhub/eventhub.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "log/slog" - "strings" "sync/atomic" "time" @@ -196,15 +195,9 @@ func (c *EventHubConnector) processBatch( var fn *lua.LFunction if req.Script != "" { var err error - ls, err = utils.LoadScript(ctx, req.Script, func(ls *lua.LState) int { - top := ls.GetTop() - ss := make([]string, top) - for i := range top { - ss[i] = ls.ToStringMeta(ls.Get(i + 1)).String() - } - _ = c.LogFlowInfo(ctx, req.FlowJobName, strings.Join(ss, "\t")) - return 0 - }) + ls, err = utils.LoadScript(ctx, req.Script, utils.LuaPrintFn(func(s string) { + _ = c.LogFlowInfo(ctx, req.FlowJobName, s) + })) if err != nil { return 0, err } diff --git a/flow/connectors/kafka/kafka.go b/flow/connectors/kafka/kafka.go index cfe4652598..c58da5e50f 100644 --- a/flow/connectors/kafka/kafka.go +++ b/flow/connectors/kafka/kafka.go @@ -5,7 +5,6 @@ import ( "crypto/tls" "fmt" "log/slog" - "strings" "sync/atomic" "time" @@ -178,15 +177,9 @@ func (c *KafkaConnector) createPool( } return utils.LuaPool(func() (*lua.LState, error) { - ls, err := utils.LoadScript(ctx, script, func(ls *lua.LState) int { - top := ls.GetTop() - ss := make([]string, top) - for i := range top { - ss[i] = ls.ToStringMeta(ls.Get(i + 1)).String() - } - _ = c.LogFlowInfo(ctx, flowJobName, strings.Join(ss, "\t")) - return 0 - }) + ls, err := utils.LoadScript(ctx, script, utils.LuaPrintFn(func(s string) { + _ = c.LogFlowInfo(ctx, flowJobName, s) + })) if err != nil { return nil, err } diff --git a/flow/connectors/pubsub/pubsub.go b/flow/connectors/pubsub/pubsub.go index 54031f016d..0a8709b3b2 100644 --- a/flow/connectors/pubsub/pubsub.go +++ b/flow/connectors/pubsub/pubsub.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "log/slog" - "strings" "sync" "sync/atomic" "time" @@ -130,15 +129,9 @@ func (c *PubSubConnector) createPool( queueErr func(error), ) (*utils.LPool[[]PubSubMessage], error) { return utils.LuaPool(func() (*lua.LState, error) { - ls, err := utils.LoadScript(ctx, script, func(ls *lua.LState) int { - top := ls.GetTop() - ss := make([]string, top) - for i := range top { - ss[i] = ls.ToStringMeta(ls.Get(i + 1)).String() - } - _ = c.LogFlowInfo(ctx, flowJobName, strings.Join(ss, "\t")) - return 0 - }) + ls, err := utils.LoadScript(ctx, script, utils.LuaPrintFn(func(s string) { + _ = c.LogFlowInfo(ctx, flowJobName, s) + })) if err != nil { return nil, err } diff --git a/flow/connectors/utils/lua.go b/flow/connectors/utils/lua.go index 47676721b3..f1d82f373f 100644 --- a/flow/connectors/utils/lua.go +++ b/flow/connectors/utils/lua.go @@ -3,6 +3,7 @@ package utils import ( "context" "fmt" + "strings" "github.com/yuin/gopher-lua" @@ -35,6 +36,18 @@ func LVAsStringOrNil(ls *lua.LState, v lua.LValue) (string, error) { } } +func LuaPrintFn(fn func(string)) lua.LGFunction { + return func(ls *lua.LState) int { + top := ls.GetTop() + ss := make([]string, top) + for i := range top { + ss[i] = ls.ToStringMeta(ls.Get(i + 1)).String() + } + fn(strings.Join(ss, "\t")) + return 0 + } +} + func LoadScript(ctx context.Context, script string, printfn lua.LGFunction) (*lua.LState, error) { ls := lua.NewState(lua.Options{SkipOpenLibs: true}) ls.SetContext(ctx) diff --git a/flow/e2e/postgres/qrep_flow_pg_test.go b/flow/e2e/postgres/qrep_flow_pg_test.go index 63a226ae13..4f2f944a97 100644 --- a/flow/e2e/postgres/qrep_flow_pg_test.go +++ b/flow/e2e/postgres/qrep_flow_pg_test.go @@ -406,3 +406,55 @@ func (s PeerFlowE2ETestSuitePG) Test_Pause() { env.Cancel() e2e.RequireEnvCanceled(s.t, env) } + +func (s PeerFlowE2ETestSuitePG) TestTransform() { + numRows := 10 + + srcTable := "test_transform" + s.setupSourceTable(srcTable, numRows) + + dstTable := "test_transformdst" + + srcSchemaQualified := fmt.Sprintf("%s_%s.%s", "e2e_test", s.suffix, srcTable) + dstSchemaQualified := fmt.Sprintf("%s_%s.%s", "e2e_test", s.suffix, dstTable) + + query := fmt.Sprintf("SELECT * FROM %s WHERE updated_at BETWEEN {{.start}} AND {{.end}}", srcSchemaQualified) + + postgresPeer := e2e.GeneratePostgresPeer() + + _, err := s.Conn().Exec(context.Background(), `insert into public.scripts (name, lang, source) values + ('pgtransform', 'lua', 'function transformRow(row) row.myreal = 1729 end') on conflict do nothing`) + require.NoError(s.t, err) + + qrepConfig, err := e2e.CreateQRepWorkflowConfig( + "test_transform", + srcSchemaQualified, + dstSchemaQualified, + query, + postgresPeer, + "", + true, + "_PEERDB_SYNCED_AT", + "", + ) + require.NoError(s.t, err) + qrepConfig.WriteMode = &protos.QRepWriteMode{ + WriteType: protos.QRepWriteType_QREP_WRITE_MODE_OVERWRITE, + } + qrepConfig.InitialCopyOnly = false + qrepConfig.Script = "pgtransform" + + tc := e2e.NewTemporalClient(s.t) + env := e2e.RunQRepFlowWorkflow(tc, qrepConfig) + e2e.EnvWaitFor(s.t, env, 3*time.Minute, "waiting for first sync to complete", func() bool { + err = s.compareCounts(dstSchemaQualified, int64(numRows)) + return err == nil + }) + require.NoError(s.t, env.Error()) + + var exists bool + err = s.Conn().QueryRow(context.Background(), + fmt.Sprintf("select exists(select * from %s where myreal <> 1729)", dstSchemaQualified)).Scan(&exists) + require.NoError(s.t, err) + require.False(s.t, exists) +} diff --git a/flow/pua/peerdb.go b/flow/pua/peerdb.go index 7330cb9d09..c44aeb497a 100644 --- a/flow/pua/peerdb.go +++ b/flow/pua/peerdb.go @@ -3,6 +3,7 @@ package pua import ( "bytes" "fmt" + "math" "math/big" "time" @@ -52,6 +53,7 @@ func RegisterTypes(ls *lua.LState) { mt = LuaRow.NewMetatable(ls) mt.RawSetString("__index", ls.NewFunction(LuaRowIndex)) + mt.RawSetString("__newindex", ls.NewFunction(LuaRowNewIndex)) mt.RawSetString("__len", ls.NewFunction(LuaRowLen)) mt = shared.LuaUuid.NewMetatable(ls) @@ -157,6 +159,178 @@ func LuaRowIndex(ls *lua.LState) int { return 1 } +func LVAsTime(ls *lua.LState, lv lua.LValue) time.Time { + switch v := lv.(type) { + case lua.LNumber: + ipart, fpart := math.Modf(float64(v)) + return time.Unix(int64(ipart), int64(fpart*1e9)) + case *lua.LUserData: + if tm, ok := v.Value.(time.Time); ok { + return tm + } + } + ls.RaiseError("Cannot convert %T to time.Time", lv) + return time.Time{} +} + +func LuaRowNewIndex(ls *lua.LState) int { + _, row := LuaRow.Check(ls, 1) + key := ls.CheckString(2) + val := ls.Get(3) + qv := row.GetColumnValue(key) + kind := qv.Kind() + if val == lua.LNil { + row.AddColumn(key, qvalue.QValueNull(kind)) + } + var newqv qvalue.QValue + switch kind { + case qvalue.QValueKindInvalid: + newqv = qvalue.QValueInvalid{Val: lua.LVAsString(val)} + case qvalue.QValueKindFloat32: + newqv = qvalue.QValueFloat32{Val: float32(lua.LVAsNumber(val))} + case qvalue.QValueKindFloat64: + newqv = qvalue.QValueFloat64{Val: float64(lua.LVAsNumber(val))} + case qvalue.QValueKindInt16: + newqv = qvalue.QValueInt16{Val: int16(lua.LVAsNumber(val))} + case qvalue.QValueKindInt32: + newqv = qvalue.QValueInt32{Val: int32(lua.LVAsNumber(val))} + case qvalue.QValueKindInt64: + switch v := val.(type) { + case lua.LNumber: + newqv = qvalue.QValueInt64{Val: int64(v)} + case *lua.LUserData: + switch i64 := v.Value.(type) { + case int64: + newqv = qvalue.QValueInt64{Val: i64} + case uint64: + newqv = qvalue.QValueInt64{Val: int64(i64)} + } + } + if newqv == nil { + ls.RaiseError("invalid int64") + } + case qvalue.QValueKindBoolean: + newqv = qvalue.QValueBoolean{Val: lua.LVAsBool(val)} + case qvalue.QValueKindQChar: + switch v := val.(type) { + case lua.LNumber: + newqv = qvalue.QValueQChar{Val: uint8(v)} + case lua.LString: + if len(v) > 0 { + newqv = qvalue.QValueQChar{Val: v[0]} + } + default: + ls.RaiseError("invalid \"char\"") + } + case qvalue.QValueKindString: + newqv = qvalue.QValueString{Val: lua.LVAsString(val)} + case qvalue.QValueKindTimestamp: + newqv = qvalue.QValueTimestamp{Val: LVAsTime(ls, val)} + case qvalue.QValueKindTimestampTZ: + newqv = qvalue.QValueTimestampTZ{Val: LVAsTime(ls, val)} + case qvalue.QValueKindDate: + newqv = qvalue.QValueDate{Val: LVAsTime(ls, val)} + case qvalue.QValueKindTime: + newqv = qvalue.QValueTime{Val: LVAsTime(ls, val)} + case qvalue.QValueKindTimeTZ: + newqv = qvalue.QValueTimeTZ{Val: LVAsTime(ls, val)} + case qvalue.QValueKindNumeric: + newqv = qvalue.QValueNumeric{Val: LVAsDecimal(ls, val)} + case qvalue.QValueKindBytes: + newqv = qvalue.QValueBytes{Val: []byte(lua.LVAsString(val))} + case qvalue.QValueKindUUID: + if ud, ok := val.(*lua.LUserData); ok { + if id, ok := ud.Value.(uuid.UUID); ok { + newqv = qvalue.QValueUUID{Val: [16]byte(id)} + } + } + case qvalue.QValueKindJSON: + newqv = qvalue.QValueJSON{Val: lua.LVAsString(val)} + case qvalue.QValueKindBit: + newqv = qvalue.QValueBit{Val: []byte(lua.LVAsString(val))} + case qvalue.QValueKindArrayFloat32: + if tbl, ok := val.(*lua.LTable); ok { + newqv = qvalue.QValueArrayFloat32{ + Val: shared.LTableToSlice(ls, tbl, func(_ *lua.LState, v lua.LValue) float32 { + return float32(lua.LVAsNumber(v)) + }), + } + } + case qvalue.QValueKindArrayFloat64: + if tbl, ok := val.(*lua.LTable); ok { + newqv = qvalue.QValueArrayFloat64{ + Val: shared.LTableToSlice(ls, tbl, func(_ *lua.LState, v lua.LValue) float64 { + return float64(lua.LVAsNumber(v)) + }), + } + } + case qvalue.QValueKindArrayInt16: + if tbl, ok := val.(*lua.LTable); ok { + newqv = qvalue.QValueArrayFloat64{ + Val: shared.LTableToSlice(ls, tbl, func(_ *lua.LState, v lua.LValue) float64 { + return float64(lua.LVAsNumber(v)) + }), + } + } + case qvalue.QValueKindArrayInt32: + if tbl, ok := val.(*lua.LTable); ok { + newqv = qvalue.QValueArrayFloat64{ + Val: shared.LTableToSlice(ls, tbl, func(_ *lua.LState, v lua.LValue) float64 { + return float64(lua.LVAsNumber(v)) + }), + } + } + case qvalue.QValueKindArrayInt64: + if tbl, ok := val.(*lua.LTable); ok { + newqv = qvalue.QValueArrayFloat64{ + Val: shared.LTableToSlice(ls, tbl, func(_ *lua.LState, v lua.LValue) float64 { + return float64(lua.LVAsNumber(v)) + }), + } + } + case qvalue.QValueKindArrayString: + if tbl, ok := val.(*lua.LTable); ok { + newqv = qvalue.QValueArrayString{ + Val: shared.LTableToSlice(ls, tbl, func(_ *lua.LState, v lua.LValue) string { + return lua.LVAsString(v) + }), + } + } + case qvalue.QValueKindArrayDate: + if tbl, ok := val.(*lua.LTable); ok { + newqv = qvalue.QValueArrayDate{ + Val: shared.LTableToSlice(ls, tbl, LVAsTime), + } + } + case qvalue.QValueKindArrayTimestamp: + if tbl, ok := val.(*lua.LTable); ok { + newqv = qvalue.QValueArrayDate{ + Val: shared.LTableToSlice(ls, tbl, LVAsTime), + } + } + case qvalue.QValueKindArrayTimestampTZ: + if tbl, ok := val.(*lua.LTable); ok { + newqv = qvalue.QValueArrayDate{ + Val: shared.LTableToSlice(ls, tbl, LVAsTime), + } + } + case qvalue.QValueKindArrayBoolean: + if tbl, ok := val.(*lua.LTable); ok { + newqv = qvalue.QValueArrayBoolean{ + Val: shared.LTableToSlice(ls, tbl, func(_ *lua.LState, v lua.LValue) bool { + return lua.LVAsBool(v) + }), + } + } + default: + ls.RaiseError(fmt.Sprintf("no support for reassigning %s", kind)) + return 0 + } + + row.AddColumn(key, newqv) + return 1 +} + func LuaRowLen(ls *lua.LState) int { row := LuaRow.StartMethod(ls) ls.Push(lua.LNumber(len(row.ColToVal))) diff --git a/flow/pua/stream_adapter.go b/flow/pua/stream_adapter.go new file mode 100644 index 0000000000..0367134368 --- /dev/null +++ b/flow/pua/stream_adapter.go @@ -0,0 +1,33 @@ +package pua + +import ( + "github.com/yuin/gopher-lua" + + "github.com/PeerDB-io/peer-flow/model" +) + +func AttachToStream(ls *lua.LState, lfn *lua.LFunction, stream *model.QRecordStream) *model.QRecordStream { + output := model.NewQRecordStream(0) + go func() { + schema := stream.Schema() + output.SetSchema(schema) + for record := range stream.Records { + row := model.NewRecordItems(len(record)) + for i, qv := range record { + row.AddColumn(schema.Fields[i].Name, qv) + } + ls.Push(lfn) + ls.Push(LuaRow.New(ls, row)) + if err := ls.PCall(1, 0, nil); err != nil { + output.Close(err) + return + } + for i, field := range schema.Fields { + record[i] = row.GetColumnValue(field.Name) + } + output.Records <- record + } + output.Close(stream.Err()) + }() + return output +} diff --git a/flow/shared/lua.go b/flow/shared/lua.go index 2b95c3464c..26aeb4fe2a 100644 --- a/flow/shared/lua.go +++ b/flow/shared/lua.go @@ -18,7 +18,7 @@ var ( LuaDecimal = glua64.UserDataType[decimal.Decimal]{Name: "peerdb_decimal"} ) -func SliceToLTable[T any](ls *lua.LState, s []T, f func(x T) lua.LValue) *lua.LTable { +func SliceToLTable[T any](ls *lua.LState, s []T, f func(T) lua.LValue) *lua.LTable { tbl := ls.CreateTable(len(s), 0) tbl.Metatable = ls.GetTypeMetatable("Array") for idx, val := range s { @@ -26,3 +26,12 @@ func SliceToLTable[T any](ls *lua.LState, s []T, f func(x T) lua.LValue) *lua.LT } return tbl } + +func LTableToSlice[T any](ls *lua.LState, tbl *lua.LTable, f func(*lua.LState, lua.LValue) T) []T { + tlen := tbl.Len() + slice := make([]T, 0, tlen) + for i := range tlen { + slice = append(slice, f(ls, tbl.RawGetInt(i))) + } + return slice +} diff --git a/ui/app/mirrors/create/helpers/qrep.ts b/ui/app/mirrors/create/helpers/qrep.ts index 392439ee77..aa2b80d462 100644 --- a/ui/app/mirrors/create/helpers/qrep.ts +++ b/ui/app/mirrors/create/helpers/qrep.ts @@ -123,8 +123,14 @@ export const qrepSettings: MirrorSetting[] = [ ...curr, waitBetweenBatchesSeconds: parseInt(value as string, 10) || 30, })), - tips: 'Time to wait (in seconds) between getting partitions to process. The default is 30 seconds', + tips: 'Time to wait (in seconds) between getting partitions to process. The default is 30 seconds.', default: 30, type: 'number', }, + { + label: 'Script', + stateHandler: (value, setter) => + setter((curr: QRepConfig) => ({ ...curr, script: value as string })), + tips: 'Script to use for row transformations. The default is no scripting.', + }, ];