Skip to content

Commit

Permalink
Merge pull request Place1#106 from DasSkelett/feature/generate-client…
Browse files Browse the repository at this point in the history
…-dns
  • Loading branch information
DasSkelett authored Mar 2, 2022
2 parents 4fae08b + 52be27a commit a2d89d4
Show file tree
Hide file tree
Showing 13 changed files with 567 additions and 174 deletions.
164 changes: 129 additions & 35 deletions cmd/serve/main.go
Original file line number Diff line number Diff line change
@@ -1,34 +1,38 @@
package serve

import (
"context"
"fmt"
"io/ioutil"
"net"
"net/http"
"os"
"os/signal"
"strings"
"syscall"
"time"

"github.com/docker/libnetwork/resolvconf"
"github.com/docker/libnetwork/types"
"github.com/place1/wg-access-server/internal/config"
"github.com/place1/wg-access-server/internal/devices"
"github.com/place1/wg-access-server/internal/dnsproxy"
"github.com/place1/wg-access-server/internal/network"
"github.com/place1/wg-access-server/internal/services"
"github.com/place1/wg-access-server/internal/storage"
"github.com/place1/wg-access-server/pkg/authnz"
"github.com/place1/wg-access-server/pkg/authnz/authconfig"
"github.com/place1/wg-access-server/pkg/authnz/authsession"

"github.com/docker/libnetwork/resolvconf"
"github.com/docker/libnetwork/types"
"github.com/gorilla/mux"
"github.com/pkg/errors"
"github.com/place1/wg-embed/pkg/wgembed"
"github.com/sirupsen/logrus"
"github.com/vishvananda/netlink"
"golang.org/x/crypto/bcrypt"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"gopkg.in/alecthomas/kingpin.v2"
"gopkg.in/yaml.v2"

"github.com/gorilla/mux"
"github.com/place1/wg-embed/pkg/wgembed"

"github.com/pkg/errors"
"github.com/place1/wg-access-server/internal/config"
"github.com/place1/wg-access-server/internal/devices"
"github.com/place1/wg-access-server/internal/dnsproxy"
"github.com/place1/wg-access-server/internal/network"
"github.com/sirupsen/logrus"
)

