From a0be7d4cdbaac6237e5598782971eaebb3b200b0 Mon Sep 17 00:00:00 2001 From: Amogh-Bharadwaj Date: Tue, 5 Dec 2023 01:10:40 +0530 Subject: [PATCH] snowflake and postgres fix --- flow/connectors/postgres/client.go | 4 +- flow/connectors/postgres/postgres.go | 24 ++++++++- flow/connectors/postgres/qrep.go | 2 +- flow/connectors/snowflake/snowflake.go | 69 +++++++++++++++++++------- 4 files changed, 76 insertions(+), 23 deletions(-) diff --git a/flow/connectors/postgres/client.go b/flow/connectors/postgres/client.go index 00df76491a..8f552c57df 100644 --- a/flow/connectors/postgres/client.go +++ b/flow/connectors/postgres/client.go @@ -140,7 +140,7 @@ func (c *PostgresConnector) getPrimaryKeyColumns(schemaTable *utils.SchemaTable) return pkCols, nil } -func (c *PostgresConnector) tableExists(schemaTable *utils.SchemaTable) (*[]protos.TableColumn, error) { +func (c *PostgresConnector) tableExists(schemaTable *utils.SchemaTable) ([]protos.TableColumn, error) { rows, err := c.pool.Query(c.ctx, `SELECT COLUMN_NAME, DATA_TYPE FROM information_schema.columns WHERE table_schema = $1 AND table_name = $2`, schemaTable.Schema, @@ -161,7 +161,7 @@ func (c *PostgresConnector) tableExists(schemaTable *utils.SchemaTable) (*[]prot columns = append(columns, protos.TableColumn{Name: colName, Type: colType}) } - return &columns, nil + return columns, nil } // checkSlotAndPublication checks if the replication slot and publication exist. diff --git a/flow/connectors/postgres/postgres.go b/flow/connectors/postgres/postgres.go index 60af566a94..34247812a6 100644 --- a/flow/connectors/postgres/postgres.go +++ b/flow/connectors/postgres/postgres.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "regexp" + "strings" "time" "github.com/PeerDB-io/peer-flow/connectors/utils" @@ -121,7 +122,7 @@ func (c *PostgresConnector) NeedsSetupMetadataTables() bool { if err != nil { return true } - return len(*columns) != 0 + return len(columns) != 0 } // SetupMetadataTables sets up the metadata tables. @@ -645,7 +646,28 @@ func (c *PostgresConnector) SetupNormalizedTables(req *protos.SetupNormalizedTab return nil, fmt.Errorf("error occurred while checking if normalized table exists: %w", err) } if destinationColumns != nil { + sourceColumns := req.TableNameSchemaMapping[tableIdentifier].Columns tableExistsMapping[tableIdentifier] = true + log.Infoln("found existing normalized table, checking if it matches the desired schema") + if len(destinationColumns) != len(sourceColumns) { + return nil, fmt.Errorf("failed to setup normalized table: schemas on both sides differ") + } + for id := range destinationColumns { + column := &destinationColumns[id] + existingName := strings.ToLower(column.Name) + sourceType, ok := sourceColumns[existingName] + if !ok { + return nil, fmt.Errorf("failed to setup normalized table:"+ + "non-matching column name: %v", + existingName) + } + sourceTypeConverted := strings.ToLower(qValueKindToPostgresType(sourceType)) + if sourceTypeConverted != column.Type { + return nil, fmt.Errorf("failed to setup normalized table: mismatched column %v "+ + "with destination type %v and source type %v", + existingName, column.Type, sourceTypeConverted) + } + } continue } diff --git a/flow/connectors/postgres/qrep.go b/flow/connectors/postgres/qrep.go index af0ba13807..cde9678a5e 100644 --- a/flow/connectors/postgres/qrep.go +++ b/flow/connectors/postgres/qrep.go @@ -471,7 +471,7 @@ func (c *PostgresConnector) SyncQRepRecords( return 0, fmt.Errorf("failed to check if table exists: %w", err) } - if len(*sourceColumns) == 0 { + if len(sourceColumns) == 0 { return 0, fmt.Errorf("table %s does not exist, used schema: %s", dstTable.Table, dstTable.Schema) } diff --git a/flow/connectors/snowflake/snowflake.go b/flow/connectors/snowflake/snowflake.go index c624234923..00f5cb4476 100644 --- a/flow/connectors/snowflake/snowflake.go +++ b/flow/connectors/snowflake/snowflake.go @@ -67,8 +67,9 @@ const ( updateMetadataForSyncRecordsSQL = "UPDATE %s.%s SET OFFSET=?, SYNC_BATCH_ID=? WHERE MIRROR_JOB_NAME=?" updateMetadataForNormalizeRecordsSQL = "UPDATE %s.%s SET NORMALIZE_BATCH_ID=? WHERE MIRROR_JOB_NAME=?" - checkIfTableExistsSQL = `SELECT TO_BOOLEAN(COUNT(1)) FROM INFORMATION_SCHEMA.TABLES - WHERE TABLE_SCHEMA=? and TABLE_NAME=?` + checkIfTableExistsSQL = `SELECT COLUMN_NAME, DATA_TYPE + FROM information_schema.columns + WHERE table_schema=? and table_name=?` checkIfJobMetadataExistsSQL = "SELECT TO_BOOLEAN(COUNT(1)) FROM %s.%s WHERE MIRROR_JOB_NAME=?" getLastOffsetSQL = "SELECT OFFSET FROM %s.%s WHERE MIRROR_JOB_NAME=?" getLastSyncBatchID_SQL = "SELECT SYNC_BATCH_ID FROM %s.%s WHERE MIRROR_JOB_NAME=?" @@ -184,7 +185,7 @@ func (c *SnowflakeConnector) NeedsSetupMetadataTables() bool { if err != nil { return true } - return !result + return result != nil && len(result) == 0 } func (c *SnowflakeConnector) SetupMetadataTables() error { @@ -403,12 +404,39 @@ func (c *SnowflakeConnector) SetupNormalizedTables( if err != nil { return nil, fmt.Errorf("error while parsing table schema and name: %w", err) } - tableAlreadyExists, err := c.checkIfTableExists(normalizedTableNameComponents.schemaIdentifier, + destinationColumns, err := c.checkIfTableExists(normalizedTableNameComponents.schemaIdentifier, normalizedTableNameComponents.tableIdentifier) if err != nil { return nil, fmt.Errorf("error occurred while checking if normalized table exists: %w", err) } - if tableAlreadyExists { + if destinationColumns != nil { + sourceColumns := req.TableNameSchemaMapping[tableIdentifier].Columns + log.Infoln("found existing normalized table, checking if it matches the desired schema") + if len(destinationColumns) != len(sourceColumns) { + return nil, fmt.Errorf("failed to setup normalized table: schemas on both sides differ") + } + for id := range destinationColumns { + column := &destinationColumns[id] + existingName := strings.ToLower(column.Name) + sourceType, ok := sourceColumns[existingName] + if !ok { + return nil, fmt.Errorf("failed to setup normalized table:"+ + "non-matching column name: %v", + existingName) + } + sourceTypeConverted, typeErr := qValueKindToSnowflakeType(qvalue.QValueKind(sourceType)) + if typeErr != nil { + return nil, fmt.Errorf("failed to convert type in schema check: %w", typeErr) + } + exception := (sourceTypeConverted == "INTEGER" && column.Type == "NUMBER") || + (sourceTypeConverted == "STRING" && column.Type == "TEXT") + if sourceTypeConverted != column.Type && !exception { + return nil, fmt.Errorf("failed to setup normalized table: mismatched column %v"+ + " with destination type %v and source type %v", + existingName, column.Type, sourceTypeConverted) + } + } + tableExistsMapping[tableIdentifier] = true continue } @@ -456,13 +484,13 @@ func (c *SnowflakeConnector) ReplayTableSchemaDeltas(flowJobName string, } for _, addedColumn := range schemaDelta.AddedColumns { - sfColtype, err := qValueKindToSnowflakeType(qvalue.QValueKind(addedColumn.ColumnType)) + sfType, err := qValueKindToSnowflakeType(qvalue.QValueKind(addedColumn.ColumnType)) if err != nil { return fmt.Errorf("failed to convert column type %s to snowflake type: %w", addedColumn.ColumnType, err) } _, err = tableSchemaModifyTx.Exec(fmt.Sprintf("ALTER TABLE %s ADD COLUMN IF NOT EXISTS \"%s\" %s", - schemaDelta.DstTableName, strings.ToUpper(addedColumn.ColumnName), sfColtype)) + schemaDelta.DstTableName, strings.ToUpper(addedColumn.ColumnName), sfType)) if err != nil { return fmt.Errorf("failed to add column %s for table %s: %w", addedColumn.ColumnName, schemaDelta.DstTableName, err) @@ -738,20 +766,23 @@ func (c *SnowflakeConnector) SyncFlowCleanup(jobName string) error { return nil } -func (c *SnowflakeConnector) checkIfTableExists(schemaIdentifier string, tableIdentifier string) (bool, error) { - rows, err := c.database.QueryContext(c.ctx, checkIfTableExistsSQL, schemaIdentifier, tableIdentifier) +func (c *SnowflakeConnector) checkIfTableExists(schemaIdentifier string, tableIdentifier string) ([]protos.TableColumn, error) { + rows, err := c.database.QueryContext(c.ctx, checkIfTableExistsSQL, + strings.ToUpper(schemaIdentifier), strings.ToUpper(tableIdentifier)) if err != nil { - return false, err + return nil, err } - // this query is guaranteed to return exactly one row - var result bool - rows.Next() - err = rows.Scan(&result) - if err != nil { - return false, fmt.Errorf("error while reading result row: %w", err) + var columns []protos.TableColumn + for rows.Next() { + var colName, Type string + err = rows.Scan(&colName, &Type) + if err != nil { + return nil, fmt.Errorf("error while checking for existing table: %w", err) + } + columns = append(columns, protos.TableColumn{Name: colName, Type: Type}) } - return result, nil + return columns, nil } func generateCreateTableSQLForNormalizedTable( @@ -763,12 +794,12 @@ func generateCreateTableSQLForNormalizedTable( createTableSQLArray := make([]string, 0, len(sourceTableSchema.Columns)) for columnName, genericColumnType := range sourceTableSchema.Columns { columnNameUpper := strings.ToUpper(columnName) - sfColType, err := qValueKindToSnowflakeType(qvalue.QValueKind(genericColumnType)) + sfType, err := qValueKindToSnowflakeType(qvalue.QValueKind(genericColumnType)) if err != nil { log.Warnf("failed to convert column type %s to snowflake type: %v", genericColumnType, err) continue } - createTableSQLArray = append(createTableSQLArray, fmt.Sprintf(`"%s" %s,`, columnNameUpper, sfColType)) + createTableSQLArray = append(createTableSQLArray, fmt.Sprintf(`"%s" %s,`, columnNameUpper, sfType)) } // add a _peerdb_is_deleted column to the normalized table