Skip to content

Commit

Permalink
add tls flags
Browse files Browse the repository at this point in the history
  • Loading branch information
Jesse Schmidt committed Jun 13, 2024
1 parent 36324bd commit 3880ef4
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 0 deletions.
6 changes: 6 additions & 0 deletions cmd/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,10 @@ const (
flagNameBatchMaxRecords = "hnsw-batch-max-records"
flagNameBatchInterval = "hnsw-batch-interval"
flagNameBatchEnabled = "hnsw-batch-enabled"
flagNameTLSProtocols = "tls-protocols"
flagNameTLSCaFile = "tls-cafile"
flagNameTLSCaPath = "tls-capath"
flagNameTLSCertFile = "tls-certfile"
flagNameTLSKeyFile = "tls-keyfile"
flagNameTLSKeyFilePass = "tls-keyfile-password"
)
3 changes: 3 additions & 0 deletions cmd/createIndex.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ var createIndexFlags = &struct {
hnswBatchInterval flags.Uint32OptionalFlag
hnswBatchEnabled flags.BoolOptionalFlag
timeout time.Duration
tls *flags.TLSFlags
}{
host: flags.NewDefaultHostPortFlag(),
seeds: &flags.SeedsSliceFlag{},
Expand All @@ -51,6 +52,7 @@ var createIndexFlags = &struct {
hnswBatchMaxRecords: flags.Uint32OptionalFlag{},
hnswBatchInterval: flags.Uint32OptionalFlag{},
hnswBatchEnabled: flags.BoolOptionalFlag{},
tls: &flags.TLSFlags{},
}

func newCreateIndexFlagSet() *pflag.FlagSet {
Expand All @@ -74,6 +76,7 @@ func newCreateIndexFlagSet() *pflag.FlagSet {
flagSet.Var(&createIndexFlags.hnswBatchMaxRecords, flagNameBatchMaxRecords, commonFlags.DefaultWrapHelpString("Maximum number of records to fit in a batch. The default value is 10000.")) //nolint:lll // For readability
flagSet.Var(&createIndexFlags.hnswBatchInterval, flagNameBatchInterval, commonFlags.DefaultWrapHelpString("The maximum amount of time in milliseconds to wait before finalizing a batch. The default value is 10000.")) //nolint:lll // For readability
flagSet.Var(&createIndexFlags.hnswBatchEnabled, flagNameBatchEnabled, commonFlags.DefaultWrapHelpString("Enables batching for index updates. Default is true meaning batching is enabled.")) //nolint:lll // For readability
flagSet.AddFlagSet(createIndexFlags.tls.NewFlagSet(commonFlags.DefaultWrapHelpString))

return flagSet
}
Expand Down
3 changes: 3 additions & 0 deletions cmd/dropIndex.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@ var dropIndexFlags = &struct {
sets []string
indexName string
timeout time.Duration
tls *flags.TLSFlags
}{
host: flags.NewDefaultHostPortFlag(),
seeds: &flags.SeedsSliceFlag{},
tls: &flags.TLSFlags{},
}

func newDropIndexFlagSet() *pflag.FlagSet {
Expand All @@ -40,6 +42,7 @@ func newDropIndexFlagSet() *pflag.FlagSet {
flagSet.StringArrayVarP(&dropIndexFlags.sets, flagNameSets, "s", nil, commonFlags.DefaultWrapHelpString("The sets for the index.")) //nolint:lll // For readability
flagSet.StringVarP(&dropIndexFlags.indexName, flagNameIndexName, "i", "", commonFlags.DefaultWrapHelpString("The name of the index.")) //nolint:lll // For readability
flagSet.DurationVar(&dropIndexFlags.timeout, flagNameTimeout, time.Second*5, commonFlags.DefaultWrapHelpString("The distance metric for the index.")) //nolint:lll // For readability
flagSet.AddFlagSet(dropIndexFlags.tls.NewFlagSet(commonFlags.DefaultWrapHelpString))

return flagSet
}
Expand Down
61 changes: 61 additions & 0 deletions cmd/flags/tls.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package flags

import (
"crypto/tls"

commonClient "github.com/aerospike/tools-common-go/client"
commonFlags "github.com/aerospike/tools-common-go/flags"
"github.com/spf13/pflag"
)

type TLSFlags struct {
TLSProtocols commonFlags.TLSProtocolsFlag
TLSRootCAFile commonFlags.CertFlag
TLSRootCAPath commonFlags.CertPathFlag
TLSCertFile commonFlags.CertFlag
TLSKeyFile commonFlags.CertFlag
TLSKeyFilePass commonFlags.PasswordFlag
}

func NewTLSFlags() *TLSFlags {
return &TLSFlags{
TLSProtocols: commonFlags.NewDefaultTLSProtocolsFlag(),
}
}

// NewAerospikeFlagSet returns a new pflag.FlagSet with Aerospike flags defined.
// Values set in the returned FlagSet will be stored in the AerospikeFlags argument.
func (tf *TLSFlags) NewFlagSet(fmtUsage commonFlags.UsageFormatter) *pflag.FlagSet {
f := &pflag.FlagSet{}

f.Var(&tf.TLSRootCAFile, "tls-cafile", fmtUsage("The CA used when connecting to AVS."))
f.Var(&tf.TLSRootCAPath, "tls-capath", fmtUsage("A path containing CAs for connecting to AVS."))
f.Var(&tf.TLSCertFile, "tls-certfile", fmtUsage("The certificate file for mutual TLS authentication with AVS."))
f.Var(&tf.TLSKeyFile, "tls-keyfile", fmtUsage("The key file used for mutual TLS authentication with AVS."))
f.Var(&tf.TLSKeyFilePass, "tls-keyfile-password", fmtUsage("The password used to decrypt the key-file if encrypted."))
f.Var(&tf.TLSProtocols, "tls-protocols", fmtUsage(
"Set the TLS protocol selection criteria. This format is the same as"+
" Apache's SSLProtocol documented at https://httpd.apache.org/docs/current/mod/mod_ssl.html#ssl protocol.",
))

return f
}

func (tf *TLSFlags) NewTLSConfig() (*tls.Config, error) {
rootCA := [][]byte{}

if len(tf.TLSRootCAFile) != 0 {
rootCA = append(rootCA, tf.TLSRootCAFile)
}

rootCA = append(rootCA, tf.TLSRootCAPath...)

return commonClient.NewTLSConfig(
rootCA,
tf.TLSCertFile,
tf.TLSKeyFile,
tf.TLSKeyFilePass,
tf.TLSProtocols.Min,
tf.TLSProtocols.Max,
).NewGoTLSConfig()
}
3 changes: 3 additions & 0 deletions cmd/listIndex.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@ var listIndexFlags = &struct {
listenerName flags.StringOptionalFlag
verbose bool
timeout time.Duration
tls *flags.TLSFlags
}{
host: flags.NewDefaultHostPortFlag(),
seeds: &flags.SeedsSliceFlag{},
tls: &flags.TLSFlags{},
}

func newListIndexFlagSet() *pflag.FlagSet {
Expand All @@ -37,6 +39,7 @@ func newListIndexFlagSet() *pflag.FlagSet {
flagSet.VarP(&listIndexFlags.listenerName, flagNameListenerName, "l", commonFlags.DefaultWrapHelpString("The listener to ask the AVS server for as configured in the AVS server. Likely required for cloud deployments.")) //nolint:lll // For readability
flagSet.BoolVarP(&listIndexFlags.verbose, flagNameVerbose, "v", false, commonFlags.DefaultWrapHelpString("Print detailed index information.")) //nolint:lll // For readability
flagSet.DurationVar(&listIndexFlags.timeout, flagNameTimeout, time.Second*5, commonFlags.DefaultWrapHelpString("The distance metric for the index.")) //nolint:lll // For readability
flagSet.AddFlagSet(listIndexFlags.tls.NewFlagSet(commonFlags.DefaultWrapHelpString))

return flagSet
}
Expand Down

0 comments on commit 3880ef4

Please sign in to comment.