diff --git a/apps/cnquery/cmd/plugin.go b/apps/cnquery/cmd/plugin.go index 44c049de7f..dac775358f 100644 --- a/apps/cnquery/cmd/plugin.go +++ b/apps/cnquery/cmd/plugin.go @@ -117,12 +117,12 @@ func (c *cnqueryPlugin) RunQuery(conf *run.RunQueryConfig, runtime *providers.Ru for i := range assets { connectAsset := assets[i] - connectAssetRuntime := providers.Coordinator.NewRuntimeFrom(runtime) - - if err := connectAssetRuntime.DetectProvider(connectAsset); err != nil { + connectAssetRuntime, err := providers.Coordinator.RuntimeFor(connectAsset, runtime) + if err != nil { return err } - err := connectAssetRuntime.Connect(&pp.ConnectReq{ + + err = connectAssetRuntime.Connect(&pp.ConnectReq{ Features: config.Features, Asset: connectAsset, Upstream: upstreamConfig, diff --git a/cli/providers/providers.go b/cli/providers/providers.go index f192a2d59e..a3dc42c8cf 100644 --- a/cli/providers/providers.go +++ b/cli/providers/providers.go @@ -408,6 +408,9 @@ func setConnector(provider *plugin.Provider, connector *plugin.Connector, run fu // TODO: add flag to set timeout and then use RuntimeWithShutdownTimeout runtime := providers.Coordinator.NewRuntime() + if err = providers.SetDefaultRuntime(runtime); err != nil { + log.Error().Msg(err.Error()) + } autoUpdate := true if viper.IsSet("auto_update") { diff --git a/explorer/scan/local_scanner.go b/explorer/scan/local_scanner.go index 541777fd5f..5ff5d04df5 100644 --- a/explorer/scan/local_scanner.go +++ b/explorer/scan/local_scanner.go @@ -174,14 +174,15 @@ func preprocessQueryPackFilters(filters []string) []string { func (s *LocalScanner) distributeJob(job *Job, ctx context.Context, upstream *upstream.UpstreamConfig) (*explorer.ReportCollection, bool, error) { log.Info().Msgf("discover related assets for %d asset(s)", len(job.Inventory.Spec.Assets)) - im, err := manager.NewManager(manager.WithInventory(job.Inventory, providers.Coordinator.NewRuntime())) + im, err := manager.NewManager(manager.WithInventory(job.Inventory, providers.DefaultRuntime())) if err != nil { return nil, false, errors.New("failed to resolve inventory for connection") } assetList := im.GetAssets() var assets []*assetWithRuntime - var assetCandidates []*inventory.Asset + // note: asset candidate runtimes are the runtime that discovered them + var assetCandidates []*assetWithRuntime // we connect and perform discovery for each asset in the job inventory for i := range assetList { @@ -190,11 +191,10 @@ func (s *LocalScanner) distributeJob(job *Job, ctx context.Context, upstream *up if err != nil { return nil, false, err } - runtime := providers.Coordinator.NewRuntime() - err = runtime.DetectProvider(resolvedAsset) + runtime, err := providers.Coordinator.RuntimeFor(asset, providers.DefaultRuntime()) if err != nil { - log.Error().Err(err).Msg("unable to detect provider for asset") + log.Error().Err(err).Str("asset", asset.Name).Msg("unable to create runtime for asset") continue } runtime.SetRecording(s.recording) @@ -211,7 +211,12 @@ func (s *LocalScanner) distributeJob(job *Job, ctx context.Context, upstream *up if err != nil { return nil, false, err } - assetCandidates = append(assetCandidates, processedAssets...) + for i := range processedAssets { + assetCandidates = append(assetCandidates, &assetWithRuntime{ + asset: processedAssets[i], + runtime: runtime, + }) + } // TODO: we want to keep better track of errors, since there may be // multiple assets coming in. It's annoying to abort the scan if we get one // error at this stage. @@ -222,28 +227,26 @@ func (s *LocalScanner) distributeJob(job *Job, ctx context.Context, upstream *up } // for each asset candidate, we initialize a new runtime and connect to it. - for _, asset := range assetCandidates { - runtime := providers.Coordinator.NewRuntime() - // Make sure the provider for the asset is present - if err := runtime.DetectProvider(asset); err != nil { + for i := range assetCandidates { + candidate := assetCandidates[i] + + runtime, err := providers.Coordinator.RuntimeFor(candidate.asset, candidate.runtime) + if err != nil { return nil, false, err } - // attach recording before connect, so it is tied to the asset - runtime.SetRecording(s.recording) - - err := runtime.Connect(&plugin.ConnectReq{ + err = runtime.Connect(&plugin.ConnectReq{ Features: config.Features, - Asset: asset, + Asset: candidate.asset, Upstream: upstream, }) if err != nil { - log.Error().Err(err).Msg("unable to connect to asset") + log.Error().Err(err).Str("asset", candidate.asset.Name).Msg("unable to connect to asset") continue } assets = append(assets, &assetWithRuntime{ - asset: asset, + asset: candidate.asset, runtime: runtime, }) } @@ -383,6 +386,7 @@ func (s *LocalScanner) distributeJob(job *Job, ctx context.Context, upstream *up defer scanGroup.Done() multiprogress.Open() }() + scanGroup.Wait() return reporter.Reports(), finished, nil } diff --git a/providers/coordinator.go b/providers/coordinator.go index 0a83739b24..0b0a2403ec 100644 --- a/providers/coordinator.go +++ b/providers/coordinator.go @@ -6,6 +6,7 @@ package providers import ( "os" "os/exec" + "strconv" "sync" "time" @@ -13,6 +14,7 @@ import ( "github.com/hashicorp/go-plugin" "github.com/muesli/termenv" "github.com/rs/zerolog/log" + "go.mondoo.com/cnquery/providers-sdk/v1/inventory" pp "go.mondoo.com/cnquery/providers-sdk/v1/plugin" "go.mondoo.com/cnquery/providers-sdk/v1/resources" coreconf "go.mondoo.com/cnquery/providers/core/config" @@ -22,13 +24,18 @@ import ( var BuiltinCoreID = coreconf.Config.ID var Coordinator = coordinator{ - Running: []*RunningProvider{}, + Running: []*RunningProvider{}, + runtimes: map[string]*Runtime{}, } type coordinator struct { Providers Providers Running []*RunningProvider - mutex sync.Mutex + + unprocessedRuntimes []*Runtime + runtimes map[string]*Runtime + runtimeCnt int + mutex sync.Mutex } type builtinProvider struct { @@ -246,6 +253,101 @@ func (c *coordinator) tryProviderUpdate(provider *Provider, update UpdateProvide return provider, nil } +func (c *coordinator) NewRuntime() *Runtime { + res := &Runtime{ + coordinator: c, + providers: map[string]*ConnectedProvider{}, + schema: extensibleSchema{ + loaded: map[string]struct{}{}, + Schema: resources.Schema{ + Resources: map[string]*resources.ResourceInfo{}, + }, + }, + Recording: NullRecording{}, + shutdownTimeout: defaultShutdownTimeout, + } + res.schema.runtime = res + + // TODO: do this dynamically in the future + res.schema.loadAllSchemas() + + c.mutex.Lock() + c.unprocessedRuntimes = append(c.unprocessedRuntimes, res) + c.runtimeCnt++ + cnt := c.runtimeCnt + c.mutex.Unlock() + + log.Warn().Msg("Started a new runtime (" + strconv.Itoa(cnt) + " total)") + + return res +} + +func (c *coordinator) NewRuntimeFrom(parent *Runtime) *Runtime { + res := c.NewRuntime() + res.Recording = parent.Recording + for k, v := range parent.providers { + res.providers[k] = v + } + return res +} + +// RuntimFor an asset will return a new or existing runtime for a given asset. +// If a runtime for this asset already exists, it will re-use it. If the runtime +// is new, it will create it and detect the provider. +// The asset and parent must be defined. +func (c *coordinator) RuntimeFor(asset *inventory.Asset, parent *Runtime) (*Runtime, error) { + c.mutex.Lock() + c.unsafeRefreshRuntimes() + res := c.unsafeGetAssetRuntime(asset) + c.mutex.Unlock() + + if res != nil { + return res, nil + } + + res = c.NewRuntimeFrom(parent) + return res, res.DetectProvider(asset) +} + +// Only call this with a mutex lock around it! +func (c *coordinator) unsafeRefreshRuntimes() { + var remaining []*Runtime + for i := range c.unprocessedRuntimes { + rt := c.unprocessedRuntimes[i] + if asset := rt.asset(); asset == nil && !c.unsafeSetAssetRuntime(asset, rt) { + remaining = append(remaining, rt) + } + } + c.unprocessedRuntimes = remaining +} + +func (c *coordinator) unsafeGetAssetRuntime(asset *inventory.Asset) *Runtime { + if asset.Mrn != "" { + if rt := c.runtimes[asset.Mrn]; rt != nil { + return rt + } + } + for _, id := range asset.PlatformIds { + if rt := c.runtimes[id]; rt != nil { + return rt + } + } + return nil +} + +func (c *coordinator) unsafeSetAssetRuntime(asset *inventory.Asset, runtime *Runtime) bool { + found := false + if asset.Mrn != "" { + c.runtimes[asset.Mrn] = runtime + found = true + } + for _, id := range asset.PlatformIds { + c.runtimes[id] = runtime + found = true + } + return found +} + func (c *coordinator) Close(p *RunningProvider) { if !p.isClosed { p.isClosed = true diff --git a/providers/defaults.go b/providers/defaults.go index ffd53b01b5..2d8512774b 100644 --- a/providers/defaults.go +++ b/providers/defaults.go @@ -3,7 +3,10 @@ package providers -import "go.mondoo.com/cnquery/providers-sdk/v1/plugin" +import ( + "github.com/cockroachdb/errors" + "go.mondoo.com/cnquery/providers-sdk/v1/plugin" +) const DefaultOsID = "go.mondoo.com/cnquery/providers/os" @@ -16,6 +19,14 @@ func DefaultRuntime() *Runtime { return defaultRuntime } +func SetDefaultRuntime(rt *Runtime) error { + if rt == nil { + return errors.New("attempted to set default runtime to null") + } + defaultRuntime = rt + return nil +} + // DefaultProviders are useful when working in air-gapped environments // to tell users what providers are used for common connections, when there // is no other way to find out. diff --git a/providers/extensible_schema.go b/providers/extensible_schema.go new file mode 100644 index 0000000000..9239ff9bb1 --- /dev/null +++ b/providers/extensible_schema.go @@ -0,0 +1,123 @@ +// Copyright (c) Mondoo, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package providers + +import ( + "sync" + + "github.com/rs/zerolog/log" + "go.mondoo.com/cnquery/providers-sdk/v1/resources" +) + +type extensibleSchema struct { + resources.Schema + + loaded map[string]struct{} + runtime *Runtime + allLoaded bool + lockAll sync.Mutex // only used in getting all schemas + lockAdd sync.Mutex // only used when adding a schema +} + +func (x *extensibleSchema) loadAllSchemas() { + x.lockAll.Lock() + defer x.lockAll.Unlock() + + // If another goroutine started to load this before us, it will be locked until + // we complete to load everything and then it will be dumped into this + // position. At this point, if it has been loaded we can return safely, since + // we don't unlock until we are finished loading. + if x.allLoaded { + return + } + x.allLoaded = true + + providers, err := ListActive() + if err != nil { + log.Error().Err(err).Msg("failed to list all providers, can't load additional schemas") + return + } + + for name := range providers { + schema, err := x.runtime.coordinator.LoadSchema(name) + if err != nil { + log.Error().Err(err).Msg("load schema failed") + } else { + x.Add(name, schema) + } + } +} + +func (x *extensibleSchema) Close() { + x.loaded = map[string]struct{}{} + x.Schema.Resources = nil +} + +func (x *extensibleSchema) Lookup(name string) *resources.ResourceInfo { + if found, ok := x.Resources[name]; ok { + return found + } + if x.allLoaded { + return nil + } + + x.loadAllSchemas() + return x.Resources[name] +} + +func (x *extensibleSchema) LookupField(resource string, field string) (*resources.ResourceInfo, *resources.Field) { + found, ok := x.Resources[resource] + if !ok { + if x.allLoaded { + return nil, nil + } + + x.loadAllSchemas() + + found, ok = x.Resources[resource] + if !ok { + return nil, nil + } + return found, found.Fields[field] + } + + fieldObj, ok := found.Fields[field] + if ok { + return found, fieldObj + } + if x.allLoaded { + return found, nil + } + + x.loadAllSchemas() + return found, found.Fields[field] +} + +func (x *extensibleSchema) Add(name string, schema *resources.Schema) { + if schema == nil { + return + } + if name == "" { + log.Error().Msg("tried to add a schema with no name") + return + } + + x.lockAdd.Lock() + defer x.lockAdd.Unlock() + + if _, ok := x.loaded[name]; ok { + return + } + + x.loaded[name] = struct{}{} + x.Schema.Add(schema) +} + +func (x *extensibleSchema) AllResources() map[string]*resources.ResourceInfo { + if !x.allLoaded { + x.loadAllSchemas() + } + + return x.Resources +} diff --git a/providers/runtime.go b/providers/runtime.go index 9dfe53ec9f..7f342ab630 100644 --- a/providers/runtime.go +++ b/providers/runtime.go @@ -52,35 +52,6 @@ func (c *coordinator) RuntimeWithShutdownTimeout(timeout time.Duration) *Runtime return runtime } -func (c *coordinator) NewRuntime() *Runtime { - res := &Runtime{ - coordinator: c, - providers: map[string]*ConnectedProvider{}, - schema: extensibleSchema{ - loaded: map[string]struct{}{}, - Schema: resources.Schema{ - Resources: map[string]*resources.ResourceInfo{}, - }, - }, - Recording: NullRecording{}, - shutdownTimeout: defaultShutdownTimeout, - } - res.schema.runtime = res - - // TODO: do this dynamically in the future - res.schema.loadAllSchemas() - return res -} - -func (c *coordinator) NewRuntimeFrom(parent *Runtime) *Runtime { - res := c.NewRuntime() - res.Recording = parent.Recording - for k, v := range parent.providers { - res.providers[k] = v - } - return res -} - type shutdownResult struct { Response *plugin.ShutdownRes Error error @@ -610,114 +581,9 @@ func (r *Runtime) AddSchema(name string, schema *resources.Schema) { r.schema.Add(name, schema) } -type extensibleSchema struct { - resources.Schema - - loaded map[string]struct{} - runtime *Runtime - allLoaded bool - lockAll sync.Mutex // only used in getting all schemas - lockAdd sync.Mutex // only used when adding a schema -} - -func (x *extensibleSchema) loadAllSchemas() { - x.lockAll.Lock() - defer x.lockAll.Unlock() - - // If another goroutine started to load this before us, it will be locked until - // we complete to load everything and then it will be dumped into this - // position. At this point, if it has been loaded we can return safely, since - // we don't unlock until we are finished loading. - if x.allLoaded { - return - } - x.allLoaded = true - - providers, err := ListActive() - if err != nil { - log.Error().Err(err).Msg("failed to list all providers, can't load additional schemas") - return - } - - for name := range providers { - schema, err := x.runtime.coordinator.LoadSchema(name) - if err != nil { - log.Error().Err(err).Msg("load schema failed") - } else { - x.Add(name, schema) - } - } -} - -func (x *extensibleSchema) Close() { - x.loaded = map[string]struct{}{} - x.Schema.Resources = nil -} - -func (x *extensibleSchema) Lookup(name string) *resources.ResourceInfo { - if found, ok := x.Resources[name]; ok { - return found - } - if x.allLoaded { +func (r *Runtime) asset() *inventory.Asset { + if r.Provider == nil || r.Provider.Connection == nil { return nil } - - x.loadAllSchemas() - return x.Resources[name] -} - -func (x *extensibleSchema) LookupField(resource string, field string) (*resources.ResourceInfo, *resources.Field) { - found, ok := x.Resources[resource] - if !ok { - if x.allLoaded { - return nil, nil - } - - x.loadAllSchemas() - - found, ok = x.Resources[resource] - if !ok { - return nil, nil - } - return found, found.Fields[field] - } - - fieldObj, ok := found.Fields[field] - if ok { - return found, fieldObj - } - if x.allLoaded { - return found, nil - } - - x.loadAllSchemas() - return found, found.Fields[field] -} - -func (x *extensibleSchema) Add(name string, schema *resources.Schema) { - if schema == nil { - return - } - if name == "" { - log.Error().Msg("tried to add a schema with no name") - return - } - - x.lockAdd.Lock() - defer x.lockAdd.Unlock() - - if _, ok := x.loaded[name]; ok { - return - } - - x.loaded[name] = struct{}{} - x.Schema.Add(schema) -} - -func (x *extensibleSchema) AllResources() map[string]*resources.ResourceInfo { - if !x.allLoaded { - x.loadAllSchemas() - } - - return x.Resources + return r.Provider.Connection.Asset }