Skip to content

Commit

Permalink
feat(misconf): Support private registries for misconf check bundle (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
JeffResc authored Apr 1, 2024
1 parent df024e8 commit f23ed77
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 13 deletions.
2 changes: 1 addition & 1 deletion pkg/cloud/aws/scanner/scanner.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func (s *AWSScanner) Scan(ctx context.Context, option flag.Options) (scan.Result
var policyPaths []string
var downloadedPolicyPaths []string
var err error
downloadedPolicyPaths, err = operation.InitBuiltinPolicies(context.Background(), option.CacheDir, option.Quiet, option.SkipPolicyUpdate, option.MisconfOptions.PolicyBundleRepository)
downloadedPolicyPaths, err = operation.InitBuiltinPolicies(context.Background(), option.CacheDir, option.Quiet, option.SkipPolicyUpdate, option.MisconfOptions.PolicyBundleRepository, option.RegistryOpts())
if err != nil {
if !option.SkipPolicyUpdate {
log.Logger.Errorf("Falling back to embedded policies: %s", err)
Expand Down
2 changes: 1 addition & 1 deletion pkg/commands/artifact/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,7 @@ func initScannerConfig(opts flag.Options, cacheClient cache.Cache) (ScannerConfi

var downloadedPolicyPaths []string
var disableEmbedded bool
downloadedPolicyPaths, err := operation.InitBuiltinPolicies(context.Background(), opts.CacheDir, opts.Quiet, opts.SkipPolicyUpdate, opts.MisconfOptions.PolicyBundleRepository)
downloadedPolicyPaths, err := operation.InitBuiltinPolicies(context.Background(), opts.CacheDir, opts.Quiet, opts.SkipPolicyUpdate, opts.MisconfOptions.PolicyBundleRepository, opts.RegistryOpts())
if err != nil {
if !opts.SkipPolicyUpdate {
log.Logger.Errorf("Falling back to embedded policies: %s", err)
Expand Down
6 changes: 3 additions & 3 deletions pkg/commands/operation/operation.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ func showDBInfo(cacheDir string) error {
}

// InitBuiltinPolicies downloads the built-in policies and loads them
func InitBuiltinPolicies(ctx context.Context, cacheDir string, quiet, skipUpdate bool, policyBundleRepository string) ([]string, error) {
func InitBuiltinPolicies(ctx context.Context, cacheDir string, quiet, skipUpdate bool, policyBundleRepository string, registryOpts ftypes.RegistryOptions) ([]string, error) {
mu.Lock()
defer mu.Unlock()

Expand All @@ -159,7 +159,7 @@ func InitBuiltinPolicies(ctx context.Context, cacheDir string, quiet, skipUpdate

needsUpdate := false
if !skipUpdate {
needsUpdate, err = client.NeedsUpdate(ctx)
needsUpdate, err = client.NeedsUpdate(ctx, registryOpts)
if err != nil {
return nil, xerrors.Errorf("unable to check if built-in policies need to be updated: %w", err)
}
Expand All @@ -168,7 +168,7 @@ func InitBuiltinPolicies(ctx context.Context, cacheDir string, quiet, skipUpdate
if needsUpdate {
log.Logger.Info("Need to update the built-in policies")
log.Logger.Info("Downloading the built-in policies...")
if err = client.DownloadBuiltinPolicies(ctx); err != nil {
if err = client.DownloadBuiltinPolicies(ctx, registryOpts); err != nil {
return nil, xerrors.Errorf("failed to download built-in policies: %w", err)
}
}
Expand Down
12 changes: 6 additions & 6 deletions pkg/policy/policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,10 @@ func NewClient(cacheDir string, quiet bool, policyBundleRepo string, opts ...Opt
}, nil
}

func (c *Client) populateOCIArtifact() error {
func (c *Client) populateOCIArtifact(registryOpts types.RegistryOptions) error {
if c.artifact == nil {
log.Logger.Debugf("Using URL: %s to load policy bundle", c.policyBundleRepo)
art, err := oci.NewArtifact(c.policyBundleRepo, c.quiet, types.RegistryOptions{})
art, err := oci.NewArtifact(c.policyBundleRepo, c.quiet, registryOpts)
if err != nil {
return xerrors.Errorf("OCI artifact error: %w", err)
}
Expand All @@ -102,8 +102,8 @@ func (c *Client) populateOCIArtifact() error {
}

// DownloadBuiltinPolicies download default policies from GitHub Pages
func (c *Client) DownloadBuiltinPolicies(ctx context.Context) error {
if err := c.populateOCIArtifact(); err != nil {
func (c *Client) DownloadBuiltinPolicies(ctx context.Context, registryOpts types.RegistryOptions) error {
if err := c.populateOCIArtifact(registryOpts); err != nil {
return xerrors.Errorf("OPA bundle error: %w", err)
}

Expand Down Expand Up @@ -154,7 +154,7 @@ func (c *Client) LoadBuiltinPolicies() ([]string, error) {
}

// NeedsUpdate returns if the default policy should be updated
func (c *Client) NeedsUpdate(ctx context.Context) (bool, error) {
func (c *Client) NeedsUpdate(ctx context.Context, registryOpts types.RegistryOptions) (bool, error) {
meta, err := c.GetMetadata()
if err != nil {
return true, nil
Expand All @@ -165,7 +165,7 @@ func (c *Client) NeedsUpdate(ctx context.Context) (bool, error) {
return false, nil
}

if err = c.populateOCIArtifact(); err != nil {
if err = c.populateOCIArtifact(registryOpts); err != nil {
return false, xerrors.Errorf("OPA bundle error: %w", err)
}

Expand Down
4 changes: 2 additions & 2 deletions pkg/policy/policy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ func TestClient_NeedsUpdate(t *testing.T) {
require.NoError(t, err)

// Assert results
got, err := c.NeedsUpdate(context.Background())
got, err := c.NeedsUpdate(context.Background(), ftypes.RegistryOptions{})
assert.Equal(t, tt.wantErr, err != nil)
assert.Equal(t, tt.want, got)
})
Expand Down Expand Up @@ -367,7 +367,7 @@ func TestClient_DownloadBuiltinPolicies(t *testing.T) {
c, err := policy.NewClient(tempDir, true, "", policy.WithClock(tt.clock), policy.WithOCIArtifact(art))
require.NoError(t, err)

err = c.DownloadBuiltinPolicies(context.Background())
err = c.DownloadBuiltinPolicies(context.Background(), ftypes.RegistryOptions{})
if tt.wantErr != "" {
require.NotNil(t, err)
assert.Contains(t, err.Error(), tt.wantErr)
Expand Down

0 comments on commit f23ed77

Please sign in to comment.