Skip to content

Commit

Permalink
config.go: pull unified Config.apply() out of createNewConfig() and u…
Browse files Browse the repository at this point in the history
…pdate()

as a bonus it ensures returned Config object doesn't have any configuration
values missing
  • Loading branch information
nazarewk committed Feb 26, 2024
1 parent 38497ae commit 1b7da57
Showing 1 changed file with 134 additions and 114 deletions.
248 changes: 134 additions & 114 deletions client/internal/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"fmt"
"net/url"
"os"
"reflect"
"strings"

log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
Expand Down Expand Up @@ -45,6 +47,7 @@ type ConfigInput struct {
RosenpassEnabled *bool
RosenpassPermissive *bool
InterfaceName *string
InterfaceBlacklist []string
WireguardPort *int
DisableAutoConnect *bool
}
Expand All @@ -62,7 +65,12 @@ type Config struct {
DisableIPv6Discovery bool
RosenpassEnabled bool
RosenpassPermissive bool
ServerSSHAllowed *bool

// ServerSSHAllowed tells whether SSH server should be enabled
// used to be server-side configuration for pre-0.26
// it is a pointer defaulting to false in createNewConfig()
// it is defaulting to true for existing configurations
ServerSSHAllowed *bool
// SSHKey is a private SSH key in a PEM format
SSHKey string

Expand Down Expand Up @@ -97,6 +105,14 @@ func ReadConfig(configPath string) (*Config, error) {
if _, err := util.ReadJson(configPath, config); err != nil {
return nil, err
}
// initialize through apply() without changes
if changed, err := config.apply(ConfigInput{}); err != nil {
return nil, err
} else if changed {
if err = WriteOutConfig(configPath, config); err != nil {
return nil, err
}
}

return config, nil
}
Expand Down Expand Up @@ -149,185 +165,189 @@ func WriteOutConfig(path string, config *Config) error {

// createNewConfig creates a new config generating a new Wireguard key and saving to file
func createNewConfig(input ConfigInput) (*Config, error) {
wgKey := generateKey()
pem, err := ssh.GeneratePrivateKey(ssh.ED25519)
if err != nil {
return nil, err
}

config := &Config{
SSHKey: string(pem),
PrivateKey: wgKey,
IFaceBlackList: []string{},
DisableIPv6Discovery: false,
NATExternalIPs: input.NATExternalIPs,
CustomDNSAddress: string(input.CustomDNSAddress),
ServerSSHAllowed: util.False(),
DisableAutoConnect: false,
// defaults to false only for new (post 0.26) configurations
ServerSSHAllowed: util.False(),
}

defaultManagementURL, err := parseURL("Management URL", DefaultManagementURL)
if err != nil {
if _, err := config.apply(input); err != nil {
return nil, err
}

config.ManagementURL = defaultManagementURL
if input.ManagementURL != "" {
URL, err := parseURL("Management URL", input.ManagementURL)
if err != nil {
return nil, err
}
config.ManagementURL = URL
}

config.WgPort = iface.DefaultWgPort
if input.WireguardPort != nil {
config.WgPort = *input.WireguardPort
}

config.WgIface = iface.WgInterfaceDefault
if input.InterfaceName != nil {
config.WgIface = *input.InterfaceName
}

if input.PreSharedKey != nil {
config.PreSharedKey = *input.PreSharedKey
}

if input.RosenpassEnabled != nil {
config.RosenpassEnabled = *input.RosenpassEnabled
}
return config, nil
}

if input.RosenpassPermissive != nil {
config.RosenpassPermissive = *input.RosenpassPermissive
}
func update(input ConfigInput) (*Config, error) {
config := &Config{}

if input.ServerSSHAllowed != nil {
config.ServerSSHAllowed = input.ServerSSHAllowed
if _, err := util.ReadJson(input.ConfigPath, config); err != nil {
return nil, err
}

defaultAdminURL, err := parseURL("Admin URL", DefaultAdminURL)
updated, err := config.apply(input)
if err != nil {
return nil, err
}

config.AdminURL = defaultAdminURL
if input.AdminURL != "" {
newURL, err := parseURL("Admin Panel URL", input.AdminURL)
if err != nil {
if updated {
if err := util.WriteJson(input.ConfigPath, config); err != nil {
return nil, err
}
config.AdminURL = newURL
}

config.IFaceBlackList = defaultInterfaceBlacklist
return config, nil
}

func update(input ConfigInput) (*Config, error) {
config := &Config{}

if _, err := util.ReadJson(input.ConfigPath, config); err != nil {
return nil, err
func (config *Config) apply(input ConfigInput) (updated bool, err error) {
if config.ManagementURL == nil {
log.Infof("using default Management URL %s", DefaultManagementURL)
config.ManagementURL, err = parseURL("Management URL", DefaultManagementURL)
if err != nil {
return false, err
}
}

refresh := false

if input.ManagementURL != "" && config.ManagementURL.String() != input.ManagementURL {
log.Infof("new Management URL provided, updated to %s (old value %s)",
input.ManagementURL, config.ManagementURL)
newURL, err := parseURL("Management URL", input.ManagementURL)
if input.ManagementURL != "" && input.ManagementURL != config.ManagementURL.String() {
log.Infof("new Management URL provided, updated to %#v (old value %#v)",
input.ManagementURL, config.ManagementURL.String())
URL, err := parseURL("Management URL", input.ManagementURL)
if err != nil {
return nil, err
return false, err
}
config.ManagementURL = URL
updated = true
} else if config.ManagementURL == nil {
log.Infof("using default Management URL %s", DefaultManagementURL)
config.ManagementURL, err = parseURL("Management URL", DefaultManagementURL)
if err != nil {
return false, err
}
config.ManagementURL = newURL
refresh = true
}

if input.AdminURL != "" && (config.AdminURL == nil || config.AdminURL.String() != input.AdminURL) {
log.Infof("new Admin Panel URL provided, updated to %s (old value %s)",
input.AdminURL, config.AdminURL)
if config.AdminURL == nil {
log.Infof("using default Admin URL %s", DefaultManagementURL)
config.AdminURL, err = parseURL("Admin URL", DefaultAdminURL)
if err != nil {
return false, err
}
}
if input.AdminURL != "" && input.AdminURL != config.AdminURL.String() {
log.Infof("new Admin Panel URL provided, updated to %#v (old value %#v)",
input.AdminURL, config.AdminURL.String())
newURL, err := parseURL("Admin Panel URL", input.AdminURL)
if err != nil {
return nil, err
return updated, err
}
config.AdminURL = newURL
refresh = true
updated = true
}

if input.PreSharedKey != nil && config.PreSharedKey != *input.PreSharedKey {
log.Infof("new pre-shared key provided, replacing old key")
config.PreSharedKey = *input.PreSharedKey
refresh = true
if config.PrivateKey == "" {
log.Infof("generated new Wireguard key")
config.PrivateKey = generateKey()
updated = true
}

if config.SSHKey == "" {
log.Infof("generated new SSH key")
pem, err := ssh.GeneratePrivateKey(ssh.ED25519)
if err != nil {
return nil, err
return false, err
}
config.SSHKey = string(pem)
refresh = true
updated = true
}

if config.WgPort == 0 {
config.WgPort = iface.DefaultWgPort
refresh = true
}

if input.WireguardPort != nil {
if input.WireguardPort != nil && *input.WireguardPort != config.WgPort {
log.Infof("updating Wireguard port %d (old value %d)",
*input.WireguardPort, config.WgPort)
config.WgPort = *input.WireguardPort
refresh = true
updated = true
} else if config.WgPort == 0 {
config.WgPort = iface.DefaultWgPort
log.Infof("using default Wireguard port %d", config.WgPort)
updated = true
}

if input.InterfaceName != nil {
if input.InterfaceName != nil && *input.InterfaceName != config.WgIface {
log.Infof("updating Wireguard interface %#v (old value %#v)",
*input.InterfaceName, config.WgIface)
config.WgIface = *input.InterfaceName
refresh = true
updated = true
} else if config.WgIface == "" {
config.WgIface = iface.WgInterfaceDefault
log.Infof("using default Wireguard interface %s", config.WgIface)
updated = true
}

if input.NATExternalIPs != nil && len(config.NATExternalIPs) != len(input.NATExternalIPs) {
if input.NATExternalIPs != nil && reflect.DeepEqual(config.NATExternalIPs, input.NATExternalIPs) {
log.Infof("updating NAT External IP [ %s ] (old value: [ %s ])",
strings.Join(input.NATExternalIPs, " "),
strings.Join(config.NATExternalIPs, " "))
config.NATExternalIPs = input.NATExternalIPs
refresh = true
updated = true
}

if input.CustomDNSAddress != nil {
config.CustomDNSAddress = string(input.CustomDNSAddress)
refresh = true
if input.PreSharedKey != nil && *input.PreSharedKey != config.PreSharedKey {
log.Infof("new pre-shared key provided, replacing old key")
config.PreSharedKey = *input.PreSharedKey
updated = true
}

if input.RosenpassEnabled != nil {
if input.RosenpassEnabled != nil && *input.RosenpassEnabled != config.RosenpassEnabled {
log.Infof("switching Rosenpass to %t", *input.RosenpassEnabled)
config.RosenpassEnabled = *input.RosenpassEnabled
refresh = true
updated = true
}

if input.RosenpassPermissive != nil {
if input.RosenpassPermissive != nil && *input.RosenpassPermissive != config.RosenpassPermissive {
log.Infof("switching Rosenpass permissive to %t", *input.RosenpassPermissive)
config.RosenpassPermissive = *input.RosenpassPermissive
refresh = true
updated = true
}

if input.DisableAutoConnect != nil {
config.DisableAutoConnect = *input.DisableAutoConnect
refresh = true
if input.CustomDNSAddress != nil && string(input.CustomDNSAddress) != config.CustomDNSAddress {
log.Infof("updating custom DNS address %#v (old value %#v)",
string(input.CustomDNSAddress), config.CustomDNSAddress)
config.CustomDNSAddress = string(input.CustomDNSAddress)
updated = true
}

if input.ServerSSHAllowed != nil {
config.ServerSSHAllowed = input.ServerSSHAllowed
refresh = true
if input.InterfaceBlacklist != nil && reflect.DeepEqual(input.InterfaceBlacklist, config.IFaceBlackList) {
log.Infof("updating interface blacklist [ %s ] (old value: [ %s ])",
strings.Join(input.InterfaceBlacklist, " "),
strings.Join(config.IFaceBlackList, " "))
config.IFaceBlackList = input.InterfaceBlacklist
updated = true
} else if config.IFaceBlackList == nil {
config.IFaceBlackList = defaultInterfaceBlacklist
updated = true
}

if config.ServerSSHAllowed == nil {
config.ServerSSHAllowed = util.True()
refresh = true
if input.DisableAutoConnect != nil && *input.DisableAutoConnect != config.DisableAutoConnect {
if *input.DisableAutoConnect {
log.Infof("turning off automatic connection on startup")
} else {
log.Infof("enabling automatic connection on startup")
}
config.DisableAutoConnect = *input.DisableAutoConnect
updated = true
}

if refresh {
// since we have new management URL, we need to update config file
if err := util.WriteJson(input.ConfigPath, config); err != nil {
return nil, err
if input.ServerSSHAllowed != nil && *input.ServerSSHAllowed != *config.ServerSSHAllowed {
if *input.ServerSSHAllowed {
log.Infof("enabling SSH server")
} else {
log.Infof("disabling SSH server")
}
config.ServerSSHAllowed = input.ServerSSHAllowed
updated = true
} else if config.ServerSSHAllowed == nil {
// enables SSH for configs from old versions to preserve backwards compatibility
log.Infof("falling back to enabled SSH server for pre-existing configuration")
config.ServerSSHAllowed = util.True()
updated = true
}

return config, nil
return updated, nil
}

// parseURL parses and validates a service URL
Expand Down

0 comments on commit 1b7da57

Please sign in to comment.