Skip to content

Commit

Permalink
Refactor guacgql command to support out-of-tree backends
Browse files Browse the repository at this point in the history
Signed-off-by: robert-cronin <[email protected]>
  • Loading branch information
robert-cronin committed Nov 18, 2024
1 parent cb5677f commit a510df6
Show file tree
Hide file tree
Showing 12 changed files with 281 additions and 244 deletions.
55 changes: 0 additions & 55 deletions cmd/guacgql/cmd/ent.go

This file was deleted.

77 changes: 17 additions & 60 deletions cmd/guacgql/cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"os"
"strings"

"github.com/guacsec/guac/pkg/assembler/backends"
"github.com/guacsec/guac/pkg/cli"
"github.com/guacsec/guac/pkg/version"
"github.com/spf13/cobra"
Expand All @@ -34,35 +35,6 @@ var flags = struct {
tlsKeyFile string
debug bool
tracegql bool

// Needed only if using neo4j backend
nAddr string
nUser string
nPass string
nRealm string

// Needed only if using ent backend
dbAddress string
dbDriver string
dbDebug bool
dbMigrate bool
dbConnTime string

// Needed only if using arangodb backend
arangoAddr string
arangoUser string
arangoPass string

// Needed only if using neptune backend
neptuneEndpoint string
neptunePort int
neptuneRegion string
neptuneUser string
neptuneRealm string

kvStore string
kvRedis string
kvTiKV string
}{}

var rootCmd = &cobra.Command{
Expand All @@ -77,51 +49,36 @@ var rootCmd = &cobra.Command{
flags.debug = viper.GetBool("gql-debug")
flags.tracegql = viper.GetBool("gql-trace")

flags.nUser = viper.GetString("neo4j-user")
flags.nPass = viper.GetString("neo4j-pass")
flags.nAddr = viper.GetString("neo4j-addr")
flags.nRealm = viper.GetString("neo4j-realm")

// Needed only if using ent backend
flags.dbAddress = viper.GetString("db-address")
flags.dbDriver = viper.GetString("db-driver")
flags.dbDebug = viper.GetBool("db-debug")
flags.dbMigrate = viper.GetBool("db-migrate")
flags.dbConnTime = viper.GetString("db-conn-time")

flags.arangoUser = viper.GetString("arango-user")
flags.arangoPass = viper.GetString("arango-pass")
flags.arangoAddr = viper.GetString("arango-addr")

flags.neptuneEndpoint = viper.GetString("neptune-endpoint")
flags.neptunePort = viper.GetInt("neptune-port")
flags.neptuneRegion = viper.GetString("neptune-region")
flags.neptuneUser = viper.GetString("neptune-user")
flags.neptuneRealm = viper.GetString("neptune-realm")

flags.kvStore = viper.GetString("kv-store")
flags.kvRedis = viper.GetString("kv-redis")
flags.kvTiKV = viper.GetString("kv-tikv")
startServer(cmd)
},
}

