From f4e36303f13e09039f32124df5939c3b5593f5d8 Mon Sep 17 00:00:00 2001 From: Dominik Richter Date: Sat, 2 Dec 2023 12:28:23 -0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20cascade=20provider=20connection?= =?UTF-8?q?=20errors=20(#2728)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When an auxiliary provider is detected, its connection errors currently don't propagate after the initial connect call. Example: You are connecting to AWS and request `asset{*}`. This will request fields form the `os` provider, including `cpes` and `vulnerabilityReport`. The first call to `cpes` will fail to connect (as it should) and the error will be reported. On the second call (when we request `vulnerabilityReport`) it will return the provider during `provider := r.providers[info.Provider]; provider != nil `. However, that provider now doesn't have a valid connection. This PR adds the connection errors as a stable field within the provider and also returns them on returning a cached provider. Fixes https://github.com/mondoohq/cnquery/issues/2724 Signed-off-by: Dominik Richter --- providers/runtime.go | 45 ++++++++++++++++++++------------------------ 1 file changed, 20 insertions(+), 25 deletions(-) diff --git a/providers/runtime.go b/providers/runtime.go index fb543f6ca1..9a3fdb4c17 100644 --- a/providers/runtime.go +++ b/providers/runtime.go @@ -44,8 +44,9 @@ type Runtime struct { } type ConnectedProvider struct { - Instance *RunningProvider - Connection *plugin.ConnectRes + Instance *RunningProvider + Connection *plugin.ConnectRes + ConnectionError error } func (c *coordinator) RuntimeWithShutdownTimeout(timeout time.Duration) *Runtime { @@ -219,10 +220,9 @@ func (r *Runtime) Connect(req *plugin.ConnectReq) error { runtime: r, } - var err error - r.Provider.Connection, err = r.Provider.Instance.Plugin.Connect(req, &callbacks) - if err != nil { - return err + r.Provider.Connection, r.Provider.ConnectionError = r.Provider.Instance.Plugin.Connect(req, &callbacks) + if r.Provider.ConnectionError != nil { + return r.Provider.ConnectionError } // TODO: This is a stopgap that detects if the connect call returned an asset @@ -244,9 +244,9 @@ func (r *Runtime) Connect(req *plugin.ConnectReq) error { if postProvider.ID != r.Provider.Instance.ID { req.Asset = r.Provider.Connection.Asset r.UseProvider(postProvider.ID) - r.Provider.Connection, err = r.Provider.Instance.Plugin.Connect(req, &callbacks) - if err != nil { - return err + r.Provider.Connection, r.Provider.ConnectionError = r.Provider.Instance.Plugin.Connect(req, &callbacks) + if r.Provider.ConnectionError != nil { + return r.Provider.ConnectionError } } @@ -536,15 +536,14 @@ func (r *Runtime) SetMockRecording(anyRecording Recording, providerID string, mo runtime: r, } - res, err := provider.Instance.Plugin.Connect(&plugin.ConnectReq{ + provider.Connection, provider.ConnectionError = provider.Instance.Plugin.Connect(&plugin.ConnectReq{ Asset: asset, Upstream: r.UpstreamConfig, HasRecording: true, }, &callbacks) - if err != nil { - return multierr.Wrap(err, "failed to set mock connection for recording") + if provider.ConnectionError != nil { + return multierr.Wrap(provider.ConnectionError, "failed to set mock connection for recording") } - provider.Connection = res } if provider.Connection == nil { @@ -573,7 +572,7 @@ func (r *Runtime) lookupResourceProvider(resource string) (*ConnectedProvider, * } if provider := r.providers[info.Provider]; provider != nil { - return provider, info, nil + return provider, info, provider.ConnectionError } providerConn := r.Provider.Instance.ID @@ -586,17 +585,15 @@ func (r *Runtime) lookupResourceProvider(resource string) (*ConnectedProvider, * return nil, nil, multierr.Wrap(err, "failed to start provider '"+info.Provider+"'") } - conn, err := res.Instance.Plugin.Connect(&plugin.ConnectReq{ + res.Connection, res.ConnectionError = res.Instance.Plugin.Connect(&plugin.ConnectReq{ Features: r.features, Upstream: r.UpstreamConfig, Asset: r.Provider.Connection.Asset, }, nil) - if err != nil { - return nil, nil, err + if res.ConnectionError != nil { + return nil, nil, res.ConnectionError } - res.Connection = conn - return res, info, nil } @@ -610,7 +607,7 @@ func (r *Runtime) lookupFieldProvider(resource string, field string) (*Connected } if provider := r.providers[fieldInfo.Provider]; provider != nil { - return provider, resourceInfo, fieldInfo, nil + return provider, resourceInfo, fieldInfo, provider.ConnectionError } res, err := r.addProvider(fieldInfo.Provider, false) @@ -618,17 +615,15 @@ func (r *Runtime) lookupFieldProvider(resource string, field string) (*Connected return nil, nil, nil, multierr.Wrap(err, "failed to start provider '"+fieldInfo.Provider+"'") } - conn, err := res.Instance.Plugin.Connect(&plugin.ConnectReq{ + res.Connection, res.ConnectionError = res.Instance.Plugin.Connect(&plugin.ConnectReq{ Features: r.features, Upstream: r.UpstreamConfig, Asset: r.Provider.Connection.Asset, }, nil) - if err != nil { - return nil, nil, nil, err + if res.ConnectionError != nil { + return nil, nil, nil, res.ConnectionError } - res.Connection = conn - return res, resourceInfo, fieldInfo, nil }