From 44421eaeb098cf51e5739d564a4fb70d23729bae Mon Sep 17 00:00:00 2001 From: Noy Date: Tue, 17 Dec 2024 17:55:10 +0800 Subject: [PATCH] Revert "feat: support multiple statements in one query (`implicit transaction block` not implemented)" This reverts commit ba451a64fe8ea234d634520576627a021fcaa33f. --- compatibility/pg-pytools/psycopg_test.py | 17 +- compatibility/pg/test.bats | 2 - pgserver/connection_data.go | 14 +- pgserver/connection_handler.go | 250 +++++++++++------------ pgserver/duck_handler.go | 2 +- pgserver/pg_catalog_handler.go | 70 +++---- 6 files changed, 174 insertions(+), 181 deletions(-) diff --git a/compatibility/pg-pytools/psycopg_test.py b/compatibility/pg-pytools/psycopg_test.py index b08b41d1..0578c912 100644 --- a/compatibility/pg-pytools/psycopg_test.py +++ b/compatibility/pg-pytools/psycopg_test.py @@ -1,3 +1,4 @@ +from psycopg import sql import psycopg rows = [ @@ -12,14 +13,16 @@ with psycopg.connect("dbname=postgres user=postgres host=127.0.0.1 port=5432", autocommit=True) as conn: # Open a cursor to perform database operations with conn.cursor() as cur: + cur.execute("DROP SCHEMA IF EXISTS test CASCADE") + cur.execute("CREATE SCHEMA test") + cur.execute(""" - DROP SCHEMA IF EXISTS test CASCADE; - CREATE SCHEMA test; - CREATE TABLE test.tb1 ( - id integer PRIMARY KEY, - num integer, - data text) - """) + CREATE TABLE test.tb1 ( + id integer PRIMARY KEY, + num integer, + data text) + """) + # Pass data to fill a query placeholders and let Psycopg perform the correct conversion cur.execute( diff --git a/compatibility/pg/test.bats b/compatibility/pg/test.bats index a71e8fba..768cc2f0 100644 --- a/compatibility/pg/test.bats +++ b/compatibility/pg/test.bats @@ -41,8 +41,6 @@ start_process() { start_process $BATS_TEST_DIRNAME/c/pg_test 127.0.0.1 5432 postgres "" $BATS_TEST_DIRNAME/test.data } -# Failed because of the following error: -# > Catalog Error: Table with name pg_range does not exist! # @test "pg-csharp" { # set_custom_teardown "sudo pkill -f dotnet" # start_process dotnet build $BATS_TEST_DIRNAME/csharp/PGTest.csproj -o $BATS_TEST_DIRNAME/csharp/bin diff --git a/pgserver/connection_data.go b/pgserver/connection_data.go index d9864557..f5a596e4 100644 --- a/pgserver/connection_data.go +++ b/pgserver/connection_data.go @@ -48,14 +48,14 @@ const ( ReadyForQueryTransactionIndicator_FailedTransactionBlock ReadyForQueryTransactionIndicator = 'E' ) -// ConvertedStatement represents a statement that has been converted from the Postgres representation to the Vitess -// representation. String may contain the string version of the converted statement. AST will contain the tree -// version of the converted statement, and is the recommended form to use. If AST is nil, then use the String version, +// ConvertedQuery represents a query that has been converted from the Postgres representation to the Vitess +// representation. String may contain the string version of the converted query. AST will contain the tree +// version of the converted query, and is the recommended form to use. If AST is nil, then use the String version, // otherwise always prefer to AST. -type ConvertedStatement struct { +type ConvertedQuery struct { String string AST tree.Statement - Tag string + StatementTag string PgParsable bool SubscriptionConfig *SubscriptionConfig BackupConfig *BackupConfig @@ -86,7 +86,7 @@ type copyFromStdinState struct { } type PortalData struct { - Statement ConvertedStatement + Query ConvertedQuery IsEmptyQuery bool Fields []pgproto3.FieldDescription ResultFormatCodes []int16 @@ -96,7 +96,7 @@ type PortalData struct { } type PreparedStatementData struct { - Statement ConvertedStatement + Query ConvertedQuery ReturnFields []pgproto3.FieldDescription BindVarTypes []uint32 Stmt *duckdb.Stmt diff --git a/pgserver/connection_handler.go b/pgserver/connection_handler.go index 80cc4ac6..3b6e3164 100644 --- a/pgserver/connection_handler.go +++ b/pgserver/connection_handler.go @@ -470,7 +470,7 @@ func (h *ConnectionHandler) handleQuery(message *pgproto3.Query) (endOfMessages return true, err } - statements, err := h.convertQuery(message.String) + query, err := h.convertQuery(message.String) if err != nil { return true, err } @@ -479,64 +479,53 @@ func (h *ConnectionHandler) handleQuery(message *pgproto3.Query) (endOfMessages h.deletePreparedStatement("") h.deletePortal("") - for _, statement := range statements { - // Certain statement types get handled directly by the handler instead of being passed to the engine - handled, endOfMessages, err = h.handleStatementOutsideEngine(statement) - if handled { - if err != nil { - h.logger.Warnf("Failed to handle statement %v outside engine: %v", statement, err) - return true, err - } - } else { - if err != nil { - h.logger.Warnf("Failed to handle statement %v outside engine: %v", statement, err) - } - endOfMessages, err = true, h.runStatement(statement) - if err != nil { - return true, err - } - } + // Certain statement types get handled directly by the handler instead of being passed to the engine + handled, endOfMessages, err = h.handleQueryOutsideEngine(query) + if handled { + return endOfMessages, err + } else if err != nil { + h.logger.Warnf("Failed to handle query %v outside engine: %v", query, err) } - return endOfMessages, nil + return true, h.query(query) } -// handleStatementOutsideEngine handles any queries that should be handled by the handler directly, rather than being +// handleQueryOutsideEngine handles any queries that should be handled by the handler directly, rather than being // passed to the engine. The response parameter |handled| is true if the query was handled, |endOfMessages| is true // if no more messages are expected for this query and server should send the client a READY FOR QUERY message, // and any error that occurred while handling the query. -func (h *ConnectionHandler) handleStatementOutsideEngine(statement ConvertedStatement) (handled bool, endOfMessages bool, err error) { - switch stmt := statement.AST.(type) { +func (h *ConnectionHandler) handleQueryOutsideEngine(query ConvertedQuery) (handled bool, endOfMessages bool, err error) { + switch stmt := query.AST.(type) { case *tree.Deallocate: // TODO: handle ALL keyword - return true, true, h.deallocatePreparedStatement(stmt.Name.String(), h.preparedStatements, statement, h.Conn()) + return true, true, h.deallocatePreparedStatement(stmt.Name.String(), h.preparedStatements, query, h.Conn()) case *tree.Discard: - return true, true, h.discardAll(statement) + return true, true, h.discardAll(query) case *tree.CopyFrom: // When copying data from STDIN, the data is sent to the server as CopyData messages // We send endOfMessages=false since the server will be in COPY DATA mode and won't // be ready for more queries util COPY DATA mode is completed. if stmt.Stdin { - return true, false, h.handleCopyFromStdinQuery(statement, stmt, "") + return true, false, h.handleCopyFromStdinQuery(query, stmt, "") } case *tree.CopyTo: - return true, true, h.handleCopyToStdout(statement, stmt, "" /* unused */, stmt.Options.CopyFormat, "") + return true, true, h.handleCopyToStdout(query, stmt, "" /* unused */, stmt.Options.CopyFormat, "") } - if statement.Tag == "COPY" { - if target, format, options, ok := ParseCopyFrom(statement.String); ok { + if query.StatementTag == "COPY" { + if target, format, options, ok := ParseCopyFrom(query.String); ok { stmt, err := parser.ParseOne("COPY " + target + " FROM STDIN") if err != nil { return false, true, err } copyFrom := stmt.AST.(*tree.CopyFrom) copyFrom.Options.CopyFormat = format - return true, false, h.handleCopyFromStdinQuery(statement, copyFrom, options) + return true, false, h.handleCopyFromStdinQuery(query, copyFrom, options) } - if subquery, format, options, ok := ParseCopyTo(statement.String); ok { + if subquery, format, options, ok := ParseCopyTo(query.String); ok { if strings.HasPrefix(subquery, "(") && strings.HasSuffix(subquery, ")") { // subquery may be richer than Postgres supports, so we just pass it as a string - return true, true, h.handleCopyToStdout(statement, nil, subquery, format, options) + return true, true, h.handleCopyToStdout(query, nil, subquery, format, options) } // subquery is "table [(column_list)]", so we can parse it and pass the AST stmt, err := parser.ParseOne("COPY " + subquery + " TO STDOUT") @@ -545,11 +534,11 @@ func (h *ConnectionHandler) handleStatementOutsideEngine(statement ConvertedStat } copyTo := stmt.AST.(*tree.CopyTo) copyTo.Options.CopyFormat = format - return true, true, h.handleCopyToStdout(statement, copyTo, "", format, options) + return true, true, h.handleCopyToStdout(query, copyTo, "", format, options) } } - handled, err = h.handlePgCatalogQueries(statement) + handled, err = h.handlePgCatalogQueries(query) if handled || err != nil { return true, true, err } @@ -562,28 +551,26 @@ func (h *ConnectionHandler) handleParse(message *pgproto3.Parse) error { h.waitForSync = true // TODO: "Named prepared statements must be explicitly closed before they can be redefined by another Parse message, but this is not required for the unnamed statement" - statements, err := h.convertQuery(message.Query) + query, err := h.convertQuery(message.Query) if err != nil { return err } - // TODO(Noy): handle multiple statements - statement := statements[0] - if statement.AST == nil { + if query.AST == nil { // special case: empty query h.preparedStatements[message.Name] = PreparedStatementData{ - Statement: statement, + Query: query, } return nil } - handledOutsideEngine, err := shouldQueryBeHandledInPlace(statement) + handledOutsideEngine, err := shouldQueryBeHandledInPlace(query) if err != nil { return err } if handledOutsideEngine { h.preparedStatements[message.Name] = PreparedStatementData{ - Statement: statement, + Query: query, ReturnFields: nil, BindVarTypes: nil, Stmt: nil, @@ -592,13 +579,13 @@ func (h *ConnectionHandler) handleParse(message *pgproto3.Parse) error { return h.send(&pgproto3.ParseComplete{}) } - stmt, params, fields, err := h.duckHandler.ComPrepareParsed(context.Background(), h.mysqlConn, statement.String, statement.AST) + stmt, params, fields, err := h.duckHandler.ComPrepareParsed(context.Background(), h.mysqlConn, query.String, query.AST) if err != nil { return err } - if !statement.PgParsable { - statement.Tag = GetStatementTag(stmt) + if !query.PgParsable { + query.StatementTag = GetStatementTag(stmt) } // https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-EXT-QUERY @@ -619,7 +606,7 @@ func (h *ConnectionHandler) handleParse(message *pgproto3.Parse) error { } } h.preparedStatements[message.Name] = PreparedStatementData{ - Statement: statement, + Query: query, ReturnFields: fields, BindVarTypes: bindVarTypes, Stmt: stmt, @@ -652,7 +639,7 @@ func (h *ConnectionHandler) handleDescribe(message *pgproto3.Describe) error { } bindvarTypes = preparedStatementData.BindVarTypes - tag = preparedStatementData.Statement.Tag + tag = preparedStatementData.Query.StatementTag } if bindvarTypes == nil { @@ -666,7 +653,7 @@ func (h *ConnectionHandler) handleDescribe(message *pgproto3.Describe) error { if portalData.Stmt != nil { fields = portalData.Fields - tag = portalData.Statement.Tag + tag = portalData.Query.StatementTag } } @@ -687,18 +674,18 @@ func (h *ConnectionHandler) handleBind(message *pgproto3.Bind) error { if preparedData.Stmt == nil { h.portals[message.DestinationPortal] = PortalData{ - Statement: preparedData.Statement, - Fields: nil, - Stmt: nil, - Vars: nil, + Query: preparedData.Query, + Fields: nil, + Stmt: nil, + Vars: nil, } return h.send(&pgproto3.BindComplete{}) } - if preparedData.Statement.AST == nil { + if preparedData.Query.AST == nil { // special case: empty query h.portals[message.DestinationPortal] = PortalData{ - Statement: preparedData.Statement, + Query: preparedData.Query, IsEmptyQuery: true, } return h.send(&pgproto3.BindComplete{}) @@ -715,7 +702,7 @@ func (h *ConnectionHandler) handleBind(message *pgproto3.Bind) error { } h.portals[message.DestinationPortal] = PortalData{ - Statement: preparedData.Statement, + Query: preparedData.Query, Fields: fields, ResultFormatCodes: message.ResultFormatCodes, Stmt: preparedData.Stmt, @@ -736,14 +723,14 @@ func (h *ConnectionHandler) handleExecute(message *pgproto3.Execute) error { } logrus.Tracef("executing portal %s with contents %v", message.Portal, portalData) - query := portalData.Statement + query := portalData.Query if portalData.IsEmptyQuery { return h.send(&pgproto3.EmptyQueryResponse{}) } // Certain statement types get handled directly by the handler instead of being passed to the engine - handled, _, err := h.handleStatementOutsideEngine(query) + handled, _, err := h.handleQueryOutsideEngine(query) if handled { return err } @@ -751,13 +738,13 @@ func (h *ConnectionHandler) handleExecute(message *pgproto3.Execute) error { // |rowsAffected| gets altered by the callback below rowsAffected := int32(0) - callback := h.spoolRowsCallback(query.Tag, &rowsAffected, true) + callback := h.spoolRowsCallback(query.StatementTag, &rowsAffected, true) err = h.duckHandler.ComExecuteBound(context.Background(), h.mysqlConn, portalData, callback) if err != nil { return err } - return h.send(makeCommandComplete(query.Tag, rowsAffected)) + return h.send(makeCommandComplete(query.StatementTag, rowsAffected)) } func makeCommandComplete(tag string, rows int32) *pgproto3.CommandComplete { @@ -925,7 +912,7 @@ func (h *ConnectionHandler) handleCopyFail(_ *pgproto3.CopyFail) (stop bool, end return false, true, nil } -func (h *ConnectionHandler) deallocatePreparedStatement(name string, preparedStatements map[string]PreparedStatementData, query ConvertedStatement, conn net.Conn) error { +func (h *ConnectionHandler) deallocatePreparedStatement(name string, preparedStatements map[string]PreparedStatementData, query ConvertedQuery, conn net.Conn) error { _, ok := preparedStatements[name] if !ok { return fmt.Errorf("prepared statement %s does not exist", name) @@ -933,7 +920,7 @@ func (h *ConnectionHandler) deallocatePreparedStatement(name string, preparedSta h.deletePreparedStatement(name) return h.send(&pgproto3.CommandComplete{ - CommandTag: []byte(query.Tag), + CommandTag: []byte(query.StatementTag), }) } @@ -978,27 +965,27 @@ func (h *ConnectionHandler) convertBindParameters(types []uint32, formatCodes [] return vars, nil } -// runStatement runs the given query and sends a CommandComplete message to the client -func (h *ConnectionHandler) runStatement(statement ConvertedStatement) error { - h.logger.Tracef("running statement %v", statement) +// query runs the given query and sends a CommandComplete message to the client +func (h *ConnectionHandler) query(query ConvertedQuery) error { + h.logger.Tracef("running query %v", query) // |rowsAffected| gets altered by the callback below rowsAffected := int32(0) - // Get the accurate statement tag for the statement - if !statement.PgParsable && !IsWellKnownStatementTag(statement.Tag) { - tag, err := h.duckHandler.getStatementTag(h.mysqlConn, statement.String) + // Get the accurate statement tag for the query + if !query.PgParsable && !IsWellKnownStatementTag(query.StatementTag) { + tag, err := h.duckHandler.getStatementTag(h.mysqlConn, query.String) if err != nil { return err } - h.logger.Tracef("getting statement tag for statement %v via preparing in DuckDB: %s", statement, tag) - statement.Tag = tag + h.logger.Tracef("getting statement tag for query %v via preparing in DuckDB: %s", query, tag) + query.StatementTag = tag } - if statement.SubscriptionConfig != nil { - return h.executeSubscriptionSQL(statement.SubscriptionConfig) - } else if statement.BackupConfig != nil { - msg, err := h.executeBackup(statement.BackupConfig) + if query.SubscriptionConfig != nil { + return h.executeSubscriptionSQL(query.SubscriptionConfig) + } else if query.BackupConfig != nil { + msg, err := h.executeBackup(query.BackupConfig) if err != nil { return err } @@ -1007,18 +994,18 @@ func (h *ConnectionHandler) runStatement(statement ConvertedStatement) error { }) } - callback := h.spoolRowsCallback(statement.Tag, &rowsAffected, false) + callback := h.spoolRowsCallback(query.StatementTag, &rowsAffected, false) if err := h.duckHandler.ComQuery( context.Background(), h.mysqlConn, - statement.String, - statement.AST, + query.String, + query.AST, callback, ); err != nil { - return fmt.Errorf("fallback statement execution failed: %w", err) + return fmt.Errorf("fallback query execution failed: %w", err) } - return h.send(makeCommandComplete(statement.Tag, rowsAffected)) + return h.send(makeCommandComplete(query.StatementTag, rowsAffected)) } // spoolRowsCallback returns a callback function that will send RowDescription message, @@ -1089,7 +1076,7 @@ func (h *ConnectionHandler) handledPSQLCommands(statement string) (bool, error) if err != nil { return false, err } - return true, h.runStatement(query[0]) + return true, h.query(query) } // Command: \l on psql 16 if statement == "select\n d.datname as \"name\",\n pg_catalog.pg_get_userbyid(d.datdba) as \"owner\",\n pg_catalog.pg_encoding_to_char(d.encoding) as \"encoding\",\n case d.datlocprovider when 'c' then 'libc' when 'i' then 'icu' end as \"locale provider\",\n d.datcollate as \"collate\",\n d.datctype as \"ctype\",\n d.daticulocale as \"icu locale\",\n null as \"icu rules\",\n pg_catalog.array_to_string(d.datacl, e'\\n') as \"access privileges\"\nfrom pg_catalog.pg_database d\norder by 1;" { @@ -1097,27 +1084,27 @@ func (h *ConnectionHandler) handledPSQLCommands(statement string) (bool, error) if err != nil { return false, err } - return true, h.runStatement(query[0]) + return true, h.query(query) } // Command: \dt if statement == "select n.nspname as \"schema\",\n c.relname as \"name\",\n case c.relkind when 'r' then 'table' when 'v' then 'view' when 'm' then 'materialized view' when 'i' then 'index' when 's' then 'sequence' when 't' then 'toast table' when 'f' then 'foreign table' when 'p' then 'partitioned table' when 'i' then 'partitioned index' end as \"type\",\n pg_catalog.pg_get_userbyid(c.relowner) as \"owner\"\nfrom pg_catalog.pg_class c\n left join pg_catalog.pg_namespace n on n.oid = c.relnamespace\n left join pg_catalog.pg_am am on am.oid = c.relam\nwhere c.relkind in ('r','p','')\n and n.nspname <> 'pg_catalog'\n and n.nspname !~ '^pg_toast'\n and n.nspname <> 'information_schema'\n and pg_catalog.pg_table_is_visible(c.oid)\norder by 1,2;" { - return true, h.runStatement(ConvertedStatement{ - String: `SELECT table_schema AS "Schema", TABLE_NAME AS "Name", 'table' AS "Type", 'postgres' AS "Owner" FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA <> 'pg_catalog' AND TABLE_SCHEMA <> 'information_schema' AND TABLE_TYPE = 'BASE TABLE' ORDER BY 2;`, - Tag: "SELECT", + return true, h.query(ConvertedQuery{ + String: `SELECT table_schema AS "Schema", TABLE_NAME AS "Name", 'table' AS "Type", 'postgres' AS "Owner" FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA <> 'pg_catalog' AND TABLE_SCHEMA <> 'information_schema' AND TABLE_TYPE = 'BASE TABLE' ORDER BY 2;`, + StatementTag: "SELECT", }) } // Command: \d if statement == "select n.nspname as \"schema\",\n c.relname as \"name\",\n case c.relkind when 'r' then 'table' when 'v' then 'view' when 'm' then 'materialized view' when 'i' then 'index' when 's' then 'sequence' when 't' then 'toast table' when 'f' then 'foreign table' when 'p' then 'partitioned table' when 'i' then 'partitioned index' end as \"type\",\n pg_catalog.pg_get_userbyid(c.relowner) as \"owner\"\nfrom pg_catalog.pg_class c\n left join pg_catalog.pg_namespace n on n.oid = c.relnamespace\n left join pg_catalog.pg_am am on am.oid = c.relam\nwhere c.relkind in ('r','p','v','m','s','f','')\n and n.nspname <> 'pg_catalog'\n and n.nspname !~ '^pg_toast'\n and n.nspname <> 'information_schema'\n and pg_catalog.pg_table_is_visible(c.oid)\norder by 1,2;" { - return true, h.runStatement(ConvertedStatement{ - String: `SELECT table_schema AS "Schema", TABLE_NAME AS "Name", IF(TABLE_TYPE = 'VIEW', 'view', 'table') AS "Type", 'postgres' AS "Owner" FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA <> 'pg_catalog' AND TABLE_SCHEMA <> 'information_schema' AND TABLE_TYPE = 'BASE TABLE' OR TABLE_TYPE = 'VIEW' ORDER BY 2;`, - Tag: "SELECT", + return true, h.query(ConvertedQuery{ + String: `SELECT table_schema AS "Schema", TABLE_NAME AS "Name", IF(TABLE_TYPE = 'VIEW', 'view', 'table') AS "Type", 'postgres' AS "Owner" FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA <> 'pg_catalog' AND TABLE_SCHEMA <> 'information_schema' AND TABLE_TYPE = 'BASE TABLE' OR TABLE_TYPE = 'VIEW' ORDER BY 2;`, + StatementTag: "SELECT", }) } // Alternate \d for psql 14 if statement == "select n.nspname as \"schema\",\n c.relname as \"name\",\n case c.relkind when 'r' then 'table' when 'v' then 'view' when 'm' then 'materialized view' when 'i' then 'index' when 's' then 'sequence' when 's' then 'special' when 't' then 'toast table' when 'f' then 'foreign table' when 'p' then 'partitioned table' when 'i' then 'partitioned index' end as \"type\",\n pg_catalog.pg_get_userbyid(c.relowner) as \"owner\"\nfrom pg_catalog.pg_class c\n left join pg_catalog.pg_namespace n on n.oid = c.relnamespace\n left join pg_catalog.pg_am am on am.oid = c.relam\nwhere c.relkind in ('r','p','v','m','s','f','')\n and n.nspname <> 'pg_catalog'\n and n.nspname !~ '^pg_toast'\n and n.nspname <> 'information_schema'\n and pg_catalog.pg_table_is_visible(c.oid)\norder by 1,2;" { - return true, h.runStatement(ConvertedStatement{ - String: `SELECT table_schema AS "Schema", TABLE_NAME AS "Name", IF(TABLE_TYPE = 'VIEW', 'view', 'table') AS "Type", 'postgres' AS "Owner" FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA <> 'pg_catalog' AND TABLE_SCHEMA <> 'information_schema' AND TABLE_TYPE = 'BASE TABLE' OR TABLE_TYPE = 'VIEW' ORDER BY 2;`, - Tag: "SELECT", + return true, h.query(ConvertedQuery{ + String: `SELECT table_schema AS "Schema", TABLE_NAME AS "Name", IF(TABLE_TYPE = 'VIEW', 'view', 'table') AS "Type", 'postgres' AS "Owner" FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA <> 'pg_catalog' AND TABLE_SCHEMA <> 'information_schema' AND TABLE_TYPE = 'BASE TABLE' OR TABLE_TYPE = 'VIEW' ORDER BY 2;`, + StatementTag: "SELECT", }) } // Command: \d table_name @@ -1128,31 +1115,31 @@ func (h *ConnectionHandler) handledPSQLCommands(statement string) (bool, error) } // Command: \dn if statement == "select n.nspname as \"name\",\n pg_catalog.pg_get_userbyid(n.nspowner) as \"owner\"\nfrom pg_catalog.pg_namespace n\nwhere n.nspname !~ '^pg_' and n.nspname <> 'information_schema'\norder by 1;" { - return true, h.runStatement(ConvertedStatement{ - String: `SELECT 'public' AS "Name", 'pg_database_owner' AS "Owner";`, - Tag: "SELECT", + return true, h.query(ConvertedQuery{ + String: `SELECT 'public' AS "Name", 'pg_database_owner' AS "Owner";`, + StatementTag: "SELECT", }) } // Command: \df if statement == "select n.nspname as \"schema\",\n p.proname as \"name\",\n pg_catalog.pg_get_function_result(p.oid) as \"result data type\",\n pg_catalog.pg_get_function_arguments(p.oid) as \"argument data types\",\n case p.prokind\n when 'a' then 'agg'\n when 'w' then 'window'\n when 'p' then 'proc'\n else 'func'\n end as \"type\"\nfrom pg_catalog.pg_proc p\n left join pg_catalog.pg_namespace n on n.oid = p.pronamespace\nwhere pg_catalog.pg_function_is_visible(p.oid)\n and n.nspname <> 'pg_catalog'\n and n.nspname <> 'information_schema'\norder by 1, 2, 4;" { - return true, h.runStatement(ConvertedStatement{ - String: `SELECT '' AS "Schema", '' AS "Name", '' AS "Result data type", '' AS "Argument data types", '' AS "Type" LIMIT 0;`, - Tag: "SELECT", + return true, h.query(ConvertedQuery{ + String: `SELECT '' AS "Schema", '' AS "Name", '' AS "Result data type", '' AS "Argument data types", '' AS "Type" LIMIT 0;`, + StatementTag: "SELECT", }) } // Command: \dv if statement == "select n.nspname as \"schema\",\n c.relname as \"name\",\n case c.relkind when 'r' then 'table' when 'v' then 'view' when 'm' then 'materialized view' when 'i' then 'index' when 's' then 'sequence' when 't' then 'toast table' when 'f' then 'foreign table' when 'p' then 'partitioned table' when 'i' then 'partitioned index' end as \"type\",\n pg_catalog.pg_get_userbyid(c.relowner) as \"owner\"\nfrom pg_catalog.pg_class c\n left join pg_catalog.pg_namespace n on n.oid = c.relnamespace\nwhere c.relkind in ('v','')\n and n.nspname <> 'pg_catalog'\n and n.nspname !~ '^pg_toast'\n and n.nspname <> 'information_schema'\n and pg_catalog.pg_table_is_visible(c.oid)\norder by 1,2;" { - return true, h.runStatement(ConvertedStatement{ - String: `SELECT table_schema AS "Schema", TABLE_NAME AS "Name", 'view' AS "Type", 'postgres' AS "Owner" FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA <> 'pg_catalog' AND TABLE_SCHEMA <> 'information_schema' AND TABLE_TYPE = 'VIEW' ORDER BY 2;`, - Tag: "SELECT", + return true, h.query(ConvertedQuery{ + String: `SELECT table_schema AS "Schema", TABLE_NAME AS "Name", 'view' AS "Type", 'postgres' AS "Owner" FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA <> 'pg_catalog' AND TABLE_SCHEMA <> 'information_schema' AND TABLE_TYPE = 'VIEW' ORDER BY 2;`, + StatementTag: "SELECT", }) } // Command: \du if statement == "select r.rolname, r.rolsuper, r.rolinherit,\n r.rolcreaterole, r.rolcreatedb, r.rolcanlogin,\n r.rolconnlimit, r.rolvaliduntil,\n array(select b.rolname\n from pg_catalog.pg_auth_members m\n join pg_catalog.pg_roles b on (m.roleid = b.oid)\n where m.member = r.oid) as memberof\n, r.rolreplication\n, r.rolbypassrls\nfrom pg_catalog.pg_roles r\nwhere r.rolname !~ '^pg_'\norder by 1;" { // We don't support users yet, so we'll just return nothing for now - return true, h.runStatement(ConvertedStatement{ - String: `SELECT '' FROM dual LIMIT 0;`, - Tag: "SELECT", + return true, h.query(ConvertedQuery{ + String: `SELECT '' FROM dual LIMIT 0;`, + StatementTag: "SELECT", }) } return false, nil @@ -1162,15 +1149,15 @@ func (h *ConnectionHandler) handledPSQLCommands(statement string) (bool, error) func (h *ConnectionHandler) handledWorkbenchCommands(statement string) (bool, error) { lower := strings.ToLower(statement) if lower == "select * from current_schema()" || lower == "select * from current_schema();" { - return true, h.runStatement(ConvertedStatement{ - String: `SELECT search_path AS "current_schema";`, - Tag: "SELECT", + return true, h.query(ConvertedQuery{ + String: `SELECT search_path AS "current_schema";`, + StatementTag: "SELECT", }) } if lower == "select * from current_database()" || lower == "select * from current_database();" { - return true, h.runStatement(ConvertedStatement{ - String: `SELECT DATABASE() AS "current_database";`, - Tag: "SELECT", + return true, h.query(ConvertedQuery{ + String: `SELECT DATABASE() AS "current_database";`, + StatementTag: "SELECT", }) } return false, nil @@ -1206,64 +1193,69 @@ func (h *ConnectionHandler) sendError(err error) { } } -// convertQuery takes the given Postgres query, and converts it as a list of ast.ConvertedStatement that will work with the handler. -func (h *ConnectionHandler) convertQuery(query string, modifiers ...QueryModifier) ([]ConvertedStatement, error) { +// convertQuery takes the given Postgres query, and converts it as an ast.ConvertedQuery that will work with the handler. +func (h *ConnectionHandler) convertQuery(query string, modifiers ...QueryModifier) (ConvertedQuery, error) { for _, modifier := range modifiers { query = modifier(query) } + parsable := true + // Check if the query is a subscription query, and if so, parse it as a subscription query. subscriptionConfig, err := parseSubscriptionSQL(query) if subscriptionConfig != nil && err == nil { - return []ConvertedStatement{{ + return ConvertedQuery{ String: query, PgParsable: true, SubscriptionConfig: subscriptionConfig, - }}, nil + }, nil } // Check if the query is a backup query, and if so, parse it as a backup query. backupConfig, err := parseBackupSQL(query) if backupConfig != nil && err == nil { - return []ConvertedStatement{{ + return ConvertedQuery{ String: query, PgParsable: true, BackupConfig: backupConfig, - }}, nil + }, nil } stmts, err := parser.Parse(query) if err != nil { // DuckDB syntax is not fully compatible with PostgreSQL, so we need to handle some queries differently. + parsable = false stmts, _ = parser.Parse("SELECT 'SQL syntax is incompatible with PostgreSQL' AS error") - return []ConvertedStatement{{ - String: query, - AST: stmts[0].AST, - Tag: GuessStatementTag(query), - PgParsable: false, - }}, nil } + if len(stmts) > 1 { + return ConvertedQuery{}, fmt.Errorf("only a single statement at a time is currently supported") + } if len(stmts) == 0 { - return []ConvertedStatement{{String: query}}, nil + return ConvertedQuery{String: query}, nil } - convertedStmts := make([]ConvertedStatement, len(stmts)) - for i, stmt := range stmts { - convertedStmts[i].String = stmt.SQL - convertedStmts[i].AST = stmt.AST - convertedStmts[i].Tag = stmt.AST.StatementTag() - convertedStmts[i].PgParsable = true + var stmtTag string + if parsable { + stmtTag = stmts[0].AST.StatementTag() + } else { + stmtTag = GuessStatementTag(query) } - return convertedStmts, nil + + return ConvertedQuery{ + String: query, + AST: stmts[0].AST, + StatementTag: stmtTag, + PgParsable: parsable, + }, nil } // discardAll handles the DISCARD ALL command -func (h *ConnectionHandler) discardAll(query ConvertedStatement) error { +func (h *ConnectionHandler) discardAll(query ConvertedQuery) error { h.closeBackendConn() return h.send(&pgproto3.CommandComplete{ - CommandTag: []byte(query.Tag), + CommandTag: []byte(query.StatementTag), }) } @@ -1271,7 +1263,7 @@ func (h *ConnectionHandler) discardAll(query ConvertedStatement) error { // COPY FROM STDIN can't be handled directly by the GMS engine, since COPY FROM STDIN relies on multiple messages sent // over the wire. func (h *ConnectionHandler) handleCopyFromStdinQuery( - query ConvertedStatement, copyFrom *tree.CopyFrom, + query ConvertedQuery, copyFrom *tree.CopyFrom, rawOptions string, // For non-PG-parseable COPY FROM ) error { sqlCtx, err := h.duckHandler.NewContext(context.Background(), h.mysqlConn, query.String) @@ -1335,7 +1327,7 @@ func returnsRow(tag string) bool { } } -func (h *ConnectionHandler) handleCopyToStdout(query ConvertedStatement, copyTo *tree.CopyTo, subquery string, format tree.CopyFormat, rawOptions string) error { +func (h *ConnectionHandler) handleCopyToStdout(query ConvertedQuery, copyTo *tree.CopyTo, subquery string, format tree.CopyFormat, rawOptions string) error { ctx, err := h.duckHandler.NewContext(context.Background(), h.mysqlConn, query.String) if err != nil { return err diff --git a/pgserver/duck_handler.go b/pgserver/duck_handler.go index 7ce5cf85..2c6475a7 100644 --- a/pgserver/duck_handler.go +++ b/pgserver/duck_handler.go @@ -107,7 +107,7 @@ func (h *DuckHandler) ComBind(ctx context.Context, c *mysql.Conn, prepared Prepa // ComExecuteBound implements the Handler interface. func (h *DuckHandler) ComExecuteBound(ctx context.Context, conn *mysql.Conn, portal PortalData, callback func(*Result) error) error { - err := h.doQuery(ctx, conn, portal.Statement.String, portal.Statement.AST, portal.Stmt, portal.Vars, portal.ResultFormatCodes, ExtendedQueryMode, h.executeBoundPlan, callback) + err := h.doQuery(ctx, conn, portal.Query.String, portal.Query.AST, portal.Stmt, portal.Vars, portal.ResultFormatCodes, ExtendedQueryMode, h.executeBoundPlan, callback) if err != nil { err = sql.CastSQLError(err) } diff --git a/pgserver/pg_catalog_handler.go b/pgserver/pg_catalog_handler.go index 17d0171f..026b4be4 100644 --- a/pgserver/pg_catalog_handler.go +++ b/pgserver/pg_catalog_handler.go @@ -130,9 +130,9 @@ func (h *ConnectionHandler) handleIsInRecovery() (bool, error) { if err != nil { return false, err } - return true, h.runStatement(ConvertedStatement{ - String: fmt.Sprintf(`SELECT '%s' AS "pg_is_in_recovery";`, isInRecovery), - Tag: "SELECT", + return true, h.query(ConvertedQuery{ + String: fmt.Sprintf(`SELECT '%s' AS "pg_is_in_recovery";`, isInRecovery), + StatementTag: "SELECT", }) } @@ -142,14 +142,14 @@ func (h *ConnectionHandler) handleWALSN() (bool, error) { if err != nil { return false, err } - return true, h.runStatement(ConvertedStatement{ - String: fmt.Sprintf(`SELECT '%s' AS "%s";`, lsnStr, "pg_current_wal_lsn"), - Tag: "SELECT", + return true, h.query(ConvertedQuery{ + String: fmt.Sprintf(`SELECT '%s' AS "%s";`, lsnStr, "pg_current_wal_lsn"), + StatementTag: "SELECT", }) } // handler for currentSetting -func (h *ConnectionHandler) handleCurrentSetting(query ConvertedStatement) (bool, error) { +func (h *ConnectionHandler) handleCurrentSetting(query ConvertedQuery) (bool, error) { sql := RemoveComments(query.String) matches := currentSettingRegex.FindStringSubmatch(sql) if len(matches) != 3 { @@ -159,38 +159,38 @@ func (h *ConnectionHandler) handleCurrentSetting(query ConvertedStatement) (bool if err != nil { return false, err } - return true, h.runStatement(ConvertedStatement{ - String: fmt.Sprintf(`SELECT '%s' AS "current_setting";`, fmt.Sprintf("%v", setting)), - Tag: "SELECT", + return true, h.query(ConvertedQuery{ + String: fmt.Sprintf(`SELECT '%s' AS "current_setting";`, fmt.Sprintf("%v", setting)), + StatementTag: "SELECT", }) } // handler for pgCatalog -func (h *ConnectionHandler) handlePgCatalog(query ConvertedStatement) (bool, error) { +func (h *ConnectionHandler) handlePgCatalog(query ConvertedQuery) (bool, error) { sql := RemoveComments(query.String) - return true, h.runStatement(ConvertedStatement{ - String: pgCatalogRegex.ReplaceAllString(sql, " FROM __sys__.$1"), - Tag: "SELECT", + return true, h.query(ConvertedQuery{ + String: pgCatalogRegex.ReplaceAllString(sql, " FROM __sys__.$1"), + StatementTag: "SELECT", }) } type PGCatalogHandler struct { // HandledInPlace is a function that determines if the query should be handled in place and not passed to the engine. - HandledInPlace func(ConvertedStatement) (bool, error) - Handler func(*ConnectionHandler, ConvertedStatement) (bool, error) + HandledInPlace func(ConvertedQuery) (bool, error) + Handler func(*ConnectionHandler, ConvertedQuery) (bool, error) } -func isPgIsInRecovery(query ConvertedStatement) bool { +func isPgIsInRecovery(query ConvertedQuery) bool { sql := RemoveComments(query.String) return pgIsInRecoveryRegex.MatchString(sql) } -func isPgWALSN(query ConvertedStatement) bool { +func isPgWALSN(query ConvertedQuery) bool { sql := RemoveComments(query.String) return pgWALLSNRegex.MatchString(sql) } -func isPgCurrentSetting(query ConvertedStatement) bool { +func isPgCurrentSetting(query ConvertedQuery) bool { sql := RemoveComments(query.String) if !currentSettingRegex.MatchString(sql) { return false @@ -206,7 +206,7 @@ func isPgCurrentSetting(query ConvertedStatement) bool { return true } -func isSpecialPgCatalog(query ConvertedStatement) bool { +func isSpecialPgCatalog(query ConvertedQuery) bool { sql := RemoveComments(query.String) return pgCatalogRegex.MatchString(sql) } @@ -214,7 +214,7 @@ func isSpecialPgCatalog(query ConvertedStatement) bool { // The key is the statement tag of the query. var pgCatalogHandlers = map[string]PGCatalogHandler{ "SELECT": { - HandledInPlace: func(query ConvertedStatement) (bool, error) { + HandledInPlace: func(query ConvertedQuery) (bool, error) { // TODO(sean): Evaluate the conditions by iterating over the AST. if isPgIsInRecovery(query) { return true, nil @@ -230,7 +230,7 @@ var pgCatalogHandlers = map[string]PGCatalogHandler{ } return false, nil }, - Handler: func(h *ConnectionHandler, query ConvertedStatement) (bool, error) { + Handler: func(h *ConnectionHandler, query ConvertedQuery) (bool, error) { if isPgIsInRecovery(query) { return h.handleIsInRecovery() } @@ -248,14 +248,14 @@ var pgCatalogHandlers = map[string]PGCatalogHandler{ }, }, "SHOW": { - HandledInPlace: func(query ConvertedStatement) (bool, error) { + HandledInPlace: func(query ConvertedQuery) (bool, error) { switch query.AST.(type) { case *tree.ShowVar: return true, nil } return false, nil }, - Handler: func(h *ConnectionHandler, query ConvertedStatement) (bool, error) { + Handler: func(h *ConnectionHandler, query ConvertedQuery) (bool, error) { showVar, ok := query.AST.(*tree.ShowVar) if !ok { return false, nil @@ -266,9 +266,9 @@ var pgCatalogHandlers = map[string]PGCatalogHandler{ if err != nil { return false, err } - return true, h.runStatement(ConvertedStatement{ - String: fmt.Sprintf(`SELECT '%s' AS "%s";`, fmt.Sprintf("%v", setting), key), - Tag: "SELECT", + return true, h.query(ConvertedQuery{ + String: fmt.Sprintf(`SELECT '%s' AS "%s";`, fmt.Sprintf("%v", setting), key), + StatementTag: "SELECT", }) } // TODO(sean): Implement SHOW ALL @@ -281,7 +281,7 @@ var pgCatalogHandlers = map[string]PGCatalogHandler{ }, }, "SET": { - HandledInPlace: func(query ConvertedStatement) (bool, error) { + HandledInPlace: func(query ConvertedQuery) (bool, error) { switch stmt := query.AST.(type) { case *tree.SetVar: key := strings.ToLower(stmt.Name) @@ -301,7 +301,7 @@ var pgCatalogHandlers = map[string]PGCatalogHandler{ } return false, nil }, - Handler: func(h *ConnectionHandler, query ConvertedStatement) (bool, error) { + Handler: func(h *ConnectionHandler, query ConvertedQuery) (bool, error) { setVar, ok := query.AST.(*tree.SetVar) if !ok { return false, fmt.Errorf("error: invalid set statement: %v", query.String) @@ -337,7 +337,7 @@ var pgCatalogHandlers = map[string]PGCatalogHandler{ }, }, "RESET": { - HandledInPlace: func(query ConvertedStatement) (bool, error) { + HandledInPlace: func(query ConvertedQuery) (bool, error) { switch stmt := query.AST.(type) { case *tree.SetVar: if !stmt.Reset && !stmt.ResetAll { @@ -351,7 +351,7 @@ var pgCatalogHandlers = map[string]PGCatalogHandler{ } return false, nil }, - Handler: func(h *ConnectionHandler, query ConvertedStatement) (bool, error) { + Handler: func(h *ConnectionHandler, query ConvertedQuery) (bool, error) { resetVar, ok := query.AST.(*tree.SetVar) if !ok || (!resetVar.Reset && !resetVar.ResetAll) { return false, fmt.Errorf("error: invalid reset statement: %v", query.String) @@ -378,8 +378,8 @@ var pgCatalogHandlers = map[string]PGCatalogHandler{ // shouldQueryBeHandledInPlace determines whether a query should be handled in place, rather than being // passed to the engine. This is useful for queries that are not supported by the engine, or that require // special handling. -func shouldQueryBeHandledInPlace(sql ConvertedStatement) (bool, error) { - handler, ok := pgCatalogHandlers[sql.Tag] +func shouldQueryBeHandledInPlace(sql ConvertedQuery) (bool, error) { + handler, ok := pgCatalogHandlers[sql.StatementTag] if !ok { return false, nil } @@ -392,8 +392,8 @@ func shouldQueryBeHandledInPlace(sql ConvertedStatement) (bool, error) { // TODO(sean): This is a temporary work around for clients that query the views from schema 'pg_catalog'. // Remove this once we add the views for 'pg_catalog'. -func (h *ConnectionHandler) handlePgCatalogQueries(sql ConvertedStatement) (bool, error) { - handler, ok := pgCatalogHandlers[sql.Tag] +func (h *ConnectionHandler) handlePgCatalogQueries(sql ConvertedQuery) (bool, error) { + handler, ok := pgCatalogHandlers[sql.StatementTag] if !ok { return false, nil }