From 6fd94b24a3ded68cb19ce5cce1e22592f4f69107 Mon Sep 17 00:00:00 2001 From: Jeffrey Chien Date: Mon, 22 Jul 2024 19:49:04 -0400 Subject: [PATCH] Add unit tests. --- .../windows_event_log/service_monitor.go | 14 ++++- .../windows_event_log/service_monitor_test.go | 62 +++++++++++++++++++ .../wineventlog/wineventlog.go | 23 ++++--- 3 files changed, 86 insertions(+), 13 deletions(-) create mode 100644 plugins/inputs/windows_event_log/service_monitor_test.go diff --git a/plugins/inputs/windows_event_log/service_monitor.go b/plugins/inputs/windows_event_log/service_monitor.go index 1e21076c05..4812e4c5d8 100644 --- a/plugins/inputs/windows_event_log/service_monitor.go +++ b/plugins/inputs/windows_event_log/service_monitor.go @@ -7,7 +7,7 @@ package windows_event_log import ( - "fmt" + "errors" "log" "time" @@ -21,6 +21,14 @@ const ( serviceName = "eventlog" ) +var ( + errServiceNotRunning = errors.New("service is not running") +) + +type statusChecker interface { + Query() (svc.Status, error) +} + type serviceMonitor struct { listeners []chan struct{} done chan struct{} @@ -82,7 +90,7 @@ func (m *serviceMonitor) notify() { } } -func getPID(service *mgr.Service) (uint32, error) { +func getPID(service statusChecker) (uint32, error) { status, err := service.Query() if err != nil { return 0, err @@ -90,5 +98,5 @@ func getPID(service *mgr.Service) (uint32, error) { if status.State == svc.Running { return status.ProcessId, nil } - return 0, fmt.Errorf("service is not running") + return 0, errServiceNotRunning } diff --git a/plugins/inputs/windows_event_log/service_monitor_test.go b/plugins/inputs/windows_event_log/service_monitor_test.go new file mode 100644 index 0000000000..08f9612a8e --- /dev/null +++ b/plugins/inputs/windows_event_log/service_monitor_test.go @@ -0,0 +1,62 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: MIT + +//go:build windows +// +build windows + +package windows_event_log + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "golang.org/x/sys/windows/svc" +) + +type mockStatusCheck struct { + status svc.Status + err error +} + +func (m *mockStatusCheck) Query() (svc.Status, error) { + return m.status, m.err +} + +func TestGetPID(t *testing.T) { + testErr := errors.New("test error") + testCases := map[string]struct { + status svc.Status + err error + wantPID uint32 + wantErr error + }{ + "WithQueryError": { + err: testErr, + wantPID: 0, + wantErr: testErr, + }, + "WithStoppedService": { + status: svc.Status{ + State: svc.Stopped, + ProcessId: 0, + }, + wantPID: 0, + wantErr: errServiceNotRunning, + }, + "WithRunningService": { + status: svc.Status{ + State: svc.Running, + ProcessId: 123, + }, + wantPID: 123, + }, + } + for name, testCase := range testCases { + t.Run(name, func(t *testing.T) { + gotPID, gotErr := getPID(&mockStatusCheck{status: testCase.status, err: testCase.err}) + assert.Equal(t, testCase.wantPID, gotPID) + assert.Equal(t, testCase.wantErr, gotErr) + }) + } +} diff --git a/plugins/inputs/windows_event_log/wineventlog/wineventlog.go b/plugins/inputs/windows_event_log/wineventlog/wineventlog.go index a33fdd61bd..234ad30d1d 100644 --- a/plugins/inputs/windows_event_log/wineventlog/wineventlog.go +++ b/plugins/inputs/windows_event_log/wineventlog/wineventlog.go @@ -89,15 +89,7 @@ func NewEventLog(name string, levels []string, logGroupName, logStreamName, rend func (w *windowsEventLog) Init() error { go w.runSaveState() w.eventOffset = w.loadState() - // Subscribe for events. - // This will fail if the eventlog name has not been registered. - // However returning an error would mean the plugin won't monitor other eventlogs. - err := w.Open() - if werr, ok := err.(*wevtAPIError); ok && werr.api == apiEvtSubscribe { - log.Printf("W! [wineventlog] %v", err) - return nil - } - return err + return w.Open() } func (w *windowsEventLog) SetOutput(fn func(logs.LogEvent)) { @@ -177,7 +169,18 @@ func (w *windowsEventLog) run() { } } +// Open subscription for events. Instead of failing the subscription if the eventlog name has not been registered, +// log the error. func (w *windowsEventLog) Open() error { + err := w.open() + if werr, ok := err.(*wevtAPIError); ok && werr.api == apiEvtSubscribe { + log.Printf("W! [wineventlog] %v", err) + return nil + } + return err +} + +func (w *windowsEventLog) open() error { bookmark, err := CreateBookmark(w.name, w.eventOffset) if err != nil { return err @@ -218,7 +221,7 @@ func (w *windowsEventLog) resubscribe() error { } } w.eventHandle = EvtHandle(0) - return w.Open() + return w.open() } func (w *windowsEventLog) LogGroupName() string {