Skip to content

Commit

Permalink
replacing SF tableNameComponents with PG's version
Browse files Browse the repository at this point in the history
  • Loading branch information
heavycrystal committed Dec 26, 2023
1 parent 2f18f76 commit 2be8bf0
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 53 deletions.
3 changes: 2 additions & 1 deletion flow/connectors/snowflake/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/snowflakedb/gosnowflake"

peersql "github.com/PeerDB-io/peer-flow/connectors/sql"
"github.com/PeerDB-io/peer-flow/connectors/utils"
"github.com/PeerDB-io/peer-flow/generated/protos"
"github.com/PeerDB-io/peer-flow/model/qvalue"
"github.com/PeerDB-io/peer-flow/shared"
Expand Down Expand Up @@ -69,7 +70,7 @@ func NewSnowflakeClient(ctx context.Context, config *protos.SnowflakeConfig) (*S
func (c *SnowflakeConnector) getTableCounts(tables []string) (int64, error) {
var totalRecords int64
for _, table := range tables {
_, err := parseTableName(table)
_, err := utils.ParseSchemaTable(table)
if err != nil {
return 0, fmt.Errorf("failed to parse table name %s: %w", table, err)
}
Expand Down
8 changes: 2 additions & 6 deletions flow/connectors/snowflake/qrep.go
Original file line number Diff line number Diff line change
Expand Up @@ -278,21 +278,17 @@ func (c *SnowflakeConnector) CleanupQRepFlow(config *protos.QRepConfig) error {

func (c *SnowflakeConnector) getColsFromTable(tableName string) (*model.ColumnInformation, error) {
// parse the table name to get the schema and table name
components, err := parseTableName(tableName)
schemaTable, err := utils.ParseSchemaTable(tableName)
if err != nil {
return nil, fmt.Errorf("failed to parse table name: %w", err)
}

// convert tableIdentifier and schemaIdentifier to upper case
components.tableIdentifier = strings.ToUpper(components.tableIdentifier)
components.schemaIdentifier = strings.ToUpper(components.schemaIdentifier)

//nolint:gosec
queryString := fmt.Sprintf(`
SELECT column_name, data_type
FROM information_schema.columns
WHERE UPPER(table_name) = '%s' AND UPPER(table_schema) = '%s'
`, components.tableIdentifier, components.schemaIdentifier)
`, strings.ToUpper(schemaTable.Table), strings.ToUpper(schemaTable.Schema))

rows, err := c.database.QueryContext(c.ctx, queryString)
if err != nil {
Expand Down
25 changes: 4 additions & 21 deletions flow/connectors/snowflake/snowflake.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,6 @@ const (
checkSchemaExistsSQL = "SELECT TO_BOOLEAN(COUNT(1)) FROM INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME=?"
)

type tableNameComponents struct {
schemaIdentifier string
tableIdentifier string
}

type SnowflakeConnector struct {
ctx context.Context
database *sql.DB
Expand Down Expand Up @@ -245,12 +240,11 @@ func (c *SnowflakeConnector) GetTableSchema(
}

func (c *SnowflakeConnector) getTableSchemaForTable(tableName string) (*protos.TableSchema, error) {
tableNameComponents, err := parseTableName(tableName)
schemaTable, err := utils.ParseSchemaTable(tableName)
if err != nil {
return nil, fmt.Errorf("error while parsing table schema and name: %w", err)
}
rows, err := c.database.QueryContext(c.ctx, getTableSchemaSQL, tableNameComponents.schemaIdentifier,
tableNameComponents.tableIdentifier)
rows, err := c.database.QueryContext(c.ctx, getTableSchemaSQL, schemaTable.Schema, schemaTable.Table)
if err != nil {
return nil, fmt.Errorf("error querying Snowflake peer for schema of table %s: %w", tableName, err)
}
Expand Down Expand Up @@ -423,12 +417,11 @@ func (c *SnowflakeConnector) SetupNormalizedTables(
) (*protos.SetupNormalizedTableBatchOutput, error) {
tableExistsMapping := make(map[string]bool)
for tableIdentifier, tableSchema := range req.TableNameSchemaMapping {
normalizedTableNameComponents, err := parseTableName(tableIdentifier)
normalizedSchemaTable, err := utils.ParseSchemaTable(tableIdentifier)
if err != nil {
return nil, fmt.Errorf("error while parsing table schema and name: %w", err)
}
tableAlreadyExists, err := c.checkIfTableExists(normalizedTableNameComponents.schemaIdentifier,
normalizedTableNameComponents.tableIdentifier)
tableAlreadyExists, err := c.checkIfTableExists(normalizedSchemaTable.Schema, normalizedSchemaTable.Table)
if err != nil {
return nil, fmt.Errorf("error occurred while checking if normalized table exists: %w", err)
}
Expand Down Expand Up @@ -940,16 +933,6 @@ func (c *SnowflakeConnector) generateAndExecuteMergeStatement(
return result.RowsAffected()
}

// parseTableName parses a table name into schema and table name.
func parseTableName(tableName string) (*tableNameComponents, error) {
schemaIdentifier, tableIdentifier, hasDot := strings.Cut(tableName, ".")
if !hasDot || strings.ContainsRune(tableIdentifier, '.') {
return nil, fmt.Errorf("invalid table name: %s", tableName)
}

return &tableNameComponents{schemaIdentifier, tableIdentifier}, nil
}

func (c *SnowflakeConnector) jobMetadataExists(jobName string) (bool, error) {
var result pgtype.Bool
err := c.database.QueryRowContext(c.ctx,
Expand Down
25 changes: 24 additions & 1 deletion flow/connectors/utils/identifiers.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,30 @@
package utils

import "fmt"
import (
"fmt"
"strings"
)

func QuoteIdentifier(identifier string) string {
return fmt.Sprintf(`"%s"`, identifier)
}

// SchemaTable is a table in a schema.
type SchemaTable struct {
Schema string
Table string
}

func (t *SchemaTable) String() string {
return fmt.Sprintf(`"%s"."%s"`, t.Schema, t.Table)
}

// ParseSchemaTable parses a table name into schema and table name.
func ParseSchemaTable(tableName string) (*SchemaTable, error) {
schema, table, hasDot := strings.Cut(tableName, ".")
if !hasDot || strings.ContainsRune(table, '.') {
return nil, fmt.Errorf("invalid table name: %s", tableName)
}

return &SchemaTable{schema, table}, nil
}
24 changes: 0 additions & 24 deletions flow/connectors/utils/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"fmt"
"net/url"
"strings"

"github.com/PeerDB-io/peer-flow/generated/protos"
"github.com/jackc/pgx/v5/pgtype"
Expand Down Expand Up @@ -49,26 +48,3 @@ func GetCustomDataTypes(ctx context.Context, pool *pgxpool.Pool) (map[uint32]str
}
return customTypeMap, nil
}

// SchemaTable is a table in a schema.
type SchemaTable struct {
Schema string
Table string
}

func (t *SchemaTable) String() string {
return fmt.Sprintf(`"%s"."%s"`, t.Schema, t.Table)
}

// ParseSchemaTable parses a table name into schema and table name.
func ParseSchemaTable(tableName string) (*SchemaTable, error) {
parts := strings.Split(tableName, ".")
if len(parts) != 2 {
return nil, fmt.Errorf("invalid table name: %s", tableName)
}

return &SchemaTable{
Schema: parts[0],
Table: parts[1],
}, nil
}

0 comments on commit 2be8bf0

Please sign in to comment.