diff --git a/service/firewall/interception/dnsmonitor/etwlink_windows.go b/service/firewall/interception/dnsmonitor/etwlink_windows.go index d014bbab1..cb9d8675f 100644 --- a/service/firewall/interception/dnsmonitor/etwlink_windows.go +++ b/service/firewall/interception/dnsmonitor/etwlink_windows.go @@ -14,7 +14,7 @@ import ( ) type ETWSession struct { - i integration.ETWFunctions + i *integration.ETWFunctions shutdownGuard atomic.Bool shutdownMutex sync.Mutex @@ -23,7 +23,10 @@ type ETWSession struct { } // NewSession creates new ETW event listener and initilizes it. This is a low level interface, make sure to call DestorySession when you are done using it. -func NewSession(etwInterface integration.ETWFunctions, callback func(domain string, result string)) (*ETWSession, error) { +func NewSession(etwInterface *integration.ETWFunctions, callback func(domain string, result string)) (*ETWSession, error) { + if etwInterface == nil { + return nil, fmt.Errorf("etw interface was nil") + } etwSession := &ETWSession{ i: etwInterface, } @@ -47,7 +50,7 @@ func NewSession(etwInterface integration.ETWFunctions, callback func(domain stri // Initialize session. err := etwSession.i.InitializeSession(etwSession.state) if err != nil { - return nil, fmt.Errorf("failed to initialzie session: %q", err) + return nil, fmt.Errorf("failed to initialize session: %q", err) } return etwSession, nil @@ -65,6 +68,10 @@ func (l *ETWSession) IsRunning() bool { // FlushTrace flushes the trace buffer. func (l *ETWSession) FlushTrace() error { + if l.i == nil { + return fmt.Errorf("session not initialized") + } + l.shutdownMutex.Lock() defer l.shutdownMutex.Unlock() @@ -83,6 +90,9 @@ func (l *ETWSession) StopTrace() error { // DestroySession closes the session and frees the allocated memory. Listener cannot be used after this function is called. func (l *ETWSession) DestroySession() error { + if l.i == nil { + return fmt.Errorf("session not initialized") + } l.shutdownMutex.Lock() defer l.shutdownMutex.Unlock() diff --git a/service/firewall/interception/dnsmonitor/eventlistener_windows.go b/service/firewall/interception/dnsmonitor/eventlistener_windows.go index b6a39fd8d..a46e8cc6b 100644 --- a/service/firewall/interception/dnsmonitor/eventlistener_windows.go +++ b/service/firewall/interception/dnsmonitor/eventlistener_windows.go @@ -23,22 +23,38 @@ func newListener(module *DNSMonitor) (*Listener, error) { ResolverInfo.Source = resolver.ServerSourceETW listener := &Listener{} - var err error // Initialize new dns event session. - listener.etw, err = NewSession(module.instance.OSIntegration().GetETWInterface(), listener.processEvent) + err := initializeSessions(module, listener) if err != nil { - return nil, err + // Listen for event if the dll has been loaded + module.instance.OSIntegration().OnInitializedEvent.AddCallback("loader-listener", func(wc *mgr.WorkerCtx, s struct{}) (cancel bool, err error) { + err = initializeSessions(module, listener) + if err != nil { + return false, err + } + return true, nil + }) } + return listener, nil +} - // Start listening for events. +func initializeSessions(module *DNSMonitor, listener *Listener) error { + var err error + listener.etw, err = NewSession(module.instance.OSIntegration().GetETWInterface(), listener.processEvent) + if err != nil { + return err + } + // Start listener module.mgr.Go("etw-dns-event-listener", func(w *mgr.WorkerCtx) error { return listener.etw.StartTrace() }) - - return listener, nil + return nil } func (l *Listener) flush() error { + if l.etw == nil { + return fmt.Errorf("etw not initialized") + } return l.etw.FlushTrace() } diff --git a/service/integration/etw_windows.go b/service/integration/etw_windows.go index eac3ad8f4..d655967a7 100644 --- a/service/integration/etw_windows.go +++ b/service/integration/etw_windows.go @@ -19,8 +19,8 @@ type ETWFunctions struct { stopOldSession *windows.Proc } -func initializeETW(dll *windows.DLL) (ETWFunctions, error) { - var functions ETWFunctions +func initializeETW(dll *windows.DLL) (*ETWFunctions, error) { + functions := &ETWFunctions{} var err error functions.createState, err = dll.FindProc("PM_ETWCreateState") if err != nil { diff --git a/service/integration/integration_windows.go b/service/integration/integration_windows.go index 80b1fd74c..b297b3638 100644 --- a/service/integration/integration_windows.go +++ b/service/integration/integration_windows.go @@ -5,24 +5,55 @@ package integration import ( "fmt" + "sync" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/updates" "golang.org/x/sys/windows" ) type OSSpecific struct { dll *windows.DLL - etwFunctions ETWFunctions + etwFunctions *ETWFunctions } // Initialize loads the dll and finds all the needed functions from it. func (i *OSIntegration) Initialize() error { + // Try to load dll + err := i.loadDLL() + if err != nil { + log.Errorf("integration: failed to load dll: %s", err) + + callbackLock := sync.Mutex{} + // listen for event from the updater and try to load again if any. + i.instance.Updates().EventResourcesUpdated.AddCallback("core-dll-loader", func(wc *mgr.WorkerCtx, s struct{}) (cancel bool, err error) { + // Make sure no multiple callas are executed at the same time. + callbackLock.Lock() + defer callbackLock.Unlock() + + // Try to load again. + err = i.loadDLL() + if err != nil { + log.Errorf("integration: failed to load dll: %s", err) + } else { + log.Info("integration: initialize successful after updater event") + } + return false, nil + }) + + } else { + log.Info("integration: initialize successful") + } + return nil +} + +func (i *OSIntegration) loadDLL() error { // Find path to the dll. file, err := updates.GetPlatformFile("dll/portmaster-core.dll") if err != nil { return err } - // Load the DLL. i.os.dll, err = windows.LoadDLL(file.Path()) if err != nil { @@ -35,10 +66,13 @@ func (i *OSIntegration) Initialize() error { return err } + // Notify listeners + i.OnInitializedEvent.Submit(struct{}{}) + return nil } -// CleanUp releases any resourses allocated during initializaion. +// CleanUp releases any resources allocated during initialization. func (i *OSIntegration) CleanUp() error { if i.os.dll != nil { return i.os.dll.Release() @@ -46,7 +80,7 @@ func (i *OSIntegration) CleanUp() error { return nil } -// GetETWInterface return struct containing all the ETW related functions. -func (i *OSIntegration) GetETWInterface() ETWFunctions { +// GetETWInterface return struct containing all the ETW related functions, and nil if it was not loaded yet +func (i *OSIntegration) GetETWInterface() *ETWFunctions { return i.os.etwFunctions } diff --git a/service/integration/module.go b/service/integration/module.go index 0e43798ae..6a2374349 100644 --- a/service/integration/module.go +++ b/service/integration/module.go @@ -7,8 +7,9 @@ import ( // OSIntegration module provides special integration with the OS. type OSIntegration struct { - m *mgr.Manager - states *mgr.StateMgr + m *mgr.Manager + + OnInitializedEvent *mgr.EventMgr[struct{}] //nolint:unused os OSSpecific @@ -20,10 +21,9 @@ type OSIntegration struct { func New(instance instance) (*OSIntegration, error) { m := mgr.New("OSIntegration") module := &OSIntegration{ - m: m, - states: m.NewStateMgr(), - - instance: instance, + m: m, + OnInitializedEvent: mgr.NewEventMgr[struct{}]("on-initialized", m), + instance: instance, } return module, nil diff --git a/service/network/connection.go b/service/network/connection.go index b3dd70fcc..1c1bbf198 100644 --- a/service/network/connection.go +++ b/service/network/connection.go @@ -550,7 +550,11 @@ func (conn *Connection) GatherConnectionInfo(pkt packet.Packet) (err error) { if module.instance.Resolver().IsDisabled() && conn.shouldWaitForDomain() { // Flush the dns listener buffer and try again. for i := range 4 { - _ = module.instance.DNSMonitor().Flush() + err = module.instance.DNSMonitor().Flush() + if err != nil { + // Error flushing, dont try again. + break + } ipinfo, err = resolver.GetIPInfo(resolver.IPInfoProfileScopeGlobal, pkt.Info().RemoteIP().String()) if err == nil { log.Tracer(pkt.Ctx()).Debugf("network: found domain from dnsmonitor after %d tries", i+1)