From 644782c63cdaf7961248a31737b12a0335cee53f Mon Sep 17 00:00:00 2001 From: Ivan Milchev Date: Tue, 6 Feb 2024 14:56:43 +0200 Subject: [PATCH 1/5] coordinator v2 Signed-off-by: Ivan Milchev --- apps/cnquery/cmd/login.go | 2 +- apps/cnquery/cmd/logout.go | 2 +- apps/cnquery/cmd/plugin.go | 2 +- apps/cnquery/cmd/shell.go | 2 +- apps/cnquery/cmd/status.go | 2 +- cli/providers/providers.go | 16 +-- explorer/scan/discovery.go | 78 +++++++++---- explorer/scan/local_scanner.go | 40 +++---- providers/builtin.go | 2 +- providers/defaults_shared.go | 2 +- .../{coordinator.go => global_coordinator.go} | 33 +++++- providers/local_coordinator.go | 104 ++++++++++++++++++ providers/mock.go | 2 +- providers/providers.go | 2 +- providers/runtime.go | 6 +- 15 files changed, 232 insertions(+), 63 deletions(-) rename providers/{coordinator.go => global_coordinator.go} (93%) create mode 100644 providers/local_coordinator.go diff --git a/apps/cnquery/cmd/login.go b/apps/cnquery/cmd/login.go index 2699654ef8..22bd37ac53 100644 --- a/apps/cnquery/cmd/login.go +++ b/apps/cnquery/cmd/login.go @@ -52,7 +52,7 @@ You remain logged in until you explicitly log out using the 'logout' subcommand. viper.BindPFlag("name", cmd.Flags().Lookup("name")) }, RunE: func(cmd *cobra.Command, args []string) error { - defer cnquery_providers.Coordinator.Shutdown() + defer cnquery_providers.GlobalCoordinator.Shutdown() token, _ := cmd.Flags().GetString("token") annotations, _ := cmd.Flags().GetStringToString("annotation") return register(token, annotations) diff --git a/apps/cnquery/cmd/logout.go b/apps/cnquery/cmd/logout.go index 08dd8c0045..9c68d1297e 100644 --- a/apps/cnquery/cmd/logout.go +++ b/apps/cnquery/cmd/logout.go @@ -36,7 +36,7 @@ ensure the credentials cannot be used in the future. viper.BindPFlag("force", cmd.Flags().Lookup("force")) }, RunE: func(cmd *cobra.Command, args []string) error { - defer cnquery_providers.Coordinator.Shutdown() + defer cnquery_providers.GlobalCoordinator.Shutdown() var err error // its perfectly fine not to have a config here, therefore we ignore errors diff --git a/apps/cnquery/cmd/plugin.go b/apps/cnquery/cmd/plugin.go index c216f0939e..31c13c9eec 100644 --- a/apps/cnquery/cmd/plugin.go +++ b/apps/cnquery/cmd/plugin.go @@ -129,7 +129,7 @@ func (c *cnqueryPlugin) RunQuery(conf *run.RunQueryConfig, runtime *providers.Ru return err } - connectAssetRuntime, err := providers.Coordinator.RuntimeFor(asset, runtime) + connectAssetRuntime, err := providers.GlobalCoordinator.RuntimeFor(asset, runtime) if err != nil { return err } diff --git a/apps/cnquery/cmd/shell.go b/apps/cnquery/cmd/shell.go index 6db9ee4a10..4c5607c156 100644 --- a/apps/cnquery/cmd/shell.go +++ b/apps/cnquery/cmd/shell.go @@ -153,7 +153,7 @@ func StartShell(runtime *providers.Runtime, conf *ShellConfig) error { // when we close the shell, we need to close the backend and store the recording onCloseHandler := func() { runtime.Close() - providers.Coordinator.Shutdown() + providers.GlobalCoordinator.Shutdown() } shellOptions := []shell.ShellOption{} diff --git a/apps/cnquery/cmd/status.go b/apps/cnquery/cmd/status.go index 841d0eed4d..177d09cf83 100644 --- a/apps/cnquery/cmd/status.go +++ b/apps/cnquery/cmd/status.go @@ -43,7 +43,7 @@ Status sends a ping to Mondoo Platform to verify the credentials. viper.BindPFlag("output", cmd.Flags().Lookup("output")) }, RunE: func(cmd *cobra.Command, args []string) error { - defer providers.Coordinator.Shutdown() + defer providers.GlobalCoordinator.Shutdown() opts, optsErr := config.Read() if optsErr != nil { return cli_errors.NewCommandError(errors.Wrap(optsErr, "could not load configuration"), 1) diff --git a/cli/providers/providers.go b/cli/providers/providers.go index 425cb6d4e0..e115c723ff 100644 --- a/cli/providers/providers.go +++ b/cli/providers/providers.go @@ -424,7 +424,7 @@ func setConnector(provider *plugin.Provider, connector *plugin.Connector, run fu } // TODO: add flag to set timeout and then use RuntimeWithShutdownTimeout - runtime := providers.Coordinator.NewRuntime() + runtime := providers.GlobalCoordinator.NewRuntime() if err = providers.SetDefaultRuntime(runtime); err != nil { log.Error().Msg(err.Error()) } @@ -440,12 +440,12 @@ func setConnector(provider *plugin.Provider, connector *plugin.Connector, run fu } if err := runtime.UseProvider(provider.ID); err != nil { - providers.Coordinator.Shutdown() + providers.GlobalCoordinator.Shutdown() log.Fatal().Err(err).Msg("failed to start provider " + provider.Name) } if record != "" && useRecording != "" { - providers.Coordinator.Shutdown() + providers.GlobalCoordinator.Shutdown() log.Fatal().Msg("please only use --record or --use-recording, but not both at the same time") } recordingPath := record @@ -459,7 +459,7 @@ func setConnector(provider *plugin.Provider, connector *plugin.Connector, run fu PrettyPrintJSON: pretty, }) if err != nil { - providers.Coordinator.Shutdown() + providers.GlobalCoordinator.Shutdown() log.Fatal().Msg(err.Error()) } runtime.SetRecording(recording) @@ -471,13 +471,13 @@ func setConnector(provider *plugin.Provider, connector *plugin.Connector, run fu }) if err != nil { runtime.Close() - providers.Coordinator.Shutdown() + providers.GlobalCoordinator.Shutdown() log.Fatal().Err(err).Msg("failed to parse cli arguments") } if cliRes == nil { runtime.Close() - providers.Coordinator.Shutdown() + providers.GlobalCoordinator.Shutdown() log.Fatal().Msg("failed to process CLI arguments, nothing was returned") return // adding this here as a compiler hint to stop warning about nil-dereferences } @@ -485,7 +485,7 @@ func setConnector(provider *plugin.Provider, connector *plugin.Connector, run fu if cliRes.Asset == nil { log.Warn().Err(err).Msg("failed to discover assets after processing CLI arguments") } else { - assetRuntime, err := providers.Coordinator.RuntimeFor(cliRes.Asset, runtime) + assetRuntime, err := providers.GlobalCoordinator.RuntimeFor(cliRes.Asset, runtime) if err != nil { log.Warn().Err(err).Msg("failed to get runtime for an asset that was detected after parsing the CLI") } else { @@ -495,7 +495,7 @@ func setConnector(provider *plugin.Provider, connector *plugin.Connector, run fu run(cc, runtime, cliRes) runtime.Close() - providers.Coordinator.Shutdown() + providers.GlobalCoordinator.Shutdown() } attachFlags(cmd.Flags(), allFlags) diff --git a/explorer/scan/discovery.go b/explorer/scan/discovery.go index 129b0bb702..7991890257 100644 --- a/explorer/scan/discovery.go +++ b/explorer/scan/discovery.go @@ -18,6 +18,12 @@ import ( "go.mondoo.com/cnquery/v10/providers-sdk/v1/upstream" ) +type RootAsset struct { + Coordinator providers.Coordinator + Asset *inventory.Asset + Children []*AssetWithRuntime +} + type AssetWithRuntime struct { Asset *inventory.Asset Runtime *providers.Runtime @@ -30,13 +36,26 @@ type AssetWithError struct { type DiscoveredAssets struct { platformIds map[string]struct{} - Assets []*AssetWithRuntime - Errors []*AssetWithError + // Assets []*AssetWithRuntime + Errors []*AssetWithError + + RootAssets map[*inventory.Asset]*RootAsset +} + +func (d *DiscoveredAssets) AddRoot(root *inventory.Asset, coordinator providers.Coordinator, runtime *providers.Runtime) bool { + if _, ok := d.RootAssets[root]; ok { + return false + } + d.RootAssets[root] = &RootAsset{ + Coordinator: coordinator, + Asset: root, + } + return true } // Add adds an asset and its runtime to the discovered assets list. It returns true if the // asset has been added, false if it is a duplicate -func (d *DiscoveredAssets) Add(asset *inventory.Asset, runtime *providers.Runtime) bool { +func (d *DiscoveredAssets) Add(root, asset *inventory.Asset, runtime *providers.Runtime) bool { isDuplicate := false for _, platformId := range asset.PlatformIds { if _, ok := d.platformIds[platformId]; ok { @@ -49,7 +68,7 @@ func (d *DiscoveredAssets) Add(asset *inventory.Asset, runtime *providers.Runtim return false } - d.Assets = append(d.Assets, &AssetWithRuntime{Asset: asset, Runtime: runtime}) + d.RootAssets[root].Children = append(d.RootAssets[root].Children, &AssetWithRuntime{Asset: asset, Runtime: runtime}) return true } @@ -57,13 +76,23 @@ func (d *DiscoveredAssets) AddError(asset *inventory.Asset, err error) { d.Errors = append(d.Errors, &AssetWithError{Asset: asset, Err: err}) } +func (d *DiscoveredAssets) GetFlattenedChildren() []*AssetWithRuntime { + var assets []*AssetWithRuntime + for _, a := range d.RootAssets { + assets = append(assets, a.Children...) + } + return assets +} + func (d *DiscoveredAssets) GetAssetsByPlatformID(platformID string) []*inventory.Asset { var assets []*inventory.Asset - for _, a := range d.Assets { - for _, p := range a.Asset.PlatformIds { - if platformID == "" || p == platformID { - assets = append(assets, a.Asset) - break + for _, a := range d.RootAssets { + for _, c := range a.Children { + for _, p := range c.Asset.PlatformIds { + if platformID == "" || p == platformID { + assets = append(assets, a.Asset) + break + } } } } @@ -91,7 +120,7 @@ func DiscoverAssets(ctx context.Context, inv *inventory.Inventory, upstream *ups runtimeLabels = runtimeEnv.Labels() } - discoveredAssets := &DiscoveredAssets{platformIds: map[string]struct{}{}} + discoveredAssets := &DiscoveredAssets{platformIds: map[string]struct{}{}, RootAssets: map[*inventory.Asset]*RootAsset{}} // we connect and perform discovery for each asset in the job inventory for _, rootAsset := range invAssets { @@ -100,20 +129,29 @@ func DiscoverAssets(ctx context.Context, inv *inventory.Inventory, upstream *ups return nil, err } + coordinator := providers.NewLocalCoordinator(providers.GlobalCoordinator) + // create runtime for root asset - rootAssetWithRuntime, err := createRuntimeForAsset(resolvedRootAsset, upstream, recording) + rootAssetWithRuntime, err := createRuntimeForAsset(resolvedRootAsset, coordinator, upstream, recording) if err != nil { log.Error().Err(err).Str("asset", resolvedRootAsset.Name).Msg("unable to create runtime for asset") discoveredAssets.AddError(rootAssetWithRuntime.Asset, err) + coordinator.Shutdown() continue } resolvedRootAsset = rootAssetWithRuntime.Asset // to ensure we get all the information the connect call gave us + if !discoveredAssets.AddRoot(resolvedRootAsset, coordinator, rootAssetWithRuntime.Runtime) { + rootAssetWithRuntime.Runtime.Close() + coordinator.Shutdown() + continue + } + // If the root asset has platform IDs, then it is a scannable asset, so we need to add it if len(resolvedRootAsset.PlatformIds) > 0 { prepareAsset(resolvedRootAsset, resolvedRootAsset, runtimeLabels) - if !discoveredAssets.Add(rootAssetWithRuntime.Asset, rootAssetWithRuntime.Runtime) { + if !discoveredAssets.Add(resolvedRootAsset, rootAssetWithRuntime.Asset, rootAssetWithRuntime.Runtime) { rootAssetWithRuntime.Runtime.Close() } } @@ -126,7 +164,7 @@ func DiscoverAssets(ctx context.Context, inv *inventory.Inventory, upstream *ups // for all discovered assets, we apply mondoo-specific labels and annotations that come from the root asset for _, a := range rootAssetWithRuntime.Runtime.Provider.Connection.Inventory.Spec.Assets { // create runtime for root asset - assetWithRuntime, err := createRuntimeForAsset(a, upstream, recording) + assetWithRuntime, err := createRuntimeForAsset(a, coordinator, upstream, recording) if err != nil { log.Error().Err(err).Str("asset", a.Name).Msg("unable to create runtime for asset") discoveredAssets.AddError(assetWithRuntime.Asset, err) @@ -137,7 +175,7 @@ func DiscoverAssets(ctx context.Context, inv *inventory.Inventory, upstream *ups prepareAsset(resolvedAsset, resolvedRootAsset, runtimeLabels) // If the asset has been already added, we should close its runtime - if !discoveredAssets.Add(resolvedAsset, assetWithRuntime.Runtime) { + if !discoveredAssets.Add(resolvedRootAsset, resolvedAsset, assetWithRuntime.Runtime) { assetWithRuntime.Runtime.Close() } } @@ -146,15 +184,15 @@ func DiscoverAssets(ctx context.Context, inv *inventory.Inventory, upstream *ups // if there is exactly one asset, assure that the --asset-name is used // TODO: make it so that the --asset-name is set for the root asset only even if multiple assets are there // This is a temporary fix that only works if there is only one asset - if len(discoveredAssets.Assets) == 1 && invAssets[0].Name != "" && invAssets[0].Name != discoveredAssets.Assets[0].Asset.Name { - log.Debug().Str("asset", discoveredAssets.Assets[0].Asset.Name).Msg("Overriding asset name with --asset-name flag") - discoveredAssets.Assets[0].Asset.Name = invAssets[0].Name - } + // if len(discoveredAssets.Assets) == 1 && invAssets[0].Name != "" && invAssets[0].Name != discoveredAssets.Assets[0].Asset.Name { + // log.Debug().Str("asset", discoveredAssets.Assets[0].Asset.Name).Msg("Overriding asset name with --asset-name flag") + // discoveredAssets.Assets[0].Asset.Name = invAssets[0].Name + // } return discoveredAssets, nil } -func createRuntimeForAsset(asset *inventory.Asset, upstream *upstream.UpstreamConfig, recording llx.Recording) (*AssetWithRuntime, error) { +func createRuntimeForAsset(asset *inventory.Asset, coordinator providers.Coordinator, upstream *upstream.UpstreamConfig, recording llx.Recording) (*AssetWithRuntime, error) { var runtime *providers.Runtime var err error // Close the runtime if an error occured @@ -164,7 +202,7 @@ func createRuntimeForAsset(asset *inventory.Asset, upstream *upstream.UpstreamCo } }() - runtime, err = providers.Coordinator.RuntimeFor(asset, providers.DefaultRuntime()) + runtime, err = coordinator.RuntimeFor(asset, providers.DefaultRuntime()) if err != nil { return nil, err } diff --git a/explorer/scan/local_scanner.go b/explorer/scan/local_scanner.go index cc998ce0fc..2b4197d3d2 100644 --- a/explorer/scan/local_scanner.go +++ b/explorer/scan/local_scanner.go @@ -176,8 +176,9 @@ func CreateProgressBar(discoveredAssets *DiscoveredAssets, disableProgressBar bo if isatty.IsTerminal(os.Stdout.Fd()) && !disableProgressBar && !strings.EqualFold(logger.GetLevel(), "debug") && !strings.EqualFold(logger.GetLevel(), "trace") { progressBarElements := map[string]string{} orderedKeys := []string{} - for i := range discoveredAssets.Assets { - asset := discoveredAssets.Assets[i].Asset + assets := discoveredAssets.GetFlattenedChildren() + for i := range assets { + asset := assets[i].Asset // this shouldn't happen, but might // it normally indicates a bug in the provider if presentAsset, present := progressBarElements[asset.PlatformIds[0]]; present { @@ -201,9 +202,6 @@ func CreateProgressBar(discoveredAssets *DiscoveredAssets, disableProgressBar bo func (s *LocalScanner) distributeJob(job *Job, ctx context.Context, upstream *upstream.UpstreamConfig) (*explorer.ReportCollection, error) { log.Info().Msgf("discover related assets for %d asset(s)", len(job.Inventory.Spec.Assets)) - // Always shut down the coordinator, to make sure providers are killed - defer providers.Coordinator.Shutdown() - discoveredAssets, err := DiscoverAssets(ctx, job.Inventory, upstream, s.recording) if err != nil { return nil, err @@ -213,10 +211,10 @@ func (s *LocalScanner) distributeJob(job *Job, ctx context.Context, upstream *up // Within this process, we set up a catch-all deferred function, that shuts // down all runtimes, in case we exit early. defer func() { - for _, asset := range discoveredAssets.Assets { + for _, asset := range discoveredAssets.RootAssets { // we can call close multiple times and it will only execute once - if asset.Runtime != nil { - asset.Runtime.Close() + if asset.Coordinator != nil { + asset.Coordinator.Shutdown() } } }() @@ -228,7 +226,8 @@ func (s *LocalScanner) distributeJob(job *Job, ctx context.Context, upstream *up reporter.AddScanError(discoveredAssets.Errors[i].Asset, discoveredAssets.Errors[i].Err) } - if len(discoveredAssets.Assets) == 0 { + assets := discoveredAssets.GetFlattenedChildren() + if len(assets) == 0 { return reporter.Reports(), nil } @@ -246,7 +245,13 @@ func (s *LocalScanner) distributeJob(job *Job, ctx context.Context, upstream *up } }() - assetBatches := slicesx.Batch(discoveredAssets.Assets, 100) + // if a bundle was provided check that it matches the filter, bundles can also be downloaded + // later therefore we do not want to stop execution here + if job.Bundle != nil && job.Bundle.FilterQueryPacks(job.QueryPackFilters) { + return nil, errors.New("all available packs filtered out. nothing to do") + } + + assetBatches := slicesx.Batch(assets, 100) for i := range assetBatches { batch := assetBatches[i] @@ -304,20 +309,17 @@ func (s *LocalScanner) distributeJob(job *Job, ctx context.Context, upstream *up } } } + } - // if a bundle was provided check that it matches the filter, bundles can also be downloaded - // later therefore we do not want to stop execution here - if job.Bundle != nil && job.Bundle.FilterQueryPacks(job.QueryPackFilters) { - return nil, errors.New("all available packs filtered out. nothing to do") - } - + for k := range discoveredAssets.RootAssets { + root := discoveredAssets.RootAssets[k] wg := sync.WaitGroup{} wg.Add(1) go func() { defer wg.Done() - for i := range batch { - asset := batch[i].Asset - runtime := batch[i].Runtime + for i := range root.Children { + asset := root.Children[i].Asset + runtime := root.Children[i].Runtime // Make sure the context has not been canceled in the meantime. Note that this approach works only for single threaded execution. If we have more than 1 thread calling this function, // we need to solve this at a different level. diff --git a/providers/builtin.go b/providers/builtin.go index b9714687c4..81c4bad31a 100644 --- a/providers/builtin.go +++ b/providers/builtin.go @@ -35,7 +35,7 @@ var builtinProviders = map[string]*builtinProvider{ Runtime: &RunningProvider{ Name: mockProvider.Name, ID: mockProvider.ID, - Plugin: &mockProviderService{coordinator: &Coordinator}, + Plugin: &mockProviderService{coordinator: GlobalCoordinator}, isClosed: false, }, Config: mockProvider.Provider, diff --git a/providers/defaults_shared.go b/providers/defaults_shared.go index d5d85e3426..2817fccf22 100644 --- a/providers/defaults_shared.go +++ b/providers/defaults_shared.go @@ -21,7 +21,7 @@ var defaultRuntime *Runtime func DefaultRuntime() *Runtime { if defaultRuntime == nil { - defaultRuntime = Coordinator.NewRuntime() + defaultRuntime = GlobalCoordinator.NewRuntime() } return defaultRuntime } diff --git a/providers/coordinator.go b/providers/global_coordinator.go similarity index 93% rename from providers/coordinator.go rename to providers/global_coordinator.go index fea0969bd9..f7a43ad4e7 100644 --- a/providers/coordinator.go +++ b/providers/global_coordinator.go @@ -24,9 +24,21 @@ import ( "google.golang.org/grpc/status" ) +type Coordinator interface { + Start(id string, isEphemeral bool, update UpdateProvidersConfig) (*RunningProvider, error) + Stop(provider *RunningProvider, isEphemeral bool) error + NewRuntime() *Runtime + RuntimeFor(asset *inventory.Asset, parent *Runtime) (*Runtime, error) + GetRunningProviderById(id string) *RunningProvider + GetProviders() Providers + SetProviders(providers Providers) + LoadSchema(name string) (*resources.Schema, error) + Shutdown() +} + var BuiltinCoreID = coreconf.Config.ID -var Coordinator = coordinator{ +var GlobalCoordinator Coordinator = &coordinator{ RunningByID: map[string]*RunningProvider{}, RunningEphemeral: map[*RunningProvider]struct{}{}, runtimes: map[string]*Runtime{}, @@ -262,6 +274,20 @@ func (c *coordinator) Start(id string, isEphemeral bool, update UpdateProvidersC return res, nil } +func (c *coordinator) GetRunningProviderById(id string) *RunningProvider { + c.mutex.Lock() + defer c.mutex.Unlock() + return c.RunningByID[id] +} + +func (c *coordinator) GetProviders() Providers { + return c.Providers +} + +func (c *coordinator) SetProviders(providers Providers) { + c.Providers = providers +} + type ProviderVersions struct { Providers []ProviderVersion `json:"providers"` } @@ -376,7 +402,7 @@ func (c *coordinator) NewRuntimeFrom(parent *Runtime) *Runtime { return res } -// RuntimFor an asset will return a new or existing runtime for a given asset. +// RuntimeFor an asset will return a new or existing runtime for a given asset. // If a runtime for this asset already exists, it will re-use it. If the runtime // is new, it will create it and detect the provider. // The asset and parent must be defined. @@ -468,6 +494,7 @@ func (c *coordinator) Stop(provider *RunningProvider, isEphemeral bool) error { func (c *coordinator) Shutdown() { c.mutex.Lock() + defer c.mutex.Unlock() for cur := range c.RunningEphemeral { if err := cur.Shutdown(); err != nil { @@ -489,8 +516,6 @@ func (c *coordinator) Shutdown() { c.runtimes = map[string]*Runtime{} c.runtimeCnt = 0 c.unprocessedRuntimes = []*Runtime{} - - c.mutex.Unlock() } // LoadSchema for a given provider. Providers also cache their Schemas, so diff --git a/providers/local_coordinator.go b/providers/local_coordinator.go new file mode 100644 index 0000000000..e24a84a624 --- /dev/null +++ b/providers/local_coordinator.go @@ -0,0 +1,104 @@ +// Copyright (c) Mondoo, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package providers + +import ( + "sync" + + "go.mondoo.com/cnquery/v10/providers-sdk/v1/inventory" + "go.mondoo.com/cnquery/v10/providers-sdk/v1/resources" +) + +type localCoordinator struct { + parent Coordinator + + runningByID map[string]*RunningProvider + runningEphemeral map[*RunningProvider]struct{} + mutex sync.Mutex +} + +func NewLocalCoordinator(parent Coordinator) Coordinator { + return &localCoordinator{ + parent: parent, + runningByID: map[string]*RunningProvider{}, + runningEphemeral: map[*RunningProvider]struct{}{}, + } +} + +func (lc *localCoordinator) Start(id string, isEphemeral bool, update UpdateProvidersConfig) (*RunningProvider, error) { + // From the parent's perspective, all providers from its children are ephemeral + provider, err := lc.parent.Start(id, true, update) + if err != nil { + return nil, err + } + + lc.mutex.Lock() + if isEphemeral { + lc.runningEphemeral[provider] = struct{}{} + } else { + lc.runningByID[provider.ID] = provider + } + lc.mutex.Unlock() + return provider, nil +} + +func (lc *localCoordinator) Stop(provider *RunningProvider, isEphemeral bool) error { + if provider == nil { + return nil + } + + lc.mutex.Lock() + defer lc.mutex.Unlock() + + if isEphemeral { + delete(lc.runningEphemeral, provider) + } else { + found := lc.runningByID[provider.ID] + if found != nil { + delete(lc.runningByID, provider.ID) + } + } + return lc.parent.Stop(provider, true) +} + +func (lc *localCoordinator) NewRuntime() *Runtime { + return lc.parent.NewRuntime() +} + +func (lc *localCoordinator) RuntimeFor(asset *inventory.Asset, parent *Runtime) (*Runtime, error) { + return lc.parent.RuntimeFor(asset, parent) +} + +func (lc *localCoordinator) GetRunningProviderById(id string) *RunningProvider { + lc.mutex.Lock() + defer lc.mutex.Unlock() + return lc.runningByID[id] +} + +func (lc *localCoordinator) GetProviders() Providers { + return lc.parent.GetProviders() +} + +func (lc *localCoordinator) SetProviders(providers Providers) { + lc.parent.SetProviders(providers) +} + +func (lc *localCoordinator) LoadSchema(name string) (*resources.Schema, error) { + return lc.parent.LoadSchema(name) +} + +func (lc *localCoordinator) Shutdown() { + lc.mutex.Lock() + defer lc.mutex.Unlock() + + for cur := range lc.runningEphemeral { + lc.parent.Stop(cur, true) + } + lc.runningEphemeral = map[*RunningProvider]struct{}{} + + for _, runtime := range lc.runningByID { + lc.parent.Stop(runtime, true) + } + lc.runningByID = map[string]*RunningProvider{} +} diff --git a/providers/mock.go b/providers/mock.go index 10fb847ec0..32b86c02e7 100644 --- a/providers/mock.go +++ b/providers/mock.go @@ -23,7 +23,7 @@ var mockProvider = Provider{ } type mockProviderService struct { - coordinator *coordinator + coordinator Coordinator initialized bool runtime *Runtime } diff --git a/providers/providers.go b/providers/providers.go index 63f493bafd..a49a52fc4f 100644 --- a/providers/providers.go +++ b/providers/providers.go @@ -206,7 +206,7 @@ func ListActive() (Providers, error) { } // useful for caching; even if the structure gets updated with new providers - Coordinator.Providers = res + GlobalCoordinator.SetProviders(res) return res, nil } diff --git a/providers/runtime.go b/providers/runtime.go index 2997581526..6e6c689cc8 100644 --- a/providers/runtime.go +++ b/providers/runtime.go @@ -33,7 +33,7 @@ type Runtime struct { recording llx.Recording features []byte // coordinator is used to grab providers - coordinator *coordinator + coordinator Coordinator // providers for with open connections providers map[string]*ConnectedProvider // schema aggregates all resources executable on this asset @@ -149,7 +149,7 @@ func (r *Runtime) addProvider(id string, isEphemeral bool) (*ConnectedProvider, } else { // TODO: we need to detect only the shared running providers - running = r.coordinator.RunningByID[id] + running = r.coordinator.GetRunningProviderById(id) if running == nil { var err error running, err = r.coordinator.Start(id, false, r.AutoUpdate) @@ -193,7 +193,7 @@ func (r *Runtime) providerForAsset(asset *inventory.Asset) (*Provider, error) { conn.Type = inventory.ConnBackendToType(conn.Backend) } - provider, err := EnsureProvider(ProviderLookup{ConnType: conn.Type}, true, r.coordinator.Providers) + provider, err := EnsureProvider(ProviderLookup{ConnType: conn.Type}, true, r.coordinator.GetProviders()) if err != nil { errs.Add(err) continue From d856b59769eab2263f1c78f208b56223535616ab Mon Sep 17 00:00:00 2001 From: Ivan Milchev Date: Tue, 6 Feb 2024 16:47:51 +0200 Subject: [PATCH 2/5] fix local coordinator logic Signed-off-by: Ivan Milchev --- explorer/scan/local_scanner.go | 3 +++ providers/global_coordinator.go | 1 + providers/local_coordinator.go | 36 ++++++++++++++++++++++++++------- 3 files changed, 33 insertions(+), 7 deletions(-) diff --git a/explorer/scan/local_scanner.go b/explorer/scan/local_scanner.go index 2b4197d3d2..521966c612 100644 --- a/explorer/scan/local_scanner.go +++ b/explorer/scan/local_scanner.go @@ -351,6 +351,9 @@ func (s *LocalScanner) distributeJob(job *Job, ctx context.Context, upstream *up } }() wg.Wait() + + // Shutdown the coordinator for the current root asset + root.Coordinator.Shutdown() } scanGroups.Wait() return reporter.Reports(), nil diff --git a/providers/global_coordinator.go b/providers/global_coordinator.go index f7a43ad4e7..5eae2b41a2 100644 --- a/providers/global_coordinator.go +++ b/providers/global_coordinator.go @@ -28,6 +28,7 @@ type Coordinator interface { Start(id string, isEphemeral bool, update UpdateProvidersConfig) (*RunningProvider, error) Stop(provider *RunningProvider, isEphemeral bool) error NewRuntime() *Runtime + NewRuntimeFrom(parent *Runtime) *Runtime RuntimeFor(asset *inventory.Asset, parent *Runtime) (*Runtime, error) GetRunningProviderById(id string) *RunningProvider GetProviders() Providers diff --git a/providers/local_coordinator.go b/providers/local_coordinator.go index e24a84a624..276902c350 100644 --- a/providers/local_coordinator.go +++ b/providers/local_coordinator.go @@ -6,6 +6,7 @@ package providers import ( "sync" + "github.com/rs/zerolog/log" "go.mondoo.com/cnquery/v10/providers-sdk/v1/inventory" "go.mondoo.com/cnquery/v10/providers-sdk/v1/resources" ) @@ -51,6 +52,10 @@ func (lc *localCoordinator) Stop(provider *RunningProvider, isEphemeral bool) er lc.mutex.Lock() defer lc.mutex.Unlock() + if err := lc.parent.Stop(provider, true); err != nil { + return err + } + if isEphemeral { delete(lc.runningEphemeral, provider) } else { @@ -59,15 +64,30 @@ func (lc *localCoordinator) Stop(provider *RunningProvider, isEphemeral bool) er delete(lc.runningByID, provider.ID) } } - return lc.parent.Stop(provider, true) + return nil } func (lc *localCoordinator) NewRuntime() *Runtime { - return lc.parent.NewRuntime() + runtime := lc.parent.NewRuntime() + // Override the coordinator with the local one, so providers are managed + // by the local coordinator + runtime.coordinator = lc + return runtime +} + +func (lc *localCoordinator) NewRuntimeFrom(parent *Runtime) *Runtime { + res := lc.NewRuntime() + res.recording = parent.Recording() + for k, v := range parent.providers { + res.providers[k] = v + } + return res } func (lc *localCoordinator) RuntimeFor(asset *inventory.Asset, parent *Runtime) (*Runtime, error) { - return lc.parent.RuntimeFor(asset, parent) + runtime := lc.parent.NewRuntimeFrom(parent) + runtime.coordinator = lc + return runtime, runtime.DetectProvider(asset) } func (lc *localCoordinator) GetRunningProviderById(id string) *RunningProvider { @@ -92,13 +112,15 @@ func (lc *localCoordinator) Shutdown() { lc.mutex.Lock() defer lc.mutex.Unlock() - for cur := range lc.runningEphemeral { - lc.parent.Stop(cur, true) + for provider := range lc.runningEphemeral { + log.Debug().Str("provider", provider.Name).Msg("Shutting down ephemeral provider") + lc.parent.Stop(provider, true) } lc.runningEphemeral = map[*RunningProvider]struct{}{} - for _, runtime := range lc.runningByID { - lc.parent.Stop(runtime, true) + for _, provider := range lc.runningByID { + log.Debug().Str("provider", provider.Name).Msg("Shutting down provider") + lc.parent.Stop(provider, true) } lc.runningByID = map[string]*RunningProvider{} } From bf90bcd05de9f194055b5cdf399b6f1aa63d2187 Mon Sep 17 00:00:00 2001 From: Ivan Milchev Date: Tue, 6 Feb 2024 17:22:53 +0200 Subject: [PATCH 3/5] fix tests Signed-off-by: Ivan Milchev --- explorer/scan/discovery.go | 2 +- explorer/scan/discovery_test.go | 87 +++++++++++++++++-------- providers-sdk/v1/testutils/testutils.go | 2 +- 3 files changed, 62 insertions(+), 29 deletions(-) diff --git a/explorer/scan/discovery.go b/explorer/scan/discovery.go index 7991890257..d187593d09 100644 --- a/explorer/scan/discovery.go +++ b/explorer/scan/discovery.go @@ -90,7 +90,7 @@ func (d *DiscoveredAssets) GetAssetsByPlatformID(platformID string) []*inventory for _, c := range a.Children { for _, p := range c.Asset.PlatformIds { if platformID == "" || p == platformID { - assets = append(assets, a.Asset) + assets = append(assets, c.Asset) break } } diff --git a/explorer/scan/discovery_test.go b/explorer/scan/discovery_test.go index 1e7afd4ecf..3545340d65 100644 --- a/explorer/scan/discovery_test.go +++ b/explorer/scan/discovery_test.go @@ -15,57 +15,86 @@ import ( inventory "go.mondoo.com/cnquery/v10/providers-sdk/v1/inventory" ) +func TestDiscoveredAssets_AddRoot(t *testing.T) { + d := &DiscoveredAssets{ + platformIds: map[string]struct{}{}, + RootAssets: map[*inventory.Asset]*RootAsset{}, + Errors: []*AssetWithError{}, + } + + root := &inventory.Asset{} + assert.True(t, d.AddRoot(root, nil, nil)) + assert.Len(t, d.GetFlattenedChildren(), 0) + assert.Len(t, d.Errors, 0) + + // Make sure adding duplicates is not possible + assert.False(t, d.AddRoot(root, nil, nil)) + assert.Len(t, d.GetFlattenedChildren(), 0) + assert.Len(t, d.Errors, 0) +} + func TestDiscoveredAssets_Add(t *testing.T) { d := &DiscoveredAssets{ platformIds: map[string]struct{}{}, - Assets: []*AssetWithRuntime{}, + RootAssets: map[*inventory.Asset]*RootAsset{}, Errors: []*AssetWithError{}, } + + root := &inventory.Asset{} + assert.True(t, d.AddRoot(root, nil, nil)) + asset := &inventory.Asset{ PlatformIds: []string{"platform1"}, } runtime := &providers.Runtime{} - assert.True(t, d.Add(asset, runtime)) - assert.Len(t, d.Assets, 1) + assert.True(t, d.Add(root, asset, runtime)) + assert.Len(t, d.GetFlattenedChildren(), 1) assert.Len(t, d.Errors, 0) // Make sure adding duplicates is not possible - assert.False(t, d.Add(asset, runtime)) - assert.Len(t, d.Assets, 1) + assert.False(t, d.Add(root, asset, runtime)) + assert.Len(t, d.GetFlattenedChildren(), 1) assert.Len(t, d.Errors, 0) } func TestDiscoveredAssets_Add_MultiplePlatformIDs(t *testing.T) { d := &DiscoveredAssets{ platformIds: map[string]struct{}{}, - Assets: []*AssetWithRuntime{}, + RootAssets: map[*inventory.Asset]*RootAsset{}, Errors: []*AssetWithError{}, } + + root := &inventory.Asset{} + assert.True(t, d.AddRoot(root, nil, nil)) + asset := &inventory.Asset{ PlatformIds: []string{"platform1", "platform2"}, } runtime := &providers.Runtime{} - assert.True(t, d.Add(asset, runtime)) - assert.Len(t, d.Assets, 1) + assert.True(t, d.Add(root, asset, runtime)) + assert.Len(t, d.GetFlattenedChildren(), 1) assert.Len(t, d.Errors, 0) // Make sure adding duplicates is not possible - assert.False(t, d.Add(&inventory.Asset{ + assert.False(t, d.Add(root, &inventory.Asset{ PlatformIds: []string{"platform3", asset.PlatformIds[0]}, }, runtime)) - assert.Len(t, d.Assets, 1) + assert.Len(t, d.GetFlattenedChildren(), 1) assert.Len(t, d.Errors, 0) } func TestDiscoveredAssets_GetAssetsByPlatformID(t *testing.T) { d := &DiscoveredAssets{ platformIds: map[string]struct{}{}, - Assets: []*AssetWithRuntime{}, + RootAssets: map[*inventory.Asset]*RootAsset{}, Errors: []*AssetWithError{}, } + root := &inventory.Asset{} + assert.True(t, d.AddRoot(root, nil, nil)) + allPlatformIds := []string{} for i := 0; i < 10; i++ { pId := fmt.Sprintf("platform1%d", i) @@ -75,9 +104,9 @@ func TestDiscoveredAssets_GetAssetsByPlatformID(t *testing.T) { } runtime := &providers.Runtime{} - assert.True(t, d.Add(asset, runtime)) + assert.True(t, d.Add(root, asset, runtime)) } - assert.Len(t, d.Assets, 10) + assert.Len(t, d.GetFlattenedChildren(), 10) // Make sure adding duplicates is not possible assets := d.GetAssetsByPlatformID(allPlatformIds[0]) @@ -88,10 +117,13 @@ func TestDiscoveredAssets_GetAssetsByPlatformID(t *testing.T) { func TestDiscoveredAssets_GetAssetsByPlatformID_Empty(t *testing.T) { d := &DiscoveredAssets{ platformIds: map[string]struct{}{}, - Assets: []*AssetWithRuntime{}, + RootAssets: map[*inventory.Asset]*RootAsset{}, Errors: []*AssetWithError{}, } + root := &inventory.Asset{} + assert.True(t, d.AddRoot(root, nil, nil)) + allPlatformIds := []string{} for i := 0; i < 10; i++ { pId := fmt.Sprintf("platform1%d", i) @@ -101,9 +133,9 @@ func TestDiscoveredAssets_GetAssetsByPlatformID_Empty(t *testing.T) { } runtime := &providers.Runtime{} - assert.True(t, d.Add(asset, runtime)) + assert.True(t, d.Add(root, asset, runtime)) } - assert.Len(t, d.Assets, 10) + assert.Len(t, d.GetFlattenedChildren(), 10) // Make sure adding duplicates is not possible assets := d.GetAssetsByPlatformID("") @@ -143,11 +175,12 @@ func TestDiscoverAssets(t *testing.T) { inv := getInventory() discoveredAssets, err := DiscoverAssets(context.Background(), inv, nil, providers.NullRecording{}) require.NoError(t, err) - assert.Len(t, discoveredAssets.Assets, 3) + assets := discoveredAssets.GetFlattenedChildren() + assert.Len(t, assets, 3) assert.Len(t, discoveredAssets.Errors, 0) - assert.Equal(t, "mondoo-operator-123", discoveredAssets.Assets[0].Asset.ManagedBy) - assert.Equal(t, "mondoo-operator-123", discoveredAssets.Assets[1].Asset.ManagedBy) - assert.Equal(t, "mondoo-operator-123", discoveredAssets.Assets[2].Asset.ManagedBy) + assert.Equal(t, "mondoo-operator-123", assets[0].Asset.ManagedBy) + assert.Equal(t, "mondoo-operator-123", assets[1].Asset.ManagedBy) + assert.Equal(t, "mondoo-operator-123", assets[2].Asset.ManagedBy) }) t.Run("with duplicate root assets", func(t *testing.T) { @@ -157,7 +190,7 @@ func TestDiscoverAssets(t *testing.T) { require.NoError(t, err) // Make sure no duplicates are returned - assert.Len(t, discoveredAssets.Assets, 3) + assert.Len(t, discoveredAssets.GetFlattenedChildren(), 3) assert.Len(t, discoveredAssets.Errors, 0) }) @@ -168,7 +201,7 @@ func TestDiscoverAssets(t *testing.T) { require.NoError(t, err) // Make sure no duplicates are returned - assert.Len(t, discoveredAssets.Assets, 3) + assert.Len(t, discoveredAssets.GetFlattenedChildren(), 3) assert.Len(t, discoveredAssets.Errors, 0) }) @@ -181,7 +214,7 @@ func TestDiscoverAssets(t *testing.T) { discoveredAssets, err := DiscoverAssets(context.Background(), inv, nil, providers.NullRecording{}) require.NoError(t, err) - for _, asset := range discoveredAssets.Assets { + for _, asset := range discoveredAssets.GetFlattenedChildren() { for k, v := range inv.Spec.Assets[0].Annotations { require.Contains(t, asset.Asset.Annotations, k) assert.Equal(t, v, asset.Asset.Annotations[k]) @@ -195,7 +228,7 @@ func TestDiscoverAssets(t *testing.T) { discoveredAssets, err := DiscoverAssets(context.Background(), inv, nil, providers.NullRecording{}) require.NoError(t, err) - for _, asset := range discoveredAssets.Assets { + for _, asset := range discoveredAssets.GetFlattenedChildren() { assert.Equal(t, inv.Spec.Assets[0].ManagedBy, asset.Asset.ManagedBy) } }) @@ -216,7 +249,7 @@ func TestDiscoverAssets(t *testing.T) { discoveredAssets, err := DiscoverAssets(context.Background(), inv, nil, providers.NullRecording{}) require.NoError(t, err) - for _, asset := range discoveredAssets.Assets { + for _, asset := range discoveredAssets.GetFlattenedChildren() { require.Contains(t, asset.Asset.Labels, "mondoo.com/exec-environment") assert.Equal(t, "actions.github.com", asset.Asset.Labels["mondoo.com/exec-environment"]) } @@ -239,7 +272,7 @@ func TestDiscoverAssets(t *testing.T) { discoveredAssets, err := DiscoverAssets(context.Background(), inv, nil, providers.NullRecording{}) require.NoError(t, err) - for _, asset := range discoveredAssets.Assets { + for _, asset := range discoveredAssets.GetFlattenedChildren() { require.Contains(t, asset.Asset.Labels, "mondoo.com/exec-environment") assert.Equal(t, "actions.github.com", asset.Asset.Labels["mondoo.com/exec-environment"]) } @@ -251,6 +284,6 @@ func TestDiscoverAssets(t *testing.T) { discoveredAssets, err := DiscoverAssets(context.Background(), inv, nil, providers.NullRecording{}) require.NoError(t, err) - assert.Len(t, discoveredAssets.Assets, 1) + assert.Len(t, discoveredAssets.GetFlattenedChildren(), 1) }) } diff --git a/providers-sdk/v1/testutils/testutils.go b/providers-sdk/v1/testutils/testutils.go index 1ec57740d4..899a6921f7 100644 --- a/providers-sdk/v1/testutils/testutils.go +++ b/providers-sdk/v1/testutils/testutils.go @@ -200,7 +200,7 @@ func Local() llx.Runtime { networkSchema := MustLoadSchema(SchemaProvider{Provider: "network"}) mockSchema := MustLoadSchema(SchemaProvider{Provider: "mockprovider"}) - runtime := providers.Coordinator.NewRuntime() + runtime := providers.GlobalCoordinator.NewRuntime() provider := &providers.RunningProvider{ Name: osconf.Config.Name, From 82692cc94b2ba100b6b6310cca4dcfaae93b18ab Mon Sep 17 00:00:00 2001 From: Ivan Milchev Date: Tue, 6 Feb 2024 17:43:16 +0200 Subject: [PATCH 4/5] move some things around in local scanner to make output look better Signed-off-by: Ivan Milchev --- explorer/scan/local_scanner.go | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/explorer/scan/local_scanner.go b/explorer/scan/local_scanner.go index 521966c612..48ffadd5cf 100644 --- a/explorer/scan/local_scanner.go +++ b/explorer/scan/local_scanner.go @@ -231,19 +231,7 @@ func (s *LocalScanner) distributeJob(job *Job, ctx context.Context, upstream *up return reporter.Reports(), nil } - multiprogress, err := CreateProgressBar(discoveredAssets, s.disableProgressBar) - if err != nil { - return nil, err - } - // start the progress bar - scanGroups := sync.WaitGroup{} - scanGroups.Add(1) - go func() { - defer scanGroups.Done() - if err := multiprogress.Open(); err != nil { - log.Error().Err(err).Msg("failed to open progress bar") - } - }() + log.Info().Msgf("discovered %d assets", len(assets)) // if a bundle was provided check that it matches the filter, bundles can also be downloaded // later therefore we do not want to stop execution here @@ -257,7 +245,7 @@ func (s *LocalScanner) distributeJob(job *Job, ctx context.Context, upstream *up // sync assets if upstream != nil && upstream.ApiEndpoint != "" && !upstream.Incognito { - log.Info().Msg("synchronize assets") + log.Info().Msgf("synchronize %d assets", len(batch)) client, err := upstream.InitClient() if err != nil { return nil, err @@ -311,6 +299,20 @@ func (s *LocalScanner) distributeJob(job *Job, ctx context.Context, upstream *up } } + multiprogress, err := CreateProgressBar(discoveredAssets, s.disableProgressBar) + if err != nil { + return nil, err + } + // start the progress bar + scanGroups := sync.WaitGroup{} + scanGroups.Add(1) + go func() { + defer scanGroups.Done() + if err := multiprogress.Open(); err != nil { + log.Error().Err(err).Msg("failed to open progress bar") + } + }() + for k := range discoveredAssets.RootAssets { root := discoveredAssets.RootAssets[k] wg := sync.WaitGroup{} From 7f1d9853ed87b3cd59d7b04e407365736c035f64 Mon Sep 17 00:00:00 2001 From: Ivan Milchev Date: Tue, 6 Feb 2024 17:46:26 +0200 Subject: [PATCH 5/5] fix linter errors Signed-off-by: Ivan Milchev --- providers/local_coordinator.go | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/providers/local_coordinator.go b/providers/local_coordinator.go index 276902c350..8bdf45ce11 100644 --- a/providers/local_coordinator.go +++ b/providers/local_coordinator.go @@ -113,14 +113,18 @@ func (lc *localCoordinator) Shutdown() { defer lc.mutex.Unlock() for provider := range lc.runningEphemeral { - log.Debug().Str("provider", provider.Name).Msg("Shutting down ephemeral provider") - lc.parent.Stop(provider, true) + log.Debug().Str("provider", provider.Name).Msg("shutting down ephemeral provider") + if err := lc.parent.Stop(provider, true); err != nil { + log.Error().Err(err).Str("provider", provider.Name).Msg("error stopping ephemeral provider") + } } lc.runningEphemeral = map[*RunningProvider]struct{}{} for _, provider := range lc.runningByID { - log.Debug().Str("provider", provider.Name).Msg("Shutting down provider") - lc.parent.Stop(provider, true) + log.Debug().Str("provider", provider.Name).Msg("shutting down provider") + if err := lc.parent.Stop(provider, true); err != nil { + log.Error().Err(err).Str("provider", provider.Name).Msg("error stopping provider") + } } lc.runningByID = map[string]*RunningProvider{} }