diff --git a/flow/connectors/kafka/kafka.go b/flow/connectors/kafka/kafka.go index e09b14ab33..89adbbd6a4 100644 --- a/flow/connectors/kafka/kafka.go +++ b/flow/connectors/kafka/kafka.go @@ -165,39 +165,48 @@ func lvalueToKafkaRecord(ls *lua.LState, value lua.LValue) (*kgo.Record, error) return kr, nil } -func (c *KafkaConnector) SyncRecords(ctx context.Context, req *model.SyncRecordsRequest[model.RecordItems]) (*model.SyncResponse, error) { - wgCtx, wgErr := context.WithCancelCause(ctx) +func (c *KafkaConnector) createPool( + ctx context.Context, + script string, + flowJobName string, + wgErr func(error), +) (*utils.LPool[[]*kgo.Record], error) { produceCb := func(_ *kgo.Record, err error) { if err != nil { wgErr(err) } } - numRecords := atomic.Int64{} - tableNameRowsMapping := utils.InitialiseTableRowsMap(req.TableMappings) - - pool, err := utils.LuaPool(func() (*lua.LState, error) { - ls, err := utils.LoadScript(wgCtx, req.Script, func(ls *lua.LState) int { + return utils.LuaPool(func() (*lua.LState, error) { + ls, err := utils.LoadScript(ctx, script, func(ls *lua.LState) int { top := ls.GetTop() ss := make([]string, top) for i := range top { ss[i] = ls.ToStringMeta(ls.Get(i + 1)).String() } - _ = c.LogFlowInfo(ctx, req.FlowJobName, strings.Join(ss, "\t")) + _ = c.LogFlowInfo(ctx, flowJobName, strings.Join(ss, "\t")) return 0 }) if err != nil { return nil, err } - if req.Script == "" { + if script == "" { ls.Env.RawSetString("onRecord", ls.NewFunction(utils.DefaultOnRecord)) } return ls, nil }, func(krs []*kgo.Record) { for _, kr := range krs { - c.client.Produce(wgCtx, kr, produceCb) + c.client.Produce(ctx, kr, produceCb) } }) +} + +func (c *KafkaConnector) SyncRecords(ctx context.Context, req *model.SyncRecordsRequest[model.RecordItems]) (*model.SyncResponse, error) { + numRecords := atomic.Int64{} + tableNameRowsMapping := utils.InitialiseTableRowsMap(req.TableMappings) + + wgCtx, wgErr := context.WithCancelCause(ctx) + pool, err := c.createPool(wgCtx, req.Script, req.FlowJobName, wgErr) if err != nil { return nil, err } diff --git a/flow/connectors/kafka/qrep.go b/flow/connectors/kafka/qrep.go index 0e4867ae02..60725ae842 100644 --- a/flow/connectors/kafka/qrep.go +++ b/flow/connectors/kafka/qrep.go @@ -3,14 +3,12 @@ package connkafka import ( "context" "fmt" - "strings" "sync/atomic" "time" "github.com/twmb/franz-go/pkg/kgo" lua "github.com/yuin/gopher-lua" - "github.com/PeerDB-io/peer-flow/connectors/utils" "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/model" "github.com/PeerDB-io/peer-flow/pua" @@ -27,43 +25,16 @@ func (c *KafkaConnector) SyncQRepRecords( stream *model.QRecordStream, ) (int, error) { startTime := time.Now() + numRecords := atomic.Int64{} schema := stream.Schema() wgCtx, wgErr := context.WithCancelCause(ctx) - produceCb := func(_ *kgo.Record, err error) { - if err != nil { - wgErr(err) - } - } - - pool, err := utils.LuaPool(func() (*lua.LState, error) { - ls, err := utils.LoadScript(wgCtx, config.Script, func(ls *lua.LState) int { - top := ls.GetTop() - ss := make([]string, top) - for i := range top { - ss[i] = ls.ToStringMeta(ls.Get(i + 1)).String() - } - _ = c.LogFlowInfo(ctx, config.FlowJobName, strings.Join(ss, "\t")) - return 0 - }) - if err != nil { - return nil, err - } - if config.Script == "" { - ls.Env.RawSetString("onRecord", ls.NewFunction(utils.DefaultOnRecord)) - } - return ls, nil - }, func(krs []*kgo.Record) { - for _, kr := range krs { - c.client.Produce(wgCtx, kr, produceCb) - } - }) + pool, err := c.createPool(wgCtx, config.Script, config.FlowJobName, wgErr) if err != nil { return 0, err } defer pool.Close() - numRecords := atomic.Int64{} Loop: for { select {