Skip to content

Commit

Permalink
snowflake and postgres fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Amogh-Bharadwaj committed Dec 5, 2023
1 parent 7c3b604 commit a0be7d4
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 23 deletions.
4 changes: 2 additions & 2 deletions flow/connectors/postgres/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ func (c *PostgresConnector) getPrimaryKeyColumns(schemaTable *utils.SchemaTable)
return pkCols, nil
}

func (c *PostgresConnector) tableExists(schemaTable *utils.SchemaTable) (*[]protos.TableColumn, error) {
func (c *PostgresConnector) tableExists(schemaTable *utils.SchemaTable) ([]protos.TableColumn, error) {
rows, err := c.pool.Query(c.ctx,
`SELECT COLUMN_NAME, DATA_TYPE FROM information_schema.columns WHERE table_schema = $1 AND table_name = $2`,
schemaTable.Schema,
Expand All @@ -161,7 +161,7 @@ func (c *PostgresConnector) tableExists(schemaTable *utils.SchemaTable) (*[]prot
columns = append(columns, protos.TableColumn{Name: colName, Type: colType})
}

return &columns, nil
return columns, nil
}

// checkSlotAndPublication checks if the replication slot and publication exist.
Expand Down
24 changes: 23 additions & 1 deletion flow/connectors/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"regexp"
"strings"
"time"

"github.com/PeerDB-io/peer-flow/connectors/utils"
Expand Down Expand Up @@ -121,7 +122,7 @@ func (c *PostgresConnector) NeedsSetupMetadataTables() bool {
if err != nil {
return true
}
return len(*columns) != 0
return len(columns) != 0
}

// SetupMetadataTables sets up the metadata tables.
Expand Down Expand Up @@ -645,7 +646,28 @@ func (c *PostgresConnector) SetupNormalizedTables(req *protos.SetupNormalizedTab
return nil, fmt.Errorf("error occurred while checking if normalized table exists: %w", err)
}
if destinationColumns != nil {
sourceColumns := req.TableNameSchemaMapping[tableIdentifier].Columns
tableExistsMapping[tableIdentifier] = true
log.Infoln("found existing normalized table, checking if it matches the desired schema")
if len(destinationColumns) != len(sourceColumns) {
return nil, fmt.Errorf("failed to setup normalized table: schemas on both sides differ")
}
for id := range destinationColumns {
column := &destinationColumns[id]
existingName := strings.ToLower(column.Name)
sourceType, ok := sourceColumns[existingName]
if !ok {
return nil, fmt.Errorf("failed to setup normalized table:"+
"non-matching column name: %v",
existingName)
}
sourceTypeConverted := strings.ToLower(qValueKindToPostgresType(sourceType))
if sourceTypeConverted != column.Type {
return nil, fmt.Errorf("failed to setup normalized table: mismatched column %v "+
"with destination type %v and source type %v",
existingName, column.Type, sourceTypeConverted)
}
}
continue
}

Expand Down
2 changes: 1 addition & 1 deletion flow/connectors/postgres/qrep.go
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ func (c *PostgresConnector) SyncQRepRecords(
return 0, fmt.Errorf("failed to check if table exists: %w", err)
}

if len(*sourceColumns) == 0 {
if len(sourceColumns) == 0 {
return 0, fmt.Errorf("table %s does not exist, used schema: %s", dstTable.Table, dstTable.Schema)
}

Expand Down
69 changes: 50 additions & 19 deletions flow/connectors/snowflake/snowflake.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,9 @@ const (
updateMetadataForSyncRecordsSQL = "UPDATE %s.%s SET OFFSET=?, SYNC_BATCH_ID=? WHERE MIRROR_JOB_NAME=?"
updateMetadataForNormalizeRecordsSQL = "UPDATE %s.%s SET NORMALIZE_BATCH_ID=? WHERE MIRROR_JOB_NAME=?"

checkIfTableExistsSQL = `SELECT TO_BOOLEAN(COUNT(1)) FROM INFORMATION_SCHEMA.TABLES
WHERE TABLE_SCHEMA=? and TABLE_NAME=?`
checkIfTableExistsSQL = `SELECT COLUMN_NAME, DATA_TYPE
FROM information_schema.columns
WHERE table_schema=? and table_name=?`
checkIfJobMetadataExistsSQL = "SELECT TO_BOOLEAN(COUNT(1)) FROM %s.%s WHERE MIRROR_JOB_NAME=?"
getLastOffsetSQL = "SELECT OFFSET FROM %s.%s WHERE MIRROR_JOB_NAME=?"
getLastSyncBatchID_SQL = "SELECT SYNC_BATCH_ID FROM %s.%s WHERE MIRROR_JOB_NAME=?"
Expand Down Expand Up @@ -184,7 +185,7 @@ func (c *SnowflakeConnector) NeedsSetupMetadataTables() bool {
if err != nil {
return true
}
return !result
return result != nil && len(result) == 0
}

func (c *SnowflakeConnector) SetupMetadataTables() error {
Expand Down Expand Up @@ -403,12 +404,39 @@ func (c *SnowflakeConnector) SetupNormalizedTables(
if err != nil {
return nil, fmt.Errorf("error while parsing table schema and name: %w", err)
}
tableAlreadyExists, err := c.checkIfTableExists(normalizedTableNameComponents.schemaIdentifier,
destinationColumns, err := c.checkIfTableExists(normalizedTableNameComponents.schemaIdentifier,
normalizedTableNameComponents.tableIdentifier)
if err != nil {
return nil, fmt.Errorf("error occurred while checking if normalized table exists: %w", err)
}
if tableAlreadyExists {
if destinationColumns != nil {
sourceColumns := req.TableNameSchemaMapping[tableIdentifier].Columns
log.Infoln("found existing normalized table, checking if it matches the desired schema")
if len(destinationColumns) != len(sourceColumns) {
return nil, fmt.Errorf("failed to setup normalized table: schemas on both sides differ")
}
for id := range destinationColumns {
column := &destinationColumns[id]
existingName := strings.ToLower(column.Name)
sourceType, ok := sourceColumns[existingName]
if !ok {
return nil, fmt.Errorf("failed to setup normalized table:"+
"non-matching column name: %v",
existingName)
}
sourceTypeConverted, typeErr := qValueKindToSnowflakeType(qvalue.QValueKind(sourceType))
if typeErr != nil {
return nil, fmt.Errorf("failed to convert type in schema check: %w", typeErr)
}
exception := (sourceTypeConverted == "INTEGER" && column.Type == "NUMBER") ||
(sourceTypeConverted == "STRING" && column.Type == "TEXT")
if sourceTypeConverted != column.Type && !exception {
return nil, fmt.Errorf("failed to setup normalized table: mismatched column %v"+
" with destination type %v and source type %v",
existingName, column.Type, sourceTypeConverted)
}
}

tableExistsMapping[tableIdentifier] = true
continue
}
Expand Down Expand Up @@ -456,13 +484,13 @@ func (c *SnowflakeConnector) ReplayTableSchemaDeltas(flowJobName string,
}

for _, addedColumn := range schemaDelta.AddedColumns {
sfColtype, err := qValueKindToSnowflakeType(qvalue.QValueKind(addedColumn.ColumnType))
sfType, err := qValueKindToSnowflakeType(qvalue.QValueKind(addedColumn.ColumnType))
if err != nil {
return fmt.Errorf("failed to convert column type %s to snowflake type: %w",
addedColumn.ColumnType, err)
}
_, err = tableSchemaModifyTx.Exec(fmt.Sprintf("ALTER TABLE %s ADD COLUMN IF NOT EXISTS \"%s\" %s",
schemaDelta.DstTableName, strings.ToUpper(addedColumn.ColumnName), sfColtype))
schemaDelta.DstTableName, strings.ToUpper(addedColumn.ColumnName), sfType))
if err != nil {
return fmt.Errorf("failed to add column %s for table %s: %w", addedColumn.ColumnName,
schemaDelta.DstTableName, err)
Expand Down Expand Up @@ -738,20 +766,23 @@ func (c *SnowflakeConnector) SyncFlowCleanup(jobName string) error {
return nil
}

func (c *SnowflakeConnector) checkIfTableExists(schemaIdentifier string, tableIdentifier string) (bool, error) {
rows, err := c.database.QueryContext(c.ctx, checkIfTableExistsSQL, schemaIdentifier, tableIdentifier)
func (c *SnowflakeConnector) checkIfTableExists(schemaIdentifier string, tableIdentifier string) ([]protos.TableColumn, error) {
rows, err := c.database.QueryContext(c.ctx, checkIfTableExistsSQL,
strings.ToUpper(schemaIdentifier), strings.ToUpper(tableIdentifier))
if err != nil {
return false, err
return nil, err
}

// this query is guaranteed to return exactly one row
var result bool
rows.Next()
err = rows.Scan(&result)
if err != nil {
return false, fmt.Errorf("error while reading result row: %w", err)
var columns []protos.TableColumn
for rows.Next() {
var colName, Type string
err = rows.Scan(&colName, &Type)
if err != nil {
return nil, fmt.Errorf("error while checking for existing table: %w", err)
}
columns = append(columns, protos.TableColumn{Name: colName, Type: Type})
}
return result, nil
return columns, nil
}

func generateCreateTableSQLForNormalizedTable(
Expand All @@ -763,12 +794,12 @@ func generateCreateTableSQLForNormalizedTable(
createTableSQLArray := make([]string, 0, len(sourceTableSchema.Columns))
for columnName, genericColumnType := range sourceTableSchema.Columns {
columnNameUpper := strings.ToUpper(columnName)
sfColType, err := qValueKindToSnowflakeType(qvalue.QValueKind(genericColumnType))
sfType, err := qValueKindToSnowflakeType(qvalue.QValueKind(genericColumnType))
if err != nil {
log.Warnf("failed to convert column type %s to snowflake type: %v", genericColumnType, err)
continue
}
createTableSQLArray = append(createTableSQLArray, fmt.Sprintf(`"%s" %s,`, columnNameUpper, sfColType))
createTableSQLArray = append(createTableSQLArray, fmt.Sprintf(`"%s" %s,`, columnNameUpper, sfType))
}

// add a _peerdb_is_deleted column to the normalized table
Expand Down

0 comments on commit a0be7d4

Please sign in to comment.