diff --git a/.github/workflows/main-benchmark.yml b/.github/workflows/main-benchmark.yml index e7d31ef9f4..0fbf8eec86 100644 --- a/.github/workflows/main-benchmark.yml +++ b/.github/workflows/main-benchmark.yml @@ -63,4 +63,4 @@ jobs: uses: actions/cache/save@v4 with: path: ./cache - key: ${{ runner.os }}-benchmark-${{ github.run_id }} \ No newline at end of file + key: ${{ runner.os }}-benchmark-${{ github.run_id }} diff --git a/.github/workflows/pr-test-lint.yml b/.github/workflows/pr-test-lint.yml index a46d51920b..33490eb187 100644 --- a/.github/workflows/pr-test-lint.yml +++ b/.github/workflows/pr-test-lint.yml @@ -128,6 +128,24 @@ jobs: name: test-results-cli path: report.xml + go-race: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Import environment variables from file + run: cat ".github/env" >> $GITHUB_ENV + + - name: Install Go + uses: actions/setup-go@v5 + with: + go-version: ">=${{ env.golang-version }}" + cache: false + + - name: Run race detector on selected packages + run: make race/go + go-bench: runs-on: ubuntu-latest if: github.ref != 'refs/heads/main' diff --git a/.vscode/launch.json b/.vscode/launch.json index 85931b3a76..2fab2c0990 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -198,6 +198,18 @@ "shell", "ssh", "user@18.215.249.49", ], }, + { + "name": "scan github org", + "type": "go", + "request": "launch", + "program": "${workspaceRoot}/apps/cnquery/cnquery.go", + "args": [ + "scan", + "github", + "org", "hit-training", + "--log-level", "trace" + ] + }, { "name": "Configure Built-in Providers", "type": "go", diff --git a/Makefile b/Makefile index 753969f04e..b70761b4f9 100644 --- a/Makefile +++ b/Makefile @@ -700,6 +700,10 @@ test: test/go test/lint benchmark/go: go test -bench=. -benchmem go.mondoo.com/cnquery/v11/explorer/scan/benchmark +race/go: + go test -race go.mondoo.com/cnquery/v11/internal/workerpool + go test -race go.mondoo.com/cnquery/v11/explorer/scan + test/generate: prep/tools/mockgen go generate ./providers diff --git a/explorer/scan/discovery.go b/explorer/scan/discovery.go index be7a67e7e0..994b9f402f 100644 --- a/explorer/scan/discovery.go +++ b/explorer/scan/discovery.go @@ -6,11 +6,13 @@ package scan import ( "context" "errors" + "sync" "time" "github.com/rs/zerolog/log" "go.mondoo.com/cnquery/v11/cli/config" "go.mondoo.com/cnquery/v11/cli/execruntime" + "go.mondoo.com/cnquery/v11/internal/workerpool" "go.mondoo.com/cnquery/v11/llx" "go.mondoo.com/cnquery/v11/logger" "go.mondoo.com/cnquery/v11/providers" @@ -20,6 +22,9 @@ import ( "go.mondoo.com/cnquery/v11/providers-sdk/v1/upstream" ) +// number of parallel goroutines discovering assets +const workers = 10 + type AssetWithRuntime struct { Asset *inventory.Asset Runtime *providers.Runtime @@ -34,28 +39,30 @@ type DiscoveredAssets struct { platformIds map[string]struct{} Assets []*AssetWithRuntime Errors []*AssetWithError + assetsLock sync.Mutex } // Add adds an asset and its runtime to the discovered assets list. It returns true if the // asset has been added, false if it is a duplicate func (d *DiscoveredAssets) Add(asset *inventory.Asset, runtime *providers.Runtime) bool { - isDuplicate := false + d.assetsLock.Lock() + defer d.assetsLock.Unlock() + for _, platformId := range asset.PlatformIds { if _, ok := d.platformIds[platformId]; ok { - isDuplicate = true - break + // duplicate + return false } d.platformIds[platformId] = struct{}{} } - if isDuplicate { - return false - } d.Assets = append(d.Assets, &AssetWithRuntime{Asset: asset, Runtime: runtime}) return true } func (d *DiscoveredAssets) AddError(asset *inventory.Asset, err error) { + d.assetsLock.Lock() + defer d.assetsLock.Unlock() d.Errors = append(d.Errors, &AssetWithError{Asset: asset, Err: err}) } @@ -161,17 +168,30 @@ func discoverAssets(rootAssetWithRuntime *AssetWithRuntime, resolvedRootAsset *i return } + pool := workerpool.New[*AssetWithRuntime](workers) + pool.Start() + defer pool.Close() + // for all discovered assets, we apply mondoo-specific labels and annotations that come from the root asset - for _, a := range rootAssetWithRuntime.Runtime.Provider.Connection.Inventory.Spec.Assets { - // create runtime for root asset - assetWithRuntime, err := createRuntimeForAsset(a, upstream, recording) - if err != nil { - log.Error().Err(err).Str("asset", a.Name).Msg("unable to create runtime for asset") - discoveredAssets.AddError(a, err) - continue - } + for _, asset := range rootAssetWithRuntime.Runtime.Provider.Connection.Inventory.Spec.Assets { + pool.Submit(func() (*AssetWithRuntime, error) { + assetWithRuntime, err := createRuntimeForAsset(asset, upstream, recording) + if err != nil { + log.Error().Err(err).Str("asset", asset.GetName()).Msg("unable to create runtime for asset") + discoveredAssets.AddError(asset, err) + } + return assetWithRuntime, nil + }) + } + + // Wait for the workers to finish processing + pool.Wait() + + // Get all assets with runtimes from the pool + for _, result := range pool.GetResults() { + assetWithRuntime := result.Value - // If no asset was returned and no error, then we observed a duplicate asset with a + // If asset is nil, then we observed a duplicate asset with a // runtime that already exists. if assetWithRuntime == nil { continue diff --git a/internal/workerpool/collector.go b/internal/workerpool/collector.go index 2d105501be..bb33bb836e 100644 --- a/internal/workerpool/collector.go +++ b/internal/workerpool/collector.go @@ -9,13 +9,11 @@ import ( ) type collector[R any] struct { - resultsCh <-chan R - results []R + resultsCh <-chan Result[R] + results []Result[R] read sync.Mutex - errorsCh <-chan error - errors []error - + // The total number of requests read. requestsRead int64 } @@ -27,29 +25,35 @@ func (c *collector[R]) start() { c.read.Lock() c.results = append(c.results, result) c.read.Unlock() - - case err := <-c.errorsCh: - c.read.Lock() - c.errors = append(c.errors, err) - c.read.Unlock() } atomic.AddInt64(&c.requestsRead, 1) } }() } -func (c *collector[R]) GetResults() []R { + +func (c *collector[R]) RequestsRead() int64 { + return atomic.LoadInt64(&c.requestsRead) +} + +func (c *collector[R]) GetResults() []Result[R] { c.read.Lock() defer c.read.Unlock() return c.results } -func (c *collector[R]) GetErrors() []error { - c.read.Lock() - defer c.read.Unlock() - return c.errors +func (c *collector[R]) GetValues() (slice []R) { + results := c.GetResults() + for i := range results { + slice = append(slice, results[i].Value) + } + return } -func (c *collector[R]) RequestsRead() int64 { - return atomic.LoadInt64(&c.requestsRead) +func (c *collector[R]) GetErrors() (slice []error) { + results := c.GetResults() + for i := range results { + slice = append(slice, results[i].Error) + } + return } diff --git a/internal/workerpool/pool.go b/internal/workerpool/pool.go index 8553ca25d5..1ad4afa86b 100644 --- a/internal/workerpool/pool.go +++ b/internal/workerpool/pool.go @@ -7,25 +7,40 @@ import ( "sync" "sync/atomic" "time" - - "github.com/cockroachdb/errors" ) +// Represent the tasks that can be sent to the pool. type Task[R any] func() (result R, err error) +// The result generated from a task. +type Result[R any] struct { + Value R + Error error +} + // Pool is a generic pool of workers. type Pool[R any] struct { - queueCh chan Task[R] - resultsCh chan R - errorsCh chan error + // The queue where tasks are submitted. + queueCh chan Task[R] + // Where workers send the results after a task is executed, + // the collector then reads them and aggregate them. + resultsCh chan Result[R] + + // The total number of requests sent. requestsSent int64 - once sync.Once - workers []*worker[R] + // Number of workers to spawn. workerCount int + // The list of workers that are listening to the queue. + workers []*worker[R] + + // A single collector to aggregate results. collector[R] + + // used to protect starting the pool multiple times + once sync.Once } // New initializes a new Pool with the provided number of workers. The pool is generic and can @@ -37,14 +52,12 @@ type Pool[R any] struct { // return 42, nil // } func New[R any](count int) *Pool[R] { - resultsCh := make(chan R) - errorsCh := make(chan error) + resultsCh := make(chan Result[R]) return &Pool[R]{ queueCh: make(chan Task[R]), resultsCh: resultsCh, - errorsCh: errorsCh, workerCount: count, - collector: collector[R]{resultsCh: resultsCh, errorsCh: errorsCh}, + collector: collector[R]{resultsCh: resultsCh}, } } @@ -56,7 +69,7 @@ func New[R any](count int) *Pool[R] { func (p *Pool[R]) Start() { p.once.Do(func() { for i := 0; i < p.workerCount; i++ { - w := worker[R]{id: i, queueCh: p.queueCh, resultsCh: p.resultsCh, errorsCh: p.errorsCh} + w := worker[R]{id: i, queueCh: p.queueCh, resultsCh: p.resultsCh} w.start() p.workers = append(p.workers, &w) } @@ -67,22 +80,33 @@ func (p *Pool[R]) Start() { // Submit sends a task to the workers func (p *Pool[R]) Submit(t Task[R]) { - p.queueCh <- t - atomic.AddInt64(&p.requestsSent, 1) -} - -// GetErrors returns any error from a processed task -func (p *Pool[R]) GetErrors() error { - return errors.Join(p.collector.GetErrors()...) + if t != nil { + p.queueCh <- t + atomic.AddInt64(&p.requestsSent, 1) + } } // GetResults returns the tasks results. // // It is recommended to call `Wait()` before reading the results. -func (p *Pool[R]) GetResults() []R { +func (p *Pool[R]) GetResults() []Result[R] { return p.collector.GetResults() } +// GetValues returns only the values of the pool results +// +// It is recommended to call `Wait()` before reading the results. +func (p *Pool[R]) GetValues() []R { + return p.collector.GetValues() +} + +// GetErrors returns only the errors of the pool results +// +// It is recommended to call `Wait()` before reading the results. +func (p *Pool[R]) GettErrors() []error { + return p.collector.GetErrors() +} + // Close waits for workers and collector to process all the requests, and then closes // the task queue channel. After closing the pool, calling `Submit()` will panic. func (p *Pool[R]) Close() { @@ -92,7 +116,7 @@ func (p *Pool[R]) Close() { // Wait waits until all tasks have been processed. func (p *Pool[R]) Wait() { - ticker := time.NewTicker(100 * time.Millisecond) + ticker := time.NewTicker(10 * time.Millisecond) for { if !p.Processing() { return diff --git a/internal/workerpool/pool_test.go b/internal/workerpool/pool_test.go index 3b3946df1e..222dad5707 100644 --- a/internal/workerpool/pool_test.go +++ b/internal/workerpool/pool_test.go @@ -35,11 +35,10 @@ func TestPoolSubmitAndRetrieveResult(t *testing.T) { // should have one result results := pool.GetResults() if assert.Len(t, results, 1) { - assert.Equal(t, 42, results[0]) + assert.Equal(t, 42, results[0].Value) + // without errors + assert.NoError(t, results[0].Error) } - - // no errors - assert.Nil(t, pool.GetErrors()) } func TestPoolHandleErrors(t *testing.T) { @@ -53,12 +52,12 @@ func TestPoolHandleErrors(t *testing.T) { } pool.Submit(task) - // Wait for error collector to process + // Wait for collector to process the results pool.Wait() - err := pool.GetErrors() - if assert.Error(t, err) { - assert.Contains(t, err.Error(), "task error") + errs := pool.GetErrors() + if assert.Len(t, errs, 1) { + assert.Equal(t, errs[0].Error(), "task error") } } @@ -86,12 +85,26 @@ func TestPoolMultipleTasksWithErrors(t *testing.T) { // Wait for error collector to process pool.Wait() - results := pool.GetResults() - assert.ElementsMatch(t, []*test{&test{1}, &test{2}, &test{3}}, results) - err := pool.GetErrors() - if assert.Error(t, err) { - assert.Contains(t, err.Error(), "task error") - } + // Access results together + assert.ElementsMatch(t, + []workerpool.Result[*test]{ + {&test{1}, nil}, + {&test{2}, nil}, + {&test{3}, nil}, + {nil, errors.New("task error")}, + }, + pool.GetResults(), + ) + + // You can also access values and errors directly + assert.ElementsMatch(t, + []*test{nil, &test{1}, &test{2}, &test{3}}, + pool.GetValues(), + ) + assert.ElementsMatch(t, + []error{nil, nil, errors.New("task error"), nil}, + pool.GetErrors(), + ) } func TestPoolHandlesNilTasks(t *testing.T) { @@ -104,8 +117,8 @@ func TestPoolHandlesNilTasks(t *testing.T) { pool.Wait() - err := pool.GetErrors() - assert.NoError(t, err) + assert.Empty(t, pool.GetErrors()) + assert.Empty(t, pool.GetValues()) } func TestPoolProcessing(t *testing.T) { @@ -126,9 +139,8 @@ func TestPoolProcessing(t *testing.T) { // wait pool.Wait() - // read results - result := pool.GetResults() - assert.Equal(t, []int{10}, result) + // read values + assert.Equal(t, []int{10}, pool.GetValues()) // should not longer be processing assert.False(t, pool.Processing()) diff --git a/internal/workerpool/worker.go b/internal/workerpool/worker.go index 77b5c81f15..31257353c6 100644 --- a/internal/workerpool/worker.go +++ b/internal/workerpool/worker.go @@ -6,25 +6,14 @@ package workerpool type worker[R any] struct { id int queueCh <-chan Task[R] - resultsCh chan<- R - errorsCh chan<- error + resultsCh chan<- Result[R] } func (w *worker[R]) start() { go func() { for task := range w.queueCh { - if task == nil { - // let the collector know we processed the request - w.errorsCh <- nil - continue - } - data, err := task() - if err != nil { - w.errorsCh <- err - } else { - w.resultsCh <- data - } + w.resultsCh <- Result[R]{data, err} } }() } diff --git a/providers-sdk/v1/plugin/service.go b/providers-sdk/v1/plugin/service.go index 382efcc90b..ea36c6f9a1 100644 --- a/providers-sdk/v1/plugin/service.go +++ b/providers-sdk/v1/plugin/service.go @@ -51,11 +51,8 @@ func (s *Service) AddRuntime(conf *inventory.Config, createRuntime func(connId u } // ^^ - s.runtimesLock.Lock() - defer s.runtimesLock.Unlock() - // If a runtime with this ID already exists, then return that - if runtime, ok := s.runtimes[conf.Id]; ok { + if runtime, err := s.GetRuntime(conf.Id); err == nil { return runtime, nil } @@ -66,7 +63,7 @@ func (s *Service) AddRuntime(conf *inventory.Config, createRuntime func(connId u if runtime.Connection != nil { if parentId := runtime.Connection.ParentID(); parentId > 0 { - parentRuntime, err := s.doGetRuntime(parentId) + parentRuntime, err := s.GetRuntime(parentId) if err != nil { return nil, errors.New("parent connection " + strconv.FormatUint(uint64(parentId), 10) + " not found") } @@ -74,10 +71,19 @@ func (s *Service) AddRuntime(conf *inventory.Config, createRuntime func(connId u } } - s.runtimes[conf.Id] = runtime + + // store the new runtime + s.addRuntime(conf.Id, runtime) + return runtime, nil } +func (s *Service) addRuntime(id uint32, runtime *Runtime) { + s.runtimesLock.Lock() + defer s.runtimesLock.Unlock() + s.runtimes[id] = runtime +} + // FIXME: DEPRECATED, remove in v12.0 vv func (s *Service) deprecatedAddRuntime(createRuntime func(connId uint32) (*Runtime, error)) (*Runtime, error) { s.runtimesLock.Lock() diff --git a/providers/github/connection/connection.go b/providers/github/connection/connection.go index 8e7cdb5942..c974ef0b06 100644 --- a/providers/github/connection/connection.go +++ b/providers/github/connection/connection.go @@ -74,6 +74,7 @@ func NewGithubConnection(id uint32, asset *inventory.Asset) (*GithubConnection, ctx := context.WithValue(context.Background(), github.SleepUntilPrimaryRateLimitResetWhenRateLimited, true) // perform a quick call to verify the token's validity. + // @afiune do we need to validate the token for every connection? can this be a "once" operation? _, resp, err := client.Meta.Zen(ctx) if err != nil { if resp != nil && resp.StatusCode == 401 { diff --git a/providers/github/resources/github_org.go b/providers/github/resources/github_org.go index ef39f97159..0876e4a71f 100644 --- a/providers/github/resources/github_org.go +++ b/providers/github/resources/github_org.go @@ -4,12 +4,12 @@ package resources import ( - "errors" "slices" "strconv" "strings" "time" + "github.com/cockroachdb/errors" "github.com/google/go-github/v67/github" "github.com/rs/zerolog/log" "go.mondoo.com/cnquery/v11/internal/workerpool" @@ -287,7 +287,7 @@ func (g *mqlGithubOrganization) repositories() ([]interface{}, error) { for { // exit as soon as we collect all repositories - reposLen := len(slices.Concat(workerPool.GetResults()...)) + reposLen := len(slices.Concat(workerPool.GetValues()...)) if reposLen >= int(repoCount) { break } @@ -303,7 +303,8 @@ func (g *mqlGithubOrganization) repositories() ([]interface{}, error) { listOpts.Page++ // check if any request failed - if err := workerPool.GetErrors(); err != nil { + if errs := workerPool.GetErrors(); len(errs) != 0 { + err := errors.Join(errs...) if strings.Contains(err.Error(), "404") { return nil, nil } @@ -316,7 +317,7 @@ func (g *mqlGithubOrganization) repositories() ([]interface{}, error) { } res := []interface{}{} - for _, repos := range workerPool.GetResults() { + for _, repos := range workerPool.GetValues() { for i := range repos { repo := repos[i] diff --git a/providers/runtime.go b/providers/runtime.go index 840f09fa25..ee71ab7af3 100644 --- a/providers/runtime.go +++ b/providers/runtime.go @@ -39,6 +39,9 @@ type Runtime struct { isClosed bool close sync.Once shutdownTimeout time.Duration + + // used to lock unsafe tasks + mu sync.Mutex } type ConnectedProvider struct { @@ -118,12 +121,23 @@ func (r *Runtime) UseProvider(id string) error { return err } + r.mu.Lock() r.Provider = res + r.mu.Unlock() return nil } func (r *Runtime) AddConnectedProvider(c *ConnectedProvider) { + r.mu.Lock() r.providers[c.Instance.ID] = c + r.mu.Unlock() +} + +func (r *Runtime) setProviderConnection(c *plugin.ConnectRes, err error) { + r.mu.Lock() + r.Provider.Connection = c + r.Provider.ConnectionError = err + r.mu.Unlock() } func (r *Runtime) addProvider(id string) (*ConnectedProvider, error) { @@ -232,9 +246,10 @@ func (r *Runtime) Connect(req *plugin.ConnectReq) error { // } - r.Provider.Connection, r.Provider.ConnectionError = r.Provider.Instance.Plugin.Connect(req, &callbacks) - if r.Provider.ConnectionError != nil { - return r.Provider.ConnectionError + conn, err := r.Provider.Instance.Plugin.Connect(req, &callbacks) + r.setProviderConnection(conn, err) + if err != nil { + return err } // TODO: This is a stopgap that detects if the connect call returned an asset @@ -256,9 +271,10 @@ func (r *Runtime) Connect(req *plugin.ConnectReq) error { if postProvider.ID != r.Provider.Instance.ID { req.Asset = r.Provider.Connection.Asset r.UseProvider(postProvider.ID) - r.Provider.Connection, r.Provider.ConnectionError = r.Provider.Instance.Plugin.Connect(req, &callbacks) - if r.Provider.ConnectionError != nil { - return r.Provider.ConnectionError + conn, err := r.Provider.Instance.Plugin.Connect(req, &callbacks) + r.setProviderConnection(conn, err) + if err != nil { + return err } } @@ -747,6 +763,9 @@ func (r *Runtime) Schema() resources.ResourcesSchema { } func (r *Runtime) asset() *inventory.Asset { + r.mu.Lock() + defer r.mu.Unlock() + if r.Provider == nil || r.Provider.Connection == nil { return nil }