Skip to content

Commit

Permalink
🐛 cascade provider connection errors (#2728)
Browse files Browse the repository at this point in the history
When an auxiliary provider is detected, its connection errors currently
don't propagate after the initial connect call.

Example: You are connecting to AWS and request `asset{*}`. This will
request fields form the `os` provider, including `cpes` and
`vulnerabilityReport`.

The first call to `cpes` will fail to connect (as it should) and the
error will be reported. On the second call (when we request
`vulnerabilityReport`) it will return the provider during `provider :=
r.providers[info.Provider]; provider != nil `. However, that provider
now doesn't have a valid connection.

This PR adds the connection errors as a stable field within the provider
and also returns them on returning a cached provider.

Fixes #2724

Signed-off-by: Dominik Richter <[email protected]>
  • Loading branch information
arlimus authored Dec 2, 2023
1 parent 09d5a4a commit f4e3630
Showing 1 changed file with 20 additions and 25 deletions.
45 changes: 20 additions & 25 deletions providers/runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ type Runtime struct {
}

type ConnectedProvider struct {
Instance *RunningProvider
Connection *plugin.ConnectRes
Instance *RunningProvider
Connection *plugin.ConnectRes
ConnectionError error
}

func (c *coordinator) RuntimeWithShutdownTimeout(timeout time.Duration) *Runtime {
Expand Down Expand Up @@ -219,10 +220,9 @@ func (r *Runtime) Connect(req *plugin.ConnectReq) error {
runtime: r,
}

var err error
r.Provider.Connection, err = r.Provider.Instance.Plugin.Connect(req, &callbacks)
if err != nil {
return err
r.Provider.Connection, r.Provider.ConnectionError = r.Provider.Instance.Plugin.Connect(req, &callbacks)
if r.Provider.ConnectionError != nil {
return r.Provider.ConnectionError
}

// TODO: This is a stopgap that detects if the connect call returned an asset
Expand All @@ -244,9 +244,9 @@ func (r *Runtime) Connect(req *plugin.ConnectReq) error {
if postProvider.ID != r.Provider.Instance.ID {
req.Asset = r.Provider.Connection.Asset
r.UseProvider(postProvider.ID)
r.Provider.Connection, err = r.Provider.Instance.Plugin.Connect(req, &callbacks)
if err != nil {
return err
r.Provider.Connection, r.Provider.ConnectionError = r.Provider.Instance.Plugin.Connect(req, &callbacks)
if r.Provider.ConnectionError != nil {
return r.Provider.ConnectionError
}
}

Expand Down Expand Up @@ -536,15 +536,14 @@ func (r *Runtime) SetMockRecording(anyRecording Recording, providerID string, mo
runtime: r,
}

res, err := provider.Instance.Plugin.Connect(&plugin.ConnectReq{
provider.Connection, provider.ConnectionError = provider.Instance.Plugin.Connect(&plugin.ConnectReq{
Asset: asset,
Upstream: r.UpstreamConfig,
HasRecording: true,
}, &callbacks)
if err != nil {
return multierr.Wrap(err, "failed to set mock connection for recording")
if provider.ConnectionError != nil {
return multierr.Wrap(provider.ConnectionError, "failed to set mock connection for recording")
}
provider.Connection = res
}

if provider.Connection == nil {
Expand Down Expand Up @@ -573,7 +572,7 @@ func (r *Runtime) lookupResourceProvider(resource string) (*ConnectedProvider, *
}

if provider := r.providers[info.Provider]; provider != nil {
return provider, info, nil
return provider, info, provider.ConnectionError
}

providerConn := r.Provider.Instance.ID
Expand All @@ -586,17 +585,15 @@ func (r *Runtime) lookupResourceProvider(resource string) (*ConnectedProvider, *
return nil, nil, multierr.Wrap(err, "failed to start provider '"+info.Provider+"'")
}

conn, err := res.Instance.Plugin.Connect(&plugin.ConnectReq{
res.Connection, res.ConnectionError = res.Instance.Plugin.Connect(&plugin.ConnectReq{
Features: r.features,
Upstream: r.UpstreamConfig,
Asset: r.Provider.Connection.Asset,
}, nil)
if err != nil {
return nil, nil, err
if res.ConnectionError != nil {
return nil, nil, res.ConnectionError
}

res.Connection = conn

return res, info, nil
}

Expand All @@ -610,25 +607,23 @@ func (r *Runtime) lookupFieldProvider(resource string, field string) (*Connected
}

if provider := r.providers[fieldInfo.Provider]; provider != nil {
return provider, resourceInfo, fieldInfo, nil
return provider, resourceInfo, fieldInfo, provider.ConnectionError
}

res, err := r.addProvider(fieldInfo.Provider, false)
if err != nil {
return nil, nil, nil, multierr.Wrap(err, "failed to start provider '"+fieldInfo.Provider+"'")
}

conn, err := res.Instance.Plugin.Connect(&plugin.ConnectReq{
res.Connection, res.ConnectionError = res.Instance.Plugin.Connect(&plugin.ConnectReq{
Features: r.features,
Upstream: r.UpstreamConfig,
Asset: r.Provider.Connection.Asset,
}, nil)
if err != nil {
return nil, nil, nil, err
if res.ConnectionError != nil {
return nil, nil, nil, res.ConnectionError
}

res.Connection = conn

return res, resourceInfo, fieldInfo, nil
}

Expand Down

0 comments on commit f4e3630

Please sign in to comment.