diff --git a/pkg/services/dns/dns.go b/pkg/services/dns/dns.go index e50bbc52..9168dff0 100644 --- a/pkg/services/dns/dns.go +++ b/pkg/services/dns/dns.go @@ -16,17 +16,16 @@ import ( ) type dnsHandler struct { - zones []types.Zone - zonesLock sync.RWMutex - udpClient *dns.Client - tcpClient *dns.Client - hostsFile *HostsFile - nameservers []string + zones []types.Zone + zonesLock sync.RWMutex + udpClient *dns.Client + tcpClient *dns.Client + hostsFile *HostsFile + dnsConfig *dnsConfig } func newDNSHandler(zones []types.Zone) (*dnsHandler, error) { - - nameservers, err := getDNSHostAndPort() + dnsConfig, err := newDNSConfig() if err != nil { return nil, err } @@ -37,13 +36,12 @@ func newDNSHandler(zones []types.Zone) (*dnsHandler, error) { } return &dnsHandler{ - zones: zones, - tcpClient: &dns.Client{Net: "tcp"}, - udpClient: &dns.Client{Net: "udp"}, - nameservers: nameservers, - hostsFile: hostsFile, + zones: zones, + tcpClient: &dns.Client{Net: "tcp"}, + udpClient: &dns.Client{Net: "udp"}, + dnsConfig: dnsConfig, + hostsFile: hostsFile, }, nil - } func (h *dnsHandler) handle(w dns.ResponseWriter, dnsClient *dns.Client, r *dns.Msg, responseMessageSize int) { @@ -145,7 +143,7 @@ func (h *dnsHandler) addAnswers(dnsClient *dns.Client, r *dns.Msg) *dns.Msg { return m } } - for _, nameserver := range h.nameservers { + for _, nameserver := range h.dnsConfig.Nameservers() { msg := r.Copy() r, _, err := dnsClient.Exchange(msg, nameserver) // return first good answer diff --git a/pkg/services/dns/dns_config.go b/pkg/services/dns/dns_config.go new file mode 100644 index 00000000..1abdfe0c --- /dev/null +++ b/pkg/services/dns/dns_config.go @@ -0,0 +1,24 @@ +package dns + +import "sync" + +type dnsConfig struct { + mu sync.RWMutex + nameservers []string +} + +func newDNSConfig() (*dnsConfig, error) { + r := &dnsConfig{nameservers: []string{}} + if err := r.init(); err != nil { + return nil, err + } + + return r, nil +} + +func (r *dnsConfig) Nameservers() []string { + r.mu.RLock() + defer r.mu.RUnlock() + + return r.nameservers +} diff --git a/pkg/services/dns/dns_config_unix.go b/pkg/services/dns/dns_config_unix.go index ef66be39..9d795a83 100644 --- a/pkg/services/dns/dns_config_unix.go +++ b/pkg/services/dns/dns_config_unix.go @@ -7,18 +7,53 @@ import ( "net" "net/netip" + "github.com/containers/gvisor-tap-vsock/pkg/utils" "github.com/miekg/dns" log "github.com/sirupsen/logrus" ) +func (r *dnsConfig) init() error { + if err := r.refreshNameservers(); err != nil { + return err + } + + w, err := utils.NewFileWatcher(etcResolvConfPath) + if err != nil { + return err + } + + if err := w.Start(func() { _ = r.refreshNameservers() }); err != nil { + return err + } + + return nil +} + +func (r *dnsConfig) refreshNameservers() error { + nsList, err := getDNSHostAndPort(etcResolvConfPath) + if err != nil { + log.Errorf("can't load dns nameservers: %v", err) + return err + } + + log.Infof("reloading dns nameservers to %v", nsList) + + r.mu.Lock() + r.nameservers = nsList + r.mu.Unlock() + return nil +} + +const etcResolvConfPath = "/etc/resolv.conf" + var errEmptyResolvConf = errors.New("no DNS servers configured in /etc/resolv.conf") -func getDNSHostAndPort() ([]string, error) { - conf, err := dns.ClientConfigFromFile("/etc/resolv.conf") +func getDNSHostAndPort(path string) ([]string, error) { + conf, err := dns.ClientConfigFromFile(path) if err != nil { return []string{}, err } - var hosts = make([]string, len(conf.Servers)) + hosts := make([]string, 0, len(conf.Servers)) for _, server := range conf.Servers { dnsIP, err := netip.ParseAddr(server) if err != nil { diff --git a/pkg/services/dns/dns_config_windows.go b/pkg/services/dns/dns_config_windows.go index 5f1166d1..431727df 100644 --- a/pkg/services/dns/dns_config_windows.go +++ b/pkg/services/dns/dns_config_windows.go @@ -9,6 +9,16 @@ import ( qdmDns "github.com/qdm12/dns/v2/pkg/nameserver" ) +func (r *dnsConfig) init() error { + nsList, err := getDNSHostAndPort() + if err != nil { + return err + } + + r.nameservers = nsList + return nil +} + func getDNSHostAndPort() ([]string, error) { nameservers := qdmDns.GetDNSServers() @@ -21,5 +31,4 @@ func getDNSHostAndPort() ([]string, error) { } return dnsServers, nil - } diff --git a/pkg/services/dns/hosts_file.go b/pkg/services/dns/hosts_file.go index d8d8c7d8..106129ee 100644 --- a/pkg/services/dns/hosts_file.go +++ b/pkg/services/dns/hosts_file.go @@ -2,11 +2,10 @@ package dns import ( "net" - "path/filepath" "sync" "github.com/areYouLazy/libhosty" - "github.com/fsnotify/fsnotify" + "github.com/containers/gvisor-tap-vsock/pkg/utils" log "github.com/sirupsen/logrus" ) @@ -23,48 +22,31 @@ func NewHostsFile(hostsPath string) (*HostsFile, error) { if err != nil { return nil, err } - watcher, err := fsnotify.NewWatcher() - if err != nil { - return nil, err - } h := &HostsFile{ hostsFile: hostsFile, hostsFilePath: hostsFile.Config.FilePath, } - go func() { - h.startWatch(watcher) - }() + if err := h.startWatch(); err != nil { + return nil, err + } + return h, nil } -func (h *HostsFile) startWatch(w *fsnotify.Watcher) { - err := w.Add(filepath.Dir(h.hostsFilePath)) - +func (h *HostsFile) startWatch() error { + watcher, err := utils.NewFileWatcher(h.hostsFilePath) if err != nil { - log.Errorf("Hosts file adding watcher error:%s", err) - return + log.Errorf("Hosts file adding watcher error: %s", err) + return err } - for { - select { - case err, ok := <-w.Errors: - if !ok { - return - } - log.Errorf("Hosts file watcher error:%s", err) - case event, ok := <-w.Events: - if !ok { - return - } - if event.Name == h.hostsFilePath && event.Op&fsnotify.Write == fsnotify.Write { - err := h.updateHostsFile() - if err != nil { - log.Errorf("Hosts file read error:%s", err) - return - } - } - } + + if err := watcher.Start(h.updateHostsFile); err != nil { + log.Errorf("Hosts file adding watcher error: %s", err) + return err } + + return nil } func (h *HostsFile) LookupByHostname(name string) (net.IP, error) { @@ -75,17 +57,17 @@ func (h *HostsFile) LookupByHostname(name string) (net.IP, error) { return ip, err } -func (h *HostsFile) updateHostsFile() error { +func (h *HostsFile) updateHostsFile() { newHosts, err := readHostsFile(h.hostsFilePath) if err != nil { - return err + log.Errorf("Hosts file read error:%s", err) + return } h.hostsReadLock.Lock() defer h.hostsReadLock.Unlock() h.hostsFile = newHosts - return nil } func readHostsFile(hostsFilePath string) (*libhosty.HostsFile, error) { diff --git a/pkg/utils/filewatcher.go b/pkg/utils/filewatcher.go new file mode 100644 index 00000000..0993d38e --- /dev/null +++ b/pkg/utils/filewatcher.go @@ -0,0 +1,84 @@ +package utils + +import ( + "fmt" + "path/filepath" + "time" + + "github.com/fsnotify/fsnotify" +) + +// FileWatcher is an utility that +type FileWatcher struct { + w *fsnotify.Watcher + path string + + writeGracePeriod time.Duration + timer *time.Timer +} + +func NewFileWatcher(path string) (*FileWatcher, error) { + watcher, err := fsnotify.NewWatcher() + if err != nil { + return nil, err + } + + return &FileWatcher{w: watcher, path: path, writeGracePeriod: 200 * time.Millisecond}, nil +} + +func (fw *FileWatcher) Start(changeHandler func()) error { + // Ensure that the target that we're watching is not a symlink as we won't get any events when we're watching + // a symlink. + fileRealPath, err := filepath.EvalSymlinks(fw.path) + if err != nil { + return fmt.Errorf("adding watcher failed: %s", err) + } + + // watch the directory instead of the individual file to ensure the notification still works when the file is modified + // through moving/renaming rather than writing into it directly (like what most modern editor does by default). + // ref: https://github.com/fsnotify/fsnotify/blob/a9bc2e01792f868516acf80817f7d7d7b3315409/README.md#watching-a-file-doesnt-work-well + if err = fw.w.Add(filepath.Dir(fileRealPath)); err != nil { + return fmt.Errorf("adding watcher failed: %s", err) + } + + go func() { + for { + select { + case _, ok := <-fw.w.Errors: + if !ok { + return // watcher is closed. + } + case event, ok := <-fw.w.Events: + if !ok { + return // watcher is closed. + } + + if event.Name != fileRealPath { + continue // we don't care about this file. + } + + // Create may not always followed by Write e.g. when we replace the file with mv. + if event.Op.Has(fsnotify.Create) || event.Op.Has(fsnotify.Write) { + // as per the documentation, receiving Write does not mean that the write is finished. + // we try our best here to ignore "unfinished" write by assuming that after [writeGracePeriod] of + // inactivity the write has been finished. + fw.debounce(changeHandler) + } + } + } + }() + + return nil +} + +func (fw *FileWatcher) debounce(fn func()) { + if fw.timer != nil { + fw.timer.Stop() + } + + fw.timer = time.AfterFunc(fw.writeGracePeriod, fn) +} + +func (fw *FileWatcher) Stop() error { + return fw.w.Close() +} diff --git a/pkg/utils/filewatcher_test.go b/pkg/utils/filewatcher_test.go new file mode 100644 index 00000000..a13950b9 --- /dev/null +++ b/pkg/utils/filewatcher_test.go @@ -0,0 +1,67 @@ +package utils + +import ( + "os" + "path/filepath" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestFileWatcher(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "file.txt") + _ = os.WriteFile(path, []byte("1"), 0o600) + + fw, err := NewFileWatcher(path) + fw.writeGracePeriod = 50 * time.Millisecond // reduce the delay to make the test runs faster. + assert.NoError(t, err) + _ = fw.w.Add(path) + + var numTriggered atomic.Int64 + assertNumTriggered := func(expected int) { + time.Sleep(fw.writeGracePeriod + 50*time.Millisecond) + assert.Equal(t, int64(expected), numTriggered.Load()) + } + + _ = fw.Start(func() { + numTriggered.Add(1) + }) + + // CASE: can detect changes to the file. + if err := os.WriteFile(path, []byte("2"), 0o600); err != nil { + panic(err) + } + assertNumTriggered(1) + + // CASE: can detect "swap"-based file modification. + tmpFile := filepath.Join(dir, "tmp.txt") + if err := os.WriteFile(tmpFile, []byte("lol"), 0o600); err != nil { + panic(err) + } + if err := os.Rename(tmpFile, path); err != nil { + panic(err) + } + assertNumTriggered(2) + + // CASE: combine multiple partial writes into single event. + fd, err := os.OpenFile(path, os.O_WRONLY|os.O_TRUNC, 0o600) + if err != nil { + panic(err) + } + // we assume these writes happens in less than 50ms. + _, _ = fd.Write([]byte("a")) + _ = fd.Sync() + _, _ = fd.Write([]byte("b")) + fd.Close() + assertNumTriggered(3) + + // CASE: closed file watcher should not call the callback after Stop() is called. + assert.NoError(t, fw.Stop()) + if err := os.WriteFile(path, []byte("2"), 0o600); err != nil { + panic(err) + } + assertNumTriggered(3) // does not change. +}