diff --git a/dns/dhcp.go b/dns/dhcp.go index 7932a03..9d8f34d 100644 --- a/dns/dhcp.go +++ b/dns/dhcp.go @@ -8,14 +8,12 @@ import ( "github.com/wweir/netboot/dhcp4" ) -func GetDefaultDNSServer() string { - xid := make([]byte, 4) - rand.Read(xid) +var xid = make([]byte, 4) +func GetDefaultDNSServer() string { pack := &dhcp4.Packet{ - Type: dhcp4.MsgDiscover, - TransactionID: xid, - Broadcast: true, + Type: dhcp4.MsgDiscover, + Broadcast: true, } options := map[dhcp4.Option][]byte{ dhcp4.OptRequestedOptions: []byte{byte(dhcp4.OptDNSServers)}, @@ -24,12 +22,14 @@ func GetDefaultDNSServer() string { ifaces := mustGetInterfaces() for _, iface := range ifaces { conn, err := dhcp4.NewConn(iface.IP.String() + ":68") - if err != nil { - glog.Errorln(err) + if err != nil { // maybe in use + glog.V(1).Infoln(err) continue } defer conn.Close() + rand.Read(xid) + pack.TransactionID = xid pack.HardwareAddr = iface.Interface.HardwareAddr options[dhcp4.OptClientIdentifier] = iface.Interface.HardwareAddr pack.Options = dhcp4.Options(options) @@ -45,12 +45,12 @@ func GetDefaultDNSServer() string { continue } - ip, err := pack.Options.IP(dhcp4.OptDNSServers) + ips, err := pack.Options.IPs(dhcp4.OptDNSServers) if err != nil { glog.Errorln(err) continue } - return ip.String() + return ips[0].String() // if len(ips) == 0, err should not be wrong size } return "" } diff --git a/dns/dns.go b/dns/dns.go index ab7b4a7..41d4728 100644 --- a/dns/dns.go +++ b/dns/dns.go @@ -18,8 +18,23 @@ func StartDNS(dnsServer, listenIP string) { ip := net.ParseIP(listenIP) suggest := &intelliSuggest{listenIP, 2 * time.Second, []string{"80", "443"}} mem.DefaultCache = mem.New(time.Hour) + var dhcpCh chan struct{} if dnsServer != "" { dnsServer = net.JoinHostPort(dnsServer, "53") + } else { + dhcpCh = make(chan struct{}) + go func() { + for { + <-dhcpCh + host := GetDefaultDNSServer() + if host == "" { + continue + } + // atomic action + dnsServer = net.JoinHostPort(host, "53") + glog.Infoln("set dns server to", host) + } + }() } dns.HandleFunc(".", func(w dns.ResponseWriter, r *dns.Msg) { @@ -39,15 +54,15 @@ func StartDNS(dnsServer, listenIP string) { domain = domain[:idx] } - matchAndServe(w, r, domain, listenIP, &dnsServer, ip, suggest) + matchAndServe(w, r, domain, listenIP, dnsServer, dhcpCh, ip, suggest) }) server := &dns.Server{Addr: net.JoinHostPort(listenIP, "53"), Net: "udp"} glog.Fatalln(server.ListenAndServe()) } -func matchAndServe(w dns.ResponseWriter, r *dns.Msg, domain, listenIP string, - dnsServer *string, ipNet net.IP, suggest *intelliSuggest) { +func matchAndServe(w dns.ResponseWriter, r *dns.Msg, domain, listenIP, dnsServer string, + dhcpCh chan struct{}, ipNet net.IP, suggest *intelliSuggest) { inWriteList := whiteList.Match(domain) @@ -61,18 +76,12 @@ func matchAndServe(w dns.ResponseWriter, r *dns.Msg, domain, listenIP string, go mem.Remember(suggest, domain) } - if *dnsServer == "" { - host := GetDefaultDNSServer() - if host == "" { - return + msg, err := dns.Exchange(r, dnsServer) + if err != nil && dhcpCh != nil { + select { + case dhcpCh <- struct{}{}: + default: } - *dnsServer = net.JoinHostPort(host, "53") - glog.Infoln("set dns server to", host) - } - - msg, err := dns.Exchange(r, *dnsServer) - if err != nil { - *dnsServer = "" } if msg == nil { // expose any response except nil glog.V(1).Infof("get dns of %s fail: %s", domain, err)