Skip to content

Commit

Permalink
Address linter warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
DasSkelett committed Nov 18, 2021
1 parent dbe1ed2 commit a7aa16d
Show file tree
Hide file tree
Showing 9 changed files with 128 additions and 62 deletions.
8 changes: 6 additions & 2 deletions cmd/serve/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ func (cmd *servecmd) Run() {

logrus.Infof("wireguard VPN network is %s", network.StringJoinIPNets(vpnip, vpnipv6))

if err := network.ConfigureForwarding(conf.WireGuard.Interface, conf.VPN.GatewayInterface, conf.VPN.CIDR, conf.VPN.CIDRv6, conf.VPN.NAT66, conf.VPN.AllowedIPs); err != nil {
if err := network.ConfigureForwarding(conf.VPN.GatewayInterface, conf.VPN.CIDR, conf.VPN.CIDRv6, conf.VPN.NAT66, conf.VPN.AllowedIPs); err != nil {
logrus.Fatal(err)
}
}
Expand Down Expand Up @@ -155,7 +155,11 @@ func (cmd *servecmd) Run() {

// Authentication middleware
if conf.Auth.IsEnabled() {
router.Use(authnz.NewMiddleware(conf.Auth, claimsMiddleware(conf)))
middleware, err := authnz.NewMiddleware(conf.Auth, claimsMiddleware(conf))
if err != nil {
logrus.Fatal(errors.Wrap(err, "failed to set up authnz middleware"))
}
router.Use(middleware)
} else {
logrus.Warn("[DEPRECATION NOTICE] using wg-access-server without an admin user is deprecated and will be removed in an upcoming minor release.")
router.Use(func(next http.Handler) http.Handler {
Expand Down
22 changes: 19 additions & 3 deletions internal/dnsproxy/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,19 +74,35 @@ func (d *DNSServer) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
m, err := d.Lookup(r)
if err != nil {
logrus.Errorf("failed lookup record with error: %s\n%s", err.Error(), r)
dns.HandleFailed(w, r)
HandleFailed(w, r)
return
}
m.SetReply(r)
w.WriteMsg(m)
err = w.WriteMsg(m)
if err != nil {
logrus.Errorf("failed write response for client with error: %s\n%s", err.Error(), r)
return
}
default:
m := &dns.Msg{}
m.SetReply(r)
w.WriteMsg(m)
err := w.WriteMsg(m)
if err != nil {
logrus.Errorf("failed write response for client with error: %s\n%s", err.Error(), r)
return
}
}

}

// HandleFailed returns a HandlerFunc that returns SERVFAIL for every request it gets.
func HandleFailed(w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg)
m.SetRcode(r, dns.RcodeServerFailure)
// does not matter if this write fails
_ = w.WriteMsg(m)
}

func (d *DNSServer) Lookup(m *dns.Msg) (*dns.Msg, error) {
key := makekey(m)

Expand Down
71 changes: 50 additions & 21 deletions internal/network/network.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func StringJoinIPs(a, b *net.IPNet) string {
return ""
}

func ConfigureForwarding(wgIface string, gatewayIface string, cidr string, cidrv6 string, nat66 bool, allowedIPs []string) error {
func ConfigureForwarding(gatewayIface string, cidr string, cidrv6 string, nat66 bool, allowedIPs []string) error {
// Networking configuration (iptables) configuration
// to ensure that traffic from clients of the wireguard interface
// is sent to the provided network interface
Expand All @@ -79,7 +79,7 @@ func ConfigureForwarding(wgIface string, gatewayIface string, cidr string, cidrv
}

if cidr != ""{
if err := configureForwardingv4( gatewayIface, cidr, allowedIPv4s); err != nil {
if err := configureForwardingv4(gatewayIface, cidr, allowedIPv4s); err != nil {
return err
}
}
Expand All @@ -99,16 +99,34 @@ func configureForwardingv4(gatewayIface string, cidr string, allowedIPs []string

// Cleanup our chains first so that we don't leak
// iptable rules when the network configuration changes.
ipt.ClearChain("filter", "WG_ACCESS_SERVER_FORWARD")
ipt.ClearChain("nat", "WG_ACCESS_SERVER_POSTROUTING")
err = ipt.ClearChain("filter", "WG_ACCESS_SERVER_FORWARD")
if err != nil {
return errors.Wrap(err, "failed to clear filter chain")
}
err = ipt.ClearChain("nat", "WG_ACCESS_SERVER_POSTROUTING")
if err != nil {
return errors.Wrap(err, "failed to clear nat chain")
}

// Create our own chain for forwarding rules
ipt.NewChain("filter", "WG_ACCESS_SERVER_FORWARD")
ipt.AppendUnique("filter", "FORWARD", "-j", "WG_ACCESS_SERVER_FORWARD")
err = ipt.NewChain("filter", "WG_ACCESS_SERVER_FORWARD")
if err != nil {
return errors.Wrap(err, "failed to create filter chain")
}
err = ipt.AppendUnique("filter", "FORWARD", "-j", "WG_ACCESS_SERVER_FORWARD")
if err != nil {
return errors.Wrap(err, "failed to append FORWARD rule to filter chain")
}

// Create our own chain for postrouting rules
ipt.NewChain("nat", "WG_ACCESS_SERVER_POSTROUTING")
ipt.AppendUnique("nat", "POSTROUTING", "-j", "WG_ACCESS_SERVER_POSTROUTING")
err = ipt.NewChain("nat", "WG_ACCESS_SERVER_POSTROUTING")
if err != nil {
return errors.Wrap(err, "failed to create nat chain")
}
err = ipt.AppendUnique("nat", "POSTROUTING", "-j", "WG_ACCESS_SERVER_POSTROUTING")
if err != nil {
return errors.Wrap(err, "failed to append POSTROUTING rule to nat chain")
}

for _, allowedCIDR := range allowedIPs {
if err := ipt.AppendUnique("filter", "WG_ACCESS_SERVER_FORWARD", "-s", cidr, "-d", allowedCIDR, "-j", "ACCEPT"); err != nil {
Expand All @@ -134,14 +152,32 @@ func configureForwardingv6(gatewayIface string, cidrv6 string, nat66 bool, allow
return errors.Wrap(err, "failed to init ip6tables")
}

ipt.ClearChain("filter", "WG_ACCESS_SERVER_FORWARD")
ipt.ClearChain("nat", "WG_ACCESS_SERVER_POSTROUTING")
err = ipt.ClearChain("filter", "WG_ACCESS_SERVER_FORWARD")
if err != nil {
return errors.Wrap(err, "failed to clear filter chain")
}
err = ipt.ClearChain("nat", "WG_ACCESS_SERVER_POSTROUTING")
if err != nil {
return errors.Wrap(err, "failed to clear nat chain")
}

ipt.NewChain("filter", "WG_ACCESS_SERVER_FORWARD")
ipt.AppendUnique("filter", "FORWARD", "-j", "WG_ACCESS_SERVER_FORWARD")
err = ipt.NewChain("filter", "WG_ACCESS_SERVER_FORWARD")
if err != nil {
return errors.Wrap(err, "failed to create filter chain")
}
err = ipt.AppendUnique("filter", "FORWARD", "-j", "WG_ACCESS_SERVER_FORWARD")
if err != nil {
return errors.Wrap(err, "failed to append FORWARD rule to filter chain")
}

ipt.NewChain("nat", "WG_ACCESS_SERVER_POSTROUTING")
ipt.AppendUnique("nat", "POSTROUTING", "-j", "WG_ACCESS_SERVER_POSTROUTING")
err = ipt.NewChain("nat", "WG_ACCESS_SERVER_POSTROUTING")
if err != nil {
return errors.Wrap(err, "failed to create nat chain")
}
err = ipt.AppendUnique("nat", "POSTROUTING", "-j", "WG_ACCESS_SERVER_POSTROUTING")
if err != nil {
return errors.Wrap(err, "failed to append POSTROUTING rule to nat chain")
}

// Accept client traffic for given allowed ips
for _, allowedCIDR := range allowedIPs {
Expand Down Expand Up @@ -175,10 +211,3 @@ func nextIP(ip net.IP) net.IP {
}
return next
}

func boolToRule(accept bool) string {
if accept {
return "ACCEPT"
}
return "REJECT"
}
1 change: 0 additions & 1 deletion internal/services/api_router.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,5 @@ func ApiRouter(deps *ApiServices) http.Handler {

w.WriteHeader(400)
fmt.Fprintln(w, "expected grpc request")
return
})
}
8 changes: 4 additions & 4 deletions internal/services/converters.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ package services
import (
"time"

"github.com/golang/protobuf/ptypes"
"github.com/golang/protobuf/ptypes/timestamp"
"github.com/golang/protobuf/ptypes/wrappers"
"github.com/sirupsen/logrus"
"google.golang.org/protobuf/types/known/timestamppb"
)

func TimestampToTime(value *timestamp.Timestamp) time.Time {
Expand All @@ -17,10 +17,10 @@ func TimeToTimestamp(value *time.Time) *timestamp.Timestamp {
if value == nil {
return nil
}
t, err := ptypes.TimestampProto(*value)
if err != nil {
t := timestamppb.New(*value)
if t == nil {
logrus.Error("bad time value")
t = ptypes.TimestampNow()
t = timestamppb.Now()
}
return t
}
Expand Down
4 changes: 3 additions & 1 deletion internal/traces/traces.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@ import (
"github.com/sirupsen/logrus"
)

type traceContextKey string

const (
TraceIDKey = "trace.id"
TraceIDKey traceContextKey = "trace.id"
)

func WithTraceID(ctx context.Context) context.Context {
Expand Down
35 changes: 16 additions & 19 deletions pkg/authnz/authconfig/basic.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,30 +30,27 @@ func (c *BasicAuthConfig) Provider() *authruntime.Provider {
func basicAuthLogin(c *BasicAuthConfig, runtime *authruntime.ProviderRuntime) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
u, p, ok := r.BasicAuth()
if !ok {
w.Header().Set("WWW-Authenticate", `Basic realm="site"`)
w.WriteHeader(http.StatusUnauthorized)
fmt.Fprintln(w, "unauthorized")
return
}

if ok := checkCreds(c.Users, u, p); ok {
runtime.SetSession(w, r, &authsession.AuthSession{
Identity: &authsession.Identity{
Provider: "basic",
Subject: u,
Name: u,
Email: "", // basic auth has no email
},
})
runtime.Done(w, r)
return
if ok {
if ok := checkCreds(c.Users, u, p); ok {
err := runtime.SetSession(w, r, &authsession.AuthSession{
Identity: &authsession.Identity{
Provider: "basic",
Subject: u,
Name: u,
Email: "", // basic auth has no email
},
})
if err == nil {
runtime.Done(w, r)
return
}
}
}

// If we're here something went wrong, return StatusUnauthorized
w.Header().Set("WWW-Authenticate", `Basic realm="site"`)
w.WriteHeader(http.StatusUnauthorized)
fmt.Fprintln(w, "unauthorized")
return
}
}

Expand Down
18 changes: 15 additions & 3 deletions pkg/authnz/authconfig/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,13 @@ func (c *OIDCConfig) Provider() *authruntime.Provider {
func (c *OIDCConfig) loginHandler(runtime *authruntime.ProviderRuntime, oauthConfig *oauth2.Config) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
oauthStateString := authutil.RandomString(32)
runtime.SetSession(w, r, &authsession.AuthSession{
err := runtime.SetSession(w, r, &authsession.AuthSession{
Nonce: &oauthStateString,
})
if err != nil {
http.Error(w, "no session", http.StatusUnauthorized)
return
}
url := oauthConfig.AuthCodeURL(oauthStateString)
http.Redirect(w, r, url, http.StatusTemporaryRedirect)
}
Expand Down Expand Up @@ -107,7 +111,11 @@ func (c *OIDCConfig) callbackHandler(runtime *authruntime.ProviderRuntime, oauth
}

oidcProfileData := make(map[string]interface{})
info.Claims(&oidcProfileData)
err = info.Claims(&oidcProfileData)
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}

claims := &authsession.Claims{}
for claimName, rule := range c.ClaimMapping {
Expand All @@ -126,7 +134,7 @@ func (c *OIDCConfig) callbackHandler(runtime *authruntime.ProviderRuntime, oauth
}
}

runtime.SetSession(w, r, &authsession.AuthSession{
err = runtime.SetSession(w, r, &authsession.AuthSession{
Identity: &authsession.Identity{
Provider: c.Name,
Subject: info.Subject,
Expand All @@ -135,6 +143,10 @@ func (c *OIDCConfig) callbackHandler(runtime *authruntime.ProviderRuntime, oauth
Claims: *claims,
},
})
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}

runtime.Done(w, r)
}
Expand Down
23 changes: 15 additions & 8 deletions pkg/authnz/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,24 @@ type AuthMiddleware struct {
runtime *authruntime.ProviderRuntime
}

func New(config authconfig.AuthConfig, claimsMiddleware authsession.ClaimsMiddleware) *AuthMiddleware {
func New(config authconfig.AuthConfig, claimsMiddleware authsession.ClaimsMiddleware) (*AuthMiddleware, error) {
router := mux.NewRouter()
store := sessions.NewCookieStore([]byte(authutil.RandomString(32)))
runtime := authruntime.NewProviderRuntime(store)
providers := config.Providers()

for _, p := range providers {
if p.RegisterRoutes != nil {
p.RegisterRoutes(router, runtime)
err := p.RegisterRoutes(router, runtime)
if err != nil {
return nil, err
}
}
}

router.HandleFunc("/signin", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
fmt.Fprint(w, authtemplates.RenderLoginPage(w, authtemplates.LoginPage{
_, _ = fmt.Fprint(w, authtemplates.RenderLoginPage(w, authtemplates.LoginPage{
Providers: providers,
}))
})
Expand All @@ -48,15 +51,15 @@ func New(config authconfig.AuthConfig, claimsMiddleware authsession.ClaimsMiddle
index, err := strconv.Atoi(mux.Vars(r)["index"])
if err != nil || index < 0 || len(providers) <= index {
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(w, "unknown provider")
_, _ = fmt.Fprintf(w, "unknown provider")
return
}
provider := providers[index]
provider.Invoke(w, r, runtime)
})

router.HandleFunc("/signout", func(w http.ResponseWriter, r *http.Request) {
runtime.ClearSession(w, r)
_ = runtime.ClearSession(w, r)
runtime.Restart(w, r)
})

Expand All @@ -65,11 +68,15 @@ func New(config authconfig.AuthConfig, claimsMiddleware authsession.ClaimsMiddle
claimsMiddleware,
router,
runtime,
}
}, nil
}

func NewMiddleware(config authconfig.AuthConfig, claimsMiddleware authsession.ClaimsMiddleware) mux.MiddlewareFunc {
return New(config, claimsMiddleware).Middleware
func NewMiddleware(config authconfig.AuthConfig, claimsMiddleware authsession.ClaimsMiddleware) (mux.MiddlewareFunc, error) {
authMiddleware, err := New(config, claimsMiddleware)
if err != nil {
return nil, err
}
return authMiddleware.Middleware, nil
}

func (m *AuthMiddleware) Middleware(next http.Handler) http.Handler {
Expand Down

0 comments on commit a7aa16d

Please sign in to comment.