From 42f0627f5b4f55dee932d9a9ef1591fb2d83bd80 Mon Sep 17 00:00:00 2001 From: Salim Afiune Maya Date: Tue, 10 Dec 2024 23:08:13 -0800 Subject: [PATCH 01/13] :zap: workerpool package to submit parallel requests Signed-off-by: Salim Afiune Maya --- internal/workerpool/pool.go | 71 +++++++++++++++ internal/workerpool/pool_test.go | 145 +++++++++++++++++++++++++++++++ internal/workerpool/worker.go | 28 ++++++ 3 files changed, 244 insertions(+) create mode 100644 internal/workerpool/pool.go create mode 100644 internal/workerpool/pool_test.go create mode 100644 internal/workerpool/worker.go diff --git a/internal/workerpool/pool.go b/internal/workerpool/pool.go new file mode 100644 index 0000000000..3bf8d71bfc --- /dev/null +++ b/internal/workerpool/pool.go @@ -0,0 +1,71 @@ +// Copyright (c) Mondoo, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package workerpool + +import ( + "github.com/cockroachdb/errors" +) + +type Task[R any] func() (result R, err error) + +type Pool[R any] struct { + queue chan Task[R] + results chan R + errors chan error + workerCount int + requestsSent int + requestsRead int + + err error +} + +func New[R any](count int) *Pool[R] { + return &Pool[R]{ + queue: make(chan Task[R]), + results: make(chan R), + errors: make(chan error), + workerCount: count, + } +} + +func (p *Pool[R]) Start() { + for i := 0; i < p.workerCount; i++ { + w := worker[R]{id: i, queue: p.queue, results: p.results, errors: p.errors} + w.Start() + } + + p.errorCollector() +} + +func (p *Pool[R]) errorCollector() { + go func() { + for e := range p.errors { + p.err = errors.Join(p.err, e) + } + }() +} + +func (p *Pool[R]) GetError() error { + return p.err +} + +func (p *Pool[R]) Submit(t Task[R]) { + p.queue <- t + p.requestsSent++ +} + +func (p *Pool[R]) GetResult() R { + defer func() { + p.requestsRead++ + }() + return <-p.results +} + +func (p *Pool[R]) HasPendingRequests() bool { + return p.requestsSent-p.requestsRead > 0 +} + +func (p *Pool[R]) Close() { + close(p.queue) +} diff --git a/internal/workerpool/pool_test.go b/internal/workerpool/pool_test.go new file mode 100644 index 0000000000..6337ca28c0 --- /dev/null +++ b/internal/workerpool/pool_test.go @@ -0,0 +1,145 @@ +// Copyright (c) Mondoo, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package workerpool_test + +import ( + "errors" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "go.mondoo.com/cnquery/v11/internal/workerpool" +) + +func TestPoolSubmitAndRetrieveResult(t *testing.T) { + pool := workerpool.New[int](2) + pool.Start() + defer pool.Close() + + task := func() (int, error) { + return 42, nil + } + + // no requests + assert.False(t, pool.HasPendingRequests()) + + // submit a request + pool.Submit(task) + + // should have pending requests + assert.True(t, pool.HasPendingRequests()) + + // assert results comes back + result := pool.GetResult() + assert.Equal(t, 42, result) + + // no more requests pending + assert.False(t, pool.HasPendingRequests()) + + // no errors + assert.Nil(t, pool.GetError()) +} + +func TestPoolHandleErrors(t *testing.T) { + pool := workerpool.New[int](5) + pool.Start() + defer pool.Close() + + // submit a task that will return an error + task := func() (int, error) { + return 0, errors.New("task error") + } + pool.Submit(task) + + // Wait for error collector to process + time.Sleep(100 * time.Millisecond) + + err := pool.GetError() + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "task error") + } +} + +func TestPoolMultipleTasksWithErrors(t *testing.T) { + type test struct { + data int + } + pool := workerpool.New[*test](5) + pool.Start() + defer pool.Close() + + tasks := []workerpool.Task[*test]{ + func() (*test, error) { return &test{1}, nil }, + func() (*test, error) { return &test{2}, nil }, + func() (*test, error) { + return nil, errors.New("task error") + }, + func() (*test, error) { return &test{3}, nil }, + } + + for _, task := range tasks { + pool.Submit(task) + } + + var results []*test + for range tasks { + results = append(results, pool.GetResult()) + } + + assert.ElementsMatch(t, []*test{nil, &test{1}, &test{2}, &test{3}}, results) + assert.False(t, pool.HasPendingRequests()) + +} + +func TestPoolHandlesNilTasks(t *testing.T) { + pool := workerpool.New[int](2) + pool.Start() + defer pool.Close() + + var nilTask workerpool.Task[int] + pool.Submit(nilTask) + + // Wait for worker to process the nil task + time.Sleep(100 * time.Millisecond) + + err := pool.GetError() + assert.NoError(t, err) +} + +func TestPoolHasPendingRequests(t *testing.T) { + pool := workerpool.New[int](2) + pool.Start() + defer pool.Close() + + task := func() (int, error) { + time.Sleep(50 * time.Millisecond) + return 10, nil + } + + pool.Submit(task) + assert.True(t, pool.HasPendingRequests()) + + result := pool.GetResult() + assert.Equal(t, 10, result) + assert.False(t, pool.HasPendingRequests()) +} + +func TestPoolClosesGracefully(t *testing.T) { + pool := workerpool.New[int](1) + pool.Start() + + task := func() (int, error) { + time.Sleep(100 * time.Millisecond) + return 42, nil + } + + pool.Submit(task) + + pool.Close() + + // Ensure no panic occurs and channels are closed + assert.PanicsWithError(t, "send on closed channel", func() { + pool.Submit(task) + }) +} diff --git a/internal/workerpool/worker.go b/internal/workerpool/worker.go new file mode 100644 index 0000000000..4a391d44b6 --- /dev/null +++ b/internal/workerpool/worker.go @@ -0,0 +1,28 @@ +// Copyright (c) Mondoo, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package workerpool + +type worker[R any] struct { + id int + queue <-chan Task[R] + results chan<- R + errors chan<- error +} + +func (w *worker[R]) Start() { + go func() { + for task := range w.queue { + if task == nil { + continue + } + + data, err := task() + if err != nil { + w.errors <- err + } + + w.results <- data + } + }() +} From 5ae3d75d7b480f3adee9c740e87c85aa8095bddb Mon Sep 17 00:00:00 2001 From: Salim Afiune Maya Date: Tue, 10 Dec 2024 23:11:48 -0800 Subject: [PATCH 02/13] :zap: fetch org repositories in parallel Signed-off-by: Salim Afiune Maya --- providers-sdk/v1/inventory/inventory.pb.go | 4 +- providers-sdk/v1/plugin/plugin.pb.go | 4 +- providers-sdk/v1/plugin/plugin_grpc.pb.go | 46 ++++--------- providers-sdk/v1/resources/resources.pb.go | 4 +- providers-sdk/v1/vault/vault.pb.go | 4 +- providers/github/provider/provider.go | 3 +- providers/github/resources/github.go | 1 + providers/github/resources/github.lr | 2 + providers/github/resources/github.lr.go | 12 ++++ .../github/resources/github.lr.manifest.yaml | 2 + providers/github/resources/github_org.go | 65 +++++++++++++++---- 11 files changed, 93 insertions(+), 54 deletions(-) diff --git a/providers-sdk/v1/inventory/inventory.pb.go b/providers-sdk/v1/inventory/inventory.pb.go index 07cf91ce9f..6ae36eef19 100644 --- a/providers-sdk/v1/inventory/inventory.pb.go +++ b/providers-sdk/v1/inventory/inventory.pb.go @@ -3,8 +3,8 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.35.1 -// protoc v5.28.3 +// protoc-gen-go v1.35.2 +// protoc v5.29.0 // source: inventory.proto package inventory diff --git a/providers-sdk/v1/plugin/plugin.pb.go b/providers-sdk/v1/plugin/plugin.pb.go index bc99be0316..844819aa20 100644 --- a/providers-sdk/v1/plugin/plugin.pb.go +++ b/providers-sdk/v1/plugin/plugin.pb.go @@ -3,8 +3,8 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.35.1 -// protoc v5.28.3 +// protoc-gen-go v1.35.2 +// protoc v5.29.0 // source: plugin.proto package plugin diff --git a/providers-sdk/v1/plugin/plugin_grpc.pb.go b/providers-sdk/v1/plugin/plugin_grpc.pb.go index 4d1d3d352d..81b221fc9e 100644 --- a/providers-sdk/v1/plugin/plugin_grpc.pb.go +++ b/providers-sdk/v1/plugin/plugin_grpc.pb.go @@ -3,8 +3,8 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: -// - protoc-gen-go-grpc v1.5.1 -// - protoc v5.28.3 +// - protoc-gen-go-grpc v1.4.0 +// - protoc v5.29.0 // source: plugin.proto package plugin @@ -18,8 +18,8 @@ import ( // This is a compile-time assertion to ensure that this generated file // is compatible with the grpc package it is being compiled against. -// Requires gRPC-Go v1.64.0 or later. -const _ = grpc.SupportPackageIsVersion9 +// Requires gRPC-Go v1.62.0 or later. +const _ = grpc.SupportPackageIsVersion8 const ( ProviderPlugin_Heartbeat_FullMethodName = "/cnquery.providers.v1.ProviderPlugin/Heartbeat" @@ -136,7 +136,7 @@ func (c *providerPluginClient) StoreData(ctx context.Context, in *StoreReq, opts // ProviderPluginServer is the server API for ProviderPlugin service. // All implementations must embed UnimplementedProviderPluginServer -// for forward compatibility. +// for forward compatibility type ProviderPluginServer interface { Heartbeat(context.Context, *HeartbeatReq) (*HeartbeatRes, error) ParseCLI(context.Context, *ParseCLIReq) (*ParseCLIRes, error) @@ -149,12 +149,9 @@ type ProviderPluginServer interface { mustEmbedUnimplementedProviderPluginServer() } -// UnimplementedProviderPluginServer must be embedded to have -// forward compatible implementations. -// -// NOTE: this should be embedded by value instead of pointer to avoid a nil -// pointer dereference when methods are called. -type UnimplementedProviderPluginServer struct{} +// UnimplementedProviderPluginServer must be embedded to have forward compatible implementations. +type UnimplementedProviderPluginServer struct { +} func (UnimplementedProviderPluginServer) Heartbeat(context.Context, *HeartbeatReq) (*HeartbeatRes, error) { return nil, status.Errorf(codes.Unimplemented, "method Heartbeat not implemented") @@ -181,7 +178,6 @@ func (UnimplementedProviderPluginServer) StoreData(context.Context, *StoreReq) ( return nil, status.Errorf(codes.Unimplemented, "method StoreData not implemented") } func (UnimplementedProviderPluginServer) mustEmbedUnimplementedProviderPluginServer() {} -func (UnimplementedProviderPluginServer) testEmbeddedByValue() {} // UnsafeProviderPluginServer may be embedded to opt out of forward compatibility for this service. // Use of this interface is not recommended, as added methods to ProviderPluginServer will @@ -191,13 +187,6 @@ type UnsafeProviderPluginServer interface { } func RegisterProviderPluginServer(s grpc.ServiceRegistrar, srv ProviderPluginServer) { - // If the following call pancis, it indicates UnimplementedProviderPluginServer was - // embedded by pointer and is nil. This will cause panics if an - // unimplemented method is ever invoked, so we test this at initialization - // time to prevent it from happening at runtime later due to I/O. - if t, ok := srv.(interface{ testEmbeddedByValue() }); ok { - t.testEmbeddedByValue() - } s.RegisterService(&ProviderPlugin_ServiceDesc, srv) } @@ -444,7 +433,7 @@ func (c *providerCallbackClient) GetData(ctx context.Context, in *DataReq, opts // ProviderCallbackServer is the server API for ProviderCallback service. // All implementations must embed UnimplementedProviderCallbackServer -// for forward compatibility. +// for forward compatibility type ProviderCallbackServer interface { Collect(context.Context, *DataRes) (*CollectRes, error) GetRecording(context.Context, *DataReq) (*ResourceData, error) @@ -452,12 +441,9 @@ type ProviderCallbackServer interface { mustEmbedUnimplementedProviderCallbackServer() } -// UnimplementedProviderCallbackServer must be embedded to have -// forward compatible implementations. -// -// NOTE: this should be embedded by value instead of pointer to avoid a nil -// pointer dereference when methods are called. -type UnimplementedProviderCallbackServer struct{} +// UnimplementedProviderCallbackServer must be embedded to have forward compatible implementations. +type UnimplementedProviderCallbackServer struct { +} func (UnimplementedProviderCallbackServer) Collect(context.Context, *DataRes) (*CollectRes, error) { return nil, status.Errorf(codes.Unimplemented, "method Collect not implemented") @@ -469,7 +455,6 @@ func (UnimplementedProviderCallbackServer) GetData(context.Context, *DataReq) (* return nil, status.Errorf(codes.Unimplemented, "method GetData not implemented") } func (UnimplementedProviderCallbackServer) mustEmbedUnimplementedProviderCallbackServer() {} -func (UnimplementedProviderCallbackServer) testEmbeddedByValue() {} // UnsafeProviderCallbackServer may be embedded to opt out of forward compatibility for this service. // Use of this interface is not recommended, as added methods to ProviderCallbackServer will @@ -479,13 +464,6 @@ type UnsafeProviderCallbackServer interface { } func RegisterProviderCallbackServer(s grpc.ServiceRegistrar, srv ProviderCallbackServer) { - // If the following call pancis, it indicates UnimplementedProviderCallbackServer was - // embedded by pointer and is nil. This will cause panics if an - // unimplemented method is ever invoked, so we test this at initialization - // time to prevent it from happening at runtime later due to I/O. - if t, ok := srv.(interface{ testEmbeddedByValue() }); ok { - t.testEmbeddedByValue() - } s.RegisterService(&ProviderCallback_ServiceDesc, srv) } diff --git a/providers-sdk/v1/resources/resources.pb.go b/providers-sdk/v1/resources/resources.pb.go index 35ba6cad44..19797b6200 100644 --- a/providers-sdk/v1/resources/resources.pb.go +++ b/providers-sdk/v1/resources/resources.pb.go @@ -3,8 +3,8 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.35.1 -// protoc v5.28.3 +// protoc-gen-go v1.35.2 +// protoc v5.29.0 // source: resources.proto package resources diff --git a/providers-sdk/v1/vault/vault.pb.go b/providers-sdk/v1/vault/vault.pb.go index cfb80dda7e..ad919c8cb4 100644 --- a/providers-sdk/v1/vault/vault.pb.go +++ b/providers-sdk/v1/vault/vault.pb.go @@ -3,8 +3,8 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.35.1 -// protoc v5.28.3 +// protoc-gen-go v1.35.2 +// protoc v5.29.0 // source: vault.proto package vault diff --git a/providers/github/provider/provider.go b/providers/github/provider/provider.go index 970dc541af..25cdfdeaa1 100644 --- a/providers/github/provider/provider.go +++ b/providers/github/provider/provider.go @@ -53,7 +53,8 @@ func (s *Service) ParseCLI(req *plugin.ParseCLIReq) (*plugin.ParseCLIRes, error) } isAppAuth := false - if appId, ok := req.Flags[connection.OPTION_APP_ID]; ok && len(appId.Value) > 0 { + appId, ok := flags[connection.OPTION_APP_ID] + if ok && len(appId.Value) > 0 { conf.Options[connection.OPTION_APP_ID] = string(appId.Value) installId := req.Flags[connection.OPTION_APP_INSTALLATION_ID] diff --git a/providers/github/resources/github.go b/providers/github/resources/github.go index 2510b176eb..bccbd372fd 100644 --- a/providers/github/resources/github.go +++ b/providers/github/resources/github.go @@ -77,4 +77,5 @@ func githubTimestamp(ts *github.Timestamp) *time.Time { const ( paginationPerPage = 100 + workers = 10 ) diff --git a/providers/github/resources/github.lr b/providers/github/resources/github.lr index b3b3a6b3cb..67bb1f8e6e 100644 --- a/providers/github/resources/github.lr +++ b/providers/github/resources/github.lr @@ -80,6 +80,8 @@ github.organization @defaults("login name") { updatedAt time // Number of private repositories totalPrivateRepos int + // Number of public repositories + totalPublicRepos int // Number of owned private repositories for the organization ownedPrivateRepos int // Number of private gists diff --git a/providers/github/resources/github.lr.go b/providers/github/resources/github.lr.go index beb8b4516f..44725e97b2 100644 --- a/providers/github/resources/github.lr.go +++ b/providers/github/resources/github.lr.go @@ -276,6 +276,9 @@ var getDataFields = map[string]func(r plugin.Resource) *plugin.DataRes{ "github.organization.totalPrivateRepos": func(r plugin.Resource) *plugin.DataRes { return (r.(*mqlGithubOrganization).GetTotalPrivateRepos()).ToDataRes(types.Int) }, + "github.organization.totalPublicRepos": func(r plugin.Resource) *plugin.DataRes { + return (r.(*mqlGithubOrganization).GetTotalPublicRepos()).ToDataRes(types.Int) + }, "github.organization.ownedPrivateRepos": func(r plugin.Resource) *plugin.DataRes { return (r.(*mqlGithubOrganization).GetOwnedPrivateRepos()).ToDataRes(types.Int) }, @@ -1118,6 +1121,10 @@ var setDataFields = map[string]func(r plugin.Resource, v *llx.RawData) bool { r.(*mqlGithubOrganization).TotalPrivateRepos, ok = plugin.RawToTValue[int64](v.Value, v.Error) return }, + "github.organization.totalPublicRepos": func(r plugin.Resource, v *llx.RawData) (ok bool) { + r.(*mqlGithubOrganization).TotalPublicRepos, ok = plugin.RawToTValue[int64](v.Value, v.Error) + return + }, "github.organization.ownedPrivateRepos": func(r plugin.Resource, v *llx.RawData) (ok bool) { r.(*mqlGithubOrganization).OwnedPrivateRepos, ok = plugin.RawToTValue[int64](v.Value, v.Error) return @@ -2404,6 +2411,7 @@ type mqlGithubOrganization struct { CreatedAt plugin.TValue[*time.Time] UpdatedAt plugin.TValue[*time.Time] TotalPrivateRepos plugin.TValue[int64] + TotalPublicRepos plugin.TValue[int64] OwnedPrivateRepos plugin.TValue[int64] PrivateGists plugin.TValue[int64] DiskUsage plugin.TValue[int64] @@ -2533,6 +2541,10 @@ func (c *mqlGithubOrganization) GetTotalPrivateRepos() *plugin.TValue[int64] { return &c.TotalPrivateRepos } +func (c *mqlGithubOrganization) GetTotalPublicRepos() *plugin.TValue[int64] { + return &c.TotalPublicRepos +} + func (c *mqlGithubOrganization) GetOwnedPrivateRepos() *plugin.TValue[int64] { return &c.OwnedPrivateRepos } diff --git a/providers/github/resources/github.lr.manifest.yaml b/providers/github/resources/github.lr.manifest.yaml index a5f4adf8dd..af798f362b 100755 --- a/providers/github/resources/github.lr.manifest.yaml +++ b/providers/github/resources/github.lr.manifest.yaml @@ -274,6 +274,8 @@ resources: total_private_repos: {} totalPrivateRepos: min_mondoo_version: 6.11.0 + totalPublicRepos: + min_mondoo_version: 9.0.0 twitter_username: {} twitterUsername: min_mondoo_version: 6.11.0 diff --git a/providers/github/resources/github_org.go b/providers/github/resources/github_org.go index 67005eb7ab..ad783c57c8 100644 --- a/providers/github/resources/github_org.go +++ b/providers/github/resources/github_org.go @@ -10,6 +10,8 @@ import ( "time" "github.com/google/go-github/v67/github" + "github.com/rs/zerolog/log" + "go.mondoo.com/cnquery/v11/internal/workerpool" "go.mondoo.com/cnquery/v11/llx" "go.mondoo.com/cnquery/v11/logger" "go.mondoo.com/cnquery/v11/providers-sdk/v1/plugin" @@ -70,6 +72,7 @@ func initGithubOrganization(runtime *plugin.Runtime, args map[string]*llx.RawDat args["createdAt"] = llx.TimeDataPtr(githubTimestamp(org.CreatedAt)) args["updatedAt"] = llx.TimeDataPtr(githubTimestamp(org.UpdatedAt)) args["totalPrivateRepos"] = llx.IntDataPtr(org.TotalPrivateRepos) + args["totalPublicRepos"] = llx.IntDataPtr(org.PublicRepos) args["ownedPrivateRepos"] = llx.IntDataPtr(org.OwnedPrivateRepos) args["privateGists"] = llx.IntDataDefault(org.PrivateGists, 0) args["diskUsage"] = llx.IntDataDefault(org.DiskUsage, 0) @@ -262,26 +265,66 @@ func (g *mqlGithubOrganization) repositories() ([]interface{}, error) { return nil, g.Login.Error } orgLogin := g.Login.Data - - listOpts := &github.RepositoryListByOrgOptions{ - ListOptions: github.ListOptions{PerPage: paginationPerPage}, - Type: "all", + listOpts := github.RepositoryListByOrgOptions{ + ListOptions: github.ListOptions{ + PerPage: paginationPerPage, + Page: 1, + }, + Type: "all", } + repoCount := g.TotalPrivateRepos.Data + g.TotalPublicRepos.Data + workerPool := workerpool.New[[]*github.Repository](workers) + workerPool.Start() + defer workerPool.Close() + + log.Debug(). + Int("workers", workers). + Int64("total_repos", repoCount). + Str("organization", g.Name.Data). + Msg("list repositories") + var allRepos []*github.Repository for { - repos, resp, err := conn.Client().Repositories.ListByOrg(conn.Context(), orgLogin, listOpts) - if err != nil { + + // exit as soon as we collect all repositories + if len(allRepos) >= int(repoCount) { + break + } + + // send as many request as workers we have + for i := 1; i <= workers; i++ { + opts := listOpts + workerPool.Submit(func() ([]*github.Repository, error) { + repos, _, err := conn.Client().Repositories.ListByOrg(conn.Context(), orgLogin, &opts) + return repos, err + }) + + // check if we need to submit more requests + newRepoCount := len(allRepos) + i*paginationPerPage + if newRepoCount > int(repoCount) { + break + } + + // next page + listOpts.Page++ + } + + // wait for the results + for i := 0; i < workers; i++ { + if workerPool.HasPendingRequests() { + allRepos = append(allRepos, workerPool.GetResult()...) + } + } + + // check if any request failed + if err := workerPool.GetError(); err != nil { if strings.Contains(err.Error(), "404") { return nil, nil } return nil, err } - allRepos = append(allRepos, repos...) - if resp.NextPage == 0 { - break - } - listOpts.Page = resp.NextPage + } if g.repoCacheMap == nil { From 93ca08d3392f41325b41af0b5dfbe9c98e4540c3 Mon Sep 17 00:00:00 2001 From: Salim Afiune Maya Date: Wed, 11 Dec 2024 14:01:15 -0800 Subject: [PATCH 03/13] =?UTF-8?q?=E2=9A=99=EF=B8=8F=20=20add=20a=20collect?= =?UTF-8?q?or=20to=20the=20workerpool?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This will help us submit as many requests as we want without knowing about the workers. Signed-off-by: Salim Afiune Maya --- internal/workerpool/collector.go | 34 +++++++ internal/workerpool/pool.go | 112 ++++++++++++++++------- internal/workerpool/pool_test.go | 96 +++++++++++++------ internal/workerpool/worker.go | 18 ++-- providers/github/resources/github_org.go | 56 +++++------- 5 files changed, 212 insertions(+), 104 deletions(-) create mode 100644 internal/workerpool/collector.go diff --git a/internal/workerpool/collector.go b/internal/workerpool/collector.go new file mode 100644 index 0000000000..4c5257afda --- /dev/null +++ b/internal/workerpool/collector.go @@ -0,0 +1,34 @@ +// Copyright (c) Mondoo, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package workerpool + +type collector[R any] struct { + resultsCh <-chan R + results []R + + errorsCh <-chan error + errors []error + + requestsRead int64 +} + +func (c *collector[R]) Start() { + go func() { + for { + select { + case result := <-c.resultsCh: + c.results = append(c.results, result) + + case err := <-c.errorsCh: + c.errors = append(c.errors, err) + } + + c.requestsRead++ + } + }() +} + +func (c *collector[R]) RequestsRead() int64 { + return c.requestsRead +} diff --git a/internal/workerpool/pool.go b/internal/workerpool/pool.go index 3bf8d71bfc..d407543cff 100644 --- a/internal/workerpool/pool.go +++ b/internal/workerpool/pool.go @@ -4,68 +4,114 @@ package workerpool import ( + "sync/atomic" + "time" + "github.com/cockroachdb/errors" ) type Task[R any] func() (result R, err error) +// Pool is a generic pool of workers. type Pool[R any] struct { - queue chan Task[R] - results chan R - errors chan error - workerCount int - requestsSent int - requestsRead int - - err error + queueCh chan Task[R] + resultsCh chan R + errorsCh chan error + requestsSent int64 + + workers []*worker[R] + workerCount int + + collector[R] } +// New initializes a new Pool with the provided number of workers. The pool is generic and can +// accept any type of Task that returns the signature `func() (R, error)`. +// +// For example, a Pool[int] will accept Tasks similar to: +// +// task := func() (int, error) { +// return 42, nil +// } func New[R any](count int) *Pool[R] { + resultsCh := make(chan R) + errorsCh := make(chan error) return &Pool[R]{ - queue: make(chan Task[R]), - results: make(chan R), - errors: make(chan error), + queueCh: make(chan Task[R]), + resultsCh: resultsCh, + errorsCh: errorsCh, workerCount: count, + collector: collector[R]{resultsCh: resultsCh, errorsCh: errorsCh}, } } +// Start the pool workers and collector. Make sure call `Close()` to clear the pool. +// +// pool := workerpool.New[int](10) +// pool.Start() +// defer pool.Close() func (p *Pool[R]) Start() { for i := 0; i < p.workerCount; i++ { - w := worker[R]{id: i, queue: p.queue, results: p.results, errors: p.errors} + w := worker[R]{id: i, queueCh: p.queueCh, resultsCh: p.resultsCh, errorsCh: p.errorsCh} w.Start() + p.workers = append(p.workers, &w) } - p.errorCollector() + p.collector.Start() } -func (p *Pool[R]) errorCollector() { - go func() { - for e := range p.errors { - p.err = errors.Join(p.err, e) - } - }() +// Submit sends a task to the workers +func (p *Pool[R]) Submit(t Task[R]) { + p.queueCh <- t + atomic.AddInt64(&p.requestsSent, 1) } -func (p *Pool[R]) GetError() error { - return p.err +// GetErrors returns any error from a processed task +func (p *Pool[R]) GetErrors() error { + return errors.Join(p.collector.errors...) } -func (p *Pool[R]) Submit(t Task[R]) { - p.queue <- t - p.requestsSent++ +// GetResults returns the tasks results. +// +// It is recommended to call `Wait()` before reading the results. +func (p *Pool[R]) GetResults() []R { + return p.collector.results +} + +// Close waits for workers and collector to process all the requests, and then closes +// the task queue channel. After closing the pool, calling `Submit()` will panic. +func (p *Pool[R]) Close() { + p.Wait() + close(p.queueCh) } -func (p *Pool[R]) GetResult() R { - defer func() { - p.requestsRead++ - }() - return <-p.results +// Wait waits until all tasks have been processed. +func (p *Pool[R]) Wait() { + ticker := time.NewTicker(100 * time.Millisecond) + for { + if !p.Processing() { + return + } + <-ticker.C + } } -func (p *Pool[R]) HasPendingRequests() bool { - return p.requestsSent-p.requestsRead > 0 +// PendingRequests returns the number of pending requests. +func (p *Pool[R]) PendingRequests() int64 { + return p.requestsSent - p.collector.RequestsRead() } -func (p *Pool[R]) Close() { - close(p.queue) +// Processing return true if tasks are being processed. +func (p *Pool[R]) Processing() bool { + if !p.empty() { + return false + } + + return p.PendingRequests() != 0 +} + +func (p *Pool[R]) empty() bool { + return len(p.queueCh) == 0 && + len(p.resultsCh) == 0 && + len(p.errorsCh) == 0 } diff --git a/internal/workerpool/pool_test.go b/internal/workerpool/pool_test.go index 6337ca28c0..3b3946df1e 100644 --- a/internal/workerpool/pool_test.go +++ b/internal/workerpool/pool_test.go @@ -8,6 +8,8 @@ import ( "testing" "time" + "math/rand" + "github.com/stretchr/testify/assert" "go.mondoo.com/cnquery/v11/internal/workerpool" ) @@ -21,24 +23,23 @@ func TestPoolSubmitAndRetrieveResult(t *testing.T) { return 42, nil } - // no requests - assert.False(t, pool.HasPendingRequests()) + // no results + assert.Empty(t, pool.GetResults()) // submit a request pool.Submit(task) - // should have pending requests - assert.True(t, pool.HasPendingRequests()) - - // assert results comes back - result := pool.GetResult() - assert.Equal(t, 42, result) + // wait for the request to process + pool.Wait() - // no more requests pending - assert.False(t, pool.HasPendingRequests()) + // should have one result + results := pool.GetResults() + if assert.Len(t, results, 1) { + assert.Equal(t, 42, results[0]) + } // no errors - assert.Nil(t, pool.GetError()) + assert.Nil(t, pool.GetErrors()) } func TestPoolHandleErrors(t *testing.T) { @@ -53,9 +54,9 @@ func TestPoolHandleErrors(t *testing.T) { pool.Submit(task) // Wait for error collector to process - time.Sleep(100 * time.Millisecond) + pool.Wait() - err := pool.GetError() + err := pool.GetErrors() if assert.Error(t, err) { assert.Contains(t, err.Error(), "task error") } @@ -82,14 +83,15 @@ func TestPoolMultipleTasksWithErrors(t *testing.T) { pool.Submit(task) } - var results []*test - for range tasks { - results = append(results, pool.GetResult()) - } - - assert.ElementsMatch(t, []*test{nil, &test{1}, &test{2}, &test{3}}, results) - assert.False(t, pool.HasPendingRequests()) + // Wait for error collector to process + pool.Wait() + results := pool.GetResults() + assert.ElementsMatch(t, []*test{&test{1}, &test{2}, &test{3}}, results) + err := pool.GetErrors() + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "task error") + } } func TestPoolHandlesNilTasks(t *testing.T) { @@ -100,14 +102,13 @@ func TestPoolHandlesNilTasks(t *testing.T) { var nilTask workerpool.Task[int] pool.Submit(nilTask) - // Wait for worker to process the nil task - time.Sleep(100 * time.Millisecond) + pool.Wait() - err := pool.GetError() + err := pool.GetErrors() assert.NoError(t, err) } -func TestPoolHasPendingRequests(t *testing.T) { +func TestPoolProcessing(t *testing.T) { pool := workerpool.New[int](2) pool.Start() defer pool.Close() @@ -118,11 +119,19 @@ func TestPoolHasPendingRequests(t *testing.T) { } pool.Submit(task) - assert.True(t, pool.HasPendingRequests()) - result := pool.GetResult() - assert.Equal(t, 10, result) - assert.False(t, pool.HasPendingRequests()) + // should be processing + assert.True(t, pool.Processing()) + + // wait + pool.Wait() + + // read results + result := pool.GetResults() + assert.Equal(t, []int{10}, result) + + // should not longer be processing + assert.False(t, pool.Processing()) } func TestPoolClosesGracefully(t *testing.T) { @@ -143,3 +152,34 @@ func TestPoolClosesGracefully(t *testing.T) { pool.Submit(task) }) } + +func TestPoolWithManyTasks(t *testing.T) { + // 30k requests with a pool of 100 workers + // should be around 15 seconds + requestCount := 30000 + pool := workerpool.New[int](100) + pool.Start() + defer pool.Close() + + task := func() (int, error) { + random := rand.Intn(100) + time.Sleep(time.Duration(random) * time.Millisecond) + return random, nil + } + + for i := 0; i < requestCount; i++ { + pool.Submit(task) + } + + // should be processing + assert.True(t, pool.Processing()) + + // wait + pool.Wait() + + // read results + assert.Equal(t, requestCount, len(pool.GetResults())) + + // should not longer be processing + assert.False(t, pool.Processing()) +} diff --git a/internal/workerpool/worker.go b/internal/workerpool/worker.go index 4a391d44b6..19b21de1e7 100644 --- a/internal/workerpool/worker.go +++ b/internal/workerpool/worker.go @@ -4,25 +4,27 @@ package workerpool type worker[R any] struct { - id int - queue <-chan Task[R] - results chan<- R - errors chan<- error + id int + queueCh <-chan Task[R] + resultsCh chan<- R + errorsCh chan<- error } func (w *worker[R]) Start() { go func() { - for task := range w.queue { + for task := range w.queueCh { if task == nil { + // let the collector know we processed the request + w.errorsCh <- nil continue } data, err := task() if err != nil { - w.errors <- err + w.errorsCh <- err + } else { + w.resultsCh <- data } - - w.results <- data } }() } diff --git a/providers/github/resources/github_org.go b/providers/github/resources/github_org.go index ad783c57c8..ef39f97159 100644 --- a/providers/github/resources/github_org.go +++ b/providers/github/resources/github_org.go @@ -5,6 +5,7 @@ package resources import ( "errors" + "slices" "strconv" "strings" "time" @@ -284,47 +285,30 @@ func (g *mqlGithubOrganization) repositories() ([]interface{}, error) { Str("organization", g.Name.Data). Msg("list repositories") - var allRepos []*github.Repository for { - // exit as soon as we collect all repositories - if len(allRepos) >= int(repoCount) { + reposLen := len(slices.Concat(workerPool.GetResults()...)) + if reposLen >= int(repoCount) { break } - // send as many request as workers we have - for i := 1; i <= workers; i++ { - opts := listOpts - workerPool.Submit(func() ([]*github.Repository, error) { - repos, _, err := conn.Client().Repositories.ListByOrg(conn.Context(), orgLogin, &opts) - return repos, err - }) - - // check if we need to submit more requests - newRepoCount := len(allRepos) + i*paginationPerPage - if newRepoCount > int(repoCount) { - break - } + // send requests to workers + opts := listOpts + workerPool.Submit(func() ([]*github.Repository, error) { + repos, _, err := conn.Client().Repositories.ListByOrg(conn.Context(), orgLogin, &opts) + return repos, err + }) - // next page - listOpts.Page++ - } - - // wait for the results - for i := 0; i < workers; i++ { - if workerPool.HasPendingRequests() { - allRepos = append(allRepos, workerPool.GetResult()...) - } - } + // next page + listOpts.Page++ // check if any request failed - if err := workerPool.GetError(); err != nil { + if err := workerPool.GetErrors(); err != nil { if strings.Contains(err.Error(), "404") { return nil, nil } return nil, err } - } if g.repoCacheMap == nil { @@ -332,15 +316,17 @@ func (g *mqlGithubOrganization) repositories() ([]interface{}, error) { } res := []interface{}{} - for i := range allRepos { - repo := allRepos[i] + for _, repos := range workerPool.GetResults() { + for i := range repos { + repo := repos[i] - r, err := newMqlGithubRepository(g.MqlRuntime, repo) - if err != nil { - return nil, err + r, err := newMqlGithubRepository(g.MqlRuntime, repo) + if err != nil { + return nil, err + } + res = append(res, r) + g.repoCacheMap[repo.GetName()] = r } - res = append(res, r) - g.repoCacheMap[repo.GetName()] = r } return res, nil From 4765ace06eec98ddf8121044c6474b57ec8b7cb7 Mon Sep 17 00:00:00 2001 From: Salim Afiune Maya Date: Wed, 11 Dec 2024 18:02:00 -0800 Subject: [PATCH 04/13] :rotating_light: fix race conditions Signed-off-by: Salim Afiune Maya --- internal/workerpool/collector.go | 27 +++++++++++++++++++--- internal/workerpool/pool.go | 39 ++++++++++++++------------------ internal/workerpool/worker.go | 2 +- 3 files changed, 42 insertions(+), 26 deletions(-) diff --git a/internal/workerpool/collector.go b/internal/workerpool/collector.go index 4c5257afda..2d105501be 100644 --- a/internal/workerpool/collector.go +++ b/internal/workerpool/collector.go @@ -3,9 +3,15 @@ package workerpool +import ( + "sync" + "sync/atomic" +) + type collector[R any] struct { resultsCh <-chan R results []R + read sync.Mutex errorsCh <-chan error errors []error @@ -13,22 +19,37 @@ type collector[R any] struct { requestsRead int64 } -func (c *collector[R]) Start() { +func (c *collector[R]) start() { go func() { for { select { case result := <-c.resultsCh: + c.read.Lock() c.results = append(c.results, result) + c.read.Unlock() case err := <-c.errorsCh: + c.read.Lock() c.errors = append(c.errors, err) + c.read.Unlock() } - c.requestsRead++ + atomic.AddInt64(&c.requestsRead, 1) } }() } +func (c *collector[R]) GetResults() []R { + c.read.Lock() + defer c.read.Unlock() + return c.results +} + +func (c *collector[R]) GetErrors() []error { + c.read.Lock() + defer c.read.Unlock() + return c.errors +} func (c *collector[R]) RequestsRead() int64 { - return c.requestsRead + return atomic.LoadInt64(&c.requestsRead) } diff --git a/internal/workerpool/pool.go b/internal/workerpool/pool.go index d407543cff..8553ca25d5 100644 --- a/internal/workerpool/pool.go +++ b/internal/workerpool/pool.go @@ -4,6 +4,7 @@ package workerpool import ( + "sync" "sync/atomic" "time" @@ -14,10 +15,12 @@ type Task[R any] func() (result R, err error) // Pool is a generic pool of workers. type Pool[R any] struct { - queueCh chan Task[R] - resultsCh chan R - errorsCh chan error + queueCh chan Task[R] + resultsCh chan R + errorsCh chan error + requestsSent int64 + once sync.Once workers []*worker[R] workerCount int @@ -51,13 +54,15 @@ func New[R any](count int) *Pool[R] { // pool.Start() // defer pool.Close() func (p *Pool[R]) Start() { - for i := 0; i < p.workerCount; i++ { - w := worker[R]{id: i, queueCh: p.queueCh, resultsCh: p.resultsCh, errorsCh: p.errorsCh} - w.Start() - p.workers = append(p.workers, &w) - } + p.once.Do(func() { + for i := 0; i < p.workerCount; i++ { + w := worker[R]{id: i, queueCh: p.queueCh, resultsCh: p.resultsCh, errorsCh: p.errorsCh} + w.start() + p.workers = append(p.workers, &w) + } - p.collector.Start() + p.collector.start() + }) } // Submit sends a task to the workers @@ -68,14 +73,14 @@ func (p *Pool[R]) Submit(t Task[R]) { // GetErrors returns any error from a processed task func (p *Pool[R]) GetErrors() error { - return errors.Join(p.collector.errors...) + return errors.Join(p.collector.GetErrors()...) } // GetResults returns the tasks results. // // It is recommended to call `Wait()` before reading the results. func (p *Pool[R]) GetResults() []R { - return p.collector.results + return p.collector.GetResults() } // Close waits for workers and collector to process all the requests, and then closes @@ -98,20 +103,10 @@ func (p *Pool[R]) Wait() { // PendingRequests returns the number of pending requests. func (p *Pool[R]) PendingRequests() int64 { - return p.requestsSent - p.collector.RequestsRead() + return atomic.LoadInt64(&p.requestsSent) - p.collector.RequestsRead() } // Processing return true if tasks are being processed. func (p *Pool[R]) Processing() bool { - if !p.empty() { - return false - } - return p.PendingRequests() != 0 } - -func (p *Pool[R]) empty() bool { - return len(p.queueCh) == 0 && - len(p.resultsCh) == 0 && - len(p.errorsCh) == 0 -} diff --git a/internal/workerpool/worker.go b/internal/workerpool/worker.go index 19b21de1e7..77b5c81f15 100644 --- a/internal/workerpool/worker.go +++ b/internal/workerpool/worker.go @@ -10,7 +10,7 @@ type worker[R any] struct { errorsCh chan<- error } -func (w *worker[R]) Start() { +func (w *worker[R]) start() { go func() { for task := range w.queueCh { if task == nil { From 8c989040693f91d4b216995182acafdea1ec5212 Mon Sep 17 00:00:00 2001 From: Salim Afiune Maya Date: Wed, 11 Dec 2024 14:44:49 -0800 Subject: [PATCH 05/13] :zap: discover assets in parallel Signed-off-by: Salim Afiune Maya --- .vscode/launch.json | 12 +++++ explorer/scan/discovery.go | 61 +++++++++++++++-------- providers-sdk/v1/plugin/service.go | 18 ++++--- providers/github/connection/connection.go | 1 + 4 files changed, 65 insertions(+), 27 deletions(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index 85931b3a76..2fab2c0990 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -198,6 +198,18 @@ "shell", "ssh", "user@18.215.249.49", ], }, + { + "name": "scan github org", + "type": "go", + "request": "launch", + "program": "${workspaceRoot}/apps/cnquery/cnquery.go", + "args": [ + "scan", + "github", + "org", "hit-training", + "--log-level", "trace" + ] + }, { "name": "Configure Built-in Providers", "type": "go", diff --git a/explorer/scan/discovery.go b/explorer/scan/discovery.go index be7a67e7e0..92d16e3dea 100644 --- a/explorer/scan/discovery.go +++ b/explorer/scan/discovery.go @@ -6,11 +6,13 @@ package scan import ( "context" "errors" + "sync" "time" "github.com/rs/zerolog/log" "go.mondoo.com/cnquery/v11/cli/config" "go.mondoo.com/cnquery/v11/cli/execruntime" + "go.mondoo.com/cnquery/v11/internal/workerpool" "go.mondoo.com/cnquery/v11/llx" "go.mondoo.com/cnquery/v11/logger" "go.mondoo.com/cnquery/v11/providers" @@ -20,6 +22,9 @@ import ( "go.mondoo.com/cnquery/v11/providers-sdk/v1/upstream" ) +// number of parallel goroutines discovering assets +const workers = 10 + type AssetWithRuntime struct { Asset *inventory.Asset Runtime *providers.Runtime @@ -34,11 +39,15 @@ type DiscoveredAssets struct { platformIds map[string]struct{} Assets []*AssetWithRuntime Errors []*AssetWithError + assetsLock sync.Mutex } // Add adds an asset and its runtime to the discovered assets list. It returns true if the // asset has been added, false if it is a duplicate func (d *DiscoveredAssets) Add(asset *inventory.Asset, runtime *providers.Runtime) bool { + d.assetsLock.Lock() + defer d.assetsLock.Unlock() + isDuplicate := false for _, platformId := range asset.PlatformIds { if _, ok := d.platformIds[platformId]; ok { @@ -161,35 +170,45 @@ func discoverAssets(rootAssetWithRuntime *AssetWithRuntime, resolvedRootAsset *i return } + pool := workerpool.New[bool](workers) + pool.Start() + defer pool.Close() + // for all discovered assets, we apply mondoo-specific labels and annotations that come from the root asset for _, a := range rootAssetWithRuntime.Runtime.Provider.Connection.Inventory.Spec.Assets { - // create runtime for root asset - assetWithRuntime, err := createRuntimeForAsset(a, upstream, recording) - if err != nil { - log.Error().Err(err).Str("asset", a.Name).Msg("unable to create runtime for asset") - discoveredAssets.AddError(a, err) - continue - } + pool.Submit(func() (bool, error) { + // create runtime for root asset + assetWithRuntime, err := createRuntimeForAsset(a, upstream, recording) + if err != nil { + log.Error().Err(err).Str("asset", a.Name).Msg("unable to create runtime for asset") + discoveredAssets.AddError(a, err) + return false, err + } - // If no asset was returned and no error, then we observed a duplicate asset with a - // runtime that already exists. - if assetWithRuntime == nil { - continue - } + // If no asset was returned and no error, then we observed a duplicate asset with a + // runtime that already exists. + if assetWithRuntime == nil { + return false, nil + } - resolvedAsset := assetWithRuntime.Runtime.Provider.Connection.Asset - if len(resolvedAsset.PlatformIds) > 0 { - prepareAsset(resolvedAsset, resolvedRootAsset, runtimeLabels) + resolvedAsset := assetWithRuntime.Runtime.Provider.Connection.Asset + if len(resolvedAsset.PlatformIds) > 0 { + prepareAsset(resolvedAsset, resolvedRootAsset, runtimeLabels) - // If the asset has been already added, we should close its runtime - if !discoveredAssets.Add(resolvedAsset, assetWithRuntime.Runtime) { + // If the asset has been already added, we should close its runtime + if !discoveredAssets.Add(resolvedAsset, assetWithRuntime.Runtime) { + assetWithRuntime.Runtime.Close() + } + } else { + discoverAssets(assetWithRuntime, resolvedRootAsset, discoveredAssets, runtimeLabels, upstream, recording) assetWithRuntime.Runtime.Close() } - } else { - discoverAssets(assetWithRuntime, resolvedRootAsset, discoveredAssets, runtimeLabels, upstream, recording) - assetWithRuntime.Runtime.Close() - } + return true, nil + }) } + + // Wait for the workers to finish processing + pool.Wait() } func createRuntimeForAsset(asset *inventory.Asset, upstream *upstream.UpstreamConfig, recording llx.Recording) (*AssetWithRuntime, error) { diff --git a/providers-sdk/v1/plugin/service.go b/providers-sdk/v1/plugin/service.go index 382efcc90b..ea36c6f9a1 100644 --- a/providers-sdk/v1/plugin/service.go +++ b/providers-sdk/v1/plugin/service.go @@ -51,11 +51,8 @@ func (s *Service) AddRuntime(conf *inventory.Config, createRuntime func(connId u } // ^^ - s.runtimesLock.Lock() - defer s.runtimesLock.Unlock() - // If a runtime with this ID already exists, then return that - if runtime, ok := s.runtimes[conf.Id]; ok { + if runtime, err := s.GetRuntime(conf.Id); err == nil { return runtime, nil } @@ -66,7 +63,7 @@ func (s *Service) AddRuntime(conf *inventory.Config, createRuntime func(connId u if runtime.Connection != nil { if parentId := runtime.Connection.ParentID(); parentId > 0 { - parentRuntime, err := s.doGetRuntime(parentId) + parentRuntime, err := s.GetRuntime(parentId) if err != nil { return nil, errors.New("parent connection " + strconv.FormatUint(uint64(parentId), 10) + " not found") } @@ -74,10 +71,19 @@ func (s *Service) AddRuntime(conf *inventory.Config, createRuntime func(connId u } } - s.runtimes[conf.Id] = runtime + + // store the new runtime + s.addRuntime(conf.Id, runtime) + return runtime, nil } +func (s *Service) addRuntime(id uint32, runtime *Runtime) { + s.runtimesLock.Lock() + defer s.runtimesLock.Unlock() + s.runtimes[id] = runtime +} + // FIXME: DEPRECATED, remove in v12.0 vv func (s *Service) deprecatedAddRuntime(createRuntime func(connId uint32) (*Runtime, error)) (*Runtime, error) { s.runtimesLock.Lock() diff --git a/providers/github/connection/connection.go b/providers/github/connection/connection.go index 8e7cdb5942..c974ef0b06 100644 --- a/providers/github/connection/connection.go +++ b/providers/github/connection/connection.go @@ -74,6 +74,7 @@ func NewGithubConnection(id uint32, asset *inventory.Asset) (*GithubConnection, ctx := context.WithValue(context.Background(), github.SleepUntilPrimaryRateLimitResetWhenRateLimited, true) // perform a quick call to verify the token's validity. + // @afiune do we need to validate the token for every connection? can this be a "once" operation? _, resp, err := client.Meta.Zen(ctx) if err != nil { if resp != nil && resp.StatusCode == 401 { From b3dc6d8105ecc8faa9a2daa77586a9b0c8d87a8c Mon Sep 17 00:00:00 2001 From: Salim Afiune Maya Date: Thu, 12 Dec 2024 06:15:46 -0800 Subject: [PATCH 06/13] =?UTF-8?q?=F0=9F=A7=AA=20decrease=20workerpool=20wa?= =?UTF-8?q?it=20ticker=20to=2010ms?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Salim Afiune Maya --- internal/workerpool/pool.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/workerpool/pool.go b/internal/workerpool/pool.go index 8553ca25d5..d59d5b0098 100644 --- a/internal/workerpool/pool.go +++ b/internal/workerpool/pool.go @@ -92,7 +92,7 @@ func (p *Pool[R]) Close() { // Wait waits until all tasks have been processed. func (p *Pool[R]) Wait() { - ticker := time.NewTicker(100 * time.Millisecond) + ticker := time.NewTicker(10 * time.Millisecond) for { if !p.Processing() { return From 17f7beb32da340261499a5f300823c4aa73b40c6 Mon Sep 17 00:00:00 2001 From: Salim Afiune Maya Date: Thu, 12 Dec 2024 06:54:51 -0800 Subject: [PATCH 07/13] =?UTF-8?q?=F0=9F=90=9B=20make=20`DiscoveredAssets.A?= =?UTF-8?q?ddError()`=20thread=20safe?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Salim Afiune Maya --- explorer/scan/discovery.go | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/explorer/scan/discovery.go b/explorer/scan/discovery.go index 92d16e3dea..1405f7ed75 100644 --- a/explorer/scan/discovery.go +++ b/explorer/scan/discovery.go @@ -48,23 +48,21 @@ func (d *DiscoveredAssets) Add(asset *inventory.Asset, runtime *providers.Runtim d.assetsLock.Lock() defer d.assetsLock.Unlock() - isDuplicate := false for _, platformId := range asset.PlatformIds { if _, ok := d.platformIds[platformId]; ok { - isDuplicate = true - break + // duplicate + return false } d.platformIds[platformId] = struct{}{} } - if isDuplicate { - return false - } d.Assets = append(d.Assets, &AssetWithRuntime{Asset: asset, Runtime: runtime}) return true } func (d *DiscoveredAssets) AddError(asset *inventory.Asset, err error) { + d.assetsLock.Lock() + defer d.assetsLock.Unlock() d.Errors = append(d.Errors, &AssetWithError{Asset: asset, Err: err}) } From 842e1e75200ace3976bd9d0150a2c73fcade9b47 Mon Sep 17 00:00:00 2001 From: Salim Afiune Maya Date: Thu, 12 Dec 2024 08:10:32 -0800 Subject: [PATCH 08/13] :thread: add mutex when running `provider.connect()` Signed-off-by: Salim Afiune Maya --- Makefile | 4 ++++ providers/runtime.go | 5 +++++ 2 files changed, 9 insertions(+) diff --git a/Makefile b/Makefile index 753969f04e..b70761b4f9 100644 --- a/Makefile +++ b/Makefile @@ -700,6 +700,10 @@ test: test/go test/lint benchmark/go: go test -bench=. -benchmem go.mondoo.com/cnquery/v11/explorer/scan/benchmark +race/go: + go test -race go.mondoo.com/cnquery/v11/internal/workerpool + go test -race go.mondoo.com/cnquery/v11/explorer/scan + test/generate: prep/tools/mockgen go generate ./providers diff --git a/providers/runtime.go b/providers/runtime.go index 840f09fa25..6b38972648 100644 --- a/providers/runtime.go +++ b/providers/runtime.go @@ -39,6 +39,9 @@ type Runtime struct { isClosed bool close sync.Once shutdownTimeout time.Duration + + // used to lock unsafe tasks + mu sync.Mutex } type ConnectedProvider struct { @@ -232,7 +235,9 @@ func (r *Runtime) Connect(req *plugin.ConnectReq) error { // } + r.mu.Lock() r.Provider.Connection, r.Provider.ConnectionError = r.Provider.Instance.Plugin.Connect(req, &callbacks) + r.mu.Unlock() if r.Provider.ConnectionError != nil { return r.Provider.ConnectionError } From e3a58fd6552622d602e080fb3d3b9ce15ea43b29 Mon Sep 17 00:00:00 2001 From: Salim Afiune Maya Date: Thu, 12 Dec 2024 10:21:40 -0800 Subject: [PATCH 09/13] =?UTF-8?q?=E2=9A=99=EF=B8=8F=20=20reduce=20the=20wo?= =?UTF-8?q?rkerpool=20Task=20function?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit We really don't need to do everything inside the workerpool, do we? Signed-off-by: Salim Afiune Maya --- explorer/scan/discovery.go | 55 +++++++++++++++++++------------------- 1 file changed, 28 insertions(+), 27 deletions(-) diff --git a/explorer/scan/discovery.go b/explorer/scan/discovery.go index 1405f7ed75..d645b7c091 100644 --- a/explorer/scan/discovery.go +++ b/explorer/scan/discovery.go @@ -168,45 +168,46 @@ func discoverAssets(rootAssetWithRuntime *AssetWithRuntime, resolvedRootAsset *i return } - pool := workerpool.New[bool](workers) + pool := workerpool.New[*AssetWithRuntime](workers) pool.Start() defer pool.Close() // for all discovered assets, we apply mondoo-specific labels and annotations that come from the root asset - for _, a := range rootAssetWithRuntime.Runtime.Provider.Connection.Inventory.Spec.Assets { - pool.Submit(func() (bool, error) { - // create runtime for root asset - assetWithRuntime, err := createRuntimeForAsset(a, upstream, recording) + for _, asset := range rootAssetWithRuntime.Runtime.Provider.Connection.Inventory.Spec.Assets { + pool.Submit(func() (*AssetWithRuntime, error) { + assetWithRuntime, err := createRuntimeForAsset(asset, upstream, recording) if err != nil { - log.Error().Err(err).Str("asset", a.Name).Msg("unable to create runtime for asset") - discoveredAssets.AddError(a, err) - return false, err + log.Error().Err(err).Str("asset", asset.GetName()).Msg("unable to create runtime for asset") + discoveredAssets.AddError(asset, err) } + return assetWithRuntime, nil + }) + } - // If no asset was returned and no error, then we observed a duplicate asset with a - // runtime that already exists. - if assetWithRuntime == nil { - return false, nil - } + // Wait for the workers to finish processing + pool.Wait() + + // Get all assets with runtimes from the pool + for _, assetWithRuntime := range pool.GetResults() { + // If asset is nil, then we observed a duplicate asset with a + // runtime that already exists. + if assetWithRuntime == nil { + continue + } - resolvedAsset := assetWithRuntime.Runtime.Provider.Connection.Asset - if len(resolvedAsset.PlatformIds) > 0 { - prepareAsset(resolvedAsset, resolvedRootAsset, runtimeLabels) + resolvedAsset := assetWithRuntime.Runtime.Provider.Connection.Asset + if len(resolvedAsset.PlatformIds) > 0 { + prepareAsset(resolvedAsset, resolvedRootAsset, runtimeLabels) - // If the asset has been already added, we should close its runtime - if !discoveredAssets.Add(resolvedAsset, assetWithRuntime.Runtime) { - assetWithRuntime.Runtime.Close() - } - } else { - discoverAssets(assetWithRuntime, resolvedRootAsset, discoveredAssets, runtimeLabels, upstream, recording) + // If the asset has been already added, we should close its runtime + if !discoveredAssets.Add(resolvedAsset, assetWithRuntime.Runtime) { assetWithRuntime.Runtime.Close() } - return true, nil - }) + } else { + discoverAssets(assetWithRuntime, resolvedRootAsset, discoveredAssets, runtimeLabels, upstream, recording) + assetWithRuntime.Runtime.Close() + } } - - // Wait for the workers to finish processing - pool.Wait() } func createRuntimeForAsset(asset *inventory.Asset, upstream *upstream.UpstreamConfig, recording llx.Recording) (*AssetWithRuntime, error) { From 4ffb815692e2b4e88ce28412b5227db8bd237fa9 Mon Sep 17 00:00:00 2001 From: Salim Afiune Maya Date: Sat, 14 Dec 2024 22:46:08 +0100 Subject: [PATCH 10/13] =?UTF-8?q?=E2=9A=99=EF=B8=8F=20=20return=20`pool.Re?= =?UTF-8?q?sult`=20as=20a=20combined=20struct?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Salim Afiune Maya --- explorer/scan/discovery.go | 4 +- internal/workerpool/collector.go | 38 +++++++------- internal/workerpool/pool.go | 64 ++++++++++++++++-------- internal/workerpool/pool_test.go | 50 +++++++++++------- internal/workerpool/worker.go | 15 +----- providers/github/resources/github_org.go | 9 ++-- 6 files changed, 106 insertions(+), 74 deletions(-) diff --git a/explorer/scan/discovery.go b/explorer/scan/discovery.go index d645b7c091..994b9f402f 100644 --- a/explorer/scan/discovery.go +++ b/explorer/scan/discovery.go @@ -188,7 +188,9 @@ func discoverAssets(rootAssetWithRuntime *AssetWithRuntime, resolvedRootAsset *i pool.Wait() // Get all assets with runtimes from the pool - for _, assetWithRuntime := range pool.GetResults() { + for _, result := range pool.GetResults() { + assetWithRuntime := result.Value + // If asset is nil, then we observed a duplicate asset with a // runtime that already exists. if assetWithRuntime == nil { diff --git a/internal/workerpool/collector.go b/internal/workerpool/collector.go index 2d105501be..bb33bb836e 100644 --- a/internal/workerpool/collector.go +++ b/internal/workerpool/collector.go @@ -9,13 +9,11 @@ import ( ) type collector[R any] struct { - resultsCh <-chan R - results []R + resultsCh <-chan Result[R] + results []Result[R] read sync.Mutex - errorsCh <-chan error - errors []error - + // The total number of requests read. requestsRead int64 } @@ -27,29 +25,35 @@ func (c *collector[R]) start() { c.read.Lock() c.results = append(c.results, result) c.read.Unlock() - - case err := <-c.errorsCh: - c.read.Lock() - c.errors = append(c.errors, err) - c.read.Unlock() } atomic.AddInt64(&c.requestsRead, 1) } }() } -func (c *collector[R]) GetResults() []R { + +func (c *collector[R]) RequestsRead() int64 { + return atomic.LoadInt64(&c.requestsRead) +} + +func (c *collector[R]) GetResults() []Result[R] { c.read.Lock() defer c.read.Unlock() return c.results } -func (c *collector[R]) GetErrors() []error { - c.read.Lock() - defer c.read.Unlock() - return c.errors +func (c *collector[R]) GetValues() (slice []R) { + results := c.GetResults() + for i := range results { + slice = append(slice, results[i].Value) + } + return } -func (c *collector[R]) RequestsRead() int64 { - return atomic.LoadInt64(&c.requestsRead) +func (c *collector[R]) GetErrors() (slice []error) { + results := c.GetResults() + for i := range results { + slice = append(slice, results[i].Error) + } + return } diff --git a/internal/workerpool/pool.go b/internal/workerpool/pool.go index d59d5b0098..1ad4afa86b 100644 --- a/internal/workerpool/pool.go +++ b/internal/workerpool/pool.go @@ -7,25 +7,40 @@ import ( "sync" "sync/atomic" "time" - - "github.com/cockroachdb/errors" ) +// Represent the tasks that can be sent to the pool. type Task[R any] func() (result R, err error) +// The result generated from a task. +type Result[R any] struct { + Value R + Error error +} + // Pool is a generic pool of workers. type Pool[R any] struct { - queueCh chan Task[R] - resultsCh chan R - errorsCh chan error + // The queue where tasks are submitted. + queueCh chan Task[R] + // Where workers send the results after a task is executed, + // the collector then reads them and aggregate them. + resultsCh chan Result[R] + + // The total number of requests sent. requestsSent int64 - once sync.Once - workers []*worker[R] + // Number of workers to spawn. workerCount int + // The list of workers that are listening to the queue. + workers []*worker[R] + + // A single collector to aggregate results. collector[R] + + // used to protect starting the pool multiple times + once sync.Once } // New initializes a new Pool with the provided number of workers. The pool is generic and can @@ -37,14 +52,12 @@ type Pool[R any] struct { // return 42, nil // } func New[R any](count int) *Pool[R] { - resultsCh := make(chan R) - errorsCh := make(chan error) + resultsCh := make(chan Result[R]) return &Pool[R]{ queueCh: make(chan Task[R]), resultsCh: resultsCh, - errorsCh: errorsCh, workerCount: count, - collector: collector[R]{resultsCh: resultsCh, errorsCh: errorsCh}, + collector: collector[R]{resultsCh: resultsCh}, } } @@ -56,7 +69,7 @@ func New[R any](count int) *Pool[R] { func (p *Pool[R]) Start() { p.once.Do(func() { for i := 0; i < p.workerCount; i++ { - w := worker[R]{id: i, queueCh: p.queueCh, resultsCh: p.resultsCh, errorsCh: p.errorsCh} + w := worker[R]{id: i, queueCh: p.queueCh, resultsCh: p.resultsCh} w.start() p.workers = append(p.workers, &w) } @@ -67,22 +80,33 @@ func (p *Pool[R]) Start() { // Submit sends a task to the workers func (p *Pool[R]) Submit(t Task[R]) { - p.queueCh <- t - atomic.AddInt64(&p.requestsSent, 1) -} - -// GetErrors returns any error from a processed task -func (p *Pool[R]) GetErrors() error { - return errors.Join(p.collector.GetErrors()...) + if t != nil { + p.queueCh <- t + atomic.AddInt64(&p.requestsSent, 1) + } } // GetResults returns the tasks results. // // It is recommended to call `Wait()` before reading the results. -func (p *Pool[R]) GetResults() []R { +func (p *Pool[R]) GetResults() []Result[R] { return p.collector.GetResults() } +// GetValues returns only the values of the pool results +// +// It is recommended to call `Wait()` before reading the results. +func (p *Pool[R]) GetValues() []R { + return p.collector.GetValues() +} + +// GetErrors returns only the errors of the pool results +// +// It is recommended to call `Wait()` before reading the results. +func (p *Pool[R]) GettErrors() []error { + return p.collector.GetErrors() +} + // Close waits for workers and collector to process all the requests, and then closes // the task queue channel. After closing the pool, calling `Submit()` will panic. func (p *Pool[R]) Close() { diff --git a/internal/workerpool/pool_test.go b/internal/workerpool/pool_test.go index 3b3946df1e..222dad5707 100644 --- a/internal/workerpool/pool_test.go +++ b/internal/workerpool/pool_test.go @@ -35,11 +35,10 @@ func TestPoolSubmitAndRetrieveResult(t *testing.T) { // should have one result results := pool.GetResults() if assert.Len(t, results, 1) { - assert.Equal(t, 42, results[0]) + assert.Equal(t, 42, results[0].Value) + // without errors + assert.NoError(t, results[0].Error) } - - // no errors - assert.Nil(t, pool.GetErrors()) } func TestPoolHandleErrors(t *testing.T) { @@ -53,12 +52,12 @@ func TestPoolHandleErrors(t *testing.T) { } pool.Submit(task) - // Wait for error collector to process + // Wait for collector to process the results pool.Wait() - err := pool.GetErrors() - if assert.Error(t, err) { - assert.Contains(t, err.Error(), "task error") + errs := pool.GetErrors() + if assert.Len(t, errs, 1) { + assert.Equal(t, errs[0].Error(), "task error") } } @@ -86,12 +85,26 @@ func TestPoolMultipleTasksWithErrors(t *testing.T) { // Wait for error collector to process pool.Wait() - results := pool.GetResults() - assert.ElementsMatch(t, []*test{&test{1}, &test{2}, &test{3}}, results) - err := pool.GetErrors() - if assert.Error(t, err) { - assert.Contains(t, err.Error(), "task error") - } + // Access results together + assert.ElementsMatch(t, + []workerpool.Result[*test]{ + {&test{1}, nil}, + {&test{2}, nil}, + {&test{3}, nil}, + {nil, errors.New("task error")}, + }, + pool.GetResults(), + ) + + // You can also access values and errors directly + assert.ElementsMatch(t, + []*test{nil, &test{1}, &test{2}, &test{3}}, + pool.GetValues(), + ) + assert.ElementsMatch(t, + []error{nil, nil, errors.New("task error"), nil}, + pool.GetErrors(), + ) } func TestPoolHandlesNilTasks(t *testing.T) { @@ -104,8 +117,8 @@ func TestPoolHandlesNilTasks(t *testing.T) { pool.Wait() - err := pool.GetErrors() - assert.NoError(t, err) + assert.Empty(t, pool.GetErrors()) + assert.Empty(t, pool.GetValues()) } func TestPoolProcessing(t *testing.T) { @@ -126,9 +139,8 @@ func TestPoolProcessing(t *testing.T) { // wait pool.Wait() - // read results - result := pool.GetResults() - assert.Equal(t, []int{10}, result) + // read values + assert.Equal(t, []int{10}, pool.GetValues()) // should not longer be processing assert.False(t, pool.Processing()) diff --git a/internal/workerpool/worker.go b/internal/workerpool/worker.go index 77b5c81f15..31257353c6 100644 --- a/internal/workerpool/worker.go +++ b/internal/workerpool/worker.go @@ -6,25 +6,14 @@ package workerpool type worker[R any] struct { id int queueCh <-chan Task[R] - resultsCh chan<- R - errorsCh chan<- error + resultsCh chan<- Result[R] } func (w *worker[R]) start() { go func() { for task := range w.queueCh { - if task == nil { - // let the collector know we processed the request - w.errorsCh <- nil - continue - } - data, err := task() - if err != nil { - w.errorsCh <- err - } else { - w.resultsCh <- data - } + w.resultsCh <- Result[R]{data, err} } }() } diff --git a/providers/github/resources/github_org.go b/providers/github/resources/github_org.go index ef39f97159..0876e4a71f 100644 --- a/providers/github/resources/github_org.go +++ b/providers/github/resources/github_org.go @@ -4,12 +4,12 @@ package resources import ( - "errors" "slices" "strconv" "strings" "time" + "github.com/cockroachdb/errors" "github.com/google/go-github/v67/github" "github.com/rs/zerolog/log" "go.mondoo.com/cnquery/v11/internal/workerpool" @@ -287,7 +287,7 @@ func (g *mqlGithubOrganization) repositories() ([]interface{}, error) { for { // exit as soon as we collect all repositories - reposLen := len(slices.Concat(workerPool.GetResults()...)) + reposLen := len(slices.Concat(workerPool.GetValues()...)) if reposLen >= int(repoCount) { break } @@ -303,7 +303,8 @@ func (g *mqlGithubOrganization) repositories() ([]interface{}, error) { listOpts.Page++ // check if any request failed - if err := workerPool.GetErrors(); err != nil { + if errs := workerPool.GetErrors(); len(errs) != 0 { + err := errors.Join(errs...) if strings.Contains(err.Error(), "404") { return nil, nil } @@ -316,7 +317,7 @@ func (g *mqlGithubOrganization) repositories() ([]interface{}, error) { } res := []interface{}{} - for _, repos := range workerPool.GetResults() { + for _, repos := range workerPool.GetValues() { for i := range repos { repo := repos[i] From 842f60e66971eb7c185b3a75d25b1745e6487e74 Mon Sep 17 00:00:00 2001 From: Salim Afiune Maya Date: Sun, 15 Dec 2024 03:56:48 +0100 Subject: [PATCH 11/13] =?UTF-8?q?=F0=9F=8F=8E=EF=B8=8F=20fix=20more=20data?= =?UTF-8?q?=20race=20conditions?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Salim Afiune Maya --- providers/runtime.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/providers/runtime.go b/providers/runtime.go index 6b38972648..162c3abb20 100644 --- a/providers/runtime.go +++ b/providers/runtime.go @@ -121,12 +121,16 @@ func (r *Runtime) UseProvider(id string) error { return err } + r.mu.Lock() r.Provider = res + r.mu.Unlock() return nil } func (r *Runtime) AddConnectedProvider(c *ConnectedProvider) { + r.mu.Lock() r.providers[c.Instance.ID] = c + r.mu.Unlock() } func (r *Runtime) addProvider(id string) (*ConnectedProvider, error) { @@ -752,6 +756,9 @@ func (r *Runtime) Schema() resources.ResourcesSchema { } func (r *Runtime) asset() *inventory.Asset { + r.mu.Lock() + defer r.mu.Unlock() + if r.Provider == nil || r.Provider.Connection == nil { return nil } From 8badda7dee2fd12842a2109ef2513fe19930bca8 Mon Sep 17 00:00:00 2001 From: Salim Afiune Maya Date: Sun, 15 Dec 2024 10:33:39 +0100 Subject: [PATCH 12/13] =?UTF-8?q?=F0=9F=A4=96=20Run=20race=20detector=20on?= =?UTF-8?q?=20CI?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Salim Afiune Maya --- .github/workflows/main-benchmark.yml | 2 +- .github/workflows/pr-test-lint.yml | 18 ++++++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/.github/workflows/main-benchmark.yml b/.github/workflows/main-benchmark.yml index e7d31ef9f4..0fbf8eec86 100644 --- a/.github/workflows/main-benchmark.yml +++ b/.github/workflows/main-benchmark.yml @@ -63,4 +63,4 @@ jobs: uses: actions/cache/save@v4 with: path: ./cache - key: ${{ runner.os }}-benchmark-${{ github.run_id }} \ No newline at end of file + key: ${{ runner.os }}-benchmark-${{ github.run_id }} diff --git a/.github/workflows/pr-test-lint.yml b/.github/workflows/pr-test-lint.yml index a46d51920b..33490eb187 100644 --- a/.github/workflows/pr-test-lint.yml +++ b/.github/workflows/pr-test-lint.yml @@ -128,6 +128,24 @@ jobs: name: test-results-cli path: report.xml + go-race: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Import environment variables from file + run: cat ".github/env" >> $GITHUB_ENV + + - name: Install Go + uses: actions/setup-go@v5 + with: + go-version: ">=${{ env.golang-version }}" + cache: false + + - name: Run race detector on selected packages + run: make race/go + go-bench: runs-on: ubuntu-latest if: github.ref != 'refs/heads/main' From ceb2585ff4680540f0403e564d3f126e27ad58f4 Mon Sep 17 00:00:00 2001 From: Salim Afiune Maya Date: Sun, 15 Dec 2024 22:19:21 +0100 Subject: [PATCH 13/13] =?UTF-8?q?=E2=9A=99=EF=B8=8F=20=20split=20plugin=20?= =?UTF-8?q?connect=20func=20and=20assignation?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Salim Afiune Maya --- providers/runtime.go | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/providers/runtime.go b/providers/runtime.go index 162c3abb20..ee71ab7af3 100644 --- a/providers/runtime.go +++ b/providers/runtime.go @@ -133,6 +133,13 @@ func (r *Runtime) AddConnectedProvider(c *ConnectedProvider) { r.mu.Unlock() } +func (r *Runtime) setProviderConnection(c *plugin.ConnectRes, err error) { + r.mu.Lock() + r.Provider.Connection = c + r.Provider.ConnectionError = err + r.mu.Unlock() +} + func (r *Runtime) addProvider(id string) (*ConnectedProvider, error) { // TODO: we need to detect only the shared running providers running, err := r.coordinator.GetRunningProvider(id, r.AutoUpdate) @@ -239,11 +246,10 @@ func (r *Runtime) Connect(req *plugin.ConnectReq) error { // } - r.mu.Lock() - r.Provider.Connection, r.Provider.ConnectionError = r.Provider.Instance.Plugin.Connect(req, &callbacks) - r.mu.Unlock() - if r.Provider.ConnectionError != nil { - return r.Provider.ConnectionError + conn, err := r.Provider.Instance.Plugin.Connect(req, &callbacks) + r.setProviderConnection(conn, err) + if err != nil { + return err } // TODO: This is a stopgap that detects if the connect call returned an asset @@ -265,9 +271,10 @@ func (r *Runtime) Connect(req *plugin.ConnectReq) error { if postProvider.ID != r.Provider.Instance.ID { req.Asset = r.Provider.Connection.Asset r.UseProvider(postProvider.ID) - r.Provider.Connection, r.Provider.ConnectionError = r.Provider.Instance.Plugin.Connect(req, &callbacks) - if r.Provider.ConnectionError != nil { - return r.Provider.ConnectionError + conn, err := r.Provider.Instance.Plugin.Connect(req, &callbacks) + r.setProviderConnection(conn, err) + if err != nil { + return err } }