Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

⚡ discover assets in parallel #4973

Merged
merged 14 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,18 @@
"shell", "ssh", "[email protected]",
],
},
{
"name": "scan github org",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did you mean to keep this in here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I think it is useful for folks testing, specially now that we have https://github.com/hit-training.

I can remove it.

"type": "go",
"request": "launch",
"program": "${workspaceRoot}/apps/cnquery/cnquery.go",
"args": [
"scan",
"github",
"org", "hit-training",
"--log-level", "trace"
]
},
{
"name": "Configure Built-in Providers",
"type": "go",
Expand Down
71 changes: 44 additions & 27 deletions explorer/scan/discovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@ package scan
import (
"context"
"errors"
"sync"
"time"

"github.com/rs/zerolog/log"
"go.mondoo.com/cnquery/v11/cli/config"
"go.mondoo.com/cnquery/v11/cli/execruntime"
"go.mondoo.com/cnquery/v11/internal/workerpool"
"go.mondoo.com/cnquery/v11/llx"
"go.mondoo.com/cnquery/v11/logger"
"go.mondoo.com/cnquery/v11/providers"
Expand All @@ -20,6 +22,9 @@ import (
"go.mondoo.com/cnquery/v11/providers-sdk/v1/upstream"
)

// number of parallel goroutines discovering assets
const workers = 10

type AssetWithRuntime struct {
Asset *inventory.Asset
Runtime *providers.Runtime
Expand All @@ -34,28 +39,30 @@ type DiscoveredAssets struct {
platformIds map[string]struct{}
Assets []*AssetWithRuntime
Errors []*AssetWithError
assetsLock sync.Mutex
}

// Add adds an asset and its runtime to the discovered assets list. It returns true if the
// asset has been added, false if it is a duplicate
func (d *DiscoveredAssets) Add(asset *inventory.Asset, runtime *providers.Runtime) bool {
isDuplicate := false
d.assetsLock.Lock()
defer d.assetsLock.Unlock()

for _, platformId := range asset.PlatformIds {
if _, ok := d.platformIds[platformId]; ok {
isDuplicate = true
break
// duplicate
return false
}
d.platformIds[platformId] = struct{}{}
}
if isDuplicate {
return false
}

d.Assets = append(d.Assets, &AssetWithRuntime{Asset: asset, Runtime: runtime})
return true
}

func (d *DiscoveredAssets) AddError(asset *inventory.Asset, err error) {
d.assetsLock.Lock()
defer d.assetsLock.Unlock()
d.Errors = append(d.Errors, &AssetWithError{Asset: asset, Err: err})
}

Expand Down Expand Up @@ -161,35 +168,45 @@ func discoverAssets(rootAssetWithRuntime *AssetWithRuntime, resolvedRootAsset *i
return
}

pool := workerpool.New[bool](workers)
pool.Start()
defer pool.Close()

// for all discovered assets, we apply mondoo-specific labels and annotations that come from the root asset
for _, a := range rootAssetWithRuntime.Runtime.Provider.Connection.Inventory.Spec.Assets {
// create runtime for root asset
assetWithRuntime, err := createRuntimeForAsset(a, upstream, recording)
if err != nil {
log.Error().Err(err).Str("asset", a.Name).Msg("unable to create runtime for asset")
discoveredAssets.AddError(a, err)
continue
}
pool.Submit(func() (bool, error) {
// create runtime for root asset
assetWithRuntime, err := createRuntimeForAsset(a, upstream, recording)
if err != nil {
log.Error().Err(err).Str("asset", a.Name).Msg("unable to create runtime for asset")
discoveredAssets.AddError(a, err)
return false, err
}

// If no asset was returned and no error, then we observed a duplicate asset with a
// runtime that already exists.
if assetWithRuntime == nil {
continue
}
// If no asset was returned and no error, then we observed a duplicate asset with a
// runtime that already exists.
if assetWithRuntime == nil {
return false, nil
}

resolvedAsset := assetWithRuntime.Runtime.Provider.Connection.Asset
if len(resolvedAsset.PlatformIds) > 0 {
prepareAsset(resolvedAsset, resolvedRootAsset, runtimeLabels)
resolvedAsset := assetWithRuntime.Runtime.Provider.Connection.Asset
if len(resolvedAsset.PlatformIds) > 0 {
prepareAsset(resolvedAsset, resolvedRootAsset, runtimeLabels)

// If the asset has been already added, we should close its runtime
if !discoveredAssets.Add(resolvedAsset, assetWithRuntime.Runtime) {
// If the asset has been already added, we should close its runtime
if !discoveredAssets.Add(resolvedAsset, assetWithRuntime.Runtime) {
assetWithRuntime.Runtime.Close()
}
} else {
discoverAssets(assetWithRuntime, resolvedRootAsset, discoveredAssets, runtimeLabels, upstream, recording)
assetWithRuntime.Runtime.Close()
}
} else {
discoverAssets(assetWithRuntime, resolvedRootAsset, discoveredAssets, runtimeLabels, upstream, recording)
assetWithRuntime.Runtime.Close()
}
return true, nil
})
}

// Wait for the workers to finish processing
pool.Wait()
}

func createRuntimeForAsset(asset *inventory.Asset, upstream *upstream.UpstreamConfig, recording llx.Recording) (*AssetWithRuntime, error) {
Expand Down
55 changes: 55 additions & 0 deletions internal/workerpool/collector.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// Copyright (c) Mondoo, Inc.
// SPDX-License-Identifier: BUSL-1.1

package workerpool

import (
"sync"
"sync/atomic"
)

type collector[R any] struct {
resultsCh <-chan R
results []R
read sync.Mutex

errorsCh <-chan error
errors []error

requestsRead int64
}

func (c *collector[R]) start() {
go func() {
for {
select {
case result := <-c.resultsCh:
c.read.Lock()
c.results = append(c.results, result)
c.read.Unlock()

case err := <-c.errorsCh:
c.read.Lock()
c.errors = append(c.errors, err)
c.read.Unlock()
}

atomic.AddInt64(&c.requestsRead, 1)
}
}()
}
func (c *collector[R]) GetResults() []R {
c.read.Lock()
defer c.read.Unlock()
return c.results
}

func (c *collector[R]) GetErrors() []error {
c.read.Lock()
defer c.read.Unlock()
return c.errors
}

func (c *collector[R]) RequestsRead() int64 {
return atomic.LoadInt64(&c.requestsRead)
}
112 changes: 112 additions & 0 deletions internal/workerpool/pool.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
// Copyright (c) Mondoo, Inc.
// SPDX-License-Identifier: BUSL-1.1

package workerpool

import (
"sync"
"sync/atomic"
"time"

"github.com/cockroachdb/errors"
)

type Task[R any] func() (result R, err error)

// Pool is a generic pool of workers.
type Pool[R any] struct {
queueCh chan Task[R]
resultsCh chan R
errorsCh chan error

requestsSent int64
once sync.Once

workers []*worker[R]
workerCount int

collector[R]
}

// New initializes a new Pool with the provided number of workers. The pool is generic and can
// accept any type of Task that returns the signature `func() (R, error)`.
//
// For example, a Pool[int] will accept Tasks similar to:
//
// task := func() (int, error) {
// return 42, nil
// }
func New[R any](count int) *Pool[R] {
resultsCh := make(chan R)
errorsCh := make(chan error)
return &Pool[R]{
queueCh: make(chan Task[R]),
resultsCh: resultsCh,
errorsCh: errorsCh,
workerCount: count,
collector: collector[R]{resultsCh: resultsCh, errorsCh: errorsCh},
}
}

// Start the pool workers and collector. Make sure call `Close()` to clear the pool.
//
// pool := workerpool.New[int](10)
// pool.Start()
// defer pool.Close()
func (p *Pool[R]) Start() {
p.once.Do(func() {
for i := 0; i < p.workerCount; i++ {
w := worker[R]{id: i, queueCh: p.queueCh, resultsCh: p.resultsCh, errorsCh: p.errorsCh}
w.start()
p.workers = append(p.workers, &w)
}

p.collector.start()
})
}

// Submit sends a task to the workers
func (p *Pool[R]) Submit(t Task[R]) {
p.queueCh <- t
atomic.AddInt64(&p.requestsSent, 1)
}

// GetErrors returns any error from a processed task
func (p *Pool[R]) GetErrors() error {
return errors.Join(p.collector.GetErrors()...)
}

// GetResults returns the tasks results.
//
// It is recommended to call `Wait()` before reading the results.
func (p *Pool[R]) GetResults() []R {
return p.collector.GetResults()
}

// Close waits for workers and collector to process all the requests, and then closes
// the task queue channel. After closing the pool, calling `Submit()` will panic.
func (p *Pool[R]) Close() {
p.Wait()
close(p.queueCh)
}

// Wait waits until all tasks have been processed.
func (p *Pool[R]) Wait() {
ticker := time.NewTicker(10 * time.Millisecond)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fixed the benchmark failure.

for {
if !p.Processing() {
return
}
<-ticker.C
}
}

// PendingRequests returns the number of pending requests.
func (p *Pool[R]) PendingRequests() int64 {
return atomic.LoadInt64(&p.requestsSent) - p.collector.RequestsRead()
}

// Processing return true if tasks are being processed.
func (p *Pool[R]) Processing() bool {
return p.PendingRequests() != 0
}
Loading
Loading