func init() {
cobra.OnInitialize(cli.InitConfig)

// Register common flags
set, err := cli.BuildFlags([]string{
"arango-addr", "arango-user", "arango-pass",
"neo4j-addr", "neo4j-user", "neo4j-pass", "neo4j-realm",
"neptune-endpoint", "neptune-port", "neptune-region", "neptune-user", "neptune-realm",
"gql-listen-port", "gql-tls-cert-file", "gql-tls-key-file", "gql-debug", "gql-backend", "gql-trace",
"db-address", "db-driver", "db-debug", "db-migrate", "db-conn-time",
"kv-store", "kv-redis", "kv-tikv", "enable-prometheus",
"gql-listen-port",
"gql-tls-cert-file",
"gql-tls-key-file",
"gql-debug",
"gql-backend",
"gql-trace",
"enable-prometheus",
})
if err != nil {
fmt.Fprintf(os.Stderr, "failed to setup flag: %v", err)
os.Exit(1)
}
rootCmd.Flags().AddFlagSet(set)

// Register backend-specific flags
err = backends.RegisterFlags(rootCmd)
if err != nil {
fmt.Fprintf(os.Stderr, "failed to register backend flags: %v", err)
os.Exit(1)
}

if err := viper.BindPFlags(rootCmd.Flags()); err != nil {
fmt.Fprintf(os.Stderr, "failed to bind flags: %v", err)
os.Exit(1)
Expand Down
107 changes: 15 additions & 92 deletions cmd/guacgql/cmd/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,47 +26,24 @@ import (
"syscall"
"time"

"github.com/guacsec/guac/pkg/version"
// import all known backends
_ "github.com/guacsec/guac/pkg/assembler/backends/neo4j"
_ "github.com/guacsec/guac/pkg/assembler/backends/neptune"
_ "github.com/guacsec/guac/pkg/assembler/backends/ent/backend"
_ "github.com/guacsec/guac/pkg/assembler/backends/keyvalue"
_ "github.com/guacsec/guac/pkg/assembler/backends/arangodb"

"github.com/99designs/gqlgen/graphql/handler/debug"
"github.com/99designs/gqlgen/graphql/playground"
"github.com/guacsec/guac/pkg/assembler/backends"
"github.com/guacsec/guac/pkg/assembler/backends/arangodb"
_ "github.com/guacsec/guac/pkg/assembler/backends/keyvalue"
"github.com/guacsec/guac/pkg/assembler/backends/neo4j"
"github.com/guacsec/guac/pkg/assembler/backends/neptune"
"github.com/guacsec/guac/pkg/assembler/kv"
"github.com/guacsec/guac/pkg/assembler/kv/redis"
"github.com/guacsec/guac/pkg/assembler/server"
"github.com/guacsec/guac/pkg/logging"
"github.com/guacsec/guac/pkg/metrics"
"github.com/guacsec/guac/pkg/version"
"github.com/spf13/cobra"
"github.com/spf13/viper"
"golang.org/x/exp/maps"
)

const (
arango = "arango"
neo4js = "neo4j"
ent = "ent"
neptunes = "neptune"
keyvalue = "keyvalue"
)

type optsFunc func(context.Context) backends.BackendArgs

var getOpts map[string]optsFunc

func init() {
if getOpts == nil {
getOpts = make(map[string]optsFunc)
}
getOpts[arango] = getArango
getOpts[neo4js] = getNeo4j
getOpts[neptunes] = getNeptune
getOpts[keyvalue] = getKeyValue
}

func startServer(cmd *cobra.Command) {
var srvHandler http.Handler
ctx := logging.WithLogger(context.Background())
Expand All @@ -78,9 +55,15 @@ func startServer(cmd *cobra.Command) {
os.Exit(1)
}

backend, err := backends.Get(flags.backend, ctx, getOpts[flags.backend](ctx))
backendArgs, err := backends.GetBackendArgs(ctx, flags.backend)
if err != nil {
logger.Errorf("error creating %v backend: %w", flags.backend, err)
logger.Errorf("failed to parse backend flags with error: %v", err)
os.Exit(1)
}

backend, err := backends.Get(flags.backend, ctx, backendArgs)
if err != nil {
logger.Errorf("Error creating %v backend: %v", flags.backend, err)
os.Exit(1)
}

Expand Down Expand Up @@ -161,15 +144,9 @@ func setupPrometheus(ctx context.Context, name string) (metrics.MetricCollector,
}

func validateFlags() error {
if !slices.Contains(maps.Keys(getOpts), flags.backend) {
return fmt.Errorf("invalid graphql backend specified: %v", flags.backend)
}
if !slices.Contains(backends.List(), flags.backend) {
return fmt.Errorf("invalid graphql backend specified: %v", flags.backend)
}
if !slices.Contains([]string{"memmap", "redis", "tikv"}, flags.kvStore) {
return fmt.Errorf("invalid kv store specified: %v", flags.kvStore)
}
return nil
}

Expand All @@ -183,57 +160,3 @@ func versionHandler(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = fmt.Fprint(w, version.Version)
}

func getArango(_ context.Context) backends.BackendArgs {
return &arangodb.ArangoConfig{
User: flags.arangoUser,
Pass: flags.arangoPass,
DBAddr: flags.arangoAddr,
}
}

func getNeo4j(_ context.Context) backends.BackendArgs {
return &neo4j.Neo4jConfig{
User: flags.nUser,
Pass: flags.nPass,
Realm: flags.nRealm,
DBAddr: flags.nAddr,
}
}

var tikvGS func(context.Context, string) (kv.Store, error)

func getKeyValue(ctx context.Context) backends.BackendArgs {
logger := logging.FromContext(ctx)
switch flags.kvStore {
case "memmap":
// default is memmap
return nil
case "redis":
s, err := redis.GetStore(flags.kvRedis)
if err != nil {
logger.Fatalf("error with Redis: %v", err)
}
return s
case "tikv":
if tikvGS == nil {
logger.Fatal("TiKV not supported on 32-bit")
}
s, err := tikvGS(ctx, flags.kvTiKV)
if err != nil {
logger.Fatalf("error with TiKV: %v", err)
}
return s
}
return nil
}

func getNeptune(_ context.Context) backends.BackendArgs {
return &neptune.NeptuneConfig{
Endpoint: flags.neptuneEndpoint,
Port: flags.neptunePort,
Region: flags.neptuneRegion,
User: flags.neptuneUser,
Realm: flags.neptuneRealm,
}
}
34 changes: 33 additions & 1 deletion pkg/assembler/backends/arangodb/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ import (
"time"

"github.com/99designs/gqlgen/graphql"
"github.com/spf13/cobra"
"github.com/spf13/viper"

jsoniter "github.com/json-iterator/go"

Expand All @@ -42,6 +44,13 @@ type ArangoConfig struct {
TestData bool
}

// flags holds the command-line flags for ArangoDB configuration
var flags = struct {
addr string
user string
pass string
}{}

type arangoQueryBuilder struct {
query strings.Builder
}
Expand All @@ -59,7 +68,7 @@ type index struct {
}

func init() {
backends.Register("arango", getBackend)
backends.Register("arango", getBackend, registerFlags, parseFlags)
}

func initIndex(name string, fields []string, unique bool) index {
Expand Down Expand Up @@ -117,6 +126,29 @@ func DeleteDatabase(ctx context.Context, args backends.BackendArgs) error {
return nil
}

// registerFlags registers ArangoDB-specific command line flags
func registerFlags(cmd *cobra.Command) error {
flagSet := cmd.Flags()
flagSet.StringVar(&flags.addr, "arango-addr", "http://localhost:8529", "address to arango db")
flagSet.StringVar(&flags.user, "arango-user", "", "arango user to connect to graph db")
flagSet.StringVar(&flags.pass, "arango-pass", "", "arango password to connect to graph db")

if err := viper.BindPFlags(flagSet); err != nil {
return fmt.Errorf("failed to bind flags: %w", err)
}

return nil
}

// parseFlags returns the ArangoDB configuration from parsed flags
func parseFlags(ctx context.Context) (backends.BackendArgs, error) {
return &ArangoConfig{
DBAddr: flags.addr,
User: flags.user,
Pass: flags.pass,
}, nil
}

func getBackend(ctx context.Context, args backends.BackendArgs) (backends.Backend, error) {
config, ok := args.(*ArangoConfig)
if !ok {
Expand Down
Loading

0 comments on commit a510df6

Please sign in to comment.