From e3f6ef658be8073490ba29de1c2180dbbb1fd77a Mon Sep 17 00:00:00 2001 From: Kevin Biju <52661649+heavycrystal@users.noreply.github.com> Date: Tue, 30 Apr 2024 10:26:19 +0530 Subject: [PATCH] Elasticsearch connector for CDC (#1649) Current limitations: 1) soft-delete does not work, needs to be handled specifically from the ES side 2) relies on implicit index creation, no custom mapping or other index configuration is possible 3) TOAST columns aren't handled properly --- flow/activities/flowable_core.go | 15 +- .../connelasticsearch/elasticsearch.go | 250 ++++++++++++------ flow/connectors/connelasticsearch/qrep.go | 175 ++++++++++++ flow/connectors/core.go | 1 + flow/connectors/postgres/cdc.go | 29 +- flow/connectors/postgres/postgres.go | 2 +- flow/e2e/elasticsearch/elasticsearch.go | 6 +- flow/e2e/elasticsearch/peer_flow_es_test.go | 150 +++++++++++ flow/e2e/elasticsearch/qrep_flow_es_test.go | 76 +++++- flow/model/model.go | 28 +- flow/workflows/snapshot_flow.go | 18 +- 11 files changed, 623 insertions(+), 127 deletions(-) create mode 100644 flow/connectors/connelasticsearch/qrep.go create mode 100644 flow/e2e/elasticsearch/peer_flow_es_test.go diff --git a/flow/activities/flowable_core.go b/flow/activities/flowable_core.go index e9687a2240..319da34c49 100644 --- a/flow/activities/flowable_core.go +++ b/flow/activities/flowable_core.go @@ -187,13 +187,14 @@ func syncCore[TPull connectors.CDCPullConnectorCore, TSync connectors.CDCSyncCon syncStartTime = time.Now() res, err = sync(dstConn, errCtx, &model.SyncRecordsRequest[Items]{ - SyncBatchID: syncBatchID, - Records: recordBatch, - ConsumedOffset: &consumedOffset, - FlowJobName: flowName, - TableMappings: options.TableMappings, - StagingPath: config.CdcStagingPath, - Script: config.Script, + SyncBatchID: syncBatchID, + Records: recordBatch, + ConsumedOffset: &consumedOffset, + FlowJobName: flowName, + TableMappings: options.TableMappings, + StagingPath: config.CdcStagingPath, + Script: config.Script, + TableNameSchemaMapping: options.TableNameSchemaMapping, }) if err != nil { a.Alerter.LogFlowError(ctx, flowName, err) diff --git a/flow/connectors/connelasticsearch/elasticsearch.go b/flow/connectors/connelasticsearch/elasticsearch.go index 0672e05c20..8d5f092790 100644 --- a/flow/connectors/connelasticsearch/elasticsearch.go +++ b/flow/connectors/connelasticsearch/elasticsearch.go @@ -4,26 +4,36 @@ import ( "bytes" "context" "crypto/tls" + "encoding/base64" "encoding/json" + "errors" "fmt" "log/slog" "net/http" "sync" + "sync/atomic" "time" "github.com/elastic/go-elasticsearch/v8" "github.com/elastic/go-elasticsearch/v8/esutil" - "github.com/google/uuid" "go.temporal.io/sdk/log" + "golang.org/x/exp/maps" metadataStore "github.com/PeerDB-io/peer-flow/connectors/external_metadata" + "github.com/PeerDB-io/peer-flow/connectors/utils" "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/logger" "github.com/PeerDB-io/peer-flow/model" "github.com/PeerDB-io/peer-flow/model/qvalue" + "github.com/PeerDB-io/peer-flow/peerdbenv" "github.com/PeerDB-io/peer-flow/shared" ) +const ( + actionIndex = "index" + actionDelete = "delete" +) + type ElasticsearchConnector struct { *metadataStore.PostgresMetadata client *elasticsearch.Client @@ -78,93 +88,170 @@ func (esc *ElasticsearchConnector) Close() error { return nil } -func (esc *ElasticsearchConnector) SetupQRepMetadataTables(ctx context.Context, - config *protos.QRepConfig, +// ES is queue-like, no raw table staging needed +func (esc *ElasticsearchConnector) CreateRawTable(ctx context.Context, + req *protos.CreateRawTableInput, +) (*protos.CreateRawTableOutput, error) { + return &protos.CreateRawTableOutput{TableIdentifier: "n/a"}, nil +} + +// we handle schema changes by not handling them since no mapping is being enforced right now +func (esc *ElasticsearchConnector) ReplayTableSchemaDeltas(ctx context.Context, + flowJobName string, schemaDeltas []*protos.TableSchemaDelta, ) error { return nil } -func (esc *ElasticsearchConnector) SyncQRepRecords(ctx context.Context, config *protos.QRepConfig, - partition *protos.QRepPartition, stream *model.QRecordStream, -) (int, error) { - startTime := time.Now() +func recordItemsProcessor(items model.RecordItems) ([]byte, error) { + qRecordJsonMap := make(map[string]any) + + for key, val := range items.ColToVal { + if r, ok := val.(qvalue.QValueJSON); ok { // JSON is stored as a string, fix that + qRecordJsonMap[key] = json.RawMessage( + shared.UnsafeFastStringToReadOnlyBytes(r.Val)) + } else { + qRecordJsonMap[key] = val.Value() + } + } - schema := stream.Schema() + return json.Marshal(qRecordJsonMap) +} - var bulkIndexFatalError error - var bulkIndexErrors []error - var bulkIndexMutex sync.Mutex - var docId string - numRecords := 0 - bulkIndexerHasShutdown := false - - // -1 means use UUID, >=0 means column in the record - upsertColIndex := -1 - // only support single upsert column for now - if config.WriteMode.WriteType == protos.QRepWriteType_QREP_WRITE_MODE_UPSERT && - len(config.WriteMode.UpsertKeyColumns) == 1 { - for i, field := range schema.Fields { - if config.WriteMode.UpsertKeyColumns[0] == field.Name { - upsertColIndex = i +func (esc *ElasticsearchConnector) SyncRecords(ctx context.Context, + req *model.SyncRecordsRequest[model.RecordItems], +) (*model.SyncResponse, error) { + tableNameRowsMapping := utils.InitialiseTableRowsMap(req.TableMappings) + // atomics for counts will be unnecessary in other destinations, using a mutex instead + var recordCountsUpdateMutex sync.Mutex + // we're taking a mutex anyway, avoid atomic + var lastSeenLSN atomic.Int64 + var numRecords atomic.Int64 + + // no I don't like this either + esBulkIndexerCache := make(map[string]esutil.BulkIndexer) + bulkIndexersHaveShutdown := false + // true if we saw errors while closing + cacheCloser := func() bool { + closeHasErrors := false + if bulkIndexersHaveShutdown { + for _, esBulkIndexer := range maps.Values(esBulkIndexerCache) { + err := esBulkIndexer.Close(context.Background()) + if err != nil { + esc.logger.Error("[es] failed to close bulk indexer", slog.Any("error", err)) + closeHasErrors = true + } } + bulkIndexersHaveShutdown = true } + return closeHasErrors } + defer cacheCloser() - esBulkIndexer, err := esutil.NewBulkIndexer(esutil.BulkIndexerConfig{ - Index: config.DestinationTableIdentifier, - Client: esc.client, - // parallelism comes from the workflow design itself, no need for this - NumWorkers: 1, - FlushInterval: 10 * time.Second, - }) - if err != nil { - esc.logger.Error("[es] failed to initialize bulk indexer", slog.Any("error", err)) - return 0, fmt.Errorf("[es] failed to initialize bulk indexer: %w", err) - } - defer func() { - if !bulkIndexerHasShutdown { - err := esBulkIndexer.Close(context.Background()) - if err != nil { - esc.logger.Error("[es] failed to close bulk indexer", slog.Any("error", err)) + flushLoopDone := make(chan struct{}) + // we only update lastSeenLSN in the OnSuccess call, so this should be safe even if race + // between loop breaking and closing flushLoopDone + go func() { + ticker := time.NewTicker(peerdbenv.PeerDBQueueFlushTimeoutSeconds()) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-flushLoopDone: + return + case <-ticker.C: + lastSeen := lastSeenLSN.Load() + if lastSeen > req.ConsumedOffset.Load() { + if err := esc.SetLastOffset(ctx, req.FlowJobName, lastSeen); err != nil { + esc.logger.Warn("[es] SetLastOffset error", slog.Any("error", err)) + } else { + shared.AtomicInt64Max(req.ConsumedOffset, lastSeen) + esc.logger.Info("processBatch", slog.Int64("updated last offset", lastSeen)) + } + } } } }() - for qRecord := range stream.Records { - qRecordJsonMap := make(map[string]any) + var docId string + var bulkIndexFatalError error + var bulkIndexErrors []error + var bulkIndexOnFailureMutex sync.Mutex - if upsertColIndex >= 0 { - docId = fmt.Sprintf("%v", qRecord[upsertColIndex].Value()) - } else { - docId = uuid.New().String() + for record := range req.Records.GetRecords() { + var bodyBytes []byte + var err error + action := actionIndex + + switch record.(type) { + case *model.InsertRecord[model.RecordItems], *model.UpdateRecord[model.RecordItems]: + bodyBytes, err = recordItemsProcessor(record.GetItems()) + if err != nil { + esc.logger.Error("[es] failed to json.Marshal record", slog.Any("error", err)) + return nil, fmt.Errorf("[es] failed to json.Marshal record: %w", err) + } + case *model.DeleteRecord[model.RecordItems]: + action = actionDelete + // no need to supply the document since we are deleting + bodyBytes = nil } - for i, field := range schema.Fields { - switch r := qRecord[i].(type) { - // JSON is stored as a string, fix that - case qvalue.QValueJSON: - qRecordJsonMap[field.Name] = json.RawMessage(shared. - UnsafeFastStringToReadOnlyBytes(r.Val)) - default: - qRecordJsonMap[field.Name] = r.Value() + + bulkIndexer, ok := esBulkIndexerCache[record.GetDestinationTableName()] + if !ok { + bulkIndexer, err = esutil.NewBulkIndexer(esutil.BulkIndexerConfig{ + Index: record.GetDestinationTableName(), + Client: esc.client, + // can't really ascertain how many tables present to provide a reasonable value + NumWorkers: 1, + FlushInterval: 10 * time.Second, + }) + if err != nil { + esc.logger.Error("[es] failed to initialize bulk indexer", slog.Any("error", err)) + return nil, fmt.Errorf("[es] failed to initialize bulk indexer: %w", err) } + esBulkIndexerCache[record.GetDestinationTableName()] = bulkIndexer } - qRecordJsonBytes, err := json.Marshal(qRecordJsonMap) - if err != nil { - esc.logger.Error("[es] failed to json.Marshal record", slog.Any("error", err)) - return 0, fmt.Errorf("[es] failed to json.Marshal record: %w", err) + + if len(req.TableNameSchemaMapping[record.GetDestinationTableName()].PrimaryKeyColumns) == 1 { + qValue, err := record.GetItems().GetValueByColName( + req.TableNameSchemaMapping[record.GetDestinationTableName()].PrimaryKeyColumns[0]) + if err != nil { + esc.logger.Error("[es] failed to process record", slog.Any("error", err)) + return nil, fmt.Errorf("[es] failed to process record: %w", err) + } + docId = fmt.Sprint(qValue.Value()) + } else { + tablePkey, err := model.RecToTablePKey(req.TableNameSchemaMapping, record) + if err != nil { + esc.logger.Error("[es] failed to process record", slog.Any("error", err)) + return nil, fmt.Errorf("[es] failed to process record: %w", err) + } + docId = base64.RawURLEncoding.EncodeToString(tablePkey.PkeyColVal[:]) } - err = esBulkIndexer.Add(ctx, esutil.BulkIndexerItem{ - Action: "index", + err = bulkIndexer.Add(ctx, esutil.BulkIndexerItem{ + Action: action, DocumentID: docId, - Body: bytes.NewReader(qRecordJsonBytes), + Body: bytes.NewReader(bodyBytes), + OnSuccess: func(_ context.Context, _ esutil.BulkIndexerItem, _ esutil.BulkIndexerResponseItem) { + shared.AtomicInt64Max(&lastSeenLSN, record.GetCheckpointID()) + numRecords.Add(1) + recordCountsUpdateMutex.Lock() + defer recordCountsUpdateMutex.Unlock() + record.PopulateCountMap(tableNameRowsMapping) + }, // OnFailure is called for each failed operation, log and let parent handle OnFailure: func(ctx context.Context, item esutil.BulkIndexerItem, res esutil.BulkIndexerResponseItem, err error, ) { - bulkIndexMutex.Lock() - defer bulkIndexMutex.Unlock() + // attempt to delete a record that wasn't present, possible from no initial load + if item.Action == actionDelete && res.Status == 404 { + return + } + bulkIndexOnFailureMutex.Lock() + defer bulkIndexOnFailureMutex.Unlock() if err != nil { bulkIndexErrors = append(bulkIndexErrors, err) } else { @@ -172,7 +259,7 @@ func (esc *ElasticsearchConnector) SyncQRepRecords(ctx context.Context, config * if res.Error.Cause.Type != "" || res.Error.Cause.Reason != "" { causeString = fmt.Sprintf("(caused by type:%s reason:%s)", res.Error.Cause.Type, res.Error.Cause.Reason) } - cbErr := fmt.Errorf("id:%s type:%s reason:%s %s", item.DocumentID, res.Error.Type, + cbErr := fmt.Errorf("id:%s action:%s type:%s reason:%s %s", item.DocumentID, item.Action, res.Error.Type, res.Error.Reason, causeString) bulkIndexErrors = append(bulkIndexErrors, cbErr) if res.Error.Type == "illegal_argument_exception" { @@ -183,36 +270,37 @@ func (esc *ElasticsearchConnector) SyncQRepRecords(ctx context.Context, config * }) if err != nil { esc.logger.Error("[es] failed to add record to bulk indexer", slog.Any("error", err)) - return 0, fmt.Errorf("[es] failed to add record to bulk indexer: %w", err) + return nil, fmt.Errorf("[es] failed to add record to bulk indexer: %w", err) } if bulkIndexFatalError != nil { esc.logger.Error("[es] fatal error while indexing record", slog.Any("error", bulkIndexFatalError)) - return 0, fmt.Errorf("[es] fatal error while indexing record: %w", bulkIndexFatalError) + return nil, fmt.Errorf("[es] fatal error while indexing record: %w", bulkIndexFatalError) } - - // update here instead of OnSuccess, if we close successfully it should match - numRecords++ } + // "Receive on a closed channel yields the zero value after all elements in the channel are received." + close(flushLoopDone) - if err := stream.Err(); err != nil { - esc.logger.Error("[es] failed to get record from stream", slog.Any("error", err)) - return 0, fmt.Errorf("[es] failed to get record from stream: %w", err) + if cacheCloser() { + esc.logger.Error("[es] failed to close bulk indexer(s)") + return nil, errors.New("[es] failed to close bulk indexer(s)") } - if err := esBulkIndexer.Close(ctx); err != nil { - esc.logger.Error("[es] failed to close bulk indexer", slog.Any("error", err)) - return 0, fmt.Errorf("[es] failed to close bulk indexer: %w", err) - } - bulkIndexerHasShutdown = true + bulkIndexersHaveShutdown = true if len(bulkIndexErrors) > 0 { for _, err := range bulkIndexErrors { esc.logger.Error("[es] failed to index record", slog.Any("err", err)) } } - err = esc.FinishQRepPartition(ctx, partition, config.FlowJobName, startTime) - if err != nil { - esc.logger.Error("[es] failed to log partition info", slog.Any("error", err)) - return 0, fmt.Errorf("[es] failed to log partition info: %w", err) + lastCheckpoint := req.Records.GetLastCheckpoint() + if err := esc.FinishBatch(ctx, req.FlowJobName, req.SyncBatchID, lastCheckpoint); err != nil { + return nil, err } - return numRecords, nil + + return &model.SyncResponse{ + CurrentSyncBatchID: req.SyncBatchID, + LastSyncedCheckpointID: lastCheckpoint, + NumRecordsSynced: numRecords.Load(), + TableNameRowsMapping: tableNameRowsMapping, + TableSchemaDeltas: req.Records.SchemaDeltas, + }, nil } diff --git a/flow/connectors/connelasticsearch/qrep.go b/flow/connectors/connelasticsearch/qrep.go new file mode 100644 index 0000000000..142c6de363 --- /dev/null +++ b/flow/connectors/connelasticsearch/qrep.go @@ -0,0 +1,175 @@ +package connelasticsearch + +import ( + "bytes" + "context" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "log/slog" + "slices" + "sync" + "time" + + "github.com/elastic/go-elasticsearch/v8/esutil" + + "github.com/PeerDB-io/peer-flow/generated/protos" + "github.com/PeerDB-io/peer-flow/model" + "github.com/PeerDB-io/peer-flow/model/qvalue" + "github.com/PeerDB-io/peer-flow/shared" +) + +func (esc *ElasticsearchConnector) SetupQRepMetadataTables(ctx context.Context, + config *protos.QRepConfig, +) error { + return nil +} + +func upsertKeyColsHash(qRecord []qvalue.QValue, upsertColIndices []int) string { + hasher := sha256.New() + + for _, upsertColIndex := range upsertColIndices { + // cannot return an error + _, _ = fmt.Fprint(hasher, qRecord[upsertColIndex].Value()) + } + hashBytes := hasher.Sum(nil) + return base64.RawURLEncoding.EncodeToString(hashBytes) +} + +func (esc *ElasticsearchConnector) SyncQRepRecords(ctx context.Context, config *protos.QRepConfig, + partition *protos.QRepPartition, stream *model.QRecordStream, +) (int, error) { + startTime := time.Now() + + schema := stream.Schema() + + var bulkIndexFatalError error + var bulkIndexErrors []error + var bulkIndexOnFailureMutex sync.Mutex + var docId string + numRecords := 0 + bulkIndexerHasShutdown := false + + // len == 0 means use UUID + // len == 1 means single column, use value directly + // len > 1 means SHA256 hash of upsert key columns + // ordered such that we preserve order of UpsertKeyColumns + var upsertKeyColIndices []int + if config.WriteMode.WriteType == protos.QRepWriteType_QREP_WRITE_MODE_UPSERT { + schemaColNames := schema.GetColumnNames() + for _, upsertCol := range config.WriteMode.UpsertKeyColumns { + idx := slices.Index(schemaColNames, upsertCol) + if idx != -1 { + upsertKeyColIndices = append(upsertKeyColIndices, idx) + } + } + } + + esBulkIndexer, err := esutil.NewBulkIndexer(esutil.BulkIndexerConfig{ + Index: config.DestinationTableIdentifier, + Client: esc.client, + // parallelism comes from the workflow design itself, no need for this + NumWorkers: 1, + FlushInterval: 10 * time.Second, + }) + if err != nil { + esc.logger.Error("[es] failed to initialize bulk indexer", slog.Any("error", err)) + return 0, fmt.Errorf("[es] failed to initialize bulk indexer: %w", err) + } + defer func() { + if !bulkIndexerHasShutdown { + err := esBulkIndexer.Close(context.Background()) + if err != nil { + esc.logger.Error("[es] failed to close bulk indexer", slog.Any("error", err)) + } + } + }() + + for qRecord := range stream.Records { + qRecordJsonMap := make(map[string]any) + + switch len(upsertKeyColIndices) { + case 0: + // relying on autogeneration of document ID + case 1: + docId = fmt.Sprint(qRecord[upsertKeyColIndices[0]].Value()) + default: + docId = upsertKeyColsHash(qRecord, upsertKeyColIndices) + } + for i, field := range schema.Fields { + if r, ok := qRecord[i].(qvalue.QValueJSON); ok { // JSON is stored as a string, fix that + qRecordJsonMap[field.Name] = json.RawMessage( + shared.UnsafeFastStringToReadOnlyBytes(r.Val)) + } else { + qRecordJsonMap[field.Name] = qRecord[i].Value() + } + } + qRecordJsonBytes, err := json.Marshal(qRecordJsonMap) + if err != nil { + esc.logger.Error("[es] failed to json.Marshal record", slog.Any("error", err)) + return 0, fmt.Errorf("[es] failed to json.Marshal record: %w", err) + } + + err = esBulkIndexer.Add(ctx, esutil.BulkIndexerItem{ + Action: actionIndex, + DocumentID: docId, + Body: bytes.NewReader(qRecordJsonBytes), + + // OnFailure is called for each failed operation, log and let parent handle + OnFailure: func(ctx context.Context, item esutil.BulkIndexerItem, + res esutil.BulkIndexerResponseItem, err error, + ) { + bulkIndexOnFailureMutex.Lock() + defer bulkIndexOnFailureMutex.Unlock() + if err != nil { + bulkIndexErrors = append(bulkIndexErrors, err) + } else { + causeString := "" + if res.Error.Cause.Type != "" || res.Error.Cause.Reason != "" { + causeString = fmt.Sprintf("(caused by type:%s reason:%s)", res.Error.Cause.Type, res.Error.Cause.Reason) + } + cbErr := fmt.Errorf("id:%s type:%s reason:%s %s", item.DocumentID, res.Error.Type, + res.Error.Reason, causeString) + bulkIndexErrors = append(bulkIndexErrors, cbErr) + if res.Error.Type == "illegal_argument_exception" { + bulkIndexFatalError = cbErr + } + } + }, + }) + if err != nil { + esc.logger.Error("[es] failed to add record to bulk indexer", slog.Any("error", err)) + return 0, fmt.Errorf("[es] failed to add record to bulk indexer: %w", err) + } + if bulkIndexFatalError != nil { + esc.logger.Error("[es] fatal error while indexing record", slog.Any("error", bulkIndexFatalError)) + return 0, fmt.Errorf("[es] fatal error while indexing record: %w", bulkIndexFatalError) + } + + // update here instead of OnSuccess, if we close successfully it should match + numRecords++ + } + + if err := stream.Err(); err != nil { + esc.logger.Error("[es] failed to get record from stream", slog.Any("error", err)) + return 0, fmt.Errorf("[es] failed to get record from stream: %w", err) + } + if err := esBulkIndexer.Close(ctx); err != nil { + esc.logger.Error("[es] failed to close bulk indexer", slog.Any("error", err)) + return 0, fmt.Errorf("[es] failed to close bulk indexer: %w", err) + } + bulkIndexerHasShutdown = true + if len(bulkIndexErrors) > 0 { + for _, err := range bulkIndexErrors { + esc.logger.Error("[es] failed to index record", slog.Any("err", err)) + } + } + + err = esc.FinishQRepPartition(ctx, partition, config.FlowJobName, startTime) + if err != nil { + esc.logger.Error("[es] failed to log partition info", slog.Any("error", err)) + return 0, fmt.Errorf("[es] failed to log partition info: %w", err) + } + return numRecords, nil +} diff --git a/flow/connectors/core.go b/flow/connectors/core.go index 1e3321fa72..545269c4d1 100644 --- a/flow/connectors/core.go +++ b/flow/connectors/core.go @@ -304,6 +304,7 @@ var ( _ CDCSyncConnector = &connpubsub.PubSubConnector{} _ CDCSyncConnector = &conns3.S3Connector{} _ CDCSyncConnector = &connclickhouse.ClickhouseConnector{} + _ CDCSyncConnector = &connelasticsearch.ElasticsearchConnector{} _ CDCSyncPgConnector = &connpostgres.PostgresConnector{} diff --git a/flow/connectors/postgres/cdc.go b/flow/connectors/postgres/cdc.go index 4adcc1d0d4..ae1e84292e 100644 --- a/flow/connectors/postgres/cdc.go +++ b/flow/connectors/postgres/cdc.go @@ -2,10 +2,8 @@ package connpostgres import ( "context" - "crypto/sha256" "fmt" "log/slog" - "slices" "time" "github.com/jackc/pglogrepl" @@ -488,7 +486,7 @@ func PullCdcRecords[Items model.Items]( return err } } else { - tablePkeyVal, err := recToTablePKey(req, rec) + tablePkeyVal, err := model.RecToTablePKey[Items](req.TableNameSchemaMapping, rec) if err != nil { return err } @@ -520,7 +518,7 @@ func PullCdcRecords[Items model.Items]( return err } } else { - tablePkeyVal, err := recToTablePKey(req, rec) + tablePkeyVal, err := model.RecToTablePKey[Items](req.TableNameSchemaMapping, rec) if err != nil { return err } @@ -538,7 +536,7 @@ func PullCdcRecords[Items model.Items]( return err } } else { - tablePkeyVal, err := recToTablePKey(req, rec) + tablePkeyVal, err := model.RecToTablePKey[Items](req.TableNameSchemaMapping, rec) if err != nil { return err } @@ -865,27 +863,6 @@ func processRelationMessage[Items model.Items]( return nil, nil } -func recToTablePKey[Items model.Items]( - req *model.PullRecordsRequest[Items], - rec model.Record[Items], -) (model.TableWithPkey, error) { - tableName := rec.GetDestinationTableName() - pkeyColsMerged := make([][]byte, 0, len(req.TableNameSchemaMapping[tableName].PrimaryKeyColumns)) - - for _, pkeyCol := range req.TableNameSchemaMapping[tableName].PrimaryKeyColumns { - pkeyColBytes, err := rec.GetItems().GetBytesByColName(pkeyCol) - if err != nil { - return model.TableWithPkey{}, fmt.Errorf("error getting pkey column value: %w", err) - } - pkeyColsMerged = append(pkeyColsMerged, pkeyColBytes) - } - - return model.TableWithPkey{ - TableName: tableName, - PkeyColVal: sha256.Sum256(slices.Concat(pkeyColsMerged...)), - }, nil -} - func (p *PostgresCDCSource) getParentRelIDIfPartitioned(relID uint32) uint32 { parentRelID, ok := p.childToParentRelIDMapping[relID] if ok { diff --git a/flow/connectors/postgres/postgres.go b/flow/connectors/postgres/postgres.go index e88a736428..a20291b04d 100644 --- a/flow/connectors/postgres/postgres.go +++ b/flow/connectors/postgres/postgres.go @@ -706,7 +706,7 @@ func (c *PostgresConnector) CreateRawTable(ctx context.Context, req *protos.Crea _, err = createRawTableTx.Exec(ctx, fmt.Sprintf(createRawTableDstTableIndexSQL, rawTableIdentifier, c.metadataSchema, rawTableIdentifier)) if err != nil { - return nil, fmt.Errorf("error creating destion table index on raw table: %w", err) + return nil, fmt.Errorf("error creating destination table index on raw table: %w", err) } err = createRawTableTx.Commit(ctx) diff --git a/flow/e2e/elasticsearch/elasticsearch.go b/flow/e2e/elasticsearch/elasticsearch.go index d0abea9413..ed3389adb7 100644 --- a/flow/e2e/elasticsearch/elasticsearch.go +++ b/flow/e2e/elasticsearch/elasticsearch.go @@ -80,8 +80,12 @@ func (s elasticsearchSuite) Peer() *protos.Peer { return s.peer } -func (s elasticsearchSuite) CountDocumentsInIndex(index string) int64 { +func (s elasticsearchSuite) countDocumentsInIndex(index string) int64 { res, err := s.esClient.Count().Index(index).Do(context.Background()) + // index may not exist yet, don't error out for that + if err != nil && strings.Contains(err.Error(), "index_not_found_exception") { + return 0 + } require.NoError(s.t, err, "failed to get count of documents in index") return res.Count } diff --git a/flow/e2e/elasticsearch/peer_flow_es_test.go b/flow/e2e/elasticsearch/peer_flow_es_test.go new file mode 100644 index 0000000000..04e1e6d6ab --- /dev/null +++ b/flow/e2e/elasticsearch/peer_flow_es_test.go @@ -0,0 +1,150 @@ +package e2e_elasticsearch + +import ( + "context" + "fmt" + "time" + + "github.com/stretchr/testify/require" + + "github.com/PeerDB-io/peer-flow/e2e" + peerflow "github.com/PeerDB-io/peer-flow/workflows" +) + +func (s elasticsearchSuite) Test_Simple_PKey_CDC_Mirror() { + srcTableName := e2e.AttachSchema(s, "es_simple_pkey_cdc") + + _, err := s.conn.Conn().Exec(context.Background(), fmt.Sprintf(` + CREATE TABLE IF NOT EXISTS %s ( + id INT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, + c1 INT, + val TEXT, + updated_at TIMESTAMP DEFAULT now() + ); + `, srcTableName)) + require.NoError(s.t, err, "failed creating table") + + tc := e2e.NewTemporalClient(s.t) + connectionGen := e2e.FlowConnectionGenerationConfig{ + FlowJobName: e2e.AddSuffix(s, "es_simple_pkey_cdc"), + TableNameMapping: map[string]string{srcTableName: srcTableName}, + Destination: s.peer, + } + flowConnConfig := connectionGen.GenerateFlowConnectionConfigs() + flowConnConfig.MaxBatchSize = 100 + flowConnConfig.DoInitialSnapshot = true + + rowCount := 10 + for i := range rowCount { + _, err := s.conn.Conn().Exec(context.Background(), fmt.Sprintf(` + INSERT INTO %s(c1,val) VALUES(%d,'val%d') + `, srcTableName, i, i)) + require.NoError(s.t, err, "failed to insert row") + } + + env := e2e.ExecutePeerflow(tc, peerflow.CDCFlowWorkflow, flowConnConfig, nil) + e2e.SetupCDCFlowStatusQuery(s.t, env, connectionGen) + + for i := range rowCount { + _, err := s.conn.Conn().Exec(context.Background(), fmt.Sprintf(` + INSERT INTO %s(c1,val) VALUES(%d,'val%d') + `, srcTableName, i, i)) + require.NoError(s.t, err, "failed to insert row") + } + e2e.EnvWaitFor(s.t, env, 3*time.Minute, "wait for initial snapshot + inserted rows", func() bool { + return s.countDocumentsInIndex(srcTableName) == int64(2*rowCount) + }) + + _, err = s.conn.Conn().Exec(context.Background(), fmt.Sprintf(` + UPDATE %s SET c1=c1+2,updated_at=now() WHERE id%%2=0;`, srcTableName)) + require.NoError(s.t, err, "failed to update rows on source") + for i := range rowCount { + _, err := s.conn.Conn().Exec(context.Background(), fmt.Sprintf(` + INSERT INTO %s(c1,val) VALUES(%d,'val%d') + `, srcTableName, i, i)) + require.NoError(s.t, err, "failed to insert row") + } + e2e.EnvWaitFor(s.t, env, 3*time.Minute, "wait for updates + new inserts", func() bool { + return s.countDocumentsInIndex(srcTableName) == int64(3*rowCount) + }) + + _, err = s.conn.Conn().Exec(context.Background(), fmt.Sprintf(` + DELETE FROM %s WHERE id%%2=1;`, srcTableName)) + require.NoError(s.t, err, "failed to delete rows on source") + e2e.EnvWaitFor(s.t, env, 3*time.Minute, "wait for deletes", func() bool { + return s.countDocumentsInIndex(srcTableName) == int64(3*rowCount/2) + }) + + env.Cancel() + e2e.RequireEnvCanceled(s.t, env) +} + +func (s elasticsearchSuite) Test_Composite_PKey_CDC_Mirror() { + srcTableName := e2e.AttachSchema(s, "es_composite_pkey_cdc") + + _, err := s.conn.Conn().Exec(context.Background(), fmt.Sprintf(` + CREATE TABLE IF NOT EXISTS %s ( + id INT GENERATED ALWAYS AS IDENTITY, + c1 INT, + val TEXT, + updated_at TIMESTAMP DEFAULT now(), + PRIMARY KEY(id,val) + ); + `, srcTableName)) + require.NoError(s.t, err, "failed creating table") + + tc := e2e.NewTemporalClient(s.t) + connectionGen := e2e.FlowConnectionGenerationConfig{ + FlowJobName: e2e.AddSuffix(s, "es_composite_pkey_cdc"), + TableNameMapping: map[string]string{srcTableName: srcTableName}, + Destination: s.peer, + } + flowConnConfig := connectionGen.GenerateFlowConnectionConfigs() + flowConnConfig.MaxBatchSize = 100 + flowConnConfig.DoInitialSnapshot = true + + rowCount := 10 + for i := range rowCount { + _, err := s.conn.Conn().Exec(context.Background(), fmt.Sprintf(` + INSERT INTO %s(c1,val) VALUES(%d,'val%d') + `, srcTableName, i, i)) + require.NoError(s.t, err, "failed to insert row") + } + + env := e2e.ExecutePeerflow(tc, peerflow.CDCFlowWorkflow, flowConnConfig, nil) + e2e.SetupCDCFlowStatusQuery(s.t, env, connectionGen) + + for i := range rowCount { + _, err := s.conn.Conn().Exec(context.Background(), fmt.Sprintf(` + INSERT INTO %s(c1,val) VALUES(%d,'val%d') + `, srcTableName, i, i)) + require.NoError(s.t, err, "failed to insert row") + } + + e2e.EnvWaitFor(s.t, env, 3*time.Minute, "wait for initial snapshot + inserted rows", func() bool { + return s.countDocumentsInIndex(srcTableName) == int64(2*rowCount) + }) + + _, err = s.conn.Conn().Exec(context.Background(), fmt.Sprintf(` + UPDATE %s SET c1=c1+2,updated_at=now() WHERE id%%2=0;`, srcTableName)) + require.NoError(s.t, err, "failed to update rows on source") + for i := range rowCount { + _, err := s.conn.Conn().Exec(context.Background(), fmt.Sprintf(` + INSERT INTO %s(c1,val) VALUES(%d,'val%d') + `, srcTableName, i, i)) + require.NoError(s.t, err, "failed to insert row") + } + e2e.EnvWaitFor(s.t, env, 3*time.Minute, "wait for updates + new inserts", func() bool { + return s.countDocumentsInIndex(srcTableName) == int64(3*rowCount) + }) + + _, err = s.conn.Conn().Exec(context.Background(), fmt.Sprintf(` + DELETE FROM %s WHERE id%%2=1;`, srcTableName)) + require.NoError(s.t, err, "failed to delete rows on source") + e2e.EnvWaitFor(s.t, env, 3*time.Minute, "wait for deletes", func() bool { + return s.countDocumentsInIndex(srcTableName) == int64(3*rowCount/2) + }) + + env.Cancel() + e2e.RequireEnvCanceled(s.t, env) +} diff --git a/flow/e2e/elasticsearch/qrep_flow_es_test.go b/flow/e2e/elasticsearch/qrep_flow_es_test.go index dafa1b2d79..85ac99043e 100644 --- a/flow/e2e/elasticsearch/qrep_flow_es_test.go +++ b/flow/e2e/elasticsearch/qrep_flow_es_test.go @@ -10,18 +10,20 @@ import ( "github.com/PeerDB-io/peer-flow/e2e" "github.com/PeerDB-io/peer-flow/e2eshared" + "github.com/PeerDB-io/peer-flow/generated/protos" ) func Test_Elasticsearch(t *testing.T) { e2eshared.RunSuite(t, SetupSuite) } -func (s elasticsearchSuite) Test_Simple_Qrep() { - srcTableName := e2e.AttachSchema(s, "es_simple") +func (s elasticsearchSuite) Test_Simple_QRep_Append() { + srcTableName := e2e.AttachSchema(s, "es_simple_append") _, err := s.conn.Conn().Exec(context.Background(), fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( id INT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, + c1 INT, val TEXT, updated_at TIMESTAMP DEFAULT now() ); @@ -31,8 +33,8 @@ func (s elasticsearchSuite) Test_Simple_Qrep() { rowCount := 10 for i := range rowCount { _, err := s.conn.Conn().Exec(context.Background(), fmt.Sprintf(` - INSERT INTO %s(val) VALUES('val%d') - `, srcTableName, i)) + INSERT INTO %s(c1,val) VALUES(%d,'val%d') + `, srcTableName, i, i)) require.NoError(s.t, err, "failed to insert row") } @@ -51,10 +53,72 @@ func (s elasticsearchSuite) Test_Simple_Qrep() { "", "", ) + qrepConfig.InitialCopyOnly = false + require.NoError(s.t, err) env := e2e.RunQRepFlowWorkflow(tc, qrepConfig) - e2e.EnvWaitForFinished(s.t, env, 3*time.Minute) + + e2e.EnvWaitFor(s.t, env, 10*time.Second, "waiting for ES to catch up", func() bool { + return s.countDocumentsInIndex(srcTableName) == int64(rowCount) + }) + _, err = s.conn.Conn().Exec(context.Background(), fmt.Sprintf(` + UPDATE %s SET c1=c1+2,updated_at=now() WHERE id%%2=0;`, srcTableName)) + require.NoError(s.t, err, "failed to update rows on source") + e2e.EnvWaitFor(s.t, env, 20*time.Second, "waiting for ES to catch up", func() bool { + return s.countDocumentsInIndex(srcTableName) == int64(3*rowCount/2) + }) + require.NoError(s.t, env.Error()) +} + +func (s elasticsearchSuite) Test_Simple_QRep_Upsert() { + srcTableName := e2e.AttachSchema(s, "es_simple_upsert") + + _, err := s.conn.Conn().Exec(context.Background(), fmt.Sprintf(` + CREATE TABLE IF NOT EXISTS %s ( + id INT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, + c1 INT, + val TEXT, + updated_at TIMESTAMP DEFAULT now() + ); + `, srcTableName)) + require.NoError(s.t, err, "failed creating table") + + rowCount := 10 + for i := range rowCount { + _, err := s.conn.Conn().Exec(context.Background(), fmt.Sprintf(` + INSERT INTO %s(c1,val) VALUES(%d,'val%d') + `, srcTableName, i, i)) + require.NoError(s.t, err, "failed to insert row") + } + + tc := e2e.NewTemporalClient(s.t) + + query := fmt.Sprintf("SELECT * FROM %s WHERE updated_at BETWEEN {{.start}} AND {{.end}}", + srcTableName) - require.EqualValues(s.t, rowCount, s.CountDocumentsInIndex(srcTableName)) + qrepConfig, err := e2e.CreateQRepWorkflowConfig("test_es_simple_qrep", + srcTableName, + srcTableName, + query, + s.peer, + "", + false, + "", + "", + ) + qrepConfig.WriteMode = &protos.QRepWriteMode{ + WriteType: protos.QRepWriteType_QREP_WRITE_MODE_UPSERT, + UpsertKeyColumns: []string{"id"}, + } + qrepConfig.InitialCopyOnly = false + + require.NoError(s.t, err) + env := e2e.RunQRepFlowWorkflow(tc, qrepConfig) + + e2e.EnvWaitFor(s.t, env, 10*time.Second, "waiting for ES to catch up", func() bool { + return s.countDocumentsInIndex(srcTableName) == int64(rowCount) + }) + + require.NoError(s.t, env.Error()) } diff --git a/flow/model/model.go b/flow/model/model.go index 54c2ef41ce..acdcdc5176 100644 --- a/flow/model/model.go +++ b/flow/model/model.go @@ -1,6 +1,8 @@ package model import ( + "crypto/sha256" + "fmt" "sync/atomic" "time" @@ -75,12 +77,36 @@ type TableWithPkey struct { PkeyColVal [32]byte } +func RecToTablePKey[T Items]( + tableNameSchemaMapping map[string]*protos.TableSchema, + rec Record[T], +) (TableWithPkey, error) { + tableName := rec.GetDestinationTableName() + hasher := sha256.New() + + for _, pkeyCol := range tableNameSchemaMapping[tableName].PrimaryKeyColumns { + pkeyColBytes, err := rec.GetItems().GetBytesByColName(pkeyCol) + if err != nil { + return TableWithPkey{}, fmt.Errorf("error getting pkey column value: %w", err) + } + // cannot return an error + _, _ = hasher.Write(pkeyColBytes) + } + + return TableWithPkey{ + TableName: tableName, + PkeyColVal: [32]byte(hasher.Sum(nil)), + }, nil +} + type SyncRecordsRequest[T Items] struct { Records *CDCStream[T] // ConsumedOffset allows destination to confirm lsn for slot ConsumedOffset *atomic.Int64 // FlowJobName is the name of the flow job. - FlowJobName string + // destination table name -> schema mapping + TableNameSchemaMapping map[string]*protos.TableSchema + FlowJobName string // Staging path for AVRO files in CDC StagingPath string // Lua script diff --git a/flow/workflows/snapshot_flow.go b/flow/workflows/snapshot_flow.go index f1c8c7d7f3..fdec9d15c6 100644 --- a/flow/workflows/snapshot_flow.go +++ b/flow/workflows/snapshot_flow.go @@ -171,6 +171,18 @@ func (s *SnapshotFlowExecution) cloneTable( numRowsPerPartition = s.config.SnapshotNumRowsPerPartition } + snapshotWriteMode := &protos.QRepWriteMode{ + WriteType: protos.QRepWriteType_QREP_WRITE_MODE_APPEND, + } + // ensure document IDs are synchronized across initial load and CDC + // for the same document + if s.config.Destination.Type == protos.DBType_ELASTICSEARCH { + snapshotWriteMode = &protos.QRepWriteMode{ + WriteType: protos.QRepWriteType_QREP_WRITE_MODE_UPSERT, + UpsertKeyColumns: s.tableNameSchemaMapping[mapping.DestinationTableIdentifier].PrimaryKeyColumns, + } + } + config := &protos.QRepConfig{ FlowJobName: childWorkflowID, SourcePeer: sourcePostgres, @@ -185,10 +197,8 @@ func (s *SnapshotFlowExecution) cloneTable( StagingPath: s.config.SnapshotStagingPath, SyncedAtColName: s.config.SyncedAtColName, SoftDeleteColName: s.config.SoftDeleteColName, - WriteMode: &protos.QRepWriteMode{ - WriteType: protos.QRepWriteType_QREP_WRITE_MODE_APPEND, - }, - System: s.config.System, + WriteMode: snapshotWriteMode, + System: s.config.System, } state := NewQRepFlowState()