-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdres.go
136 lines (117 loc) · 3.33 KB
/
dres.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
package main
import (
"errors"
"fmt"
"log"
"net"
"sort"
)
import "github.com/miekg/dns"
type Network struct {
Name string
Net net.IPNet
}
type Networks []Network
func (networks Networks) Len() int { return len(networks) }
func (networks Networks) Less(i, j int) bool {
s1, _ := networks[i].Net.Mask.Size()
s2, _ := networks[j].Net.Mask.Size()
return s1 > s2
}
func (networks Networks) Swap(i, j int) { networks[i], networks[j] = networks[j], networks[i] }
type Dres struct {
Networks Networks
Resolvers map[string][]Resolver
}
func Load(config Config) Dres {
var resolverByName = make(map[string]Resolver)
for resolverName, resolverConfig := range config.Resolvers {
resolver, err := LoadResolver(resolverName, resolverConfig)
if err != nil {
log.Fatalf("Error constructing resolver: %s", err)
} else {
resolverByName[resolverName] = resolver
log.Printf("Loaded resolver %s of type %s", resolverName, resolverConfig.Type)
}
}
var networks Networks
for rangeName, cidr := range config.CIDRS {
_, ipNet, err := net.ParseCIDR(cidr)
if err != nil {
log.Fatalf("Error reading CIDR: %s", err)
} else {
network := Network{
Name: rangeName,
Net: *ipNet,
}
networks = append(networks, network)
log.Printf("Loaded network %s: %s", rangeName, cidr)
}
}
sort.Sort(networks)
var resolvers = make(map[string][]Resolver)
for rangeName, resolverNames := range config.Configuration {
log.Printf("Range %s has following resolvers:", rangeName)
_, ok := resolvers[rangeName]
if !ok {
resolvers[rangeName] = make([]Resolver, len(resolverNames))
}
for i, resolverName := range resolverNames {
resolvers[rangeName][i] = resolverByName[resolverName]
log.Printf(" - %s", resolverName)
}
}
return Dres{
Networks: networks,
Resolvers: resolvers,
}
}
func (dres Dres) GetResolvers(addr net.Addr) []Resolver {
networkName, err := dres.GetNetworkName(addr)
if err != nil {
log.Printf("Network for address %s not found: %s", addr, err)
return make([]Resolver, 0)
}
return dres.Resolvers[networkName]
}
func (dres Dres) HandleFunc(writer dns.ResponseWriter, msg *dns.Msg) {
log.Printf("Request from %s", writer.RemoteAddr())
for _, question := range msg.Question {
log.Printf(" Question: %s", question.String())
}
for _, resolver := range dres.GetResolvers(writer.RemoteAddr()) {
response, err := resolver.Handle(msg)
if err != nil {
log.Printf(" Resolver %s failed to handle query: %s", resolver.GetName(), err)
} else {
log.Printf(" Answer from resolver %s", resolver.GetName())
response.Compress = true
if writer.WriteMsg(response) != nil {
log.Printf(" Unable to response. See error %s", err)
} else {
_ = writer.Close()
return
}
}
}
log.Printf(" Query from %s not handled", writer.RemoteAddr())
_ = writer.Close()
}
func (dres Dres) GetNetworkName(addr net.Addr) (string, error) {
for _, network := range dres.Networks {
if network.Net.Contains(GetIP(addr)) {
return network.Name, nil
}
}
errorMessage := fmt.Sprintf("unable to find network for %s", addr.String())
return "", errors.New(errorMessage)
}
func main() {
dres := Load(LoadConfig())
server := dns.Server{Addr: ":53", Net: "udp"}
dns.HandleFunc(".", dres.HandleFunc)
err := server.ListenAndServe()
if err != nil {
log.Fatalf("Unable to start dres. See error: %s", err)
}
}