diff --git a/flow/connectors/postgres/client.go b/flow/connectors/postgres/client.go index f14f8045b..816eecfb2 100644 --- a/flow/connectors/postgres/client.go +++ b/flow/connectors/postgres/client.go @@ -126,19 +126,15 @@ func (c *PostgresConnector) getPrimaryKeyColumns(schemaTable *utils.SchemaTable) return nil, fmt.Errorf("error getting primary key column for table %s: %w", schemaTable, err) } defer rows.Close() - // 0 rows returned, table has no primary keys - if !rows.Next() { - return nil, fmt.Errorf("table %s has no primary keys", schemaTable) - } for { + if !rows.Next() { + break + } err = rows.Scan(&pkCol) if err != nil { return nil, fmt.Errorf("error scanning primary key column for table %s: %w", schemaTable, err) } pkCols = append(pkCols, pkCol) - if !rows.Next() { - break - } } return pkCols, nil @@ -314,13 +310,15 @@ func generateCreateTableSQLForNormalizedTable(sourceTableIdentifier string, } // add composite primary key to the table - primaryKeyColsQuoted := make([]string, 0) - for _, primaryKeyCol := range sourceTableSchema.PrimaryKeyColumns { - primaryKeyColsQuoted = append(primaryKeyColsQuoted, - fmt.Sprintf(`"%s"`, primaryKeyCol)) + if len(sourceTableSchema.PrimaryKeyColumns) > 0 { + primaryKeyColsQuoted := make([]string, 0, len(sourceTableSchema.PrimaryKeyColumns)) + for _, primaryKeyCol := range sourceTableSchema.PrimaryKeyColumns { + primaryKeyColsQuoted = append(primaryKeyColsQuoted, + fmt.Sprintf(`"%s"`, primaryKeyCol)) + } + createTableSQLArray = append(createTableSQLArray, fmt.Sprintf("PRIMARY KEY(%s),", + strings.TrimSuffix(strings.Join(primaryKeyColsQuoted, ","), ","))) } - createTableSQLArray = append(createTableSQLArray, fmt.Sprintf("PRIMARY KEY(%s),", - strings.TrimSuffix(strings.Join(primaryKeyColsQuoted, ","), ","))) return fmt.Sprintf(createNormalizedTableSQL, sourceTableIdentifier, strings.TrimSuffix(strings.Join(createTableSQLArray, ""), ",")) diff --git a/flow/connectors/postgres/postgres.go b/flow/connectors/postgres/postgres.go index 4f6a4427f..ccabcfdb8 100644 --- a/flow/connectors/postgres/postgres.go +++ b/flow/connectors/postgres/postgres.go @@ -573,6 +573,11 @@ func (c *PostgresConnector) getTableSchemaForTable( return nil, fmt.Errorf("error getting replica identity for table %s: %w", schemaTable, replErr) } + pKeyCols, err := c.getPrimaryKeyColumns(schemaTable) + if err != nil { + return nil, fmt.Errorf("error getting primary key column for table %s: %w", schemaTable, err) + } + // Get the column names and types rows, err := c.pool.Query(c.ctx, fmt.Sprintf(`SELECT * FROM %s LIMIT 0`, schemaTable.String()), @@ -582,13 +587,6 @@ func (c *PostgresConnector) getTableSchemaForTable( } defer rows.Close() - pKeyCols, err := c.getPrimaryKeyColumns(schemaTable) - if err != nil { - if !isFullReplica { - return nil, fmt.Errorf("error getting primary key column for table %s: %w", schemaTable, err) - } - } - res := &protos.TableSchema{ TableIdentifier: tableName, Columns: make(map[string]string), @@ -744,6 +742,21 @@ func (c *PostgresConnector) EnsurePullability(req *protos.EnsurePullabilityBatch return nil, err } + isFullReplica, replErr := c.isTableFullReplica(schemaTable) + if replErr != nil { + return nil, fmt.Errorf("error getting replica identity for table %s: %w", schemaTable, replErr) + } + + pKeyCols, err := c.getPrimaryKeyColumns(schemaTable) + if err != nil { + return nil, fmt.Errorf("error getting primary key column for table %s: %w", schemaTable, err) + } + + // we only allow no primary key if the table has REPLICA IDENTITY FULL + if len(pKeyCols) == 0 && !isFullReplica { + return nil, fmt.Errorf("table %s has no primary keys and does not have REPLICA IDENTITY FULL", schemaTable) + } + tableIdentifierMapping[tableName] = &protos.TableIdentifier{ TableIdentifier: &protos.TableIdentifier_PostgresTableIdentifier{ PostgresTableIdentifier: &protos.PostgresTableIdentifier{ diff --git a/flow/connectors/snowflake/snowflake.go b/flow/connectors/snowflake/snowflake.go index f04b9be38..4cd48aed0 100644 --- a/flow/connectors/snowflake/snowflake.go +++ b/flow/connectors/snowflake/snowflake.go @@ -783,13 +783,15 @@ func generateCreateTableSQLForNormalizedTable( } // add composite primary key to the table - primaryKeyColsUpperQuoted := make([]string, 0) - for _, primaryKeyCol := range sourceTableSchema.PrimaryKeyColumns { - primaryKeyColsUpperQuoted = append(primaryKeyColsUpperQuoted, - fmt.Sprintf(`"%s"`, strings.ToUpper(primaryKeyCol))) + if len(sourceTableSchema.PrimaryKeyColumns) > 0 { + primaryKeyColsUpperQuoted := make([]string, 0, len(sourceTableSchema.PrimaryKeyColumns)) + for _, primaryKeyCol := range sourceTableSchema.PrimaryKeyColumns { + primaryKeyColsUpperQuoted = append(primaryKeyColsUpperQuoted, + fmt.Sprintf(`"%s"`, strings.ToUpper(primaryKeyCol))) + } + createTableSQLArray = append(createTableSQLArray, fmt.Sprintf("PRIMARY KEY(%s),", + strings.TrimSuffix(strings.Join(primaryKeyColsUpperQuoted, ","), ","))) } - createTableSQLArray = append(createTableSQLArray, fmt.Sprintf("PRIMARY KEY(%s),", - strings.TrimSuffix(strings.Join(primaryKeyColsUpperQuoted, ","), ","))) return fmt.Sprintf(createNormalizedTableSQL, sourceTableIdentifier, strings.TrimSuffix(strings.Join(createTableSQLArray, ""), ","))