Skip to content

Commit

Permalink
Load types using a single SQL query
Browse files Browse the repository at this point in the history
When loading even a single type into pgx's type map, multiple SQL
queries are performed in series. Over a slow link, this is not ideal.
Worse, if multiple types are being registered, this is repeated multiple
times.

This commit changes the internal implementation of LoadType to use a
single SQL query. It also added a LoadTypes, which can retrieve type
mapping information for multiple types in a single SQL call.
Additionally, LoadTypes will recursively load any related types,
avoiding the need to explicitly list everything.
  • Loading branch information
nicois committed Jun 17, 2024
1 parent 9907b87 commit 924f798
Show file tree
Hide file tree
Showing 3 changed files with 275 additions and 12 deletions.
232 changes: 229 additions & 3 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"encoding/hex"
"errors"
"fmt"
"regexp"
"strconv"
"strings"
"time"
Expand Down Expand Up @@ -107,8 +108,10 @@ var (
ErrTooManyRows = errors.New("too many rows in result set")
)

var errDisabledStatementCache = fmt.Errorf("cannot use QueryExecModeCacheStatement with disabled statement cache")
var errDisabledDescriptionCache = fmt.Errorf("cannot use QueryExecModeCacheDescribe with disabled description cache")
var (
errDisabledStatementCache = fmt.Errorf("cannot use QueryExecModeCacheStatement with disabled statement cache")
errDisabledDescriptionCache = fmt.Errorf("cannot use QueryExecModeCacheDescribe with disabled description cache")
)

