Skip to content

Commit

Permalink
Add support for Options.TLSConfig.ServerName (#310)
Browse files Browse the repository at this point in the history
Additionally bump the shared test cases and adjust error handling around
branch and database conflicts in connection configuration.
  • Loading branch information
fmoor authored Apr 19, 2024
1 parent bdef9e4 commit dabd2d1
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 66 deletions.
141 changes: 78 additions & 63 deletions internal/client/connutils.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ type connConfig struct {
waitUntilAvailable time.Duration
tlsCAData []byte
tlsSecurity string
tlsServerName string
serverSettings *snc.ServerSettings
secretKey string
}
Expand All @@ -88,6 +89,7 @@ func (c *connConfig) tlsConfig() (*tls.Config, error) {
tlsConfig := &tls.Config{
RootCAs: roots,
NextProtos: []string{"edgedb-binary"},
ServerName: c.tlsServerName,
}

switch c.tlsSecurity {
Expand Down Expand Up @@ -128,11 +130,11 @@ type configResolver struct {
host cfgVal // string
port cfgVal // int
database cfgVal // string
branch cfgVal // string
user cfgVal // string
password cfgVal // OptionalStr
tlsCAData cfgVal // []byte
tlsSecurity cfgVal // string
tlsServerName cfgVal // string
waitUntilAvailable cfgVal // time.Duration
serverSettings *snc.ServerSettings
secretKey cfgVal // string
Expand Down Expand Up @@ -217,17 +219,6 @@ func (r *configResolver) setDatabase(val, source string) error {
return nil
}

func (r *configResolver) setBranch(val, source string) error {
if r.branch.val != nil {
return nil
}
if val == "" {
return errors.New(`invalid branch name: ""`)
}
r.branch = cfgVal{val: val, source: source}
return nil
}

func (r *configResolver) setUser(val, source string) error {
if r.user.val != nil {
return nil
Expand Down Expand Up @@ -279,6 +270,15 @@ func (r *configResolver) setTLSSecurity(val string, source string) error {
return nil
}

func (r *configResolver) setTLSServerName(val string, source string) error {
if r.tlsServerName.val != nil {
return nil
}

r.tlsServerName = cfgVal{val: val, source: source}
return nil
}

func (r *configResolver) setWaitUntilAvailable(
val time.Duration,
source string,
Expand Down Expand Up @@ -354,7 +354,7 @@ func (r *configResolver) resolveOptions(
}

if opts.Branch != "" {
if e := r.setBranch(opts.Branch, "Branch options"); e != nil {
if e := r.setDatabase(opts.Branch, "Branch options"); e != nil {
return e
}
}
Expand Down Expand Up @@ -424,6 +424,14 @@ func (r *configResolver) resolveOptions(
"TLSOptions.SecurityMode option")
}

if opts.TLSOptions.ServerName != "" {
secSources = append(secSources, "TLSOptions.ServerName")
err = r.setTLSServerName(
opts.TLSOptions.ServerName,
"TLSOptions.ServerName options",
)
}

if len(secSources) > 1 {
return fmt.Errorf(
"mutually exclusive options set in Options: %v",
Expand Down Expand Up @@ -502,50 +510,24 @@ func (r *configResolver) resolveDSN(
"cannot be present at the same time")
}

if r.database.val != nil {
return fmt.Errorf(
"`branch` in DSN and %s are mutually exclusive options",
r.database.source,
)
}

val, err = popDSNValue(query, db, "branch", r.branch.val == nil)
val, err = popDSNValue(query, db, "branch", r.database.val == nil)
if err != nil {
return err
} else if val.val != nil {
br := strings.TrimPrefix(val.val.(string), "/")
if e := r.setBranch(br, source+val.source); e != nil {
if e := r.setDatabase(br, source+val.source); e != nil {
return e
}
}
} else {
if r.branch.val != nil {
if queryContains("database", query) {
return fmt.Errorf(
"`database` in DSN and %s are mutually exclusive options",
r.branch.source,
)
}

val, err = popDSNValue(query, db, "branch", r.branch.val == nil)
if err != nil {
return err
} else if val.val != nil {
br := strings.TrimPrefix(val.val.(string), "/")
if e := r.setBranch(br, source+val.source); e != nil {
return e
}
}
} else {
val, err = popDSNValue(
query, db, "database", r.database.val == nil)
if err != nil {
return err
} else if val.val != nil {
db := strings.TrimPrefix(val.val.(string), "/")
if e := r.setDatabase(db, source+val.source); e != nil {
return e
}
val, err = popDSNValue(
query, db, "database", r.database.val == nil)
if err != nil {
return err
} else if val.val != nil {
db := strings.TrimPrefix(val.val.(string), "/")
if e := r.setDatabase(db, source+val.source); e != nil {
return e
}
}
}
Expand Down Expand Up @@ -614,6 +596,22 @@ func (r *configResolver) resolveDSN(
}
}

val, err = popDSNValue(
query,
"",
"tls_server_name",
r.tlsServerName.val == nil,
)
if err != nil {
return err
}
if val.val != nil {
err = r.setTLSServerName(val.val.(string), source+val.source)
if err != nil {
return err
}
}

val, err = popDSNValue(
query,
"",
Expand Down Expand Up @@ -707,7 +705,7 @@ func (r *configResolver) applyCredentials(
}

if br, ok := creds.branch.Get(); ok && br != "" {
if e := r.setBranch(br, source); e != nil {
if e := r.setDatabase(br, source); e != nil {
return e
}
}
Expand All @@ -734,15 +732,21 @@ func (r *configResolver) applyCredentials(
}

func (r *configResolver) resolveEnvVars(paths *cfgPaths) (bool, error) {
if db, ok := os.LookupEnv("EDGEDB_DATABASE"); ok {
db, dbOk := os.LookupEnv("EDGEDB_DATABASE")
if dbOk {
err := r.setDatabase(db, "EDGEDB_DATABASE environment variable")
if err != nil {
return false, err
}
}

if db, ok := os.LookupEnv("EDGEDB_BRANCH"); ok {
err := r.setBranch(db, "EDGEDB_BRANCH environment variable")
if branch, ok := os.LookupEnv("EDGEDB_BRANCH"); ok {
if dbOk {
return false, errors.New(
"mutually exclusive options EDGEDB_DATABASE and " +
"EDGEDB_BRANCH environment variables are set")
}
err := r.setDatabase(branch, "EDGEDB_BRANCH environment variable")
if err != nil {
return false, err
}
Expand Down Expand Up @@ -784,6 +788,16 @@ func (r *configResolver) resolveEnvVars(paths *cfgPaths) (bool, error) {
}
}

if val, ok := os.LookupEnv("EDGEDB_TLS_SERVER_NAME"); ok {
e := r.setTLSServerName(
val,
"EDGEDB_TLS_SERVER_NAME environment variable",
)
if e != nil {
return false, e
}
}

if len(tlsCaSources) > 1 {
return false, fmt.Errorf(
"mutually exclusive environment variables set: %v",
Expand Down Expand Up @@ -946,18 +960,8 @@ func (r *configResolver) config(opts *Options) (*connConfig, error) {
database := "edgedb"
branch := "__default__"
if r.database.val != nil {
if r.branch.val != nil {
return nil, fmt.Errorf(
"%s and %s are mutually exclusive options",
r.database.source,
r.branch.source,
)
}
database = r.database.val.(string)
branch = database
} else if r.branch.val != nil {
branch = r.branch.val.(string)
database = branch
}

user := "edgedb"
Expand All @@ -980,6 +984,11 @@ func (r *configResolver) config(opts *Options) (*connConfig, error) {
tlsSecurity = r.tlsSecurity.val.(string)
}

tlsServerName := ""
if r.tlsServerName.val != nil {
tlsServerName = r.tlsServerName.val.(string)
}

secretKey := ""
if r.secretKey.val != nil {
secretKey = r.secretKey.val.(string)
Expand Down Expand Up @@ -1033,6 +1042,7 @@ func (r *configResolver) config(opts *Options) (*connConfig, error) {
serverSettings: r.serverSettings,
tlsCAData: certData,
tlsSecurity: tlsSecurity,
tlsServerName: tlsServerName,
secretKey: secretKey,
}, nil
}
Expand Down Expand Up @@ -1268,6 +1278,11 @@ var dsnKeyLookup = map[string][]string{
"password": {"password", "password_env", "password_file"},
"tls_ca_file": {"tls_ca_file", "tls_ca_file_env"},
"tls_security": {"tls_security", "tls_security_env", "tls_security_file"},
"tls_server_name": {
"tls_server_name",
"tls_server_name_env",
"tls_server_name_file",
},
"tls_verify_hostname": {
"tls_verify_hostname",
"tls_verify_hostname_env",
Expand Down
6 changes: 6 additions & 0 deletions internal/client/connutils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,8 @@ func TestConnectionParameterResolution(t *testing.T) {
options.TLSOptions.CA = getBytes(t, opts, "tlsCA")
options.TLSOptions.SecurityMode = TLSSecurityMode(
getStr(t, opts, "tlsSecurity"))
options.TLSOptions.ServerName = getStr(
t, opts, "tlsServerName")
if opts["serverSettings"] != nil {
ss := opts["serverSettings"].(map[string]interface{})
options.ServerSettings = make(map[string][]byte, len(ss))
Expand Down Expand Up @@ -673,6 +675,10 @@ func TestConnectionParameterResolution(t *testing.T) {
expectedResult.secretKey = key.(string)
}

if key := res["tlsServerName"]; key != nil {
expectedResult.tlsServerName = key.(string)
}

ss := res["serverSettings"].(map[string]interface{})
for k, v := range ss {
expectedResult.serverSettings.Set(k, []byte(v.(string)))
Expand Down
6 changes: 4 additions & 2 deletions internal/client/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,11 @@ func validateCredentials(data map[string]interface{}) (*credentials, error) {
result.host.Set(h)
}

if inMap("database", data) && inMap("branch", data) {
if inMap("database", data) &&
inMap("branch", data) &&
data["database"] != data["branch"] {
return nil, errors.New(
"`database` and `branch` are mutually exclusive")
"`database` and `branch` are both set but do not match")
}

if database, ok := data["database"]; ok {
Expand Down
2 changes: 2 additions & 0 deletions internal/client/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ type TLSOptions struct {
CAFile string
// Determines how strict we are with TLS checks
SecurityMode TLSSecurityMode
// Used to verify the hostname on the returned certificates
ServerName string
}

// TLSSecurityMode specifies how strict TLS validation is.
Expand Down
2 changes: 1 addition & 1 deletion shared-client-testcases

0 comments on commit dabd2d1

Please sign in to comment.