Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support routing by ListenerID or TLS server name #269

Merged
merged 1 commit into from
Jan 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ func minTTL(answer *dns.Msg) (uint32, bool) {
type AnswerShuffleFunc func(*dns.Msg)

// Randomly re-order the A/AAAA answer records.
func AnswerShuffleRandon(msg *dns.Msg) {
func AnswerShuffleRandom(msg *dns.Msg) {
if len(msg.Answer) < 2 {
return
}
Expand Down
3 changes: 3 additions & 0 deletions cmd/routedns/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ type resolver struct {
CA string
ClientKey string `toml:"client-key"`
ClientCrt string `toml:"client-crt"`
ServerName string `toml:"server-name"` // TLS server name presented in the server certificate
BootstrapAddr string `toml:"bootstrap-address"`
LocalAddr string `toml:"local-address"`
EDNS0UDPSize uint16 `toml:"edns0-udp-size"` // UDP resolver option
Expand Down Expand Up @@ -160,6 +161,8 @@ type route struct {
Invert bool // Invert the result of the match
DoHPath string `toml:"doh-path"` // DoH query path if received over DoH (regexp)
Resolver string
Listener string // ID of the listener that received the original request
TLSServerName string `toml:"servername"` // TLS servername
}

// LoadConfig reads a config file and returns the decoded structure.
Expand Down
12 changes: 6 additions & 6 deletions cmd/routedns/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ func start(opt options, args []string) error {
resolver, ok := resolvers[l.Resolver]
// All Listeners should route queries (except the admin service).
if !ok && l.Protocol != "admin" {
return fmt.Errorf("listener '%s' references non-existant resolver, group or router '%s'", id, l.Resolver)
return fmt.Errorf("listener '%s' references non-existent resolver, group or router '%s'", id, l.Resolver)
}
allowedNet, err := parseCIDRList(l.AllowedNet)
if err != nil {
Expand Down Expand Up @@ -301,7 +301,7 @@ func instantiateGroup(id string, g group, resolvers map[string]rdns.Resolver) er
for _, rid := range g.Resolvers {
resolver, ok := resolvers[rid]
if !ok {
return fmt.Errorf("group '%s' references non-existant resolver or group '%s'", id, rid)
return fmt.Errorf("group '%s' references non-existent resolver or group '%s'", id, rid)
}
gr = append(gr, resolver)
}
Expand Down Expand Up @@ -547,7 +547,7 @@ func instantiateGroup(id string, g group, resolvers map[string]rdns.Resolver) er
switch g.CacheAnswerShuffle {
case "": // default
case "random":
shuffleFunc = rdns.AnswerShuffleRandon
shuffleFunc = rdns.AnswerShuffleRandom
case "round-robin":
shuffleFunc = rdns.AnswerShuffleRoundRobin
default:
Expand Down Expand Up @@ -693,7 +693,7 @@ func instantiateGroup(id string, g group, resolvers map[string]rdns.Resolver) er
if len(gr) != 1 {
return fmt.Errorf("type response-collapse only supports one resolver in '%s'", id)
}
opt := rdns.ResponseCollapsOptions{
opt := rdns.ResponseCollapseOptions{
NullRCode: g.NullRCode,
}
resolvers[id] = rdns.NewResponseCollapse(id, gr[0], opt)
Expand Down Expand Up @@ -724,13 +724,13 @@ func instantiateRouter(id string, r router, resolvers map[string]rdns.Resolver)
for _, route := range r.Routes {
resolver, ok := resolvers[route.Resolver]
if !ok {
return fmt.Errorf("router '%s' references non-existant resolver or group '%s'", id, route.Resolver)
return fmt.Errorf("router '%s' references non-existent resolver or group '%s'", id, route.Resolver)
}
types := route.Types
if route.Type != "" { // Support the deprecated "Type" by just adding it to "Types" if defined
types = append(types, route.Type)
}
r, err := rdns.NewRoute(route.Name, route.Class, types, route.Weekdays, route.Before, route.After, route.Source, route.DoHPath, resolver)
r, err := rdns.NewRoute(route.Name, route.Class, types, route.Weekdays, route.Before, route.After, route.Source, route.DoHPath, route.Listener, route.TLSServerName, resolver)
if err != nil {
return fmt.Errorf("failure parsing routes for router '%s' : %s", id, err.Error())
}
Expand Down
6 changes: 3 additions & 3 deletions cmd/routedns/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ func instantiateResolver(id string, r resolver, resolvers map[string]rdns.Resolv
case "doq":
r.Address = rdns.AddressWithDefault(r.Address, rdns.DoQPort)

tlsConfig, err := rdns.TLSClientConfig(r.CA, r.ClientCrt, r.ClientKey)
tlsConfig, err := rdns.TLSClientConfig(r.CA, r.ClientCrt, r.ClientKey, r.ServerName)
if err != nil {
return err
}
Expand All @@ -31,7 +31,7 @@ func instantiateResolver(id string, r resolver, resolvers map[string]rdns.Resolv
case "dot":
r.Address = rdns.AddressWithDefault(r.Address, rdns.DoTPort)

tlsConfig, err := rdns.TLSClientConfig(r.CA, r.ClientCrt, r.ClientKey)
tlsConfig, err := rdns.TLSClientConfig(r.CA, r.ClientCrt, r.ClientKey, r.ServerName)
if err != nil {
return err
}
Expand Down Expand Up @@ -64,7 +64,7 @@ func instantiateResolver(id string, r resolver, resolvers map[string]rdns.Resolv
case "doh":
r.Address = rdns.AddressWithDefault(r.Address, rdns.DoHPort)

tlsConfig, err := rdns.TLSClientConfig(r.CA, r.ClientCrt, r.ClientKey)
tlsConfig, err := rdns.TLSClientConfig(r.CA, r.ClientCrt, r.ClientKey, r.ServerName)
if err != nil {
return err
}
Expand Down
17 changes: 13 additions & 4 deletions dnslistener.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package rdns

import (
"crypto/tls"
"net"

"github.com/miekg/dns"
Expand Down Expand Up @@ -49,10 +50,18 @@ func (s DNSListener) String() string {
func listenHandler(id, protocol, addr string, r Resolver, allowedNet []*net.IPNet) dns.HandlerFunc {
metrics := NewListenerMetrics("listener", id)
return func(w dns.ResponseWriter, req *dns.Msg) {
var (
ci ClientInfo
err error
)
var err error

ci := ClientInfo{
Listener: id,
}

if r, ok := w.(interface{ ConnectionState() *tls.ConnectionState }); ok {
connState := r.ConnectionState()
if connState != nil {
ci.TLSServerName = connState.ServerName
}
}

switch addr := w.RemoteAddr().(type) {
case *net.TCPAddr:
Expand Down
3 changes: 3 additions & 0 deletions doc/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -1063,6 +1063,8 @@ A route has the following fields:
- `before` - Time of day in the format HH:mm before which the rule matches. Uses 24h format. For example `17:30`.
- `invert` - Invert the result of the matching if set to `true`. Optional.
- `doh-path` - Regexp that matches on the DoH query path the client used.
- `listener` - Regexp that matches on the ID of the listener that first received.
- `servername` - Regexp that matches on the TLS server name used in the TLS handshake with the listener.
- `resolver` - The identifier of a resolver, group, or another router. Required.

Examples:
Expand Down Expand Up @@ -1348,6 +1350,7 @@ Secure resolvers such as DoT, DoH, or DoQ offer additional options to configure
- `client-crt` - Client certificate file.
- `client-key` - Client certificate key file
- `ca` - CA certificate to validate server certificates.
- `server-name` - Name of the certificate presented by the server if it does not match the name in the endpoint address.

Examples:

Expand Down
10 changes: 8 additions & 2 deletions dohlistener.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,9 +240,15 @@ func (s *DoHListener) parseAndRespond(b []byte, w http.ResponseWriter, r *http.R
http.Error(w, "Invalid RemoteAddr", http.StatusBadRequest)
return
}
var tlsServerName string
if r.TLS != nil {
tlsServerName = r.TLS.ServerName
}
ci := ClientInfo{
SourceIP: clientIP,
DoHPath: r.URL.Path,
SourceIP: clientIP,
DoHPath: r.URL.Path,
TLSServerName: tlsServerName,
Listener: s.id,
}
log := Log.WithFields(logrus.Fields{
"id": s.id,
Expand Down
6 changes: 3 additions & 3 deletions dohlistener_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func TestDoHListenerSimple(t *testing.T) {
time.Sleep(time.Second)

// Make a client talking to the listener using POST
tlsConfig, err := TLSClientConfig("testdata/ca.crt", "", "")
tlsConfig, err := TLSClientConfig("testdata/ca.crt", "", "", "")
require.NoError(t, err)
u := "https://" + addr + "/dns-query"
cPost, err := NewDoHClient("test-doh", u, DoHClientOptions{TLSConfig: tlsConfig, Method: "POST"})
Expand Down Expand Up @@ -82,7 +82,7 @@ func TestDoHListenerMutual(t *testing.T) {

// Make a client talking to the listener. Need to trust the issuer of the server certificate and
// present a client certificate.
tlsClientConfig, err := TLSClientConfig("testdata/ca.crt", "testdata/client.crt", "testdata/client.key")
tlsClientConfig, err := TLSClientConfig("testdata/ca.crt", "testdata/client.crt", "testdata/client.key", "")
require.NoError(t, err)
u := "https://" + addr + "/dns-query"
c, err := NewDoHClient("test-doh", u, DoHClientOptions{TLSConfig: tlsClientConfig})
Expand Down Expand Up @@ -116,7 +116,7 @@ func TestDoHListenerMutualQUIC(t *testing.T) {

// Make a client talking to the listener. Need to trust the issuer of the server certificate and
// present a client certificate.
tlsClientConfig, err := TLSClientConfig("testdata/ca.crt", "testdata/client.crt", "testdata/client.key")
tlsClientConfig, err := TLSClientConfig("testdata/ca.crt", "testdata/client.crt", "testdata/client.key", "")
require.NoError(t, err)
u := "https://" + addr + "/dns-query"
c, err := NewDoHClient("test-doh", u, DoHClientOptions{TLSConfig: tlsClientConfig, Transport: "quic"})
Expand Down
7 changes: 6 additions & 1 deletion doqlistener.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,12 @@ func (s DoQListener) Stop() error {
}

func (s DoQListener) handleConnection(connection quic.Connection) {
var ci ClientInfo
tlsServerName := connection.ConnectionState().TLS.ServerName

ci := ClientInfo{
Listener: s.id,
TLSServerName: tlsServerName,
}
switch addr := connection.RemoteAddr().(type) {
case *net.TCPAddr:
ci.SourceIP = addr.IP
Expand Down
2 changes: 1 addition & 1 deletion dotclient_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func TestDoTClientCA(t *testing.T) {
conn.Close()

// Create a config with CA using the temp file
tlsConfig, err := TLSClientConfig(crtFile, "", "")
tlsConfig, err := TLSClientConfig(crtFile, "", "", "")
require.NoError(t, err)

// DoT client with valid CA
Expand Down
6 changes: 3 additions & 3 deletions dotlistener_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func TestDoTListenerSimple(t *testing.T) {
time.Sleep(time.Second)

// Make a client talking to the listener. Need to trust the issue of the server certificate.
tlsConfig, err := TLSClientConfig("testdata/ca.crt", "", "")
tlsConfig, err := TLSClientConfig("testdata/ca.crt", "", "", "")
require.NoError(t, err)
c, _ := NewDoTClient("test-dot", addr, DoTClientOptions{TLSConfig: tlsConfig})

Expand Down Expand Up @@ -64,7 +64,7 @@ func TestDoTListenerMutual(t *testing.T) {

// Make a client talking to the listener. Need to trust the issue of the server certificate and
// present a client certificate.
tlsClientConfig, err := TLSClientConfig("testdata/ca.crt", "testdata/client.crt", "testdata/client.key")
tlsClientConfig, err := TLSClientConfig("testdata/ca.crt", "testdata/client.crt", "testdata/client.key", "")
require.NoError(t, err)
c, _ := NewDoTClient("test-dot", addr, DoTClientOptions{TLSConfig: tlsClientConfig})

Expand Down Expand Up @@ -99,7 +99,7 @@ func TestDoTListenerPadding(t *testing.T) {
time.Sleep(time.Second)

// Make a client talking to the listener. Need to trust the issue of the server certificate.
tlsConfig, err := TLSClientConfig("testdata/ca.crt", "", "")
tlsConfig, err := TLSClientConfig("testdata/ca.crt", "", "", "")
require.NoError(t, err)
c, _ := NewDoTClient("test-dot", addr, DoTClientOptions{TLSConfig: tlsConfig})

Expand Down
4 changes: 2 additions & 2 deletions example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ func Example_router() {

// Build a router that will send all "*.cloudflare.com" to the cloudflare
// resolver while everything else goes to the google resolver (default)
route1, _ := rdns.NewRoute(`\.cloudflare\.com\.$`, "", nil, nil, "", "", "", "", cloudflare)
route2, _ := rdns.NewRoute("", "", nil, nil, "", "", "", "", google)
route1, _ := rdns.NewRoute(`\.cloudflare\.com\.$`, "", nil, nil, "", "", "", "", "", "", cloudflare)
route2, _ := rdns.NewRoute("", "", nil, nil, "", "", "", "", "", "", google)
r := rdns.NewRouter("my-router")
r.Add(route1, route2)

Expand Down
7 changes: 7 additions & 0 deletions listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@ type ClientInfo struct {
// DoH query path used by the client. Only populated when
// the query was received over DoH.
DoHPath string

// TLS SNI server name
TLSServerName string

// Listener ID of the listener that first received the request. Can be
// used to route queries.
Listener string
}

// Metrics that are available from listeners and clients.
Expand Down
8 changes: 4 additions & 4 deletions response-collapse.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,18 @@ import (
type ResponseCollapse struct {
id string
resolver Resolver
ResponseCollapsOptions
ResponseCollapseOptions
}

type ResponseCollapsOptions struct {
type ResponseCollapseOptions struct {
NullRCode int // Response code when there's nothing left after collapsing the response
}

var _ Resolver = &ResponseCollapse{}

// NewResponseMinimize returns a new instance of a response minimizer.
func NewResponseCollapse(id string, resolver Resolver, opt ResponseCollapsOptions) *ResponseCollapse {
return &ResponseCollapse{id: id, resolver: resolver, ResponseCollapsOptions: opt}
func NewResponseCollapse(id string, resolver Resolver, opt ResponseCollapseOptions) *ResponseCollapse {
return &ResponseCollapse{id: id, resolver: resolver, ResponseCollapseOptions: opt}
}

// Resolve a DNS query, then collapse the response to remove anything from the
Expand Down
64 changes: 44 additions & 20 deletions route.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,22 @@ import (
)

type route struct {
types []uint16
class uint16
name *regexp.Regexp
source *net.IPNet
weekdays []time.Weekday
before *TimeOfDay
after *TimeOfDay
inverted bool // invert the matching behavior
dohPath *regexp.Regexp
resolver Resolver
types []uint16
class uint16
name *regexp.Regexp
source *net.IPNet
weekdays []time.Weekday
before *TimeOfDay
after *TimeOfDay
inverted bool // invert the matching behavior
dohPath *regexp.Regexp
resolver Resolver
listenerID *regexp.Regexp
tlsServerName *regexp.Regexp
}

// NewRoute initializes a route from string parameters.
func NewRoute(name, class string, types, weekdays []string, before, after, source, dohPath string, resolver Resolver) (*route, error) {
func NewRoute(name, class string, types, weekdays []string, before, after, source, dohPath, listenerID, tlsServerName string, resolver Resolver) (*route, error) {
if resolver == nil {
return nil, errors.New("no resolver defined for route")
}
Expand Down Expand Up @@ -58,6 +60,14 @@ func NewRoute(name, class string, types, weekdays []string, before, after, sourc
if err != nil {
return nil, err
}
listenerRe, err := regexp.Compile(listenerID)
if err != nil {
return nil, err
}
tlsRe, err := regexp.Compile(tlsServerName)
if err != nil {
return nil, err
}
var sNet *net.IPNet
if source != "" {
_, sNet, err = net.ParseCIDR(source)
Expand All @@ -66,15 +76,17 @@ func NewRoute(name, class string, types, weekdays []string, before, after, sourc
}
}
return &route{
types: t,
class: c,
name: re,
weekdays: w,
before: b,
after: a,
source: sNet,
dohPath: dohRe,
resolver: resolver,
types: t,
class: c,
name: re,
weekdays: w,
before: b,
after: a,
source: sNet,
dohPath: dohRe,
listenerID: listenerRe,
tlsServerName: tlsRe,
resolver: resolver,
}, nil
}

Expand All @@ -95,6 +107,12 @@ func (r *route) match(q *dns.Msg, ci ClientInfo) bool {
if !r.dohPath.MatchString(ci.DoHPath) {
return r.inverted
}
if !r.listenerID.MatchString(ci.Listener) {
return r.inverted
}
if !r.tlsServerName.MatchString(ci.TLSServerName) {
return r.inverted
}
if len(r.weekdays) > 0 || r.before != nil || r.after != nil {
now := time.Now().Local()
hour := now.Hour()
Expand Down Expand Up @@ -151,6 +169,12 @@ func (r *route) String() string {
if r.dohPath.String() != "" {
fragments = append(fragments, "doh-path="+r.dohPath.String())
}
if r.listenerID.String() != "" {
fragments = append(fragments, "listener="+r.listenerID.String())
}
if r.tlsServerName.String() != "" {
fragments = append(fragments, "servername="+r.tlsServerName.String())
}
if len(r.weekdays) > 0 {
fragments = append(fragments, fmt.Sprintf("weekdays=%v", r.weekdays))
}
Expand Down
Loading