From 1b7da57f02ebdcd509f3d1de51d8bb94102e7429 Mon Sep 17 00:00:00 2001 From: Krzysztof Nazarewski Date: Thu, 15 Feb 2024 12:02:37 +0100 Subject: [PATCH] config.go: pull unified Config.apply() out of createNewConfig() and update() as a bonus it ensures returned Config object doesn't have any configuration values missing --- client/internal/config.go | 248 ++++++++++++++++++++------------------ 1 file changed, 134 insertions(+), 114 deletions(-) diff --git a/client/internal/config.go b/client/internal/config.go index 2f69582350e..049deed10a5 100644 --- a/client/internal/config.go +++ b/client/internal/config.go @@ -5,6 +5,8 @@ import ( "fmt" "net/url" "os" + "reflect" + "strings" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" @@ -45,6 +47,7 @@ type ConfigInput struct { RosenpassEnabled *bool RosenpassPermissive *bool InterfaceName *string + InterfaceBlacklist []string WireguardPort *int DisableAutoConnect *bool } @@ -62,7 +65,12 @@ type Config struct { DisableIPv6Discovery bool RosenpassEnabled bool RosenpassPermissive bool - ServerSSHAllowed *bool + + // ServerSSHAllowed tells whether SSH server should be enabled + // used to be server-side configuration for pre-0.26 + // it is a pointer defaulting to false in createNewConfig() + // it is defaulting to true for existing configurations + ServerSSHAllowed *bool // SSHKey is a private SSH key in a PEM format SSHKey string @@ -97,6 +105,14 @@ func ReadConfig(configPath string) (*Config, error) { if _, err := util.ReadJson(configPath, config); err != nil { return nil, err } + // initialize through apply() without changes + if changed, err := config.apply(ConfigInput{}); err != nil { + return nil, err + } else if changed { + if err = WriteOutConfig(configPath, config); err != nil { + return nil, err + } + } return config, nil } @@ -149,185 +165,189 @@ func WriteOutConfig(path string, config *Config) error { // createNewConfig creates a new config generating a new Wireguard key and saving to file func createNewConfig(input ConfigInput) (*Config, error) { - wgKey := generateKey() - pem, err := ssh.GeneratePrivateKey(ssh.ED25519) - if err != nil { - return nil, err - } - config := &Config{ - SSHKey: string(pem), - PrivateKey: wgKey, - IFaceBlackList: []string{}, - DisableIPv6Discovery: false, - NATExternalIPs: input.NATExternalIPs, - CustomDNSAddress: string(input.CustomDNSAddress), - ServerSSHAllowed: util.False(), - DisableAutoConnect: false, + // defaults to false only for new (post 0.26) configurations + ServerSSHAllowed: util.False(), } - defaultManagementURL, err := parseURL("Management URL", DefaultManagementURL) - if err != nil { + if _, err := config.apply(input); err != nil { return nil, err } - config.ManagementURL = defaultManagementURL - if input.ManagementURL != "" { - URL, err := parseURL("Management URL", input.ManagementURL) - if err != nil { - return nil, err - } - config.ManagementURL = URL - } - - config.WgPort = iface.DefaultWgPort - if input.WireguardPort != nil { - config.WgPort = *input.WireguardPort - } - - config.WgIface = iface.WgInterfaceDefault - if input.InterfaceName != nil { - config.WgIface = *input.InterfaceName - } - - if input.PreSharedKey != nil { - config.PreSharedKey = *input.PreSharedKey - } - - if input.RosenpassEnabled != nil { - config.RosenpassEnabled = *input.RosenpassEnabled - } + return config, nil +} - if input.RosenpassPermissive != nil { - config.RosenpassPermissive = *input.RosenpassPermissive - } +func update(input ConfigInput) (*Config, error) { + config := &Config{} - if input.ServerSSHAllowed != nil { - config.ServerSSHAllowed = input.ServerSSHAllowed + if _, err := util.ReadJson(input.ConfigPath, config); err != nil { + return nil, err } - defaultAdminURL, err := parseURL("Admin URL", DefaultAdminURL) + updated, err := config.apply(input) if err != nil { return nil, err } - config.AdminURL = defaultAdminURL - if input.AdminURL != "" { - newURL, err := parseURL("Admin Panel URL", input.AdminURL) - if err != nil { + if updated { + if err := util.WriteJson(input.ConfigPath, config); err != nil { return nil, err } - config.AdminURL = newURL } - config.IFaceBlackList = defaultInterfaceBlacklist return config, nil } -func update(input ConfigInput) (*Config, error) { - config := &Config{} - - if _, err := util.ReadJson(input.ConfigPath, config); err != nil { - return nil, err +func (config *Config) apply(input ConfigInput) (updated bool, err error) { + if config.ManagementURL == nil { + log.Infof("using default Management URL %s", DefaultManagementURL) + config.ManagementURL, err = parseURL("Management URL", DefaultManagementURL) + if err != nil { + return false, err + } } - - refresh := false - - if input.ManagementURL != "" && config.ManagementURL.String() != input.ManagementURL { - log.Infof("new Management URL provided, updated to %s (old value %s)", - input.ManagementURL, config.ManagementURL) - newURL, err := parseURL("Management URL", input.ManagementURL) + if input.ManagementURL != "" && input.ManagementURL != config.ManagementURL.String() { + log.Infof("new Management URL provided, updated to %#v (old value %#v)", + input.ManagementURL, config.ManagementURL.String()) + URL, err := parseURL("Management URL", input.ManagementURL) if err != nil { - return nil, err + return false, err + } + config.ManagementURL = URL + updated = true + } else if config.ManagementURL == nil { + log.Infof("using default Management URL %s", DefaultManagementURL) + config.ManagementURL, err = parseURL("Management URL", DefaultManagementURL) + if err != nil { + return false, err } - config.ManagementURL = newURL - refresh = true } - if input.AdminURL != "" && (config.AdminURL == nil || config.AdminURL.String() != input.AdminURL) { - log.Infof("new Admin Panel URL provided, updated to %s (old value %s)", - input.AdminURL, config.AdminURL) + if config.AdminURL == nil { + log.Infof("using default Admin URL %s", DefaultManagementURL) + config.AdminURL, err = parseURL("Admin URL", DefaultAdminURL) + if err != nil { + return false, err + } + } + if input.AdminURL != "" && input.AdminURL != config.AdminURL.String() { + log.Infof("new Admin Panel URL provided, updated to %#v (old value %#v)", + input.AdminURL, config.AdminURL.String()) newURL, err := parseURL("Admin Panel URL", input.AdminURL) if err != nil { - return nil, err + return updated, err } config.AdminURL = newURL - refresh = true + updated = true } - if input.PreSharedKey != nil && config.PreSharedKey != *input.PreSharedKey { - log.Infof("new pre-shared key provided, replacing old key") - config.PreSharedKey = *input.PreSharedKey - refresh = true + if config.PrivateKey == "" { + log.Infof("generated new Wireguard key") + config.PrivateKey = generateKey() + updated = true } if config.SSHKey == "" { + log.Infof("generated new SSH key") pem, err := ssh.GeneratePrivateKey(ssh.ED25519) if err != nil { - return nil, err + return false, err } config.SSHKey = string(pem) - refresh = true + updated = true } - if config.WgPort == 0 { - config.WgPort = iface.DefaultWgPort - refresh = true - } - - if input.WireguardPort != nil { + if input.WireguardPort != nil && *input.WireguardPort != config.WgPort { + log.Infof("updating Wireguard port %d (old value %d)", + *input.WireguardPort, config.WgPort) config.WgPort = *input.WireguardPort - refresh = true + updated = true + } else if config.WgPort == 0 { + config.WgPort = iface.DefaultWgPort + log.Infof("using default Wireguard port %d", config.WgPort) + updated = true } - if input.InterfaceName != nil { + if input.InterfaceName != nil && *input.InterfaceName != config.WgIface { + log.Infof("updating Wireguard interface %#v (old value %#v)", + *input.InterfaceName, config.WgIface) config.WgIface = *input.InterfaceName - refresh = true + updated = true + } else if config.WgIface == "" { + config.WgIface = iface.WgInterfaceDefault + log.Infof("using default Wireguard interface %s", config.WgIface) + updated = true } - if input.NATExternalIPs != nil && len(config.NATExternalIPs) != len(input.NATExternalIPs) { + if input.NATExternalIPs != nil && reflect.DeepEqual(config.NATExternalIPs, input.NATExternalIPs) { + log.Infof("updating NAT External IP [ %s ] (old value: [ %s ])", + strings.Join(input.NATExternalIPs, " "), + strings.Join(config.NATExternalIPs, " ")) config.NATExternalIPs = input.NATExternalIPs - refresh = true + updated = true } - if input.CustomDNSAddress != nil { - config.CustomDNSAddress = string(input.CustomDNSAddress) - refresh = true + if input.PreSharedKey != nil && *input.PreSharedKey != config.PreSharedKey { + log.Infof("new pre-shared key provided, replacing old key") + config.PreSharedKey = *input.PreSharedKey + updated = true } - if input.RosenpassEnabled != nil { + if input.RosenpassEnabled != nil && *input.RosenpassEnabled != config.RosenpassEnabled { + log.Infof("switching Rosenpass to %t", *input.RosenpassEnabled) config.RosenpassEnabled = *input.RosenpassEnabled - refresh = true + updated = true } - if input.RosenpassPermissive != nil { + if input.RosenpassPermissive != nil && *input.RosenpassPermissive != config.RosenpassPermissive { + log.Infof("switching Rosenpass permissive to %t", *input.RosenpassPermissive) config.RosenpassPermissive = *input.RosenpassPermissive - refresh = true + updated = true } - if input.DisableAutoConnect != nil { - config.DisableAutoConnect = *input.DisableAutoConnect - refresh = true + if input.CustomDNSAddress != nil && string(input.CustomDNSAddress) != config.CustomDNSAddress { + log.Infof("updating custom DNS address %#v (old value %#v)", + string(input.CustomDNSAddress), config.CustomDNSAddress) + config.CustomDNSAddress = string(input.CustomDNSAddress) + updated = true } - if input.ServerSSHAllowed != nil { - config.ServerSSHAllowed = input.ServerSSHAllowed - refresh = true + if input.InterfaceBlacklist != nil && reflect.DeepEqual(input.InterfaceBlacklist, config.IFaceBlackList) { + log.Infof("updating interface blacklist [ %s ] (old value: [ %s ])", + strings.Join(input.InterfaceBlacklist, " "), + strings.Join(config.IFaceBlackList, " ")) + config.IFaceBlackList = input.InterfaceBlacklist + updated = true + } else if config.IFaceBlackList == nil { + config.IFaceBlackList = defaultInterfaceBlacklist + updated = true } - if config.ServerSSHAllowed == nil { - config.ServerSSHAllowed = util.True() - refresh = true + if input.DisableAutoConnect != nil && *input.DisableAutoConnect != config.DisableAutoConnect { + if *input.DisableAutoConnect { + log.Infof("turning off automatic connection on startup") + } else { + log.Infof("enabling automatic connection on startup") + } + config.DisableAutoConnect = *input.DisableAutoConnect + updated = true } - if refresh { - // since we have new management URL, we need to update config file - if err := util.WriteJson(input.ConfigPath, config); err != nil { - return nil, err + if input.ServerSSHAllowed != nil && *input.ServerSSHAllowed != *config.ServerSSHAllowed { + if *input.ServerSSHAllowed { + log.Infof("enabling SSH server") + } else { + log.Infof("disabling SSH server") } + config.ServerSSHAllowed = input.ServerSSHAllowed + updated = true + } else if config.ServerSSHAllowed == nil { + // enables SSH for configs from old versions to preserve backwards compatibility + log.Infof("falling back to enabled SSH server for pre-existing configuration") + config.ServerSSHAllowed = util.True() + updated = true } - - return config, nil + return updated, nil } // parseURL parses and validates a service URL