diff --git a/cmd/serve/main.go b/cmd/serve/main.go index 8e2ea0a2..efcd4001 100644 --- a/cmd/serve/main.go +++ b/cmd/serve/main.go @@ -114,7 +114,7 @@ func (cmd *servecmd) Run() { logrus.Infof("wireguard VPN network is %s", network.StringJoinIPNets(vpnip, vpnipv6)) - if err := network.ConfigureForwarding(conf.WireGuard.Interface, conf.VPN.GatewayInterface, conf.VPN.CIDR, conf.VPN.CIDRv6, conf.VPN.NAT66, conf.VPN.AllowedIPs); err != nil { + if err := network.ConfigureForwarding(conf.VPN.GatewayInterface, conf.VPN.CIDR, conf.VPN.CIDRv6, conf.VPN.NAT66, conf.VPN.AllowedIPs); err != nil { logrus.Fatal(err) } } @@ -155,7 +155,11 @@ func (cmd *servecmd) Run() { // Authentication middleware if conf.Auth.IsEnabled() { - router.Use(authnz.NewMiddleware(conf.Auth, claimsMiddleware(conf))) + middleware, err := authnz.NewMiddleware(conf.Auth, claimsMiddleware(conf)) + if err != nil { + logrus.Fatal(errors.Wrap(err, "failed to set up authnz middleware")) + } + router.Use(middleware) } else { logrus.Warn("[DEPRECATION NOTICE] using wg-access-server without an admin user is deprecated and will be removed in an upcoming minor release.") router.Use(func(next http.Handler) http.Handler { diff --git a/internal/dnsproxy/server.go b/internal/dnsproxy/server.go index 6ed650e3..3ca4a53f 100644 --- a/internal/dnsproxy/server.go +++ b/internal/dnsproxy/server.go @@ -74,19 +74,35 @@ func (d *DNSServer) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { m, err := d.Lookup(r) if err != nil { logrus.Errorf("failed lookup record with error: %s\n%s", err.Error(), r) - dns.HandleFailed(w, r) + HandleFailed(w, r) return } m.SetReply(r) - w.WriteMsg(m) + err = w.WriteMsg(m) + if err != nil { + logrus.Errorf("failed write response for client with error: %s\n%s", err.Error(), r) + return + } default: m := &dns.Msg{} m.SetReply(r) - w.WriteMsg(m) + err := w.WriteMsg(m) + if err != nil { + logrus.Errorf("failed write response for client with error: %s\n%s", err.Error(), r) + return + } } } +// HandleFailed returns a HandlerFunc that returns SERVFAIL for every request it gets. +func HandleFailed(w dns.ResponseWriter, r *dns.Msg) { + m := new(dns.Msg) + m.SetRcode(r, dns.RcodeServerFailure) + // does not matter if this write fails + _ = w.WriteMsg(m) +} + func (d *DNSServer) Lookup(m *dns.Msg) (*dns.Msg, error) { key := makekey(m) diff --git a/internal/network/network.go b/internal/network/network.go index c1f3482e..d425a38c 100644 --- a/internal/network/network.go +++ b/internal/network/network.go @@ -56,7 +56,7 @@ func StringJoinIPs(a, b *net.IPNet) string { return "" } -func ConfigureForwarding(wgIface string, gatewayIface string, cidr string, cidrv6 string, nat66 bool, allowedIPs []string) error { +func ConfigureForwarding(gatewayIface string, cidr string, cidrv6 string, nat66 bool, allowedIPs []string) error { // Networking configuration (iptables) configuration // to ensure that traffic from clients of the wireguard interface // is sent to the provided network interface @@ -79,7 +79,7 @@ func ConfigureForwarding(wgIface string, gatewayIface string, cidr string, cidrv } if cidr != ""{ - if err := configureForwardingv4( gatewayIface, cidr, allowedIPv4s); err != nil { + if err := configureForwardingv4(gatewayIface, cidr, allowedIPv4s); err != nil { return err } } @@ -99,16 +99,34 @@ func configureForwardingv4(gatewayIface string, cidr string, allowedIPs []string // Cleanup our chains first so that we don't leak // iptable rules when the network configuration changes. - ipt.ClearChain("filter", "WG_ACCESS_SERVER_FORWARD") - ipt.ClearChain("nat", "WG_ACCESS_SERVER_POSTROUTING") + err = ipt.ClearChain("filter", "WG_ACCESS_SERVER_FORWARD") + if err != nil { + return errors.Wrap(err, "failed to clear filter chain") + } + err = ipt.ClearChain("nat", "WG_ACCESS_SERVER_POSTROUTING") + if err != nil { + return errors.Wrap(err, "failed to clear nat chain") + } // Create our own chain for forwarding rules - ipt.NewChain("filter", "WG_ACCESS_SERVER_FORWARD") - ipt.AppendUnique("filter", "FORWARD", "-j", "WG_ACCESS_SERVER_FORWARD") + err = ipt.NewChain("filter", "WG_ACCESS_SERVER_FORWARD") + if err != nil { + return errors.Wrap(err, "failed to create filter chain") + } + err = ipt.AppendUnique("filter", "FORWARD", "-j", "WG_ACCESS_SERVER_FORWARD") + if err != nil { + return errors.Wrap(err, "failed to append FORWARD rule to filter chain") + } // Create our own chain for postrouting rules - ipt.NewChain("nat", "WG_ACCESS_SERVER_POSTROUTING") - ipt.AppendUnique("nat", "POSTROUTING", "-j", "WG_ACCESS_SERVER_POSTROUTING") + err = ipt.NewChain("nat", "WG_ACCESS_SERVER_POSTROUTING") + if err != nil { + return errors.Wrap(err, "failed to create nat chain") + } + err = ipt.AppendUnique("nat", "POSTROUTING", "-j", "WG_ACCESS_SERVER_POSTROUTING") + if err != nil { + return errors.Wrap(err, "failed to append POSTROUTING rule to nat chain") + } for _, allowedCIDR := range allowedIPs { if err := ipt.AppendUnique("filter", "WG_ACCESS_SERVER_FORWARD", "-s", cidr, "-d", allowedCIDR, "-j", "ACCEPT"); err != nil { @@ -134,14 +152,32 @@ func configureForwardingv6(gatewayIface string, cidrv6 string, nat66 bool, allow return errors.Wrap(err, "failed to init ip6tables") } - ipt.ClearChain("filter", "WG_ACCESS_SERVER_FORWARD") - ipt.ClearChain("nat", "WG_ACCESS_SERVER_POSTROUTING") + err = ipt.ClearChain("filter", "WG_ACCESS_SERVER_FORWARD") + if err != nil { + return errors.Wrap(err, "failed to clear filter chain") + } + err = ipt.ClearChain("nat", "WG_ACCESS_SERVER_POSTROUTING") + if err != nil { + return errors.Wrap(err, "failed to clear nat chain") + } - ipt.NewChain("filter", "WG_ACCESS_SERVER_FORWARD") - ipt.AppendUnique("filter", "FORWARD", "-j", "WG_ACCESS_SERVER_FORWARD") + err = ipt.NewChain("filter", "WG_ACCESS_SERVER_FORWARD") + if err != nil { + return errors.Wrap(err, "failed to create filter chain") + } + err = ipt.AppendUnique("filter", "FORWARD", "-j", "WG_ACCESS_SERVER_FORWARD") + if err != nil { + return errors.Wrap(err, "failed to append FORWARD rule to filter chain") + } - ipt.NewChain("nat", "WG_ACCESS_SERVER_POSTROUTING") - ipt.AppendUnique("nat", "POSTROUTING", "-j", "WG_ACCESS_SERVER_POSTROUTING") + err = ipt.NewChain("nat", "WG_ACCESS_SERVER_POSTROUTING") + if err != nil { + return errors.Wrap(err, "failed to create nat chain") + } + err = ipt.AppendUnique("nat", "POSTROUTING", "-j", "WG_ACCESS_SERVER_POSTROUTING") + if err != nil { + return errors.Wrap(err, "failed to append POSTROUTING rule to nat chain") + } // Accept client traffic for given allowed ips for _, allowedCIDR := range allowedIPs { @@ -175,10 +211,3 @@ func nextIP(ip net.IP) net.IP { } return next } - -func boolToRule(accept bool) string { - if accept { - return "ACCEPT" - } - return "REJECT" -} diff --git a/internal/services/api_router.go b/internal/services/api_router.go index 9b676e06..45efcca7 100644 --- a/internal/services/api_router.go +++ b/internal/services/api_router.go @@ -54,6 +54,5 @@ func ApiRouter(deps *ApiServices) http.Handler { w.WriteHeader(400) fmt.Fprintln(w, "expected grpc request") - return }) } diff --git a/internal/services/converters.go b/internal/services/converters.go index 956f5567..620c4cec 100644 --- a/internal/services/converters.go +++ b/internal/services/converters.go @@ -3,10 +3,10 @@ package services import ( "time" - "github.com/golang/protobuf/ptypes" "github.com/golang/protobuf/ptypes/timestamp" "github.com/golang/protobuf/ptypes/wrappers" "github.com/sirupsen/logrus" + "google.golang.org/protobuf/types/known/timestamppb" ) func TimestampToTime(value *timestamp.Timestamp) time.Time { @@ -17,10 +17,10 @@ func TimeToTimestamp(value *time.Time) *timestamp.Timestamp { if value == nil { return nil } - t, err := ptypes.TimestampProto(*value) - if err != nil { + t := timestamppb.New(*value) + if t == nil { logrus.Error("bad time value") - t = ptypes.TimestampNow() + t = timestamppb.Now() } return t } diff --git a/internal/traces/traces.go b/internal/traces/traces.go index f5a4dd7d..50afe6a2 100644 --- a/internal/traces/traces.go +++ b/internal/traces/traces.go @@ -8,8 +8,10 @@ import ( "github.com/sirupsen/logrus" ) +type traceContextKey string + const ( - TraceIDKey = "trace.id" + TraceIDKey traceContextKey = "trace.id" ) func WithTraceID(ctx context.Context) context.Context { diff --git a/pkg/authnz/authconfig/basic.go b/pkg/authnz/authconfig/basic.go index 98a115d1..ae670065 100644 --- a/pkg/authnz/authconfig/basic.go +++ b/pkg/authnz/authconfig/basic.go @@ -30,30 +30,27 @@ func (c *BasicAuthConfig) Provider() *authruntime.Provider { func basicAuthLogin(c *BasicAuthConfig, runtime *authruntime.ProviderRuntime) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { u, p, ok := r.BasicAuth() - if !ok { - w.Header().Set("WWW-Authenticate", `Basic realm="site"`) - w.WriteHeader(http.StatusUnauthorized) - fmt.Fprintln(w, "unauthorized") - return - } - - if ok := checkCreds(c.Users, u, p); ok { - runtime.SetSession(w, r, &authsession.AuthSession{ - Identity: &authsession.Identity{ - Provider: "basic", - Subject: u, - Name: u, - Email: "", // basic auth has no email - }, - }) - runtime.Done(w, r) - return + if ok { + if ok := checkCreds(c.Users, u, p); ok { + err := runtime.SetSession(w, r, &authsession.AuthSession{ + Identity: &authsession.Identity{ + Provider: "basic", + Subject: u, + Name: u, + Email: "", // basic auth has no email + }, + }) + if err == nil { + runtime.Done(w, r) + return + } + } } + // If we're here something went wrong, return StatusUnauthorized w.Header().Set("WWW-Authenticate", `Basic realm="site"`) w.WriteHeader(http.StatusUnauthorized) fmt.Fprintln(w, "unauthorized") - return } } diff --git a/pkg/authnz/authconfig/oidc.go b/pkg/authnz/authconfig/oidc.go index 40d97c3d..d101dde4 100644 --- a/pkg/authnz/authconfig/oidc.go +++ b/pkg/authnz/authconfig/oidc.go @@ -71,9 +71,13 @@ func (c *OIDCConfig) Provider() *authruntime.Provider { func (c *OIDCConfig) loginHandler(runtime *authruntime.ProviderRuntime, oauthConfig *oauth2.Config) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { oauthStateString := authutil.RandomString(32) - runtime.SetSession(w, r, &authsession.AuthSession{ + err := runtime.SetSession(w, r, &authsession.AuthSession{ Nonce: &oauthStateString, }) + if err != nil { + http.Error(w, "no session", http.StatusUnauthorized) + return + } url := oauthConfig.AuthCodeURL(oauthStateString) http.Redirect(w, r, url, http.StatusTemporaryRedirect) } @@ -107,7 +111,11 @@ func (c *OIDCConfig) callbackHandler(runtime *authruntime.ProviderRuntime, oauth } oidcProfileData := make(map[string]interface{}) - info.Claims(&oidcProfileData) + err = info.Claims(&oidcProfileData) + if err != nil { + http.Error(w, err.Error(), http.StatusUnauthorized) + return + } claims := &authsession.Claims{} for claimName, rule := range c.ClaimMapping { @@ -126,7 +134,7 @@ func (c *OIDCConfig) callbackHandler(runtime *authruntime.ProviderRuntime, oauth } } - runtime.SetSession(w, r, &authsession.AuthSession{ + err = runtime.SetSession(w, r, &authsession.AuthSession{ Identity: &authsession.Identity{ Provider: c.Name, Subject: info.Subject, @@ -135,6 +143,10 @@ func (c *OIDCConfig) callbackHandler(runtime *authruntime.ProviderRuntime, oauth Claims: *claims, }, }) + if err != nil { + http.Error(w, err.Error(), http.StatusUnauthorized) + return + } runtime.Done(w, r) } diff --git a/pkg/authnz/router.go b/pkg/authnz/router.go index 079ec6ba..03fb916b 100644 --- a/pkg/authnz/router.go +++ b/pkg/authnz/router.go @@ -25,7 +25,7 @@ type AuthMiddleware struct { runtime *authruntime.ProviderRuntime } -func New(config authconfig.AuthConfig, claimsMiddleware authsession.ClaimsMiddleware) *AuthMiddleware { +func New(config authconfig.AuthConfig, claimsMiddleware authsession.ClaimsMiddleware) (*AuthMiddleware, error) { router := mux.NewRouter() store := sessions.NewCookieStore([]byte(authutil.RandomString(32))) runtime := authruntime.NewProviderRuntime(store) @@ -33,13 +33,16 @@ func New(config authconfig.AuthConfig, claimsMiddleware authsession.ClaimsMiddle for _, p := range providers { if p.RegisterRoutes != nil { - p.RegisterRoutes(router, runtime) + err := p.RegisterRoutes(router, runtime) + if err != nil { + return nil, err + } } } router.HandleFunc("/signin", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) - fmt.Fprint(w, authtemplates.RenderLoginPage(w, authtemplates.LoginPage{ + _, _ = fmt.Fprint(w, authtemplates.RenderLoginPage(w, authtemplates.LoginPage{ Providers: providers, })) }) @@ -48,7 +51,7 @@ func New(config authconfig.AuthConfig, claimsMiddleware authsession.ClaimsMiddle index, err := strconv.Atoi(mux.Vars(r)["index"]) if err != nil || index < 0 || len(providers) <= index { w.WriteHeader(http.StatusBadRequest) - fmt.Fprintf(w, "unknown provider") + _, _ = fmt.Fprintf(w, "unknown provider") return } provider := providers[index] @@ -56,7 +59,7 @@ func New(config authconfig.AuthConfig, claimsMiddleware authsession.ClaimsMiddle }) router.HandleFunc("/signout", func(w http.ResponseWriter, r *http.Request) { - runtime.ClearSession(w, r) + _ = runtime.ClearSession(w, r) runtime.Restart(w, r) }) @@ -65,11 +68,15 @@ func New(config authconfig.AuthConfig, claimsMiddleware authsession.ClaimsMiddle claimsMiddleware, router, runtime, - } + }, nil } -func NewMiddleware(config authconfig.AuthConfig, claimsMiddleware authsession.ClaimsMiddleware) mux.MiddlewareFunc { - return New(config, claimsMiddleware).Middleware +func NewMiddleware(config authconfig.AuthConfig, claimsMiddleware authsession.ClaimsMiddleware) (mux.MiddlewareFunc, error) { + authMiddleware, err := New(config, claimsMiddleware) + if err != nil { + return nil, err + } + return authMiddleware.Middleware, nil } func (m *AuthMiddleware) Middleware(next http.Handler) http.Handler {