diff --git a/extension/agenthealth/handler/stats/client/client.go b/extension/agenthealth/handler/stats/client/client.go index 40567d929e..025df31598 100644 --- a/extension/agenthealth/handler/stats/client/client.go +++ b/extension/agenthealth/handler/stats/client/client.go @@ -75,11 +75,20 @@ func (csh *clientStatsHandler) HandleRequest(ctx context.Context, r *http.Reques } requestID := csh.getRequestID(ctx) recorder := &requestRecorder{start: time.Now()} - if r.GetBody != nil { - body, err := r.GetBody() - if err == nil { - recorder.payloadBytes, err = io.Copy(io.Discard, body) + if r.ContentLength != 0 { + recorder.payloadBytes = r.ContentLength + } else if r.Body != nil { + rsc, ok := r.Body.(aws.ReaderSeekerCloser) + if !ok { + rsc = aws.ReadSeekCloser(r.Body) } + length, _ := aws.SeekerLen(rsc) + if length == -1 { + if body, err := r.GetBody(); err == nil { + length, _ = io.Copy(io.Discard, body) + } + } + recorder.payloadBytes = length } csh.requestCache.Set(requestID, recorder, ttlcache.DefaultTTL) } diff --git a/extension/agenthealth/handler/stats/client/client_test.go b/extension/agenthealth/handler/stats/client/client_test.go index 649d4b340a..35ac023e0e 100644 --- a/extension/agenthealth/handler/stats/client/client_test.go +++ b/extension/agenthealth/handler/stats/client/client_test.go @@ -11,6 +11,7 @@ import ( "time" "github.com/amazon-contributing/opentelemetry-collector-contrib/extension/awsmiddleware" + "github.com/aws/aws-sdk-go/aws" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -28,6 +29,7 @@ func TestHandle(t *testing.T) { body := []byte("test payload size") req, err := http.NewRequest("", "localhost", bytes.NewBuffer(body)) require.NoError(t, err) + req.ContentLength = 20 ctx := context.Background() handler.HandleRequest(ctx, req) got := handler.Stats(operation) @@ -41,6 +43,25 @@ func TestHandle(t *testing.T) { assert.NotNil(t, got.PayloadBytes) assert.NotNil(t, got.StatusCode) assert.Equal(t, http.StatusOK, *got.StatusCode) - assert.Equal(t, 17, *got.PayloadBytes) + assert.Equal(t, 20, *got.PayloadBytes) assert.GreaterOrEqual(t, *got.LatencyMillis, int64(1)) + + // without content length + req.ContentLength = 0 + handler.HandleRequest(ctx, req) + handler.HandleResponse(ctx, &http.Response{StatusCode: http.StatusOK}) + got = handler.Stats(operation) + assert.NotNil(t, got.PayloadBytes) + assert.Equal(t, 17, *got.PayloadBytes) + + // with seeker + body = append(body, " with seeker"...) + req, err = http.NewRequest("", "localhost", aws.ReadSeekCloser(bytes.NewReader(body))) + require.NoError(t, err) + req.ContentLength = 0 + handler.HandleRequest(ctx, req) + handler.HandleResponse(ctx, &http.Response{StatusCode: http.StatusOK}) + got = handler.Stats(operation) + assert.NotNil(t, got.PayloadBytes) + assert.Equal(t, 29, *got.PayloadBytes) }