Skip to content

Commit

Permalink
feat:Migrate to DuckDB Arrow for Query Execution
Browse files Browse the repository at this point in the history
  • Loading branch information
d h authored and d h committed Jan 3, 2025
1 parent cd9c772 commit 4f76af4
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 31 deletions.
50 changes: 38 additions & 12 deletions flightsqlserver/sqlite_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,8 @@ type SQLiteFlightSQLServer struct {
openTransactions sync.Map
}

func NewSQLiteFlightSQLServer(db *sql.DB, conn *duckdb.Conn) (*SQLiteFlightSQLServer, error) {
ret := &SQLiteFlightSQLServer{db: db, conn: conn}
func NewSQLiteFlightSQLServer(db *sql.DB) (*SQLiteFlightSQLServer, error) {
ret := &SQLiteFlightSQLServer{db: db}
ret.Alloc = memory.DefaultAllocator
for k, v := range SqlInfoResultMap() {
ret.RegisterSqlInfo(flightsql.SqlInfo(k), v)
Expand Down Expand Up @@ -273,7 +273,7 @@ func (s *SQLiteFlightSQLServer) DoGetStatement(ctx context.Context, cmd flightsq
// db = tx.(*sql.Tx)
// }

return doGetQuery(ctx, s.conn, query, nil)
return doGetQuery(ctx, s.db, query, nil)
}

func (s *SQLiteFlightSQLServer) GetFlightInfoCatalogs(_ context.Context, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) {
Expand Down Expand Up @@ -348,7 +348,16 @@ func (s *SQLiteFlightSQLServer) GetFlightInfoTables(_ context.Context, cmd fligh
func (s *SQLiteFlightSQLServer) DoGetTables(ctx context.Context, cmd flightsql.GetTables) (*arrow.Schema, <-chan flight.StreamChunk, error) {
query := prepareQueryForGetTables(cmd)

arrow, err := duckdb.NewArrowFromConn(s.conn)
conn, err := s.db.Conn(ctx)
var duckConn *duckdb.Conn
err = conn.Raw(func(driverConn any) error {
duckConn = driverConn.(*duckdb.Conn)
return nil
})
if err != nil {
return nil, nil, err
}
arrow, err := duckdb.NewArrowFromConn(duckConn)
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -394,7 +403,7 @@ func (s *SQLiteFlightSQLServer) GetFlightInfoTableTypes(_ context.Context, desc

func (s *SQLiteFlightSQLServer) DoGetTableTypes(ctx context.Context) (*arrow.Schema, <-chan flight.StreamChunk, error) {
query := "SELECT DISTINCT type AS table_type FROM sqlite_master"
return doGetQuery(ctx, s.conn, query, schema_ref.TableTypes)
return doGetQuery(ctx, s.db, query, schema_ref.TableTypes)
}

func (s *SQLiteFlightSQLServer) DoPutCommandStatementUpdate(ctx context.Context, cmd flightsql.StatementUpdate) (int64, error) {
Expand Down Expand Up @@ -477,9 +486,15 @@ type dbQueryCtx interface {
QueryContext(context.Context, string, ...any) (*sql.Rows, error)
}

func doGetQuery(ctx context.Context, conn *duckdb.Conn, query string, schema *arrow.Schema, args ...interface{}) (*arrow.Schema, <-chan flight.StreamChunk, error) {
func doGetQuery(ctx context.Context, db *sql.DB, query string, schema *arrow.Schema, args ...interface{}) (*arrow.Schema, <-chan flight.StreamChunk, error) {

arrow, err := duckdb.NewArrowFromConn(conn)
conn, err := db.Conn(ctx)
var duckConn *duckdb.Conn
err = conn.Raw(func(driverConn any) error {
duckConn = driverConn.(*duckdb.Conn)
return nil
})
arrow, err := duckdb.NewArrowFromConn(duckConn)
if err != nil {
return nil, nil, err
}
Expand All @@ -495,12 +510,23 @@ func doGetQuery(ctx context.Context, conn *duckdb.Conn, query string, schema *ar

func (s *SQLiteFlightSQLServer) DoGetPreparedStatement(ctx context.Context, cmd flightsql.PreparedStatementQuery) (schema *arrow.Schema, out <-chan flight.StreamChunk, err error) {
val, ok := s.prepared.Load(string(cmd.GetPreparedStatementHandle()))

if !ok {
return nil, nil, status.Error(codes.InvalidArgument, "prepared statement not found")
}

conn, err := s.db.Conn(ctx)
var duckConn *duckdb.Conn
err = conn.Raw(func(driverConn any) error {
duckConn = driverConn.(*duckdb.Conn)
return nil
})
if err != nil {
return nil, nil, err
}

stmt := val.(Statement)
arrow, err := duckdb.NewArrowFromConn(s.conn)
arrow, err := duckdb.NewArrowFromConn(duckConn)
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -688,7 +714,7 @@ func (s *SQLiteFlightSQLServer) DoGetPrimaryKeys(ctx context.Context, cmd flight

fmt.Fprintf(&b, " and table_name LIKE '%s'", cmd.Table)

return doGetQuery(ctx, s.conn, b.String(), schema_ref.PrimaryKeys)
return doGetQuery(ctx, s.db, b.String(), schema_ref.PrimaryKeys)
}

func (s *SQLiteFlightSQLServer) GetFlightInfoImportedKeys(_ context.Context, _ flightsql.TableRef, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) {
Expand All @@ -704,7 +730,7 @@ func (s *SQLiteFlightSQLServer) DoGetImportedKeys(ctx context.Context, ref fligh
filter += " AND fk_schema_name = '" + *ref.DBSchema + "'"
}
query := prepareQueryForGetKeys(filter)
return doGetQuery(ctx, s.conn, query, schema_ref.ImportedKeys)
return doGetQuery(ctx, s.db, query, schema_ref.ImportedKeys)
}

func (s *SQLiteFlightSQLServer) GetFlightInfoExportedKeys(_ context.Context, _ flightsql.TableRef, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) {
Expand All @@ -720,7 +746,7 @@ func (s *SQLiteFlightSQLServer) DoGetExportedKeys(ctx context.Context, ref fligh
filter += " AND pk_schema_name = '" + *ref.DBSchema + "'"
}
query := prepareQueryForGetKeys(filter)
return doGetQuery(ctx, s.conn, query, schema_ref.ExportedKeys)
return doGetQuery(ctx, s.db, query, schema_ref.ExportedKeys)
}

func (s *SQLiteFlightSQLServer) GetFlightInfoCrossReference(_ context.Context, _ flightsql.CrossTableRef, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) {
Expand All @@ -746,7 +772,7 @@ func (s *SQLiteFlightSQLServer) DoGetCrossReference(ctx context.Context, cmd fli
filter += " AND fk_schema_name = '" + *fkref.DBSchema + "'"
}
query := prepareQueryForGetKeys(filter)
return doGetQuery(ctx, s.conn, query, schema_ref.ExportedKeys)
return doGetQuery(ctx, s.db, query, schema_ref.ExportedKeys)
}

func (s *SQLiteFlightSQLServer) BeginTransaction(_ context.Context, req flightsql.ActionBeginTransactionRequest) (id []byte, err error) {
Expand Down
12 changes: 1 addition & 11 deletions flightsqltest/driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import (
"database/sql"
"errors"
"fmt"
"log"
"math/rand"
"os"
"strings"
Expand All @@ -33,7 +32,6 @@ import (
"time"

"github.com/apecloud/myduckserver/flightsqlserver"
"github.com/marcboeker/go-duckdb"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"

Expand Down Expand Up @@ -98,16 +96,8 @@ func (s *SqlTestSuite) SetupSuite() {
if err != nil {
return nil, "", err
}
conn, err := provider.Connector().Connect(context.Background())
if err != nil {
log.Fatal(err)
}

duckConn, ok := conn.(*duckdb.Conn)
if !ok {
log.Fatal("Failed to get DuckDB connection")
}
sqliteServer, err := flightsqlserver.NewSQLiteFlightSQLServer(provider.Storage(), duckConn)
sqliteServer, err := flightsqlserver.NewSQLiteFlightSQLServer(provider.Storage())
if err != nil {
return nil, "", err
}
Expand Down
9 changes: 1 addition & 8 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ import (
"github.com/dolthub/go-mysql-server/server"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/vitess/go/mysql"
"github.com/marcboeker/go-duckdb"
_ "github.com/marcboeker/go-duckdb"
"github.com/sirupsen/logrus"
)
Expand Down Expand Up @@ -193,18 +192,12 @@ func main() {
if flightsqlPort > 0 {

db := provider.Storage()
conn, err := provider.Connector().Connect(context.Background())
if err != nil {
log.Fatal(err)
}
defer db.Close()

duckConn, ok := conn.(*duckdb.Conn)
if !ok {
log.Fatal("Failed to get DuckDB connection")
}

srv, err := flightsqlserver.NewSQLiteFlightSQLServer(db, duckConn)
srv, err := flightsqlserver.NewSQLiteFlightSQLServer(db)
if err != nil {
log.Fatal(err)
}
Expand Down

0 comments on commit 4f76af4

Please sign in to comment.