Skip to content

Commit

Permalink
⚡ re-use runtimes across providers
Browse files Browse the repository at this point in the history
  • Loading branch information
arlimus committed Oct 2, 2023
1 parent df07d0f commit 792def7
Show file tree
Hide file tree
Showing 7 changed files with 270 additions and 161 deletions.
8 changes: 4 additions & 4 deletions apps/cnquery/cmd/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,12 @@ func (c *cnqueryPlugin) RunQuery(conf *run.RunQueryConfig, runtime *providers.Ru

for i := range assets {
connectAsset := assets[i]
connectAssetRuntime := providers.Coordinator.NewRuntimeFrom(runtime)

if err := connectAssetRuntime.DetectProvider(connectAsset); err != nil {
connectAssetRuntime, err := providers.Coordinator.RuntimeFor(connectAsset, runtime)
if err != nil {
return err
}
err := connectAssetRuntime.Connect(&pp.ConnectReq{

err = connectAssetRuntime.Connect(&pp.ConnectReq{
Features: config.Features,
Asset: connectAsset,
Upstream: upstreamConfig,
Expand Down
3 changes: 3 additions & 0 deletions cli/providers/providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,9 @@ func setConnector(provider *plugin.Provider, connector *plugin.Connector, run fu

// TODO: add flag to set timeout and then use RuntimeWithShutdownTimeout
runtime := providers.Coordinator.NewRuntime()
if err = providers.SetDefaultRuntime(runtime); err != nil {
log.Error().Msg(err.Error())
}

autoUpdate := true
if viper.IsSet("auto_update") {
Expand Down
38 changes: 21 additions & 17 deletions explorer/scan/local_scanner.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,14 +174,15 @@ func preprocessQueryPackFilters(filters []string) []string {
func (s *LocalScanner) distributeJob(job *Job, ctx context.Context, upstream *upstream.UpstreamConfig) (*explorer.ReportCollection, bool, error) {
log.Info().Msgf("discover related assets for %d asset(s)", len(job.Inventory.Spec.Assets))

im, err := manager.NewManager(manager.WithInventory(job.Inventory, providers.Coordinator.NewRuntime()))
im, err := manager.NewManager(manager.WithInventory(job.Inventory, providers.DefaultRuntime()))
if err != nil {
return nil, false, errors.New("failed to resolve inventory for connection")
}
assetList := im.GetAssets()

var assets []*assetWithRuntime
var assetCandidates []*inventory.Asset
// note: asset candidate runtimes are the runtime that discovered them
var assetCandidates []*assetWithRuntime

// we connect and perform discovery for each asset in the job inventory
for i := range assetList {
Expand All @@ -190,11 +191,10 @@ func (s *LocalScanner) distributeJob(job *Job, ctx context.Context, upstream *up
if err != nil {
return nil, false, err
}
runtime := providers.Coordinator.NewRuntime()

err = runtime.DetectProvider(resolvedAsset)
runtime, err := providers.Coordinator.RuntimeFor(asset, providers.DefaultRuntime())
if err != nil {
log.Error().Err(err).Msg("unable to detect provider for asset")
log.Error().Err(err).Str("asset", asset.Name).Msg("unable to create runtime for asset")
continue
}
runtime.SetRecording(s.recording)
Expand All @@ -211,7 +211,12 @@ func (s *LocalScanner) distributeJob(job *Job, ctx context.Context, upstream *up
if err != nil {
return nil, false, err
}
assetCandidates = append(assetCandidates, processedAssets...)
for i := range processedAssets {
assetCandidates = append(assetCandidates, &assetWithRuntime{
asset: processedAssets[i],
runtime: runtime,
})
}
// TODO: we want to keep better track of errors, since there may be
// multiple assets coming in. It's annoying to abort the scan if we get one
// error at this stage.
Expand All @@ -222,28 +227,26 @@ func (s *LocalScanner) distributeJob(job *Job, ctx context.Context, upstream *up
}

// for each asset candidate, we initialize a new runtime and connect to it.
for _, asset := range assetCandidates {
runtime := providers.Coordinator.NewRuntime()
// Make sure the provider for the asset is present
if err := runtime.DetectProvider(asset); err != nil {
for i := range assetCandidates {
candidate := assetCandidates[i]

runtime, err := providers.Coordinator.RuntimeFor(candidate.asset, candidate.runtime)
if err != nil {
return nil, false, err
}

// attach recording before connect, so it is tied to the asset
runtime.SetRecording(s.recording)

err := runtime.Connect(&plugin.ConnectReq{
err = runtime.Connect(&plugin.ConnectReq{
Features: config.Features,
Asset: asset,
Asset: candidate.asset,
Upstream: upstream,
})
if err != nil {
log.Error().Err(err).Msg("unable to connect to asset")
log.Error().Err(err).Str("asset", candidate.asset.Name).Msg("unable to connect to asset")
continue
}

assets = append(assets, &assetWithRuntime{
asset: asset,
asset: candidate.asset,
runtime: runtime,
})
}
Expand Down Expand Up @@ -383,6 +386,7 @@ func (s *LocalScanner) distributeJob(job *Job, ctx context.Context, upstream *up
defer scanGroup.Done()
multiprogress.Open()
}()

scanGroup.Wait()
return reporter.Reports(), finished, nil
}
Expand Down
106 changes: 104 additions & 2 deletions providers/coordinator.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@ package providers
import (
"os"
"os/exec"
"strconv"
"sync"
"time"

"github.com/cockroachdb/errors"
"github.com/hashicorp/go-plugin"
"github.com/muesli/termenv"
"github.com/rs/zerolog/log"
"go.mondoo.com/cnquery/providers-sdk/v1/inventory"
pp "go.mondoo.com/cnquery/providers-sdk/v1/plugin"
"go.mondoo.com/cnquery/providers-sdk/v1/resources"
coreconf "go.mondoo.com/cnquery/providers/core/config"
Expand All @@ -22,13 +24,18 @@ import (
var BuiltinCoreID = coreconf.Config.ID

var Coordinator = coordinator{
Running: []*RunningProvider{},
Running: []*RunningProvider{},
runtimes: map[string]*Runtime{},
}

type coordinator struct {
Providers Providers
Running []*RunningProvider
mutex sync.Mutex

unprocessedRuntimes []*Runtime
runtimes map[string]*Runtime
runtimeCnt int
mutex sync.Mutex
}

type builtinProvider struct {
Expand Down Expand Up @@ -246,6 +253,101 @@ func (c *coordinator) tryProviderUpdate(provider *Provider, update UpdateProvide
return provider, nil
}

func (c *coordinator) NewRuntime() *Runtime {
res := &Runtime{
coordinator: c,
providers: map[string]*ConnectedProvider{},
schema: extensibleSchema{
loaded: map[string]struct{}{},
Schema: resources.Schema{
Resources: map[string]*resources.ResourceInfo{},
},
},
Recording: NullRecording{},
shutdownTimeout: defaultShutdownTimeout,
}
res.schema.runtime = res

// TODO: do this dynamically in the future
res.schema.loadAllSchemas()

c.mutex.Lock()
c.unprocessedRuntimes = append(c.unprocessedRuntimes, res)
c.runtimeCnt++
cnt := c.runtimeCnt
c.mutex.Unlock()

log.Warn().Msg("Started a new runtime (" + strconv.Itoa(cnt) + " total)")

return res
}

func (c *coordinator) NewRuntimeFrom(parent *Runtime) *Runtime {
res := c.NewRuntime()
res.Recording = parent.Recording
for k, v := range parent.providers {
res.providers[k] = v
}
return res
}

// RuntimFor an asset will return a new or existing runtime for a given asset.
// If a runtime for this asset already exists, it will re-use it. If the runtime
// is new, it will create it and detect the provider.
// The asset and parent must be defined.
func (c *coordinator) RuntimeFor(asset *inventory.Asset, parent *Runtime) (*Runtime, error) {
c.mutex.Lock()
c.unsafeRefreshRuntimes()
res := c.unsafeGetAssetRuntime(asset)
c.mutex.Unlock()

if res != nil {
return res, nil
}

res = c.NewRuntimeFrom(parent)
return res, res.DetectProvider(asset)
}

// Only call this with a mutex lock around it!
func (c *coordinator) unsafeRefreshRuntimes() {
var remaining []*Runtime
for i := range c.unprocessedRuntimes {
rt := c.unprocessedRuntimes[i]
if asset := rt.asset(); asset == nil && !c.unsafeSetAssetRuntime(asset, rt) {
remaining = append(remaining, rt)
}
}
c.unprocessedRuntimes = remaining
}

func (c *coordinator) unsafeGetAssetRuntime(asset *inventory.Asset) *Runtime {
if asset.Mrn != "" {
if rt := c.runtimes[asset.Mrn]; rt != nil {
return rt
}
}
for _, id := range asset.PlatformIds {
if rt := c.runtimes[id]; rt != nil {
return rt
}
}
return nil
}

func (c *coordinator) unsafeSetAssetRuntime(asset *inventory.Asset, runtime *Runtime) bool {
found := false
if asset.Mrn != "" {
c.runtimes[asset.Mrn] = runtime
found = true
}
for _, id := range asset.PlatformIds {
c.runtimes[id] = runtime
found = true
}
return found
}

func (c *coordinator) Close(p *RunningProvider) {
if !p.isClosed {
p.isClosed = true
Expand Down
13 changes: 12 additions & 1 deletion providers/defaults.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@

package providers

import "go.mondoo.com/cnquery/providers-sdk/v1/plugin"
import (
"github.com/cockroachdb/errors"
"go.mondoo.com/cnquery/providers-sdk/v1/plugin"
)

const DefaultOsID = "go.mondoo.com/cnquery/providers/os"

Expand All @@ -16,6 +19,14 @@ func DefaultRuntime() *Runtime {
return defaultRuntime
}

func SetDefaultRuntime(rt *Runtime) error {
if rt == nil {
return errors.New("attempted to set default runtime to null")
}
defaultRuntime = rt
return nil
}

// DefaultProviders are useful when working in air-gapped environments
// to tell users what providers are used for common connections, when there
// is no other way to find out.
Expand Down
Loading

0 comments on commit 792def7

Please sign in to comment.