diff --git a/apps/cnquery/cmd/login.go b/apps/cnquery/cmd/login.go index 2699654ef8..22bd37ac53 100644 --- a/apps/cnquery/cmd/login.go +++ b/apps/cnquery/cmd/login.go @@ -52,7 +52,7 @@ You remain logged in until you explicitly log out using the 'logout' subcommand. viper.BindPFlag("name", cmd.Flags().Lookup("name")) }, RunE: func(cmd *cobra.Command, args []string) error { - defer cnquery_providers.Coordinator.Shutdown() + defer cnquery_providers.GlobalCoordinator.Shutdown() token, _ := cmd.Flags().GetString("token") annotations, _ := cmd.Flags().GetStringToString("annotation") return register(token, annotations) diff --git a/apps/cnquery/cmd/logout.go b/apps/cnquery/cmd/logout.go index 08dd8c0045..9c68d1297e 100644 --- a/apps/cnquery/cmd/logout.go +++ b/apps/cnquery/cmd/logout.go @@ -36,7 +36,7 @@ ensure the credentials cannot be used in the future. viper.BindPFlag("force", cmd.Flags().Lookup("force")) }, RunE: func(cmd *cobra.Command, args []string) error { - defer cnquery_providers.Coordinator.Shutdown() + defer cnquery_providers.GlobalCoordinator.Shutdown() var err error // its perfectly fine not to have a config here, therefore we ignore errors diff --git a/apps/cnquery/cmd/plugin.go b/apps/cnquery/cmd/plugin.go index c216f0939e..31c13c9eec 100644 --- a/apps/cnquery/cmd/plugin.go +++ b/apps/cnquery/cmd/plugin.go @@ -129,7 +129,7 @@ func (c *cnqueryPlugin) RunQuery(conf *run.RunQueryConfig, runtime *providers.Ru return err } - connectAssetRuntime, err := providers.Coordinator.RuntimeFor(asset, runtime) + connectAssetRuntime, err := providers.GlobalCoordinator.RuntimeFor(asset, runtime) if err != nil { return err } diff --git a/apps/cnquery/cmd/shell.go b/apps/cnquery/cmd/shell.go index 6db9ee4a10..4c5607c156 100644 --- a/apps/cnquery/cmd/shell.go +++ b/apps/cnquery/cmd/shell.go @@ -153,7 +153,7 @@ func StartShell(runtime *providers.Runtime, conf *ShellConfig) error { // when we close the shell, we need to close the backend and store the recording onCloseHandler := func() { runtime.Close() - providers.Coordinator.Shutdown() + providers.GlobalCoordinator.Shutdown() } shellOptions := []shell.ShellOption{} diff --git a/apps/cnquery/cmd/status.go b/apps/cnquery/cmd/status.go index 841d0eed4d..177d09cf83 100644 --- a/apps/cnquery/cmd/status.go +++ b/apps/cnquery/cmd/status.go @@ -43,7 +43,7 @@ Status sends a ping to Mondoo Platform to verify the credentials. viper.BindPFlag("output", cmd.Flags().Lookup("output")) }, RunE: func(cmd *cobra.Command, args []string) error { - defer providers.Coordinator.Shutdown() + defer providers.GlobalCoordinator.Shutdown() opts, optsErr := config.Read() if optsErr != nil { return cli_errors.NewCommandError(errors.Wrap(optsErr, "could not load configuration"), 1) diff --git a/cli/providers/providers.go b/cli/providers/providers.go index 425cb6d4e0..e115c723ff 100644 --- a/cli/providers/providers.go +++ b/cli/providers/providers.go @@ -424,7 +424,7 @@ func setConnector(provider *plugin.Provider, connector *plugin.Connector, run fu } // TODO: add flag to set timeout and then use RuntimeWithShutdownTimeout - runtime := providers.Coordinator.NewRuntime() + runtime := providers.GlobalCoordinator.NewRuntime() if err = providers.SetDefaultRuntime(runtime); err != nil { log.Error().Msg(err.Error()) } @@ -440,12 +440,12 @@ func setConnector(provider *plugin.Provider, connector *plugin.Connector, run fu } if err := runtime.UseProvider(provider.ID); err != nil { - providers.Coordinator.Shutdown() + providers.GlobalCoordinator.Shutdown() log.Fatal().Err(err).Msg("failed to start provider " + provider.Name) } if record != "" && useRecording != "" { - providers.Coordinator.Shutdown() + providers.GlobalCoordinator.Shutdown() log.Fatal().Msg("please only use --record or --use-recording, but not both at the same time") } recordingPath := record @@ -459,7 +459,7 @@ func setConnector(provider *plugin.Provider, connector *plugin.Connector, run fu PrettyPrintJSON: pretty, }) if err != nil { - providers.Coordinator.Shutdown() + providers.GlobalCoordinator.Shutdown() log.Fatal().Msg(err.Error()) } runtime.SetRecording(recording) @@ -471,13 +471,13 @@ func setConnector(provider *plugin.Provider, connector *plugin.Connector, run fu }) if err != nil { runtime.Close() - providers.Coordinator.Shutdown() + providers.GlobalCoordinator.Shutdown() log.Fatal().Err(err).Msg("failed to parse cli arguments") } if cliRes == nil { runtime.Close() - providers.Coordinator.Shutdown() + providers.GlobalCoordinator.Shutdown() log.Fatal().Msg("failed to process CLI arguments, nothing was returned") return // adding this here as a compiler hint to stop warning about nil-dereferences } @@ -485,7 +485,7 @@ func setConnector(provider *plugin.Provider, connector *plugin.Connector, run fu if cliRes.Asset == nil { log.Warn().Err(err).Msg("failed to discover assets after processing CLI arguments") } else { - assetRuntime, err := providers.Coordinator.RuntimeFor(cliRes.Asset, runtime) + assetRuntime, err := providers.GlobalCoordinator.RuntimeFor(cliRes.Asset, runtime) if err != nil { log.Warn().Err(err).Msg("failed to get runtime for an asset that was detected after parsing the CLI") } else { @@ -495,7 +495,7 @@ func setConnector(provider *plugin.Provider, connector *plugin.Connector, run fu run(cc, runtime, cliRes) runtime.Close() - providers.Coordinator.Shutdown() + providers.GlobalCoordinator.Shutdown() } attachFlags(cmd.Flags(), allFlags) diff --git a/explorer/scan/discovery.go b/explorer/scan/discovery.go index 129b0bb702..d187593d09 100644 --- a/explorer/scan/discovery.go +++ b/explorer/scan/discovery.go @@ -18,6 +18,12 @@ import ( "go.mondoo.com/cnquery/v10/providers-sdk/v1/upstream" ) +type RootAsset struct { + Coordinator providers.Coordinator + Asset *inventory.Asset + Children []*AssetWithRuntime +} + type AssetWithRuntime struct { Asset *inventory.Asset Runtime *providers.Runtime @@ -30,13 +36,26 @@ type AssetWithError struct { type DiscoveredAssets struct { platformIds map[string]struct{} - Assets []*AssetWithRuntime - Errors []*AssetWithError + // Assets []*AssetWithRuntime + Errors []*AssetWithError + + RootAssets map[*inventory.Asset]*RootAsset +} + +func (d *DiscoveredAssets) AddRoot(root *inventory.Asset, coordinator providers.Coordinator, runtime *providers.Runtime) bool { + if _, ok := d.RootAssets[root]; ok { + return false + } + d.RootAssets[root] = &RootAsset{ + Coordinator: coordinator, + Asset: root, + } + return true } // Add adds an asset and its runtime to the discovered assets list. It returns true if the // asset has been added, false if it is a duplicate -func (d *DiscoveredAssets) Add(asset *inventory.Asset, runtime *providers.Runtime) bool { +func (d *DiscoveredAssets) Add(root, asset *inventory.Asset, runtime *providers.Runtime) bool { isDuplicate := false for _, platformId := range asset.PlatformIds { if _, ok := d.platformIds[platformId]; ok { @@ -49,7 +68,7 @@ func (d *DiscoveredAssets) Add(asset *inventory.Asset, runtime *providers.Runtim return false } - d.Assets = append(d.Assets, &AssetWithRuntime{Asset: asset, Runtime: runtime}) + d.RootAssets[root].Children = append(d.RootAssets[root].Children, &AssetWithRuntime{Asset: asset, Runtime: runtime}) return true } @@ -57,13 +76,23 @@ func (d *DiscoveredAssets) AddError(asset *inventory.Asset, err error) { d.Errors = append(d.Errors, &AssetWithError{Asset: asset, Err: err}) } +func (d *DiscoveredAssets) GetFlattenedChildren() []*AssetWithRuntime { + var assets []*AssetWithRuntime + for _, a := range d.RootAssets { + assets = append(assets, a.Children...) + } + return assets +} + func (d *DiscoveredAssets) GetAssetsByPlatformID(platformID string) []*inventory.Asset { var assets []*inventory.Asset - for _, a := range d.Assets { - for _, p := range a.Asset.PlatformIds { - if platformID == "" || p == platformID { - assets = append(assets, a.Asset) - break + for _, a := range d.RootAssets { + for _, c := range a.Children { + for _, p := range c.Asset.PlatformIds { + if platformID == "" || p == platformID { + assets = append(assets, c.Asset) + break + } } } } @@ -91,7 +120,7 @@ func DiscoverAssets(ctx context.Context, inv *inventory.Inventory, upstream *ups runtimeLabels = runtimeEnv.Labels() } - discoveredAssets := &DiscoveredAssets{platformIds: map[string]struct{}{}} + discoveredAssets := &DiscoveredAssets{platformIds: map[string]struct{}{}, RootAssets: map[*inventory.Asset]*RootAsset{}} // we connect and perform discovery for each asset in the job inventory for _, rootAsset := range invAssets { @@ -100,20 +129,29 @@ func DiscoverAssets(ctx context.Context, inv *inventory.Inventory, upstream *ups return nil, err } + coordinator := providers.NewLocalCoordinator(providers.GlobalCoordinator) + // create runtime for root asset - rootAssetWithRuntime, err := createRuntimeForAsset(resolvedRootAsset, upstream, recording) + rootAssetWithRuntime, err := createRuntimeForAsset(resolvedRootAsset, coordinator, upstream, recording) if err != nil { log.Error().Err(err).Str("asset", resolvedRootAsset.Name).Msg("unable to create runtime for asset") discoveredAssets.AddError(rootAssetWithRuntime.Asset, err) + coordinator.Shutdown() continue } resolvedRootAsset = rootAssetWithRuntime.Asset // to ensure we get all the information the connect call gave us + if !discoveredAssets.AddRoot(resolvedRootAsset, coordinator, rootAssetWithRuntime.Runtime) { + rootAssetWithRuntime.Runtime.Close() + coordinator.Shutdown() + continue + } + // If the root asset has platform IDs, then it is a scannable asset, so we need to add it if len(resolvedRootAsset.PlatformIds) > 0 { prepareAsset(resolvedRootAsset, resolvedRootAsset, runtimeLabels) - if !discoveredAssets.Add(rootAssetWithRuntime.Asset, rootAssetWithRuntime.Runtime) { + if !discoveredAssets.Add(resolvedRootAsset, rootAssetWithRuntime.Asset, rootAssetWithRuntime.Runtime) { rootAssetWithRuntime.Runtime.Close() } } @@ -126,7 +164,7 @@ func DiscoverAssets(ctx context.Context, inv *inventory.Inventory, upstream *ups // for all discovered assets, we apply mondoo-specific labels and annotations that come from the root asset for _, a := range rootAssetWithRuntime.Runtime.Provider.Connection.Inventory.Spec.Assets { // create runtime for root asset - assetWithRuntime, err := createRuntimeForAsset(a, upstream, recording) + assetWithRuntime, err := createRuntimeForAsset(a, coordinator, upstream, recording) if err != nil { log.Error().Err(err).Str("asset", a.Name).Msg("unable to create runtime for asset") discoveredAssets.AddError(assetWithRuntime.Asset, err) @@ -137,7 +175,7 @@ func DiscoverAssets(ctx context.Context, inv *inventory.Inventory, upstream *ups prepareAsset(resolvedAsset, resolvedRootAsset, runtimeLabels) // If the asset has been already added, we should close its runtime - if !discoveredAssets.Add(resolvedAsset, assetWithRuntime.Runtime) { + if !discoveredAssets.Add(resolvedRootAsset, resolvedAsset, assetWithRuntime.Runtime) { assetWithRuntime.Runtime.Close() } } @@ -146,15 +184,15 @@ func DiscoverAssets(ctx context.Context, inv *inventory.Inventory, upstream *ups // if there is exactly one asset, assure that the --asset-name is used // TODO: make it so that the --asset-name is set for the root asset only even if multiple assets are there // This is a temporary fix that only works if there is only one asset - if len(discoveredAssets.Assets) == 1 && invAssets[0].Name != "" && invAssets[0].Name != discoveredAssets.Assets[0].Asset.Name { - log.Debug().Str("asset", discoveredAssets.Assets[0].Asset.Name).Msg("Overriding asset name with --asset-name flag") - discoveredAssets.Assets[0].Asset.Name = invAssets[0].Name - } + // if len(discoveredAssets.Assets) == 1 && invAssets[0].Name != "" && invAssets[0].Name != discoveredAssets.Assets[0].Asset.Name { + // log.Debug().Str("asset", discoveredAssets.Assets[0].Asset.Name).Msg("Overriding asset name with --asset-name flag") + // discoveredAssets.Assets[0].Asset.Name = invAssets[0].Name + // } return discoveredAssets, nil } -func createRuntimeForAsset(asset *inventory.Asset, upstream *upstream.UpstreamConfig, recording llx.Recording) (*AssetWithRuntime, error) { +func createRuntimeForAsset(asset *inventory.Asset, coordinator providers.Coordinator, upstream *upstream.UpstreamConfig, recording llx.Recording) (*AssetWithRuntime, error) { var runtime *providers.Runtime var err error // Close the runtime if an error occured @@ -164,7 +202,7 @@ func createRuntimeForAsset(asset *inventory.Asset, upstream *upstream.UpstreamCo } }() - runtime, err = providers.Coordinator.RuntimeFor(asset, providers.DefaultRuntime()) + runtime, err = coordinator.RuntimeFor(asset, providers.DefaultRuntime()) if err != nil { return nil, err } diff --git a/explorer/scan/discovery_test.go b/explorer/scan/discovery_test.go index 1e7afd4ecf..3545340d65 100644 --- a/explorer/scan/discovery_test.go +++ b/explorer/scan/discovery_test.go @@ -15,57 +15,86 @@ import ( inventory "go.mondoo.com/cnquery/v10/providers-sdk/v1/inventory" ) +func TestDiscoveredAssets_AddRoot(t *testing.T) { + d := &DiscoveredAssets{ + platformIds: map[string]struct{}{}, + RootAssets: map[*inventory.Asset]*RootAsset{}, + Errors: []*AssetWithError{}, + } + + root := &inventory.Asset{} + assert.True(t, d.AddRoot(root, nil, nil)) + assert.Len(t, d.GetFlattenedChildren(), 0) + assert.Len(t, d.Errors, 0) + + // Make sure adding duplicates is not possible + assert.False(t, d.AddRoot(root, nil, nil)) + assert.Len(t, d.GetFlattenedChildren(), 0) + assert.Len(t, d.Errors, 0) +} + func TestDiscoveredAssets_Add(t *testing.T) { d := &DiscoveredAssets{ platformIds: map[string]struct{}{}, - Assets: []*AssetWithRuntime{}, + RootAssets: map[*inventory.Asset]*RootAsset{}, Errors: []*AssetWithError{}, } + + root := &inventory.Asset{} + assert.True(t, d.AddRoot(root, nil, nil)) + asset := &inventory.Asset{ PlatformIds: []string{"platform1"}, } runtime := &providers.Runtime{} - assert.True(t, d.Add(asset, runtime)) - assert.Len(t, d.Assets, 1) + assert.True(t, d.Add(root, asset, runtime)) + assert.Len(t, d.GetFlattenedChildren(), 1) assert.Len(t, d.Errors, 0) // Make sure adding duplicates is not possible - assert.False(t, d.Add(asset, runtime)) - assert.Len(t, d.Assets, 1) + assert.False(t, d.Add(root, asset, runtime)) + assert.Len(t, d.GetFlattenedChildren(), 1) assert.Len(t, d.Errors, 0) } func TestDiscoveredAssets_Add_MultiplePlatformIDs(t *testing.T) { d := &DiscoveredAssets{ platformIds: map[string]struct{}{}, - Assets: []*AssetWithRuntime{}, + RootAssets: map[*inventory.Asset]*RootAsset{}, Errors: []*AssetWithError{}, } + + root := &inventory.Asset{} + assert.True(t, d.AddRoot(root, nil, nil)) + asset := &inventory.Asset{ PlatformIds: []string{"platform1", "platform2"}, } runtime := &providers.Runtime{} - assert.True(t, d.Add(asset, runtime)) - assert.Len(t, d.Assets, 1) + assert.True(t, d.Add(root, asset, runtime)) + assert.Len(t, d.GetFlattenedChildren(), 1) assert.Len(t, d.Errors, 0) // Make sure adding duplicates is not possible - assert.False(t, d.Add(&inventory.Asset{ + assert.False(t, d.Add(root, &inventory.Asset{ PlatformIds: []string{"platform3", asset.PlatformIds[0]}, }, runtime)) - assert.Len(t, d.Assets, 1) + assert.Len(t, d.GetFlattenedChildren(), 1) assert.Len(t, d.Errors, 0) } func TestDiscoveredAssets_GetAssetsByPlatformID(t *testing.T) { d := &DiscoveredAssets{ platformIds: map[string]struct{}{}, - Assets: []*AssetWithRuntime{}, + RootAssets: map[*inventory.Asset]*RootAsset{}, Errors: []*AssetWithError{}, } + root := &inventory.Asset{} + assert.True(t, d.AddRoot(root, nil, nil)) + allPlatformIds := []string{} for i := 0; i < 10; i++ { pId := fmt.Sprintf("platform1%d", i) @@ -75,9 +104,9 @@ func TestDiscoveredAssets_GetAssetsByPlatformID(t *testing.T) { } runtime := &providers.Runtime{} - assert.True(t, d.Add(asset, runtime)) + assert.True(t, d.Add(root, asset, runtime)) } - assert.Len(t, d.Assets, 10) + assert.Len(t, d.GetFlattenedChildren(), 10) // Make sure adding duplicates is not possible assets := d.GetAssetsByPlatformID(allPlatformIds[0]) @@ -88,10 +117,13 @@ func TestDiscoveredAssets_GetAssetsByPlatformID(t *testing.T) { func TestDiscoveredAssets_GetAssetsByPlatformID_Empty(t *testing.T) { d := &DiscoveredAssets{ platformIds: map[string]struct{}{}, - Assets: []*AssetWithRuntime{}, + RootAssets: map[*inventory.Asset]*RootAsset{}, Errors: []*AssetWithError{}, } + root := &inventory.Asset{} + assert.True(t, d.AddRoot(root, nil, nil)) + allPlatformIds := []string{} for i := 0; i < 10; i++ { pId := fmt.Sprintf("platform1%d", i) @@ -101,9 +133,9 @@ func TestDiscoveredAssets_GetAssetsByPlatformID_Empty(t *testing.T) { } runtime := &providers.Runtime{} - assert.True(t, d.Add(asset, runtime)) + assert.True(t, d.Add(root, asset, runtime)) } - assert.Len(t, d.Assets, 10) + assert.Len(t, d.GetFlattenedChildren(), 10) // Make sure adding duplicates is not possible assets := d.GetAssetsByPlatformID("") @@ -143,11 +175,12 @@ func TestDiscoverAssets(t *testing.T) { inv := getInventory() discoveredAssets, err := DiscoverAssets(context.Background(), inv, nil, providers.NullRecording{}) require.NoError(t, err) - assert.Len(t, discoveredAssets.Assets, 3) + assets := discoveredAssets.GetFlattenedChildren() + assert.Len(t, assets, 3) assert.Len(t, discoveredAssets.Errors, 0) - assert.Equal(t, "mondoo-operator-123", discoveredAssets.Assets[0].Asset.ManagedBy) - assert.Equal(t, "mondoo-operator-123", discoveredAssets.Assets[1].Asset.ManagedBy) - assert.Equal(t, "mondoo-operator-123", discoveredAssets.Assets[2].Asset.ManagedBy) + assert.Equal(t, "mondoo-operator-123", assets[0].Asset.ManagedBy) + assert.Equal(t, "mondoo-operator-123", assets[1].Asset.ManagedBy) + assert.Equal(t, "mondoo-operator-123", assets[2].Asset.ManagedBy) }) t.Run("with duplicate root assets", func(t *testing.T) { @@ -157,7 +190,7 @@ func TestDiscoverAssets(t *testing.T) { require.NoError(t, err) // Make sure no duplicates are returned - assert.Len(t, discoveredAssets.Assets, 3) + assert.Len(t, discoveredAssets.GetFlattenedChildren(), 3) assert.Len(t, discoveredAssets.Errors, 0) }) @@ -168,7 +201,7 @@ func TestDiscoverAssets(t *testing.T) { require.NoError(t, err) // Make sure no duplicates are returned - assert.Len(t, discoveredAssets.Assets, 3) + assert.Len(t, discoveredAssets.GetFlattenedChildren(), 3) assert.Len(t, discoveredAssets.Errors, 0) }) @@ -181,7 +214,7 @@ func TestDiscoverAssets(t *testing.T) { discoveredAssets, err := DiscoverAssets(context.Background(), inv, nil, providers.NullRecording{}) require.NoError(t, err) - for _, asset := range discoveredAssets.Assets { + for _, asset := range discoveredAssets.GetFlattenedChildren() { for k, v := range inv.Spec.Assets[0].Annotations { require.Contains(t, asset.Asset.Annotations, k) assert.Equal(t, v, asset.Asset.Annotations[k]) @@ -195,7 +228,7 @@ func TestDiscoverAssets(t *testing.T) { discoveredAssets, err := DiscoverAssets(context.Background(), inv, nil, providers.NullRecording{}) require.NoError(t, err) - for _, asset := range discoveredAssets.Assets { + for _, asset := range discoveredAssets.GetFlattenedChildren() { assert.Equal(t, inv.Spec.Assets[0].ManagedBy, asset.Asset.ManagedBy) } }) @@ -216,7 +249,7 @@ func TestDiscoverAssets(t *testing.T) { discoveredAssets, err := DiscoverAssets(context.Background(), inv, nil, providers.NullRecording{}) require.NoError(t, err) - for _, asset := range discoveredAssets.Assets { + for _, asset := range discoveredAssets.GetFlattenedChildren() { require.Contains(t, asset.Asset.Labels, "mondoo.com/exec-environment") assert.Equal(t, "actions.github.com", asset.Asset.Labels["mondoo.com/exec-environment"]) } @@ -239,7 +272,7 @@ func TestDiscoverAssets(t *testing.T) { discoveredAssets, err := DiscoverAssets(context.Background(), inv, nil, providers.NullRecording{}) require.NoError(t, err) - for _, asset := range discoveredAssets.Assets { + for _, asset := range discoveredAssets.GetFlattenedChildren() { require.Contains(t, asset.Asset.Labels, "mondoo.com/exec-environment") assert.Equal(t, "actions.github.com", asset.Asset.Labels["mondoo.com/exec-environment"]) } @@ -251,6 +284,6 @@ func TestDiscoverAssets(t *testing.T) { discoveredAssets, err := DiscoverAssets(context.Background(), inv, nil, providers.NullRecording{}) require.NoError(t, err) - assert.Len(t, discoveredAssets.Assets, 1) + assert.Len(t, discoveredAssets.GetFlattenedChildren(), 1) }) } diff --git a/explorer/scan/local_scanner.go b/explorer/scan/local_scanner.go index cc998ce0fc..48ffadd5cf 100644 --- a/explorer/scan/local_scanner.go +++ b/explorer/scan/local_scanner.go @@ -176,8 +176,9 @@ func CreateProgressBar(discoveredAssets *DiscoveredAssets, disableProgressBar bo if isatty.IsTerminal(os.Stdout.Fd()) && !disableProgressBar && !strings.EqualFold(logger.GetLevel(), "debug") && !strings.EqualFold(logger.GetLevel(), "trace") { progressBarElements := map[string]string{} orderedKeys := []string{} - for i := range discoveredAssets.Assets { - asset := discoveredAssets.Assets[i].Asset + assets := discoveredAssets.GetFlattenedChildren() + for i := range assets { + asset := assets[i].Asset // this shouldn't happen, but might // it normally indicates a bug in the provider if presentAsset, present := progressBarElements[asset.PlatformIds[0]]; present { @@ -201,9 +202,6 @@ func CreateProgressBar(discoveredAssets *DiscoveredAssets, disableProgressBar bo func (s *LocalScanner) distributeJob(job *Job, ctx context.Context, upstream *upstream.UpstreamConfig) (*explorer.ReportCollection, error) { log.Info().Msgf("discover related assets for %d asset(s)", len(job.Inventory.Spec.Assets)) - // Always shut down the coordinator, to make sure providers are killed - defer providers.Coordinator.Shutdown() - discoveredAssets, err := DiscoverAssets(ctx, job.Inventory, upstream, s.recording) if err != nil { return nil, err @@ -213,10 +211,10 @@ func (s *LocalScanner) distributeJob(job *Job, ctx context.Context, upstream *up // Within this process, we set up a catch-all deferred function, that shuts // down all runtimes, in case we exit early. defer func() { - for _, asset := range discoveredAssets.Assets { + for _, asset := range discoveredAssets.RootAssets { // we can call close multiple times and it will only execute once - if asset.Runtime != nil { - asset.Runtime.Close() + if asset.Coordinator != nil { + asset.Coordinator.Shutdown() } } }() @@ -228,31 +226,26 @@ func (s *LocalScanner) distributeJob(job *Job, ctx context.Context, upstream *up reporter.AddScanError(discoveredAssets.Errors[i].Asset, discoveredAssets.Errors[i].Err) } - if len(discoveredAssets.Assets) == 0 { + assets := discoveredAssets.GetFlattenedChildren() + if len(assets) == 0 { return reporter.Reports(), nil } - multiprogress, err := CreateProgressBar(discoveredAssets, s.disableProgressBar) - if err != nil { - return nil, err + log.Info().Msgf("discovered %d assets", len(assets)) + + // if a bundle was provided check that it matches the filter, bundles can also be downloaded + // later therefore we do not want to stop execution here + if job.Bundle != nil && job.Bundle.FilterQueryPacks(job.QueryPackFilters) { + return nil, errors.New("all available packs filtered out. nothing to do") } - // start the progress bar - scanGroups := sync.WaitGroup{} - scanGroups.Add(1) - go func() { - defer scanGroups.Done() - if err := multiprogress.Open(); err != nil { - log.Error().Err(err).Msg("failed to open progress bar") - } - }() - assetBatches := slicesx.Batch(discoveredAssets.Assets, 100) + assetBatches := slicesx.Batch(assets, 100) for i := range assetBatches { batch := assetBatches[i] // sync assets if upstream != nil && upstream.ApiEndpoint != "" && !upstream.Incognito { - log.Info().Msg("synchronize assets") + log.Info().Msgf("synchronize %d assets", len(batch)) client, err := upstream.InitClient() if err != nil { return nil, err @@ -304,20 +297,31 @@ func (s *LocalScanner) distributeJob(job *Job, ctx context.Context, upstream *up } } } + } - // if a bundle was provided check that it matches the filter, bundles can also be downloaded - // later therefore we do not want to stop execution here - if job.Bundle != nil && job.Bundle.FilterQueryPacks(job.QueryPackFilters) { - return nil, errors.New("all available packs filtered out. nothing to do") + multiprogress, err := CreateProgressBar(discoveredAssets, s.disableProgressBar) + if err != nil { + return nil, err + } + // start the progress bar + scanGroups := sync.WaitGroup{} + scanGroups.Add(1) + go func() { + defer scanGroups.Done() + if err := multiprogress.Open(); err != nil { + log.Error().Err(err).Msg("failed to open progress bar") } + }() + for k := range discoveredAssets.RootAssets { + root := discoveredAssets.RootAssets[k] wg := sync.WaitGroup{} wg.Add(1) go func() { defer wg.Done() - for i := range batch { - asset := batch[i].Asset - runtime := batch[i].Runtime + for i := range root.Children { + asset := root.Children[i].Asset + runtime := root.Children[i].Runtime // Make sure the context has not been canceled in the meantime. Note that this approach works only for single threaded execution. If we have more than 1 thread calling this function, // we need to solve this at a different level. @@ -349,6 +353,9 @@ func (s *LocalScanner) distributeJob(job *Job, ctx context.Context, upstream *up } }() wg.Wait() + + // Shutdown the coordinator for the current root asset + root.Coordinator.Shutdown() } scanGroups.Wait() return reporter.Reports(), nil diff --git a/providers-sdk/v1/testutils/testutils.go b/providers-sdk/v1/testutils/testutils.go index 1ec57740d4..899a6921f7 100644 --- a/providers-sdk/v1/testutils/testutils.go +++ b/providers-sdk/v1/testutils/testutils.go @@ -200,7 +200,7 @@ func Local() llx.Runtime { networkSchema := MustLoadSchema(SchemaProvider{Provider: "network"}) mockSchema := MustLoadSchema(SchemaProvider{Provider: "mockprovider"}) - runtime := providers.Coordinator.NewRuntime() + runtime := providers.GlobalCoordinator.NewRuntime() provider := &providers.RunningProvider{ Name: osconf.Config.Name, diff --git a/providers/builtin.go b/providers/builtin.go index b9714687c4..81c4bad31a 100644 --- a/providers/builtin.go +++ b/providers/builtin.go @@ -35,7 +35,7 @@ var builtinProviders = map[string]*builtinProvider{ Runtime: &RunningProvider{ Name: mockProvider.Name, ID: mockProvider.ID, - Plugin: &mockProviderService{coordinator: &Coordinator}, + Plugin: &mockProviderService{coordinator: GlobalCoordinator}, isClosed: false, }, Config: mockProvider.Provider, diff --git a/providers/defaults_shared.go b/providers/defaults_shared.go index d5d85e3426..2817fccf22 100644 --- a/providers/defaults_shared.go +++ b/providers/defaults_shared.go @@ -21,7 +21,7 @@ var defaultRuntime *Runtime func DefaultRuntime() *Runtime { if defaultRuntime == nil { - defaultRuntime = Coordinator.NewRuntime() + defaultRuntime = GlobalCoordinator.NewRuntime() } return defaultRuntime } diff --git a/providers/coordinator.go b/providers/global_coordinator.go similarity index 93% rename from providers/coordinator.go rename to providers/global_coordinator.go index fea0969bd9..5eae2b41a2 100644 --- a/providers/coordinator.go +++ b/providers/global_coordinator.go @@ -24,9 +24,22 @@ import ( "google.golang.org/grpc/status" ) +type Coordinator interface { + Start(id string, isEphemeral bool, update UpdateProvidersConfig) (*RunningProvider, error) + Stop(provider *RunningProvider, isEphemeral bool) error + NewRuntime() *Runtime + NewRuntimeFrom(parent *Runtime) *Runtime + RuntimeFor(asset *inventory.Asset, parent *Runtime) (*Runtime, error) + GetRunningProviderById(id string) *RunningProvider + GetProviders() Providers + SetProviders(providers Providers) + LoadSchema(name string) (*resources.Schema, error) + Shutdown() +} + var BuiltinCoreID = coreconf.Config.ID -var Coordinator = coordinator{ +var GlobalCoordinator Coordinator = &coordinator{ RunningByID: map[string]*RunningProvider{}, RunningEphemeral: map[*RunningProvider]struct{}{}, runtimes: map[string]*Runtime{}, @@ -262,6 +275,20 @@ func (c *coordinator) Start(id string, isEphemeral bool, update UpdateProvidersC return res, nil } +func (c *coordinator) GetRunningProviderById(id string) *RunningProvider { + c.mutex.Lock() + defer c.mutex.Unlock() + return c.RunningByID[id] +} + +func (c *coordinator) GetProviders() Providers { + return c.Providers +} + +func (c *coordinator) SetProviders(providers Providers) { + c.Providers = providers +} + type ProviderVersions struct { Providers []ProviderVersion `json:"providers"` } @@ -376,7 +403,7 @@ func (c *coordinator) NewRuntimeFrom(parent *Runtime) *Runtime { return res } -// RuntimFor an asset will return a new or existing runtime for a given asset. +// RuntimeFor an asset will return a new or existing runtime for a given asset. // If a runtime for this asset already exists, it will re-use it. If the runtime // is new, it will create it and detect the provider. // The asset and parent must be defined. @@ -468,6 +495,7 @@ func (c *coordinator) Stop(provider *RunningProvider, isEphemeral bool) error { func (c *coordinator) Shutdown() { c.mutex.Lock() + defer c.mutex.Unlock() for cur := range c.RunningEphemeral { if err := cur.Shutdown(); err != nil { @@ -489,8 +517,6 @@ func (c *coordinator) Shutdown() { c.runtimes = map[string]*Runtime{} c.runtimeCnt = 0 c.unprocessedRuntimes = []*Runtime{} - - c.mutex.Unlock() } // LoadSchema for a given provider. Providers also cache their Schemas, so diff --git a/providers/local_coordinator.go b/providers/local_coordinator.go new file mode 100644 index 0000000000..8bdf45ce11 --- /dev/null +++ b/providers/local_coordinator.go @@ -0,0 +1,130 @@ +// Copyright (c) Mondoo, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package providers + +import ( + "sync" + + "github.com/rs/zerolog/log" + "go.mondoo.com/cnquery/v10/providers-sdk/v1/inventory" + "go.mondoo.com/cnquery/v10/providers-sdk/v1/resources" +) + +type localCoordinator struct { + parent Coordinator + + runningByID map[string]*RunningProvider + runningEphemeral map[*RunningProvider]struct{} + mutex sync.Mutex +} + +func NewLocalCoordinator(parent Coordinator) Coordinator { + return &localCoordinator{ + parent: parent, + runningByID: map[string]*RunningProvider{}, + runningEphemeral: map[*RunningProvider]struct{}{}, + } +} + +func (lc *localCoordinator) Start(id string, isEphemeral bool, update UpdateProvidersConfig) (*RunningProvider, error) { + // From the parent's perspective, all providers from its children are ephemeral + provider, err := lc.parent.Start(id, true, update) + if err != nil { + return nil, err + } + + lc.mutex.Lock() + if isEphemeral { + lc.runningEphemeral[provider] = struct{}{} + } else { + lc.runningByID[provider.ID] = provider + } + lc.mutex.Unlock() + return provider, nil +} + +func (lc *localCoordinator) Stop(provider *RunningProvider, isEphemeral bool) error { + if provider == nil { + return nil + } + + lc.mutex.Lock() + defer lc.mutex.Unlock() + + if err := lc.parent.Stop(provider, true); err != nil { + return err + } + + if isEphemeral { + delete(lc.runningEphemeral, provider) + } else { + found := lc.runningByID[provider.ID] + if found != nil { + delete(lc.runningByID, provider.ID) + } + } + return nil +} + +func (lc *localCoordinator) NewRuntime() *Runtime { + runtime := lc.parent.NewRuntime() + // Override the coordinator with the local one, so providers are managed + // by the local coordinator + runtime.coordinator = lc + return runtime +} + +func (lc *localCoordinator) NewRuntimeFrom(parent *Runtime) *Runtime { + res := lc.NewRuntime() + res.recording = parent.Recording() + for k, v := range parent.providers { + res.providers[k] = v + } + return res +} + +func (lc *localCoordinator) RuntimeFor(asset *inventory.Asset, parent *Runtime) (*Runtime, error) { + runtime := lc.parent.NewRuntimeFrom(parent) + runtime.coordinator = lc + return runtime, runtime.DetectProvider(asset) +} + +func (lc *localCoordinator) GetRunningProviderById(id string) *RunningProvider { + lc.mutex.Lock() + defer lc.mutex.Unlock() + return lc.runningByID[id] +} + +func (lc *localCoordinator) GetProviders() Providers { + return lc.parent.GetProviders() +} + +func (lc *localCoordinator) SetProviders(providers Providers) { + lc.parent.SetProviders(providers) +} + +func (lc *localCoordinator) LoadSchema(name string) (*resources.Schema, error) { + return lc.parent.LoadSchema(name) +} + +func (lc *localCoordinator) Shutdown() { + lc.mutex.Lock() + defer lc.mutex.Unlock() + + for provider := range lc.runningEphemeral { + log.Debug().Str("provider", provider.Name).Msg("shutting down ephemeral provider") + if err := lc.parent.Stop(provider, true); err != nil { + log.Error().Err(err).Str("provider", provider.Name).Msg("error stopping ephemeral provider") + } + } + lc.runningEphemeral = map[*RunningProvider]struct{}{} + + for _, provider := range lc.runningByID { + log.Debug().Str("provider", provider.Name).Msg("shutting down provider") + if err := lc.parent.Stop(provider, true); err != nil { + log.Error().Err(err).Str("provider", provider.Name).Msg("error stopping provider") + } + } + lc.runningByID = map[string]*RunningProvider{} +} diff --git a/providers/mock.go b/providers/mock.go index 10fb847ec0..32b86c02e7 100644 --- a/providers/mock.go +++ b/providers/mock.go @@ -23,7 +23,7 @@ var mockProvider = Provider{ } type mockProviderService struct { - coordinator *coordinator + coordinator Coordinator initialized bool runtime *Runtime } diff --git a/providers/providers.go b/providers/providers.go index 63f493bafd..a49a52fc4f 100644 --- a/providers/providers.go +++ b/providers/providers.go @@ -206,7 +206,7 @@ func ListActive() (Providers, error) { } // useful for caching; even if the structure gets updated with new providers - Coordinator.Providers = res + GlobalCoordinator.SetProviders(res) return res, nil } diff --git a/providers/runtime.go b/providers/runtime.go index 2997581526..6e6c689cc8 100644 --- a/providers/runtime.go +++ b/providers/runtime.go @@ -33,7 +33,7 @@ type Runtime struct { recording llx.Recording features []byte // coordinator is used to grab providers - coordinator *coordinator + coordinator Coordinator // providers for with open connections providers map[string]*ConnectedProvider // schema aggregates all resources executable on this asset @@ -149,7 +149,7 @@ func (r *Runtime) addProvider(id string, isEphemeral bool) (*ConnectedProvider, } else { // TODO: we need to detect only the shared running providers - running = r.coordinator.RunningByID[id] + running = r.coordinator.GetRunningProviderById(id) if running == nil { var err error running, err = r.coordinator.Start(id, false, r.AutoUpdate) @@ -193,7 +193,7 @@ func (r *Runtime) providerForAsset(asset *inventory.Asset) (*Provider, error) { conn.Type = inventory.ConnBackendToType(conn.Backend) } - provider, err := EnsureProvider(ProviderLookup{ConnType: conn.Type}, true, r.coordinator.Providers) + provider, err := EnsureProvider(ProviderLookup{ConnType: conn.Type}, true, r.coordinator.GetProviders()) if err != nil { errs.Add(err) continue