// Connect establishes a connection with a PostgreSQL server with a connection string. See
// pgconn.Connect for details.
Expand Down Expand Up @@ -843,7 +846,6 @@ func (c *Conn) getStatementDescription(
mode QueryExecMode,
sql string,
) (sd *pgconn.StatementDescription, err error) {

switch mode {
case QueryExecModeCacheStatement:
if c.statementCache == nil {
Expand Down Expand Up @@ -1393,3 +1395,227 @@ func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error

return nil
}

/*
buildLoadTypesSQL generates the correct query for retrieving type information.
pgVersion: the major version of the PostgreSQL server
typeNames: the names of the types to load. If nil, load all types.
*/
func buildLoadTypesSQL(pgVersion int64, typeNames []string) string {
supportsMultirange := (pgVersion >= 14)
var typeNamesClause string
if typeNames == nil {
typeNamesClause = "IS NOT NULL"
} else {
typeNamesClause = "= ANY($1)"
}
parts := make([]string, 0, 10)

parts = append(parts, `
WITH RECURSIVE
selected_classes(oid,reltype) AS (
SELECT pg_class.oid, pg_class.reltype
FROM pg_catalog.pg_class
LEFT JOIN pg_catalog.pg_namespace n ON n.oid = pg_class.relnamespace
WHERE pg_catalog.pg_table_is_visible(pg_class.oid)
AND relname `, typeNamesClause, `
UNION ALL
SELECT pg_class.oid, pg_class.reltype
FROM pg_class
INNER JOIN pg_namespace ON (pg_class.relnamespace = pg_namespace.oid)
WHERE nspname || '.' || relname `, typeNamesClause, `
),
selected_types(oid) AS (
SELECT reltype AS oid
FROM selected_classes
UNION ALL
SELECT oid
FROM pg_type
WHERE typname `, typeNamesClause, `
),
pc(parent, child) AS (
SELECT parent.oid, parent.typelem
FROM pg_type parent
WHERE parent.typtype = 'b' AND parent.typelem != 0
UNION ALL
SELECT parent.oid, parent.typbasetype
FROM pg_type parent
WHERE parent.typtypmod = -1 AND parent.typbasetype != 0
UNION ALL
SELECT pg_type.oid, atttypid
FROM pg_attribute
INNER JOIN pg_class ON (pg_class.oid = pg_attribute.attrelid)
INNER JOIN pg_type ON (pg_type.oid = pg_class.reltype)
WHERE NOT attisdropped
AND attnum > 0
),
relationships(parent, child, depth) AS (
SELECT DISTINCT 0::OID, selected_types.oid, 0
FROM selected_types
UNION ALL
SELECT pg_type.oid AS parent, pg_attribute.atttypid AS child, 1
FROM selected_classes c
inner join pg_type ON (c.reltype = pg_type.oid)
inner join pg_attribute on (c.oid = pg_attribute.attrelid)
UNION ALL
SELECT pc.parent, pc.child, relationships.depth + 1
FROM pc
INNER JOIN relationships ON (pc.parent = relationships.child)
),
composite AS (
SELECT pg_type.oid, ARRAY_AGG(attname ORDER BY attnum) AS attnames, ARRAY_AGG(atttypid ORDER BY ATTNUM) AS atttypids
FROM pg_attribute
INNER JOIN pg_class ON (pg_class.oid = pg_attribute.attrelid)
INNER JOIN pg_type ON (pg_type.oid = pg_class.reltype)
WHERE NOT attisdropped
AND attnum > 0
GROUP BY pg_type.oid
)
SELECT typname,
typtype,
typbasetype,
typelem,
pg_type.oid,`)
if supportsMultirange {
parts = append(parts, `
COALESCE(multirange.rngtypid, 0) AS rngtypid,`)
} else {
parts = append(parts, `
0 AS rngtypid,`)
}
parts = append(parts, `
COALESCE(pg_range.rngsubtype, 0) AS rngsubtype,
attnames, atttypids
FROM relationships
INNER JOIN pg_type ON (pg_type.oid IN ( relationships.child,relationships.parent) )
LEFT OUTER JOIN pg_range ON (pg_type.oid = pg_range.rngtypid)`)
if supportsMultirange {
parts = append(parts, `
LEFT OUTER JOIN pg_range multirange ON (pg_type.oid = multirange.rngmultitypid)`)
}

parts = append(parts, `
LEFT OUTER JOIN composite USING (oid)
WHERE NOT (typtype = 'b' AND typelem = 0)`)
parts = append(parts, `
GROUP BY typname, typtype, typbasetype, typelem, pg_type.oid, pg_range.rngsubtype,`)
if supportsMultirange {
parts = append(parts, `
multirange.rngtypid,`)
}
parts = append(parts, `
attnames, atttypids
ORDER BY MAX(depth) desc, typname;`)
return strings.Join(parts, "")
}

// LoadAndRegisterTypes inspects the database for []typeNames and automatically registers all discovered
// types. Any types referenced by these will also be included in the registration.
func (c *Conn) LoadAndRegisterTypes(ctx context.Context, typeNames []string) error {
if typeNames == nil || len(typeNames) == 0 {
return fmt.Errorf("No type names were supplied.")
}
return c.loadAndRegisterTypes(ctx, typeNames, c.TypeMap())
}

func (c *Conn) loadAndRegisterTypes(ctx context.Context, typeNames []string, registerWith *pgtype.Map) error {
if registerWith == nil {
return fmt.Errorf("Type map must be supplied")
}
serverVersion, err := c.ServerVersion()
if err != nil {
return fmt.Errorf("Unexpected server version error: %w", err)
}
sql := buildLoadTypesSQL(serverVersion, typeNames)
var rows Rows
if typeNames == nil {
rows, err = c.Query(ctx, sql, QueryExecModeSimpleProtocol)
} else {
rows, err = c.Query(ctx, sql, QueryExecModeSimpleProtocol, typeNames)
}
if err != nil {
return fmt.Errorf("While generating load types query: %w", err)
}
defer rows.Close()
for rows.Next() {
var oid uint32
var typeName, typtype string
var typbasetype, typelem uint32
var rngsubtype, rngtypid uint32
attnames := make([]string, 0, 0)
atttypids := make([]uint32, 0, 0)
err = rows.Scan(&typeName, &typtype, &typbasetype, &typelem, &oid, &rngtypid, &rngsubtype, &attnames, &atttypids)
if err != nil {
return fmt.Errorf("While scanning type information: %w", err)
}

switch typtype {
case "b": // array
dt, ok := c.TypeMap().TypeForOID(typelem)
if !ok {
return fmt.Errorf("array element OID %v not registered while loading for %v", typelem, typeName)
}
registerWith.RegisterType(&pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.ArrayCodec{ElementType: dt}})
case "c": // composite
var fields []pgtype.CompositeCodecField
for i, fieldName := range attnames {
//if fieldOID64, err = strconv.ParseUint(composite_fields[i+1], 10, 32); err != nil {
// return nil, fmt.Errorf("While extracting OID used in composite field: %w", err)
//}
dt, ok := c.TypeMap().TypeForOID(atttypids[i])
if !ok {
return fmt.Errorf("unknown composite type field OID %v (%v)", atttypids[i], fieldName)
}
fields = append(fields, pgtype.CompositeCodecField{Name: fieldName, Type: dt})
}
if err != nil {
return fmt.Errorf("While parsing %v: %w", typeName, err)
}

registerWith.RegisterType(&pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.CompositeCodec{Fields: fields}})
case "d": // domain
dt, ok := c.TypeMap().TypeForOID(typbasetype)
if !ok {
return fmt.Errorf("domain base type OID %v was not already registered, needed for %v", typbasetype, typeName)
}

registerWith.RegisterType(&pgtype.Type{Name: typeName, OID: oid, Codec: dt.Codec})
case "e": // enum
registerWith.RegisterType(&pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.EnumCodec{}})
case "r": // range
dt, ok := c.TypeMap().TypeForOID(rngsubtype)
if !ok {
return fmt.Errorf("range element OID %v not registered for %v", rngsubtype, typeName)
}

registerWith.RegisterType(&pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.RangeCodec{ElementType: dt}})
case "m": // multirange
dt, ok := c.TypeMap().TypeForOID(rngtypid)
if !ok {
return fmt.Errorf("multirange element OID %v not registered while loading %v", rngtypid, typeName)
}

registerWith.RegisterType(&pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.MultirangeCodec{ElementType: dt}})
default:
return fmt.Errorf("unknown typtype %v for %v", typtype, typeName)
}
}
return nil
}

