Skip to content

Commit

Permalink
fix: fix the incorrect reference to catalog
Browse files Browse the repository at this point in the history
  • Loading branch information
NoyException committed Dec 27, 2024
1 parent 0c978f0 commit 584591f
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 43 deletions.
2 changes: 1 addition & 1 deletion backend/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ func (b *DuckBuilder) Build(ctx *sql.Context, root sql.Node, r sql.Row) (sql.Row

switch node := n.(type) {
case *plan.Use:
useStmt := "USE " + catalog.FullSchemaName(b.provider.CatalogName(), node.Database().Name())
useStmt := "USE " + catalog.FullSchemaName(b.provider.Pool().CurrentCatalog(ctx.ID()), node.Database().Name())
if _, err := conn.ExecContext(ctx.Context, useStmt); err != nil {
if catalog.IsDuckDBSetSchemaNotFoundError(err) {
return nil, sql.ErrDatabaseNotFound.New(node.Database().Name())
Expand Down
42 changes: 29 additions & 13 deletions catalog/connpool.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,17 @@ import (

type ConnectionPool struct {
*stdsql.DB
connector *duckdb.Connector
catalog string
conns sync.Map // concurrent-safe map[uint32]*stdsql.Conn
txns sync.Map // concurrent-safe map[uint32]*stdsql.Tx
connector *duckdb.Connector
defaultCatalogName string
conns sync.Map // concurrent-safe map[uint32]*stdsql.Conn
txns sync.Map // concurrent-safe map[uint32]*stdsql.Tx
}

func NewConnectionPool(catalog string, connector *duckdb.Connector, db *stdsql.DB) *ConnectionPool {
func NewConnectionPool(defaultCatalogName string, connector *duckdb.Connector, db *stdsql.DB) *ConnectionPool {
return &ConnectionPool{
DB: db,
connector: connector,
catalog: catalog,
DB: db,
connector: connector,
defaultCatalogName: defaultCatalogName,
}
}

Expand All @@ -57,13 +57,30 @@ func (p *ConnectionPool) CurrentSchema(id uint32) string {
}
conn := entry.(*stdsql.Conn)
var schema string
if err := conn.QueryRowContext(context.Background(), "SELECT CURRENT_SCHEMA()").Scan(&schema); err != nil {
if err := conn.QueryRowContext(context.Background(), "SELECT CURRENT_SCHEMA").Scan(&schema); err != nil {
logrus.WithError(err).Error("Failed to get current schema")
return ""
}
return schema
}

// CurrentCatalog retrieves the current catalog of the connection.
// Returns an empty string if the connection is not established.
// Returns the default catalog name if the catalog cannot be retrieved.
func (p *ConnectionPool) CurrentCatalog(id uint32) string {
entry, ok := p.conns.Load(id)
if !ok {
return ""
}
conn := entry.(*stdsql.Conn)
var catalog string
if err := conn.QueryRowContext(context.Background(), "SELECT CURRENT_CATALOG").Scan(&catalog); err != nil {
logrus.WithError(err).Error("Failed to get current catalog")
return p.defaultCatalogName
}
return catalog
}

func (p *ConnectionPool) GetConn(ctx context.Context, id uint32) (*stdsql.Conn, error) {
var conn *stdsql.Conn
entry, ok := p.conns.Load(id)
Expand All @@ -88,11 +105,11 @@ func (p *ConnectionPool) GetConnForSchema(ctx context.Context, id uint32, schema

if schemaName != "" {
var currentSchema string
if err := conn.QueryRowContext(context.Background(), "SELECT CURRENT_SCHEMA()").Scan(&currentSchema); err != nil {
if err := conn.QueryRowContext(context.Background(), "SELECT CURRENT_SCHEMA").Scan(&currentSchema); err != nil {
logrus.WithError(err).Error("Failed to get current schema")
return nil, err
} else if currentSchema != schemaName {
if _, err := conn.ExecContext(context.Background(), "USE "+FullSchemaName(p.catalog, schemaName)); err != nil {
if _, err := conn.ExecContext(context.Background(), "USE "+FullSchemaName(p.CurrentCatalog(id), schemaName)); err != nil {
if IsDuckDBSetSchemaNotFoundError(err) {
return nil, sql.ErrDatabaseNotFound.New(schemaName)
}
Expand Down Expand Up @@ -187,15 +204,14 @@ func (p *ConnectionPool) Close() error {
return errors.Join(lastErr, p.DB.Close())
}

func (p *ConnectionPool) Reset(catalog string, connector *duckdb.Connector, db *stdsql.DB) error {
func (p *ConnectionPool) Reset(connector *duckdb.Connector, db *stdsql.DB) error {
err := p.Close()
if err != nil {
return fmt.Errorf("failed to close connection pool: %w", err)
}

p.conns.Clear()
p.txns.Clear()
p.catalog = catalog
p.DB = db
p.connector = connector

Expand Down
33 changes: 18 additions & 15 deletions catalog/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (

"github.com/dolthub/go-mysql-server/sql"
"github.com/marcboeker/go-duckdb"
_ "github.com/marcboeker/go-duckdb"

"github.com/apecloud/myduckserver/adapter"
"github.com/apecloud/myduckserver/configuration"
Expand All @@ -26,7 +25,7 @@ type DatabaseProvider struct {
connector *duckdb.Connector
storage *stdsql.DB
pool *ConnectionPool
catalogName string // database name in postgres
defaultCatalogName string // default database name in postgres
dataDir string
dbFile string
dsn string
Expand Down Expand Up @@ -59,11 +58,11 @@ func NewDBProvider(defaultTimeZone, dataDir, defaultDB string) (prov *DatabasePr

shouldInit := true
if defaultDB == "" || defaultDB == "memory" {
prov.catalogName = "memory"
prov.defaultCatalogName = "memory"
prov.dbFile = ""
prov.dsn = ""
} else {
prov.catalogName = defaultDB
prov.defaultCatalogName = defaultDB
prov.dbFile = defaultDB + ".db"
prov.dsn = filepath.Join(prov.dataDir, prov.dbFile)
_, err = os.Stat(prov.dsn)
Expand All @@ -75,7 +74,7 @@ func NewDBProvider(defaultTimeZone, dataDir, defaultDB string) (prov *DatabasePr
return nil, err
}
prov.storage = stdsql.OpenDB(prov.connector)
prov.pool = NewConnectionPool(prov.catalogName, prov.connector, prov.storage)
prov.pool = NewConnectionPool(prov.defaultCatalogName, prov.connector, prov.storage)

bootQueries := []string{
"INSTALL arrow",
Expand Down Expand Up @@ -315,8 +314,8 @@ func (prov *DatabaseProvider) Pool() *ConnectionPool {
return prov.pool
}

func (prov *DatabaseProvider) CatalogName() string {
return prov.catalogName
func (prov *DatabaseProvider) DefaultCatalogName() string {
return prov.defaultCatalogName
}

func (prov *DatabaseProvider) DataDir() string {
Expand All @@ -342,7 +341,8 @@ func (prov *DatabaseProvider) AllDatabases(ctx *sql.Context) []sql.Database {
prov.mu.RLock()
defer prov.mu.RUnlock()

rows, err := adapter.QueryCatalog(ctx, "SELECT DISTINCT schema_name FROM information_schema.schemata WHERE catalog_name = ?", prov.catalogName)
catalogName := prov.pool.CurrentCatalog(ctx.ID())
rows, err := adapter.QueryCatalog(ctx, "SELECT DISTINCT schema_name FROM information_schema.schemata WHERE catalog_name = ?", catalogName)
if err != nil {
panic(ErrDuckDB.New(err))
}
Expand All @@ -360,7 +360,7 @@ func (prov *DatabaseProvider) AllDatabases(ctx *sql.Context) []sql.Database {
continue
}

all = append(all, NewDatabase(schemaName, prov.catalogName))
all = append(all, NewDatabase(schemaName, catalogName))
}

sort.Slice(all, func(i, j int) bool {
Expand All @@ -375,13 +375,14 @@ func (prov *DatabaseProvider) Database(ctx *sql.Context, name string) (sql.Datab
prov.mu.RLock()
defer prov.mu.RUnlock()

ok, err := hasDatabase(ctx, prov.catalogName, name)
catalogName := prov.pool.CurrentCatalog(ctx.ID())
ok, err := hasDatabase(ctx, catalogName, name)
if err != nil {
return nil, err
}

if ok {
return NewDatabase(name, prov.catalogName), nil
return NewDatabase(name, catalogName), nil
}
return nil, sql.ErrDatabaseNotFound.New(name)
}
Expand All @@ -391,7 +392,7 @@ func (prov *DatabaseProvider) HasDatabase(ctx *sql.Context, name string) bool {
prov.mu.RLock()
defer prov.mu.RUnlock()

ok, err := hasDatabase(ctx, prov.catalogName, name)
ok, err := hasDatabase(ctx, prov.pool.CurrentCatalog(ctx.ID()), name)
if err != nil {
panic(err)
}
Expand All @@ -413,7 +414,8 @@ func (prov *DatabaseProvider) CreateDatabase(ctx *sql.Context, name string) erro
prov.mu.Lock()
defer prov.mu.Unlock()

_, err := adapter.ExecCatalog(ctx, fmt.Sprintf(`CREATE SCHEMA %s`, FullSchemaName(prov.catalogName, name)))
_, err := adapter.ExecCatalog(ctx, fmt.Sprintf(`CREATE SCHEMA %s`,
FullSchemaName(prov.pool.CurrentCatalog(ctx.ID()), name)))
if err != nil {
return ErrDuckDB.New(err)
}
Expand All @@ -426,7 +428,8 @@ func (prov *DatabaseProvider) DropDatabase(ctx *sql.Context, name string) error
prov.mu.Lock()
defer prov.mu.Unlock()

_, err := adapter.Exec(ctx, fmt.Sprintf(`DROP SCHEMA %s CASCADE`, FullSchemaName(prov.catalogName, name)))
_, err := adapter.Exec(ctx, fmt.Sprintf(`DROP SCHEMA %s CASCADE`,
FullSchemaName(prov.pool.CurrentCatalog(ctx.ID()), name)))
if err != nil {
return ErrDuckDB.New(err)
}
Expand Down Expand Up @@ -456,5 +459,5 @@ func (prov *DatabaseProvider) Restart(readOnly bool) error {
prov.connector = connector
prov.storage = storage

return nil
return prov.pool.Reset(connector, storage)
}
15 changes: 2 additions & 13 deletions pgserver/backup_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,6 @@ func parseBackupSQL(sql string) (*BackupConfig, error) {
}

func (h *ConnectionHandler) executeBackup(backupConfig *BackupConfig) (string, error) {
// TODO(neo.zty): Add support for backing up multiple databases once MyDuck Server supports multi-database functionality.
if backupConfig.DbName != h.server.Provider.CatalogName() {
return "", fmt.Errorf("backup database name %s does not match server database name %s",
backupConfig.DbName, h.server.Provider.CatalogName())
}

sqlCtx, err := h.duckHandler.sm.NewContextWithQuery(context.Background(), h.mysqlConn, "")
if err != nil {
return "", fmt.Errorf("failed to create context for query: %w", err)
Expand All @@ -114,7 +108,7 @@ func (h *ConnectionHandler) executeBackup(backupConfig *BackupConfig) (string, e
}

msg, err := backupConfig.StorageConfig.UploadFile(
h.server.Provider.DataDir(), h.server.Provider.DbFile(), backupConfig.RemotePath)
h.server.Provider.DataDir(), backupConfig.DbName+".db", backupConfig.RemotePath)
if err != nil {
return "", err
}
Expand All @@ -133,12 +127,7 @@ func (h *ConnectionHandler) executeBackup(backupConfig *BackupConfig) (string, e

func (h *ConnectionHandler) restartServer(readOnly bool) error {
provider := h.server.Provider
err := provider.Restart(readOnly)
if err != nil {
return err
}

return h.server.Provider.Pool().Reset(provider.CatalogName(), provider.Connector(), provider.Storage())
return provider.Restart(readOnly)
}

func doCheckpoint(sqlCtx *sql.Context) error {
Expand Down
2 changes: 1 addition & 1 deletion pgserver/connection_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ func (h *ConnectionHandler) chooseInitialDatabase(startupMessage *pgproto3.Start
}
if db == "postgres" || db == "mysql" {
if provider := h.duckHandler.GetCatalogProvider(); provider != nil {
db = provider.CatalogName()
db = provider.DefaultCatalogName()
}
}

Expand Down

0 comments on commit 584591f

Please sign in to comment.