Skip to content

Commit

Permalink
Fix latest lint errors and provider test
Browse files Browse the repository at this point in the history
  • Loading branch information
chemamartinez committed Nov 23, 2023
1 parent 9a6b06c commit 7d102a2
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 24 deletions.
37 changes: 24 additions & 13 deletions x-pack/libbeat/reader/etw/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
package etw

import (
"encoding/binary"
"syscall"
"testing"
"unsafe"
Expand Down Expand Up @@ -61,6 +60,14 @@ func TestGUIDFromProviderName_EmptyProviderList(t *testing.T) {
mockProviderName := "NonExistentProvider"

EnumerateProvidersFunc = func(pBuffer *ProviderEnumerationInfo, pBufferSize *uint32) error {
// Check if the buffer size is sufficient
requiredSize := uint32(unsafe.Sizeof(ProviderEnumerationInfo{})) + uint32(unsafe.Sizeof(TraceProviderInfo{}))*0 // As there are no providers
if *pBufferSize < requiredSize {
// Set the size required and return the error
*pBufferSize = requiredSize
return ERROR_INSUFFICIENT_BUFFER
}

// Empty list of providers
*pBuffer = ProviderEnumerationInfo{
NumberOfProviders: 0,
Expand All @@ -85,20 +92,22 @@ func TestGUIDFromProviderName_GUIDNotFound(t *testing.T) {
mockGUID := GUID{Data1: 1234, Data2: 5678}

EnumerateProvidersFunc = func(pBuffer *ProviderEnumerationInfo, pBufferSize *uint32) error {
// Create and populate a buffer for the provider name
utf16Name, _ := syscall.UTF16FromString(realProviderName)
nameBuffer := make([]byte, len(utf16Name)*2)
for i, v := range utf16Name {
binary.LittleEndian.PutUint16(nameBuffer[i*2:], v)
requiredSize := uint32(unsafe.Sizeof(ProviderEnumerationInfo{})) + uint32(unsafe.Sizeof(TraceProviderInfo{}))*1 // Size for one provider
if *pBufferSize < requiredSize {
*pBufferSize = requiredSize
return ERROR_INSUFFICIENT_BUFFER
}

// Create and populate a buffer for the provider name
utf16Ptr, _ := syscall.UTF16PtrFromString(realProviderName)

// Create and populate the ProviderEnumerationInfo
*pBuffer = ProviderEnumerationInfo{
NumberOfProviders: 1,
TraceProviderInfoArray: [ANYSIZE_ARRAY]TraceProviderInfo{
{
ProviderGuid: mockGUID,
ProviderNameOffset: uint32(uintptr(unsafe.Pointer(&nameBuffer[0]))),
ProviderNameOffset: uint32(uintptr(unsafe.Pointer(utf16Ptr))),
},
},
}
Expand All @@ -120,20 +129,22 @@ func TestGUIDFromProviderName_Success(t *testing.T) {
mockGUID := GUID{Data1: 1234, Data2: 5678}

EnumerateProvidersFunc = func(pBuffer *ProviderEnumerationInfo, pBufferSize *uint32) error {
// Create and populate a buffer for the provider name
utf16Name, _ := syscall.UTF16FromString(mockProviderName)
nameBuffer := make([]byte, len(utf16Name)*2)
for i, v := range utf16Name {
binary.LittleEndian.PutUint16(nameBuffer[i*2:], v)
requiredSize := uint32(unsafe.Sizeof(ProviderEnumerationInfo{})) + uint32(unsafe.Sizeof(TraceProviderInfo{}))*1 // Size for one provider
if *pBufferSize < requiredSize {
*pBufferSize = requiredSize
return ERROR_INSUFFICIENT_BUFFER
}

// Create and populate a buffer for the provider name
utf16Ptr, _ := syscall.UTF16PtrFromString(mockProviderName)

// Create and populate the ProviderEnumerationInfo
*pBuffer = ProviderEnumerationInfo{
NumberOfProviders: 1,
TraceProviderInfoArray: [ANYSIZE_ARRAY]TraceProviderInfo{
{
ProviderGuid: mockGUID,
ProviderNameOffset: uint32(uintptr(unsafe.Pointer(&nameBuffer[0]))),
ProviderNameOffset: uint32(uintptr(unsafe.Pointer(utf16Ptr))),
},
},
}
Expand Down
8 changes: 5 additions & 3 deletions x-pack/libbeat/reader/etw/syscall_advapi32.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
package etw

import (
"errors"
"syscall"
"unsafe"

Expand Down Expand Up @@ -278,7 +279,7 @@ func _EnableTraceEx2(traceHandle uintptr,
enableProperty uint32,
enableParameters *EnableTraceParameters) error {
r0, _, _ := enableTraceEx2.Call(
uintptr(traceHandle),
traceHandle,
uintptr(unsafe.Pointer(providerId)),
uintptr(isEnabled),
uintptr(level),
Expand All @@ -298,7 +299,7 @@ func _ControlTrace(traceHandle uintptr,
properties *EventTraceProperties,
controlCode uint32) error {
r0, _, _ := controlTraceW.Call(
uintptr(traceHandle),
traceHandle,
uintptr(unsafe.Pointer(instanceName)),
uintptr(unsafe.Pointer(properties)),
uintptr(controlCode))
Expand All @@ -312,7 +313,8 @@ func _ControlTrace(traceHandle uintptr,
func _OpenTrace(logfile *EventTraceLogfile) (uint64, error) {
r0, _, err := openTraceW.Call(
uintptr(unsafe.Pointer(logfile)))
if err.(syscall.Errno) == 0 {
var errno syscall.Errno
if errors.As(err, &errno) && errno == 0 {
return uint64(r0), nil
}
return uint64(r0), err
Expand Down
15 changes: 7 additions & 8 deletions x-pack/libbeat/reader/etw/syscall_tdh.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,12 @@ import (
)

var (
tdh = windows.NewLazySystemDLL("tdh.dll")
tdhEnumerateProviders = tdh.NewProc("TdhEnumerateProviders")
tdhEnumerateProviderFieldInformation = tdh.NewProc("TdhEnumerateProviderFieldInformation")
tdhGetEventInformation = tdh.NewProc("TdhGetEventInformation")
tdhGetEventMapInformation = tdh.NewProc("TdhGetEventMapInformation")
tdhFormatProperty = tdh.NewProc("TdhFormatProperty")
tdhGetProperty = tdh.NewProc("TdhGetProperty")
tdh = windows.NewLazySystemDLL("tdh.dll")
tdhEnumerateProviders = tdh.NewProc("TdhEnumerateProviders")
tdhGetEventInformation = tdh.NewProc("TdhGetEventInformation")
tdhGetEventMapInformation = tdh.NewProc("TdhGetEventMapInformation")
tdhFormatProperty = tdh.NewProc("TdhFormatProperty")
tdhGetProperty = tdh.NewProc("TdhGetProperty")
)

const ANYSIZE_ARRAY = 1 << 25
Expand Down Expand Up @@ -220,7 +219,7 @@ type PropertyDataDescriptor struct {
Reserved uint32
}

// Interface to replace the pointer to the function in unit tests
// EnumerateProvidersFunc is used to replace the pointer to the function in unit tests
var EnumerateProvidersFunc = _TdhEnumerateProviders

// https://learn.microsoft.com/en-us/windows/win32/api/tdh/nf-tdh-tdhenumerateproviders
Expand Down

0 comments on commit 7d102a2

Please sign in to comment.