From b8113643cc4ade5409d547e4917fb6a4d4571cf4 Mon Sep 17 00:00:00 2001 From: Blake Rouse Date: Mon, 2 Dec 2024 08:51:56 -0500 Subject: [PATCH] Only run providers that are referenced in the policy (#6169) --- ...un-providers-referenced-in-the-policy.yaml | 32 ++ .../application/coordinator/coordinator.go | 30 ++ .../coordinator/coordinator_test.go | 81 ++++ internal/pkg/agent/cmd/run.go | 7 +- internal/pkg/agent/transpiler/ast.go | 64 +++ internal/pkg/agent/transpiler/ast_test.go | 151 ++++++ internal/pkg/agent/transpiler/vars.go | 97 ++-- internal/pkg/agent/vars/vars.go | 1 - internal/pkg/composable/benchmark_test.go | 10 +- internal/pkg/composable/config.go | 11 +- internal/pkg/composable/controller.go | 452 +++++++++++++----- internal/pkg/composable/controller_test.go | 106 ++-- 12 files changed, 838 insertions(+), 204 deletions(-) create mode 100644 changelog/fragments/1732840106-Only-run-providers-referenced-in-the-policy.yaml diff --git a/changelog/fragments/1732840106-Only-run-providers-referenced-in-the-policy.yaml b/changelog/fragments/1732840106-Only-run-providers-referenced-in-the-policy.yaml new file mode 100644 index 00000000000..a73f073c080 --- /dev/null +++ b/changelog/fragments/1732840106-Only-run-providers-referenced-in-the-policy.yaml @@ -0,0 +1,32 @@ +# Kind can be one of: +# - breaking-change: a change to previously-documented behavior +# - deprecation: functionality that is being removed in a later release +# - bug-fix: fixes a problem in a previous version +# - enhancement: extends functionality but does not break or fix existing behavior +# - feature: new functionality +# - known-issue: problems that we are aware of in a given version +# - security: impacts on the security of a product or a user’s deployment. +# - upgrade: important information for someone upgrading from a prior version +# - other: does not fit into any of the other categories +kind: enhancement + +# Change summary; a 80ish characters long description of the change. +summary: Only run providers referenced in the policy + +# Long description; in case the summary is not enough to describe the change +# this field accommodate a description without length limits. +# NOTE: This field will be rendered only for breaking-change and known-issue kinds at the moment. +#description: + +# Affected component; usually one of "elastic-agent", "fleet-server", "filebeat", "metricbeat", "auditbeat", "all", etc. +component: elastic-agent + +# PR URL; optional; the PR number that added the changeset. +# If not present is automatically filled by the tooling finding the PR where this changelog fragment has been added. +# NOTE: the tooling supports backports, so it's able to fill the original PR number instead of the backport PR number. +# Please provide it if you are adding a fragment for a different PR. +pr: https://github.com/elastic/elastic-agent/pull/6169 + +# Issue URL; optional; the GitHub issue related to this changeset (either closes or is part of). +# If not present is automatically filled by the tooling with the issue linked to the PR number. +issue: https://github.com/elastic/elastic-agent/issues/3609 diff --git a/internal/pkg/agent/application/coordinator/coordinator.go b/internal/pkg/agent/application/coordinator/coordinator.go index 7198f45bb04..4033f9d6cb7 100644 --- a/internal/pkg/agent/application/coordinator/coordinator.go +++ b/internal/pkg/agent/application/coordinator/coordinator.go @@ -172,6 +172,9 @@ type ConfigManager interface { type VarsManager interface { Runner + // Observe instructs the variables to observe. + Observe([]string) + // Watch returns the chanel to watch for variable changes. Watch() <-chan []*transpiler.Vars } @@ -1235,6 +1238,9 @@ func (c *Coordinator) processConfigAgent(ctx context.Context, cfg *config.Config return err } + // pass the observed vars from the AST to the varsMgr + c.observeASTVars() + // Disabled for 8.8.0 release in order to limit the surface // https://github.com/elastic/security-team/issues/6501 @@ -1313,6 +1319,30 @@ func (c *Coordinator) generateAST(cfg *config.Config) (err error) { return nil } +// observeASTVars identifies the variables that are referenced in the computed AST and passed to +// the varsMgr so it knows what providers are being referenced. If a providers is not being +// referenced then the provider does not need to be running. +func (c *Coordinator) observeASTVars() { + if c.varsMgr == nil { + // No varsMgr (only happens in testing) + return + } + if c.ast == nil { + // No AST; no vars + c.varsMgr.Observe(nil) + return + } + inputs, ok := transpiler.Lookup(c.ast, "inputs") + if !ok { + // No inputs; no vars + c.varsMgr.Observe(nil) + return + } + var vars []string + vars = inputs.Vars(vars) + c.varsMgr.Observe(vars) +} + // processVars updates the transpiler vars in the Coordinator. // Called on the main Coordinator goroutine. func (c *Coordinator) processVars(ctx context.Context, vars []*transpiler.Vars) { diff --git a/internal/pkg/agent/application/coordinator/coordinator_test.go b/internal/pkg/agent/application/coordinator/coordinator_test.go index 36f788d6f24..ea67c0d4ee5 100644 --- a/internal/pkg/agent/application/coordinator/coordinator_test.go +++ b/internal/pkg/agent/application/coordinator/coordinator_test.go @@ -12,6 +12,7 @@ import ( "path/filepath" goruntime "runtime" "strings" + "sync" "testing" "time" @@ -327,6 +328,77 @@ func mustNewStruct(t *testing.T, v map[string]interface{}) *structpb.Struct { return str } +func TestCoordinator_VarsMgr_Observe(t *testing.T) { + coordCh := make(chan error) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + coord, cfgMgr, varsMgr := createCoordinator(t, ctx) + stateChan := coord.StateSubscribe(ctx, 32) + go func() { + err := coord.Run(ctx) + if errors.Is(err, context.Canceled) { + // allowed error + err = nil + } + coordCh <- err + }() + + // wait for it to be in starting state + waitForState(t, stateChan, func(state State) bool { + return state.State == agentclient.Starting && + state.Message == "Waiting for initial configuration and composable variables" + }, 3*time.Second) + + // set vars state should stay same (until config) + varsMgr.Vars(ctx, []*transpiler.Vars{{}}) + + // State changes happen asynchronously in the Coordinator goroutine, so + // wait a little bit to make sure no changes are reported; if the Vars + // call does trigger a change, it should happen relatively quickly. + select { + case <-stateChan: + assert.Fail(t, "Vars call shouldn't cause a state change") + case <-time.After(50 * time.Millisecond): + } + + // set configuration that has variables present + cfg, err := config.NewConfigFrom(map[string]interface{}{ + "inputs": []interface{}{ + map[string]interface{}{ + "type": "filestream", + "paths": []interface{}{ + "${env.filestream_path|env.log_path|'/var/log/syslog'}", + }, + }, + map[string]interface{}{ + "type": "windows", + "condition": "${host.platform} == 'windows'", + }, + }, + }) + require.NoError(t, err) + cfgMgr.Config(ctx, cfg) + + // healthy signals that the configuration has been computed + waitForState(t, stateChan, func(state State) bool { + return state.State == agentclient.Healthy && state.Message == "Running" + }, 3*time.Second) + + // get the set observed vars from the fake vars manager + varsMgr.observedMx.Lock() + observed := varsMgr.observed + varsMgr.observedMx.Unlock() + + // stop the coordinator + cancel() + err = <-coordCh + require.NoError(t, err) + + // verify that the observed vars are the expected vars + assert.Equal(t, []string{"env.filestream_path", "env.log_path", "host.platform"}, observed) +} + func TestCoordinator_State_Starting(t *testing.T) { coordCh := make(chan error) ctx, cancel := context.WithCancel(context.Background()) @@ -1072,6 +1144,9 @@ func (l *configChange) Fail(err error) { type fakeVarsManager struct { varsCh chan []*transpiler.Vars errCh chan error + + observedMx sync.RWMutex + observed []string } func newFakeVarsManager() *fakeVarsManager { @@ -1101,6 +1176,12 @@ func (f *fakeVarsManager) Watch() <-chan []*transpiler.Vars { return f.varsCh } +func (f *fakeVarsManager) Observe(observed []string) { + f.observedMx.Lock() + defer f.observedMx.Unlock() + f.observed = observed +} + func (f *fakeVarsManager) Vars(ctx context.Context, vars []*transpiler.Vars) { select { case <-ctx.Done(): diff --git a/internal/pkg/agent/cmd/run.go b/internal/pkg/agent/cmd/run.go index 21e4c19a28e..43a2ad34ee6 100644 --- a/internal/pkg/agent/cmd/run.go +++ b/internal/pkg/agent/cmd/run.go @@ -284,15 +284,10 @@ func runElasticAgent(ctx context.Context, cancel context.CancelFunc, override cf l.Info("APM instrumentation disabled") } - coord, configMgr, composable, err := application.New(ctx, l, baseLogger, logLvl, agentInfo, rex, tracer, testingMode, fleetInitTimeout, configuration.IsFleetServerBootstrap(cfg.Fleet), modifiers...) + coord, configMgr, _, err := application.New(ctx, l, baseLogger, logLvl, agentInfo, rex, tracer, testingMode, fleetInitTimeout, configuration.IsFleetServerBootstrap(cfg.Fleet), modifiers...) if err != nil { return logReturn(l, err) } - defer func() { - if composable != nil { - composable.Close() - } - }() monitoringServer, err := setupMetrics(l, cfg.Settings.DownloadConfig.OS(), cfg.Settings.MonitoringConfig, tracer, coord) if err != nil { diff --git a/internal/pkg/agent/transpiler/ast.go b/internal/pkg/agent/transpiler/ast.go index 1fae370ce40..149818d502b 100644 --- a/internal/pkg/agent/transpiler/ast.go +++ b/internal/pkg/agent/transpiler/ast.go @@ -58,6 +58,10 @@ type Node interface { // Hash compute a sha256 hash of the current node and recursively call any children. Hash() []byte + // Vars adds to the array with the variables identified in the node. Returns the array in-case + // the capacity of the array had to be changed. + Vars([]string) []string + // Apply apply the current vars, returning the new value for the node. Apply(*Vars) (Node, error) @@ -162,6 +166,15 @@ func (d *Dict) Hash() []byte { return h.Sum(nil) } +// Vars returns a list of all variables referenced in the dictionary. +func (d *Dict) Vars(vars []string) []string { + for _, v := range d.value { + k := v.(*Key) + vars = k.Vars(vars) + } + return vars +} + // Apply applies the vars to all the nodes in the dictionary. func (d *Dict) Apply(vars *Vars) (Node, error) { nodes := make([]Node, 0, len(d.value)) @@ -277,6 +290,14 @@ func (k *Key) Hash() []byte { return h.Sum(nil) } +// Vars returns a list of all variables referenced in the value. +func (k *Key) Vars(vars []string) []string { + if k.value == nil { + return vars + } + return k.value.Vars(vars) +} + // Apply applies the vars to the value. func (k *Key) Apply(vars *Vars) (Node, error) { if k.value == nil { @@ -397,6 +418,14 @@ func (l *List) ShallowClone() Node { return &List{value: nodes} } +// Vars returns a list of all variables referenced in the list. +func (l *List) Vars(vars []string) []string { + for _, v := range l.value { + vars = v.Vars(vars) + } + return vars +} + // Apply applies the vars to all nodes in the list. func (l *List) Apply(vars *Vars) (Node, error) { nodes := make([]Node, 0, len(l.value)) @@ -472,6 +501,16 @@ func (s *StrVal) Hash() []byte { return []byte(s.value) } +// Vars returns a list of all variables referenced in the string. +func (s *StrVal) Vars(vars []string) []string { + // errors are ignored (if there is an error determine the vars it will also error computing the policy) + _, _ = replaceVars(s.value, func(variable string) (Node, Processors, bool) { + vars = append(vars, variable) + return nil, nil, false + }, false) + return vars +} + // Apply applies the vars to the string value. func (s *StrVal) Apply(vars *Vars) (Node, error) { return vars.Replace(s.value) @@ -523,6 +562,11 @@ func (s *IntVal) ShallowClone() Node { return s.Clone() } +// Vars does nothing. Cannot have variable in an IntVal. +func (s *IntVal) Vars(vars []string) []string { + return vars +} + // Apply does nothing. func (s *IntVal) Apply(_ *Vars) (Node, error) { return s, nil @@ -584,6 +628,11 @@ func (s *UIntVal) Hash() []byte { return []byte(s.String()) } +// Vars does nothing. Cannot have variable in an UIntVal. +func (s *UIntVal) Vars(vars []string) []string { + return vars +} + // Apply does nothing. func (s *UIntVal) Apply(_ *Vars) (Node, error) { return s, nil @@ -641,6 +690,11 @@ func (s *FloatVal) Hash() []byte { return []byte(strconv.FormatFloat(s.value, 'f', -1, 64)) } +// Vars does nothing. Cannot have variable in an FloatVal. +func (s *FloatVal) Vars(vars []string) []string { + return vars +} + // Apply does nothing. func (s *FloatVal) Apply(_ *Vars) (Node, error) { return s, nil @@ -703,6 +757,11 @@ func (s *BoolVal) Hash() []byte { return falseVal } +// Vars does nothing. Cannot have variable in an BoolVal. +func (s *BoolVal) Vars(vars []string) []string { + return vars +} + // Apply does nothing. func (s *BoolVal) Apply(_ *Vars) (Node, error) { return s, nil @@ -982,6 +1041,11 @@ func attachProcessors(node Node, processors Processors) Node { // Lookup accept an AST and a selector and return the matching Node at that position. func Lookup(a *AST, selector Selector) (Node, bool) { + // Be defensive and ensure that the ast is usable. + if a == nil || a.root == nil { + return nil, false + } + // Run through the graph and find matching nodes. current := a.root for _, part := range splitPath(selector) { diff --git a/internal/pkg/agent/transpiler/ast_test.go b/internal/pkg/agent/transpiler/ast_test.go index 098b6be9107..cdbaff5df7a 100644 --- a/internal/pkg/agent/transpiler/ast_test.go +++ b/internal/pkg/agent/transpiler/ast_test.go @@ -920,6 +920,157 @@ func TestShallowClone(t *testing.T) { } } +func TestVars(t *testing.T) { + tests := map[string]struct { + input map[string]interface{} + result []string + }{ + "empty": { + input: map[string]interface{}{}, + result: nil, + }, + "badbracket": { + input: map[string]interface{}{ + "badbracket": "${missing.end", + }, + result: nil, + }, + "allconstant": { + input: map[string]interface{}{ + "constant": "${'constant'}", + }, + result: nil, + }, + "escaped": { + input: map[string]interface{}{ + "constant": "$${var1}", + }, + result: nil, + }, + "nested": { + input: map[string]interface{}{ + "novars": map[string]interface{}{ + "list1": []interface{}{ + map[string]interface{}{ + "int": 1, + "float": 1.1234, + "bool": true, + "str": "value1", + }, + }, + "list2": []interface{}{ + map[string]interface{}{ + "int": 2, + "float": 2.3456, + "bool": false, + "str": "value2", + }, + }, + }, + "vars1": map[string]interface{}{ + "list1": []interface{}{ + map[string]interface{}{ + "int": 1, + "float": 1.1234, + "bool": true, + "str": "${var1|var2|'constant'}", + }, + }, + "list2": []interface{}{ + map[string]interface{}{ + "int": 2, + "float": 2.3456, + "bool": false, + "str": "${var3|var1|'constant'}", + }, + }, + }, + "vars2": map[string]interface{}{ + "list1": []interface{}{ + map[string]interface{}{ + "int": 1, + "float": 1.1234, + "bool": true, + "str": "${var5|var6|'constant'}", + }, + }, + "list2": []interface{}{ + map[string]interface{}{ + "int": 2, + "float": 2.3456, + "bool": false, + "str": "${var1}", + }, + }, + }, + }, + result: []string{"var1", "var2", "var3", "var1", "var5", "var6", "var1"}, + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + ast, err := NewAST(test.input) + require.NoError(t, err) + var vars []string + vars = ast.root.Vars(vars) + assert.Equal(t, test.result, vars) + }) + } +} + +func TestLookup(t *testing.T) { + tests := map[string]struct { + ast *AST + selector Selector + node Node + ok bool + }{ + "nil": { + ast: nil, + selector: "", + node: nil, + ok: false, + }, + "noroot": { + ast: &AST{}, + selector: "", + node: nil, + ok: false, + }, + "notfound": { + ast: &AST{ + root: NewDict([]Node{NewKey("entry", NewDict([]Node{ + NewKey("var1", NewStrVal("value1")), + NewKey("var2", NewStrVal("value2")), + }))}), + }, + selector: "entry.var3", + node: nil, + ok: false, + }, + "found": { + ast: &AST{ + root: NewDict([]Node{NewKey("entry", NewDict([]Node{ + NewKey("var1", NewStrVal("value1")), + NewKey("var2", NewStrVal("value2")), + }))}), + }, + selector: "entry.var2", + node: NewKey("var2", NewStrVal("value2")), + ok: true, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + node, ok := Lookup(test.ast, test.selector) + if assert.Equal(t, test.ok, ok) { + assert.Equal(t, test.node, node) + } + }) + } +} + func mustMakeVars(mapping map[string]interface{}) *Vars { v, err := NewVars("", mapping, nil) if err != nil { diff --git a/internal/pkg/agent/transpiler/vars.go b/internal/pkg/agent/transpiler/vars.go index bcf845b7c6f..71bd8bd4cb6 100644 --- a/internal/pkg/agent/transpiler/vars.go +++ b/internal/pkg/agent/transpiler/vars.go @@ -54,6 +54,56 @@ func NewVarsWithProcessorsFromAst(id string, tree *AST, processorKey string, pro // Replace returns a new value based on variable replacement. func (v *Vars) Replace(value string) (Node, error) { + return replaceVars(value, func(variable string) (Node, Processors, bool) { + var processors Processors + node, ok := v.lookupNode(variable) + if ok && v.processorsKey != "" && varPrefixMatched(variable, v.processorsKey) { + processors = v.processors + } + return node, processors, ok + }, true) +} + +// ID returns the unique ID for the vars. +func (v *Vars) ID() string { + return v.id +} + +// Lookup returns the value from the vars. +func (v *Vars) Lookup(name string) (interface{}, bool) { + // lookup in the AST tree + return v.tree.Lookup(name) +} + +// Map transforms the variables into a map[string]interface{} and will abort and return any errors related +// to type conversion. +func (v *Vars) Map() (map[string]interface{}, error) { + return v.tree.Map() +} + +// lookupNode performs a lookup on the AST, but keeps the result as a `Node`. +// +// This is different from `Lookup` which returns the actual type, not the AST type. +func (v *Vars) lookupNode(name string) (Node, bool) { + // check if the value can be retrieved from a FetchContextProvider + for providerName, provider := range v.fetchContextProviders { + if varPrefixMatched(name, providerName) { + fetchProvider, ok := provider.(composable.FetchContextProvider) + if !ok { + return &StrVal{value: ""}, false + } + fval, found := fetchProvider.Fetch(name) + if found { + return &StrVal{value: fval}, true + } + return &StrVal{value: ""}, false + } + } + // lookup in the AST tree + return Lookup(v.tree, name) +} + +func replaceVars(value string, replacer func(variable string) (Node, Processors, bool), reqMatch bool) (Node, error) { var processors Processors matchIdxs := varsRegex.FindAllSubmatchIndex([]byte(value), -1) if !validBrackets(value, matchIdxs) { @@ -81,11 +131,11 @@ func (v *Vars) Replace(value string) (Node, error) { result += value[lastIndex:r[0]] + val.Value() set = true case *varString: - node, ok := v.lookupNode(val.Value()) + node, nodeProcessors, ok := replacer(val.Value()) if ok { node := nodeToValue(node) - if v.processorsKey != "" && varPrefixMatched(val.Value(), v.processorsKey) { - processors = v.processors + if nodeProcessors != nil { + processors = nodeProcessors } if r[i] == 0 && r[i+1] == len(value) { // possible for complete replacement of object, because the variable @@ -100,7 +150,7 @@ func (v *Vars) Replace(value string) (Node, error) { break } } - if !set { + if !set && reqMatch { return NewStrVal(""), ErrNoMatch } lastIndex = r[1] @@ -109,45 +159,6 @@ func (v *Vars) Replace(value string) (Node, error) { return NewStrValWithProcessors(result+value[lastIndex:], processors), nil } -// ID returns the unique ID for the vars. -func (v *Vars) ID() string { - return v.id -} - -// Lookup returns the value from the vars. -func (v *Vars) Lookup(name string) (interface{}, bool) { - // lookup in the AST tree - return v.tree.Lookup(name) -} - -// Map transforms the variables into a map[string]interface{} and will abort and return any errors related -// to type conversion. -func (v *Vars) Map() (map[string]interface{}, error) { - return v.tree.Map() -} - -// lookupNode performs a lookup on the AST, but keeps the result as a `Node`. -// -// This is different from `Lookup` which returns the actual type, not the AST type. -func (v *Vars) lookupNode(name string) (Node, bool) { - // check if the value can be retrieved from a FetchContextProvider - for providerName, provider := range v.fetchContextProviders { - if varPrefixMatched(name, providerName) { - fetchProvider, ok := provider.(composable.FetchContextProvider) - if !ok { - return &StrVal{value: ""}, false - } - fval, found := fetchProvider.Fetch(name) - if found { - return &StrVal{value: fval}, true - } - return &StrVal{value: ""}, false - } - } - // lookup in the AST tree - return Lookup(v.tree, name) -} - // nodeToValue ensures that the node is an actual value. func nodeToValue(node Node) Node { switch n := node.(type) { diff --git a/internal/pkg/agent/vars/vars.go b/internal/pkg/agent/vars/vars.go index 001dac23c2e..5f99d995bfb 100644 --- a/internal/pkg/agent/vars/vars.go +++ b/internal/pkg/agent/vars/vars.go @@ -26,7 +26,6 @@ func WaitForVariables(ctx context.Context, l *logger.Logger, cfg *config.Config, if err != nil { return nil, fmt.Errorf("failed to create composable controller: %w", err) } - defer composable.Close() hasTimeout := false if wait > time.Duration(0) { diff --git a/internal/pkg/composable/benchmark_test.go b/internal/pkg/composable/benchmark_test.go index fec6e797a0f..913d8d4fbd1 100644 --- a/internal/pkg/composable/benchmark_test.go +++ b/internal/pkg/composable/benchmark_test.go @@ -28,9 +28,9 @@ func BenchmarkGenerateVars100Pods(b *testing.B) { log, err := logger.New("", false) require.NoError(b, err) c := controller{ - contextProviders: make(map[string]*contextProviderState), - dynamicProviders: make(map[string]*dynamicProviderState), - logger: log, + contextProviderStates: make(map[string]*contextProviderState), + dynamicProviderStates: make(map[string]*dynamicProviderState), + logger: log, } podCount := 100 @@ -63,14 +63,14 @@ func BenchmarkGenerateVars100Pods(b *testing.B) { } providerState.mappings[string(podUID)] = podMapping } - c.dynamicProviders[providerName] = providerState + c.dynamicProviderStates[providerName] = providerState } else { providerAst, err := transpiler.NewAST(providerData[providerName]) require.NoError(b, err) providerState := &contextProviderState{ mapping: providerAst, } - c.contextProviders[providerName] = providerState + c.contextProviderStates[providerName] = providerState } } diff --git a/internal/pkg/composable/config.go b/internal/pkg/composable/config.go index 101c95af87e..04f1b38e0a2 100644 --- a/internal/pkg/composable/config.go +++ b/internal/pkg/composable/config.go @@ -4,10 +4,15 @@ package composable -import "github.com/elastic/elastic-agent/internal/pkg/config" +import ( + "time" + + "github.com/elastic/elastic-agent/internal/pkg/config" +) // Config is config for multiple providers. type Config struct { - Providers map[string]*config.Config `config:"providers"` - ProvidersInitialDefault *bool `config:"agent.providers.initial_default"` + Providers map[string]*config.Config `config:"providers"` + ProvidersInitialDefault *bool `config:"agent.providers.initial_default"` + ProvidersRestartInterval *time.Duration `config:"agent.providers.restart_interval"` } diff --git a/internal/pkg/composable/controller.go b/internal/pkg/composable/controller.go index b3ac09f59f6..2743eee9b62 100644 --- a/internal/pkg/composable/controller.go +++ b/internal/pkg/composable/controller.go @@ -22,6 +22,10 @@ import ( "github.com/elastic/elastic-agent/pkg/core/logger" ) +const ( + defaultRetryInterval = 30 * time.Second +) + // Controller manages the state of the providers current context. type Controller interface { // Run runs the controller. @@ -35,18 +39,24 @@ type Controller interface { // Watch returns the channel to watch for variable changes. Watch() <-chan []*transpiler.Vars - // Close closes the controller, allowing for any resource - // cleanup and such. - Close() + // Observe instructs the variables to observe. + Observe([]string) } // controller manages the state of the providers current context. type controller struct { - logger *logger.Logger - ch chan []*transpiler.Vars - errCh chan error - contextProviders map[string]*contextProviderState - dynamicProviders map[string]*dynamicProviderState + logger *logger.Logger + ch chan []*transpiler.Vars + observedCh chan map[string]bool + errCh chan error + restartInterval time.Duration + + managed bool + contextProviderBuilders map[string]contextProvider + dynamicProviderBuilders map[string]dynamicProvider + + contextProviderStates map[string]*contextProviderState + dynamicProviderStates map[string]*dynamicProviderState } // New creates a new controller. @@ -67,56 +77,56 @@ func New(log *logger.Logger, c *config.Config, managed bool) (Controller, error) providersInitialDefault = *providersCfg.ProvidersInitialDefault } + restartInterval := defaultRetryInterval + if providersCfg.ProvidersRestartInterval != nil { + restartInterval = *providersCfg.ProvidersRestartInterval + } + // build all the context providers - contextProviders := map[string]*contextProviderState{} + contextProviders := map[string]contextProvider{} for name, builder := range Providers.contextProviders { pCfg, ok := providersCfg.Providers[name] if (ok && !pCfg.Enabled()) || (!ok && !providersInitialDefault) { // explicitly disabled; skipping continue } - provider, err := builder(l, pCfg, managed) - if err != nil { - return nil, errors.New(err, fmt.Sprintf("failed to build provider '%s'", name), errors.TypeConfig, errors.M("provider", name)) - } - emptyMapping, _ := transpiler.NewAST(nil) - contextProviders[name] = &contextProviderState{ - // Safe for Context to be nil here because it will be filled in - // by (*controller).Run before the provider is started. - provider: provider, - mapping: emptyMapping, + contextProviders[name] = contextProvider{ + builder: builder, + cfg: pCfg, } } // build all the dynamic providers - dynamicProviders := map[string]*dynamicProviderState{} + dynamicProviders := map[string]dynamicProvider{} for name, builder := range Providers.dynamicProviders { pCfg, ok := providersCfg.Providers[name] if (ok && !pCfg.Enabled()) || (!ok && !providersInitialDefault) { // explicitly disabled; skipping continue } - provider, err := builder(l.Named(strings.Join([]string{"providers", name}, ".")), pCfg, managed) - if err != nil { - return nil, errors.New(err, fmt.Sprintf("failed to build provider '%s'", name), errors.TypeConfig, errors.M("provider", name)) - } - dynamicProviders[name] = &dynamicProviderState{ - provider: provider, - mappings: map[string]dynamicProviderMapping{}, + dynamicProviders[name] = dynamicProvider{ + builder: builder, + cfg: pCfg, } } return &controller{ - logger: l, - ch: make(chan []*transpiler.Vars, 1), - errCh: make(chan error), - contextProviders: contextProviders, - dynamicProviders: dynamicProviders, + logger: l, + ch: make(chan []*transpiler.Vars, 1), + observedCh: make(chan map[string]bool, 1), + errCh: make(chan error), + managed: managed, + restartInterval: restartInterval, + contextProviderBuilders: contextProviders, + dynamicProviderBuilders: dynamicProviders, + contextProviderStates: make(map[string]*contextProviderState), + dynamicProviderStates: make(map[string]*dynamicProviderState), }, nil } // Run runs the controller. func (c *controller) Run(ctx context.Context) error { + var wg sync.WaitGroup c.logger.Debugf("Starting controller for composable inputs") defer c.logger.Debugf("Stopped controller for composable inputs") @@ -124,49 +134,13 @@ func (c *controller) Run(ctx context.Context) error { localCtx, cancel := context.WithCancel(ctx) defer cancel() - fetchContextProviders := mapstr.M{} - - var wg sync.WaitGroup - wg.Add(len(c.contextProviders) + len(c.dynamicProviders)) - - // run all the enabled context providers - for name, state := range c.contextProviders { - state.Context = localCtx - state.signal = stateChangedChan - go func(name string, state *contextProviderState) { - defer wg.Done() - err := state.provider.Run(ctx, state) - if err != nil && !errors.Is(err, context.Canceled) { - err = errors.New(err, fmt.Sprintf("failed to run provider '%s'", name), errors.TypeConfig, errors.M("provider", name)) - c.logger.Errorf("%s", err) - } - }(name, state) - if p, ok := state.provider.(corecomp.FetchContextProvider); ok { - _, _ = fetchContextProviders.Put(name, p) - } - } - - // run all the enabled dynamic providers - for name, state := range c.dynamicProviders { - state.Context = localCtx - state.signal = stateChangedChan - go func(name string, state *dynamicProviderState) { - defer wg.Done() - err := state.provider.Run(state) - if err != nil && !errors.Is(err, context.Canceled) { - err = errors.New(err, fmt.Sprintf("failed to run provider '%s'", name), errors.TypeConfig, errors.M("provider", name)) - c.logger.Errorf("%s", err) - } - }(name, state) - } - c.logger.Debugf("Started controller for composable inputs") t := time.NewTimer(100 * time.Millisecond) - cleanupFn := func() { + defer func() { c.logger.Debugf("Stopping controller for composable inputs") t.Stop() - cancel() + cancel() // this cancel will stop all running providers // wait for all providers to stop (but its possible they still send notifications over notify // channel, and we cannot block them sending) @@ -184,7 +158,38 @@ func (c *controller) Run(ctx context.Context) error { close(c.ch) wg.Wait() + }() + + // synchronize the fetch providers through a channel + var fetchProvidersLock sync.RWMutex + var fetchProviders mapstr.M + fetchCh := make(chan fetchProvider) + go func() { + for { + select { + case <-localCtx.Done(): + return + case msg := <-fetchCh: + fetchProvidersLock.Lock() + if msg.fetchProvider == nil { + _ = fetchProviders.Delete(msg.name) + } else { + _, _ = fetchProviders.Put(msg.name, msg.fetchProvider) + } + fetchProvidersLock.Unlock() + } + } + }() + + // send initial vars state + fetchProvidersLock.RLock() + err := c.sendVars(ctx, fetchProviders) + if err != nil { + fetchProvidersLock.RUnlock() + // only error is context cancel, no need to add error message context + return err } + fetchProvidersLock.RUnlock() // performs debounce of notifies; accumulates them into 100 millisecond chunks for { @@ -192,8 +197,15 @@ func (c *controller) Run(ctx context.Context) error { for { select { case <-ctx.Done(): - cleanupFn() return ctx.Err() + case observed := <-c.observedCh: + changed := c.handleObserved(localCtx, &wg, fetchCh, stateChangedChan, observed) + if changed { + t.Reset(100 * time.Millisecond) + c.logger.Debugf("Observed state changed for composable inputs; debounce started") + drainChan(stateChangedChan) + break DEBOUNCE + } case <-stateChangedChan: t.Reset(100 * time.Millisecond) c.logger.Debugf("Variable state changed for composable inputs; debounce started") @@ -205,32 +217,41 @@ func (c *controller) Run(ctx context.Context) error { // notification received, wait for batch select { case <-ctx.Done(): - cleanupFn() return ctx.Err() case <-t.C: drainChan(stateChangedChan) // batching done, gather results } - c.logger.Debugf("Computing new variable state for composable inputs") - - vars := c.generateVars(fetchContextProviders) + // send the vars to the watcher + fetchProvidersLock.RLock() + err := c.sendVars(ctx, fetchProviders) + if err != nil { + fetchProvidersLock.RUnlock() + // only error is context cancel, no need to add error message context + return err + } + fetchProvidersLock.RUnlock() + } +} - UPDATEVARS: - for { +func (c *controller) sendVars(ctx context.Context, fetchContextProviders mapstr.M) error { + c.logger.Debugf("Computing new variable state for composable inputs") + vars := c.generateVars(fetchContextProviders) + for { + select { + case c.ch <- vars: + return nil + case <-ctx.Done(): + // coordinator is handling cancellation it won't drain the channel + return ctx.Err() + default: + // c.ch is size of 1, nothing is reading and there's already a signal select { - case c.ch <- vars: - break UPDATEVARS - case <-ctx.Done(): - // coordinator is handling cancellation it won't drain the channel + case <-c.ch: + // Vars not pushed, cleaning channel default: - // c.ch is size of 1, nothing is reading and there's already a signal - select { - case <-c.ch: - // Vars not pushed, cleaning channel - default: - // already read - } + // already read } } } @@ -246,45 +267,220 @@ func (c *controller) Watch() <-chan []*transpiler.Vars { return c.ch } -// Close closes the controller, allowing for any resource -// cleanup and such. -func (c *controller) Close() { - // Attempt to close all closeable context providers. - for name, state := range c.contextProviders { - cp, ok := state.provider.(corecomp.CloseableProvider) - if !ok { +// Observe sends the observed variables from the AST to the controller. +// +// Based on this information it will determine which providers should even be running. +func (c *controller) Observe(vars []string) { + // only need the top-level variables to determine which providers to run + // + // future: possible that all vars could be organized and then passed to each provider to + // inform the provider on which variables it needs to provide values for. + topLevel := make(map[string]bool) + for _, v := range vars { + vs := strings.SplitN(v, ".", 2) + topLevel[vs[0]] = true + } + // drain the channel first, if the previous vars had not been used yet the new list should be used instead + drainChan(c.observedCh) + c.observedCh <- topLevel +} + +func (c *controller) handleObserved(ctx context.Context, wg *sync.WaitGroup, fetchCh chan fetchProvider, stateChangedChan chan bool, observed map[string]bool) bool { + changed := false + + // get the list of already running, so we can determine a list that needs to be stopped + runningCtx := make(map[string]*contextProviderState, len(c.contextProviderStates)) + runningDyn := make(map[string]*dynamicProviderState, len(c.dynamicProviderStates)) + for name, state := range c.contextProviderStates { + runningCtx[name] = state + } + for name, state := range c.dynamicProviderStates { + runningDyn[name] = state + } + + // loop through the top-level observed variables and start the providers that are current off + for name, enabled := range observed { + if !enabled { + // should always be true, but just in-case + continue + } + _, ok := runningCtx[name] + if ok { + // already running + delete(runningCtx, name) continue } + _, ok = runningDyn[name] + if ok { + // already running + delete(runningDyn, name) + continue + } + + contextInfo, ok := c.contextProviderBuilders[name] + if ok { + state := c.startContextProvider(ctx, wg, fetchCh, stateChangedChan, name, contextInfo) + if state != nil { + changed = true + c.contextProviderStates[name] = state - if err := cp.Close(); err != nil { - c.logger.Errorf("unable to close context provider %q: %s", name, err.Error()) + } + } + dynamicInfo, ok := c.dynamicProviderBuilders[name] + if ok { + state := c.startDynamicProvider(ctx, wg, stateChangedChan, name, dynamicInfo) + if state != nil { + changed = true + c.dynamicProviderStates[name] = state + } } + c.logger.Warnf("provider %q referenced in policy but no provider exists or was explicitly disabled", name) } - // Attempt to close all closeable dynamic providers. - for name, state := range c.dynamicProviders { - cp, ok := state.provider.(corecomp.CloseableProvider) - if !ok { - continue - } + // running remaining need to be stopped + for name, state := range runningCtx { + changed = true + state.logger.Infof("Stopping provider %q", name) + state.canceller() + delete(c.contextProviderStates, name) + } + for name, state := range runningDyn { + changed = true + state.logger.Infof("Stopping dynamic provider %q", name) + state.canceller() + delete(c.dynamicProviderStates, name) + } + + return changed +} + +func (c *controller) startContextProvider(ctx context.Context, wg *sync.WaitGroup, fetchCh chan fetchProvider, stateChangedChan chan bool, name string, info contextProvider) *contextProviderState { + wg.Add(1) + l := c.logger.Named(strings.Join([]string{"providers", name}, ".")) + + ctx, cancel := context.WithCancel(ctx) + emptyMapping, _ := transpiler.NewAST(nil) + state := &contextProviderState{ + Context: ctx, + mapping: emptyMapping, + signal: stateChangedChan, + logger: l, + canceller: cancel, + } + go func() { + defer wg.Done() + for { + l.Infof("Starting context provider %q", name) + + provider, err := info.builder(l, info.cfg, c.managed) + if err != nil { + l.Errorf("provider %q failed to build (will retry in %s): %s", name, c.restartInterval.String(), err) + select { + case <-ctx.Done(): + return + case <-time.After(c.restartInterval): + // wait restart interval and then try again + } + continue + } - if err := cp.Close(); err != nil { - c.logger.Errorf("unable to close dynamic provider %q: %s", name, err.Error()) + fp, fpok := provider.(corecomp.FetchContextProvider) + if fpok { + sendFetchProvider(ctx, fetchCh, name, fp) + } + + err = provider.Run(ctx, state) + closeProvider(l, name, provider) + if errors.Is(err, context.Canceled) { + // valid exit + if fpok { + // turn off fetch provider + sendFetchProvider(ctx, fetchCh, name, nil) + } + return + } + // all other exits are bad, even a nil error + l.Errorf("provider %q failed to run (will retry in %s): %s", name, c.restartInterval.String(), err) + if fpok { + // turn off fetch provider + sendFetchProvider(ctx, fetchCh, name, nil) + } + select { + case <-ctx.Done(): + return + case <-time.After(c.restartInterval): + // wait restart interval and then try again + } } + }() + return state +} + +func sendFetchProvider(ctx context.Context, fetchCh chan fetchProvider, name string, fp corecomp.FetchContextProvider) { + select { + case <-ctx.Done(): + case fetchCh <- fetchProvider{name: name, fetchProvider: fp}: } } +func (c *controller) startDynamicProvider(ctx context.Context, wg *sync.WaitGroup, stateChangedChan chan bool, name string, info dynamicProvider) *dynamicProviderState { + wg.Add(1) + l := c.logger.Named(strings.Join([]string{"providers", name}, ".")) + + ctx, cancel := context.WithCancel(ctx) + state := &dynamicProviderState{ + Context: ctx, + mappings: map[string]dynamicProviderMapping{}, + signal: stateChangedChan, + logger: l, + canceller: cancel, + } + go func() { + defer wg.Done() + for { + l.Infof("Starting dynamic provider %q", name) + + provider, err := info.builder(l, info.cfg, c.managed) + if err != nil { + l.Errorf("provider %q failed to build (will retry in %s): %s", name, c.restartInterval.String(), err) + select { + case <-ctx.Done(): + return + case <-time.After(c.restartInterval): + // wait restart interval and then try again + } + continue + } + + err = provider.Run(state) + closeProvider(l, name, provider) + if errors.Is(err, context.Canceled) { + return + } + // all other exits are bad, even a nil error + l.Errorf("provider %q failed to run (will restart in %s): %s", name, c.restartInterval.String(), err) + select { + case <-ctx.Done(): + return + case <-time.After(c.restartInterval): + // wait restart interval and then try again + } + } + }() + return state +} + func (c *controller) generateVars(fetchContextProviders mapstr.M) []*transpiler.Vars { // build the vars list of mappings vars := make([]*transpiler.Vars, 1) mapping, _ := transpiler.NewAST(map[string]any{}) - for name, state := range c.contextProviders { + for name, state := range c.contextProviderStates { _ = mapping.Insert(state.Current(), name) } vars[0] = transpiler.NewVarsFromAst("", mapping, fetchContextProviders) // add to the vars list for each dynamic providers mappings - for name, state := range c.dynamicProviders { + for name, state := range c.dynamicProviderStates { for _, mappings := range state.Mappings() { local := mapping.ShallowClone() _ = local.Insert(mappings.mapping, name) @@ -296,13 +492,41 @@ func (c *controller) generateVars(fetchContextProviders mapstr.M) []*transpiler. return vars } +func closeProvider(l *logger.Logger, name string, provider interface{}) { + cp, ok := provider.(corecomp.CloseableProvider) + if !ok { + // doesn't implement Close + return + } + if err := cp.Close(); err != nil { + l.Errorf("unable to close context provider %q: %s", name, err) + } +} + +type contextProvider struct { + builder ContextProviderBuilder + cfg *config.Config +} + +type dynamicProvider struct { + builder DynamicProviderBuilder + cfg *config.Config +} + +type fetchProvider struct { + name string + fetchProvider corecomp.FetchContextProvider +} + type contextProviderState struct { context.Context - provider corecomp.ContextProvider - lock sync.RWMutex - mapping *transpiler.AST - signal chan bool + lock sync.RWMutex + mapping *transpiler.AST + signal chan bool + + logger *logger.Logger + canceller context.CancelFunc } // Signal signals that something has changed in the provider. @@ -357,10 +581,12 @@ type dynamicProviderMapping struct { type dynamicProviderState struct { context.Context - provider DynamicProvider lock sync.Mutex mappings map[string]dynamicProviderMapping signal chan bool + + logger *logger.Logger + canceller context.CancelFunc } // AddOrUpdate adds or updates the current mapping for the dynamic provider. @@ -471,7 +697,7 @@ func addToSet(set []int, i int) []int { return append(set, i) } -func drainChan(ch chan bool) { +func drainChan[T any](ch chan T) { for { select { case <-ch: diff --git a/internal/pkg/composable/controller_test.go b/internal/pkg/composable/controller_test.go index b4f0b383e76..3d3f532dc6b 100644 --- a/internal/pkg/composable/controller_test.go +++ b/internal/pkg/composable/controller_test.go @@ -11,16 +11,13 @@ import ( "testing" "time" - "github.com/elastic/elastic-agent/pkg/core/logger" - - "github.com/elastic/elastic-agent/internal/pkg/agent/transpiler" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/elastic/elastic-agent/internal/pkg/agent/transpiler" "github.com/elastic/elastic-agent/internal/pkg/composable" "github.com/elastic/elastic-agent/internal/pkg/config" + "github.com/elastic/elastic-agent/pkg/core/logger" _ "github.com/elastic/elastic-agent/internal/pkg/composable/providers/env" _ "github.com/elastic/elastic-agent/internal/pkg/composable/providers/host" @@ -82,21 +79,28 @@ func TestController(t *testing.T) { c, err := composable.New(log, cfg, false) require.NoError(t, err) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - timeoutCtx, timeoutCancel := context.WithTimeout(ctx, 1*time.Second) - defer timeoutCancel() - - var setVars []*transpiler.Vars + var setVars1 []*transpiler.Vars + var setVars2 []*transpiler.Vars + var setVars3 []*transpiler.Vars go func() { - defer cancel() for { select { - case <-timeoutCtx.Done(): + case <-ctx.Done(): return case vars := <-c.Watch(): - setVars = vars + if setVars1 == nil { + setVars1 = vars + c.Observe([]string{"local.vars.key1", "local_dynamic.vars.key1"}) // observed local and local_dynamic + } else if setVars2 == nil { + setVars2 = vars + c.Observe(nil) // no observed (will turn off those providers) + } else { + setVars3 = vars + cancel() + } } } }() @@ -111,52 +115,67 @@ func TestController(t *testing.T) { } require.NoError(t, err) - assert.Len(t, setVars, 3) + assert.Len(t, setVars1, 1) + assert.Len(t, setVars2, 3) + assert.Len(t, setVars3, 1) - _, hostExists := setVars[0].Lookup("host") - assert.True(t, hostExists) - _, envExists := setVars[0].Lookup("env") - assert.False(t, envExists) - local, _ := setVars[0].Lookup("local") + vars1map, err := setVars1[0].Map() + require.NoError(t, err) + assert.Len(t, vars1map, 0) // should be empty on initial + + _, hostExists := setVars2[0].Lookup("host") + assert.False(t, hostExists) // should not exist, not referenced + _, envExists := setVars2[0].Lookup("env") + assert.False(t, envExists) // should not exist, not referenced + local, _ := setVars2[0].Lookup("local") localMap, ok := local.(map[string]interface{}) require.True(t, ok) assert.Equal(t, "value1", localMap["key1"]) - local, _ = setVars[1].Lookup("local_dynamic") + local, _ = setVars2[1].Lookup("local_dynamic") localMap, ok = local.(map[string]interface{}) require.True(t, ok) assert.Equal(t, "value1", localMap["key1"]) - local, _ = setVars[2].Lookup("local_dynamic") + local, _ = setVars2[2].Lookup("local_dynamic") localMap, ok = local.(map[string]interface{}) require.True(t, ok) assert.Equal(t, "value2", localMap["key1"]) + + vars3map, err := setVars3[0].Map() + require.NoError(t, err) + assert.Len(t, vars3map, 0) // should be empty after empty Observe } func TestProvidersDefaultDisabled(t *testing.T) { tests := []struct { - name string - cfg map[string]interface{} - want int + name string + cfg map[string]interface{} + observed []string + context []string + dynamic []string }{ { name: "default disabled", cfg: map[string]interface{}{ "agent.providers.initial_default": "false", }, - want: 0, + observed: []string{"env.var1", "host.name"}, // has observed but explicitly disabled + context: nil, // should have none }, { name: "default enabled", cfg: map[string]interface{}{ "agent.providers.initial_default": "true", }, - want: 1, + observed: []string{"env.var1", "host.name"}, + context: []string{"env", "host"}, }, { - name: "default enabled - no config", - cfg: map[string]interface{}{}, - want: 1, + name: "default enabled - no config", + cfg: map[string]interface{}{}, + observed: nil, // none observed + context: nil, // should have none }, { name: "default enabled - explicit config", @@ -206,7 +225,9 @@ func TestProvidersDefaultDisabled(t *testing.T) { }, }, }, - want: 3, + observed: []string{"local.vars.key1", "local_dynamic.vars.key1"}, + context: []string{"local"}, + dynamic: []string{"local_dynamic", "local_dynamic"}, }, } @@ -220,6 +241,8 @@ func TestProvidersDefaultDisabled(t *testing.T) { c, err := composable.New(log, cfg, false) require.NoError(t, err) + c.Observe(tt.observed) + ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -249,7 +272,26 @@ func TestProvidersDefaultDisabled(t *testing.T) { } require.NoError(t, err) - assert.Len(t, setVars, tt.want) + if len(tt.context) > 0 { + for _, name := range tt.context { + _, ok := setVars[0].Lookup(name) + assert.Truef(t, ok, "context vars group missing %s", name) + } + } else { + m, err := setVars[0].Map() + if assert.NoErrorf(t, err, "failed to convert context vars to map") { + assert.Len(t, m, 0) // should be empty + } + } + if len(tt.dynamic) > 0 { + for i, name := range tt.dynamic { + _, ok := setVars[i+1].Lookup(name) + assert.Truef(t, ok, "dynamic vars group %d missing %s", i+1, name) + } + } else { + // should not have any dynamic vars + assert.Len(t, setVars, 1) + } }) } } @@ -312,7 +354,6 @@ func TestCancellation(t *testing.T) { t.Run(fmt.Sprintf("test run %d", i), func(t *testing.T) { c, err := composable.New(log, cfg, false) require.NoError(t, err) - defer c.Close() ctx, cancelFn := context.WithTimeout(context.Background(), timeout) defer cancelFn() @@ -328,7 +369,6 @@ func TestCancellation(t *testing.T) { t.Run("immediate cancellation", func(t *testing.T) { c, err := composable.New(log, cfg, false) require.NoError(t, err) - defer c.Close() ctx, cancelFn := context.WithTimeout(context.Background(), 0) cancelFn()