From 924f798683e7ca9b7234f60daae9814db13940c7 Mon Sep 17 00:00:00 2001 From: Nick Farrell Date: Sat, 15 Jun 2024 09:36:34 +1000 Subject: [PATCH] Load types using a single SQL query When loading even a single type into pgx's type map, multiple SQL queries are performed in series. Over a slow link, this is not ideal. Worse, if multiple types are being registered, this is repeated multiple times. This commit changes the internal implementation of LoadType to use a single SQL query. It also added a LoadTypes, which can retrieve type mapping information for multiple types in a single SQL call. Additionally, LoadTypes will recursively load any related types, avoiding the need to explicitly list everything. --- conn.go | 232 ++++++++++++++++++++++++++++++++++++++- go.sum | 4 - pgtype/composite_test.go | 51 ++++++++- 3 files changed, 275 insertions(+), 12 deletions(-) diff --git a/conn.go b/conn.go index 311721459..65ee087c6 100644 --- a/conn.go +++ b/conn.go @@ -6,6 +6,7 @@ import ( "encoding/hex" "errors" "fmt" + "regexp" "strconv" "strings" "time" @@ -107,8 +108,10 @@ var ( ErrTooManyRows = errors.New("too many rows in result set") ) -var errDisabledStatementCache = fmt.Errorf("cannot use QueryExecModeCacheStatement with disabled statement cache") -var errDisabledDescriptionCache = fmt.Errorf("cannot use QueryExecModeCacheDescribe with disabled description cache") +var ( + errDisabledStatementCache = fmt.Errorf("cannot use QueryExecModeCacheStatement with disabled statement cache") + errDisabledDescriptionCache = fmt.Errorf("cannot use QueryExecModeCacheDescribe with disabled description cache") +) // Connect establishes a connection with a PostgreSQL server with a connection string. See // pgconn.Connect for details. @@ -843,7 +846,6 @@ func (c *Conn) getStatementDescription( mode QueryExecMode, sql string, ) (sd *pgconn.StatementDescription, err error) { - switch mode { case QueryExecModeCacheStatement: if c.statementCache == nil { @@ -1393,3 +1395,227 @@ func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error return nil } + +/* +buildLoadTypesSQL generates the correct query for retrieving type information. + + pgVersion: the major version of the PostgreSQL server + typeNames: the names of the types to load. If nil, load all types. +*/ +func buildLoadTypesSQL(pgVersion int64, typeNames []string) string { + supportsMultirange := (pgVersion >= 14) + var typeNamesClause string + if typeNames == nil { + typeNamesClause = "IS NOT NULL" + } else { + typeNamesClause = "= ANY($1)" + } + parts := make([]string, 0, 10) + + parts = append(parts, ` +WITH RECURSIVE +selected_classes(oid,reltype) AS ( + SELECT pg_class.oid, pg_class.reltype + FROM pg_catalog.pg_class + LEFT JOIN pg_catalog.pg_namespace n ON n.oid = pg_class.relnamespace + WHERE pg_catalog.pg_table_is_visible(pg_class.oid) + AND relname `, typeNamesClause, ` +UNION ALL + SELECT pg_class.oid, pg_class.reltype + FROM pg_class + INNER JOIN pg_namespace ON (pg_class.relnamespace = pg_namespace.oid) + WHERE nspname || '.' || relname `, typeNamesClause, ` +), +selected_types(oid) AS ( + SELECT reltype AS oid + FROM selected_classes +UNION ALL + SELECT oid + FROM pg_type + WHERE typname `, typeNamesClause, ` +), +pc(parent, child) AS ( + SELECT parent.oid, parent.typelem + FROM pg_type parent + WHERE parent.typtype = 'b' AND parent.typelem != 0 +UNION ALL + SELECT parent.oid, parent.typbasetype + FROM pg_type parent + WHERE parent.typtypmod = -1 AND parent.typbasetype != 0 +UNION ALL + SELECT pg_type.oid, atttypid + FROM pg_attribute + INNER JOIN pg_class ON (pg_class.oid = pg_attribute.attrelid) + INNER JOIN pg_type ON (pg_type.oid = pg_class.reltype) + WHERE NOT attisdropped + AND attnum > 0 +), +relationships(parent, child, depth) AS ( + SELECT DISTINCT 0::OID, selected_types.oid, 0 + FROM selected_types +UNION ALL + SELECT pg_type.oid AS parent, pg_attribute.atttypid AS child, 1 + FROM selected_classes c + inner join pg_type ON (c.reltype = pg_type.oid) + inner join pg_attribute on (c.oid = pg_attribute.attrelid) +UNION ALL + SELECT pc.parent, pc.child, relationships.depth + 1 + FROM pc + INNER JOIN relationships ON (pc.parent = relationships.child) +), +composite AS ( + SELECT pg_type.oid, ARRAY_AGG(attname ORDER BY attnum) AS attnames, ARRAY_AGG(atttypid ORDER BY ATTNUM) AS atttypids + FROM pg_attribute + INNER JOIN pg_class ON (pg_class.oid = pg_attribute.attrelid) + INNER JOIN pg_type ON (pg_type.oid = pg_class.reltype) + WHERE NOT attisdropped + AND attnum > 0 + GROUP BY pg_type.oid +) +SELECT typname, + typtype, + typbasetype, + typelem, + pg_type.oid,`) + if supportsMultirange { + parts = append(parts, ` + COALESCE(multirange.rngtypid, 0) AS rngtypid,`) + } else { + parts = append(parts, ` + 0 AS rngtypid,`) + } + parts = append(parts, ` + COALESCE(pg_range.rngsubtype, 0) AS rngsubtype, + attnames, atttypids + FROM relationships + INNER JOIN pg_type ON (pg_type.oid IN ( relationships.child,relationships.parent) ) + LEFT OUTER JOIN pg_range ON (pg_type.oid = pg_range.rngtypid)`) + if supportsMultirange { + parts = append(parts, ` + LEFT OUTER JOIN pg_range multirange ON (pg_type.oid = multirange.rngmultitypid)`) + } + + parts = append(parts, ` + LEFT OUTER JOIN composite USING (oid) + WHERE NOT (typtype = 'b' AND typelem = 0)`) + parts = append(parts, ` + GROUP BY typname, typtype, typbasetype, typelem, pg_type.oid, pg_range.rngsubtype,`) + if supportsMultirange { + parts = append(parts, ` + multirange.rngtypid,`) + } + parts = append(parts, ` + attnames, atttypids + ORDER BY MAX(depth) desc, typname;`) + return strings.Join(parts, "") +} + +// LoadAndRegisterTypes inspects the database for []typeNames and automatically registers all discovered +// types. Any types referenced by these will also be included in the registration. +func (c *Conn) LoadAndRegisterTypes(ctx context.Context, typeNames []string) error { + if typeNames == nil || len(typeNames) == 0 { + return fmt.Errorf("No type names were supplied.") + } + return c.loadAndRegisterTypes(ctx, typeNames, c.TypeMap()) +} + +func (c *Conn) loadAndRegisterTypes(ctx context.Context, typeNames []string, registerWith *pgtype.Map) error { + if registerWith == nil { + return fmt.Errorf("Type map must be supplied") + } + serverVersion, err := c.ServerVersion() + if err != nil { + return fmt.Errorf("Unexpected server version error: %w", err) + } + sql := buildLoadTypesSQL(serverVersion, typeNames) + var rows Rows + if typeNames == nil { + rows, err = c.Query(ctx, sql, QueryExecModeSimpleProtocol) + } else { + rows, err = c.Query(ctx, sql, QueryExecModeSimpleProtocol, typeNames) + } + if err != nil { + return fmt.Errorf("While generating load types query: %w", err) + } + defer rows.Close() + for rows.Next() { + var oid uint32 + var typeName, typtype string + var typbasetype, typelem uint32 + var rngsubtype, rngtypid uint32 + attnames := make([]string, 0, 0) + atttypids := make([]uint32, 0, 0) + err = rows.Scan(&typeName, &typtype, &typbasetype, &typelem, &oid, &rngtypid, &rngsubtype, &attnames, &atttypids) + if err != nil { + return fmt.Errorf("While scanning type information: %w", err) + } + + switch typtype { + case "b": // array + dt, ok := c.TypeMap().TypeForOID(typelem) + if !ok { + return fmt.Errorf("array element OID %v not registered while loading for %v", typelem, typeName) + } + registerWith.RegisterType(&pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.ArrayCodec{ElementType: dt}}) + case "c": // composite + var fields []pgtype.CompositeCodecField + for i, fieldName := range attnames { + //if fieldOID64, err = strconv.ParseUint(composite_fields[i+1], 10, 32); err != nil { + // return nil, fmt.Errorf("While extracting OID used in composite field: %w", err) + //} + dt, ok := c.TypeMap().TypeForOID(atttypids[i]) + if !ok { + return fmt.Errorf("unknown composite type field OID %v (%v)", atttypids[i], fieldName) + } + fields = append(fields, pgtype.CompositeCodecField{Name: fieldName, Type: dt}) + } + if err != nil { + return fmt.Errorf("While parsing %v: %w", typeName, err) + } + + registerWith.RegisterType(&pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.CompositeCodec{Fields: fields}}) + case "d": // domain + dt, ok := c.TypeMap().TypeForOID(typbasetype) + if !ok { + return fmt.Errorf("domain base type OID %v was not already registered, needed for %v", typbasetype, typeName) + } + + registerWith.RegisterType(&pgtype.Type{Name: typeName, OID: oid, Codec: dt.Codec}) + case "e": // enum + registerWith.RegisterType(&pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.EnumCodec{}}) + case "r": // range + dt, ok := c.TypeMap().TypeForOID(rngsubtype) + if !ok { + return fmt.Errorf("range element OID %v not registered for %v", rngsubtype, typeName) + } + + registerWith.RegisterType(&pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.RangeCodec{ElementType: dt}}) + case "m": // multirange + dt, ok := c.TypeMap().TypeForOID(rngtypid) + if !ok { + return fmt.Errorf("multirange element OID %v not registered while loading %v", rngtypid, typeName) + } + + registerWith.RegisterType(&pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.MultirangeCodec{ElementType: dt}}) + default: + return fmt.Errorf("unknown typtype %v for %v", typtype, typeName) + } + } + return nil +} + +// ServerVersion returns the postgresql server version. +func (conn *Conn) ServerVersion() (int64, error) { + serverVersionStr := conn.PgConn().ParameterStatus("server_version") + serverVersionStr = regexp.MustCompile(`^[0-9]+`).FindString(serverVersionStr) + // if not PostgreSQL do nothing + if serverVersionStr == "" { + return 0, fmt.Errorf("Cannot identify server version in %q", serverVersionStr) + } + + serverVersion, err := strconv.ParseInt(serverVersionStr, 10, 64) + if err != nil { + return 0, fmt.Errorf("postgres version parsing failed: %w", err) + } + return serverVersion, nil +} diff --git a/go.sum b/go.sum index 4b02a0365..29fe452b2 100644 --- a/go.sum +++ b/go.sum @@ -4,10 +4,6 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= -github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= -github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= -github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 h1:L0QtFUgDarD7Fpv9jeVMgy/+Ec0mtnmYuImjTz6dtDA= -github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= diff --git a/pgtype/composite_test.go b/pgtype/composite_test.go index a049b448e..6168a94fb 100644 --- a/pgtype/composite_test.go +++ b/pgtype/composite_test.go @@ -10,11 +10,56 @@ import ( "github.com/stretchr/testify/require" ) -func TestCompositeCodecTranscode(t *testing.T) { +func TestCompositeCodecTranscodeWithLoadTypes(t *testing.T) { skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)") defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + _, err := conn.Exec(ctx, `drop domain if exists anotheruint64; +drop type if exists ct_test; +create domain anotheruint64 as numeric(20,0); + +create type ct_test as ( + a text, + b int4, + c anotheruint64 +);`) + require.NoError(t, err) + defer conn.Exec(ctx, "drop type ct_test") + defer conn.Exec(ctx, "drop domain anotheruint64") + + err = conn.LoadAndRegisterTypes(ctx, []string{"ct_test"}) + require.NoError(t, err) + + formats := []struct { + name string + code int16 + }{ + {name: "TextFormat", code: pgx.TextFormatCode}, + {name: "BinaryFormat", code: pgx.BinaryFormatCode}, + } + + for _, format := range formats { + var a string + var b int32 + var c uint64 + err := conn.QueryRow(ctx, "select $1::ct_test", pgx.QueryResultFormats{format.code}, + pgtype.CompositeFields{"hi", int32(42), uint64(123)}, + ).Scan( + pgtype.CompositeFields{&a, &b, &c}, + ) + require.NoErrorf(t, err, "%v", format.name) + require.EqualValuesf(t, "hi", a, "%v", format.name) + require.EqualValuesf(t, 42, b, "%v", format.name) + require.EqualValuesf(t, 123, c, "%v", format.name) + } + }) +} + +func TestCompositeCodecTranscode(t *testing.T) { + skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)") + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { _, err := conn.Exec(ctx, `drop type if exists ct_test; create type ct_test as ( @@ -94,7 +139,6 @@ func TestCompositeCodecTranscodeStruct(t *testing.T) { skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)") defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { - _, err := conn.Exec(ctx, `drop type if exists point3d; create type point3d as ( @@ -131,7 +175,6 @@ func TestCompositeCodecTranscodeStructWrapper(t *testing.T) { skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)") defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { - _, err := conn.Exec(ctx, `drop type if exists point3d; create type point3d as ( @@ -172,7 +215,6 @@ func TestCompositeCodecDecodeValue(t *testing.T) { skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)") defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { - _, err := conn.Exec(ctx, `drop type if exists point3d; create type point3d as ( @@ -217,7 +259,6 @@ func TestCompositeCodecTranscodeStructWrapperForTable(t *testing.T) { skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)") defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { - _, err := conn.Exec(ctx, `drop table if exists point3d; create table point3d (