From 8117a90cbaea453fdee9c7b5e28aa69cb56e249d Mon Sep 17 00:00:00 2001 From: folbrich Date: Mon, 26 Dec 2022 18:05:21 +0100 Subject: [PATCH] Support routing by ListenerID or TLS server name --- cache.go | 2 +- cmd/routedns/config.go | 3 ++ cmd/routedns/main.go | 12 ++++---- cmd/routedns/resolver.go | 6 ++-- dnslistener.go | 17 ++++++++--- doc/configuration.md | 3 ++ dohlistener.go | 10 +++++-- dohlistener_test.go | 6 ++-- doqlistener.go | 7 ++++- dotclient_test.go | 2 +- dotlistener_test.go | 6 ++-- example_test.go | 4 +-- listener.go | 7 +++++ response-collapse.go | 8 ++--- route.go | 64 +++++++++++++++++++++++++++------------- route_test.go | 2 +- router_test.go | 16 +++++----- tls.go | 3 +- 18 files changed, 118 insertions(+), 60 deletions(-) diff --git a/cache.go b/cache.go index b2eef75c..9b9e0d96 100644 --- a/cache.go +++ b/cache.go @@ -307,7 +307,7 @@ func minTTL(answer *dns.Msg) (uint32, bool) { type AnswerShuffleFunc func(*dns.Msg) // Randomly re-order the A/AAAA answer records. -func AnswerShuffleRandon(msg *dns.Msg) { +func AnswerShuffleRandom(msg *dns.Msg) { if len(msg.Answer) < 2 { return } diff --git a/cmd/routedns/config.go b/cmd/routedns/config.go index d9bacc99..2d259ae3 100644 --- a/cmd/routedns/config.go +++ b/cmd/routedns/config.go @@ -46,6 +46,7 @@ type resolver struct { CA string ClientKey string `toml:"client-key"` ClientCrt string `toml:"client-crt"` + ServerName string `toml:"server-name"` // TLS server name presented in the server certificate BootstrapAddr string `toml:"bootstrap-address"` LocalAddr string `toml:"local-address"` EDNS0UDPSize uint16 `toml:"edns0-udp-size"` // UDP resolver option @@ -160,6 +161,8 @@ type route struct { Invert bool // Invert the result of the match DoHPath string `toml:"doh-path"` // DoH query path if received over DoH (regexp) Resolver string + Listener string // ID of the listener that received the original request + TLSServerName string `toml:"servername"` // TLS servername } // LoadConfig reads a config file and returns the decoded structure. diff --git a/cmd/routedns/main.go b/cmd/routedns/main.go index 2b09a10e..b4243b82 100644 --- a/cmd/routedns/main.go +++ b/cmd/routedns/main.go @@ -182,7 +182,7 @@ func start(opt options, args []string) error { resolver, ok := resolvers[l.Resolver] // All Listeners should route queries (except the admin service). if !ok && l.Protocol != "admin" { - return fmt.Errorf("listener '%s' references non-existant resolver, group or router '%s'", id, l.Resolver) + return fmt.Errorf("listener '%s' references non-existent resolver, group or router '%s'", id, l.Resolver) } allowedNet, err := parseCIDRList(l.AllowedNet) if err != nil { @@ -301,7 +301,7 @@ func instantiateGroup(id string, g group, resolvers map[string]rdns.Resolver) er for _, rid := range g.Resolvers { resolver, ok := resolvers[rid] if !ok { - return fmt.Errorf("group '%s' references non-existant resolver or group '%s'", id, rid) + return fmt.Errorf("group '%s' references non-existent resolver or group '%s'", id, rid) } gr = append(gr, resolver) } @@ -547,7 +547,7 @@ func instantiateGroup(id string, g group, resolvers map[string]rdns.Resolver) er switch g.CacheAnswerShuffle { case "": // default case "random": - shuffleFunc = rdns.AnswerShuffleRandon + shuffleFunc = rdns.AnswerShuffleRandom case "round-robin": shuffleFunc = rdns.AnswerShuffleRoundRobin default: @@ -693,7 +693,7 @@ func instantiateGroup(id string, g group, resolvers map[string]rdns.Resolver) er if len(gr) != 1 { return fmt.Errorf("type response-collapse only supports one resolver in '%s'", id) } - opt := rdns.ResponseCollapsOptions{ + opt := rdns.ResponseCollapseOptions{ NullRCode: g.NullRCode, } resolvers[id] = rdns.NewResponseCollapse(id, gr[0], opt) @@ -724,13 +724,13 @@ func instantiateRouter(id string, r router, resolvers map[string]rdns.Resolver) for _, route := range r.Routes { resolver, ok := resolvers[route.Resolver] if !ok { - return fmt.Errorf("router '%s' references non-existant resolver or group '%s'", id, route.Resolver) + return fmt.Errorf("router '%s' references non-existent resolver or group '%s'", id, route.Resolver) } types := route.Types if route.Type != "" { // Support the deprecated "Type" by just adding it to "Types" if defined types = append(types, route.Type) } - r, err := rdns.NewRoute(route.Name, route.Class, types, route.Weekdays, route.Before, route.After, route.Source, route.DoHPath, resolver) + r, err := rdns.NewRoute(route.Name, route.Class, types, route.Weekdays, route.Before, route.After, route.Source, route.DoHPath, route.Listener, route.TLSServerName, resolver) if err != nil { return fmt.Errorf("failure parsing routes for router '%s' : %s", id, err.Error()) } diff --git a/cmd/routedns/resolver.go b/cmd/routedns/resolver.go index 0654d5e2..f4fe0847 100644 --- a/cmd/routedns/resolver.go +++ b/cmd/routedns/resolver.go @@ -15,7 +15,7 @@ func instantiateResolver(id string, r resolver, resolvers map[string]rdns.Resolv case "doq": r.Address = rdns.AddressWithDefault(r.Address, rdns.DoQPort) - tlsConfig, err := rdns.TLSClientConfig(r.CA, r.ClientCrt, r.ClientKey) + tlsConfig, err := rdns.TLSClientConfig(r.CA, r.ClientCrt, r.ClientKey, r.ServerName) if err != nil { return err } @@ -31,7 +31,7 @@ func instantiateResolver(id string, r resolver, resolvers map[string]rdns.Resolv case "dot": r.Address = rdns.AddressWithDefault(r.Address, rdns.DoTPort) - tlsConfig, err := rdns.TLSClientConfig(r.CA, r.ClientCrt, r.ClientKey) + tlsConfig, err := rdns.TLSClientConfig(r.CA, r.ClientCrt, r.ClientKey, r.ServerName) if err != nil { return err } @@ -64,7 +64,7 @@ func instantiateResolver(id string, r resolver, resolvers map[string]rdns.Resolv case "doh": r.Address = rdns.AddressWithDefault(r.Address, rdns.DoHPort) - tlsConfig, err := rdns.TLSClientConfig(r.CA, r.ClientCrt, r.ClientKey) + tlsConfig, err := rdns.TLSClientConfig(r.CA, r.ClientCrt, r.ClientKey, r.ServerName) if err != nil { return err } diff --git a/dnslistener.go b/dnslistener.go index b385828c..8b6e21aa 100644 --- a/dnslistener.go +++ b/dnslistener.go @@ -1,6 +1,7 @@ package rdns import ( + "crypto/tls" "net" "github.com/miekg/dns" @@ -49,10 +50,18 @@ func (s DNSListener) String() string { func listenHandler(id, protocol, addr string, r Resolver, allowedNet []*net.IPNet) dns.HandlerFunc { metrics := NewListenerMetrics("listener", id) return func(w dns.ResponseWriter, req *dns.Msg) { - var ( - ci ClientInfo - err error - ) + var err error + + ci := ClientInfo{ + Listener: id, + } + + if r, ok := w.(interface{ ConnectionState() *tls.ConnectionState }); ok { + connState := r.ConnectionState() + if connState != nil { + ci.TLSServerName = connState.ServerName + } + } switch addr := w.RemoteAddr().(type) { case *net.TCPAddr: diff --git a/doc/configuration.md b/doc/configuration.md index 32089930..1ea831d4 100644 --- a/doc/configuration.md +++ b/doc/configuration.md @@ -1063,6 +1063,8 @@ A route has the following fields: - `before` - Time of day in the format HH:mm before which the rule matches. Uses 24h format. For example `17:30`. - `invert` - Invert the result of the matching if set to `true`. Optional. - `doh-path` - Regexp that matches on the DoH query path the client used. +- `listener` - Regexp that matches on the ID of the listener that first received. +- `servername` - Regexp that matches on the TLS server name used in the TLS handshake with the listener. - `resolver` - The identifier of a resolver, group, or another router. Required. Examples: @@ -1348,6 +1350,7 @@ Secure resolvers such as DoT, DoH, or DoQ offer additional options to configure - `client-crt` - Client certificate file. - `client-key` - Client certificate key file - `ca` - CA certificate to validate server certificates. +- `server-name` - Name of the certificate presented by the server if it does not match the name in the endpoint address. Examples: diff --git a/dohlistener.go b/dohlistener.go index 6db05107..b0281358 100644 --- a/dohlistener.go +++ b/dohlistener.go @@ -240,9 +240,15 @@ func (s *DoHListener) parseAndRespond(b []byte, w http.ResponseWriter, r *http.R http.Error(w, "Invalid RemoteAddr", http.StatusBadRequest) return } + var tlsServerName string + if r.TLS != nil { + tlsServerName = r.TLS.ServerName + } ci := ClientInfo{ - SourceIP: clientIP, - DoHPath: r.URL.Path, + SourceIP: clientIP, + DoHPath: r.URL.Path, + TLSServerName: tlsServerName, + Listener: s.id, } log := Log.WithFields(logrus.Fields{ "id": s.id, diff --git a/dohlistener_test.go b/dohlistener_test.go index 00279a98..206f9291 100644 --- a/dohlistener_test.go +++ b/dohlistener_test.go @@ -28,7 +28,7 @@ func TestDoHListenerSimple(t *testing.T) { time.Sleep(time.Second) // Make a client talking to the listener using POST - tlsConfig, err := TLSClientConfig("testdata/ca.crt", "", "") + tlsConfig, err := TLSClientConfig("testdata/ca.crt", "", "", "") require.NoError(t, err) u := "https://" + addr + "/dns-query" cPost, err := NewDoHClient("test-doh", u, DoHClientOptions{TLSConfig: tlsConfig, Method: "POST"}) @@ -82,7 +82,7 @@ func TestDoHListenerMutual(t *testing.T) { // Make a client talking to the listener. Need to trust the issuer of the server certificate and // present a client certificate. - tlsClientConfig, err := TLSClientConfig("testdata/ca.crt", "testdata/client.crt", "testdata/client.key") + tlsClientConfig, err := TLSClientConfig("testdata/ca.crt", "testdata/client.crt", "testdata/client.key", "") require.NoError(t, err) u := "https://" + addr + "/dns-query" c, err := NewDoHClient("test-doh", u, DoHClientOptions{TLSConfig: tlsClientConfig}) @@ -116,7 +116,7 @@ func TestDoHListenerMutualQUIC(t *testing.T) { // Make a client talking to the listener. Need to trust the issuer of the server certificate and // present a client certificate. - tlsClientConfig, err := TLSClientConfig("testdata/ca.crt", "testdata/client.crt", "testdata/client.key") + tlsClientConfig, err := TLSClientConfig("testdata/ca.crt", "testdata/client.crt", "testdata/client.key", "") require.NoError(t, err) u := "https://" + addr + "/dns-query" c, err := NewDoHClient("test-doh", u, DoHClientOptions{TLSConfig: tlsClientConfig, Transport: "quic"}) diff --git a/doqlistener.go b/doqlistener.go index b1ea9d0d..d73759ce 100644 --- a/doqlistener.go +++ b/doqlistener.go @@ -105,7 +105,12 @@ func (s DoQListener) Stop() error { } func (s DoQListener) handleConnection(connection quic.Connection) { - var ci ClientInfo + tlsServerName := connection.ConnectionState().TLS.ServerName + + ci := ClientInfo{ + Listener: s.id, + TLSServerName: tlsServerName, + } switch addr := connection.RemoteAddr().(type) { case *net.TCPAddr: ci.SourceIP = addr.IP diff --git a/dotclient_test.go b/dotclient_test.go index c07fd284..1ad0e1e9 100644 --- a/dotclient_test.go +++ b/dotclient_test.go @@ -33,7 +33,7 @@ func TestDoTClientCA(t *testing.T) { conn.Close() // Create a config with CA using the temp file - tlsConfig, err := TLSClientConfig(crtFile, "", "") + tlsConfig, err := TLSClientConfig(crtFile, "", "", "") require.NoError(t, err) // DoT client with valid CA diff --git a/dotlistener_test.go b/dotlistener_test.go index cf573498..634c9a8d 100644 --- a/dotlistener_test.go +++ b/dotlistener_test.go @@ -29,7 +29,7 @@ func TestDoTListenerSimple(t *testing.T) { time.Sleep(time.Second) // Make a client talking to the listener. Need to trust the issue of the server certificate. - tlsConfig, err := TLSClientConfig("testdata/ca.crt", "", "") + tlsConfig, err := TLSClientConfig("testdata/ca.crt", "", "", "") require.NoError(t, err) c, _ := NewDoTClient("test-dot", addr, DoTClientOptions{TLSConfig: tlsConfig}) @@ -64,7 +64,7 @@ func TestDoTListenerMutual(t *testing.T) { // Make a client talking to the listener. Need to trust the issue of the server certificate and // present a client certificate. - tlsClientConfig, err := TLSClientConfig("testdata/ca.crt", "testdata/client.crt", "testdata/client.key") + tlsClientConfig, err := TLSClientConfig("testdata/ca.crt", "testdata/client.crt", "testdata/client.key", "") require.NoError(t, err) c, _ := NewDoTClient("test-dot", addr, DoTClientOptions{TLSConfig: tlsClientConfig}) @@ -99,7 +99,7 @@ func TestDoTListenerPadding(t *testing.T) { time.Sleep(time.Second) // Make a client talking to the listener. Need to trust the issue of the server certificate. - tlsConfig, err := TLSClientConfig("testdata/ca.crt", "", "") + tlsConfig, err := TLSClientConfig("testdata/ca.crt", "", "", "") require.NoError(t, err) c, _ := NewDoTClient("test-dot", addr, DoTClientOptions{TLSConfig: tlsConfig}) diff --git a/example_test.go b/example_test.go index 3f07a248..3d8cf100 100644 --- a/example_test.go +++ b/example_test.go @@ -44,8 +44,8 @@ func Example_router() { // Build a router that will send all "*.cloudflare.com" to the cloudflare // resolver while everything else goes to the google resolver (default) - route1, _ := rdns.NewRoute(`\.cloudflare\.com\.$`, "", nil, nil, "", "", "", "", cloudflare) - route2, _ := rdns.NewRoute("", "", nil, nil, "", "", "", "", google) + route1, _ := rdns.NewRoute(`\.cloudflare\.com\.$`, "", nil, nil, "", "", "", "", "", "", cloudflare) + route2, _ := rdns.NewRoute("", "", nil, nil, "", "", "", "", "", "", google) r := rdns.NewRouter("my-router") r.Add(route1, route2) diff --git a/listener.go b/listener.go index d355b8ac..c9201193 100644 --- a/listener.go +++ b/listener.go @@ -20,6 +20,13 @@ type ClientInfo struct { // DoH query path used by the client. Only populated when // the query was received over DoH. DoHPath string + + // TLS SNI server name + TLSServerName string + + // Listener ID of the listener that first received the request. Can be + // used to route queries. + Listener string } // Metrics that are available from listeners and clients. diff --git a/response-collapse.go b/response-collapse.go index 9c84df3a..dc6ee1a8 100644 --- a/response-collapse.go +++ b/response-collapse.go @@ -9,18 +9,18 @@ import ( type ResponseCollapse struct { id string resolver Resolver - ResponseCollapsOptions + ResponseCollapseOptions } -type ResponseCollapsOptions struct { +type ResponseCollapseOptions struct { NullRCode int // Response code when there's nothing left after collapsing the response } var _ Resolver = &ResponseCollapse{} // NewResponseMinimize returns a new instance of a response minimizer. -func NewResponseCollapse(id string, resolver Resolver, opt ResponseCollapsOptions) *ResponseCollapse { - return &ResponseCollapse{id: id, resolver: resolver, ResponseCollapsOptions: opt} +func NewResponseCollapse(id string, resolver Resolver, opt ResponseCollapseOptions) *ResponseCollapse { + return &ResponseCollapse{id: id, resolver: resolver, ResponseCollapseOptions: opt} } // Resolve a DNS query, then collapse the response to remove anything from the diff --git a/route.go b/route.go index 6a81f6a2..a544b62c 100644 --- a/route.go +++ b/route.go @@ -13,20 +13,22 @@ import ( ) type route struct { - types []uint16 - class uint16 - name *regexp.Regexp - source *net.IPNet - weekdays []time.Weekday - before *TimeOfDay - after *TimeOfDay - inverted bool // invert the matching behavior - dohPath *regexp.Regexp - resolver Resolver + types []uint16 + class uint16 + name *regexp.Regexp + source *net.IPNet + weekdays []time.Weekday + before *TimeOfDay + after *TimeOfDay + inverted bool // invert the matching behavior + dohPath *regexp.Regexp + resolver Resolver + listenerID *regexp.Regexp + tlsServerName *regexp.Regexp } // NewRoute initializes a route from string parameters. -func NewRoute(name, class string, types, weekdays []string, before, after, source, dohPath string, resolver Resolver) (*route, error) { +func NewRoute(name, class string, types, weekdays []string, before, after, source, dohPath, listenerID, tlsServerName string, resolver Resolver) (*route, error) { if resolver == nil { return nil, errors.New("no resolver defined for route") } @@ -58,6 +60,14 @@ func NewRoute(name, class string, types, weekdays []string, before, after, sourc if err != nil { return nil, err } + listenerRe, err := regexp.Compile(listenerID) + if err != nil { + return nil, err + } + tlsRe, err := regexp.Compile(tlsServerName) + if err != nil { + return nil, err + } var sNet *net.IPNet if source != "" { _, sNet, err = net.ParseCIDR(source) @@ -66,15 +76,17 @@ func NewRoute(name, class string, types, weekdays []string, before, after, sourc } } return &route{ - types: t, - class: c, - name: re, - weekdays: w, - before: b, - after: a, - source: sNet, - dohPath: dohRe, - resolver: resolver, + types: t, + class: c, + name: re, + weekdays: w, + before: b, + after: a, + source: sNet, + dohPath: dohRe, + listenerID: listenerRe, + tlsServerName: tlsRe, + resolver: resolver, }, nil } @@ -95,6 +107,12 @@ func (r *route) match(q *dns.Msg, ci ClientInfo) bool { if !r.dohPath.MatchString(ci.DoHPath) { return r.inverted } + if !r.listenerID.MatchString(ci.Listener) { + return r.inverted + } + if !r.tlsServerName.MatchString(ci.TLSServerName) { + return r.inverted + } if len(r.weekdays) > 0 || r.before != nil || r.after != nil { now := time.Now().Local() hour := now.Hour() @@ -151,6 +169,12 @@ func (r *route) String() string { if r.dohPath.String() != "" { fragments = append(fragments, "doh-path="+r.dohPath.String()) } + if r.listenerID.String() != "" { + fragments = append(fragments, "listener="+r.listenerID.String()) + } + if r.tlsServerName.String() != "" { + fragments = append(fragments, "servername="+r.tlsServerName.String()) + } if len(r.weekdays) > 0 { fragments = append(fragments, fmt.Sprintf("weekdays=%v", r.weekdays)) } diff --git a/route_test.go b/route_test.go index 3f08a90b..a691fa74 100644 --- a/route_test.go +++ b/route_test.go @@ -79,7 +79,7 @@ func TestRoute(t *testing.T) { }, } for _, test := range tests { - r, err := NewRoute(test.rName, test.rClass, test.rType, nil, "", "", "", "", &TestResolver{}) + r, err := NewRoute(test.rName, test.rClass, test.rType, nil, "", "", "", "", "", "", &TestResolver{}) require.NoError(t, err) r.Invert(test.rInvert) diff --git a/router_test.go b/router_test.go index 95e6bc88..c97693ab 100644 --- a/router_test.go +++ b/router_test.go @@ -14,8 +14,8 @@ func TestRouterType(t *testing.T) { q := new(dns.Msg) var ci ClientInfo - route1, _ := NewRoute("", "", []string{"MX"}, nil, "", "", "", "", r1) - route2, _ := NewRoute("", "", nil, nil, "", "", "", "", r2) + route1, _ := NewRoute("", "", []string{"MX"}, nil, "", "", "", "", "", "", r1) + route2, _ := NewRoute("", "", nil, nil, "", "", "", "", "", "", r2) router := NewRouter("my-router") router.Add(route1, route2) @@ -41,8 +41,8 @@ func TestRouterClass(t *testing.T) { q := new(dns.Msg) var ci ClientInfo - route1, _ := NewRoute("", "ANY", nil, nil, "", "", "", "", r1) - route2, _ := NewRoute("", "", nil, nil, "", "", "", "", r2) + route1, _ := NewRoute("", "ANY", nil, nil, "", "", "", "", "", "", r1) + route2, _ := NewRoute("", "", nil, nil, "", "", "", "", "", "", r2) router := NewRouter("my-router") router.Add(route1, route2) @@ -69,8 +69,8 @@ func TestRouterName(t *testing.T) { q := new(dns.Msg) var ci ClientInfo - route1, _ := NewRoute(`\.acme\.test\.$`, "", nil, nil, "", "", "", "", r1) - route2, _ := NewRoute("", "", nil, nil, "", "", "", "", r2) + route1, _ := NewRoute(`\.acme\.test\.$`, "", nil, nil, "", "", "", "", "", "", r1) + route2, _ := NewRoute("", "", nil, nil, "", "", "", "", "", "", r2) router := NewRouter("my-router") router.Add(route1, route2) @@ -96,8 +96,8 @@ func TestRouterSource(t *testing.T) { q := new(dns.Msg) q.SetQuestion("acme.test.", dns.TypeA) - route1, _ := NewRoute("", "", nil, nil, "", "", "192.168.1.100/32", "", r1) - route2, _ := NewRoute("", "", nil, nil, "", "", "", "", r2) + route1, _ := NewRoute("", "", nil, nil, "", "", "192.168.1.100/32", "", "", "", r1) + route2, _ := NewRoute("", "", nil, nil, "", "", "", "", "", "", r2) router := NewRouter("my-router") router.Add(route1, route2) diff --git a/tls.go b/tls.go index 91ba08da..45c4fb0b 100644 --- a/tls.go +++ b/tls.go @@ -41,9 +41,10 @@ func TLSServerConfig(caFile, crtFile, keyFile string, mutualTLS bool) (*tls.Conf // TLSClientConfig is a convenience function that builds a tls.Config instance for TLS clients // based on common options and certificate+key files. -func TLSClientConfig(caFile, crtFile, keyFile string) (*tls.Config, error) { +func TLSClientConfig(caFile, crtFile, keyFile, serverName string) (*tls.Config, error) { tlsConfig := &tls.Config{ MinVersion: tls.VersionTLS12, + ServerName: serverName, } // Add client key/cert if provided