diff --git a/flow/connectors/snowflake/client.go b/flow/connectors/snowflake/client.go index ae2da1a891..beb38a4a04 100644 --- a/flow/connectors/snowflake/client.go +++ b/flow/connectors/snowflake/client.go @@ -9,6 +9,7 @@ import ( "github.com/snowflakedb/gosnowflake" peersql "github.com/PeerDB-io/peer-flow/connectors/sql" + "github.com/PeerDB-io/peer-flow/connectors/utils" "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/model/qvalue" "github.com/PeerDB-io/peer-flow/shared" @@ -69,7 +70,7 @@ func NewSnowflakeClient(ctx context.Context, config *protos.SnowflakeConfig) (*S func (c *SnowflakeConnector) getTableCounts(tables []string) (int64, error) { var totalRecords int64 for _, table := range tables { - _, err := parseTableName(table) + _, err := utils.ParseSchemaTable(table) if err != nil { return 0, fmt.Errorf("failed to parse table name %s: %w", table, err) } diff --git a/flow/connectors/snowflake/qrep.go b/flow/connectors/snowflake/qrep.go index def870c183..1d81392d71 100644 --- a/flow/connectors/snowflake/qrep.go +++ b/flow/connectors/snowflake/qrep.go @@ -278,21 +278,17 @@ func (c *SnowflakeConnector) CleanupQRepFlow(config *protos.QRepConfig) error { func (c *SnowflakeConnector) getColsFromTable(tableName string) (*model.ColumnInformation, error) { // parse the table name to get the schema and table name - components, err := parseTableName(tableName) + schemaTable, err := utils.ParseSchemaTable(tableName) if err != nil { return nil, fmt.Errorf("failed to parse table name: %w", err) } - // convert tableIdentifier and schemaIdentifier to upper case - components.tableIdentifier = strings.ToUpper(components.tableIdentifier) - components.schemaIdentifier = strings.ToUpper(components.schemaIdentifier) - //nolint:gosec queryString := fmt.Sprintf(` SELECT column_name, data_type FROM information_schema.columns WHERE UPPER(table_name) = '%s' AND UPPER(table_schema) = '%s' - `, components.tableIdentifier, components.schemaIdentifier) + `, strings.ToUpper(schemaTable.Table), strings.ToUpper(schemaTable.Schema)) rows, err := c.database.QueryContext(c.ctx, queryString) if err != nil { diff --git a/flow/connectors/snowflake/snowflake.go b/flow/connectors/snowflake/snowflake.go index 5d2c7e03b1..55dd4444df 100644 --- a/flow/connectors/snowflake/snowflake.go +++ b/flow/connectors/snowflake/snowflake.go @@ -81,11 +81,6 @@ const ( checkSchemaExistsSQL = "SELECT TO_BOOLEAN(COUNT(1)) FROM INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME=?" ) -type tableNameComponents struct { - schemaIdentifier string - tableIdentifier string -} - type SnowflakeConnector struct { ctx context.Context database *sql.DB @@ -245,12 +240,11 @@ func (c *SnowflakeConnector) GetTableSchema( } func (c *SnowflakeConnector) getTableSchemaForTable(tableName string) (*protos.TableSchema, error) { - tableNameComponents, err := parseTableName(tableName) + schemaTable, err := utils.ParseSchemaTable(tableName) if err != nil { return nil, fmt.Errorf("error while parsing table schema and name: %w", err) } - rows, err := c.database.QueryContext(c.ctx, getTableSchemaSQL, tableNameComponents.schemaIdentifier, - tableNameComponents.tableIdentifier) + rows, err := c.database.QueryContext(c.ctx, getTableSchemaSQL, schemaTable.Schema, schemaTable.Table) if err != nil { return nil, fmt.Errorf("error querying Snowflake peer for schema of table %s: %w", tableName, err) } @@ -423,12 +417,11 @@ func (c *SnowflakeConnector) SetupNormalizedTables( ) (*protos.SetupNormalizedTableBatchOutput, error) { tableExistsMapping := make(map[string]bool) for tableIdentifier, tableSchema := range req.TableNameSchemaMapping { - normalizedTableNameComponents, err := parseTableName(tableIdentifier) + normalizedSchemaTable, err := utils.ParseSchemaTable(tableIdentifier) if err != nil { return nil, fmt.Errorf("error while parsing table schema and name: %w", err) } - tableAlreadyExists, err := c.checkIfTableExists(normalizedTableNameComponents.schemaIdentifier, - normalizedTableNameComponents.tableIdentifier) + tableAlreadyExists, err := c.checkIfTableExists(normalizedSchemaTable.Schema, normalizedSchemaTable.Table) if err != nil { return nil, fmt.Errorf("error occurred while checking if normalized table exists: %w", err) } @@ -940,16 +933,6 @@ func (c *SnowflakeConnector) generateAndExecuteMergeStatement( return result.RowsAffected() } -// parseTableName parses a table name into schema and table name. -func parseTableName(tableName string) (*tableNameComponents, error) { - schemaIdentifier, tableIdentifier, hasDot := strings.Cut(tableName, ".") - if !hasDot || strings.ContainsRune(tableIdentifier, '.') { - return nil, fmt.Errorf("invalid table name: %s", tableName) - } - - return &tableNameComponents{schemaIdentifier, tableIdentifier}, nil -} - func (c *SnowflakeConnector) jobMetadataExists(jobName string) (bool, error) { var result pgtype.Bool err := c.database.QueryRowContext(c.ctx, diff --git a/flow/connectors/utils/identifiers.go b/flow/connectors/utils/identifiers.go index 0b91b9e4f3..2ae919488d 100644 --- a/flow/connectors/utils/identifiers.go +++ b/flow/connectors/utils/identifiers.go @@ -1,7 +1,30 @@ package utils -import "fmt" +import ( + "fmt" + "strings" +) func QuoteIdentifier(identifier string) string { return fmt.Sprintf(`"%s"`, identifier) } + +// SchemaTable is a table in a schema. +type SchemaTable struct { + Schema string + Table string +} + +func (t *SchemaTable) String() string { + return fmt.Sprintf(`"%s"."%s"`, t.Schema, t.Table) +} + +// ParseSchemaTable parses a table name into schema and table name. +func ParseSchemaTable(tableName string) (*SchemaTable, error) { + schema, table, hasDot := strings.Cut(tableName, ".") + if !hasDot || strings.ContainsRune(table, '.') { + return nil, fmt.Errorf("invalid table name: %s", tableName) + } + + return &SchemaTable{schema, table}, nil +} diff --git a/flow/connectors/utils/postgres.go b/flow/connectors/utils/postgres.go index 72aaf86f5e..58cd02f205 100644 --- a/flow/connectors/utils/postgres.go +++ b/flow/connectors/utils/postgres.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "net/url" - "strings" "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/jackc/pgx/v5/pgtype" @@ -49,26 +48,3 @@ func GetCustomDataTypes(ctx context.Context, pool *pgxpool.Pool) (map[uint32]str } return customTypeMap, nil } - -// SchemaTable is a table in a schema. -type SchemaTable struct { - Schema string - Table string -} - -func (t *SchemaTable) String() string { - return fmt.Sprintf(`"%s"."%s"`, t.Schema, t.Table) -} - -// ParseSchemaTable parses a table name into schema and table name. -func ParseSchemaTable(tableName string) (*SchemaTable, error) { - parts := strings.Split(tableName, ".") - if len(parts) != 2 { - return nil, fmt.Errorf("invalid table name: %s", tableName) - } - - return &SchemaTable{ - Schema: parts[0], - Table: parts[1], - }, nil -}