Skip to content

Commit

Permalink
[service] DNS listener refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
vlabo committed Nov 19, 2024
1 parent 194eac2 commit b3b1615
Show file tree
Hide file tree
Showing 11 changed files with 352 additions and 170 deletions.
90 changes: 25 additions & 65 deletions service/firewall/interception/dnslistener/etwlink_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,12 @@ import (
"sync"
"sync/atomic"

"github.com/safing/portmaster/service/integration"
"golang.org/x/sys/windows"
)

type ETWSession struct {
dll *windows.DLL

createState *windows.Proc
initializeSession *windows.Proc
startTrace *windows.Proc
flushTrace *windows.Proc
stopTrace *windows.Proc
destroySession *windows.Proc
i integration.ETWFunctions

shutdownGuard atomic.Bool
shutdownMutex sync.Mutex
Expand All @@ -26,70 +20,39 @@ type ETWSession struct {
}

// NewSession creates new ETW event listener and initilizes it. This is a low level interface, make sure to call DestorySession when you are done using it.
func NewSession(dllpath string, callback func(domain string, result string)) (*ETWSession, error) {
etwListener := &ETWSession{}

// Initialize dll functions
var err error
etwListener.dll, err = windows.LoadDLL(dllpath)
if err != nil {
return nil, fmt.Errorf("faild to load dll: %q", err)
}
etwListener.createState, err = etwListener.dll.FindProc("PM_ETWCreateState")
if err != nil {
return nil, fmt.Errorf("faild to load function PM_ETWCreateState: %q", err)
}
etwListener.initializeSession, err = etwListener.dll.FindProc("PM_ETWInitializeSession")
if err != nil {
return nil, fmt.Errorf("faild to load function PM_ETWInitializeSession: %q", err)
}
etwListener.startTrace, err = etwListener.dll.FindProc("PM_ETWStartTrace")
if err != nil {
return nil, fmt.Errorf("faild to load function PM_ETWStartTrace: %q", err)
}
etwListener.flushTrace, err = etwListener.dll.FindProc("PM_ETWFlushTrace")
if err != nil {
return nil, fmt.Errorf("faild to load function PM_ETWFlushTrace: %q", err)
}
etwListener.stopTrace, err = etwListener.dll.FindProc("PM_ETWStopTrace")
if err != nil {
return nil, fmt.Errorf("faild to load function PM_ETWStopTrace: %q", err)
}
etwListener.destroySession, err = etwListener.dll.FindProc("PM_ETWDestroySession")
if err != nil {
return nil, fmt.Errorf("faild to load function PM_ETWDestroySession: %q", err)
func NewSession(etwInterface integration.ETWFunctions, callback func(domain string, result string)) (*ETWSession, error) {
etwSession := &ETWSession{
i: etwInterface,
}

// Make sure session from previous instances are not running.
_ = etwSession.i.StopOldSession()

// Initialize notification activated callback
win32Callback := windows.NewCallback(func(domain *uint16, result *uint16) uintptr {
callback(windows.UTF16PtrToString(domain), windows.UTF16PtrToString(result))
return 0
})
// The function only allocates memory it will not fail.
etwListener.state, _, _ = etwListener.createState.Call(win32Callback)
etwSession.state = etwSession.i.CreateState(win32Callback)

// Make sure DestroySession is called even if caller forgets to call it.
runtime.SetFinalizer(etwListener, func(l *ETWSession) {
_ = l.DestroySession()
runtime.SetFinalizer(etwSession, func(s *ETWSession) {
_ = s.i.DestroySession(s.state)
})

// Initialize session.
var rc uintptr
rc, _, err = etwListener.initializeSession.Call(etwListener.state)
if rc != 0 {
return nil, fmt.Errorf("failed to initialzie session: error code: %q", rc)
err := etwSession.i.InitializeSession(etwSession.state)
if err != nil {
return nil, fmt.Errorf("failed to initialzie session: %q", err)
}

return etwListener, nil
return etwSession, nil
}

// StartTrace starts the tracing session of dns events. This is a blocking call. It will not return until the trace is stopped.
func (l *ETWSession) StartTrace() error {
rc, _, _ := l.startTrace.Call(l.state)
if rc != 0 {
return fmt.Errorf("error code: %q", rc)
}
return nil
return l.i.StartTrace(l.state)
}

// IsRunning returns true if DestroySession has NOT been called.
Expand All @@ -102,20 +65,17 @@ func (l *ETWSession) FlushTrace() error {
l.shutdownMutex.Lock()
defer l.shutdownMutex.Unlock()

rc, _, _ := l.flushTrace.Call(l.state)
if rc != 0 {
return fmt.Errorf("error code: %q", rc)
// Make sure session is still running.
if l.shutdownGuard.Load() {
return nil
}
return nil

return l.i.FlushTrace(l.state)
}

// StopTrace stopes the trace. This will cause StartTrace to return.
func (l *ETWSession) StopTrace() error {
rc, _, _ := l.stopTrace.Call(l.state)
if rc != 0 {
return fmt.Errorf("error code: %q", rc)
}
return nil
return l.i.StopTrace(l.state)
}

// DestroySession closes the session and frees the allocated memory. Listener cannot be used after this function is called.
Expand All @@ -129,9 +89,9 @@ func (l *ETWSession) DestroySession() error {

l.shutdownGuard.Store(true)

rc, _, _ := l.destroySession.Call(l.state)
if rc != 0 {
return fmt.Errorf("error code: %q", rc)
err := l.i.DestroySession(l.state)
if err != nil {
return err
}
l.state = 0
return nil
Expand Down
8 changes: 2 additions & 6 deletions service/firewall/interception/dnslistener/eventlistener.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,13 @@

package dnslistener

import (
"github.com/safing/portmaster/service/mgr"
)

type Listener struct{}

func newListener(_ *mgr.Manager) (*Listener, error) {
func newListener(module *DNSListener) (*Listener, error) {
return &Listener{}, nil
}

func (l *Listener) flish() error {
func (l *Listener) flush() error {
// Nothing to flush
return nil
}
Expand Down
50 changes: 5 additions & 45 deletions service/firewall/interception/dnslistener/eventlistener_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,24 @@ import (
"fmt"
"net"

"github.com/miekg/dns"
"github.com/safing/portmaster/base/log"
"github.com/safing/portmaster/service/mgr"
"github.com/safing/portmaster/service/network/netutils"
"github.com/safing/portmaster/service/resolver"
"github.com/varlink/go/varlink"
)

type Listener struct {
varlinkConn *varlink.Connection
}

func newListener(m *mgr.Manager) (*Listener, error) {
func newListener(module *DNSListener) (*Listener, error) {
// Create the varlink connection with the systemd resolver.
varlinkConn, err := varlink.NewConnection(m.Ctx(), "unix:/run/systemd/resolve/io.systemd.Resolve.Monitor")
varlinkConn, err := varlink.NewConnection(module.mgr.Ctx(), "unix:/run/systemd/resolve/io.systemd.Resolve.Monitor")
if err != nil {
return nil, fmt.Errorf("dnslistener: failed to connect to systemd-resolver varlink service: %w", err)
}

listener := &Listener{varlinkConn: varlinkConn}

m.Go("systemd-resolver-event-listener", func(w *mgr.WorkerCtx) error {
module.mgr.Go("systemd-resolver-event-listener", func(w *mgr.WorkerCtx) error {
// Subscribe to the dns query events
receive, err := listener.varlinkConn.Send(w.Ctx(), "io.systemd.Resolve.Monitor.SubscribeQueryResults", nil, varlink.More)
if err != nil {
Expand Down Expand Up @@ -70,7 +66,7 @@ func newListener(m *mgr.Manager) (*Listener, error) {
return listener, nil
}

func (l *Listener) flish() error {
func (l *Listener) flush() error {
// Nothing to flush
return nil
}
Expand Down Expand Up @@ -107,41 +103,5 @@ func (l *Listener) processAnswer(queryResult *QueryResult) {
}
}

for _, ip := range ips {
// Never save domain attributions for localhost IPs.
if netutils.GetIPScope(ip) == netutils.HostLocal {
continue
}
fqdn := dns.Fqdn(domain)

// Create new record for this IP.
record := resolver.ResolvedDomain{
Domain: fqdn,
Resolver: &ResolverInfo,
DNSRequestContext: &resolver.DNSRequestContext{},
Expires: 0,
}

for {
nextDomain, isCNAME := cnames[domain]
if !isCNAME {
break
}

record.CNAMEs = append(record.CNAMEs, nextDomain)
domain = nextDomain
}

info := resolver.IPInfo{
IP: ip.String(),
}

// Add the new record to the resolved domains for this IP and scope.
info.AddDomain(record)

// Save if the record is new or has been updated.
if err := info.Save(); err != nil {
log.Errorf("nameserver: failed to save IP info record: %s", err)
}
}
saveDomain(domain, ips, cnames)
}
64 changes: 15 additions & 49 deletions service/firewall/interception/dnslistener/eventlistener_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,32 +10,31 @@ import (
"strings"

"github.com/miekg/dns"
"github.com/safing/portmaster/base/log"
"github.com/safing/portmaster/service/mgr"
"github.com/safing/portmaster/service/network/netutils"
"github.com/safing/portmaster/service/resolver"
)

type Listener struct {
etw *ETWSession
}

func newListener(m *mgr.Manager) (*Listener, error) {
func newListener(module *DNSListener) (*Listener, error) {
listener := &Listener{}
var err error
listener.etw, err = NewSession("C:/Dev/ETWDNSTrace.dll", listener.processEvent)
// Intialize new dns event session.
listener.etw, err = NewSession(module.instance.OSIntegration().GetETWInterface(), listener.processEvent)
if err != nil {
return nil, err
}

m.Go("etw-dns-event-listener", func(w *mgr.WorkerCtx) error {
// Start lisening for events.
module.mgr.Go("etw-dns-event-listener", func(w *mgr.WorkerCtx) error {
return listener.etw.StartTrace()
})

return listener, nil
}

func (l *Listener) flish() error {
func (l *Listener) flush() error {
return l.etw.FlushTrace()
}

Expand All @@ -44,8 +43,9 @@ func (l *Listener) stop() error {
return fmt.Errorf("listener is nil")
}
if l.etw == nil {
return fmt.Errorf("invalid ewt session")
return fmt.Errorf("invalid etw session")
}
// Stop and destroy trace. Destory should be called even if stop failes for some reason.
err := l.etw.StopTrace()
err2 := l.etw.DestroySession()

Expand All @@ -54,12 +54,13 @@ func (l *Listener) stop() error {
}

if err2 != nil {
return fmt.Errorf("DestorySession failed: %d", err)
return fmt.Errorf("DestorySession failed: %d", err2)
}
return nil
}

func (l *Listener) processEvent(domain string, result string) {
// Ignore empty results
if len(result) == 0 {
return
}
Expand All @@ -69,60 +70,25 @@ func (l *Listener) processEvent(domain string, result string) {

resultArray := strings.Split(result, ";")
for _, r := range resultArray {
// For result different then IP the string starts with "type:"
if strings.HasPrefix(r, "type:") {
dnsValueArray := strings.Split(r, " ")
if len(dnsValueArray) < 3 {
continue
}

if value, err := strconv.ParseInt(dnsValueArray[1], 10, 8); err == nil && value == 5 {
// CNAME
// Ignore evrything else exept CNAME.
if value, err := strconv.ParseInt(dnsValueArray[1], 10, 16); err == nil && value == int64(dns.TypeCNAME) {
cnames[domain] = dnsValueArray[2]
}

} else {
// The events deosn't start with "type:" that means it's an IP address.
ip := net.ParseIP(r)
if ip != nil {
ips = append(ips, ip)
}
}
}

for _, ip := range ips {
// Never save domain attributions for localhost IPs.
if netutils.GetIPScope(ip) == netutils.HostLocal {
continue
}
fqdn := dns.Fqdn(domain)

// Create new record for this IP.
record := resolver.ResolvedDomain{
Domain: fqdn,
Resolver: &ResolverInfo,
DNSRequestContext: &resolver.DNSRequestContext{},
Expires: 0,
}

for {
nextDomain, isCNAME := cnames[domain]
if !isCNAME {
break
}

record.CNAMEs = append(record.CNAMEs, nextDomain)
domain = nextDomain
}

info := resolver.IPInfo{
IP: ip.String(),
}

// Add the new record to the resolved domains for this IP and scope.
info.AddDomain(record)

// Save if the record is new or has been updated.
if err := info.Save(); err != nil {
log.Errorf("nameserver: failed to save IP info record: %s", err)
}
}
saveDomain(domain, ips, cnames)
}
Loading

0 comments on commit b3b1615

Please sign in to comment.