diff --git a/cmd/connect.go b/cmd/connect.go index ef6cd29..3afdd08 100644 --- a/cmd/connect.go +++ b/cmd/connect.go @@ -9,6 +9,7 @@ import ( "github.com/c-bata/go-prompt" "github.com/lewissteele/dbat/internal/db" "github.com/lewissteele/dbat/internal/input" + "github.com/lewissteele/dbat/internal/list" "github.com/lewissteele/dbat/internal/table" "github.com/spf13/cobra" ) @@ -17,7 +18,21 @@ var connectCmd = &cobra.Command{ Use: "connect", Short: "connect to saved database", Run: func(cmd *cobra.Command, args []string) { - db.Connect(selectedDB(args)) + var c string + + if len(args) > 0 { + c = args[0] + } + + if len(c) == 0 { + c = list.RenderConnectionSelection() + } + + db.Connect(c) + + if len(db.Selected()) == 0 { + db.Select(list.RenderDatabaseSelection()) + } prompt := prompt.New( executor, diff --git a/cmd/remove.go b/cmd/remove.go index f3553d6..06abc7a 100644 --- a/cmd/remove.go +++ b/cmd/remove.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/lewissteele/dbat/internal/db" + "github.com/lewissteele/dbat/internal/list" "github.com/lewissteele/dbat/internal/model" "github.com/spf13/cobra" ) @@ -12,7 +13,17 @@ var removeCmd = &cobra.Command{ Use: "remove", Short: "remove database connection", Run: func(cmd *cobra.Command, args []string) { - db.LocalDB.Where("name = ?", selectedDB(args)).Delete(&model.Database{}) + var database string + + if len(args) > 0 { + database = args[0] + } + + if len(database) == 0 { + database = list.RenderConnectionSelection() + } + + db.LocalDB.Where("name = ?", database).Delete(&model.Database{}) fmt.Println("removed database connection") }, } diff --git a/cmd/shared.go b/cmd/shared.go index 100b902..d7e6b28 100644 --- a/cmd/shared.go +++ b/cmd/shared.go @@ -1,11 +1,6 @@ package cmd -import ( - "errors" - - "github.com/charmbracelet/huh" - "github.com/lewissteele/dbat/internal/db" -) +import "errors" func isNotBlank(val string) error { if len(val) > 0 { @@ -14,33 +9,3 @@ func isNotBlank(val string) error { return errors.New("value cannot be blank") } - -func selectedDB(args []string) string { - if len(args) > 0 { - return args[0] - } - - var name string - - options := []huh.Option[string]{} - - for _, n := range db.UserDBNames() { - options = append(options, huh.NewOption[string](n, n)) - } - - form := huh.NewForm( - huh.NewGroup( - huh.NewSelect[string]().Title("database").Value(&name).Options( - options..., - ), - ), - ) - - err := form.Run() - - if err != nil { - panic(err) - } - - return name -} diff --git a/internal/db/objects.go b/internal/db/objects.go index bf9f1f7..1dde0dd 100644 --- a/internal/db/objects.go +++ b/internal/db/objects.go @@ -10,25 +10,50 @@ import ( ) var Columns []string -var Databases []string +var databases []string var Tables []string -var cacheConn *gorm.DB +func Databases() []string { + if len(databases) > 0 { + return databases + } -func cacheObjects() { - cacheConn, _ = gorm.Open( + c := NewConn() + c.Raw("show databases").Scan(&databases) + + for _, database := range databases { + if strings.Contains(database, "-") { + continue + } + + databases = append( + databases, + database, + ) + } + + d, _ := c.DB() + d.Close() + + return databases +} + +func NewConn() *gorm.DB { + c, _ := gorm.Open( dialector(UserDB), &gorm.Config{ Logger: logger.Default.LogMode(logger.Silent), SkipDefaultTransaction: true, }) - db, _ := cacheConn.DB() - db.SetMaxOpenConns(1) + d, _ := c.DB() + d.SetMaxOpenConns(1) - cacheDatabases() + return c +} - for _, d := range Databases { +func cacheObjects() { + for _, d := range Databases() { cacheTables(d) } @@ -41,31 +66,14 @@ func cacheObjects() { slices.Sort(Columns) Columns = slices.Compact(Columns) - - db.Close() -} - -func cacheDatabases() { - var databases []string - cacheConn.Raw("show databases").Scan(&databases) - - for _, database := range databases { - if strings.Contains(database, "-") { - continue - } - - Databases = append( - Databases, - database, - ) - } } func cacheTables(database string) { - cacheConn.Exec(fmt.Sprintf("use `%s`", database)) + c := NewConn() + c.Exec(fmt.Sprintf("use `%s`", database)) var tables []string - cacheConn.Raw("show tables").Scan(&tables) + c.Raw("show tables").Scan(&tables) for _, table := range tables { if Selected() == database { @@ -75,18 +83,26 @@ func cacheTables(database string) { Tables = append(Tables, strings.Join([]string{database, table}, ".")) } + + d, _ := c.DB() + d.Close() } func cacheColumns(table string) { + c := NewConn() + type column struct { Field string } var columns []column - cacheConn.Exec(fmt.Sprintf("use `%s`", Selected())) - cacheConn.Raw("show columns from channels").Scan(&columns) + c.Exec(fmt.Sprintf("use `%s`", Selected())) + c.Raw("show columns from channels").Scan(&columns) - for _, c := range columns { - Columns = append(Columns, c.Field) + for _, column := range columns { + Columns = append(Columns, column.Field) } + + d, _ := c.DB() + d.Close() } diff --git a/internal/input/completer.go b/internal/input/completer.go index 5c1e5ea..62ed4fd 100644 --- a/internal/input/completer.go +++ b/internal/input/completer.go @@ -27,7 +27,7 @@ func Completer(d prompt.Document) []prompt.Suggest { s = append(s, similarity(currentWord, keywords[:])...) s = append(s, similarity(currentWord, db.Columns)...) s = append(s, similarity(currentWord, db.Tables)...) - s = append(s, similarity(currentWord, db.Databases)...) + s = append(s, similarity(currentWord, db.Databases())...) return s } diff --git a/internal/list/connection.go b/internal/list/connection.go new file mode 100644 index 0000000..0252c18 --- /dev/null +++ b/internal/list/connection.go @@ -0,0 +1,31 @@ +package list + +import ( + "github.com/charmbracelet/huh" + "github.com/lewissteele/dbat/internal/db" +) + +func RenderConnectionSelection() string { + var c string + options := []huh.Option[string]{} + + for _, v := range db.UserDBNames() { + options = append(options, huh.NewOption[string](v, v)) + } + + form := huh.NewForm( + huh.NewGroup( + huh.NewSelect[string]().Title("database").Value(&c).Options( + options..., + ), + ), + ) + + err := form.Run() + + if err != nil { + panic(err) + } + + return c +} diff --git a/internal/list/database.go b/internal/list/database.go new file mode 100644 index 0000000..6463ac8 --- /dev/null +++ b/internal/list/database.go @@ -0,0 +1,31 @@ +package list + +import ( + "github.com/charmbracelet/huh" + "github.com/lewissteele/dbat/internal/db" +) + +func RenderDatabaseSelection() string { + var d string + options := []huh.Option[string]{} + + for _, v := range db.Databases() { + options = append(options, huh.NewOption[string](v, v)) + } + + form := huh.NewForm( + huh.NewGroup( + huh.NewSelect[string]().Title("database").Value(&d).Options( + options..., + ), + ), + ) + + err := form.Run() + + if err != nil { + panic(err) + } + + return d +}