// ServerVersion returns the postgresql server version.
func (conn *Conn) ServerVersion() (int64, error) {
serverVersionStr := conn.PgConn().ParameterStatus("server_version")
serverVersionStr = regexp.MustCompile(`^[0-9]+`).FindString(serverVersionStr)
// if not PostgreSQL do nothing
if serverVersionStr == "" {
return 0, fmt.Errorf("Cannot identify server version in %q", serverVersionStr)
}

serverVersion, err := strconv.ParseInt(serverVersionStr, 10, 64)
if err != nil {
return 0, fmt.Errorf("postgres version parsing failed: %w", err)
}
return serverVersion, nil
}
4 changes: 0 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,6 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 h1:L0QtFUgDarD7Fpv9jeVMgy/+Ec0mtnmYuImjTz6dtDA=
github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
Expand Down
51 changes: 46 additions & 5 deletions pgtype/composite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,56 @@ import (
"github.com/stretchr/testify/require"
)

func TestCompositeCodecTranscode(t *testing.T) {
func TestCompositeCodecTranscodeWithLoadTypes(t *testing.T) {
skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)")

defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
_, err := conn.Exec(ctx, `drop domain if exists anotheruint64;
drop type if exists ct_test;
create domain anotheruint64 as numeric(20,0);
create type ct_test as (
a text,
b int4,
c anotheruint64
);`)
require.NoError(t, err)
defer conn.Exec(ctx, "drop type ct_test")
defer conn.Exec(ctx, "drop domain anotheruint64")

err = conn.LoadAndRegisterTypes(ctx, []string{"ct_test"})
require.NoError(t, err)

formats := []struct {
name string
code int16
}{
{name: "TextFormat", code: pgx.TextFormatCode},
{name: "BinaryFormat", code: pgx.BinaryFormatCode},
}

for _, format := range formats {
var a string
var b int32
var c uint64

err := conn.QueryRow(ctx, "select $1::ct_test", pgx.QueryResultFormats{format.code},
pgtype.CompositeFields{"hi", int32(42), uint64(123)},
).Scan(
pgtype.CompositeFields{&a, &b, &c},
)
require.NoErrorf(t, err, "%v", format.name)
require.EqualValuesf(t, "hi", a, "%v", format.name)
require.EqualValuesf(t, 42, b, "%v", format.name)
require.EqualValuesf(t, 123, c, "%v", format.name)
}
})
}

func TestCompositeCodecTranscode(t *testing.T) {
skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)")

defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
_, err := conn.Exec(ctx, `drop type if exists ct_test;
create type ct_test as (
Expand Down Expand Up @@ -94,7 +139,6 @@ func TestCompositeCodecTranscodeStruct(t *testing.T) {
skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)")

defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {

_, err := conn.Exec(ctx, `drop type if exists point3d;
create type point3d as (
Expand Down Expand Up @@ -131,7 +175,6 @@ func TestCompositeCodecTranscodeStructWrapper(t *testing.T) {
skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)")

defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {

_, err := conn.Exec(ctx, `drop type if exists point3d;
create type point3d as (
Expand Down Expand Up @@ -172,7 +215,6 @@ func TestCompositeCodecDecodeValue(t *testing.T) {
skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)")

defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {

_, err := conn.Exec(ctx, `drop type if exists point3d;
create type point3d as (
Expand Down Expand Up @@ -217,7 +259,6 @@ func TestCompositeCodecTranscodeStructWrapperForTable(t *testing.T) {
skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)")

defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {

_, err := conn.Exec(ctx, `drop table if exists point3d;
create table point3d (
Expand Down

0 comments on commit 924f798

Please sign in to comment.