diff --git a/flow/connectors/postgres/qrep_sync_method.go b/flow/connectors/postgres/qrep_sync_method.go index 6725032411..0577fd6d3f 100644 --- a/flow/connectors/postgres/qrep_sync_method.go +++ b/flow/connectors/postgres/qrep_sync_method.go @@ -12,6 +12,7 @@ import ( "github.com/PeerDB-io/peer-flow/model" "github.com/PeerDB-io/peer-flow/shared" "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" "google.golang.org/protobuf/encoding/protojson" ) @@ -44,16 +45,21 @@ func (s *QRepStagingTableSync) SyncQRepRecords( ) partitionID := partition.PartitionId startTime := time.Now() - - pool := s.connector.pool schema, err := stream.Schema() if err != nil { slog.Error("failed to get schema from stream", slog.Any("error", err), syncLog) return 0, fmt.Errorf("failed to get schema from stream: %w", err) } + txConfig := s.connector.pool.poolConfig.Copy() + txConfig.AfterConnect = utils.RegisterCustomTypesForConnection + txPool, err := pgxpool.NewWithConfig(s.connector.pool.ctx, txConfig) + if err != nil { + return 0, fmt.Errorf("failed to create tx pool: %v", err) + } + // Second transaction - to handle rest of the processing - tx, err := pool.Begin(context.Background()) + tx, err := txPool.Begin(context.Background()) if err != nil { return 0, fmt.Errorf("failed to begin transaction: %v", err) } diff --git a/flow/connectors/utils/postgres.go b/flow/connectors/utils/postgres.go index b043698943..9a9742fef3 100644 --- a/flow/connectors/utils/postgres.go +++ b/flow/connectors/utils/postgres.go @@ -8,6 +8,7 @@ import ( "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/jackc/pgerrcode" + "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/pgxpool" @@ -56,3 +57,22 @@ func GetCustomDataTypes(ctx context.Context, pool *pgxpool.Pool) (map[uint32]str } return customTypeMap, nil } + +func RegisterCustomTypesForConnection(ctx context.Context, conn *pgx.Conn) error { + typeNames := []string{"hstore", "geometry", "geography"} + typeOIDs := make(map[string]uint32) + + for _, typeName := range typeNames { + err := conn.QueryRow(ctx, `SELECT oid FROM pg_type WHERE typname = $1`, typeName).Scan(typeOIDs[typeName]) + if err != nil { + return err + } + } + + typeMap := conn.TypeMap() + for typeName, typeOID := range typeOIDs { + typeMap.RegisterType(&pgtype.Type{Name: typeName, OID: typeOID}) + } + + return nil +}