Skip to content

Commit

Permalink
fix: resolve compatibility issues with PHP and R clients (#291)
Browse files Browse the repository at this point in the history
  • Loading branch information
NoyException authored Dec 16, 2024
1 parent 0023897 commit 167d6de
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 34 deletions.
7 changes: 4 additions & 3 deletions .github/workflows/clients-compatibility.yml
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,12 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.10'
python-version: '3.13'

- name: Install system packages
uses: awalsh128/cache-apt-pkgs-action@latest
with:
packages: bats cpanminus libpq-dev postgresql-client dotnet-sdk-8.0 dotnet-runtime-8.0 r-base-core r-cran-rpostgresql
packages: bats cpanminus libpq-dev postgresql-client dotnet-sdk-8.0 dotnet-runtime-8.0 r-base-core
version: 1.0

- name: Install dependencies
Expand All @@ -109,7 +109,8 @@ jobs:
npm install pg
sudo cpanm --notest DBD::Pg
pip3 install "psycopg[binary]" pandas pyarrow polars
# sudo R -e "install.packages('RPostgres', repos='http://cran.r-project.org')"
# TODO: Speed up the installation of RPostgres
sudo R -e "install.packages('RPostgres', repos='http://cran.r-project.org')"
sudo gem install pg
- name: Build
Expand Down
11 changes: 8 additions & 3 deletions compatibility/pg/php/pg_test.php
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,14 @@ public function run($conn) {
return false;
}
foreach ($rows as $i => $row) {
if ($row != $this->expectedResults[$i]) {
echo "Expected: " . implode(", ", $this->expectedResults[$i]) . ", got: " . implode(", ", $row) . "\n";
return false;
$row = array_map('strval', $row);
foreach ($row as $j => $value) {
$value = trim($value);
$expectedValue = trim($this->expectedResults[$i][$j]);
if ($value !== $expectedValue) {
echo "Expected: " . $expectedValue . ", got: " . $value . "\n";
return false;
}
}
}
echo "Returns " . count($rows) . " rows\n";
Expand Down
12 changes: 6 additions & 6 deletions compatibility/pg/test.bats
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,17 @@ start_process() {
start_process perl $BATS_TEST_DIRNAME/perl/pg_test.pl 127.0.0.1 5432 postgres "" $BATS_TEST_DIRNAME/test.data
}

# @test "pg-php" {
# start_process php $BATS_TEST_DIRNAME/php/pg_test.php 127.0.0.1 5432 postgres "" $BATS_TEST_DIRNAME/test.data
# }
@test "pg-php" {
start_process php $BATS_TEST_DIRNAME/php/pg_test.php 127.0.0.1 5432 postgres "" $BATS_TEST_DIRNAME/test.data
}

@test "pg-python" {
start_process python3 $BATS_TEST_DIRNAME/python/pg_test.py 127.0.0.1 5432 postgres "" $BATS_TEST_DIRNAME/test.data
}

# @test "pg-r" {
# start_process Rscript $BATS_TEST_DIRNAME/r/PGTest.R 127.0.0.1 5432 postgres "" $BATS_TEST_DIRNAME/test.data
# }
@test "pg-r" {
start_process Rscript $BATS_TEST_DIRNAME/r/PGTest.R 127.0.0.1 5432 postgres "" $BATS_TEST_DIRNAME/test.data
}

@test "pg-ruby" {
start_process ruby $BATS_TEST_DIRNAME/ruby/pg_test.rb 127.0.0.1 5432 postgres "" $BATS_TEST_DIRNAME/test.data
Expand Down
13 changes: 7 additions & 6 deletions pgserver/connection_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,13 @@ type copyFromStdinState struct {
}

type PortalData struct {
Query ConvertedQuery
IsEmptyQuery bool
Fields []pgproto3.FieldDescription
Stmt *duckdb.Stmt
Vars []any
Closed *atomic.Bool
Query ConvertedQuery
IsEmptyQuery bool
Fields []pgproto3.FieldDescription
ResultFormatCodes []int16
Stmt *duckdb.Stmt
Vars []any
Closed *atomic.Bool
}

type PreparedStatementData struct {
Expand Down
11 changes: 6 additions & 5 deletions pgserver/connection_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -702,11 +702,12 @@ func (h *ConnectionHandler) handleBind(message *pgproto3.Bind) error {
}

h.portals[message.DestinationPortal] = PortalData{
Query: preparedData.Query,
Fields: fields,
Stmt: preparedData.Stmt,
Closed: preparedData.Closed,
Vars: bindVars,
Query: preparedData.Query,
Fields: fields,
ResultFormatCodes: message.ResultFormatCodes,
Stmt: preparedData.Stmt,
Closed: preparedData.Closed,
Vars: bindVars,
}
return h.send(&pgproto3.BindComplete{})
}
Expand Down
30 changes: 19 additions & 11 deletions pgserver/duck_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,12 @@ func (h *DuckHandler) ComBind(ctx context.Context, c *mysql.Conn, prepared Prepa
return nil, err
}

// TODO(fan): Theoretically, the field descriptions may change after binding.
return prepared.ReturnFields, nil
}

// ComExecuteBound implements the Handler interface.
func (h *DuckHandler) ComExecuteBound(ctx context.Context, conn *mysql.Conn, portal PortalData, callback func(*Result) error) error {
err := h.doQuery(ctx, conn, portal.Query.String, portal.Query.AST, portal.Stmt, portal.Vars, ExtendedQueryMode, h.executeBoundPlan, callback)
err := h.doQuery(ctx, conn, portal.Query.String, portal.Query.AST, portal.Stmt, portal.Vars, portal.ResultFormatCodes, ExtendedQueryMode, h.executeBoundPlan, callback)
if err != nil {
err = sql.CastSQLError(err)
}
Expand Down Expand Up @@ -200,7 +199,7 @@ func (h *DuckHandler) ComPrepareParsed(ctx context.Context, c *mysql.Conn, query
if err != nil {
break
}
fields = schemaToFieldDescriptions(sqlCtx, schema, ExtendedQueryMode)
fields = schemaToFieldDescriptions(sqlCtx, schema, nil, ExtendedQueryMode)
default:
// For other statements, we just return the "affected rows" field.
fields = []pgproto3.FieldDescription{
Expand All @@ -221,7 +220,7 @@ func (h *DuckHandler) ComPrepareParsed(ctx context.Context, c *mysql.Conn, query

// ComQuery implements the Handler interface.
func (h *DuckHandler) ComQuery(ctx context.Context, c *mysql.Conn, query string, parsed tree.Statement, callback func(*Result) error) error {
err := h.doQuery(ctx, c, query, parsed, nil, nil, SimpleQueryMode, h.executeQuery, callback)
err := h.doQuery(ctx, c, query, parsed, nil, nil, nil, SimpleQueryMode, h.executeQuery, callback)
if err != nil {
err = sql.CastSQLError(err)
}
Expand Down Expand Up @@ -298,7 +297,7 @@ func (h *DuckHandler) getStatementTag(mysqlConn *mysql.Conn, query string) (stri

var queryLoggingRegex = regexp.MustCompile(`[\r\n\t ]+`)

func (h *DuckHandler) doQuery(ctx context.Context, c *mysql.Conn, query string, parsed tree.Statement, stmt *duckdb.Stmt, vars []any, mode QueryMode, queryExec QueryExecutor, callback func(*Result) error) error {
func (h *DuckHandler) doQuery(ctx context.Context, c *mysql.Conn, query string, parsed tree.Statement, stmt *duckdb.Stmt, vars []any, resultFormatCodes []int16, mode QueryMode, queryExec QueryExecutor, callback func(*Result) error) error {
sqlCtx, err := h.sm.NewContextWithQuery(ctx, c, query)
if err != nil {
return err
Expand Down Expand Up @@ -358,10 +357,10 @@ func (h *DuckHandler) doQuery(ctx context.Context, c *mysql.Conn, query string,
} else if schema == nil {
r, err = resultForEmptyIter(sqlCtx, rowIter)
} else if analyzer.FlagIsSet(qFlags, sql.QFlagMax1Row) {
resultFields := schemaToFieldDescriptions(sqlCtx, schema, mode)
resultFields := schemaToFieldDescriptions(sqlCtx, schema, resultFormatCodes, mode)
r, err = resultForMax1RowIter(sqlCtx, schema, rowIter, resultFields)
} else {
resultFields := schemaToFieldDescriptions(sqlCtx, schema, mode)
resultFields := schemaToFieldDescriptions(sqlCtx, schema, resultFormatCodes, mode)
r, processedAtLeastOneBatch, err = h.resultForDefaultIter(sqlCtx, schema, rowIter, callback, resultFields)
}
if err != nil {
Expand Down Expand Up @@ -569,7 +568,7 @@ func (h *DuckHandler) maybeReleaseAllLocks(c *mysql.Conn) {
}
}

func schemaToFieldDescriptions(ctx *sql.Context, s sql.Schema, mode QueryMode) []pgproto3.FieldDescription {
func schemaToFieldDescriptions(ctx *sql.Context, s sql.Schema, resultFormatCodes []int16, mode QueryMode) []pgproto3.FieldDescription {
fields := make([]pgproto3.FieldDescription, len(s))
for i, c := range s {
var oid uint32
Expand All @@ -581,9 +580,18 @@ func schemaToFieldDescriptions(ctx *sql.Context, s sql.Schema, mode QueryMode) [
if mode == SimpleQueryMode {
// https://www.postgresql.org/docs/current/protocol-flow.html
// > In simple Query mode, the format of retrieved values is always text, except ...
format = pgtype.TextFormatCode
format = pgproto3.TextFormat
} else {
format = pgType.PG.Codec.PreferredFormat()
if resultFormatCodes != nil && len(resultFormatCodes) > 0 {
// Specified overall or per-column format codes
if len(resultFormatCodes) == 1 {
format = resultFormatCodes[0]
} else {
format = resultFormatCodes[i]
}
} else {
format = pgType.PG.Codec.PreferredFormat()
}
}
size = int16(pgType.Size)
} else {
Expand All @@ -592,7 +600,7 @@ func schemaToFieldDescriptions(ctx *sql.Context, s sql.Schema, mode QueryMode) [
panic(err)
}
size = int16(c.Type.MaxTextResponseByteLength(ctx))
format = 0
format = pgproto3.TextFormat
}

// "Format" field: The format code being used for the field.
Expand Down

0 comments on commit 167d6de

Please sign in to comment.