diff --git a/flow/connectors/pubsub/pubsub.go b/flow/connectors/pubsub/pubsub.go index b08150960c..842a8200a3 100644 --- a/flow/connectors/pubsub/pubsub.go +++ b/flow/connectors/pubsub/pubsub.go @@ -121,6 +121,57 @@ func lvalueToPubSubMessage(ls *lua.LState, value lua.LValue) (PubSubMessage, err }, nil } +func (c *PubSubConnector) createPool( + ctx context.Context, + script string, + flowJobName string, + topiccache *topicCache, + publish chan<- *pubsub.PublishResult, + wgErr func(error), +) (*utils.LPool[[]PubSubMessage], error) { + return utils.LuaPool(func() (*lua.LState, error) { + ls, err := utils.LoadScript(ctx, script, func(ls *lua.LState) int { + top := ls.GetTop() + ss := make([]string, top) + for i := range top { + ss[i] = ls.ToStringMeta(ls.Get(i + 1)).String() + } + _ = c.LogFlowInfo(ctx, flowJobName, strings.Join(ss, "\t")) + return 0 + }) + if err != nil { + return nil, err + } + if script == "" { + ls.Env.RawSetString("onRecord", ls.NewFunction(utils.DefaultOnRecord)) + } + return ls, nil + }, func(messages []PubSubMessage) { + for _, message := range messages { + topicClient, err := topiccache.GetOrSet(message.Topic, func() (*pubsub.Topic, error) { + topicClient := c.client.Topic(message.Topic) + exists, err := topicClient.Exists(ctx) + if err != nil { + return nil, fmt.Errorf("error checking if topic exists: %w", err) + } + if !exists { + topicClient, err = c.client.CreateTopic(ctx, message.Topic) + if err != nil { + return nil, fmt.Errorf("error creating topic: %w", err) + } + } + return topicClient, nil + }) + if err != nil { + wgErr(err) + return + } + + publish <- topicClient.Publish(ctx, message.Message) + } + }) +} + type topicCache struct { cache map[string]*pubsub.Topic lock sync.RWMutex @@ -175,51 +226,10 @@ func (c *PubSubConnector) SyncRecords(ctx context.Context, req *model.SyncRecord tableNameRowsMapping := utils.InitialiseTableRowsMap(req.TableMappings) topiccache := topicCache{cache: make(map[string]*pubsub.Topic)} publish := make(chan *pubsub.PublishResult, 32) - waitChan := make(chan struct{}) - wgCtx, wgErr := context.WithCancelCause(ctx) - - pool, err := utils.LuaPool(func() (*lua.LState, error) { - ls, err := utils.LoadScript(ctx, req.Script, func(ls *lua.LState) int { - top := ls.GetTop() - ss := make([]string, top) - for i := range top { - ss[i] = ls.ToStringMeta(ls.Get(i + 1)).String() - } - _ = c.LogFlowInfo(ctx, req.FlowJobName, strings.Join(ss, "\t")) - return 0 - }) - if err != nil { - return nil, err - } - if req.Script == "" { - ls.Env.RawSetString("onRecord", ls.NewFunction(utils.DefaultOnRecord)) - } - return ls, nil - }, func(messages []PubSubMessage) { - for _, message := range messages { - topicClient, err := topiccache.GetOrSet(message.Topic, func() (*pubsub.Topic, error) { - topicClient := c.client.Topic(message.Topic) - exists, err := topicClient.Exists(wgCtx) - if err != nil { - return nil, fmt.Errorf("error checking if topic exists: %w", err) - } - if !exists { - topicClient, err = c.client.CreateTopic(wgCtx, message.Topic) - if err != nil { - return nil, fmt.Errorf("error creating topic: %w", err) - } - } - return topicClient, nil - }) - if err != nil { - wgErr(err) - return - } - publish <- topicClient.Publish(ctx, message.Message) - } - }) + wgCtx, wgErr := context.WithCancelCause(ctx) + pool, err := c.createPool(wgCtx, req.Script, req.FlowJobName, &topiccache, publish, wgErr) if err != nil { return nil, err } diff --git a/flow/connectors/pubsub/qrep.go b/flow/connectors/pubsub/qrep.go new file mode 100644 index 0000000000..a70ff8dbee --- /dev/null +++ b/flow/connectors/pubsub/qrep.go @@ -0,0 +1,128 @@ +package connpubsub + +import ( + "context" + "fmt" + "sync/atomic" + "time" + + "cloud.google.com/go/pubsub" + lua "github.com/yuin/gopher-lua" + + "github.com/PeerDB-io/peer-flow/generated/protos" + "github.com/PeerDB-io/peer-flow/model" + "github.com/PeerDB-io/peer-flow/pua" +) + +func (*PubSubConnector) SetupQRepMetadataTables(_ context.Context, _ *protos.QRepConfig) error { + return nil +} + +func (c *PubSubConnector) SyncQRepRecords( + ctx context.Context, + config *protos.QRepConfig, + partition *protos.QRepPartition, + stream *model.QRecordStream, +) (int, error) { + startTime := time.Now() + numRecords := atomic.Int64{} + schema := stream.Schema() + topiccache := topicCache{cache: make(map[string]*pubsub.Topic)} + publish := make(chan *pubsub.PublishResult, 32) + waitChan := make(chan struct{}) + + wgCtx, wgErr := context.WithCancelCause(ctx) + pool, err := c.createPool(wgCtx, config.Script, config.FlowJobName, &topiccache, publish, wgErr) + if err != nil { + return 0, err + } + defer pool.Close() + + go func() { + for curpub := range publish { + if _, err := curpub.Get(ctx); err != nil { + wgErr(err) + break + } + } + close(waitChan) + }() + +Loop: + for { + select { + case qrecord, ok := <-stream.Records: + if !ok { + c.logger.Info("flushing batches because no more records") + break Loop + } + + pool.Run(func(ls *lua.LState) []PubSubMessage { + items := model.NewRecordItems(len(qrecord)) + for i, val := range qrecord { + items.AddColumn(schema.Fields[i].Name, val) + } + record := &model.InsertRecord[model.RecordItems]{ + BaseRecord: model.BaseRecord{}, + Items: items, + SourceTableName: config.WatermarkTable, + DestinationTableName: config.DestinationTableIdentifier, + CommitID: 0, + } + + lfn := ls.Env.RawGetString("onRecord") + fn, ok := lfn.(*lua.LFunction) + if !ok { + wgErr(fmt.Errorf("script should define `onRecord` as function, not %s", lfn)) + return nil + } + + ls.Push(fn) + ls.Push(pua.LuaRecord.New(ls, record)) + err := ls.PCall(1, -1, nil) + if err != nil { + wgErr(fmt.Errorf("script failed: %w", err)) + return nil + } + + args := ls.GetTop() + results := make([]PubSubMessage, 0, args) + for i := range args { + msg, err := lvalueToPubSubMessage(ls, ls.Get(i-args)) + if err != nil { + wgErr(err) + return nil + } + if msg.Message != nil { + if msg.Topic == "" { + msg.Topic = record.GetDestinationTableName() + } + results = append(results, msg) + } + } + ls.SetTop(0) + numRecords.Add(1) + return results + }) + + case <-wgCtx.Done(): + break Loop + } + } + + close(publish) + if err := pool.Wait(wgCtx); err != nil { + return 0, err + } + topiccache.Stop(wgCtx) + select { + case <-wgCtx.Done(): + return 0, wgCtx.Err() + case <-waitChan: + } + + if err := c.FinishQRepPartition(ctx, partition, config.FlowJobName, startTime); err != nil { + return 0, err + } + return int(numRecords.Load()), nil +} diff --git a/flow/e2e/pubsub/pubsub_test.go b/flow/e2e/pubsub/pubsub_test.go index ec03fbb111..0261752b65 100644 --- a/flow/e2e/pubsub/pubsub_test.go +++ b/flow/e2e/pubsub/pubsub_test.go @@ -237,3 +237,74 @@ func (s PubSubSuite) TestSimple() { env.Cancel() e2e.RequireEnvCanceled(s.t, env) } + +func (s PubSubSuite) TestInitialLoad() { + srcTableName := e2e.AttachSchema(s, "psinitial") + + _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(` + CREATE TABLE IF NOT EXISTS %s ( + id SERIAL PRIMARY KEY, + val text + ); + `, srcTableName)) + require.NoError(s.t, err) + + sa, err := ServiceAccount() + require.NoError(s.t, err) + + _, err = s.Conn().Exec(context.Background(), `insert into public.scripts (name, lang, source) values + ('e2e_psinitial', 'lua', 'function onRecord(r) return r.row and r.row.val end') on conflict do nothing`) + require.NoError(s.t, err) + + flowName := e2e.AddSuffix(s, "e2epsinitial") + connectionGen := e2e.FlowConnectionGenerationConfig{ + FlowJobName: flowName, + TableNameMapping: map[string]string{srcTableName: flowName}, + Destination: s.Peer(sa), + } + flowConnConfig := connectionGen.GenerateFlowConnectionConfigs() + flowConnConfig.Script = "e2e_psinitial" + flowConnConfig.DoInitialSnapshot = true + + psclient, err := sa.CreatePubSubClient(context.Background()) + require.NoError(s.t, err) + defer psclient.Close() + topic, err := psclient.CreateTopic(context.Background(), flowName) + require.NoError(s.t, err) + sub, err := psclient.CreateSubscription(context.Background(), flowName, pubsub.SubscriptionConfig{ + Topic: topic, + RetentionDuration: 10 * time.Minute, + ExpirationPolicy: 24 * time.Hour, + }) + require.NoError(s.t, err) + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` + + INSERT INTO %s (id, val) VALUES (1, 'testval') + `, srcTableName)) + require.NoError(s.t, err) + + tc := e2e.NewTemporalClient(s.t) + env := e2e.ExecutePeerflow(tc, peerflow.CDCFlowWorkflow, flowConnConfig, nil) + e2e.SetupCDCFlowStatusQuery(s.t, env, connectionGen) + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute) + defer cancel() + + msgs := make(chan *pubsub.Message) + go func() { + _ = sub.Receive(ctx, func(_ context.Context, m *pubsub.Message) { + msgs <- m + }) + }() + select { + case msg := <-msgs: + require.NotNil(s.t, msg) + require.Equal(s.t, "testval", string(msg.Data)) + case <-ctx.Done(): + s.t.Log("UNEXPECTED TIMEOUT PubSub subscription waiting on message") + s.t.Fail() + } + + env.Cancel() + e2e.RequireEnvCanceled(s.t, env) +}