Skip to content

Commit

Permalink
parallelized lua script execution with serialized output (#1613)
Browse files Browse the repository at this point in the history
  • Loading branch information
serprex authored Apr 15, 2024
1 parent 0072bb3 commit 60bcef9
Show file tree
Hide file tree
Showing 7 changed files with 269 additions and 139 deletions.
109 changes: 64 additions & 45 deletions flow/connectors/kafka/kafka.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,38 +169,43 @@ func lvalueToKafkaRecord(ls *lua.LState, value lua.LValue) (*kgo.Record, error)
func (c *KafkaConnector) SyncRecords(ctx context.Context, req *model.SyncRecordsRequest[model.RecordItems]) (*model.SyncResponse, error) {
var wg sync.WaitGroup
wgCtx, wgErr := context.WithCancelCause(ctx)
produceCb := func(r *kgo.Record, err error) {
produceCb := func(_ *kgo.Record, err error) {
if err != nil {
wgErr(err)
}
wg.Done()
}

numRecords := int64(0)
numRecords := atomic.Int64{}
tableNameRowsMapping := utils.InitialiseTableRowsMap(req.TableMappings)

ls, err := utils.LoadScript(wgCtx, req.Script, func(ls *lua.LState) int {
top := ls.GetTop()
ss := make([]string, top)
for i := range top {
ss[i] = ls.ToStringMeta(ls.Get(i + 1)).String()
pool, err := utils.LuaPool(func() (*lua.LState, error) {
ls, err := utils.LoadScript(wgCtx, req.Script, func(ls *lua.LState) int {
top := ls.GetTop()
ss := make([]string, top)
for i := range top {
ss[i] = ls.ToStringMeta(ls.Get(i + 1)).String()
}
_ = c.LogFlowInfo(ctx, req.FlowJobName, strings.Join(ss, "\t"))
return 0
})
if err != nil {
return nil, err
}
if req.Script == "" {
ls.Env.RawSetString("onRecord", ls.NewFunction(utils.DefaultOnRecord))
}
return ls, nil
}, func(krs []*kgo.Record) {
wg.Add(len(krs))
for _, kr := range krs {
c.client.Produce(wgCtx, kr, produceCb)
}
_ = c.LogFlowInfo(ctx, req.FlowJobName, strings.Join(ss, "\t"))
return 0
})
if err != nil {
return nil, err
}
defer ls.Close()
if req.Script == "" {
ls.Env.RawSetString("onRecord", ls.NewFunction(utils.DefaultOnRecord))
}

lfn := ls.Env.RawGetString("onRecord")
fn, ok := lfn.(*lua.LFunction)
if !ok {
return nil, fmt.Errorf("script should define `onRecord` as function, not %s", lfn)
}
defer pool.Close()

lastSeenLSN := atomic.Int64{}
flushLoopDone := make(chan struct{})
Expand Down Expand Up @@ -242,39 +247,54 @@ Loop:
break Loop
}

ls.Push(fn)
ls.Push(pua.LuaRecord.New(ls, record))
err := ls.PCall(1, -1, nil)
if err != nil {
return nil, fmt.Errorf("script failed: %w", err)
}
args := ls.GetTop()
for i := range args {
kr, err := lvalueToKafkaRecord(ls, ls.Get(i-args))
pool.Run(func(ls *lua.LState) []*kgo.Record {
lfn := ls.Env.RawGetString("onRecord")
fn, ok := lfn.(*lua.LFunction)
if !ok {
wgErr(fmt.Errorf("script should define `onRecord` as function, not %s", lfn))
return nil
}

ls.Push(fn)
ls.Push(pua.LuaRecord.New(ls, record))
err := ls.PCall(1, -1, nil)
if err != nil {
return nil, err
wgErr(fmt.Errorf("script failed: %w", err))
return nil
}
if kr != nil {
if kr.Topic == "" {
kr.Topic = record.GetDestinationTableName()
}

wg.Add(1)
c.client.Produce(wgCtx, kr, produceCb)
record.PopulateCountMap(tableNameRowsMapping)
args := ls.GetTop()
results := make([]*kgo.Record, 0, args)
for i := range args {
kr, err := lvalueToKafkaRecord(ls, ls.Get(i-args))
if err != nil {
wgErr(err)
return nil
}
if kr != nil {
if kr.Topic == "" {
kr.Topic = record.GetDestinationTableName()
}
results = append(results, kr)
record.PopulateCountMap(tableNameRowsMapping)
}
}
}
ls.SetTop(0)
numRecords += 1
shared.AtomicInt64Max(&lastSeenLSN, record.GetCheckpointID())
ls.SetTop(0)
numRecords.Add(1)
shared.AtomicInt64Max(&lastSeenLSN, record.GetCheckpointID())
return results
})

case <-wgCtx.Done():
return nil, wgCtx.Err()
break Loop
}
}

close(flushLoopDone)
if err := c.client.Flush(ctx); err != nil {
if err := pool.Wait(wgCtx); err != nil {
return nil, err
}
if err := c.client.Flush(wgCtx); err != nil {
return nil, fmt.Errorf("[kafka] final flush error: %w", err)
}
waitChan := make(chan struct{})
Expand All @@ -289,15 +309,14 @@ Loop:
}

lastCheckpoint := req.Records.GetLastCheckpoint()
err = c.FinishBatch(ctx, req.FlowJobName, req.SyncBatchID, lastCheckpoint)
if err != nil {
if err := c.FinishBatch(ctx, req.FlowJobName, req.SyncBatchID, lastCheckpoint); err != nil {
return nil, err
}

return &model.SyncResponse{
CurrentSyncBatchID: req.SyncBatchID,
LastSyncedCheckpointID: lastCheckpoint,
NumRecordsSynced: numRecords,
NumRecordsSynced: numRecords.Load(),
TableNameRowsMapping: tableNameRowsMapping,
TableSchemaDeltas: req.Records.SchemaDeltas,
}, nil
Expand Down
Loading

0 comments on commit 60bcef9

Please sign in to comment.