diff --git a/providers/coordinator.go b/providers/coordinator.go index d23c621a61..8cc75d3c8c 100644 --- a/providers/coordinator.go +++ b/providers/coordinator.go @@ -43,7 +43,37 @@ type RunningProvider struct { Client *plugin.Client Schema *resources.Schema + // isClosed is true for any provider that is not running anymore, + // either via shutdown or via crash isClosed bool + // isShutdown is only used once during provider shutdown + isShutdown bool + // provider errors which are evaluated and printed during shutdown of the provider + err error + lock sync.Mutex +} + +func (p *RunningProvider) Shutdown() error { + p.lock.Lock() + defer p.lock.Unlock() + + if p.isShutdown { + return nil + } + + // This is an error that happened earlier, so we print it directly. + // The error this function returns is about failing to shutdown. + if p.err != nil { + log.Error().Msg(p.err.Error()) + } + + var err error + if !p.isClosed { + _, err = p.Plugin.Shutdown(&pp.ShutdownReq{}) + p.isClosed = true + } + p.isShutdown = true + return err } type UpdateProvidersConfig struct { diff --git a/providers/runtime.go b/providers/runtime.go index b3a84722bd..091fd4bc31 100644 --- a/providers/runtime.go +++ b/providers/runtime.go @@ -16,6 +16,7 @@ import ( "go.mondoo.com/cnquery/providers-sdk/v1/upstream" "go.mondoo.com/cnquery/types" "go.mondoo.com/cnquery/utils/multierr" + "google.golang.org/grpc/status" ) const defaultShutdownTimeout = time.Duration(time.Second * 120) @@ -36,6 +37,7 @@ type Runtime struct { // schema aggregates all resources executable on this asset schema extensibleSchema isClosed bool + close sync.Once shutdownTimeout time.Duration } @@ -76,39 +78,40 @@ type shutdownResult struct { } func (r *Runtime) tryShutdown() shutdownResult { - resp, err := r.Provider.Instance.Plugin.Shutdown(&plugin.ShutdownReq{}) + var errs multierr.Errors + for _, provider := range r.providers { + errs.Add(provider.Instance.Shutdown()) + } + return shutdownResult{ - Response: resp, - Error: err, + Error: errs.Deduplicate(), } } func (r *Runtime) Close() { - if r.isClosed { - return - } r.isClosed = true + r.close.Do(func() { + if err := r.Recording.Save(); err != nil { + log.Error().Err(err).Msg("failed to save recording") + } - if err := r.Recording.Save(); err != nil { - log.Error().Err(err).Msg("failed to save recording") - } - - response := make(chan shutdownResult, 1) - go func() { - response <- r.tryShutdown() - }() - select { - case <-time.After(r.shutdownTimeout): - log.Error().Str("provider", r.Provider.Instance.Name).Msg("timed out shutting down the provider") - case result := <-response: - if result.Error != nil { - log.Error().Err(result.Error).Msg("failed to shutdown the provider") + response := make(chan shutdownResult, 1) + go func() { + response <- r.tryShutdown() + }() + select { + case <-time.After(r.shutdownTimeout): + log.Error().Str("provider", r.Provider.Instance.Name).Msg("timed out shutting down the provider") + case result := <-response: + if result.Error != nil { + log.Error().Err(result.Error).Msg("failed to shutdown the provider") + } } - } - // TODO: ideally, we try to close the provider here but only if there are no more assets that need it - // r.coordinator.Close(r.Provider.Instance) - r.schema.Close() + // TODO: ideally, we try to close the provider here but only if there are no more assets that need it + // r.coordinator.Close(r.Provider.Instance) + r.schema.Close() + }) } func (r *Runtime) DeactivateProviderDiscovery() { @@ -333,7 +336,14 @@ func (r *Runtime) watchAndUpdate(resource string, resourceID string, field strin Field: field, }) if err != nil { - return nil, err + // Recoverable errors can continue with the exeuction, + // they only store errors in the place of actual data. + // Every other error is thrown up the chain. + handled, err := r.handlePluginError(err, provider) + if !handled { + return nil, err + } + data = &plugin.DataRes{Error: err.Error()} } var raw *llx.RawData @@ -347,6 +357,23 @@ func (r *Runtime) watchAndUpdate(resource string, resourceID string, field strin return raw, nil } +func (r *Runtime) handlePluginError(err error, provider *ConnectedProvider) (bool, error) { + st, ok := status.FromError(err) + if !ok { + return false, err + } + + switch st.Code() { + case 14: + // Error: Unavailable. Happens when the plugin crashes. + // TODO: try to restart the plugin and reset its connections + provider.Instance.isClosed = true + provider.Instance.err = errors.New("the '" + provider.Instance.Name + "' provider crashed") + return false, provider.Instance.err + } + return false, err +} + type providerCallbacks struct { recording *assetRecording runtime *Runtime diff --git a/utils/multierr/errors.go b/utils/multierr/errors.go index 94de41a57d..284b08d1e6 100644 --- a/utils/multierr/errors.go +++ b/utils/multierr/errors.go @@ -34,10 +34,11 @@ type Errors struct { } func (m *Errors) Add(err ...error) { - if err == nil { - return + for i := range err { + if err[i] != nil { + m.errors = append(m.errors, err[i]) + } } - m.errors = append(m.errors, err...) } func (m *Errors) Error() string { diff --git a/utils/multierr/errors_test.go b/utils/multierr/errors_test.go new file mode 100644 index 0000000000..de902a7d30 --- /dev/null +++ b/utils/multierr/errors_test.go @@ -0,0 +1,29 @@ +// Copyright (c) Mondoo, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package multierr_test + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "go.mondoo.com/cnquery/utils/multierr" +) + +func TestMultiErr(t *testing.T) { + t.Run("add nil errors", func(t *testing.T) { + var e multierr.Errors + e.Add(nil) + e.Add(nil, nil, nil) + assert.Nil(t, e.Deduplicate()) + }) + + t.Run("add mixed errors", func(t *testing.T) { + var e multierr.Errors + e.Add(errors.New("1"), nil, errors.New("1")) + var b multierr.Errors + b.Add(errors.New("1")) + assert.Equal(t, b.Deduplicate(), e.Deduplicate()) + }) +}