diff --git a/flow/connectors/bigquery/bigquery.go b/flow/connectors/bigquery/bigquery.go index 963eb60a96..edb3cbdcf3 100644 --- a/flow/connectors/bigquery/bigquery.go +++ b/flow/connectors/bigquery/bigquery.go @@ -27,12 +27,6 @@ import ( ) const ( - /* - Different batch Ids in code/BigQuery - 1. batchID - identifier in raw table on target to depict which batch a row was inserted. - 3. syncBatchID - batch id that was last synced or will be synced - 4. normalizeBatchID - batch id that was last normalized or will be normalized. - */ // MirrorJobsTable has the following schema: // CREATE TABLE peerdb_mirror_jobs ( // mirror_job_id STRING NOT NULL, diff --git a/flow/connectors/clickhouse/qrep_avro_sync.go b/flow/connectors/clickhouse/qrep_avro_sync.go index 68129a98d5..3d2f179bca 100644 --- a/flow/connectors/clickhouse/qrep_avro_sync.go +++ b/flow/connectors/clickhouse/qrep_avro_sync.go @@ -195,11 +195,11 @@ func (s *ClickhouseAvroSyncMethod) insertMetadata( if err != nil { s.connector.logger.Error("failed to create metadata insert statement", slog.Any("error", err), partitionLog) - return fmt.Errorf("failed to create metadata insert statement: %v", err) + return fmt.Errorf("failed to create metadata insert statement: %w", err) } if _, err := s.connector.database.Exec(insertMetadataStmt); err != nil { - return fmt.Errorf("failed to execute metadata insert statement: %v", err) + return fmt.Errorf("failed to execute metadata insert statement: %w", err) } return nil diff --git a/flow/connectors/external_metadata/store.go b/flow/connectors/external_metadata/store.go index ab5224b2ee..3be58019b8 100644 --- a/flow/connectors/external_metadata/store.go +++ b/flow/connectors/external_metadata/store.go @@ -4,10 +4,12 @@ import ( "context" "fmt" "log/slog" + "time" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgtype" + "google.golang.org/protobuf/encoding/protojson" connpostgres "github.com/PeerDB-io/peer-flow/connectors/postgres" "github.com/PeerDB-io/peer-flow/connectors/utils" @@ -18,6 +20,7 @@ import ( const ( lastSyncStateTableName = "last_sync_state" + qrepTableName = "qrep_metadata" ) type Querier interface { @@ -118,29 +121,49 @@ func (p *PostgresMetadataStore) SetupMetadata() error { // create the last sync state table _, err = p.conn.Exec(p.ctx, ` - CREATE TABLE IF NOT EXISTS `+p.QualifyTable(lastSyncStateTableName)+` ( + CREATE TABLE IF NOT EXISTS `+p.QualifyTable(lastSyncStateTableName)+`( job_name TEXT PRIMARY KEY NOT NULL, last_offset BIGINT NOT NULL, - updated_at TIMESTAMP NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), sync_batch_id BIGINT NOT NULL, normalize_batch_id BIGINT - ) - `) + )`) if err != nil && !utils.IsUniqueError(err) { p.logger.Error("failed to create last sync state table", slog.Any("error", err)) return err } + _, err = p.conn.Exec(p.ctx, ` + CREATE TABLE IF NOT EXISTS `+p.QualifyTable(qrepTableName)+`( + job_name TEXT NOT NULL, + partition_id TEXT NOT NULL, + sync_partition JSON NOT NULL, + sync_start_time TIMESTAMPTZ NOT NULL, + sync_finish_time TIMESTAMPTZ NOT NULL DEFAULT NOW() + )`) + if err != nil && !utils.IsUniqueError(err) { + p.logger.Error("failed to create qrep metadata table", slog.Any("error", err)) + return err + } + + _, err = p.conn.Exec(p.ctx, + `CREATE INDEX IF NOT EXISTS ix_qrep_metadata_partition_id ON `+ + p.QualifyTable(qrepTableName)+ + ` USING hash (partition_id)`) + if err != nil && !utils.IsUniqueError(err) { + p.logger.Error("failed to create qrep metadata index", slog.Any("error", err)) + return err + } + p.logger.Info(fmt.Sprintf("created external metadata table %s.%s", p.schemaName, lastSyncStateTableName)) return nil } func (p *PostgresMetadataStore) FetchLastOffset(jobName string) (int64, error) { - row := p.conn.QueryRow(p.ctx, ` - SELECT last_offset - FROM `+p.QualifyTable(lastSyncStateTableName)+` - WHERE job_name = $1 - `, jobName) + row := p.conn.QueryRow(p.ctx, + `SELECT last_offset FROM `+ + p.QualifyTable(lastSyncStateTableName)+ + ` WHERE job_name = $1`, jobName) var offset pgtype.Int8 err := row.Scan(&offset) if err != nil { @@ -158,11 +181,10 @@ func (p *PostgresMetadataStore) FetchLastOffset(jobName string) (int64, error) { } func (p *PostgresMetadataStore) GetLastBatchID(jobName string) (int64, error) { - row := p.conn.QueryRow(p.ctx, ` - SELECT sync_batch_id - FROM `+p.QualifyTable(lastSyncStateTableName)+` - WHERE job_name = $1 - `, jobName) + row := p.conn.QueryRow(p.ctx, + `SELECT sync_batch_id FROM `+ + p.QualifyTable(lastSyncStateTableName)+ + ` WHERE job_name = $1`, jobName) var syncBatchID pgtype.Int8 err := row.Scan(&syncBatchID) @@ -181,11 +203,10 @@ func (p *PostgresMetadataStore) GetLastBatchID(jobName string) (int64, error) { } func (p *PostgresMetadataStore) GetLastNormalizeBatchID(jobName string) (int64, error) { - rows := p.conn.QueryRow(p.ctx, ` - SELECT normalize_batch_id - FROM `+p.schemaName+`.`+lastSyncStateTableName+` - WHERE job_name = $1 - `, jobName) + rows := p.conn.QueryRow(p.ctx, + `SELECT normalize_batch_id FROM `+ + p.QualifyTable(lastSyncStateTableName)+ + ` WHERE job_name = $1`, jobName) var normalizeBatchID pgtype.Int8 err := rows.Scan(&normalizeBatchID) @@ -242,10 +263,9 @@ func (p *PostgresMetadataStore) FinishBatch(jobName string, syncBatchID int64, o func (p *PostgresMetadataStore) UpdateNormalizeBatchID(jobName string, batchID int64) error { p.logger.Info("updating normalize batch id for job") - _, err := p.conn.Exec(p.ctx, ` - UPDATE `+p.schemaName+`.`+lastSyncStateTableName+` - SET normalize_batch_id=$2 WHERE job_name=$1 - `, jobName, batchID) + _, err := p.conn.Exec(p.ctx, + `UPDATE `+p.QualifyTable(lastSyncStateTableName)+ + ` SET normalize_batch_id=$2 WHERE job_name=$1`, jobName, batchID) if err != nil { p.logger.Error("failed to update normalize batch id", slog.Any("error", err)) return err @@ -254,10 +274,51 @@ func (p *PostgresMetadataStore) UpdateNormalizeBatchID(jobName string, batchID i return nil } -func (p *PostgresMetadataStore) DropMetadata(jobName string) error { - _, err := p.conn.Exec(p.ctx, ` - DELETE FROM `+p.QualifyTable(lastSyncStateTableName)+` - WHERE job_name = $1 - `, jobName) +func (p *PostgresMetadataStore) FinishQrepPartition( + partition *protos.QRepPartition, + jobName string, + startTime time.Time, +) error { + pbytes, err := protojson.Marshal(partition) + if err != nil { + return fmt.Errorf("failed to marshal partition to json: %w", err) + } + partitionJSON := string(pbytes) + + _, err = p.conn.Exec(p.ctx, + `INSERT INTO `+p.QualifyTable(qrepTableName)+ + `(job_name, partition_id, sync_partition, sync_start_time) VALUES ($1, $2, $3, $4)`, + jobName, partition.PartitionId, partitionJSON, startTime) return err } + +func (p *PostgresMetadataStore) IsQrepPartitionSynced(partitionID string) (bool, error) { + var count int64 + err := p.conn.QueryRow(p.ctx, + `SELECT COUNT(*) FROM `+ + p.QualifyTable(qrepTableName)+ + ` WHERE partition_id = $1`, + partitionID).Scan(&count) + if err != nil { + return false, fmt.Errorf("failed to execute query: %w", err) + } + return count > 0, nil +} + +func (p *PostgresMetadataStore) DropMetadata(jobName string) error { + _, err := p.conn.Exec(p.ctx, + `DELETE FROM `+p.QualifyTable(lastSyncStateTableName)+ + ` WHERE job_name = $1`, jobName) + if err != nil { + return err + } + + _, err = p.conn.Exec(p.ctx, + `DELETE FROM `+p.QualifyTable(qrepTableName)+ + ` WHERE job_name = $1`, jobName) + if err != nil { + return err + } + + return nil +} diff --git a/flow/connectors/s3/s3.go b/flow/connectors/s3/s3.go index df031c9b3b..ec5c35e28c 100644 --- a/flow/connectors/s3/s3.go +++ b/flow/connectors/s3/s3.go @@ -155,13 +155,7 @@ func (c *S3Connector) NeedsSetupMetadataTables() bool { } func (c *S3Connector) SetupMetadataTables() error { - err := c.pgMetadata.SetupMetadata() - if err != nil { - c.logger.Error("failed to setup metadata tables", slog.Any("error", err)) - return err - } - - return nil + return c.pgMetadata.SetupMetadata() } func (c *S3Connector) GetLastSyncBatchID(jobName string) (int64, error) { @@ -172,15 +166,8 @@ func (c *S3Connector) GetLastOffset(jobName string) (int64, error) { return c.pgMetadata.FetchLastOffset(jobName) } -// update offset for a job func (c *S3Connector) SetLastOffset(jobName string, offset int64) error { - err := c.pgMetadata.UpdateLastOffset(jobName, offset) - if err != nil { - c.logger.Error("failed to update last offset: ", slog.Any("error", err)) - return err - } - - return nil + return c.pgMetadata.UpdateLastOffset(jobName, offset) } func (c *S3Connector) SyncRecords(req *model.SyncRecordsRequest) (*model.SyncResponse, error) { diff --git a/flow/connectors/snowflake/qrep.go b/flow/connectors/snowflake/qrep.go index 4361b17881..e3037d298e 100644 --- a/flow/connectors/snowflake/qrep.go +++ b/flow/connectors/snowflake/qrep.go @@ -5,12 +5,10 @@ import ( "fmt" "log/slog" "strings" - "time" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/jackc/pgx/v5/pgtype" - "google.golang.org/protobuf/encoding/protojson" "github.com/PeerDB-io/peer-flow/connectors/utils" "github.com/PeerDB-io/peer-flow/generated/protos" @@ -18,8 +16,6 @@ import ( "github.com/PeerDB-io/peer-flow/shared" ) -const qRepMetadataTableName = "_peerdb_query_replication_metadata" - func (c *SnowflakeConnector) SyncQRepRecords( config *protos.QRepConfig, partition *protos.QRepPartition, @@ -37,7 +33,7 @@ func (c *SnowflakeConnector) SyncQRepRecords( } c.logger.Info("Called QRep sync function and obtained table schema", flowLog) - done, err := c.isPartitionSynced(partition.PartitionId) + done, err := c.pgMetadata.IsQrepPartitionSynced(partition.PartitionId) if err != nil { return 0, fmt.Errorf("failed to check if partition %s is synced: %w", partition.PartitionId, err) } @@ -51,30 +47,6 @@ func (c *SnowflakeConnector) SyncQRepRecords( return avroSync.SyncQRepRecords(config, partition, tblSchema, stream) } -func (c *SnowflakeConnector) createMetadataInsertStatement( - partition *protos.QRepPartition, - jobName string, - startTime time.Time, -) (string, error) { - // marshal the partition to json using protojson - pbytes, err := protojson.Marshal(partition) - if err != nil { - return "", fmt.Errorf("failed to marshal partition to json: %v", err) - } - - // convert the bytes to string - partitionJSON := string(pbytes) - - insertMetadataStmt := fmt.Sprintf( - `INSERT INTO %s.%s - (flowJobName, partitionID, syncPartition, syncStartTime, syncFinishTime) - VALUES ('%s', '%s', '%s', '%s'::timestamp, CURRENT_TIMESTAMP);`, - c.metadataSchema, qRepMetadataTableName, jobName, partition.PartitionId, - partitionJSON, startTime.Format(time.RFC3339)) - - return insertMetadataStmt, nil -} - func (c *SnowflakeConnector) getTableSchema(tableName string) ([]*sql.ColumnType, error) { schematable, err := utils.ParseSchemaTable(tableName) if err != nil { @@ -99,49 +71,13 @@ func (c *SnowflakeConnector) getTableSchema(tableName string) ([]*sql.ColumnType return columnTypes, nil } -func (c *SnowflakeConnector) isPartitionSynced(partitionID string) (bool, error) { - //nolint:gosec - queryString := fmt.Sprintf(` - SELECT COUNT(*) - FROM %s.%s - WHERE partitionID = '%s' - `, c.metadataSchema, qRepMetadataTableName, partitionID) - - row := c.database.QueryRow(queryString) - - var count pgtype.Int8 - if err := row.Scan(&count); err != nil { - return false, fmt.Errorf("failed to execute query: %w", err) - } - - return count.Int64 > 0, nil -} - func (c *SnowflakeConnector) SetupQRepMetadataTables(config *protos.QRepConfig) error { - // NOTE that Snowflake does not support transactional DDL - createMetadataTablesTx, err := c.database.BeginTx(c.ctx, nil) - if err != nil { - return fmt.Errorf("unable to begin transaction for creating metadata tables: %w", err) - } - // in case we return after error, ensure transaction is rolled back - defer func() { - deferErr := createMetadataTablesTx.Rollback() - if deferErr != sql.ErrTxDone && deferErr != nil { - c.logger.Error("error while rolling back transaction for creating metadata tables", - slog.Any("error", deferErr)) - } - }() - err = c.createPeerDBInternalSchema(createMetadataTablesTx) - if err != nil { - return err - } - err = c.createQRepMetadataTable(createMetadataTablesTx) + _, err := c.database.ExecContext(c.ctx, fmt.Sprintf(createSchemaSQL, c.rawSchema)) if err != nil { return err } stageName := c.getStageNameForJob(config.FlowJobName) - err = c.createStage(stageName, config) if err != nil { return err @@ -154,35 +90,6 @@ func (c *SnowflakeConnector) SetupQRepMetadataTables(config *protos.QRepConfig) } } - err = createMetadataTablesTx.Commit() - if err != nil { - return fmt.Errorf("unable to commit transaction for creating metadata tables: %w", err) - } - - return nil -} - -func (c *SnowflakeConnector) createQRepMetadataTable(createMetadataTableTx *sql.Tx) error { - // Define the schema - schemaStatement := ` - CREATE TABLE IF NOT EXISTS %s.%s ( - flowJobName STRING, - partitionID STRING, - syncPartition STRING, - syncStartTime TIMESTAMP_LTZ, - syncFinishTime TIMESTAMP_LTZ - ); - ` - queryString := fmt.Sprintf(schemaStatement, c.metadataSchema, qRepMetadataTableName) - - _, err := createMetadataTableTx.Exec(queryString) - if err != nil { - c.logger.Error(fmt.Sprintf("failed to create table %s.%s", c.metadataSchema, qRepMetadataTableName), - slog.Any("error", err)) - return fmt.Errorf("failed to create table %s.%s: %w", c.metadataSchema, qRepMetadataTableName, err) - } - - c.logger.Info(fmt.Sprintf("Created table %s", qRepMetadataTableName)) return nil } @@ -371,5 +278,5 @@ func (c *SnowflakeConnector) dropStage(stagingPath string, job string) error { } func (c *SnowflakeConnector) getStageNameForJob(job string) string { - return fmt.Sprintf("%s.peerdb_stage_%s", c.metadataSchema, job) + return fmt.Sprintf("%s.peerdb_stage_%s", c.rawSchema, job) } diff --git a/flow/connectors/snowflake/qrep_avro_sync.go b/flow/connectors/snowflake/qrep_avro_sync.go index 3c330d636c..e72b0bc434 100644 --- a/flow/connectors/snowflake/qrep_avro_sync.go +++ b/flow/connectors/snowflake/qrep_avro_sync.go @@ -290,17 +290,9 @@ func (s *SnowflakeAvroSyncHandler) insertMetadata( startTime time.Time, ) error { partitionLog := slog.String(string(shared.PartitionIDKey), partition.PartitionId) - insertMetadataStmt, err := s.connector.createMetadataInsertStatement(partition, flowJobName, startTime) + err := s.connector.pgMetadata.FinishQrepPartition(partition, flowJobName, startTime) if err != nil { - s.connector.logger.Error("failed to create metadata insert statement", - slog.Any("error", err), partitionLog) - return fmt.Errorf("failed to create metadata insert statement: %v", err) - } - - if _, err := s.connector.database.ExecContext(s.connector.ctx, insertMetadataStmt); err != nil { - s.connector.logger.Error("failed to execute metadata insert statement "+insertMetadataStmt, - slog.Any("error", err), partitionLog) - return fmt.Errorf("failed to execute metadata insert statement: %v", err) + return fmt.Errorf("failed to execute metadata insert statement: %w", err) } s.connector.logger.Info("inserted metadata for partition", partitionLog) diff --git a/flow/connectors/snowflake/snowflake.go b/flow/connectors/snowflake/snowflake.go index dd479b6640..7f9b3c33fe 100644 --- a/flow/connectors/snowflake/snowflake.go +++ b/flow/connectors/snowflake/snowflake.go @@ -17,6 +17,7 @@ import ( "go.temporal.io/sdk/activity" "golang.org/x/sync/errgroup" + metadataStore "github.com/PeerDB-io/peer-flow/connectors/external_metadata" "github.com/PeerDB-io/peer-flow/connectors/utils" "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/model" @@ -26,11 +27,9 @@ import ( const ( mirrorJobsTableIdentifier = "PEERDB_MIRROR_JOBS" - createMirrorJobsTableSQL = `CREATE TABLE IF NOT EXISTS %s.%s(MIRROR_JOB_NAME STRING NOT NULL,OFFSET INT NOT NULL, - SYNC_BATCH_ID INT NOT NULL,NORMALIZE_BATCH_ID INT NOT NULL)` - rawTablePrefix = "_PEERDB_RAW" - createSchemaSQL = "CREATE TRANSIENT SCHEMA IF NOT EXISTS %s" - createRawTableSQL = `CREATE TABLE IF NOT EXISTS %s.%s(_PEERDB_UID STRING NOT NULL, + rawTablePrefix = "_PEERDB_RAW" + createSchemaSQL = "CREATE TRANSIENT SCHEMA IF NOT EXISTS %s" + createRawTableSQL = `CREATE TABLE IF NOT EXISTS %s.%s(_PEERDB_UID STRING NOT NULL, _PEERDB_TIMESTAMP INT NOT NULL,_PEERDB_DESTINATION_TABLE_NAME STRING NOT NULL,_PEERDB_DATA STRING NOT NULL, _PEERDB_RECORD_TYPE INTEGER NOT NULL, _PEERDB_MATCH_DATA STRING,_PEERDB_BATCH_ID INT, _PEERDB_UNCHANGED_TOAST_COLUMNS STRING)` @@ -55,19 +54,13 @@ const ( WHEN MATCHED AND (SOURCE._PEERDB_RECORD_TYPE = 2) THEN %s` getDistinctDestinationTableNames = `SELECT DISTINCT _PEERDB_DESTINATION_TABLE_NAME FROM %s.%s WHERE _PEERDB_BATCH_ID > %d AND _PEERDB_BATCH_ID <= %d` - getTableNametoUnchangedColsSQL = `SELECT _PEERDB_DESTINATION_TABLE_NAME, + getTableNameToUnchangedColsSQL = `SELECT _PEERDB_DESTINATION_TABLE_NAME, ARRAY_AGG(DISTINCT _PEERDB_UNCHANGED_TOAST_COLUMNS) FROM %s.%s WHERE _PEERDB_BATCH_ID > %d AND _PEERDB_BATCH_ID <= %d AND _PEERDB_RECORD_TYPE != 2 GROUP BY _PEERDB_DESTINATION_TABLE_NAME` getTableSchemaSQL = `SELECT COLUMN_NAME, DATA_TYPE FROM INFORMATION_SCHEMA.COLUMNS WHERE UPPER(TABLE_SCHEMA)=? AND UPPER(TABLE_NAME)=? ORDER BY ORDINAL_POSITION` - insertJobMetadataSQL = "INSERT INTO %s.%s VALUES (?,?,?,?)" - - updateMetadataForSyncRecordsSQL = `UPDATE %s.%s SET OFFSET=GREATEST(OFFSET, ?), SYNC_BATCH_ID=? - WHERE MIRROR_JOB_NAME=?` - updateMetadataForNormalizeRecordsSQL = "UPDATE %s.%s SET NORMALIZE_BATCH_ID=? WHERE MIRROR_JOB_NAME=?" - checkIfTableExistsSQL = `SELECT TO_BOOLEAN(COUNT(1)) FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA=? and TABLE_NAME=?` checkIfJobMetadataExistsSQL = "SELECT TO_BOOLEAN(COUNT(1)) FROM %s.%s WHERE MIRROR_JOB_NAME=?" @@ -78,14 +71,14 @@ const ( dropTableIfExistsSQL = "DROP TABLE IF EXISTS %s.%s" deleteJobMetadataSQL = "DELETE FROM %s.%s WHERE MIRROR_JOB_NAME=?" dropSchemaIfExistsSQL = "DROP SCHEMA IF EXISTS %s" - checkSchemaExistsSQL = "SELECT TO_BOOLEAN(COUNT(1)) FROM INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME=?" ) type SnowflakeConnector struct { - ctx context.Context - database *sql.DB - metadataSchema string - logger slog.Logger + ctx context.Context + database *sql.DB + pgMetadata *metadataStore.PostgresMetadataStore + rawSchema string + logger slog.Logger } // creating this to capture array results from snowflake. @@ -206,17 +199,23 @@ func NewSnowflakeConnector(ctx context.Context, return nil, fmt.Errorf("could not validate snowflake peer: %w", err) } - metadataSchema := "_PEERDB_INTERNAL" + rawSchema := "_PEERDB_INTERNAL" if snowflakeProtoConfig.MetadataSchema != nil { - metadataSchema = *snowflakeProtoConfig.MetadataSchema + rawSchema = *snowflakeProtoConfig.MetadataSchema + } + + pgMetadata, err := metadataStore.NewPostgresMetadataStore(ctx, nil, "peerdb_sf_metadata") + if err != nil { + return nil, fmt.Errorf("could not connect to metadata store: %w", err) } flowName, _ := ctx.Value(shared.FlowNameKey).(string) return &SnowflakeConnector{ - ctx: ctx, - database: database, - metadataSchema: metadataSchema, - logger: *slog.With(slog.String(string(shared.FlowNameKey), flowName)), + ctx: ctx, + database: database, + pgMetadata: pgMetadata, + rawSchema: rawSchema, + logger: *slog.With(slog.String(string(shared.FlowNameKey), flowName)), }, nil } @@ -243,43 +242,11 @@ func (c *SnowflakeConnector) ConnectionActive() error { } func (c *SnowflakeConnector) NeedsSetupMetadataTables() bool { - result, err := c.checkIfTableExists(c.metadataSchema, mirrorJobsTableIdentifier) - if err != nil { - return true - } - return !result + return c.pgMetadata.NeedsSetupMetadata() } func (c *SnowflakeConnector) SetupMetadataTables() error { - // NOTE that Snowflake does not support transactional DDL - createMetadataTablesTx, err := c.database.BeginTx(c.ctx, nil) - if err != nil { - return fmt.Errorf("unable to begin transaction for creating metadata tables: %w", err) - } - // in case we return after error, ensure transaction is rolled back - defer func() { - deferErr := createMetadataTablesTx.Rollback() - if deferErr != sql.ErrTxDone && deferErr != nil { - c.logger.Error("error while rolling back transaction for creating metadata tables", - slog.Any("error", deferErr)) - } - }() - - err = c.createPeerDBInternalSchema(createMetadataTablesTx) - if err != nil { - return err - } - _, err = createMetadataTablesTx.ExecContext(c.ctx, fmt.Sprintf(createMirrorJobsTableSQL, - c.metadataSchema, mirrorJobsTableIdentifier)) - if err != nil { - return fmt.Errorf("error while setting up mirror jobs table: %w", err) - } - err = createMetadataTablesTx.Commit() - if err != nil { - return fmt.Errorf("unable to commit transaction for creating metadata tables: %w", err) - } - - return nil + return c.pgMetadata.SetupMetadata() } // only used for testing atm. doesn't return info about pkey or ReplicaIdentity [which is PG specific anyway]. @@ -324,58 +291,19 @@ func (c *SnowflakeConnector) getTableSchemaForTable(tableName string) (*protos.T } func (c *SnowflakeConnector) GetLastOffset(jobName string) (int64, error) { - var result pgtype.Int8 - err := c.database.QueryRowContext(c.ctx, fmt.Sprintf(getLastOffsetSQL, - c.metadataSchema, mirrorJobsTableIdentifier), jobName).Scan(&result) - if err != nil { - if err == sql.ErrNoRows { - c.logger.Warn("No row found, returning 0") - return 0, nil - } - return 0, fmt.Errorf("error while reading result row: %w", err) - } - if result.Int64 == 0 { - c.logger.Warn("Assuming zero offset means no sync has happened") - return 0, nil - } - return result.Int64, nil + return c.pgMetadata.FetchLastOffset(jobName) } -func (c *SnowflakeConnector) SetLastOffset(jobName string, lastOffset int64) error { - _, err := c.database.ExecContext(c.ctx, fmt.Sprintf(setLastOffsetSQL, - c.metadataSchema, mirrorJobsTableIdentifier), lastOffset, jobName) - if err != nil { - return fmt.Errorf("error querying Snowflake peer for last syncedID: %w", err) - } - return nil +func (c *SnowflakeConnector) SetLastOffset(jobName string, offset int64) error { + return c.pgMetadata.UpdateLastOffset(jobName, offset) } func (c *SnowflakeConnector) GetLastSyncBatchID(jobName string) (int64, error) { - var result pgtype.Int8 - err := c.database.QueryRowContext(c.ctx, fmt.Sprintf(getLastSyncBatchID_SQL, c.metadataSchema, - mirrorJobsTableIdentifier), jobName).Scan(&result) - if err != nil { - if err == sql.ErrNoRows { - c.logger.Warn("No row found, returning 0") - return 0, nil - } - return 0, fmt.Errorf("error while reading result row: %w", err) - } - return result.Int64, nil + return c.pgMetadata.GetLastBatchID(jobName) } func (c *SnowflakeConnector) GetLastNormalizeBatchID(jobName string) (int64, error) { - var normBatchID pgtype.Int8 - err := c.database.QueryRowContext(c.ctx, fmt.Sprintf(getLastNormalizeBatchID_SQL, c.metadataSchema, - mirrorJobsTableIdentifier), jobName).Scan(&normBatchID) - if err != nil { - if err == sql.ErrNoRows { - c.logger.Warn("No row found, returning 0") - return 0, nil - } - return 0, fmt.Errorf("error while reading result row: %w", err) - } - return normBatchID.Int64, nil + return c.pgMetadata.GetLastNormalizeBatchID(jobName) } func (c *SnowflakeConnector) getDistinctTableNamesInBatch(flowJobName string, syncBatchID int64, @@ -383,7 +311,7 @@ func (c *SnowflakeConnector) getDistinctTableNamesInBatch(flowJobName string, sy ) ([]string, error) { rawTableIdentifier := getRawTableIdentifier(flowJobName) - rows, err := c.database.QueryContext(c.ctx, fmt.Sprintf(getDistinctDestinationTableNames, c.metadataSchema, + rows, err := c.database.QueryContext(c.ctx, fmt.Sprintf(getDistinctDestinationTableNames, c.rawSchema, rawTableIdentifier, normalizeBatchID, syncBatchID)) if err != nil { return nil, fmt.Errorf("error while retrieving table names for normalization: %w", err) @@ -407,12 +335,12 @@ func (c *SnowflakeConnector) getDistinctTableNamesInBatch(flowJobName string, sy return destinationTableNames, nil } -func (c *SnowflakeConnector) getTableNametoUnchangedCols(flowJobName string, syncBatchID int64, +func (c *SnowflakeConnector) getTableNameToUnchangedCols(flowJobName string, syncBatchID int64, normalizeBatchID int64, ) (map[string][]string, error) { rawTableIdentifier := getRawTableIdentifier(flowJobName) - rows, err := c.database.QueryContext(c.ctx, fmt.Sprintf(getTableNametoUnchangedColsSQL, c.metadataSchema, + rows, err := c.database.QueryContext(c.ctx, fmt.Sprintf(getTableNameToUnchangedColsSQL, c.rawSchema, rawTableIdentifier, normalizeBatchID, syncBatchID)) if err != nil { return nil, fmt.Errorf("error while retrieving table names for normalization: %w", err) @@ -533,27 +461,7 @@ func (c *SnowflakeConnector) SyncRecords(req *model.SyncRecordsRequest) (*model. return nil, err } - // transaction for SyncRecords - syncRecordsTx, err := c.database.BeginTx(c.ctx, nil) - if err != nil { - return nil, err - } - // in case we return after error, ensure transaction is rolled back - defer func() { - deferErr := syncRecordsTx.Rollback() - if deferErr != sql.ErrTxDone && deferErr != nil { - c.logger.Error("error while rolling back transaction for SyncRecords: %v", - slog.Any("error", deferErr), slog.Int64("syncBatchID", req.SyncBatchID)) - } - }() - - // updating metadata with new offset and syncBatchID - err = c.updateSyncMetadata(req.FlowJobName, res.LastSyncedCheckpointID, req.SyncBatchID, syncRecordsTx) - if err != nil { - return nil, err - } - // transaction commits - err = syncRecordsTx.Commit() + err = c.pgMetadata.FinishBatch(req.FlowJobName, req.SyncBatchID, res.LastSyncedCheckpointID) if err != nil { return nil, err } @@ -576,7 +484,7 @@ func (c *SnowflakeConnector) syncRecordsViaAvro( qrepConfig := &protos.QRepConfig{ StagingPath: "", FlowJobName: req.FlowJobName, - DestinationTableIdentifier: strings.ToLower(fmt.Sprintf("%s.%s", c.metadataSchema, + DestinationTableIdentifier: strings.ToLower(fmt.Sprintf("%s.%s", c.rawSchema, rawTableIdentifier)), } avroSyncer := NewSnowflakeAvroSyncHandler(qrepConfig, c) @@ -625,12 +533,12 @@ func (c *SnowflakeConnector) NormalizeRecords(req *model.NormalizeRecordsRequest }, nil } - jobMetadataExists, err := c.jobMetadataExists(req.FlowJobName) + rawSchemaExists, err := c.rawSchemaExists(req.FlowJobName) if err != nil { return nil, err } // sync hasn't created job metadata yet, chill. - if !jobMetadataExists { + if !rawSchemaExists { return &model.NormalizeResponse{ Done: false, }, nil @@ -644,7 +552,7 @@ func (c *SnowflakeConnector) NormalizeRecords(req *model.NormalizeRecordsRequest return nil, err } - tableNametoUnchangedToastCols, err := c.getTableNametoUnchangedCols(req.FlowJobName, req.SyncBatchID, normBatchID) + tableNameToUnchangedToastCols, err := c.getTableNameToUnchangedCols(req.FlowJobName, req.SyncBatchID, normBatchID) if err != nil { return nil, fmt.Errorf("couldn't tablename to unchanged cols mapping: %w", err) } @@ -663,7 +571,7 @@ func (c *SnowflakeConnector) NormalizeRecords(req *model.NormalizeRecordsRequest syncBatchID: req.SyncBatchID, normalizeBatchID: normBatchID, normalizedTableSchema: req.TableNameSchemaMapping[tableName], - unchangedToastColumns: tableNametoUnchangedToastCols[tableName], + unchangedToastColumns: tableNameToUnchangedToastCols[tableName], peerdbCols: &protos.PeerDBColumns{ SoftDelete: req.SoftDelete, SoftDeleteColName: req.SoftDeleteColName, @@ -706,8 +614,7 @@ func (c *SnowflakeConnector) NormalizeRecords(req *model.NormalizeRecordsRequest return nil, fmt.Errorf("error while normalizing records: %w", err) } - // updating metadata with new normalizeBatchID - err = c.updateNormalizeMetadata(req.FlowJobName, req.SyncBatchID) + err = c.pgMetadata.UpdateNormalizeBatchID(req.FlowJobName, req.SyncBatchID) if err != nil { return nil, err } @@ -720,20 +627,20 @@ func (c *SnowflakeConnector) NormalizeRecords(req *model.NormalizeRecordsRequest } func (c *SnowflakeConnector) CreateRawTable(req *protos.CreateRawTableInput) (*protos.CreateRawTableOutput, error) { - rawTableIdentifier := getRawTableIdentifier(req.FlowJobName) + _, err := c.database.ExecContext(c.ctx, fmt.Sprintf(createSchemaSQL, c.rawSchema)) + if err != nil { + return nil, err + } createRawTableTx, err := c.database.BeginTx(c.ctx, nil) if err != nil { return nil, fmt.Errorf("unable to begin transaction for creation of raw table: %w", err) } - err = c.createPeerDBInternalSchema(createRawTableTx) - if err != nil { - return nil, err - } // there is no easy way to check if a table has the same schema in Snowflake, // so just executing the CREATE TABLE IF NOT EXISTS blindly. + rawTableIdentifier := getRawTableIdentifier(req.FlowJobName) _, err = createRawTableTx.ExecContext(c.ctx, - fmt.Sprintf(createRawTableSQL, c.metadataSchema, rawTableIdentifier)) + fmt.Sprintf(createRawTableSQL, c.rawSchema, rawTableIdentifier)) if err != nil { return nil, fmt.Errorf("unable to create raw table: %w", err) } @@ -754,6 +661,11 @@ func (c *SnowflakeConnector) CreateRawTable(req *protos.CreateRawTableInput) (*p } func (c *SnowflakeConnector) SyncFlowCleanup(jobName string) error { + err := c.pgMetadata.DropMetadata(jobName) + if err != nil { + return fmt.Errorf("unable to clear metadata for sync flow cleanup: %w", err) + } + syncFlowCleanupTx, err := c.database.BeginTx(c.ctx, nil) if err != nil { return fmt.Errorf("unable to begin transaction for sync flow cleanup: %w", err) @@ -765,31 +677,6 @@ func (c *SnowflakeConnector) SyncFlowCleanup(jobName string) error { } }() - row := syncFlowCleanupTx.QueryRowContext(c.ctx, checkSchemaExistsSQL, c.metadataSchema) - var schemaExists pgtype.Bool - err = row.Scan(&schemaExists) - if err != nil { - return fmt.Errorf("unable to check if internal schema exists: %w", err) - } - - if schemaExists.Bool { - _, err = syncFlowCleanupTx.ExecContext(c.ctx, fmt.Sprintf(dropTableIfExistsSQL, c.metadataSchema, - getRawTableIdentifier(jobName))) - if err != nil { - return fmt.Errorf("unable to drop raw table: %w", err) - } - _, err = syncFlowCleanupTx.ExecContext(c.ctx, - fmt.Sprintf(deleteJobMetadataSQL, c.metadataSchema, mirrorJobsTableIdentifier), jobName) - if err != nil { - return fmt.Errorf("unable to delete job metadata: %w", err) - } - } - - err = syncFlowCleanupTx.Commit() - if err != nil { - return fmt.Errorf("unable to commit transaction for sync flow cleanup: %w", err) - } - err = c.dropStage("", jobName) if err != nil { return err @@ -861,92 +748,16 @@ func getRawTableIdentifier(jobName string) string { return fmt.Sprintf("%s_%s", rawTablePrefix, jobName) } -func (c *SnowflakeConnector) jobMetadataExists(jobName string) (bool, error) { +func (c *SnowflakeConnector) rawSchemaExists(jobName string) (bool, error) { var result pgtype.Bool err := c.database.QueryRowContext(c.ctx, - fmt.Sprintf(checkIfJobMetadataExistsSQL, c.metadataSchema, mirrorJobsTableIdentifier), jobName).Scan(&result) - if err != nil { - return false, fmt.Errorf("error reading result row: %w", err) - } - return result.Bool, nil -} - -func (c *SnowflakeConnector) jobMetadataExistsTx(tx *sql.Tx, jobName string) (bool, error) { - var result pgtype.Bool - err := tx.QueryRowContext(c.ctx, - fmt.Sprintf(checkIfJobMetadataExistsSQL, c.metadataSchema, mirrorJobsTableIdentifier), jobName).Scan(&result) + fmt.Sprintf(checkIfJobMetadataExistsSQL, c.rawSchema, mirrorJobsTableIdentifier), jobName).Scan(&result) if err != nil { return false, fmt.Errorf("error reading result row: %w", err) } return result.Bool, nil } -func (c *SnowflakeConnector) updateSyncMetadata(flowJobName string, lastCP int64, - syncBatchID int64, syncRecordsTx *sql.Tx, -) error { - jobMetadataExists, err := c.jobMetadataExistsTx(syncRecordsTx, flowJobName) - if err != nil { - return fmt.Errorf("failed to get sync status for flow job: %w", err) - } - - if !jobMetadataExists { - _, err := syncRecordsTx.ExecContext(c.ctx, - fmt.Sprintf(insertJobMetadataSQL, c.metadataSchema, mirrorJobsTableIdentifier), - flowJobName, lastCP, syncBatchID, 0) - if err != nil { - return fmt.Errorf("failed to insert flow job status: %w", err) - } - } else { - _, err := syncRecordsTx.ExecContext(c.ctx, - fmt.Sprintf(updateMetadataForSyncRecordsSQL, c.metadataSchema, mirrorJobsTableIdentifier), - lastCP, syncBatchID, flowJobName) - if err != nil { - return fmt.Errorf("failed to update flow job status: %w", err) - } - } - - return nil -} - -func (c *SnowflakeConnector) updateNormalizeMetadata(flowJobName string, normalizeBatchID int64) error { - jobMetadataExists, err := c.jobMetadataExists(flowJobName) - if err != nil { - return fmt.Errorf("failed to get sync status for flow job: %w", err) - } - if !jobMetadataExists { - return fmt.Errorf("job metadata does not exist, unable to update") - } - - stmt := fmt.Sprintf(updateMetadataForNormalizeRecordsSQL, c.metadataSchema, mirrorJobsTableIdentifier) - _, err = c.database.ExecContext(c.ctx, stmt, normalizeBatchID, flowJobName) - if err != nil { - return fmt.Errorf("failed to update metadata for NormalizeTables: %w", err) - } - - return nil -} - -func (c *SnowflakeConnector) createPeerDBInternalSchema(createSchemaTx *sql.Tx) error { - // check if the internal schema exists - row := createSchemaTx.QueryRowContext(c.ctx, checkSchemaExistsSQL, c.metadataSchema) - var schemaExists pgtype.Bool - err := row.Scan(&schemaExists) - if err != nil { - return fmt.Errorf("error while reading result row: %w", err) - } - - if schemaExists.Bool { - c.logger.Info(fmt.Sprintf("internal schema %s already exists", c.metadataSchema)) - return nil - } - - _, err = createSchemaTx.ExecContext(c.ctx, fmt.Sprintf(createSchemaSQL, c.metadataSchema)) - if err != nil { - return fmt.Errorf("error while creating internal schema for PeerDB: %w", err) - } - return nil -} - func (c *SnowflakeConnector) RenameTables(req *protos.RenameTablesInput) (*protos.RenameTablesOutput, error) { renameTablesTx, err := c.database.BeginTx(c.ctx, nil) if err != nil {