diff --git a/flow/connectors/clickhouse/cdc.go b/flow/connectors/clickhouse/cdc.go new file mode 100644 index 0000000000..6c286e01a5 --- /dev/null +++ b/flow/connectors/clickhouse/cdc.go @@ -0,0 +1,560 @@ +package connclickhouse + +import ( + "database/sql" + "fmt" + "log/slog" + "regexp" + "strings" + "sync/atomic" + "time" + + _ "github.com/ClickHouse/clickhouse-go/v2" + _ "github.com/ClickHouse/clickhouse-go/v2/lib/driver" + "github.com/PeerDB-io/peer-flow/generated/protos" + "github.com/PeerDB-io/peer-flow/model/qvalue" +) + +const ( + checkIfTableExistsSQL = `SELECT exists(SELECT 1 FROM system.tables WHERE database = ? AND name = ?) AS table_exists;` + mirrorJobsTableIdentifier = "PEERDB_MIRROR_JOBS" +) + +// getRawTableName returns the raw table name for the given table identifier. +func (c *ClickhouseConnector) getRawTableName(flowJobName string) string { + // replace all non-alphanumeric characters with _ + flowJobName = regexp.MustCompile("[^a-zA-Z0-9]+").ReplaceAllString(flowJobName, "_") + return fmt.Sprintf("_peerdb_raw_%s", flowJobName) +} + +func (c *ClickhouseConnector) checkIfTableExists(databaseName string, tableIdentifier string) (bool, error) { + var result pgtype.Bool + err := c.database.QueryRowContext(c.ctx, checkIfTableExistsSQL, databaseName, tableIdentifier).Scan(&result) + if err != nil { + return false, fmt.Errorf("error while reading result row: %w", err) + } + fmt.Printf("result: %+v\n", result) + return result.Bool, nil +} + +type MirrorJobRow struct { + MirrorJobName string + Offset int + SyncBatchID int + NormalizeBatchID int +} + +func (c *ClickhouseConnector) getMirrorRowByJobNAme(jobName string) (*MirrorJobRow, error) { + getLastOffsetSQL := "SELECT mirror_job_name, offset, sync_batch_id, normalize_batch_id FROM %s WHERE MIRROR_JOB_NAME=? Limit 1" + + row := c.database.QueryRowContext(c.ctx, fmt.Sprintf(getLastOffsetSQL, mirrorJobsTableIdentifier), jobName) + + var result MirrorJobRow + + err := row.Scan( + &result.MirrorJobName, + &result.Offset, + &result.SyncBatchID, + &result.NormalizeBatchID, + ) + + if err != nil { + return nil, err + } + + return &result, nil +} + +func (c *ClickhouseConnector) NeedsSetupMetadataTables() bool { + result, err := c.checkIfTableExists(c.config.Database, mirrorJobsTableIdentifier) + if err != nil { + return true + } + return !result +} + +func (c *ClickhouseConnector) SetupMetadataTables() error { + + createMirrorJobsTableSQL := `CREATE TABLE IF NOT EXISTS %s ( + MIRROR_JOB_NAME String NOT NULL, + OFFSET Int32 NOT NULL, + SYNC_BATCH_ID Int32 NOT NULL, + NORMALIZE_BATCH_ID Int32 NOT NULL + ) ENGINE = MergeTree() + ORDER BY MIRROR_JOB_NAME;` + + // NOTE that Clickhouse does not support transactional DDL + //createMetadataTablesTx, err := c.database.BeginTx(c.ctx, nil) + // if err != nil { + // return fmt.Errorf("unable to begin transaction for creating metadata tables: %w", err) + // } + // in case we return after error, ensure transaction is rolled back + // defer func() { + // deferErr := createMetadataTablesTx.Rollback() + // if deferErr != sql.ErrTxDone && deferErr != nil { + // c.logger.Error("error while rolling back transaction for creating metadata tables", + // slog.Any("error", deferErr)) + // } + // }() + + // Not needed as we dont have schema + // err = c.createPeerDBInternalSchema(createMetadataTablesTx) + // if err != nil { + // return err + // } + _, err := c.database.ExecContext(c.ctx, fmt.Sprintf(createMirrorJobsTableSQL, mirrorJobsTableIdentifier)) + if err != nil { + return fmt.Errorf("error while setting up mirror jobs table: %w", err) + } + // err = createMetadataTablesTx.Commit() + // if err != nil { + // return fmt.Errorf("unable to commit transaction for creating metadata tables: %w", err) + // } + + return nil +} + +func (c *ClickhouseConnector) GetLastOffset(jobName string) (int64, error) { + getLastOffsetSQL := "SELECT OFFSET FROM %s WHERE MIRROR_JOB_NAME=?" + + rows, err := c.database.QueryContext(c.ctx, fmt.Sprintf(getLastOffsetSQL, + mirrorJobsTableIdentifier), jobName) + if err != nil { + return 0, fmt.Errorf("error querying Clickhouse peer for last syncedID: %w", err) + } + defer func() { + err = rows.Close() + if err != nil { + c.logger.Error("error while closing rows for reading last offset", slog.Any("error", err)) + } + }() + + if !rows.Next() { + c.logger.Warn("No row found, returning 0") + return 0, nil + } + var result pgtype.Int8 + err = rows.Scan(&result) + if err != nil { + return 0, fmt.Errorf("error while reading result row: %w", err) + } + if result.Int64 == 0 { + c.logger.Warn("Assuming zero offset means no sync has happened") + return 0, nil + } + return result.Int64, nil +} + +func (c *ClickhouseConnector) SetLastOffset(jobName string, lastOffset int64) error { + currentRow, err := c.getMirrorRowByJobNAme(jobName) + + if err != nil { + return err + } + + //setLastOffsetSQL = "UPDATE %s.%s SET OFFSET=GREATEST(OFFSET, ?) WHERE MIRROR_JOB_NAME=?" + setLastOffsetSQL := `INSERT INTO %s + (mirror_job_name, offset, sync_batch_id, normalize_batch_id) + VALUES (?, ?, ?, ?);` + _, err = c.database.ExecContext(c.ctx, fmt.Sprintf(setLastOffsetSQL, + mirrorJobsTableIdentifier), currentRow.MirrorJobName, lastOffset, currentRow.SyncBatchID, currentRow.NormalizeBatchID) + if err != nil { + return fmt.Errorf("error querying Snowflake peer for last syncedID: %w", err) + } + return nil +} + +func (c *ClickhouseConnector) GetLastSyncBatchID(jobName string) (int64, error) { + getLastSyncBatchID_SQL := "SELECT SYNC_BATCH_ID FROM %s WHERE MIRROR_JOB_NAME=?" + + rows, err := c.database.QueryContext(c.ctx, fmt.Sprintf(getLastSyncBatchID_SQL, + mirrorJobsTableIdentifier), jobName) + if err != nil { + return 0, fmt.Errorf("error querying Clickhouse peer for last syncBatchId: %w", err) + } + defer rows.Close() + + var result pgtype.Int8 + if !rows.Next() { + c.logger.Warn("No row found, returning 0") + return 0, nil + } + err = rows.Scan(&result) + if err != nil { + return 0, fmt.Errorf("error while reading result row: %w", err) + } + return result.Int64, nil +} + +func (c *ClickhouseConnector) CreateRawTable(req *protos.CreateRawTableInput) (*protos.CreateRawTableOutput, error) { + rawTableName := c.getRawTableName(req.FlowJobName) + + // createRawTableTx, err := c.database.BeginTx(c.ctx, nil) + // if err != nil { + // return nil, fmt.Errorf("unable to begin transaction for creation of raw table: %w", err) + // } + + createRawTableSQL := `CREATE TABLE IF NOT EXISTS %s ( + _PEERDB_UID STRING NOT NULL, + _PEERDB_TIMESTAMP INT NOT NULL, + _PEERDB_DESTINATION_TABLE_NAME STRING NOT NULL, + _PEERDB_DATA STRING NOT NULL, + _PEERDB_RECORD_TYPE INTEGER NOT NULL, + _PEERDB_MATCH_DATA STRING, + _PEERDB_BATCH_ID INT, + _PEERDB_UNCHANGED_TOAST_COLUMNS STRING + ) ENGINE = ReplacingMergeTree ORDER BY _PEERDB_UID;` + + _, err := c.database.ExecContext(c.ctx, + fmt.Sprintf(createRawTableSQL, rawTableName)) + if err != nil { + return nil, fmt.Errorf("unable to create raw table: %w", err) + } + // err = createRawTableTx.Commit() + // if err != nil { + // return nil, fmt.Errorf("unable to commit transaction for creation of raw table: %w", err) + // } + + stage := c.getStageNameForJob(req.FlowJobName) + err = c.createStage(stage, &protos.QRepConfig{}) + if err != nil { + return nil, err + } + + return &protos.CreateRawTableOutput{ + TableIdentifier: rawTableName, + }, nil +} + +func (c *ClickhouseConnector) syncRecordsViaAvro( + req *model.SyncRecordsRequest, + rawTableIdentifier string, + syncBatchID int64, +) (*model.SyncResponse, error) { + tableNameRowsMapping := make(map[string]uint32) + streamReq := model.NewRecordsToStreamRequest(req.Records.GetRecords(), tableNameRowsMapping, syncBatchID) + streamRes, err := utils.RecordsToRawTableStream(streamReq) + if err != nil { + return nil, fmt.Errorf("failed to convert records to raw table stream: %w", err) + } + + qrepConfig := &protos.QRepConfig{ + StagingPath: "", + FlowJobName: req.FlowJobName, + DestinationTableIdentifier: strings.ToLower(fmt.Sprintf("%s", + rawTableIdentifier)), + } + avroSyncer := NewSnowflakeAvroSyncMethod(qrepConfig, c) + destinationTableSchema, err := c.getTableSchema(qrepConfig.DestinationTableIdentifier) + if err != nil { + return nil, err + } + + numRecords, err := avroSyncer.SyncRecords(destinationTableSchema, streamRes.Stream, req.FlowJobName) + if err != nil { + return nil, err + } + + tableSchemaDeltas := req.Records.WaitForSchemaDeltas(req.TableMappings) + err = c.ReplayTableSchemaDeltas(req.FlowJobName, tableSchemaDeltas) + if err != nil { + return nil, fmt.Errorf("failed to sync schema changes: %w", err) + } + + lastCheckpoint, err := req.Records.GetLastCheckpoint() + if err != nil { + return nil, err + } + + return &model.SyncResponse{ + LastSyncedCheckPointID: lastCheckpoint, + NumRecordsSynced: int64(numRecords), + CurrentSyncBatchID: syncBatchID, + TableNameRowsMapping: tableNameRowsMapping, + TableSchemaDeltas: tableSchemaDeltas, + RelationMessageMapping: <-req.Records.RelationMessageMapping, + }, nil +} + +func (c *ClickhouseConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.SyncResponse, error) { + rawTableName := getRawTableName(req.FlowJobName) + c.logger.Info(fmt.Sprintf("pushing records to Snowflake table %s", rawTableName)) + + syncBatchID, err := c.GetLastSyncBatchID(req.FlowJobName) + if err != nil { + return nil, fmt.Errorf("failed to get previous syncBatchID: %w", err) + } + syncBatchID += 1 + + res, err := c.syncRecordsViaAvro(req, rawTableName, syncBatchID) + if err != nil { + return nil, err + } + + // transaction for SyncRecords + syncRecordsTx, err := c.database.BeginTx(c.ctx, nil) + if err != nil { + return nil, err + } + // in case we return after error, ensure transaction is rolled back + defer func() { + deferErr := syncRecordsTx.Rollback() + if deferErr != sql.ErrTxDone && deferErr != nil { + c.logger.Error("error while rolling back transaction for SyncRecords: %v", + slog.Any("error", deferErr), slog.Int64("syncBatchID", syncBatchID)) + } + }() + + // updating metadata with new offset and syncBatchID + err = c.updateSyncMetadata(req.FlowJobName, res.LastSyncedCheckPointID, syncBatchID, syncRecordsTx) + if err != nil { + return nil, err + } + // transaction commits + err = syncRecordsTx.Commit() + if err != nil { + return nil, err + } + + return res, nil +} + +func (c *ClickhouseConnector) SyncFlowCleanup(jobName string) error { + syncFlowCleanupTx, err := c.database.BeginTx(c.ctx, nil) + if err != nil { + return fmt.Errorf("unable to begin transaction for sync flow cleanup: %w", err) + } + defer func() { + deferErr := syncFlowCleanupTx.Rollback() + if deferErr != sql.ErrTxDone && deferErr != nil { + c.logger.Error("error while rolling back transaction for flow cleanup", slog.Any("error", deferErr)) + } + }() + + row := syncFlowCleanupTx.QueryRowContext(c.ctx, checkSchemaExistsSQL, c.metadataSchema) + var schemaExists pgtype.Bool + err = row.Scan(&schemaExists) + if err != nil { + return fmt.Errorf("unable to check if internal schema exists: %w", err) + } + + if schemaExists.Bool { + _, err = syncFlowCleanupTx.ExecContext(c.ctx, fmt.Sprintf(dropTableIfExistsSQL, c.metadataSchema, + getRawTableIdentifier(jobName))) + if err != nil { + return fmt.Errorf("unable to drop raw table: %w", err) + } + _, err = syncFlowCleanupTx.ExecContext(c.ctx, + fmt.Sprintf(deleteJobMetadataSQL, c.metadataSchema, mirrorJobsTableIdentifier), jobName) + if err != nil { + return fmt.Errorf("unable to delete job metadata: %w", err) + } + } + + err = syncFlowCleanupTx.Commit() + if err != nil { + return fmt.Errorf("unable to commit transaction for sync flow cleanup: %w", err) + } + + err = c.dropStage("", jobName) + if err != nil { + return err + } + + return nil +} + +func (c *ClickhouseConnector) SetupNormalizedTables( + req *protos.SetupNormalizedTableBatchInput, +) (*protos.SetupNormalizedTableBatchOutput, error) { + tableExistsMapping := make(map[string]bool) + for tableIdentifier, tableSchema := range req.TableNameSchemaMapping { + normalizedSchemaTable, err := utils.ParseSchemaTable(tableIdentifier) + if err != nil { + return nil, fmt.Errorf("error while parsing table schema and name: %w", err) + } + tableAlreadyExists, err := c.checkIfTableExists(normalizedSchemaTable.Schema, normalizedSchemaTable.Table) + if err != nil { + return nil, fmt.Errorf("error occurred while checking if normalized table exists: %w", err) + } + if tableAlreadyExists { + tableExistsMapping[tableIdentifier] = true + continue + } + + normalizedTableCreateSQL := generateCreateTableSQLForNormalizedTable( + normalizedSchemaTable, tableSchema, req.SoftDeleteColName, req.SyncedAtColName) + _, err = c.database.ExecContext(c.ctx, normalizedTableCreateSQL) + if err != nil { + return nil, fmt.Errorf("[sf] error while creating normalized table: %w", err) + } + tableExistsMapping[tableIdentifier] = false + } + + return &protos.SetupNormalizedTableBatchOutput{ + TableExistsMapping: tableExistsMapping, + }, nil +} + +// ReplayTableSchemaDeltas changes a destination table to match the schema at source +// This could involve adding or dropping multiple columns. +func (c *ClickhouseConnector) ReplayTableSchemaDeltas(flowJobName string, + schemaDeltas []*protos.TableSchemaDelta, +) error { + if len(schemaDeltas) == 0 { + return nil + } + + tableSchemaModifyTx, err := c.database.Begin() + if err != nil { + return fmt.Errorf("error starting transaction for schema modification: %w", + err) + } + defer func() { + deferErr := tableSchemaModifyTx.Rollback() + if deferErr != sql.ErrTxDone && deferErr != nil { + c.logger.Error("error rolling back transaction for table schema modification", slog.Any("error", deferErr)) + } + }() + + for _, schemaDelta := range schemaDeltas { + if schemaDelta == nil || len(schemaDelta.AddedColumns) == 0 { + continue + } + + for _, addedColumn := range schemaDelta.AddedColumns { + sfColtype, err := qValueKindToSnowflakeType(qvalue.QValueKind(addedColumn.ColumnType)) + if err != nil { + return fmt.Errorf("failed to convert column type %s to snowflake type: %w", + addedColumn.ColumnType, err) + } + _, err = tableSchemaModifyTx.ExecContext(c.ctx, + fmt.Sprintf("ALTER TABLE %s ADD COLUMN IF NOT EXISTS \"%s\" %s", + schemaDelta.DstTableName, strings.ToUpper(addedColumn.ColumnName), sfColtype)) + if err != nil { + return fmt.Errorf("failed to add column %s for table %s: %w", addedColumn.ColumnName, + schemaDelta.DstTableName, err) + } + c.logger.Info(fmt.Sprintf("[schema delta replay] added column %s with data type %s", addedColumn.ColumnName, + addedColumn.ColumnType), + slog.String("destination table name", schemaDelta.DstTableName), + slog.String("source table name", schemaDelta.SrcTableName)) + } + } + + err = tableSchemaModifyTx.Commit() + if err != nil { + return fmt.Errorf("failed to commit transaction for table schema modification: %w", + err) + } + + return nil +} + +func (c *ClickhouseConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) (*model.NormalizeResponse, error) { + batchIDs, err := c.GetLastSyncAndNormalizeBatchID(req.FlowJobName) + if err != nil { + return nil, err + } + // normalize has caught up with sync, chill until more records are loaded. + if batchIDs.NormalizeBatchID >= batchIDs.SyncBatchID { + return &model.NormalizeResponse{ + Done: false, + StartBatchID: batchIDs.NormalizeBatchID, + EndBatchID: batchIDs.SyncBatchID, + }, nil + } + + jobMetadataExists, err := c.jobMetadataExists(req.FlowJobName) + if err != nil { + return nil, err + } + // sync hasn't created job metadata yet, chill. + if !jobMetadataExists { + return &model.NormalizeResponse{ + Done: false, + }, nil + } + destinationTableNames, err := c.getDistinctTableNamesInBatch( + req.FlowJobName, + batchIDs.SyncBatchID, + batchIDs.NormalizeBatchID, + ) + if err != nil { + return nil, err + } + + tableNametoUnchangedToastCols, err := c.getTableNametoUnchangedCols(req.FlowJobName, batchIDs.SyncBatchID, batchIDs.NormalizeBatchID) + if err != nil { + return nil, fmt.Errorf("couldn't tablename to unchanged cols mapping: %w", err) + } + + var totalRowsAffected int64 = 0 + g, gCtx := errgroup.WithContext(c.ctx) + g.SetLimit(8) // limit parallel merges to 8 + + for _, destinationTableName := range destinationTableNames { + tableName := destinationTableName // local variable for the closure + + g.Go(func() error { + mergeGen := &mergeStmtGenerator{ + rawTableName: getRawTableIdentifier(req.FlowJobName), + dstTableName: tableName, + syncBatchID: batchIDs.SyncBatchID, + normalizeBatchID: batchIDs.NormalizeBatchID, + normalizedTableSchema: req.TableNameSchemaMapping[tableName], + unchangedToastColumns: tableNametoUnchangedToastCols[tableName], + peerdbCols: &protos.PeerDBColumns{ + SoftDelete: req.SoftDelete, + SoftDeleteColName: req.SoftDeleteColName, + SyncedAtColName: req.SyncedAtColName, + }, + } + mergeStatement, err := mergeGen.generateMergeStmt() + if err != nil { + return err + } + + startTime := time.Now() + c.logger.Info("[merge] merging records...", slog.String("destTable", tableName)) + + result, err := c.database.ExecContext(gCtx, mergeStatement, tableName) + if err != nil { + return fmt.Errorf("failed to merge records into %s (statement: %s): %w", + tableName, mergeStatement, err) + } + + endTime := time.Now() + c.logger.Info(fmt.Sprintf("[merge] merged records into %s, took: %d seconds", + tableName, endTime.Sub(startTime)/time.Second)) + if err != nil { + c.logger.Error("[merge] error while normalizing records", slog.Any("error", err)) + return err + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("failed to get rows affected by merge statement for table %s: %w", tableName, err) + } + + atomic.AddInt64(&totalRowsAffected, rowsAffected) + return nil + }) + } + + if err := g.Wait(); err != nil { + return nil, fmt.Errorf("error while normalizing records: %w", err) + } + + // updating metadata with new normalizeBatchID + err = c.updateNormalizeMetadata(req.FlowJobName, batchIDs.SyncBatchID) + if err != nil { + return nil, err + } + + return &model.NormalizeResponse{ + Done: true, + StartBatchID: batchIDs.NormalizeBatchID + 1, + EndBatchID: batchIDs.SyncBatchID, + }, nil +} diff --git a/flow/connectors/clickhouse/clickhouse.go b/flow/connectors/clickhouse/clickhouse.go index de51b0feb9..60c2ba4840 100644 --- a/flow/connectors/clickhouse/clickhouse.go +++ b/flow/connectors/clickhouse/clickhouse.go @@ -17,6 +17,7 @@ type ClickhouseConnector struct { database *sql.DB tableSchemaMapping map[string]*protos.TableSchema logger slog.Logger + config *protos.ClickhouseConfig } func NewClickhouseConnector(ctx context.Context, @@ -33,6 +34,7 @@ func NewClickhouseConnector(ctx context.Context, database: database, tableSchemaMapping: nil, logger: *slog.With(slog.String(string(shared.FlowNameKey), flowName)), + config: clickhouseProtoConfig, }, nil }