func Register(app *kingpin.Application) *servecmd {
Expand All @@ -52,7 +56,8 @@ func Register(app *kingpin.Application) *servecmd {
cli.Flag("vpn-gateway-interface", "The gateway network interface (i.e. eth0)").Envar("WG_VPN_GATEWAY_INTERFACE").Default(detectDefaultInterface()).StringVar(&cmd.AppConfig.VPN.GatewayInterface)
cli.Flag("vpn-allowed-ips", "A list of networks that VPN clients will be allowed to connect to via the VPN").Envar("WG_VPN_ALLOWED_IPS").Default("0.0.0.0/0", "::/0").StringsVar(&cmd.AppConfig.VPN.AllowedIPs)
cli.Flag("dns-enabled", "Enable or disable the embedded dns proxy server (useful for development)").Envar("WG_DNS_ENABLED").Default("true").BoolVar(&cmd.AppConfig.DNS.Enabled)
cli.Flag("dns-upstream", "An upstream DNS server to proxy DNS traffic to. Defaults to resolveconf with Cloudflare DNS as fallback").Envar("WG_DNS_UPSTREAM").StringsVar(&cmd.AppConfig.DNS.Upstream)
cli.Flag("dns-upstream", "An upstream DNS server to proxy DNS traffic to. Defaults to resolvconf with Cloudflare DNS as fallback").Envar("WG_DNS_UPSTREAM").StringsVar(&cmd.AppConfig.DNS.Upstream)
cli.Flag("dns-domain", "A domain to serve configured device names authoritatively").Envar("WG_DNS_DOMAIN").StringVar(&cmd.AppConfig.DNS.Domain)
return cmd
}

Expand All @@ -74,6 +79,9 @@ func (cmd *servecmd) Run() {
if conf.VPN.CIDRv6 == "0" {
conf.VPN.CIDRv6 = ""
}
if conf.DNS.Domain == "0" {
conf.DNS.Domain = ""
}

// Get the server's IP addresses within the VPN
var vpnip, vpnipv6 *net.IPNet
Expand All @@ -98,6 +106,13 @@ func (cmd *servecmd) Run() {
conf.VPN.AllowedIPs = append(conf.VPN.AllowedIPs, fmt.Sprintf("%s/128", vpnipv6.IP.String()))
vpnipstrings = append(vpnipstrings, vpnipv6.String())
}
vpnips := make([]net.IP, 0, 2)
if vpnip != nil {
vpnips = append(vpnips, vpnip.IP)
}
if vpnipv6 != nil {
vpnips = append(vpnips, vpnipv6.IP)
}

// WireGuard Server
wg := wgembed.NewNoOpInterface()
Expand All @@ -120,44 +135,76 @@ func (cmd *servecmd) Run() {
}

if err := wg.LoadConfig(wgconfig); err != nil {
logrus.Fatal(errors.Wrap(err, "failed to load wireguard config"))
logrus.Error(errors.Wrap(err, "failed to load wireguard config"))
return
}

logrus.Infof("wireguard VPN network is %s", network.StringJoinIPNets(vpnip, vpnipv6))

if err := network.ConfigureForwarding(conf.VPN.GatewayInterface, conf.VPN.CIDR, conf.VPN.CIDRv6, conf.VPN.NAT44, conf.VPN.NAT66, conf.VPN.AllowedIPs); err != nil {
logrus.Fatal(err)
logrus.Error(err)
return
}
}

// Storage
storageBackend, err := storage.NewStorage(conf.Storage)
if err != nil {
logrus.Error(errors.Wrap(err, "failed to create storage backend"))
return
}
if err := storageBackend.Open(); err != nil {
logrus.Error(errors.Wrap(err, "failed to connect/open storage backend"))
return
}
defer storageBackend.Close()

// Device manager
deviceManager := devices.New(wg, storageBackend, conf.VPN.CIDR, conf.VPN.CIDRv6)

// DNS Server
if conf.DNS.Enabled {
if conf.DNS.Upstream == nil {
if conf.DNS.Upstream == nil || len(conf.DNS.Upstream) <= 0 {
conf.DNS.Upstream = detectDNSUpstream(conf.VPN.CIDR != "", conf.VPN.CIDRv6 != "")
}
listenAddr := make([]string, 0, 2)
for _, addr := range vpnips {
listenAddr = append(listenAddr, net.JoinHostPort(addr.String(), "53"))
}
dns, err := dnsproxy.New(dnsproxy.DNSServerOpts{
Upstream: conf.DNS.Upstream,
Upstream: conf.DNS.Upstream,
Domain: conf.DNS.Domain,
ListenAddr: listenAddr,
})
if err != nil {
logrus.Fatal(errors.Wrap(err, "failed to start dns server"))
logrus.Error(errors.Wrap(err, "failed to start dns server"))
return
}
defer dns.Close()
if conf.DNS.Domain != "" {
// Generate initial DNS zone for registered devices
zone := generateZone(deviceManager, vpnips)
dns.PushAuthZone(zone)
// Update the zone in the background whenever a device changes
storageBackend.OnAdd(
func(_ *storage.Device) {
zone := generateZone(deviceManager, vpnips)
dns.PushAuthZone(zone)
},
)
storageBackend.OnDelete(
func(_ *storage.Device) {
zone := generateZone(deviceManager, vpnips)
dns.PushAuthZone(zone)
},
)
}
}

// Storage
storageBackend, err := storage.NewStorage(conf.Storage)
if err != nil {
logrus.Fatal(errors.Wrap(err, "failed to create storage backend"))
}
if err := storageBackend.Open(); err != nil {
logrus.Fatal(errors.Wrap(err, "failed to connect/open storage backend"))
}
defer storageBackend.Close()

// Services
deviceManager := devices.New(wg, storageBackend, conf.VPN.CIDR, conf.VPN.CIDRv6)
if err := deviceManager.StartSync(conf.DisableMetadata); err != nil {
logrus.Fatal(errors.Wrap(err, "failed to sync"))
logrus.Error(errors.Wrap(err, "failed to sync"))
return
}

router := mux.NewRouter()
Expand All @@ -171,7 +218,8 @@ func (cmd *servecmd) Run() {
if conf.Auth.IsEnabled() {
middleware, err := authnz.NewMiddleware(conf.Auth, claimsMiddleware(conf))
if err != nil {
logrus.Fatal(errors.Wrap(err, "failed to set up authnz middleware"))
logrus.Error(errors.Wrap(err, "failed to set up authnz middleware"))
return
}
router.Use(middleware)
} else {
Expand Down Expand Up @@ -203,17 +251,38 @@ func (cmd *servecmd) Run() {

publicRouter := router

signalChan := make(chan os.Signal, 2)
signal.Notify(signalChan, os.Interrupt, syscall.SIGTERM)
errChan := make(chan error)

// Listen
address := fmt.Sprintf(":%d", conf.Port)
srv := &http.Server{
Addr: address,
Handler: publicRouter,
}

// Start Web server
logrus.Infof("web ui listening on %v", address)
if err := srv.ListenAndServe(); err != nil {
logrus.Fatal(errors.Wrap(err, "unable to start http server"))
go func() {
// Start Web server
logrus.Infof("web ui listening on %v", address)
err := srv.ListenAndServe()
if err != nil && !errors.Is(err, http.ErrServerClosed) {
errChan <- errors.Wrap(err, "unable to start http server")
}
}()

select {
case <-signalChan:
ctx := context.Background()
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
err = srv.Shutdown(ctx)
if err != nil {
logrus.Error(err)
}
cancel() // always call cancel to clean up the context
case err = <-errChan:
logrus.Error(err)
return
}
}

Expand Down Expand Up @@ -318,6 +387,31 @@ func detectDefaultInterface() string {
return ""
}

func generateZone(deviceManager *devices.DeviceManager, vpnips []net.IP) dnsproxy.Zone {
devs, err := deviceManager.ListAllDevices()
if err != nil {
logrus.Error(errors.Wrap(err, "could not query devices to generate the DNS zone"))
}

zone := make(dnsproxy.Zone)
for _, device := range devs {
owner := device.Owner
name := device.Name
addressStrings := network.SplitAddresses(device.Address)
addresses := make([]net.IP, 0, 2)
for _, str := range addressStrings {
addr, _, err := net.ParseCIDR(str)
if err != nil {
continue
}
addresses = append(addresses, addr)
}
zone[dnsproxy.ZoneKey{Owner: owner, Name: name}] = addresses
}
zone[dnsproxy.ZoneKey{}] = vpnips
return zone
}

var missingPrivateKey = `missing wireguard private key:
create a key:
Expand Down
Loading

0 comments on commit a2d89d4

Please sign in to comment.