-
Notifications
You must be signed in to change notification settings - Fork 63
/
response-blocklist-name.go
136 lines (120 loc) · 3.87 KB
/
response-blocklist-name.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
package rdns
import (
"strings"
"sync"
"time"
"github.com/miekg/dns"
)
// ResponseBlocklistName is a resolver that filters by matching the strings in CNAME, MX,
// NS, PTR and SRV response records against a blocklist.
type ResponseBlocklistName struct {
id string
ResponseBlocklistNameOptions
resolver Resolver
mu sync.RWMutex
}
var _ Resolver = &ResponseBlocklistName{}
type ResponseBlocklistNameOptions struct {
// Optional, if the response is found to match the blocklist, send the query to this resolver.
BlocklistResolver Resolver
BlocklistDB BlocklistDB
// Refresh period for the blocklist. Disabled if 0.
BlocklistRefresh time.Duration
// Inverted behavior, only allow responses that can be found on at least one list.
Inverted bool
// Optional, allows specifying extended errors to be used in the
// response when blocking.
EDNS0EDETemplate *EDNS0EDETemplate
}
// NewResponseBlocklistName returns a new instance of a response blocklist resolver.
func NewResponseBlocklistName(id string, resolver Resolver, opt ResponseBlocklistNameOptions) (*ResponseBlocklistName, error) {
blocklist := &ResponseBlocklistName{id: id, resolver: resolver, ResponseBlocklistNameOptions: opt}
// Start the refresh goroutines if we have a list and a refresh period was given
if blocklist.BlocklistDB != nil && blocklist.BlocklistRefresh > 0 {
go blocklist.refreshLoopBlocklist(blocklist.BlocklistRefresh)
}
return blocklist, nil
}
// Resolve a DNS query by first querying the upstream resolver, then checking any responses with
// strings against a blocklist. Responds with NXDOMAIN if the response matches the filter.
func (r *ResponseBlocklistName) Resolve(q *dns.Msg, ci ClientInfo) (*dns.Msg, error) {
answer, err := r.resolver.Resolve(q, ci)
if err != nil || answer == nil {
return answer, err
}
return r.blockIfMatch(q, answer, ci)
}
func (r *ResponseBlocklistName) String() string {
return r.id
}
func (r *ResponseBlocklistName) refreshLoopBlocklist(refresh time.Duration) {
for {
time.Sleep(refresh)
log := Log.WithField("id", r.id)
log.Debug("reloading blocklist")
db, err := r.BlocklistDB.Reload()
if err != nil {
Log.WithError(err).Error("failed to load rules")
continue
}
r.mu.Lock()
r.BlocklistDB = db
r.mu.Unlock()
}
}
func (r *ResponseBlocklistName) blockIfMatch(query, answer *dns.Msg, ci ClientInfo) (*dns.Msg, error) {
for _, records := range [][]dns.RR{answer.Answer, answer.Ns, answer.Extra} {
for _, rr := range records {
var name string
switch r := rr.(type) {
case *dns.CNAME:
name = r.Target
case *dns.MX:
name = r.Mx
case *dns.NS:
name = r.Ns
case *dns.PTR:
name = r.Ptr
case *dns.SRV:
name = r.Target
case *dns.HTTPS:
name = svcbString(&r.SVCB)
case *dns.TXT:
name = strings.Join(r.Txt, " ")
case *dns.SVCB:
name = svcbString(r)
case *dns.SOA:
name = r.Ns
default:
continue
}
if _, _, rule, ok := r.BlocklistDB.Match(dns.Question{Name: name}); ok != r.Inverted {
log := logger(r.id, query, ci).WithField("rule", rule.GetRule())
if r.BlocklistResolver != nil {
log.WithField("resolver", r.BlocklistResolver).Debug("blocklist match, forwarding to blocklist-resolver")
return r.BlocklistResolver.Resolve(query, ci)
}
log.Debug("blocking response")
answer = nxdomain(query)
if err := r.EDNS0EDETemplate.Apply(answer, EDNS0EDEInput{query, rule}); err != nil {
log.WithError(err).Error("failed to apply edns0ede template")
}
return answer, nil
}
}
}
return answer, nil
}
// Format an SVCB (and HTTPS) record as string like so "TARGET key1=value1 key2=value2"
// For example: ". alpn=h2,h3"
func svcbString(rr *dns.SVCB) string {
var s strings.Builder
s.WriteString(rr.Target)
for _, v := range rr.Value {
s.WriteString(" ")
s.WriteString(v.Key().String())
s.WriteString("=")
s.WriteString(v.String())
}
return s.String()
}