diff --git a/flow/activities/flowable_core.go b/flow/activities/flowable_core.go index e9687a2240..319da34c49 100644 --- a/flow/activities/flowable_core.go +++ b/flow/activities/flowable_core.go @@ -187,13 +187,14 @@ func syncCore[TPull connectors.CDCPullConnectorCore, TSync connectors.CDCSyncCon syncStartTime = time.Now() res, err = sync(dstConn, errCtx, &model.SyncRecordsRequest[Items]{ - SyncBatchID: syncBatchID, - Records: recordBatch, - ConsumedOffset: &consumedOffset, - FlowJobName: flowName, - TableMappings: options.TableMappings, - StagingPath: config.CdcStagingPath, - Script: config.Script, + SyncBatchID: syncBatchID, + Records: recordBatch, + ConsumedOffset: &consumedOffset, + FlowJobName: flowName, + TableMappings: options.TableMappings, + StagingPath: config.CdcStagingPath, + Script: config.Script, + TableNameSchemaMapping: options.TableNameSchemaMapping, }) if err != nil { a.Alerter.LogFlowError(ctx, flowName, err) diff --git a/flow/connectors/connelasticsearch/elasticsearch.go b/flow/connectors/connelasticsearch/elasticsearch.go index 0672e05c20..8d5f092790 100644 --- a/flow/connectors/connelasticsearch/elasticsearch.go +++ b/flow/connectors/connelasticsearch/elasticsearch.go @@ -4,26 +4,36 @@ import ( "bytes" "context" "crypto/tls" + "encoding/base64" "encoding/json" + "errors" "fmt" "log/slog" "net/http" "sync" + "sync/atomic" "time" "github.com/elastic/go-elasticsearch/v8" "github.com/elastic/go-elasticsearch/v8/esutil" - "github.com/google/uuid" "go.temporal.io/sdk/log" + "golang.org/x/exp/maps" metadataStore "github.com/PeerDB-io/peer-flow/connectors/external_metadata" + "github.com/PeerDB-io/peer-flow/connectors/utils" "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/logger" "github.com/PeerDB-io/peer-flow/model" "github.com/PeerDB-io/peer-flow/model/qvalue" + "github.com/PeerDB-io/peer-flow/peerdbenv" "github.com/PeerDB-io/peer-flow/shared" ) +const ( + actionIndex = "index" + actionDelete = "delete" +) + type ElasticsearchConnector struct { *metadataStore.PostgresMetadata client *elasticsearch.Client @@ -78,93 +88,170 @@ func (esc *ElasticsearchConnector) Close() error { return nil } -func (esc *ElasticsearchConnector) SetupQRepMetadataTables(ctx context.Context, - config *protos.QRepConfig, +// ES is queue-like, no raw table staging needed +func (esc *ElasticsearchConnector) CreateRawTable(ctx context.Context, + req *protos.CreateRawTableInput, +) (*protos.CreateRawTableOutput, error) { + return &protos.CreateRawTableOutput{TableIdentifier: "n/a"}, nil +} + +// we handle schema changes by not handling them since no mapping is being enforced right now +func (esc *ElasticsearchConnector) ReplayTableSchemaDeltas(ctx context.Context, + flowJobName string, schemaDeltas []*protos.TableSchemaDelta, ) error { return nil } -func (esc *ElasticsearchConnector) SyncQRepRecords(ctx context.Context, config *protos.QRepConfig, - partition *protos.QRepPartition, stream *model.QRecordStream, -) (int, error) { - startTime := time.Now() +func recordItemsProcessor(items model.RecordItems) ([]byte, error) { + qRecordJsonMap := make(map[string]any) + + for key, val := range items.ColToVal { + if r, ok := val.(qvalue.QValueJSON); ok { // JSON is stored as a string, fix that + qRecordJsonMap[key] = json.RawMessage( + shared.UnsafeFastStringToReadOnlyBytes(r.Val)) + } else { + qRecordJsonMap[key] = val.Value() + } + } - schema := stream.Schema() + return json.Marshal(qRecordJsonMap) +} - var bulkIndexFatalError error - var bulkIndexErrors []error - var bulkIndexMutex sync.Mutex - var docId string - numRecords := 0 - bulkIndexerHasShutdown := false - - // -1 means use UUID, >=0 means column in the record - upsertColIndex := -1 - // only support single upsert column for now - if config.WriteMode.WriteType == protos.QRepWriteType_QREP_WRITE_MODE_UPSERT && - len(config.WriteMode.UpsertKeyColumns) == 1 { - for i, field := range schema.Fields { - if config.WriteMode.UpsertKeyColumns[0] == field.Name { - upsertColIndex = i +func (esc *ElasticsearchConnector) SyncRecords(ctx context.Context, + req *model.SyncRecordsRequest[model.RecordItems], +) (*model.SyncResponse, error) { + tableNameRowsMapping := utils.InitialiseTableRowsMap(req.TableMappings) + // atomics for counts will be unnecessary in other destinations, using a mutex instead + var recordCountsUpdateMutex sync.Mutex + // we're taking a mutex anyway, avoid atomic + var lastSeenLSN atomic.Int64 + var numRecords atomic.Int64 + + // no I don't like this either + esBulkIndexerCache := make(map[string]esutil.BulkIndexer) + bulkIndexersHaveShutdown := false + // true if we saw errors while closing + cacheCloser := func() bool { + closeHasErrors := false + if bulkIndexersHaveShutdown { + for _, esBulkIndexer := range maps.Values(esBulkIndexerCache) { + err := esBulkIndexer.Close(context.Background()) + if err != nil { + esc.logger.Error("[es] failed to close bulk indexer", slog.Any("error", err)) + closeHasErrors = true + } } + bulkIndexersHaveShutdown = true } + return closeHasErrors } + defer cacheCloser() - esBulkIndexer, err := esutil.NewBulkIndexer(esutil.BulkIndexerConfig{ - Index: config.DestinationTableIdentifier, - Client: esc.client, - // parallelism comes from the workflow design itself, no need for this - NumWorkers: 1, - FlushInterval: 10 * time.Second, - }) - if err != nil { - esc.logger.Error("[es] failed to initialize bulk indexer", slog.Any("error", err)) - return 0, fmt.Errorf("[es] failed to initialize bulk indexer: %w", err) - } - defer func() { - if !bulkIndexerHasShutdown { - err := esBulkIndexer.Close(context.Background()) - if err != nil { - esc.logger.Error("[es] failed to close bulk indexer", slog.Any("error", err)) + flushLoopDone := make(chan struct{}) + // we only update lastSeenLSN in the OnSuccess call, so this should be safe even if race + // between loop breaking and closing flushLoopDone + go func() { + ticker := time.NewTicker(peerdbenv.PeerDBQueueFlushTimeoutSeconds()) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-flushLoopDone: + return + case <-ticker.C: + lastSeen := lastSeenLSN.Load() + if lastSeen > req.ConsumedOffset.Load() { + if err := esc.SetLastOffset(ctx, req.FlowJobName, lastSeen); err != nil { + esc.logger.Warn("[es] SetLastOffset error", slog.Any("error", err)) + } else { + shared.AtomicInt64Max(req.ConsumedOffset, lastSeen) + esc.logger.Info("processBatch", slog.Int64("updated last offset", lastSeen)) + } + } } } }() - for qRecord := range stream.Records { - qRecordJsonMap := make(map[string]any) + var docId string + var bulkIndexFatalError error + var bulkIndexErrors []error + var bulkIndexOnFailureMutex sync.Mutex - if upsertColIndex >= 0 { - docId = fmt.Sprintf("%v", qRecord[upsertColIndex].Value()) - } else { - docId = uuid.New().String() + for record := range req.Records.GetRecords() { + var bodyBytes []byte + var err error + action := actionIndex + + switch record.(type) { + case *model.InsertRecord[model.RecordItems], *model.UpdateRecord[model.RecordItems]: + bodyBytes, err = recordItemsProcessor(record.GetItems()) + if err != nil { + esc.logger.Error("[es] failed to json.Marshal record", slog.Any("error", err)) + return nil, fmt.Errorf("[es] failed to json.Marshal record: %w", err) + } + case *model.DeleteRecord[model.RecordItems]: + action = actionDelete + // no need to supply the document since we are deleting + bodyBytes = nil } - for i, field := range schema.Fields { - switch r := qRecord[i].(type) { - // JSON is stored as a string, fix that - case qvalue.QValueJSON: - qRecordJsonMap[field.Name] = json.RawMessage(shared. - UnsafeFastStringToReadOnlyBytes(r.Val)) - default: - qRecordJsonMap[field.Name] = r.Value() + + bulkIndexer, ok := esBulkIndexerCache[record.GetDestinationTableName()] + if !ok { + bulkIndexer, err = esutil.NewBulkIndexer(esutil.BulkIndexerConfig{ + Index: record.GetDestinationTableName(), + Client: esc.client, + // can't really ascertain how many tables present to provide a reasonable value + NumWorkers: 1, + FlushInterval: 10 * time.Second, + }) + if err != nil { + esc.logger.Error("[es] failed to initialize bulk indexer", slog.Any("error", err)) + return nil, fmt.Errorf("[es] failed to initialize bulk indexer: %w", err) } + esBulkIndexerCache[record.GetDestinationTableName()] = bulkIndexer } - qRecordJsonBytes, err := json.Marshal(qRecordJsonMap) - if err != nil { - esc.logger.Error("[es] failed to json.Marshal record", slog.Any("error", err)) - return 0, fmt.Errorf("[es] failed to json.Marshal record: %w", err) + + if len(req.TableNameSchemaMapping[record.GetDestinationTableName()].PrimaryKeyColumns) == 1 { + qValue, err := record.GetItems().GetValueByColName( + req.TableNameSchemaMapping[record.GetDestinationTableName()].PrimaryKeyColumns[0]) + if err != nil { + esc.logger.Error("[es] failed to process record", slog.Any("error", err)) + return nil, fmt.Errorf("[es] failed to process record: %w", err) + } + docId = fmt.Sprint(qValue.Value()) + } else { + tablePkey, err := model.RecToTablePKey(req.TableNameSchemaMapping, record) + if err != nil { + esc.logger.Error("[es] failed to process record", slog.Any("error", err)) + return nil, fmt.Errorf("[es] failed to process record: %w", err) + } + docId = base64.RawURLEncoding.EncodeToString(tablePkey.PkeyColVal[:]) } - err = esBulkIndexer.Add(ctx, esutil.BulkIndexerItem{ - Action: "index", + err = bulkIndexer.Add(ctx, esutil.BulkIndexerItem{ + Action: action, DocumentID: docId, - Body: bytes.NewReader(qRecordJsonBytes), + Body: bytes.NewReader(bodyBytes), + OnSuccess: func(_ context.Context, _ esutil.BulkIndexerItem, _ esutil.BulkIndexerResponseItem) { + shared.AtomicInt64Max(&lastSeenLSN, record.GetCheckpointID()) + numRecords.Add(1) + recordCountsUpdateMutex.Lock() + defer recordCountsUpdateMutex.Unlock() + record.PopulateCountMap(tableNameRowsMapping) + }, // OnFailure is called for each failed operation, log and let parent handle OnFailure: func(ctx context.Context, item esutil.BulkIndexerItem, res esutil.BulkIndexerResponseItem, err error, ) { - bulkIndexMutex.Lock() - defer bulkIndexMutex.Unlock() + // attempt to delete a record that wasn't present, possible from no initial load + if item.Action == actionDelete && res.Status == 404 { + return + } + bulkIndexOnFailureMutex.Lock() + defer bulkIndexOnFailureMutex.Unlock() if err != nil { bulkIndexErrors = append(bulkIndexErrors, err) } else { @@ -172,7 +259,7 @@ func (esc *ElasticsearchConnector) SyncQRepRecords(ctx context.Context, config * if res.Error.Cause.Type != "" || res.Error.Cause.Reason != "" { causeString = fmt.Sprintf("(caused by type:%s reason:%s)", res.Error.Cause.Type, res.Error.Cause.Reason) } - cbErr := fmt.Errorf("id:%s type:%s reason:%s %s", item.DocumentID, res.Error.Type, + cbErr := fmt.Errorf("id:%s action:%s type:%s reason:%s %s", item.DocumentID, item.Action, res.Error.Type, res.Error.Reason, causeString) bulkIndexErrors = append(bulkIndexErrors, cbErr) if res.Error.Type == "illegal_argument_exception" { @@ -183,36 +270,37 @@ func (esc *ElasticsearchConnector) SyncQRepRecords(ctx context.Context, config * }) if err != nil { esc.logger.Error("[es] failed to add record to bulk indexer", slog.Any("error", err)) - return 0, fmt.Errorf("[es] failed to add record to bulk indexer: %w", err) + return nil, fmt.Errorf("[es] failed to add record to bulk indexer: %w", err) } if bulkIndexFatalError != nil { esc.logger.Error("[es] fatal error while indexing record", slog.Any("error", bulkIndexFatalError)) - return 0, fmt.Errorf("[es] fatal error while indexing record: %w", bulkIndexFatalError) + return nil, fmt.Errorf("[es] fatal error while indexing record: %w", bulkIndexFatalError) } - - // update here instead of OnSuccess, if we close successfully it should match - numRecords++ } + // "Receive on a closed channel yields the zero value after all elements in the channel are received." + close(flushLoopDone) - if err := stream.Err(); err != nil { - esc.logger.Error("[es] failed to get record from stream", slog.Any("error", err)) - return 0, fmt.Errorf("[es] failed to get record from stream: %w", err) + if cacheCloser() { + esc.logger.Error("[es] failed to close bulk indexer(s)") + return nil, errors.New("[es] failed to close bulk indexer(s)") } - if err := esBulkIndexer.Close(ctx); err != nil { - esc.logger.Error("[es] failed to close bulk indexer", slog.Any("error", err)) - return 0, fmt.Errorf("[es] failed to close bulk indexer: %w", err) - } - bulkIndexerHasShutdown = true + bulkIndexersHaveShutdown = true if len(bulkIndexErrors) > 0 { for _, err := range bulkIndexErrors { esc.logger.Error("[es] failed to index record", slog.Any("err", err)) } } - err = esc.FinishQRepPartition(ctx, partition, config.FlowJobName, startTime) - if err != nil { - esc.logger.Error("[es] failed to log partition info", slog.Any("error", err)) - return 0, fmt.Errorf("[es] failed to log partition info: %w", err) + lastCheckpoint := req.Records.GetLastCheckpoint() + if err := esc.FinishBatch(ctx, req.FlowJobName, req.SyncBatchID, lastCheckpoint); err != nil { + return nil, err } - return numRecords, nil + + return &model.SyncResponse{ + CurrentSyncBatchID: req.SyncBatchID, + LastSyncedCheckpointID: lastCheckpoint, + NumRecordsSynced: numRecords.Load(), + TableNameRowsMapping: tableNameRowsMapping, + TableSchemaDeltas: req.Records.SchemaDeltas, + }, nil } diff --git a/flow/connectors/connelasticsearch/qrep.go b/flow/connectors/connelasticsearch/qrep.go new file mode 100644 index 0000000000..142c6de363 --- /dev/null +++ b/flow/connectors/connelasticsearch/qrep.go @@ -0,0 +1,175 @@ +package connelasticsearch + +import ( + "bytes" + "context" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "log/slog" + "slices" + "sync" + "time" + + "github.com/elastic/go-elasticsearch/v8/esutil" + + "github.com/PeerDB-io/peer-flow/generated/protos" + "github.com/PeerDB-io/peer-flow/model" + "github.com/PeerDB-io/peer-flow/model/qvalue" + "github.com/PeerDB-io/peer-flow/shared" +) + +func (esc *ElasticsearchConnector) SetupQRepMetadataTables(ctx context.Context, + config *protos.QRepConfig, +) error { + return nil +} + +func upsertKeyColsHash(qRecord []qvalue.QValue, upsertColIndices []int) string { + hasher := sha256.New() + + for _, upsertColIndex := range upsertColIndices { + // cannot return an error + _, _ = fmt.Fprint(hasher, qRecord[upsertColIndex].Value()) + } + hashBytes := hasher.Sum(nil) + return base64.RawURLEncoding.EncodeToString(hashBytes) +} + +func (esc *ElasticsearchConnector) SyncQRepRecords(ctx context.Context, config *protos.QRepConfig, + partition *protos.QRepPartition, stream *model.QRecordStream, +) (int, error) { + startTime := time.Now() + + schema := stream.Schema() + + var bulkIndexFatalError error + var bulkIndexErrors []error + var bulkIndexOnFailureMutex sync.Mutex + var docId string + numRecords := 0 + bulkIndexerHasShutdown := false + + // len == 0 means use UUID + // len == 1 means single column, use value directly + // len > 1 means SHA256 hash of upsert key columns + // ordered such that we preserve order of UpsertKeyColumns + var upsertKeyColIndices []int + if config.WriteMode.WriteType == protos.QRepWriteType_QREP_WRITE_MODE_UPSERT { + schemaColNames := schema.GetColumnNames() + for _, upsertCol := range config.WriteMode.UpsertKeyColumns { + idx := slices.Index(schemaColNames, upsertCol) + if idx != -1 { + upsertKeyColIndices = append(upsertKeyColIndices, idx) + } + } + } + + esBulkIndexer, err := esutil.NewBulkIndexer(esutil.BulkIndexerConfig{ + Index: config.DestinationTableIdentifier, + Client: esc.client, + // parallelism comes from the workflow design itself, no need for this + NumWorkers: 1, + FlushInterval: 10 * time.Second, + }) + if err != nil { + esc.logger.Error("[es] failed to initialize bulk indexer", slog.Any("error", err)) + return 0, fmt.Errorf("[es] failed to initialize bulk indexer: %w", err) + } + defer func() { + if !bulkIndexerHasShutdown { + err := esBulkIndexer.Close(context.Background()) + if err != nil { + esc.logger.Error("[es] failed to close bulk indexer", slog.Any("error", err)) + } + } + }() + + for qRecord := range stream.Records { + qRecordJsonMap := make(map[string]any) + + switch len(upsertKeyColIndices) { + case 0: + // relying on autogeneration of document ID + case 1: + docId = fmt.Sprint(qRecord[upsertKeyColIndices[0]].Value()) + default: + docId = upsertKeyColsHash(qRecord, upsertKeyColIndices) + } + for i, field := range schema.Fields { + if r, ok := qRecord[i].(qvalue.QValueJSON); ok { // JSON is stored as a string, fix that + qRecordJsonMap[field.Name] = json.RawMessage( + shared.UnsafeFastStringToReadOnlyBytes(r.Val)) + } else { + qRecordJsonMap[field.Name] = qRecord[i].Value() + } + } + qRecordJsonBytes, err := json.Marshal(qRecordJsonMap) + if err != nil { + esc.logger.Error("[es] failed to json.Marshal record", slog.Any("error", err)) + return 0, fmt.Errorf("[es] failed to json.Marshal record: %w", err) + } + + err = esBulkIndexer.Add(ctx, esutil.BulkIndexerItem{ + Action: actionIndex, + DocumentID: docId, + Body: bytes.NewReader(qRecordJsonBytes), + + // OnFailure is called for each failed operation, log and let parent handle + OnFailure: func(ctx context.Context, item esutil.BulkIndexerItem, + res esutil.BulkIndexerResponseItem, err error, + ) { + bulkIndexOnFailureMutex.Lock() + defer bulkIndexOnFailureMutex.Unlock() + if err != nil { + bulkIndexErrors = append(bulkIndexErrors, err) + } else { + causeString := "" + if res.Error.Cause.Type != "" || res.Error.Cause.Reason != "" { + causeString = fmt.Sprintf("(caused by type:%s reason:%s)", res.Error.Cause.Type, res.Error.Cause.Reason) + } + cbErr := fmt.Errorf("id:%s type:%s reason:%s %s", item.DocumentID, res.Error.Type, + res.Error.Reason, causeString) + bulkIndexErrors = append(bulkIndexErrors, cbErr) + if res.Error.Type == "illegal_argument_exception" { + bulkIndexFatalError = cbErr + } + } + }, + }) + if err != nil { + esc.logger.Error("[es] failed to add record to bulk indexer", slog.Any("error", err)) + return 0, fmt.Errorf("[es] failed to add record to bulk indexer: %w", err) + } + if bulkIndexFatalError != nil { + esc.logger.Error("[es] fatal error while indexing record", slog.Any("error", bulkIndexFatalError)) + return 0, fmt.Errorf("[es] fatal error while indexing record: %w", bulkIndexFatalError) + } + + // update here instead of OnSuccess, if we close successfully it should match + numRecords++ + } + + if err := stream.Err(); err != nil { + esc.logger.Error("[es] failed to get record from stream", slog.Any("error", err)) + return 0, fmt.Errorf("[es] failed to get record from stream: %w", err) + } + if err := esBulkIndexer.Close(ctx); err != nil { + esc.logger.Error("[es] failed to close bulk indexer", slog.Any("error", err)) + return 0, fmt.Errorf("[es] failed to close bulk indexer: %w", err) + } + bulkIndexerHasShutdown = true + if len(bulkIndexErrors) > 0 { + for _, err := range bulkIndexErrors { + esc.logger.Error("[es] failed to index record", slog.Any("err", err)) + } + } + + err = esc.FinishQRepPartition(ctx, partition, config.FlowJobName, startTime) + if err != nil { + esc.logger.Error("[es] failed to log partition info", slog.Any("error", err)) + return 0, fmt.Errorf("[es] failed to log partition info: %w", err) + } + return numRecords, nil +} diff --git a/flow/connectors/core.go b/flow/connectors/core.go index 1e3321fa72..545269c4d1 100644 --- a/flow/connectors/core.go +++ b/flow/connectors/core.go @@ -304,6 +304,7 @@ var ( _ CDCSyncConnector = &connpubsub.PubSubConnector{} _ CDCSyncConnector = &conns3.S3Connector{} _ CDCSyncConnector = &connclickhouse.ClickhouseConnector{} + _ CDCSyncConnector = &connelasticsearch.ElasticsearchConnector{} _ CDCSyncPgConnector = &connpostgres.PostgresConnector{} diff --git a/flow/connectors/postgres/cdc.go b/flow/connectors/postgres/cdc.go index 4adcc1d0d4..ae1e84292e 100644 --- a/flow/connectors/postgres/cdc.go +++ b/flow/connectors/postgres/cdc.go @@ -2,10 +2,8 @@ package connpostgres import ( "context" - "crypto/sha256" "fmt" "log/slog" - "slices" "time" "github.com/jackc/pglogrepl" @@ -488,7 +486,7 @@ func PullCdcRecords[Items model.Items]( return err } } else { - tablePkeyVal, err := recToTablePKey(req, rec) + tablePkeyVal, err := model.RecToTablePKey[Items](req.TableNameSchemaMapping, rec) if err != nil { return err } @@ -520,7 +518,7 @@ func PullCdcRecords[Items model.Items]( return err } } else { - tablePkeyVal, err := recToTablePKey(req, rec) + tablePkeyVal, err := model.RecToTablePKey[Items](req.TableNameSchemaMapping, rec) if err != nil { return err } @@ -538,7 +536,7 @@ func PullCdcRecords[Items model.Items]( return err } } else { - tablePkeyVal, err := recToTablePKey(req, rec) + tablePkeyVal, err := model.RecToTablePKey[Items](req.TableNameSchemaMapping, rec) if err != nil { return err } @@ -865,27 +863,6 @@ func processRelationMessage[Items model.Items]( return nil, nil } -func recToTablePKey[Items model.Items]( - req *model.PullRecordsRequest[Items], - rec model.Record[Items], -) (model.TableWithPkey, error) { - tableName := rec.GetDestinationTableName() - pkeyColsMerged := make([][]byte, 0, len(req.TableNameSchemaMapping[tableName].PrimaryKeyColumns)) - - for _, pkeyCol := range req.TableNameSchemaMapping[tableName].PrimaryKeyColumns { - pkeyColBytes, err := rec.GetItems().GetBytesByColName(pkeyCol) - if err != nil { - return model.TableWithPkey{}, fmt.Errorf("error getting pkey column value: %w", err) - } - pkeyColsMerged = append(pkeyColsMerged, pkeyColBytes) - } - - return model.TableWithPkey{ - TableName: tableName, - PkeyColVal: sha256.Sum256(slices.Concat(pkeyColsMerged...)), - }, nil -} - func (p *PostgresCDCSource) getParentRelIDIfPartitioned(relID uint32) uint32 { parentRelID, ok := p.childToParentRelIDMapping[relID] if ok { diff --git a/flow/connectors/postgres/postgres.go b/flow/connectors/postgres/postgres.go index e88a736428..a20291b04d 100644 --- a/flow/connectors/postgres/postgres.go +++ b/flow/connectors/postgres/postgres.go @@ -706,7 +706,7 @@ func (c *PostgresConnector) CreateRawTable(ctx context.Context, req *protos.Crea _, err = createRawTableTx.Exec(ctx, fmt.Sprintf(createRawTableDstTableIndexSQL, rawTableIdentifier, c.metadataSchema, rawTableIdentifier)) if err != nil { - return nil, fmt.Errorf("error creating destion table index on raw table: %w", err) + return nil, fmt.Errorf("error creating destination table index on raw table: %w", err) } err = createRawTableTx.Commit(ctx) diff --git a/flow/e2e/elasticsearch/elasticsearch.go b/flow/e2e/elasticsearch/elasticsearch.go index d0abea9413..ed3389adb7 100644 --- a/flow/e2e/elasticsearch/elasticsearch.go +++ b/flow/e2e/elasticsearch/elasticsearch.go @@ -80,8 +80,12 @@ func (s elasticsearchSuite) Peer() *protos.Peer { return s.peer } -func (s elasticsearchSuite) CountDocumentsInIndex(index string) int64 { +func (s elasticsearchSuite) countDocumentsInIndex(index string) int64 { res, err := s.esClient.Count().Index(index).Do(context.Background()) + // index may not exist yet, don't error out for that + if err != nil && strings.Contains(err.Error(), "index_not_found_exception") { + return 0 + } require.NoError(s.t, err, "failed to get count of documents in index") return res.Count } diff --git a/flow/e2e/elasticsearch/peer_flow_es_test.go b/flow/e2e/elasticsearch/peer_flow_es_test.go new file mode 100644 index 0000000000..04e1e6d6ab --- /dev/null +++ b/flow/e2e/elasticsearch/peer_flow_es_test.go @@ -0,0 +1,150 @@ +package e2e_elasticsearch + +import ( + "context" + "fmt" + "time" + + "github.com/stretchr/testify/require" + + "github.com/PeerDB-io/peer-flow/e2e" + peerflow "github.com/PeerDB-io/peer-flow/workflows" +) + +func (s elasticsearchSuite) Test_Simple_PKey_CDC_Mirror() { + srcTableName := e2e.AttachSchema(s, "es_simple_pkey_cdc") + + _, err := s.conn.Conn().Exec(context.Background(), fmt.Sprintf(` + CREATE TABLE IF NOT EXISTS %s ( + id INT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, + c1 INT, + val TEXT, + updated_at TIMESTAMP DEFAULT now() + ); + `, srcTableName)) + require.NoError(s.t, err, "failed creating table") + + tc := e2e.NewTemporalClient(s.t) + connectionGen := e2e.FlowConnectionGenerationConfig{ + FlowJobName: e2e.AddSuffix(s, "es_simple_pkey_cdc"), + TableNameMapping: map[string]string{srcTableName: srcTableName}, + Destination: s.peer, + } + flowConnConfig := connectionGen.GenerateFlowConnectionConfigs() + flowConnConfig.MaxBatchSize = 100 + flowConnConfig.DoInitialSnapshot = true + + rowCount := 10 + for i := range rowCount { + _, err := s.conn.Conn().Exec(context.Background(), fmt.Sprintf(` + INSERT INTO %s(c1,val) VALUES(%d,'val%d') + `, srcTableName, i, i)) + require.NoError(s.t, err, "failed to insert row") + } + + env := e2e.ExecutePeerflow(tc, peerflow.CDCFlowWorkflow, flowConnConfig, nil) + e2e.SetupCDCFlowStatusQuery(s.t, env, connectionGen) + + for i := range rowCount { + _, err := s.conn.Conn().Exec(context.Background(), fmt.Sprintf(` + INSERT INTO %s(c1,val) VALUES(%d,'val%d') + `, srcTableName, i, i)) + require.NoError(s.t, err, "failed to insert row") + } + e2e.EnvWaitFor(s.t, env, 3*time.Minute, "wait for initial snapshot + inserted rows", func() bool { + return s.countDocumentsInIndex(srcTableName) == int64(2*rowCount) + }) + + _, err = s.conn.Conn().Exec(context.Background(), fmt.Sprintf(` + UPDATE %s SET c1=c1+2,updated_at=now() WHERE id%%2=0;`, srcTableName)) + require.NoError(s.t, err, "failed to update rows on source") + for i := range rowCount { + _, err := s.conn.Conn().Exec(context.Background(), fmt.Sprintf(` + INSERT INTO %s(c1,val) VALUES(%d,'val%d') + `, srcTableName, i, i)) + require.NoError(s.t, err, "failed to insert row") + } + e2e.EnvWaitFor(s.t, env, 3*time.Minute, "wait for updates + new inserts", func() bool { + return s.countDocumentsInIndex(srcTableName) == int64(3*rowCount) + }) + + _, err = s.conn.Conn().Exec(context.Background(), fmt.Sprintf(` + DELETE FROM %s WHERE id%%2=1;`, srcTableName)) + require.NoError(s.t, err, "failed to delete rows on source") + e2e.EnvWaitFor(s.t, env, 3*time.Minute, "wait for deletes", func() bool { + return s.countDocumentsInIndex(srcTableName) == int64(3*rowCount/2) + }) + + env.Cancel() + e2e.RequireEnvCanceled(s.t, env) +} + +func (s elasticsearchSuite) Test_Composite_PKey_CDC_Mirror() { + srcTableName := e2e.AttachSchema(s, "es_composite_pkey_cdc") + + _, err := s.conn.Conn().Exec(context.Background(), fmt.Sprintf(` + CREATE TABLE IF NOT EXISTS %s ( + id INT GENERATED ALWAYS AS IDENTITY, + c1 INT, + val TEXT, + updated_at TIMESTAMP DEFAULT now(), + PRIMARY KEY(id,val) + ); + `, srcTableName)) + require.NoError(s.t, err, "failed creating table") + + tc := e2e.NewTemporalClient(s.t) + connectionGen := e2e.FlowConnectionGenerationConfig{ + FlowJobName: e2e.AddSuffix(s, "es_composite_pkey_cdc"), + TableNameMapping: map[string]string{srcTableName: srcTableName}, + Destination: s.peer, + } + flowConnConfig := connectionGen.GenerateFlowConnectionConfigs() + flowConnConfig.MaxBatchSize = 100 + flowConnConfig.DoInitialSnapshot = true + + rowCount := 10 + for i := range rowCount { + _, err := s.conn.Conn().Exec(context.Background(), fmt.Sprintf(` + INSERT INTO %s(c1,val) VALUES(%d,'val%d') + `, srcTableName, i, i)) + require.NoError(s.t, err, "failed to insert row") + } + + env := e2e.ExecutePeerflow(tc, peerflow.CDCFlowWorkflow, flowConnConfig, nil) + e2e.SetupCDCFlowStatusQuery(s.t, env, connectionGen) + + for i := range rowCount { + _, err := s.conn.Conn().Exec(context.Background(), fmt.Sprintf(` + INSERT INTO %s(c1,val) VALUES(%d,'val%d') + `, srcTableName, i, i)) + require.NoError(s.t, err, "failed to insert row") + } + + e2e.EnvWaitFor(s.t, env, 3*time.Minute, "wait for initial snapshot + inserted rows", func() bool { + return s.countDocumentsInIndex(srcTableName) == int64(2*rowCount) + }) + + _, err = s.conn.Conn().Exec(context.Background(), fmt.Sprintf(` + UPDATE %s SET c1=c1+2,updated_at=now() WHERE id%%2=0;`, srcTableName)) + require.NoError(s.t, err, "failed to update rows on source") + for i := range rowCount { + _, err := s.conn.Conn().Exec(context.Background(), fmt.Sprintf(` + INSERT INTO %s(c1,val) VALUES(%d,'val%d') + `, srcTableName, i, i)) + require.NoError(s.t, err, "failed to insert row") + } + e2e.EnvWaitFor(s.t, env, 3*time.Minute, "wait for updates + new inserts", func() bool { + return s.countDocumentsInIndex(srcTableName) == int64(3*rowCount) + }) + + _, err = s.conn.Conn().Exec(context.Background(), fmt.Sprintf(` + DELETE FROM %s WHERE id%%2=1;`, srcTableName)) + require.NoError(s.t, err, "failed to delete rows on source") + e2e.EnvWaitFor(s.t, env, 3*time.Minute, "wait for deletes", func() bool { + return s.countDocumentsInIndex(srcTableName) == int64(3*rowCount/2) + }) + + env.Cancel() + e2e.RequireEnvCanceled(s.t, env) +} diff --git a/flow/e2e/elasticsearch/qrep_flow_es_test.go b/flow/e2e/elasticsearch/qrep_flow_es_test.go index dafa1b2d79..85ac99043e 100644 --- a/flow/e2e/elasticsearch/qrep_flow_es_test.go +++ b/flow/e2e/elasticsearch/qrep_flow_es_test.go @@ -10,18 +10,20 @@ import ( "github.com/PeerDB-io/peer-flow/e2e" "github.com/PeerDB-io/peer-flow/e2eshared" + "github.com/PeerDB-io/peer-flow/generated/protos" ) func Test_Elasticsearch(t *testing.T) { e2eshared.RunSuite(t, SetupSuite) } -func (s elasticsearchSuite) Test_Simple_Qrep() { - srcTableName := e2e.AttachSchema(s, "es_simple") +func (s elasticsearchSuite) Test_Simple_QRep_Append() { + srcTableName := e2e.AttachSchema(s, "es_simple_append") _, err := s.conn.Conn().Exec(context.Background(), fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( id INT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, + c1 INT, val TEXT, updated_at TIMESTAMP DEFAULT now() ); @@ -31,8 +33,8 @@ func (s elasticsearchSuite) Test_Simple_Qrep() { rowCount := 10 for i := range rowCount { _, err := s.conn.Conn().Exec(context.Background(), fmt.Sprintf(` - INSERT INTO %s(val) VALUES('val%d') - `, srcTableName, i)) + INSERT INTO %s(c1,val) VALUES(%d,'val%d') + `, srcTableName, i, i)) require.NoError(s.t, err, "failed to insert row") } @@ -51,10 +53,72 @@ func (s elasticsearchSuite) Test_Simple_Qrep() { "", "", ) + qrepConfig.InitialCopyOnly = false + require.NoError(s.t, err) env := e2e.RunQRepFlowWorkflow(tc, qrepConfig) - e2e.EnvWaitForFinished(s.t, env, 3*time.Minute) + + e2e.EnvWaitFor(s.t, env, 10*time.Second, "waiting for ES to catch up", func() bool { + return s.countDocumentsInIndex(srcTableName) == int64(rowCount) + }) + _, err = s.conn.Conn().Exec(context.Background(), fmt.Sprintf(` + UPDATE %s SET c1=c1+2,updated_at=now() WHERE id%%2=0;`, srcTableName)) + require.NoError(s.t, err, "failed to update rows on source") + e2e.EnvWaitFor(s.t, env, 20*time.Second, "waiting for ES to catch up", func() bool { + return s.countDocumentsInIndex(srcTableName) == int64(3*rowCount/2) + }) + require.NoError(s.t, env.Error()) +} + +func (s elasticsearchSuite) Test_Simple_QRep_Upsert() { + srcTableName := e2e.AttachSchema(s, "es_simple_upsert") + + _, err := s.conn.Conn().Exec(context.Background(), fmt.Sprintf(` + CREATE TABLE IF NOT EXISTS %s ( + id INT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, + c1 INT, + val TEXT, + updated_at TIMESTAMP DEFAULT now() + ); + `, srcTableName)) + require.NoError(s.t, err, "failed creating table") + + rowCount := 10 + for i := range rowCount { + _, err := s.conn.Conn().Exec(context.Background(), fmt.Sprintf(` + INSERT INTO %s(c1,val) VALUES(%d,'val%d') + `, srcTableName, i, i)) + require.NoError(s.t, err, "failed to insert row") + } + + tc := e2e.NewTemporalClient(s.t) + + query := fmt.Sprintf("SELECT * FROM %s WHERE updated_at BETWEEN {{.start}} AND {{.end}}", + srcTableName) - require.EqualValues(s.t, rowCount, s.CountDocumentsInIndex(srcTableName)) + qrepConfig, err := e2e.CreateQRepWorkflowConfig("test_es_simple_qrep", + srcTableName, + srcTableName, + query, + s.peer, + "", + false, + "", + "", + ) + qrepConfig.WriteMode = &protos.QRepWriteMode{ + WriteType: protos.QRepWriteType_QREP_WRITE_MODE_UPSERT, + UpsertKeyColumns: []string{"id"}, + } + qrepConfig.InitialCopyOnly = false + + require.NoError(s.t, err) + env := e2e.RunQRepFlowWorkflow(tc, qrepConfig) + + e2e.EnvWaitFor(s.t, env, 10*time.Second, "waiting for ES to catch up", func() bool { + return s.countDocumentsInIndex(srcTableName) == int64(rowCount) + }) + + require.NoError(s.t, env.Error()) } diff --git a/flow/model/model.go b/flow/model/model.go index 54c2ef41ce..acdcdc5176 100644 --- a/flow/model/model.go +++ b/flow/model/model.go @@ -1,6 +1,8 @@ package model import ( + "crypto/sha256" + "fmt" "sync/atomic" "time" @@ -75,12 +77,36 @@ type TableWithPkey struct { PkeyColVal [32]byte } +func RecToTablePKey[T Items]( + tableNameSchemaMapping map[string]*protos.TableSchema, + rec Record[T], +) (TableWithPkey, error) { + tableName := rec.GetDestinationTableName() + hasher := sha256.New() + + for _, pkeyCol := range tableNameSchemaMapping[tableName].PrimaryKeyColumns { + pkeyColBytes, err := rec.GetItems().GetBytesByColName(pkeyCol) + if err != nil { + return TableWithPkey{}, fmt.Errorf("error getting pkey column value: %w", err) + } + // cannot return an error + _, _ = hasher.Write(pkeyColBytes) + } + + return TableWithPkey{ + TableName: tableName, + PkeyColVal: [32]byte(hasher.Sum(nil)), + }, nil +} + type SyncRecordsRequest[T Items] struct { Records *CDCStream[T] // ConsumedOffset allows destination to confirm lsn for slot ConsumedOffset *atomic.Int64 // FlowJobName is the name of the flow job. - FlowJobName string + // destination table name -> schema mapping + TableNameSchemaMapping map[string]*protos.TableSchema + FlowJobName string // Staging path for AVRO files in CDC StagingPath string // Lua script diff --git a/flow/workflows/snapshot_flow.go b/flow/workflows/snapshot_flow.go index f1c8c7d7f3..fdec9d15c6 100644 --- a/flow/workflows/snapshot_flow.go +++ b/flow/workflows/snapshot_flow.go @@ -171,6 +171,18 @@ func (s *SnapshotFlowExecution) cloneTable( numRowsPerPartition = s.config.SnapshotNumRowsPerPartition } + snapshotWriteMode := &protos.QRepWriteMode{ + WriteType: protos.QRepWriteType_QREP_WRITE_MODE_APPEND, + } + // ensure document IDs are synchronized across initial load and CDC + // for the same document + if s.config.Destination.Type == protos.DBType_ELASTICSEARCH { + snapshotWriteMode = &protos.QRepWriteMode{ + WriteType: protos.QRepWriteType_QREP_WRITE_MODE_UPSERT, + UpsertKeyColumns: s.tableNameSchemaMapping[mapping.DestinationTableIdentifier].PrimaryKeyColumns, + } + } + config := &protos.QRepConfig{ FlowJobName: childWorkflowID, SourcePeer: sourcePostgres, @@ -185,10 +197,8 @@ func (s *SnapshotFlowExecution) cloneTable( StagingPath: s.config.SnapshotStagingPath, SyncedAtColName: s.config.SyncedAtColName, SoftDeleteColName: s.config.SoftDeleteColName, - WriteMode: &protos.QRepWriteMode{ - WriteType: protos.QRepWriteType_QREP_WRITE_MODE_APPEND, - }, - System: s.config.System, + WriteMode: snapshotWriteMode, + System: s.config.System, } state := NewQRepFlowState()