Skip to content

Commit

Permalink
to #297 fix: add initial-data-dir.
Browse files Browse the repository at this point in the history
  • Loading branch information
TianyuZhang1214 committed Dec 25, 2024
1 parent e014bb9 commit 060d492
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 17 deletions.
8 changes: 3 additions & 5 deletions catalog/initial_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,12 @@ var InitialDatas = struct {
},
}

const INITIAL_DATA_DIR = "initialdata/"

var InitialDataFiles = struct {
PGClass string
PGProc string
PGType string
}{
PGClass: INITIAL_DATA_DIR + "pg_class.csv",
PGProc: INITIAL_DATA_DIR + "pg_proc.csv",
PGType: INITIAL_DATA_DIR + "pg_type.csv",
PGClass: "pg_class.csv",
PGProc: "pg_proc.csv",
PGType: "pg_type.csv",
}
10 changes: 6 additions & 4 deletions catalog/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ type DatabaseProvider struct {
dsn string
externalProcedureRegistry sql.ExternalStoredProcedureRegistry
ready bool
initialDataDir string
}

var _ sql.DatabaseProvider = (*DatabaseProvider)(nil)
Expand All @@ -41,20 +42,21 @@ var _ configuration.DataDirProvider = (*DatabaseProvider)(nil)

const readOnlySuffix = "?access_mode=read_only"

func NewInMemoryDBProvider() *DatabaseProvider {
prov, err := NewDBProvider("", ".", "")
func NewInMemoryDBProvider(initialDataDir string) *DatabaseProvider {
prov, err := NewDBProvider("", ".", "", initialDataDir)
if err != nil {
panic(err)
}
return prov
}

func NewDBProvider(defaultTimeZone, dataDir, defaultDB string) (prov *DatabaseProvider, err error) {
func NewDBProvider(defaultTimeZone, dataDir, defaultDB, initialDataDir string) (prov *DatabaseProvider, err error) {
prov = &DatabaseProvider{
mu: &sync.RWMutex{},
defaultTimeZone: defaultTimeZone,
externalProcedureRegistry: sql.NewExternalStoredProcedureRegistry(), // This has no effect, just to satisfy the upper layer interface
dataDir: dataDir,
initialDataDir: initialDataDir,
}

shouldInit := true
Expand Down Expand Up @@ -152,7 +154,7 @@ func (prov *DatabaseProvider) initCatalog() error {
if count == 0 {
if _, err := prov.storage.ExecContext(
context.Background(),
fmt.Sprintf("COPY %s FROM '%s' (FORMAT CSV, HEADER)", t.QualifiedName(), t.InitialDataFile),
fmt.Sprintf("COPY %s FROM '%s' (FORMAT CSV, HEADER)", t.QualifiedName(), prov.initialDataDir+t.InitialDataFile),
); err != nil {
return fmt.Errorf("failed to insert initial data from file into internal table %q: %w", t.Name, err)
}
Expand Down
4 changes: 2 additions & 2 deletions flightsqltest/driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ import (
"github.com/apache/arrow-go/v18/arrow/memory"
)

const defaultDbName = "mysql"
const defaultTableName = "drivertest"
const dataDirectory = "."
const dbFileName = "mysql.db"

var defaultStatements = map[string]string{
"create table": `
Expand Down Expand Up @@ -92,7 +92,7 @@ func (s *SqlTestSuite) SetupSuite() {
}

s.createServer = func() (flight.Server, string, error) {
provider, err := catalog.NewDBProvider(dataDirectory, dbFileName)
provider, err := catalog.NewDBProvider("", dataDirectory, defaultDbName, "../initialdata/")
if err != nil {
return nil, "", err
}
Expand Down
2 changes: 1 addition & 1 deletion harness/duck_harness.go
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ func (m *DuckHarness) getProvider() sql.DatabaseProvider {
}

func (m *DuckHarness) NewDatabaseProvider() sql.MutableDatabaseProvider {
return catalog.NewInMemoryDBProvider()
return catalog.NewInMemoryDBProvider("../initialdata/")
}

func (m *DuckHarness) Provider() *catalog.DatabaseProvider {
Expand Down
10 changes: 6 additions & 4 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,9 @@ var (
restoreAccessKeyId = ""
restoreSecretAccessKey = ""

flightsqlHost = "localhost"
flightsqlPort = -1 // Disabled by default
flightsqlHost = "localhost"
flightsqlPort = -1 // Disabled by default
initialDataDir = "./initialdata/"
)

func init() {
Expand Down Expand Up @@ -100,6 +101,7 @@ func init() {

flag.StringVar(&flightsqlHost, "flightsql-host", flightsqlHost, "hostname for the Flight SQL service")
flag.IntVar(&flightsqlPort, "flightsql-port", flightsqlPort, "port number for the Flight SQL service")
flag.StringVar(&initialDataDir, "initial-data-dir", initialDataDir, "directory containing initial data files")
}

func ensureSQLTranslate() {
Expand All @@ -123,12 +125,12 @@ func main() {
executeRestoreIfNeeded()

if initMode {
provider := catalog.NewInMemoryDBProvider()
provider := catalog.NewInMemoryDBProvider(initialDataDir)
provider.Close()
return
}

provider, err := catalog.NewDBProvider(defaultTimeZone, dataDirectory, defaultDb)
provider, err := catalog.NewDBProvider(defaultTimeZone, dataDirectory, defaultDb, initialDataDir)
if err != nil {
logrus.Fatalln("Failed to open the database:", err)
}
Expand Down
2 changes: 1 addition & 1 deletion pgtest/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import (
)

func CreateTestServer(t *testing.T, port int) (ctx context.Context, pgServer *pgserver.Server, conn *pgx.Conn, close func() error, err error) {
provider := catalog.NewInMemoryDBProvider()
provider := catalog.NewInMemoryDBProvider("../../initialdata/")

// Postgres tables are created in the `public` schema by default.
// Create the `public` schema if it doesn't exist.
Expand Down

0 comments on commit 060d492

Please sign in to comment.