Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🧹 providers coordinator v2 #3218

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion apps/cnquery/cmd/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ You remain logged in until you explicitly log out using the 'logout' subcommand.
viper.BindPFlag("name", cmd.Flags().Lookup("name"))
},
RunE: func(cmd *cobra.Command, args []string) error {
defer cnquery_providers.Coordinator.Shutdown()
defer cnquery_providers.GlobalCoordinator.Shutdown()
token, _ := cmd.Flags().GetString("token")
annotations, _ := cmd.Flags().GetStringToString("annotation")
return register(token, annotations)
Expand Down
2 changes: 1 addition & 1 deletion apps/cnquery/cmd/logout.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ ensure the credentials cannot be used in the future.
viper.BindPFlag("force", cmd.Flags().Lookup("force"))
},
RunE: func(cmd *cobra.Command, args []string) error {
defer cnquery_providers.Coordinator.Shutdown()
defer cnquery_providers.GlobalCoordinator.Shutdown()
var err error

// its perfectly fine not to have a config here, therefore we ignore errors
Expand Down
2 changes: 1 addition & 1 deletion apps/cnquery/cmd/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ func (c *cnqueryPlugin) RunQuery(conf *run.RunQueryConfig, runtime *providers.Ru
return err
}

connectAssetRuntime, err := providers.Coordinator.RuntimeFor(asset, runtime)
connectAssetRuntime, err := providers.GlobalCoordinator.RuntimeFor(asset, runtime)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion apps/cnquery/cmd/shell.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ func StartShell(runtime *providers.Runtime, conf *ShellConfig) error {
// when we close the shell, we need to close the backend and store the recording
onCloseHandler := func() {
runtime.Close()
providers.Coordinator.Shutdown()
providers.GlobalCoordinator.Shutdown()
}

shellOptions := []shell.ShellOption{}
Expand Down
2 changes: 1 addition & 1 deletion apps/cnquery/cmd/status.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ Status sends a ping to Mondoo Platform to verify the credentials.
viper.BindPFlag("output", cmd.Flags().Lookup("output"))
},
RunE: func(cmd *cobra.Command, args []string) error {
defer providers.Coordinator.Shutdown()
defer providers.GlobalCoordinator.Shutdown()
opts, optsErr := config.Read()
if optsErr != nil {
return cli_errors.NewCommandError(errors.Wrap(optsErr, "could not load configuration"), 1)
Expand Down
16 changes: 8 additions & 8 deletions cli/providers/providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ func setConnector(provider *plugin.Provider, connector *plugin.Connector, run fu
}

// TODO: add flag to set timeout and then use RuntimeWithShutdownTimeout
runtime := providers.Coordinator.NewRuntime()
runtime := providers.GlobalCoordinator.NewRuntime()
if err = providers.SetDefaultRuntime(runtime); err != nil {
log.Error().Msg(err.Error())
}
Expand All @@ -440,12 +440,12 @@ func setConnector(provider *plugin.Provider, connector *plugin.Connector, run fu
}

if err := runtime.UseProvider(provider.ID); err != nil {
providers.Coordinator.Shutdown()
providers.GlobalCoordinator.Shutdown()
log.Fatal().Err(err).Msg("failed to start provider " + provider.Name)
}

if record != "" && useRecording != "" {
providers.Coordinator.Shutdown()
providers.GlobalCoordinator.Shutdown()
log.Fatal().Msg("please only use --record or --use-recording, but not both at the same time")
}
recordingPath := record
Expand All @@ -459,7 +459,7 @@ func setConnector(provider *plugin.Provider, connector *plugin.Connector, run fu
PrettyPrintJSON: pretty,
})
if err != nil {
providers.Coordinator.Shutdown()
providers.GlobalCoordinator.Shutdown()
log.Fatal().Msg(err.Error())
}
runtime.SetRecording(recording)
Expand All @@ -471,21 +471,21 @@ func setConnector(provider *plugin.Provider, connector *plugin.Connector, run fu
})
if err != nil {
runtime.Close()
providers.Coordinator.Shutdown()
providers.GlobalCoordinator.Shutdown()
log.Fatal().Err(err).Msg("failed to parse cli arguments")
}

if cliRes == nil {
runtime.Close()
providers.Coordinator.Shutdown()
providers.GlobalCoordinator.Shutdown()
log.Fatal().Msg("failed to process CLI arguments, nothing was returned")
return // adding this here as a compiler hint to stop warning about nil-dereferences
}

if cliRes.Asset == nil {
log.Warn().Err(err).Msg("failed to discover assets after processing CLI arguments")
} else {
assetRuntime, err := providers.Coordinator.RuntimeFor(cliRes.Asset, runtime)
assetRuntime, err := providers.GlobalCoordinator.RuntimeFor(cliRes.Asset, runtime)
if err != nil {
log.Warn().Err(err).Msg("failed to get runtime for an asset that was detected after parsing the CLI")
} else {
Expand All @@ -495,7 +495,7 @@ func setConnector(provider *plugin.Provider, connector *plugin.Connector, run fu

run(cc, runtime, cliRes)
runtime.Close()
providers.Coordinator.Shutdown()
providers.GlobalCoordinator.Shutdown()
}

attachFlags(cmd.Flags(), allFlags)
Expand Down
78 changes: 58 additions & 20 deletions explorer/scan/discovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ import (
"go.mondoo.com/cnquery/v10/providers-sdk/v1/upstream"
)

type RootAsset struct {
Coordinator providers.Coordinator
Asset *inventory.Asset
Children []*AssetWithRuntime
}

type AssetWithRuntime struct {
Asset *inventory.Asset
Runtime *providers.Runtime
Expand All @@ -30,13 +36,26 @@ type AssetWithError struct {

type DiscoveredAssets struct {
platformIds map[string]struct{}
Assets []*AssetWithRuntime
Errors []*AssetWithError
// Assets []*AssetWithRuntime
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this line be removed

Errors []*AssetWithError

RootAssets map[*inventory.Asset]*RootAsset
}

func (d *DiscoveredAssets) AddRoot(root *inventory.Asset, coordinator providers.Coordinator, runtime *providers.Runtime) bool {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

runtime isn't used. is that expected?

if _, ok := d.RootAssets[root]; ok {
return false
}
d.RootAssets[root] = &RootAsset{
Coordinator: coordinator,
Asset: root,
}
return true
}

// Add adds an asset and its runtime to the discovered assets list. It returns true if the
// asset has been added, false if it is a duplicate
func (d *DiscoveredAssets) Add(asset *inventory.Asset, runtime *providers.Runtime) bool {
func (d *DiscoveredAssets) Add(root, asset *inventory.Asset, runtime *providers.Runtime) bool {
isDuplicate := false
for _, platformId := range asset.PlatformIds {
if _, ok := d.platformIds[platformId]; ok {
Expand All @@ -49,21 +68,31 @@ func (d *DiscoveredAssets) Add(asset *inventory.Asset, runtime *providers.Runtim
return false
}

d.Assets = append(d.Assets, &AssetWithRuntime{Asset: asset, Runtime: runtime})
d.RootAssets[root].Children = append(d.RootAssets[root].Children, &AssetWithRuntime{Asset: asset, Runtime: runtime})
return true
}

func (d *DiscoveredAssets) AddError(asset *inventory.Asset, err error) {
d.Errors = append(d.Errors, &AssetWithError{Asset: asset, Err: err})
}

func (d *DiscoveredAssets) GetFlattenedChildren() []*AssetWithRuntime {
var assets []*AssetWithRuntime
for _, a := range d.RootAssets {
assets = append(assets, a.Children...)
}
return assets
}

func (d *DiscoveredAssets) GetAssetsByPlatformID(platformID string) []*inventory.Asset {
var assets []*inventory.Asset
for _, a := range d.Assets {
for _, p := range a.Asset.PlatformIds {
if platformID == "" || p == platformID {
assets = append(assets, a.Asset)
break
for _, a := range d.RootAssets {
for _, c := range a.Children {
for _, p := range c.Asset.PlatformIds {
if platformID == "" || p == platformID {
assets = append(assets, c.Asset)
break
}
}
}
}
Expand Down Expand Up @@ -91,7 +120,7 @@ func DiscoverAssets(ctx context.Context, inv *inventory.Inventory, upstream *ups
runtimeLabels = runtimeEnv.Labels()
}

discoveredAssets := &DiscoveredAssets{platformIds: map[string]struct{}{}}
discoveredAssets := &DiscoveredAssets{platformIds: map[string]struct{}{}, RootAssets: map[*inventory.Asset]*RootAsset{}}

// we connect and perform discovery for each asset in the job inventory
for _, rootAsset := range invAssets {
Expand All @@ -100,20 +129,29 @@ func DiscoverAssets(ctx context.Context, inv *inventory.Inventory, upstream *ups
return nil, err
}

coordinator := providers.NewLocalCoordinator(providers.GlobalCoordinator)

// create runtime for root asset
rootAssetWithRuntime, err := createRuntimeForAsset(resolvedRootAsset, upstream, recording)
rootAssetWithRuntime, err := createRuntimeForAsset(resolvedRootAsset, coordinator, upstream, recording)
if err != nil {
log.Error().Err(err).Str("asset", resolvedRootAsset.Name).Msg("unable to create runtime for asset")
discoveredAssets.AddError(rootAssetWithRuntime.Asset, err)
coordinator.Shutdown()
continue
}

resolvedRootAsset = rootAssetWithRuntime.Asset // to ensure we get all the information the connect call gave us

if !discoveredAssets.AddRoot(resolvedRootAsset, coordinator, rootAssetWithRuntime.Runtime) {
rootAssetWithRuntime.Runtime.Close()
coordinator.Shutdown()
continue
}

// If the root asset has platform IDs, then it is a scannable asset, so we need to add it
if len(resolvedRootAsset.PlatformIds) > 0 {
prepareAsset(resolvedRootAsset, resolvedRootAsset, runtimeLabels)
if !discoveredAssets.Add(rootAssetWithRuntime.Asset, rootAssetWithRuntime.Runtime) {
if !discoveredAssets.Add(resolvedRootAsset, rootAssetWithRuntime.Asset, rootAssetWithRuntime.Runtime) {
rootAssetWithRuntime.Runtime.Close()
}
}
Expand All @@ -126,7 +164,7 @@ func DiscoverAssets(ctx context.Context, inv *inventory.Inventory, upstream *ups
// for all discovered assets, we apply mondoo-specific labels and annotations that come from the root asset
for _, a := range rootAssetWithRuntime.Runtime.Provider.Connection.Inventory.Spec.Assets {
// create runtime for root asset
assetWithRuntime, err := createRuntimeForAsset(a, upstream, recording)
assetWithRuntime, err := createRuntimeForAsset(a, coordinator, upstream, recording)
if err != nil {
log.Error().Err(err).Str("asset", a.Name).Msg("unable to create runtime for asset")
discoveredAssets.AddError(assetWithRuntime.Asset, err)
Expand All @@ -137,7 +175,7 @@ func DiscoverAssets(ctx context.Context, inv *inventory.Inventory, upstream *ups
prepareAsset(resolvedAsset, resolvedRootAsset, runtimeLabels)

// If the asset has been already added, we should close its runtime
if !discoveredAssets.Add(resolvedAsset, assetWithRuntime.Runtime) {
if !discoveredAssets.Add(resolvedRootAsset, resolvedAsset, assetWithRuntime.Runtime) {
assetWithRuntime.Runtime.Close()
}
}
Expand All @@ -146,15 +184,15 @@ func DiscoverAssets(ctx context.Context, inv *inventory.Inventory, upstream *ups
// if there is exactly one asset, assure that the --asset-name is used
// TODO: make it so that the --asset-name is set for the root asset only even if multiple assets are there
// This is a temporary fix that only works if there is only one asset
if len(discoveredAssets.Assets) == 1 && invAssets[0].Name != "" && invAssets[0].Name != discoveredAssets.Assets[0].Asset.Name {
log.Debug().Str("asset", discoveredAssets.Assets[0].Asset.Name).Msg("Overriding asset name with --asset-name flag")
discoveredAssets.Assets[0].Asset.Name = invAssets[0].Name
}
// if len(discoveredAssets.Assets) == 1 && invAssets[0].Name != "" && invAssets[0].Name != discoveredAssets.Assets[0].Asset.Name {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

commented out code

// log.Debug().Str("asset", discoveredAssets.Assets[0].Asset.Name).Msg("Overriding asset name with --asset-name flag")
// discoveredAssets.Assets[0].Asset.Name = invAssets[0].Name
// }

return discoveredAssets, nil
}

func createRuntimeForAsset(asset *inventory.Asset, upstream *upstream.UpstreamConfig, recording llx.Recording) (*AssetWithRuntime, error) {
func createRuntimeForAsset(asset *inventory.Asset, coordinator providers.Coordinator, upstream *upstream.UpstreamConfig, recording llx.Recording) (*AssetWithRuntime, error) {
var runtime *providers.Runtime
var err error
// Close the runtime if an error occured
Expand All @@ -164,7 +202,7 @@ func createRuntimeForAsset(asset *inventory.Asset, upstream *upstream.UpstreamCo
}
}()

runtime, err = providers.Coordinator.RuntimeFor(asset, providers.DefaultRuntime())
runtime, err = coordinator.RuntimeFor(asset, providers.DefaultRuntime())
if err != nil {
return nil, err
}
Expand Down
Loading
Loading