From 499c931511726b6caa2a4c86c1e679ed92b00712 Mon Sep 17 00:00:00 2001 From: Noy Date: Mon, 16 Dec 2024 19:15:05 +0800 Subject: [PATCH] fix: resolve compatibility issues with PHP and R clients --- .github/workflows/clients-compatibility.yml | 7 ++--- compatibility/pg/php/pg_test.php | 11 +++++--- compatibility/pg/test.bats | 12 ++++----- pgserver/connection_data.go | 13 ++++----- pgserver/connection_handler.go | 11 ++++---- pgserver/duck_handler.go | 30 +++++++++++++-------- 6 files changed, 50 insertions(+), 34 deletions(-) diff --git a/.github/workflows/clients-compatibility.yml b/.github/workflows/clients-compatibility.yml index 8605e090..b81d7fd8 100644 --- a/.github/workflows/clients-compatibility.yml +++ b/.github/workflows/clients-compatibility.yml @@ -85,12 +85,12 @@ jobs: - name: Set up Python uses: actions/setup-python@v5 with: - python-version: '3.10' + python-version: '3.13' - name: Install system packages uses: awalsh128/cache-apt-pkgs-action@latest with: - packages: bats cpanminus libpq-dev postgresql-client dotnet-sdk-8.0 dotnet-runtime-8.0 r-base-core r-cran-rpostgresql + packages: bats cpanminus libpq-dev postgresql-client dotnet-sdk-8.0 dotnet-runtime-8.0 r-base-core version: 1.0 - name: Install dependencies @@ -109,7 +109,8 @@ jobs: npm install pg sudo cpanm --notest DBD::Pg pip3 install "psycopg[binary]" pandas pyarrow polars - # sudo R -e "install.packages('RPostgres', repos='http://cran.r-project.org')" + # TODO: Speed up the installation of RPostgres + sudo R -e "install.packages('RPostgres', repos='http://cran.r-project.org')" sudo gem install pg - name: Build diff --git a/compatibility/pg/php/pg_test.php b/compatibility/pg/php/pg_test.php index 4e203b56..3c51bb42 100644 --- a/compatibility/pg/php/pg_test.php +++ b/compatibility/pg/php/pg_test.php @@ -100,9 +100,14 @@ public function run($conn) { return false; } foreach ($rows as $i => $row) { - if ($row != $this->expectedResults[$i]) { - echo "Expected: " . implode(", ", $this->expectedResults[$i]) . ", got: " . implode(", ", $row) . "\n"; - return false; + $row = array_map('strval', $row); + foreach ($row as $j => $value) { + $value = trim($value); + $expectedValue = trim($this->expectedResults[$i][$j]); + if ($value !== $expectedValue) { + echo "Expected: " . $expectedValue . ", got: " . $value . "\n"; + return false; + } } } echo "Returns " . count($rows) . " rows\n"; diff --git a/compatibility/pg/test.bats b/compatibility/pg/test.bats index 6d860830..768cc2f0 100644 --- a/compatibility/pg/test.bats +++ b/compatibility/pg/test.bats @@ -65,17 +65,17 @@ start_process() { start_process perl $BATS_TEST_DIRNAME/perl/pg_test.pl 127.0.0.1 5432 postgres "" $BATS_TEST_DIRNAME/test.data } -# @test "pg-php" { -# start_process php $BATS_TEST_DIRNAME/php/pg_test.php 127.0.0.1 5432 postgres "" $BATS_TEST_DIRNAME/test.data -# } +@test "pg-php" { + start_process php $BATS_TEST_DIRNAME/php/pg_test.php 127.0.0.1 5432 postgres "" $BATS_TEST_DIRNAME/test.data +} @test "pg-python" { start_process python3 $BATS_TEST_DIRNAME/python/pg_test.py 127.0.0.1 5432 postgres "" $BATS_TEST_DIRNAME/test.data } -# @test "pg-r" { -# start_process Rscript $BATS_TEST_DIRNAME/r/PGTest.R 127.0.0.1 5432 postgres "" $BATS_TEST_DIRNAME/test.data -# } +@test "pg-r" { + start_process Rscript $BATS_TEST_DIRNAME/r/PGTest.R 127.0.0.1 5432 postgres "" $BATS_TEST_DIRNAME/test.data +} @test "pg-ruby" { start_process ruby $BATS_TEST_DIRNAME/ruby/pg_test.rb 127.0.0.1 5432 postgres "" $BATS_TEST_DIRNAME/test.data diff --git a/pgserver/connection_data.go b/pgserver/connection_data.go index 1a8d5d3b..f5a596e4 100644 --- a/pgserver/connection_data.go +++ b/pgserver/connection_data.go @@ -86,12 +86,13 @@ type copyFromStdinState struct { } type PortalData struct { - Query ConvertedQuery - IsEmptyQuery bool - Fields []pgproto3.FieldDescription - Stmt *duckdb.Stmt - Vars []any - Closed *atomic.Bool + Query ConvertedQuery + IsEmptyQuery bool + Fields []pgproto3.FieldDescription + ResultFormatCodes []int16 + Stmt *duckdb.Stmt + Vars []any + Closed *atomic.Bool } type PreparedStatementData struct { diff --git a/pgserver/connection_handler.go b/pgserver/connection_handler.go index b24242d6..3b6e3164 100644 --- a/pgserver/connection_handler.go +++ b/pgserver/connection_handler.go @@ -702,11 +702,12 @@ func (h *ConnectionHandler) handleBind(message *pgproto3.Bind) error { } h.portals[message.DestinationPortal] = PortalData{ - Query: preparedData.Query, - Fields: fields, - Stmt: preparedData.Stmt, - Closed: preparedData.Closed, - Vars: bindVars, + Query: preparedData.Query, + Fields: fields, + ResultFormatCodes: message.ResultFormatCodes, + Stmt: preparedData.Stmt, + Closed: preparedData.Closed, + Vars: bindVars, } return h.send(&pgproto3.BindComplete{}) } diff --git a/pgserver/duck_handler.go b/pgserver/duck_handler.go index 0b7390c5..2c6475a7 100644 --- a/pgserver/duck_handler.go +++ b/pgserver/duck_handler.go @@ -102,13 +102,12 @@ func (h *DuckHandler) ComBind(ctx context.Context, c *mysql.Conn, prepared Prepa return nil, err } - // TODO(fan): Theoretically, the field descriptions may change after binding. return prepared.ReturnFields, nil } // ComExecuteBound implements the Handler interface. func (h *DuckHandler) ComExecuteBound(ctx context.Context, conn *mysql.Conn, portal PortalData, callback func(*Result) error) error { - err := h.doQuery(ctx, conn, portal.Query.String, portal.Query.AST, portal.Stmt, portal.Vars, ExtendedQueryMode, h.executeBoundPlan, callback) + err := h.doQuery(ctx, conn, portal.Query.String, portal.Query.AST, portal.Stmt, portal.Vars, portal.ResultFormatCodes, ExtendedQueryMode, h.executeBoundPlan, callback) if err != nil { err = sql.CastSQLError(err) } @@ -200,7 +199,7 @@ func (h *DuckHandler) ComPrepareParsed(ctx context.Context, c *mysql.Conn, query if err != nil { break } - fields = schemaToFieldDescriptions(sqlCtx, schema, ExtendedQueryMode) + fields = schemaToFieldDescriptions(sqlCtx, schema, nil, ExtendedQueryMode) default: // For other statements, we just return the "affected rows" field. fields = []pgproto3.FieldDescription{ @@ -221,7 +220,7 @@ func (h *DuckHandler) ComPrepareParsed(ctx context.Context, c *mysql.Conn, query // ComQuery implements the Handler interface. func (h *DuckHandler) ComQuery(ctx context.Context, c *mysql.Conn, query string, parsed tree.Statement, callback func(*Result) error) error { - err := h.doQuery(ctx, c, query, parsed, nil, nil, SimpleQueryMode, h.executeQuery, callback) + err := h.doQuery(ctx, c, query, parsed, nil, nil, nil, SimpleQueryMode, h.executeQuery, callback) if err != nil { err = sql.CastSQLError(err) } @@ -298,7 +297,7 @@ func (h *DuckHandler) getStatementTag(mysqlConn *mysql.Conn, query string) (stri var queryLoggingRegex = regexp.MustCompile(`[\r\n\t ]+`) -func (h *DuckHandler) doQuery(ctx context.Context, c *mysql.Conn, query string, parsed tree.Statement, stmt *duckdb.Stmt, vars []any, mode QueryMode, queryExec QueryExecutor, callback func(*Result) error) error { +func (h *DuckHandler) doQuery(ctx context.Context, c *mysql.Conn, query string, parsed tree.Statement, stmt *duckdb.Stmt, vars []any, resultFormatCodes []int16, mode QueryMode, queryExec QueryExecutor, callback func(*Result) error) error { sqlCtx, err := h.sm.NewContextWithQuery(ctx, c, query) if err != nil { return err @@ -358,10 +357,10 @@ func (h *DuckHandler) doQuery(ctx context.Context, c *mysql.Conn, query string, } else if schema == nil { r, err = resultForEmptyIter(sqlCtx, rowIter) } else if analyzer.FlagIsSet(qFlags, sql.QFlagMax1Row) { - resultFields := schemaToFieldDescriptions(sqlCtx, schema, mode) + resultFields := schemaToFieldDescriptions(sqlCtx, schema, resultFormatCodes, mode) r, err = resultForMax1RowIter(sqlCtx, schema, rowIter, resultFields) } else { - resultFields := schemaToFieldDescriptions(sqlCtx, schema, mode) + resultFields := schemaToFieldDescriptions(sqlCtx, schema, resultFormatCodes, mode) r, processedAtLeastOneBatch, err = h.resultForDefaultIter(sqlCtx, schema, rowIter, callback, resultFields) } if err != nil { @@ -569,7 +568,7 @@ func (h *DuckHandler) maybeReleaseAllLocks(c *mysql.Conn) { } } -func schemaToFieldDescriptions(ctx *sql.Context, s sql.Schema, mode QueryMode) []pgproto3.FieldDescription { +func schemaToFieldDescriptions(ctx *sql.Context, s sql.Schema, resultFormatCodes []int16, mode QueryMode) []pgproto3.FieldDescription { fields := make([]pgproto3.FieldDescription, len(s)) for i, c := range s { var oid uint32 @@ -581,9 +580,18 @@ func schemaToFieldDescriptions(ctx *sql.Context, s sql.Schema, mode QueryMode) [ if mode == SimpleQueryMode { // https://www.postgresql.org/docs/current/protocol-flow.html // > In simple Query mode, the format of retrieved values is always text, except ... - format = pgtype.TextFormatCode + format = pgproto3.TextFormat } else { - format = pgType.PG.Codec.PreferredFormat() + if resultFormatCodes != nil && len(resultFormatCodes) > 0 { + // Specified overall or per-column format codes + if len(resultFormatCodes) == 1 { + format = resultFormatCodes[0] + } else { + format = resultFormatCodes[i] + } + } else { + format = pgType.PG.Codec.PreferredFormat() + } } size = int16(pgType.Size) } else { @@ -592,7 +600,7 @@ func schemaToFieldDescriptions(ctx *sql.Context, s sql.Schema, mode QueryMode) [ panic(err) } size = int16(c.Type.MaxTextResponseByteLength(ctx)) - format = 0 + format = pgproto3.TextFormat } // "Format" field: The format code being used for the field.