diff --git a/jmx/jmx.go b/jmx/jmx.go index 6cb16fd9..e4750ea6 100644 --- a/jmx/jmx.go +++ b/jmx/jmx.go @@ -323,6 +323,7 @@ func doQuery(ctx context.Context, out chan []byte, queryErrC chan error, querySt queryErrC <- fmt.Errorf("reading nrjmx stdout: %s", err.Error()) } out <- b + return } } @@ -341,6 +342,7 @@ func Query(objectPattern string, timeoutMillis int) (result map[string]interface // receiveResult checks for channels to receive result from nrjmx command. func receiveResult(lineC chan []byte, queryErrC chan error, cancelFn context.CancelFunc, objectPattern string, timeout time.Duration) (result map[string]interface{}, err error) { + defer logAvailableWarnings(cmdWarnC) var warn string for { select { @@ -361,12 +363,11 @@ func receiveResult(lineC chan []byte, queryErrC chan error, cancelFn context.Can for k, v := range r { result[k] = v } + return case warn = <-cmdWarnC: - // change on the API is required to return warnings log.Warn(warn) - return case err = <-cmdErrC: return @@ -383,3 +384,15 @@ func receiveResult(lineC chan []byte, queryErrC chan error, cancelFn context.Can } } } + +func logAvailableWarnings(channel chan string) { + var warn string + for { + select { + case warn = <-channel: + log.Warn(warn) + default: + return + } + } +} diff --git a/jmx/jmx_test.go b/jmx/jmx_test.go index 7e4634ec..0b168562 100644 --- a/jmx/jmx_test.go +++ b/jmx/jmx_test.go @@ -2,6 +2,7 @@ package jmx import ( "bufio" + "bytes" "context" "flag" "fmt" @@ -10,6 +11,7 @@ import ( "testing" "time" + "github.com/newrelic/infra-integrations-sdk/v4/log" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -179,17 +181,25 @@ func openWaitWithSSL(hostname, port, username, password, keyStore, keyStorePassw } func Test_receiveResult_warningsDoNotBreakResultReception(t *testing.T) { + + var buf bytes.Buffer + log.SetOutput(&buf) + _, cancelFn := context.WithCancel(context.Background()) resultCh := make(chan []byte, 1) queryErrCh := make(chan error) outTimeout := time.Duration(timeoutMillis) * time.Millisecond + warningMessage := fmt.Sprint("WARNING foo bar") + cmdWarnC <- warningMessage - _, _ = receiveResult(resultCh, queryErrCh, cancelFn, "empty", outTimeout) + resultCh <- []byte("{\"foo\":1}") - cmdErrC <- fmt.Errorf("WARNING foo bar") - assert.Equal(t, <-cmdErrC, fmt.Errorf("WARNING foo bar")) + result, err := receiveResult(resultCh, queryErrCh, cancelFn, "foo", outTimeout) - resultCh <- []byte("{foo}") - assert.Equal(t, string(<-resultCh), "{foo}") + assert.NoError(t, err) + assert.Equal(t, map[string]interface{}{ + "foo": 1., + }, result) + assert.Equal(t, fmt.Sprintf("[WARN] %s\n", warningMessage), buf.String()) } diff --git a/log/log.go b/log/log.go index 6bd9996a..35b82a36 100644 --- a/log/log.go +++ b/log/log.go @@ -81,6 +81,11 @@ func SetupLogging(verbose bool) { } } +// SetOutput sets output stream +func SetOutput(w io.Writer) { + globalLogger.logger.SetOutput(w) +} + // Debug logs a formatted message at level Debug. func Debug(format string, args ...interface{}) { globalLogger.Debugf(format, args...)