Skip to content

Commit

Permalink
Fix client stats payload measurement (#931)
Browse files Browse the repository at this point in the history
  • Loading branch information
jefchien authored Oct 26, 2023
1 parent a6034e6 commit 52a4af3
Show file tree
Hide file tree
Showing 14 changed files with 93 additions and 62 deletions.
9 changes: 9 additions & 0 deletions extension/agenthealth/handler/stats/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,15 @@ func (s *Stats) Merge(other Stats) {
if other.EnhancedContainerInsights != nil {
s.EnhancedContainerInsights = other.EnhancedContainerInsights
}
if other.RunningInContainer != nil {
s.RunningInContainer = other.RunningInContainer
}
if other.RegionType != nil {
s.RegionType = other.RegionType
}
if other.Mode != nil {
s.Mode = other.Mode
}
}

func (s *Stats) Marshal() (string, error) {
Expand Down
14 changes: 11 additions & 3 deletions extension/agenthealth/handler/stats/agent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,23 @@ func TestMerge(t *testing.T) {
assert.EqualValues(t, 1.3, *stats.CpuPercent)
assert.EqualValues(t, 123, *stats.MemoryBytes)
stats.Merge(Stats{
CpuPercent: aws.Float64(1.5),
MemoryBytes: aws.Uint64(133),
FileDescriptorCount: aws.Int32(456),
ThreadCount: aws.Int32(789),
LatencyMillis: aws.Int64(1234),
PayloadBytes: aws.Int(5678),
StatusCode: aws.Int(200),
ImdsFallbackSucceed: aws.Int(1),
SharedConfigFallback: aws.Int(1),
ImdsFallbackSucceed: aws.Int(1),
AppSignals: aws.Int(1),
EnhancedContainerInsights: aws.Int(1),
RunningInContainer: aws.Int(0),
RegionType: aws.String("RegionType"),
Mode: aws.String("Mode"),
})
assert.EqualValues(t, 1.3, *stats.CpuPercent)
assert.EqualValues(t, 123, *stats.MemoryBytes)
assert.EqualValues(t, 1.5, *stats.CpuPercent)
assert.EqualValues(t, 133, *stats.MemoryBytes)
assert.EqualValues(t, 456, *stats.FileDescriptorCount)
assert.EqualValues(t, 789, *stats.ThreadCount)
assert.EqualValues(t, 1234, *stats.LatencyMillis)
Expand All @@ -42,6 +47,9 @@ func TestMerge(t *testing.T) {
assert.EqualValues(t, 1, *stats.SharedConfigFallback)
assert.EqualValues(t, 1, *stats.AppSignals)
assert.EqualValues(t, 1, *stats.EnhancedContainerInsights)
assert.EqualValues(t, 0, *stats.RunningInContainer)
assert.EqualValues(t, "RegionType", *stats.RegionType)
assert.EqualValues(t, "Mode", *stats.Mode)
}

func TestMarshal(t *testing.T) {
Expand Down
38 changes: 20 additions & 18 deletions extension/agenthealth/handler/stats/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import (
)

const (
handlerID = "cloudwatchagent.ClientStatsHandler"
handlerID = "cloudwatchagent.ClientStats"
ttlDuration = 10 * time.Second
cacheSize = 1000
)
Expand All @@ -35,13 +35,11 @@ type requestRecorder struct {
}

type clientStatsHandler struct {
mu sync.Mutex

filter agent.OperationsFilter
getOperationName func(ctx context.Context) string
getRequestID func(ctx context.Context) string

statsByOperation map[string]agent.Stats
statsByOperation sync.Map
requestCache *ttlcache.Cache[string, *requestRecorder]
}

Expand All @@ -59,7 +57,6 @@ func NewHandler(filter agent.OperationsFilter) Stats {
getOperationName: awsmiddleware.GetOperationName,
getRequestID: awsmiddleware.GetRequestID,
requestCache: requestCache,
statsByOperation: make(map[string]agent.Stats),
}
}

Expand All @@ -68,19 +65,22 @@ func (csh *clientStatsHandler) ID() string {
}

func (csh *clientStatsHandler) Position() awsmiddleware.HandlerPosition {
return awsmiddleware.Before
return awsmiddleware.After
}

func (csh *clientStatsHandler) HandleRequest(ctx context.Context, r *http.Request) {
operation := csh.getOperationName(ctx)
if !csh.filter.IsAllowed(operation) {
return
}
csh.mu.Lock()
defer csh.mu.Unlock()
requestID := csh.getRequestID(ctx)
recorder := &requestRecorder{start: time.Now()}
recorder.payloadBytes, _ = io.Copy(io.Discard, r.Body)
if r.GetBody != nil {
body, err := r.GetBody()
if err == nil {
recorder.payloadBytes, err = io.Copy(io.Discard, body)
}
}
csh.requestCache.Set(requestID, recorder, ttlcache.DefaultTTL)
}

Expand All @@ -89,8 +89,6 @@ func (csh *clientStatsHandler) HandleResponse(ctx context.Context, r *http.Respo
if !csh.filter.IsAllowed(operation) {
return
}
csh.mu.Lock()
defer csh.mu.Unlock()
requestID := csh.getRequestID(ctx)
item, ok := csh.requestCache.GetAndDelete(requestID)
if !ok {
Expand All @@ -101,15 +99,19 @@ func (csh *clientStatsHandler) HandleResponse(ctx context.Context, r *http.Respo
PayloadBytes: aws.Int(int(recorder.payloadBytes)),
StatusCode: aws.Int(r.StatusCode),
}
latency := time.Since(recorder.start).Milliseconds()
stats.LatencyMillis = aws.Int64(latency)
csh.statsByOperation[operation] = stats
latency := time.Since(recorder.start)
stats.LatencyMillis = aws.Int64(latency.Milliseconds())
csh.statsByOperation.Store(operation, stats)
}

func (csh *clientStatsHandler) Stats(operation string) agent.Stats {
csh.mu.Lock()
defer csh.mu.Unlock()
stats := csh.statsByOperation[operation]
csh.statsByOperation[operation] = agent.Stats{}
value, ok := csh.statsByOperation.Load(operation)
if !ok {
return agent.Stats{}
}
stats, ok := value.(agent.Stats)
if !ok {
return agent.Stats{}
}
return stats
}
2 changes: 1 addition & 1 deletion extension/agenthealth/handler/stats/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func TestHandle(t *testing.T) {
handler.(*clientStatsHandler).getOperationName = func(context.Context) string {
return operation
}
assert.Equal(t, awsmiddleware.Before, handler.Position())
assert.Equal(t, awsmiddleware.After, handler.Position())
assert.Equal(t, handlerID, handler.ID())
body := []byte("test payload size")
req, err := http.NewRequest("", "localhost", bytes.NewBuffer(body))
Expand Down
4 changes: 2 additions & 2 deletions extension/agenthealth/handler/stats/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@ import (
)

const (
handlerID = "cloudwatchagent.StatsHandler"
handlerID = "cloudwatchagent.AgentStats"
headerKeyAgentStats = "X-Amz-Agent-Stats"
)

func NewHandlers(logger *zap.Logger, cfg agent.StatsConfig) ([]awsmiddleware.RequestHandler, []awsmiddleware.ResponseHandler) {
filter := agent.NewOperationsFilter(cfg.Operations...)
clientStats := client.NewHandler(filter)
stats := newStatsHandler(logger, filter, []agent.StatsProvider{clientStats, provider.GetProcessStats(), provider.GetFlagsStats()})
return []awsmiddleware.RequestHandler{clientStats, stats}, []awsmiddleware.ResponseHandler{clientStats}
return []awsmiddleware.RequestHandler{stats, clientStats}, []awsmiddleware.ResponseHandler{clientStats}
}

type statsHandler struct {
Expand Down
29 changes: 11 additions & 18 deletions extension/agenthealth/handler/stats/provider/flag.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,15 @@ const (
flagGetInterval = 5 * time.Minute
)

type BoolFlag int
type Flag int

const (
FlagIMDSFallbackSucceed BoolFlag = iota
FlagIMDSFallbackSucceed Flag = iota
FlagSharedConfigFallback
FlagAppSignal
FlagEnhancedContainerInsights
FlagRunningInContainer
)

type StringFlag int

const (
FlagMode StringFlag = iota
FlagMode
FlagRegionType
)

Expand All @@ -41,8 +36,8 @@ var (

type FlagStats interface {
agent.StatsProvider
SetFlag(flag BoolFlag)
SetFlagWithValue(flag StringFlag, value string)
SetFlag(flag Flag)
SetFlagWithValue(flag Flag, value string)
}

type flagStats struct {
Expand All @@ -54,20 +49,18 @@ type flagStats struct {
var _ FlagStats = (*flagStats)(nil)

func (p *flagStats) update() {
p.mu.Lock()
defer p.mu.Unlock()
p.stats = agent.Stats{
p.stats.Store(agent.Stats{
ImdsFallbackSucceed: p.getIntFlag(FlagIMDSFallbackSucceed, false),
SharedConfigFallback: p.getIntFlag(FlagSharedConfigFallback, false),
AppSignals: p.getIntFlag(FlagAppSignal, false),
EnhancedContainerInsights: p.getIntFlag(FlagEnhancedContainerInsights, false),
RunningInContainer: p.getIntFlag(FlagRunningInContainer, true),
Mode: p.getStringFlag(FlagMode),
RegionType: p.getStringFlag(FlagRegionType),
}
})
}

func (p *flagStats) getIntFlag(flag BoolFlag, missingAsZero bool) *int {
func (p *flagStats) getIntFlag(flag Flag, missingAsZero bool) *int {
if _, ok := p.flags.Load(flag); ok {
return aws.Int(1)
}
Expand All @@ -77,7 +70,7 @@ func (p *flagStats) getIntFlag(flag BoolFlag, missingAsZero bool) *int {
return nil
}

func (p *flagStats) getStringFlag(flag StringFlag) *string {
func (p *flagStats) getStringFlag(flag Flag) *string {
value, ok := p.flags.Load(flag)
if !ok {
return nil
Expand All @@ -90,14 +83,14 @@ func (p *flagStats) getStringFlag(flag StringFlag) *string {
return aws.String(str)
}

func (p *flagStats) SetFlag(flag BoolFlag) {
func (p *flagStats) SetFlag(flag Flag) {
if _, ok := p.flags.Load(flag); !ok {
p.flags.Store(flag, true)
p.update()
}
}

func (p *flagStats) SetFlagWithValue(flag StringFlag, value string) {
func (p *flagStats) SetFlagWithValue(flag Flag, value string) {
if _, ok := p.flags.Load(flag); !ok {
p.flags.Store(flag, value)
p.update()
Expand Down
8 changes: 4 additions & 4 deletions extension/agenthealth/handler/stats/provider/flag_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,23 @@ import (
func TestFlagStats(t *testing.T) {
t.Setenv(envconfig.RunInContainer, envconfig.TrueValue)
provider := newFlagStats(time.Microsecond)
got := provider.stats
got := provider.getStats()
assert.Nil(t, got.ImdsFallbackSucceed)
assert.Nil(t, got.SharedConfigFallback)
assert.NotNil(t, got.RunningInContainer)
assert.Equal(t, 1, *got.RunningInContainer)
provider.SetFlag(FlagIMDSFallbackSucceed)
assert.Nil(t, got.ImdsFallbackSucceed)
got = provider.stats
got = provider.getStats()
assert.NotNil(t, got.ImdsFallbackSucceed)
assert.Equal(t, 1, *got.ImdsFallbackSucceed)
assert.Nil(t, got.SharedConfigFallback)
provider.SetFlag(FlagSharedConfigFallback)
got = provider.stats
got = provider.getStats()
assert.NotNil(t, got.SharedConfigFallback)
assert.Equal(t, 1, *got.SharedConfigFallback)
provider.SetFlagWithValue(FlagMode, "test")
got = provider.stats
got = provider.getStats()
assert.NotNil(t, got.Mode)
assert.Equal(t, "test", *got.Mode)
}
19 changes: 14 additions & 5 deletions extension/agenthealth/handler/stats/provider/interval.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package provider

import (
"sync"
"sync/atomic"
"time"

"github.com/aws/amazon-cloudwatch-agent/extension/agenthealth/handler/stats/agent"
Expand All @@ -13,29 +14,37 @@ import (
// intervalStats restricts the Stats get function to once
// per interval.
type intervalStats struct {
mu sync.Mutex
mu sync.RWMutex
interval time.Duration

getOnce *sync.Once
lastGet time.Time

stats agent.Stats
stats atomic.Value
}

var _ agent.StatsProvider = (*intervalStats)(nil)

func (p *intervalStats) Stats(string) agent.Stats {
p.mu.Lock()
defer p.mu.Unlock()
p.mu.RLock()
defer p.mu.RUnlock()
var stats agent.Stats
p.getOnce.Do(func() {
p.lastGet = time.Now()
stats = p.stats
stats = p.getStats()
go p.allowNextGetAfter(p.interval)
})
return stats
}

func (p *intervalStats) getStats() agent.Stats {
var stats agent.Stats
if value := p.stats.Load(); value != nil {
stats = value.(agent.Stats)
}
return stats
}

func (p *intervalStats) allowNextGetAfter(interval time.Duration) {
time.Sleep(interval)
p.mu.Lock()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,15 @@ import (

"github.com/aws/aws-sdk-go/aws"
"github.com/stretchr/testify/assert"

"github.com/aws/amazon-cloudwatch-agent/extension/agenthealth/handler/stats/agent"
)

func TestIntervalStats(t *testing.T) {
s := newIntervalStats(time.Millisecond)
s.stats.ThreadCount = aws.Int32(2)
s.stats.Store(agent.Stats{
ThreadCount: aws.Int32(2),
})
got := s.Stats("")
assert.NotNil(t, got.ThreadCount)
got = s.Stats("")
Expand Down
6 changes: 2 additions & 4 deletions extension/agenthealth/handler/stats/provider/process.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,12 @@ func (p *processStats) updateLoop() {
}

func (p *processStats) refresh() {
p.mu.Lock()
defer p.mu.Unlock()
p.stats = agent.Stats{
p.stats.Store(agent.Stats{
CpuPercent: p.cpuPercent(),
MemoryBytes: p.memoryBytes(),
FileDescriptorCount: p.fileDescriptorCount(),
ThreadCount: p.threadCount(),
}
})
}

func newProcessStats(proc processMetrics, interval time.Duration) *processStats {
Expand Down
4 changes: 2 additions & 2 deletions extension/agenthealth/handler/stats/provider/process_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func TestProcessStats(t *testing.T) {
testErr := errors.New("test error")
mock := &mockProcessMetrics{}
provider := newProcessStats(mock, time.Millisecond)
got := provider.stats
got := provider.getStats()
assert.NotNil(t, got.CpuPercent)
assert.NotNil(t, got.MemoryBytes)
assert.NotNil(t, got.FileDescriptorCount)
Expand All @@ -61,7 +61,7 @@ func TestProcessStats(t *testing.T) {
assert.EqualValues(t, 4, *got.ThreadCount)
mock.err = testErr
time.Sleep(2 * time.Millisecond)
got = provider.stats
got = provider.getStats()
assert.Nil(t, got.CpuPercent)
assert.Nil(t, got.MemoryBytes)
assert.Nil(t, got.FileDescriptorCount)
Expand Down
Loading

0 comments on commit 52a4af3

Please sign in to comment.