From 0840060826f35a2f4ce6310263aa2b03a77439c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20Dub=C3=A9?= Date: Tue, 13 Feb 2024 15:06:06 +0000 Subject: [PATCH 1/3] CreateNormalizedTable: lift loop, make it based on separate connector Also introduce a generalized GetConnectorAs, use internally for existing functions --- flow/activities/flowable.go | 57 ++++- flow/connectors/bigquery/bigquery.go | 230 +++++++++--------- .../bigquery/merge_stmt_generator.go | 4 +- flow/connectors/clickhouse/normalize.go | 69 +++--- flow/connectors/core.go | 172 +++++-------- flow/connectors/eventhub/eventhub.go | 10 - flow/connectors/postgres/postgres.go | 83 +++---- flow/connectors/s3/s3.go | 8 - flow/connectors/snowflake/snowflake.go | 64 ++--- flow/workflows/qrep_flow.go | 3 +- flow/workflows/setup_flow.go | 3 +- 11 files changed, 332 insertions(+), 371 deletions(-) diff --git a/flow/activities/flowable.go b/flow/activities/flowable.go index 22c2fcf99b..d63433a0f5 100644 --- a/flow/activities/flowable.go +++ b/flow/activities/flowable.go @@ -7,6 +7,7 @@ import ( "fmt" "log/slog" "sync" + "sync/atomic" "time" "github.com/jackc/pglogrepl" @@ -141,25 +142,69 @@ func (a *FlowableActivity) GetTableSchema( return srcConn.GetTableSchema(ctx, config) } -// CreateNormalizedTable creates a normalized table in the destination flowable. +// CreateNormalizedTable creates normalized tables in destination. func (a *FlowableActivity) CreateNormalizedTable( ctx context.Context, config *protos.SetupNormalizedTableBatchInput, ) (*protos.SetupNormalizedTableBatchOutput, error) { ctx = context.WithValue(ctx, shared.FlowNameKey, config.FlowName) - conn, err := connectors.GetCDCSyncConnector(ctx, config.PeerConnectionConfig) + conn, err := connectors.GetConnectorAs[connectors.NormalizedTablesConnector](ctx, config.PeerConnectionConfig) if err != nil { + if err == connectors.ErrUnsupportedFunctionality { + activity.GetLogger(ctx).Info("Connector does not implement normalized tables") + return nil, nil + } return nil, fmt.Errorf("failed to get connector: %w", err) } defer connectors.CloseConnector(ctx, conn) - setupNormalizedTablesOutput, err := conn.SetupNormalizedTables(ctx, config) + tx, err := conn.StartSetupNormalizedTables(ctx) if err != nil { - a.Alerter.LogFlowError(ctx, config.FlowName, err) - return nil, fmt.Errorf("failed to setup normalized tables: %w", err) + return nil, fmt.Errorf("failed to setup normalized tables tx: %w", err) } + defer conn.AbortSetupNormalizedTables(ctx, tx) + + numTablesSetup := atomic.Uint32{} + totalTables := uint32(len(config.TableNameSchemaMapping)) + shutdown := utils.HeartbeatRoutine(ctx, func() string { + return fmt.Sprintf("setting up normalized tables - %d of %d done", + numTablesSetup.Load(), totalTables) + }) + defer shutdown() - return setupNormalizedTablesOutput, nil + logger := activity.GetLogger(ctx) + tableExistsMapping := make(map[string]bool) + for tableIdentifier, tableSchema := range config.TableNameSchemaMapping { + created, err := conn.SetupNormalizedTable( + ctx, + tx, + tableIdentifier, + tableSchema, + config.SoftDeleteColName, + config.SyncedAtColName, + ) + if err != nil { + a.Alerter.LogFlowError(ctx, config.FlowName, err) + return nil, fmt.Errorf("failed to setup normalized table %s: %w", tableIdentifier, err) + } + tableExistsMapping[tableIdentifier] = created + + numTablesSetup.Add(1) + if created { + logger.Info(fmt.Sprintf("created table %s", tableIdentifier)) + } else { + logger.Info(fmt.Sprintf("table already exists %s", tableIdentifier)) + } + } + + err = conn.FinishSetupNormalizedTables(ctx, tx) + if err != nil { + return nil, fmt.Errorf("failed to commit normalized tables tx: %w", err) + } + + return &protos.SetupNormalizedTableBatchOutput{ + TableExistsMapping: tableExistsMapping, + }, nil } func (a *FlowableActivity) StartFlow(ctx context.Context, diff --git a/flow/connectors/bigquery/bigquery.go b/flow/connectors/bigquery/bigquery.go index 135bd3f096..edc2546903 100644 --- a/flow/connectors/bigquery/bigquery.go +++ b/flow/connectors/bigquery/bigquery.go @@ -9,7 +9,6 @@ import ( "reflect" "regexp" "strings" - "sync/atomic" "time" "cloud.google.com/go/bigquery" @@ -521,7 +520,7 @@ func (c *BigQueryConnector) NormalizeRecords(ctx context.Context, req *model.Nor unchangedToastColumns := tableNametoUnchangedToastCols[tableName] dstDatasetTable, _ := c.convertToDatasetTable(tableName) mergeGen := &mergeStmtGenerator{ - rawDatasetTable: &datasetTable{ + rawDatasetTable: datasetTable{ project: c.projectID, dataset: c.datasetID, table: rawTableName, @@ -648,138 +647,131 @@ func (c *BigQueryConnector) CreateRawTable(ctx context.Context, req *protos.Crea }, nil } -// SetupNormalizedTables sets up normalized tables, implementing the Connector interface. +func (c *BigQueryConnector) StartSetupNormalizedTables(_ context.Context) (interface{}, error) { + // needed since CreateNormalizedTable duplicate check isn't accurate enough + return make(map[datasetTable]struct{}), nil +} + +func (c *BigQueryConnector) FinishSetupNormalizedTables(_ context.Context) error { + return nil +} + +func (c *BigQueryConnector) AbortSetupNormalizedTables(_ context.Context, _ interface{}) { +} + // This runs CREATE TABLE IF NOT EXISTS on bigquery, using the schema and table name provided. -func (c *BigQueryConnector) SetupNormalizedTables( +func (c *BigQueryConnector) SetupNormalizedTable( ctx context.Context, - req *protos.SetupNormalizedTableBatchInput, -) (*protos.SetupNormalizedTableBatchOutput, error) { - numTablesSetup := atomic.Uint32{} - totalTables := uint32(len(req.TableNameSchemaMapping)) - - shutdown := utils.HeartbeatRoutine(ctx, func() string { - return fmt.Sprintf("setting up normalized tables - %d of %d done", - numTablesSetup.Load(), totalTables) - }) - defer shutdown() - - tableExistsMapping := make(map[string]bool) - datasetTablesSet := make(map[datasetTable]struct{}) - for tableIdentifier, tableSchema := range req.TableNameSchemaMapping { - // only place where we check for parsing errors - datasetTable, err := c.convertToDatasetTable(tableIdentifier) - if err != nil { - return nil, err - } - _, ok := datasetTablesSet[*datasetTable] - if ok { - return nil, fmt.Errorf("invalid mirror: two tables mirror to the same BigQuery table %s", - datasetTable.string()) + tx interface{}, + tableIdentifier string, + tableSchema *protos.TableSchema, + softDeleteColName string, + syncedAtColName string, +) (bool, error) { + datasetTablesSet := tx.(map[datasetTable]struct{}) + + // only place where we check for parsing errors + datasetTable, err := c.convertToDatasetTable(tableIdentifier) + if err != nil { + return false, err + } + _, ok := datasetTablesSet[datasetTable] + if ok { + return false, fmt.Errorf("invalid mirror: two tables mirror to the same BigQuery table %s", + datasetTable.string()) + } + datasetTablesSet[datasetTable] = struct{}{} + dataset := c.client.DatasetInProject(c.projectID, datasetTable.dataset) + _, err = dataset.Metadata(ctx) + // just assume this means dataset don't exist, and create it + if err != nil { + // if err message does not contain `notFound`, then other error happened. + if !strings.Contains(err.Error(), "notFound") { + return false, fmt.Errorf("error while checking metadata for BigQuery dataset %s: %w", + datasetTable.dataset, err) } - dataset := c.client.DatasetInProject(c.projectID, datasetTable.dataset) - _, err = dataset.Metadata(ctx) - // just assume this means dataset don't exist, and create it + c.logger.Info(fmt.Sprintf("creating dataset %s...", dataset.DatasetID)) + err = dataset.Create(ctx, nil) if err != nil { - // if err message does not contain `notFound`, then other error happened. - if !strings.Contains(err.Error(), "notFound") { - return nil, fmt.Errorf("error while checking metadata for BigQuery dataset %s: %w", - datasetTable.dataset, err) - } - c.logger.Info(fmt.Sprintf("creating dataset %s...", dataset.DatasetID)) - err = dataset.Create(ctx, nil) - if err != nil { - return nil, fmt.Errorf("failed to create BigQuery dataset %s: %w", dataset.DatasetID, err) - } - } - table := dataset.Table(datasetTable.table) - - // check if the table exists - _, err = table.Metadata(ctx) - if err == nil { - // table exists, go to next table - tableExistsMapping[tableIdentifier] = true - datasetTablesSet[*datasetTable] = struct{}{} - - c.logger.Info(fmt.Sprintf("table already exists %s", tableIdentifier)) - numTablesSetup.Add(1) - continue + return false, fmt.Errorf("failed to create BigQuery dataset %s: %w", dataset.DatasetID, err) } + } + table := dataset.Table(datasetTable.table) - // convert the column names and types to bigquery types - columns := make([]*bigquery.FieldSchema, 0, len(tableSchema.Columns)+2) - for _, column := range tableSchema.Columns { - genericColType := column.Type - if genericColType == "numeric" { - precision, scale := numeric.ParseNumericTypmod(column.TypeModifier) - if column.TypeModifier == -1 || precision > 38 || scale > 37 { - precision = numeric.PeerDBNumericPrecision - scale = numeric.PeerDBNumericScale - } - columns = append(columns, &bigquery.FieldSchema{ - Name: column.Name, - Type: bigquery.BigNumericFieldType, - Repeated: qvalue.QValueKind(genericColType).IsArray(), - Precision: int64(precision), - Scale: int64(scale), - }) - } else { - columns = append(columns, &bigquery.FieldSchema{ - Name: column.Name, - Type: qValueKindToBigQueryType(genericColType), - Repeated: qvalue.QValueKind(genericColType).IsArray(), - }) + // check if the table exists + _, err = table.Metadata(ctx) + if err == nil { + // table exists, go to next table + return true, nil + } + + // convert the column names and types to bigquery types + columns := make([]*bigquery.FieldSchema, 0, len(tableSchema.Columns)+2) + for _, column := range tableSchema.Columns { + genericColType := column.Type + if genericColType == "numeric" { + precision, scale := numeric.ParseNumericTypmod(column.TypeModifier) + if column.TypeModifier == -1 || precision > 38 || scale > 37 { + precision = numeric.PeerDBNumericPrecision + scale = numeric.PeerDBNumericScale } - } - - if req.SoftDeleteColName != "" { columns = append(columns, &bigquery.FieldSchema{ - Name: req.SoftDeleteColName, - Type: bigquery.BooleanFieldType, - Repeated: false, + Name: column.Name, + Type: bigquery.BigNumericFieldType, + Repeated: qvalue.QValueKind(genericColType).IsArray(), + Precision: int64(precision), + Scale: int64(scale), }) - } - - if req.SyncedAtColName != "" { + } else { columns = append(columns, &bigquery.FieldSchema{ - Name: req.SyncedAtColName, - Type: bigquery.TimestampFieldType, - Repeated: false, + Name: column.Name, + Type: qValueKindToBigQueryType(genericColType), + Repeated: qvalue.QValueKind(genericColType).IsArray(), }) } + } + + if softDeleteColName != "" { + columns = append(columns, &bigquery.FieldSchema{ + Name: softDeleteColName, + Type: bigquery.BooleanFieldType, + Repeated: false, + }) + } - // create the table using the columns - schema := bigquery.Schema(columns) + if syncedAtColName != "" { + columns = append(columns, &bigquery.FieldSchema{ + Name: syncedAtColName, + Type: bigquery.TimestampFieldType, + Repeated: false, + }) + } - // cluster by the primary key if < 4 columns. - var clustering *bigquery.Clustering - numPkeyCols := len(tableSchema.PrimaryKeyColumns) - if numPkeyCols > 0 && numPkeyCols < 4 { - clustering = &bigquery.Clustering{ - Fields: tableSchema.PrimaryKeyColumns, - } - } + // create the table using the columns + schema := bigquery.Schema(columns) - metadata := &bigquery.TableMetadata{ - Schema: schema, - Name: datasetTable.table, - Clustering: clustering, + // cluster by the primary key if < 4 columns. + var clustering *bigquery.Clustering + numPkeyCols := len(tableSchema.PrimaryKeyColumns) + if numPkeyCols > 0 && numPkeyCols < 4 { + clustering = &bigquery.Clustering{ + Fields: tableSchema.PrimaryKeyColumns, } + } - err = table.Create(ctx, metadata) - if err != nil { - return nil, fmt.Errorf("failed to create table %s: %w", tableIdentifier, err) - } + metadata := &bigquery.TableMetadata{ + Schema: schema, + Name: datasetTable.table, + Clustering: clustering, + } - tableExistsMapping[tableIdentifier] = false - datasetTablesSet[*datasetTable] = struct{}{} - // log that table was created - c.logger.Info(fmt.Sprintf("created table %s", tableIdentifier)) - numTablesSetup.Add(1) + err = table.Create(ctx, metadata) + if err != nil { + return false, fmt.Errorf("failed to create table %s: %w", tableIdentifier, err) } - return &protos.SetupNormalizedTableBatchOutput{ - TableExistsMapping: tableExistsMapping, - }, nil + datasetTablesSet[datasetTable] = struct{}{} + return false, nil } func (c *BigQueryConnector) SyncFlowCleanup(ctx context.Context, jobName string) error { @@ -982,25 +974,25 @@ func (d *datasetTable) string() string { return fmt.Sprintf("%s.%s.%s", d.project, d.dataset, d.table) } -func (c *BigQueryConnector) convertToDatasetTable(tableName string) (*datasetTable, error) { +func (c *BigQueryConnector) convertToDatasetTable(tableName string) (datasetTable, error) { parts := strings.Split(tableName, ".") if len(parts) == 1 { - return &datasetTable{ + return datasetTable{ dataset: c.datasetID, table: parts[0], }, nil } else if len(parts) == 2 { - return &datasetTable{ + return datasetTable{ dataset: parts[0], table: parts[1], }, nil } else if len(parts) == 3 { - return &datasetTable{ + return datasetTable{ project: parts[0], dataset: parts[1], table: parts[2], }, nil } else { - return nil, fmt.Errorf("invalid BigQuery table name: %s", tableName) + return datasetTable{}, fmt.Errorf("invalid BigQuery table name: %s", tableName) } } diff --git a/flow/connectors/bigquery/merge_stmt_generator.go b/flow/connectors/bigquery/merge_stmt_generator.go index 21350fc76e..eb8ebb6177 100644 --- a/flow/connectors/bigquery/merge_stmt_generator.go +++ b/flow/connectors/bigquery/merge_stmt_generator.go @@ -13,11 +13,11 @@ import ( type mergeStmtGenerator struct { // dataset + raw table - rawDatasetTable *datasetTable + rawDatasetTable datasetTable // destination table name, used to retrieve records from raw table dstTableName string // dataset + destination table - dstDatasetTable *datasetTable + dstDatasetTable datasetTable // last synced batchID. syncBatchID int64 // last normalized batchID. diff --git a/flow/connectors/clickhouse/normalize.go b/flow/connectors/clickhouse/normalize.go index 914da50604..225aa2845c 100644 --- a/flow/connectors/clickhouse/normalize.go +++ b/flow/connectors/clickhouse/normalize.go @@ -19,41 +19,48 @@ const ( versionColType = "Int64" ) -func (c *ClickhouseConnector) SetupNormalizedTables( - ctx context.Context, - req *protos.SetupNormalizedTableBatchInput, -) (*protos.SetupNormalizedTableBatchOutput, error) { - tableExistsMapping := make(map[string]bool) - for tableIdentifier, tableSchema := range req.TableNameSchemaMapping { - tableAlreadyExists, err := c.checkIfTableExists(ctx, c.config.Database, tableIdentifier) - if err != nil { - return nil, fmt.Errorf("error occurred while checking if normalized table exists: %w", err) - } - if tableAlreadyExists { - tableExistsMapping[tableIdentifier] = true - continue - } +func (c *ClickhouseConnector) StartSetupNormalizedTables(_ context.Context) (interface{}, error) { + return nil, nil +} - normalizedTableCreateSQL, err := generateCreateTableSQLForNormalizedTable( - tableIdentifier, - tableSchema, - req.SoftDeleteColName, - req.SyncedAtColName, - ) - if err != nil { - return nil, fmt.Errorf("error while generating create table sql for normalized table: %w", err) - } +func (c *ClickhouseConnector) FinishSetupNormalizedTables(_ context.Context) error { + return nil +} - _, err = c.database.ExecContext(ctx, normalizedTableCreateSQL) - if err != nil { - return nil, fmt.Errorf("[sf] error while creating normalized table: %w", err) - } - tableExistsMapping[tableIdentifier] = false +func (c *ClickhouseConnector) AbortSetupNormalizedTables(_ context.Context, _ interface{}) { +} + +func (c *ClickhouseConnector) SetupNormalizedTable( + ctx context.Context, + tx interface{}, + tableIdentifier string, + tableSchema *protos.TableSchema, + softDeleteColName string, + syncedAtColName string, +) (bool, error) { + tableAlreadyExists, err := c.checkIfTableExists(ctx, c.config.Database, tableIdentifier) + if err != nil { + return false, fmt.Errorf("error occurred while checking if normalized table exists: %w", err) + } + if tableAlreadyExists { + return true, nil } - return &protos.SetupNormalizedTableBatchOutput{ - TableExistsMapping: tableExistsMapping, - }, nil + normalizedTableCreateSQL, err := generateCreateTableSQLForNormalizedTable( + tableIdentifier, + tableSchema, + softDeleteColName, + syncedAtColName, + ) + if err != nil { + return false, fmt.Errorf("error while generating create table sql for normalized table: %w", err) + } + + _, err = c.database.ExecContext(ctx, normalizedTableCreateSQL) + if err != nil { + return false, fmt.Errorf("[sf] error while creating normalized table: %w", err) + } + return false, nil } func generateCreateTableSQLForNormalizedTable( diff --git a/flow/connectors/core.go b/flow/connectors/core.go index 49ba9151a0..35330136b0 100644 --- a/flow/connectors/core.go +++ b/flow/connectors/core.go @@ -58,6 +58,30 @@ type CDCPullConnector interface { AddTablesToPublication(ctx context.Context, req *protos.AddTablesToPublicationInput) error } +type NormalizedTablesConnector interface { + Connector + + // StartSetupNormalizedTables may be used to have SetupNormalizedTable calls run in a transaction. + StartSetupNormalizedTables(ctx context.Context) (interface{}, error) + + // SetupNormalizedTable sets up the normalized table on the connector. + SetupNormalizedTable( + ctx context.Context, + tx interface{}, + tableIdentifier string, + tableSchema *protos.TableSchema, + softDeleteColName string, + syncedAtColName string, + ) (bool, error) + + // AbortSetupNormalizedTables may be used to rollback transaction started by StartSetupNormalizedTables. + // Calling AbortSetupNormalizedTables after FinishSetupNormalizedTables must be a nop. + AbortSetupNormalizedTables(ctx context.Context, tx interface{}) + + // FinishSetupNormalizedTables may be used to finish transaction started by StartSetupNormalizedTables. + FinishSetupNormalizedTables(ctx context.Context, tx interface{}) error +} + type CDCSyncConnector interface { Connector @@ -84,10 +108,6 @@ type CDCSyncConnector interface { // Connectors which are non-normalizing should implement this as a nop. ReplayTableSchemaDeltas(ctx context.Context, flowJobName string, schemaDeltas []*protos.TableSchemaDelta) error - // SetupNormalizedTables sets up the normalized table on the connector. - SetupNormalizedTables(ctx context.Context, req *protos.SetupNormalizedTableBatchInput) ( - *protos.SetupNormalizedTableBatchOutput, error) - // SyncRecords pushes records to the destination peer and stores it in PeerDB specific tables. // This method should be idempotent, and should be able to be called multiple times with the same request. SyncRecords(ctx context.Context, req *model.SyncRecordsRequest) (*model.SyncResponse, error) @@ -136,19 +156,8 @@ type QRepConsolidateConnector interface { CleanupQRepFlow(ctx context.Context, config *protos.QRepConfig) error } -func GetCDCPullConnector(ctx context.Context, config *protos.Peer) (CDCPullConnector, error) { - inner := config.Config - switch inner.(type) { - case *protos.Peer_PostgresConfig: - return connpostgres.NewPostgresConnector(ctx, config.GetPostgresConfig()) - default: - return nil, ErrUnsupportedFunctionality - } -} - -func GetCDCSyncConnector(ctx context.Context, config *protos.Peer) (CDCSyncConnector, error) { - inner := config.Config - switch inner.(type) { +func GetConnector(ctx context.Context, config *protos.Peer) (Connector, error) { + switch config.Config.(type) { case *protos.Peer_PostgresConfig: return connpostgres.NewPostgresConnector(ctx, config.GetPostgresConfig()) case *protos.Peer_BigqueryConfig: @@ -161,6 +170,8 @@ func GetCDCSyncConnector(ctx context.Context, config *protos.Peer) (CDCSyncConne return conneventhub.NewEventHubConnector(ctx, config.GetEventhubGroupConfig()) case *protos.Peer_S3Config: return conns3.NewS3Connector(ctx, config.GetS3Config()) + case *protos.Peer_SqlserverConfig: + return connsqlserver.NewSQLServerConnector(ctx, config.GetSqlserverConfig()) case *protos.Peer_ClickhouseConfig: return connclickhouse.NewClickhouseConnector(ctx, config.GetClickhouseConfig()) default: @@ -168,115 +179,42 @@ func GetCDCSyncConnector(ctx context.Context, config *protos.Peer) (CDCSyncConne } } -func GetCDCNormalizeConnector(ctx context.Context, - config *protos.Peer, -) (CDCNormalizeConnector, error) { - inner := config.Config - switch inner.(type) { - case *protos.Peer_PostgresConfig: - return connpostgres.NewPostgresConnector(ctx, config.GetPostgresConfig()) - case *protos.Peer_BigqueryConfig: - return connbigquery.NewBigQueryConnector(ctx, config.GetBigqueryConfig()) - case *protos.Peer_SnowflakeConfig: - return connsnowflake.NewSnowflakeConnector(ctx, config.GetSnowflakeConfig()) - case *protos.Peer_ClickhouseConfig: - return connclickhouse.NewClickhouseConnector(ctx, config.GetClickhouseConfig()) - default: - return nil, ErrUnsupportedFunctionality +func GetConnectorAs[T Connector](ctx context.Context, config *protos.Peer) (T, error) { + var none T + conn, err := GetConnector(ctx, config) + if err != nil { + return none, err } -} -func GetQRepPullConnector(ctx context.Context, config *protos.Peer) (QRepPullConnector, error) { - inner := config.Config - switch inner.(type) { - case *protos.Peer_PostgresConfig: - return connpostgres.NewPostgresConnector(ctx, config.GetPostgresConfig()) - case *protos.Peer_SqlserverConfig: - return connsqlserver.NewSQLServerConnector(ctx, config.GetSqlserverConfig()) - default: - return nil, ErrUnsupportedFunctionality + if conn, ok := conn.(T); ok { + return conn, nil + } else { + return none, ErrUnsupportedFunctionality } } -func GetQRepSyncConnector(ctx context.Context, config *protos.Peer) (QRepSyncConnector, error) { - inner := config.Config - switch inner.(type) { - case *protos.Peer_PostgresConfig: - return connpostgres.NewPostgresConnector(ctx, config.GetPostgresConfig()) - case *protos.Peer_BigqueryConfig: - return connbigquery.NewBigQueryConnector(ctx, config.GetBigqueryConfig()) - case *protos.Peer_SnowflakeConfig: - return connsnowflake.NewSnowflakeConnector(ctx, config.GetSnowflakeConfig()) - case *protos.Peer_S3Config: - return conns3.NewS3Connector(ctx, config.GetS3Config()) - case *protos.Peer_ClickhouseConfig: - return connclickhouse.NewClickhouseConnector(ctx, config.GetClickhouseConfig()) - default: - return nil, ErrUnsupportedFunctionality - } +func GetCDCPullConnector(ctx context.Context, config *protos.Peer) (CDCPullConnector, error) { + return GetConnectorAs[CDCPullConnector](ctx, config) } -func GetConnector(ctx context.Context, peer *protos.Peer) (Connector, error) { - inner := peer.Type - switch inner { - case protos.DBType_POSTGRES: - pgConfig := peer.GetPostgresConfig() - - if pgConfig == nil { - return nil, fmt.Errorf("missing postgres config for %s peer %s", peer.Type.String(), peer.Name) - } - // we can't decide if a PG peer should have replication permissions on it because we don't know - // what the user wants to do with it, so defaulting to being permissive. - // can be revisited in the future or we can use some UI wizardry. - return connpostgres.NewPostgresConnector(ctx, pgConfig) - case protos.DBType_BIGQUERY: - bqConfig := peer.GetBigqueryConfig() - if bqConfig == nil { - return nil, fmt.Errorf("missing bigquery config for %s peer %s", peer.Type.String(), peer.Name) - } - return connbigquery.NewBigQueryConnector(ctx, bqConfig) - - case protos.DBType_SNOWFLAKE: - sfConfig := peer.GetSnowflakeConfig() - if sfConfig == nil { - return nil, fmt.Errorf("missing snowflake config for %s peer %s", peer.Type.String(), peer.Name) - } - return connsnowflake.NewSnowflakeConnector(ctx, sfConfig) - case protos.DBType_SQLSERVER: - sqlServerConfig := peer.GetSqlserverConfig() - if sqlServerConfig == nil { - return nil, fmt.Errorf("missing sqlserver config for %s peer %s", peer.Type.String(), peer.Name) - } - return connsqlserver.NewSQLServerConnector(ctx, sqlServerConfig) - case protos.DBType_S3: - s3Config := peer.GetS3Config() - if s3Config == nil { - return nil, fmt.Errorf("missing s3 config for %s peer %s", peer.Type.String(), peer.Name) - } - return conns3.NewS3Connector(ctx, s3Config) - case protos.DBType_CLICKHOUSE: - clickhouseConfig := peer.GetClickhouseConfig() - if clickhouseConfig == nil { - return nil, fmt.Errorf("missing clickhouse config for %s peer %s", peer.Type.String(), peer.Name) - } - return connclickhouse.NewClickhouseConnector(ctx, clickhouseConfig) - default: - return nil, fmt.Errorf("unsupported peer type %s", peer.Type.String()) - } +func GetCDCSyncConnector(ctx context.Context, config *protos.Peer) (CDCSyncConnector, error) { + return GetConnectorAs[CDCSyncConnector](ctx, config) } -func GetQRepConsolidateConnector(ctx context.Context, - config *protos.Peer, -) (QRepConsolidateConnector, error) { - inner := config.Config - switch inner.(type) { - case *protos.Peer_SnowflakeConfig: - return connsnowflake.NewSnowflakeConnector(ctx, config.GetSnowflakeConfig()) - case *protos.Peer_ClickhouseConfig: - return connclickhouse.NewClickhouseConnector(ctx, config.GetClickhouseConfig()) - default: - return nil, ErrUnsupportedFunctionality - } +func GetCDCNormalizeConnector(ctx context.Context, config *protos.Peer) (CDCNormalizeConnector, error) { + return GetConnectorAs[CDCNormalizeConnector](ctx, config) +} + +func GetQRepPullConnector(ctx context.Context, config *protos.Peer) (QRepPullConnector, error) { + return GetConnectorAs[QRepPullConnector](ctx, config) +} + +func GetQRepSyncConnector(ctx context.Context, config *protos.Peer) (QRepSyncConnector, error) { + return GetConnectorAs[QRepSyncConnector](ctx, config) +} + +func GetQRepConsolidateConnector(ctx context.Context, config *protos.Peer) (QRepConsolidateConnector, error) { + return GetConnectorAs[QRepConsolidateConnector](ctx, config) } func CloseConnector(ctx context.Context, conn Connector) { diff --git a/flow/connectors/eventhub/eventhub.go b/flow/connectors/eventhub/eventhub.go index 55d2a4b6c4..2641dc5ff2 100644 --- a/flow/connectors/eventhub/eventhub.go +++ b/flow/connectors/eventhub/eventhub.go @@ -261,16 +261,6 @@ func (c *EventHubConnector) ReplayTableSchemaDeltas(_ context.Context, flowJobNa return nil } -func (c *EventHubConnector) SetupNormalizedTables( - _ context.Context, - req *protos.SetupNormalizedTableBatchInput, -) (*protos.SetupNormalizedTableBatchOutput, error) { - c.logger.Info("normalization for event hub is a no-op") - return &protos.SetupNormalizedTableBatchOutput{ - TableExistsMapping: nil, - }, nil -} - func (c *EventHubConnector) SyncFlowCleanup(ctx context.Context, jobName string) error { return c.pgMetadata.DropMetadata(ctx, jobName) } diff --git a/flow/connectors/postgres/postgres.go b/flow/connectors/postgres/postgres.go index ef0169b1b3..f9c6d60190 100644 --- a/flow/connectors/postgres/postgres.go +++ b/flow/connectors/postgres/postgres.go @@ -639,60 +639,53 @@ func (c *PostgresConnector) getTableSchemaForTable( }, nil } -// SetupNormalizedTable sets up a normalized table, implementing the Connector interface. -func (c *PostgresConnector) SetupNormalizedTables( - ctx context.Context, - req *protos.SetupNormalizedTableBatchInput, -) (*protos.SetupNormalizedTableBatchOutput, error) { - tableExistsMapping := make(map[string]bool) +func (c *PostgresConnector) StartSetupNormalizedTables(ctx context.Context) (interface{}, error) { // Postgres is cool and supports transactional DDL. So we use a transaction. - createNormalizedTablesTx, err := c.conn.Begin(ctx) - if err != nil { - return nil, fmt.Errorf("error starting transaction for creating raw table: %w", err) - } - - defer func() { - deferErr := createNormalizedTablesTx.Rollback(ctx) - if deferErr != pgx.ErrTxClosed && deferErr != nil { - c.logger.Error("error rolling back transaction for creating raw table", slog.Any("error", err)) - } - }() + return c.conn.Begin(ctx) +} - for tableIdentifier, tableSchema := range req.TableNameSchemaMapping { - parsedNormalizedTable, err := utils.ParseSchemaTable(tableIdentifier) - if err != nil { - return nil, fmt.Errorf("error while parsing table schema and name: %w", err) - } - tableAlreadyExists, err := c.tableExists(ctx, parsedNormalizedTable) - if err != nil { - return nil, fmt.Errorf("error occurred while checking if normalized table exists: %w", err) - } - if tableAlreadyExists { - tableExistsMapping[tableIdentifier] = true - continue - } +func (c *PostgresConnector) AbortSetupNormalizedTables(ctx context.Context, tx interface{}) { + err := tx.(pgx.Tx).Rollback(ctx) + if err != pgx.ErrTxClosed && err != nil { + c.logger.Error("error rolling back transaction for creating raw table", slog.Any("error", err)) + } +} - // convert the column names and types to Postgres types - normalizedTableCreateSQL := generateCreateTableSQLForNormalizedTable( - parsedNormalizedTable.String(), tableSchema, req.SoftDeleteColName, req.SyncedAtColName) - _, err = createNormalizedTablesTx.Exec(ctx, normalizedTableCreateSQL) - if err != nil { - return nil, fmt.Errorf("error while creating normalized table: %w", err) - } +func (c *PostgresConnector) FinishSetupNormalizedTables(ctx context.Context, tx interface{}) error { + return tx.(pgx.Tx).Commit(ctx) +} - tableExistsMapping[tableIdentifier] = false - c.logger.Info(fmt.Sprintf("created table %s", tableIdentifier)) - utils.RecordHeartbeat(ctx, fmt.Sprintf("created table %s", tableIdentifier)) +func (c *PostgresConnector) SetupNormalizedTable( + ctx context.Context, + tx interface{}, + tableIdentifier string, + tableSchema *protos.TableSchema, + softDeleteColName string, + syncedAtColName string, +) (bool, error) { + createNormalizedTablesTx := tx.(pgx.Tx) + + parsedNormalizedTable, err := utils.ParseSchemaTable(tableIdentifier) + if err != nil { + return false, fmt.Errorf("error while parsing table schema and name: %w", err) + } + tableAlreadyExists, err := c.tableExists(ctx, parsedNormalizedTable) + if err != nil { + return false, fmt.Errorf("error occurred while checking if normalized table exists: %w", err) + } + if tableAlreadyExists { + return true, nil } - err = createNormalizedTablesTx.Commit(ctx) + // convert the column names and types to Postgres types + normalizedTableCreateSQL := generateCreateTableSQLForNormalizedTable( + parsedNormalizedTable.String(), tableSchema, softDeleteColName, syncedAtColName) + _, err = createNormalizedTablesTx.Exec(ctx, normalizedTableCreateSQL) if err != nil { - return nil, fmt.Errorf("error committing transaction for creating normalized tables: %w", err) + return false, fmt.Errorf("error while creating normalized table: %w", err) } - return &protos.SetupNormalizedTableBatchOutput{ - TableExistsMapping: tableExistsMapping, - }, nil + return false, nil } // ReplayTableSchemaDelta changes a destination table to match the schema at source diff --git a/flow/connectors/s3/s3.go b/flow/connectors/s3/s3.go index 930f8f2204..5628a3c3ff 100644 --- a/flow/connectors/s3/s3.go +++ b/flow/connectors/s3/s3.go @@ -208,14 +208,6 @@ func (c *S3Connector) ReplayTableSchemaDeltas(_ context.Context, flowJobName str return nil } -func (c *S3Connector) SetupNormalizedTables(_ context.Context, req *protos.SetupNormalizedTableBatchInput) ( - *protos.SetupNormalizedTableBatchOutput, - error, -) { - c.logger.Info("SetupNormalizedTables for S3 is a no-op") - return nil, nil -} - func (c *S3Connector) SyncFlowCleanup(ctx context.Context, jobName string) error { return c.pgMetadata.DropMetadata(ctx, jobName) } diff --git a/flow/connectors/snowflake/snowflake.go b/flow/connectors/snowflake/snowflake.go index 4a6d47434b..ab13a52a34 100644 --- a/flow/connectors/snowflake/snowflake.go +++ b/flow/connectors/snowflake/snowflake.go @@ -329,38 +329,44 @@ func (c *SnowflakeConnector) getTableNameToUnchangedCols( return resultMap, nil } -func (c *SnowflakeConnector) SetupNormalizedTables( - ctx context.Context, - req *protos.SetupNormalizedTableBatchInput, -) (*protos.SetupNormalizedTableBatchOutput, error) { - tableExistsMapping := make(map[string]bool) - for tableIdentifier, tableSchema := range req.TableNameSchemaMapping { - normalizedSchemaTable, err := utils.ParseSchemaTable(tableIdentifier) - if err != nil { - return nil, fmt.Errorf("error while parsing table schema and name: %w", err) - } - tableAlreadyExists, err := c.checkIfTableExists(ctx, normalizedSchemaTable.Schema, normalizedSchemaTable.Table) - if err != nil { - return nil, fmt.Errorf("error occurred while checking if normalized table exists: %w", err) - } - if tableAlreadyExists { - tableExistsMapping[tableIdentifier] = true - continue - } +func (c *SnowflakeConnector) StartSetupNormalizedTables(_ context.Context) (interface{}, error) { + return nil, nil +} - normalizedTableCreateSQL := generateCreateTableSQLForNormalizedTable( - normalizedSchemaTable, tableSchema, req.SoftDeleteColName, req.SyncedAtColName) - _, err = c.database.ExecContext(ctx, normalizedTableCreateSQL) - if err != nil { - return nil, fmt.Errorf("[sf] error while creating normalized table: %w", err) - } - tableExistsMapping[tableIdentifier] = false - utils.RecordHeartbeat(ctx, fmt.Sprintf("created table %s", tableIdentifier)) +func (c *SnowflakeConnector) FinishSetupNormalizedTables(_ context.Context) error { + return nil +} + +func (c *SnowflakeConnector) AbortSetupNormalizedTables(_ context.Context, _ interface{}) { +} + +func (c *SnowflakeConnector) SetupNormalizedTable( + ctx context.Context, + tx interface{}, + tableIdentifier string, + tableSchema *protos.TableSchema, + softDeleteColName string, + syncedAtColName string, +) (bool, error) { + normalizedSchemaTable, err := utils.ParseSchemaTable(tableIdentifier) + if err != nil { + return false, fmt.Errorf("error while parsing table schema and name: %w", err) + } + tableAlreadyExists, err := c.checkIfTableExists(ctx, normalizedSchemaTable.Schema, normalizedSchemaTable.Table) + if err != nil { + return false, fmt.Errorf("error occurred while checking if normalized table exists: %w", err) + } + if tableAlreadyExists { + return true, nil } - return &protos.SetupNormalizedTableBatchOutput{ - TableExistsMapping: tableExistsMapping, - }, nil + normalizedTableCreateSQL := generateCreateTableSQLForNormalizedTable( + normalizedSchemaTable, tableSchema, softDeleteColName, syncedAtColName) + _, err = c.database.ExecContext(ctx, normalizedTableCreateSQL) + if err != nil { + return false, fmt.Errorf("[sf] error while creating normalized table: %w", err) + } + return false, nil } // ReplayTableSchemaDeltas changes a destination table to match the schema at source diff --git a/flow/workflows/qrep_flow.go b/flow/workflows/qrep_flow.go index 5538a5258b..b59eba7430 100644 --- a/flow/workflows/qrep_flow.go +++ b/flow/workflows/qrep_flow.go @@ -150,8 +150,7 @@ func (q *QRepFlowExecution) SetupWatermarkTableOnDestination(ctx workflow.Contex } future := workflow.ExecuteActivity(ctx, flowable.CreateNormalizedTable, setupConfig) - var createNormalizedTablesOutput *protos.SetupNormalizedTableBatchOutput - if err := future.Get(ctx, &createNormalizedTablesOutput); err != nil { + if err := future.Get(ctx, nil); err != nil { q.logger.Error("failed to create watermark table: ", err) return fmt.Errorf("failed to create watermark table: %w", err) } diff --git a/flow/workflows/setup_flow.go b/flow/workflows/setup_flow.go index 131a04a1e0..6ee6529dd6 100644 --- a/flow/workflows/setup_flow.go +++ b/flow/workflows/setup_flow.go @@ -239,8 +239,7 @@ func (s *SetupFlowExecution) fetchTableSchemaAndSetupNormalizedTables( } future = workflow.ExecuteActivity(ctx, flowable.CreateNormalizedTable, setupConfig) - var createNormalizedTablesOutput *protos.SetupNormalizedTableBatchOutput - if err := future.Get(ctx, &createNormalizedTablesOutput); err != nil { + if err := future.Get(ctx, nil); err != nil { s.logger.Error("failed to create normalized tables: ", err) return nil, fmt.Errorf("failed to create normalized tables: %w", err) } From 770cdc3089f1e76a79cb92dcc885432b6389c91c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20Dub=C3=A9?= Date: Tue, 13 Feb 2024 15:56:46 +0000 Subject: [PATCH 2/3] Fix interface implementations --- flow/activities/flowable.go | 4 ++-- flow/connectors/bigquery/bigquery.go | 2 +- flow/connectors/clickhouse/normalize.go | 2 +- flow/connectors/core.go | 5 +++++ flow/connectors/snowflake/snowflake.go | 2 +- 5 files changed, 10 insertions(+), 5 deletions(-) diff --git a/flow/activities/flowable.go b/flow/activities/flowable.go index d63433a0f5..96166de6ef 100644 --- a/flow/activities/flowable.go +++ b/flow/activities/flowable.go @@ -147,11 +147,12 @@ func (a *FlowableActivity) CreateNormalizedTable( ctx context.Context, config *protos.SetupNormalizedTableBatchInput, ) (*protos.SetupNormalizedTableBatchOutput, error) { + logger := activity.GetLogger(ctx) ctx = context.WithValue(ctx, shared.FlowNameKey, config.FlowName) conn, err := connectors.GetConnectorAs[connectors.NormalizedTablesConnector](ctx, config.PeerConnectionConfig) if err != nil { if err == connectors.ErrUnsupportedFunctionality { - activity.GetLogger(ctx).Info("Connector does not implement normalized tables") + logger.Info("Connector does not implement normalized tables") return nil, nil } return nil, fmt.Errorf("failed to get connector: %w", err) @@ -172,7 +173,6 @@ func (a *FlowableActivity) CreateNormalizedTable( }) defer shutdown() - logger := activity.GetLogger(ctx) tableExistsMapping := make(map[string]bool) for tableIdentifier, tableSchema := range config.TableNameSchemaMapping { created, err := conn.SetupNormalizedTable( diff --git a/flow/connectors/bigquery/bigquery.go b/flow/connectors/bigquery/bigquery.go index edc2546903..a7cf31ce9c 100644 --- a/flow/connectors/bigquery/bigquery.go +++ b/flow/connectors/bigquery/bigquery.go @@ -652,7 +652,7 @@ func (c *BigQueryConnector) StartSetupNormalizedTables(_ context.Context) (inter return make(map[datasetTable]struct{}), nil } -func (c *BigQueryConnector) FinishSetupNormalizedTables(_ context.Context) error { +func (c *BigQueryConnector) FinishSetupNormalizedTables(_ context.Context, _ interface{}) error { return nil } diff --git a/flow/connectors/clickhouse/normalize.go b/flow/connectors/clickhouse/normalize.go index 225aa2845c..0a43ed0353 100644 --- a/flow/connectors/clickhouse/normalize.go +++ b/flow/connectors/clickhouse/normalize.go @@ -23,7 +23,7 @@ func (c *ClickhouseConnector) StartSetupNormalizedTables(_ context.Context) (int return nil, nil } -func (c *ClickhouseConnector) FinishSetupNormalizedTables(_ context.Context) error { +func (c *ClickhouseConnector) FinishSetupNormalizedTables(_ context.Context, _ interface{}) error { return nil } diff --git a/flow/connectors/core.go b/flow/connectors/core.go index 23bbe4e04d..6b2e470ac6 100644 --- a/flow/connectors/core.go +++ b/flow/connectors/core.go @@ -244,6 +244,11 @@ var ( _ CDCNormalizeConnector = &connsnowflake.SnowflakeConnector{} _ CDCNormalizeConnector = &connclickhouse.ClickhouseConnector{} + _ NormalizedTablesConnector = &connpostgres.PostgresConnector{} + _ NormalizedTablesConnector = &connbigquery.BigQueryConnector{} + _ NormalizedTablesConnector = &connsnowflake.SnowflakeConnector{} + _ NormalizedTablesConnector = &connclickhouse.ClickhouseConnector{} + _ QRepPullConnector = &connpostgres.PostgresConnector{} _ QRepPullConnector = &connsqlserver.SQLServerConnector{} diff --git a/flow/connectors/snowflake/snowflake.go b/flow/connectors/snowflake/snowflake.go index ab13a52a34..befebcf284 100644 --- a/flow/connectors/snowflake/snowflake.go +++ b/flow/connectors/snowflake/snowflake.go @@ -333,7 +333,7 @@ func (c *SnowflakeConnector) StartSetupNormalizedTables(_ context.Context) (inte return nil, nil } -func (c *SnowflakeConnector) FinishSetupNormalizedTables(_ context.Context) error { +func (c *SnowflakeConnector) FinishSetupNormalizedTables(_ context.Context, _ interface{}) error { return nil } From 4fb2739e693a2124a75fccb0e49bf68b5f4def7b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20Dub=C3=A9?= Date: Tue, 13 Feb 2024 16:11:16 +0000 Subject: [PATCH 3/3] rename Abort to Cleanup --- flow/activities/flowable.go | 2 +- flow/connectors/bigquery/bigquery.go | 2 +- flow/connectors/clickhouse/normalize.go | 2 +- flow/connectors/core.go | 6 +++--- flow/connectors/postgres/postgres.go | 2 +- flow/connectors/snowflake/snowflake.go | 2 +- 6 files changed, 8 insertions(+), 8 deletions(-) diff --git a/flow/activities/flowable.go b/flow/activities/flowable.go index 96166de6ef..2086aa8cfd 100644 --- a/flow/activities/flowable.go +++ b/flow/activities/flowable.go @@ -163,7 +163,7 @@ func (a *FlowableActivity) CreateNormalizedTable( if err != nil { return nil, fmt.Errorf("failed to setup normalized tables tx: %w", err) } - defer conn.AbortSetupNormalizedTables(ctx, tx) + defer conn.CleanupSetupNormalizedTables(ctx, tx) numTablesSetup := atomic.Uint32{} totalTables := uint32(len(config.TableNameSchemaMapping)) diff --git a/flow/connectors/bigquery/bigquery.go b/flow/connectors/bigquery/bigquery.go index a7cf31ce9c..438994a800 100644 --- a/flow/connectors/bigquery/bigquery.go +++ b/flow/connectors/bigquery/bigquery.go @@ -656,7 +656,7 @@ func (c *BigQueryConnector) FinishSetupNormalizedTables(_ context.Context, _ int return nil } -func (c *BigQueryConnector) AbortSetupNormalizedTables(_ context.Context, _ interface{}) { +func (c *BigQueryConnector) CleanupSetupNormalizedTables(_ context.Context, _ interface{}) { } // This runs CREATE TABLE IF NOT EXISTS on bigquery, using the schema and table name provided. diff --git a/flow/connectors/clickhouse/normalize.go b/flow/connectors/clickhouse/normalize.go index 0a43ed0353..3d92ca9e91 100644 --- a/flow/connectors/clickhouse/normalize.go +++ b/flow/connectors/clickhouse/normalize.go @@ -27,7 +27,7 @@ func (c *ClickhouseConnector) FinishSetupNormalizedTables(_ context.Context, _ i return nil } -func (c *ClickhouseConnector) AbortSetupNormalizedTables(_ context.Context, _ interface{}) { +func (c *ClickhouseConnector) CleanupSetupNormalizedTables(_ context.Context, _ interface{}) { } func (c *ClickhouseConnector) SetupNormalizedTable( diff --git a/flow/connectors/core.go b/flow/connectors/core.go index 6b2e470ac6..f83a1a7836 100644 --- a/flow/connectors/core.go +++ b/flow/connectors/core.go @@ -74,9 +74,9 @@ type NormalizedTablesConnector interface { syncedAtColName string, ) (bool, error) - // AbortSetupNormalizedTables may be used to rollback transaction started by StartSetupNormalizedTables. - // Calling AbortSetupNormalizedTables after FinishSetupNormalizedTables must be a nop. - AbortSetupNormalizedTables(ctx context.Context, tx interface{}) + // CleanupSetupNormalizedTables may be used to rollback transaction started by StartSetupNormalizedTables. + // Calling CleanupSetupNormalizedTables after FinishSetupNormalizedTables must be a nop. + CleanupSetupNormalizedTables(ctx context.Context, tx interface{}) // FinishSetupNormalizedTables may be used to finish transaction started by StartSetupNormalizedTables. FinishSetupNormalizedTables(ctx context.Context, tx interface{}) error diff --git a/flow/connectors/postgres/postgres.go b/flow/connectors/postgres/postgres.go index f9c6d60190..6c9d02631f 100644 --- a/flow/connectors/postgres/postgres.go +++ b/flow/connectors/postgres/postgres.go @@ -644,7 +644,7 @@ func (c *PostgresConnector) StartSetupNormalizedTables(ctx context.Context) (int return c.conn.Begin(ctx) } -func (c *PostgresConnector) AbortSetupNormalizedTables(ctx context.Context, tx interface{}) { +func (c *PostgresConnector) CleanupSetupNormalizedTables(ctx context.Context, tx interface{}) { err := tx.(pgx.Tx).Rollback(ctx) if err != pgx.ErrTxClosed && err != nil { c.logger.Error("error rolling back transaction for creating raw table", slog.Any("error", err)) diff --git a/flow/connectors/snowflake/snowflake.go b/flow/connectors/snowflake/snowflake.go index befebcf284..1644787119 100644 --- a/flow/connectors/snowflake/snowflake.go +++ b/flow/connectors/snowflake/snowflake.go @@ -337,7 +337,7 @@ func (c *SnowflakeConnector) FinishSetupNormalizedTables(_ context.Context, _ in return nil } -func (c *SnowflakeConnector) AbortSetupNormalizedTables(_ context.Context, _ interface{}) { +func (c *SnowflakeConnector) CleanupSetupNormalizedTables(_ context.Context, _ interface{}) { } func (c *SnowflakeConnector) SetupNormalizedTable(