Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: incorrect reference to catalog #332

Merged
merged 3 commits into from
Dec 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions adapter/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ type ConnectionHolder interface {
GetCatalogConn(ctx context.Context) (*stdsql.Conn, error)
GetCatalogTxn(ctx context.Context, options *stdsql.TxOptions) (*stdsql.Tx, error)
TryGetTxn() *stdsql.Tx
GetCurrentCatalog() string
GetCurrentSchema() string
CloseTxn()
CloseConn()
}
Expand Down Expand Up @@ -42,6 +44,14 @@ func TryGetTxn(ctx *sql.Context) *stdsql.Tx {
return ctx.Session.(ConnectionHolder).TryGetTxn()
}

func GetCurrentCatalog(ctx *sql.Context) string {
return ctx.Session.(ConnectionHolder).GetCurrentCatalog()
}

func GetCurrentSchema(ctx *sql.Context) string {
return ctx.Session.(ConnectionHolder).GetCurrentSchema()
}

func CloseTxn(ctx *sql.Context) {
ctx.Session.(ConnectionHolder).CloseTxn()
}
Expand Down
3 changes: 2 additions & 1 deletion backend/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
stdsql "database/sql"
"fmt"

"github.com/apecloud/myduckserver/adapter"
"github.com/apecloud/myduckserver/catalog"
"github.com/apecloud/myduckserver/transpiler"
"github.com/dolthub/go-mysql-server/sql"
Expand Down Expand Up @@ -124,7 +125,7 @@ func (b *DuckBuilder) Build(ctx *sql.Context, root sql.Node, r sql.Row) (sql.Row

switch node := n.(type) {
case *plan.Use:
useStmt := "USE " + catalog.FullSchemaName(b.provider.CatalogName(), node.Database().Name())
useStmt := "USE " + catalog.FullSchemaName(adapter.GetCurrentCatalog(ctx), node.Database().Name())
if _, err := conn.ExecContext(ctx.Context, useStmt); err != nil {
if catalog.IsDuckDBSetSchemaNotFoundError(err) {
return nil, sql.ErrDatabaseNotFound.New(node.Database().Name())
Expand Down
10 changes: 10 additions & 0 deletions backend/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,16 @@ func (sess *Session) TryGetTxn() *stdsql.Tx {
return sess.db.Pool().TryGetTxn(sess.ID())
}

// GetCurrentCatalog implements adapter.ConnectionHolder.
func (sess *Session) GetCurrentCatalog() string {
return sess.db.Pool().CurrentCatalog(sess.ID())
}

// GetCurrentSchema implements adapter.ConnectionHolder.
func (sess *Session) GetCurrentSchema() string {
return sess.db.Pool().CurrentSchema(sess.ID())
}

// CloseTxn implements adapter.ConnectionHolder.
func (sess *Session) CloseTxn() {
sess.db.Pool().CloseTxn(sess.ID())
Expand Down
30 changes: 22 additions & 8 deletions catalog/connpool.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,14 @@ import (
type ConnectionPool struct {
*stdsql.DB
connector *duckdb.Connector
catalog string
conns sync.Map // concurrent-safe map[uint32]*stdsql.Conn
txns sync.Map // concurrent-safe map[uint32]*stdsql.Tx
}

func NewConnectionPool(catalog string, connector *duckdb.Connector, db *stdsql.DB) *ConnectionPool {
func NewConnectionPool(connector *duckdb.Connector, db *stdsql.DB) *ConnectionPool {
return &ConnectionPool{
DB: db,
connector: connector,
catalog: catalog,
}
}

Expand All @@ -57,13 +55,30 @@ func (p *ConnectionPool) CurrentSchema(id uint32) string {
}
conn := entry.(*stdsql.Conn)
var schema string
if err := conn.QueryRowContext(context.Background(), "SELECT CURRENT_SCHEMA()").Scan(&schema); err != nil {
if err := conn.QueryRowContext(context.Background(), "SELECT CURRENT_SCHEMA").Scan(&schema); err != nil {
logrus.WithError(err).Error("Failed to get current schema")
return ""
}
return schema
}

// CurrentCatalog retrieves the current catalog of the connection.
// Returns an empty string if the connection is not established
// or the catalog cannot be retrieved.
func (p *ConnectionPool) CurrentCatalog(id uint32) string {
entry, ok := p.conns.Load(id)
if !ok {
return ""
}
conn := entry.(*stdsql.Conn)
var catalog string
if err := conn.QueryRowContext(context.Background(), "SELECT CURRENT_CATALOG").Scan(&catalog); err != nil {
logrus.WithError(err).Error("Failed to get current catalog")
return ""
}
return catalog
}

func (p *ConnectionPool) GetConn(ctx context.Context, id uint32) (*stdsql.Conn, error) {
var conn *stdsql.Conn
entry, ok := p.conns.Load(id)
Expand All @@ -88,11 +103,11 @@ func (p *ConnectionPool) GetConnForSchema(ctx context.Context, id uint32, schema

if schemaName != "" {
var currentSchema string
if err := conn.QueryRowContext(context.Background(), "SELECT CURRENT_SCHEMA()").Scan(&currentSchema); err != nil {
if err := conn.QueryRowContext(context.Background(), "SELECT CURRENT_SCHEMA").Scan(&currentSchema); err != nil {
logrus.WithError(err).Error("Failed to get current schema")
return nil, err
} else if currentSchema != schemaName {
if _, err := conn.ExecContext(context.Background(), "USE "+FullSchemaName(p.catalog, schemaName)); err != nil {
if _, err := conn.ExecContext(context.Background(), "USE "+FullSchemaName(p.CurrentCatalog(id), schemaName)); err != nil {
if IsDuckDBSetSchemaNotFoundError(err) {
return nil, sql.ErrDatabaseNotFound.New(schemaName)
}
Expand Down Expand Up @@ -187,15 +202,14 @@ func (p *ConnectionPool) Close() error {
return errors.Join(lastErr, p.DB.Close())
}

func (p *ConnectionPool) Reset(catalog string, connector *duckdb.Connector, db *stdsql.DB) error {
func (p *ConnectionPool) Reset(connector *duckdb.Connector, db *stdsql.DB) error {
err := p.Close()
if err != nil {
return fmt.Errorf("failed to close connection pool: %w", err)
}

p.conns.Clear()
p.txns.Clear()
p.catalog = catalog
p.DB = db
p.connector = connector

Expand Down
33 changes: 18 additions & 15 deletions catalog/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (

"github.com/dolthub/go-mysql-server/sql"
"github.com/marcboeker/go-duckdb"
_ "github.com/marcboeker/go-duckdb"

"github.com/apecloud/myduckserver/adapter"
"github.com/apecloud/myduckserver/configuration"
Expand All @@ -27,7 +26,7 @@ type DatabaseProvider struct {
connector *duckdb.Connector
storage *stdsql.DB
pool *ConnectionPool
catalogName string // database name in postgres
defaultCatalogName string // default database name in postgres
dataDir string
dbFile string
dsn string
Expand Down Expand Up @@ -60,11 +59,11 @@ func NewDBProvider(defaultTimeZone, dataDir, defaultDB string) (prov *DatabasePr

shouldInit := true
if defaultDB == "" || defaultDB == "memory" {
prov.catalogName = "memory"
prov.defaultCatalogName = "memory"
prov.dbFile = ""
prov.dsn = ""
} else {
prov.catalogName = defaultDB
prov.defaultCatalogName = defaultDB
prov.dbFile = defaultDB + ".db"
prov.dsn = filepath.Join(prov.dataDir, prov.dbFile)
_, err = os.Stat(prov.dsn)
Expand All @@ -76,7 +75,7 @@ func NewDBProvider(defaultTimeZone, dataDir, defaultDB string) (prov *DatabasePr
return nil, err
}
prov.storage = stdsql.OpenDB(prov.connector)
prov.pool = NewConnectionPool(prov.catalogName, prov.connector, prov.storage)
prov.pool = NewConnectionPool(prov.connector, prov.storage)

bootQueries := []string{
"INSTALL arrow",
Expand Down Expand Up @@ -353,8 +352,8 @@ func (prov *DatabaseProvider) Pool() *ConnectionPool {
return prov.pool
}

func (prov *DatabaseProvider) CatalogName() string {
return prov.catalogName
func (prov *DatabaseProvider) DefaultCatalogName() string {
return prov.defaultCatalogName
}

func (prov *DatabaseProvider) DataDir() string {
Expand All @@ -380,7 +379,8 @@ func (prov *DatabaseProvider) AllDatabases(ctx *sql.Context) []sql.Database {
prov.mu.RLock()
defer prov.mu.RUnlock()

rows, err := adapter.QueryCatalog(ctx, "SELECT DISTINCT schema_name FROM information_schema.schemata WHERE catalog_name = ?", prov.catalogName)
catalogName := adapter.GetCurrentCatalog(ctx)
rows, err := adapter.QueryCatalog(ctx, "SELECT DISTINCT schema_name FROM information_schema.schemata WHERE catalog_name = ?", catalogName)
if err != nil {
panic(ErrDuckDB.New(err))
}
Expand All @@ -398,7 +398,7 @@ func (prov *DatabaseProvider) AllDatabases(ctx *sql.Context) []sql.Database {
continue
}

all = append(all, NewDatabase(schemaName, prov.catalogName))
all = append(all, NewDatabase(schemaName, catalogName))
}

sort.Slice(all, func(i, j int) bool {
Expand All @@ -413,13 +413,14 @@ func (prov *DatabaseProvider) Database(ctx *sql.Context, name string) (sql.Datab
prov.mu.RLock()
defer prov.mu.RUnlock()

ok, err := hasDatabase(ctx, prov.catalogName, name)
catalogName := adapter.GetCurrentCatalog(ctx)
ok, err := hasDatabase(ctx, catalogName, name)
if err != nil {
return nil, err
}

if ok {
return NewDatabase(name, prov.catalogName), nil
return NewDatabase(name, catalogName), nil
}
return nil, sql.ErrDatabaseNotFound.New(name)
}
Expand All @@ -429,7 +430,7 @@ func (prov *DatabaseProvider) HasDatabase(ctx *sql.Context, name string) bool {
prov.mu.RLock()
defer prov.mu.RUnlock()

ok, err := hasDatabase(ctx, prov.catalogName, name)
ok, err := hasDatabase(ctx, adapter.GetCurrentCatalog(ctx), name)
if err != nil {
panic(err)
}
Expand All @@ -451,7 +452,8 @@ func (prov *DatabaseProvider) CreateDatabase(ctx *sql.Context, name string) erro
prov.mu.Lock()
defer prov.mu.Unlock()

_, err := adapter.ExecCatalog(ctx, fmt.Sprintf(`CREATE SCHEMA %s`, FullSchemaName(prov.catalogName, name)))
_, err := adapter.ExecCatalog(ctx, fmt.Sprintf(`CREATE SCHEMA %s`,
FullSchemaName(adapter.GetCurrentCatalog(ctx), name)))
if err != nil {
return ErrDuckDB.New(err)
}
Expand All @@ -464,7 +466,8 @@ func (prov *DatabaseProvider) DropDatabase(ctx *sql.Context, name string) error
prov.mu.Lock()
defer prov.mu.Unlock()

_, err := adapter.Exec(ctx, fmt.Sprintf(`DROP SCHEMA %s CASCADE`, FullSchemaName(prov.catalogName, name)))
_, err := adapter.Exec(ctx, fmt.Sprintf(`DROP SCHEMA %s CASCADE`,
FullSchemaName(adapter.GetCurrentCatalog(ctx), name)))
if err != nil {
return ErrDuckDB.New(err)
}
Expand Down Expand Up @@ -494,5 +497,5 @@ func (prov *DatabaseProvider) Restart(readOnly bool) error {
prov.connector = connector
prov.storage = storage

return nil
return prov.pool.Reset(connector, storage)
}
15 changes: 2 additions & 13 deletions pgserver/backup_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,6 @@ func parseBackupSQL(sql string) (*BackupConfig, error) {
}

func (h *ConnectionHandler) executeBackup(backupConfig *BackupConfig) (string, error) {
// TODO(neo.zty): Add support for backing up multiple databases once MyDuck Server supports multi-database functionality.
if backupConfig.DbName != h.server.Provider.CatalogName() {
return "", fmt.Errorf("backup database name %s does not match server database name %s",
backupConfig.DbName, h.server.Provider.CatalogName())
}

sqlCtx, err := h.duckHandler.sm.NewContextWithQuery(context.Background(), h.mysqlConn, "")
if err != nil {
return "", fmt.Errorf("failed to create context for query: %w", err)
Expand All @@ -114,7 +108,7 @@ func (h *ConnectionHandler) executeBackup(backupConfig *BackupConfig) (string, e
}

msg, err := backupConfig.StorageConfig.UploadFile(
h.server.Provider.DataDir(), h.server.Provider.DbFile(), backupConfig.RemotePath)
h.server.Provider.DataDir(), backupConfig.DbName+".db", backupConfig.RemotePath)
if err != nil {
return "", err
}
Expand All @@ -133,12 +127,7 @@ func (h *ConnectionHandler) executeBackup(backupConfig *BackupConfig) (string, e

func (h *ConnectionHandler) restartServer(readOnly bool) error {
provider := h.server.Provider
err := provider.Restart(readOnly)
if err != nil {
return err
}

return h.server.Provider.Pool().Reset(provider.CatalogName(), provider.Connector(), provider.Storage())
return provider.Restart(readOnly)
}

func doCheckpoint(sqlCtx *sql.Context) error {
Expand Down
2 changes: 1 addition & 1 deletion pgserver/connection_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ func (h *ConnectionHandler) chooseInitialDatabase(startupMessage *pgproto3.Start
}
if db == "postgres" || db == "mysql" {
if provider := h.duckHandler.GetCatalogProvider(); provider != nil {
db = provider.CatalogName()
db = provider.DefaultCatalogName()
}
}

Expand Down
Loading