Skip to content

Commit

Permalink
batch sync assets
Browse files Browse the repository at this point in the history
Signed-off-by: Ivan Milchev <[email protected]>
  • Loading branch information
imilchev committed Dec 12, 2023
1 parent a7e8a32 commit fe16ac1
Showing 1 changed file with 138 additions and 121 deletions.
259 changes: 138 additions & 121 deletions policy/scan/local_scanner.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ func (s *LocalScanner) Run(ctx context.Context, job *Job) (*ScanResult, error) {
return nil, err
}

reports, _, err := s.distributeJob(job, ctx, upstream)
reports, err := s.distributeJob(job, ctx, upstream)
if err != nil {
return nil, err
}
Expand All @@ -186,7 +186,7 @@ func (s *LocalScanner) RunIncognito(ctx context.Context, job *Job) (*ScanResult,
return nil, err
}

reports, _, err := s.distributeJob(job, ctx, upstream)
reports, err := s.distributeJob(job, ctx, upstream)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -280,7 +280,7 @@ func createAssetCandidateList(ctx context.Context, job *Job, upstream *upstream.
return assetList, assetCandidates, nil
}

func (s *LocalScanner) distributeJob(job *Job, ctx context.Context, upstream *upstream.UpstreamConfig) (*ScanResult, bool, error) {
func (s *LocalScanner) distributeJob(job *Job, ctx context.Context, upstream *upstream.UpstreamConfig) (*ScanResult, error) {
// Always shut down the coordinator, to make sure providers are killed
defer providers.Coordinator.Shutdown()

Expand All @@ -294,15 +294,15 @@ func (s *LocalScanner) distributeJob(job *Job, ctx context.Context, upstream *up
case ReportType_NONE:
reporter = NewNoOpReporter()
default:
return nil, false, errors.Errorf("unknown report type: %s", job.ReportType)
return nil, errors.Errorf("unknown report type: %s", job.ReportType)
}

log.Info().Msgf("discover related assets for %d asset(s)", len(job.Inventory.Spec.Assets))

var assets []*assetWithRuntime
assetList, assetCandidates, err := createAssetCandidateList(ctx, job, upstream, s.recording)
if err != nil {
return nil, false, err
return nil, err
}

// For each asset candidate, we initialize a new runtime and connect to it.
Expand All @@ -327,12 +327,12 @@ func (s *LocalScanner) distributeJob(job *Job, ctx context.Context, upstream *up
if candidate.asset.Connections[0].Type == "k8s" {
runtime, err = providers.Coordinator.RuntimeFor(candidate.asset, providers.DefaultRuntime())
if err != nil {
return nil, false, err
return nil, err
}
} else {
runtime, err = providers.Coordinator.EphemeralRuntimeFor(candidate.asset)
if err != nil {
return nil, false, err
return nil, err
}
}
runtime.UpstreamConfig = upstream
Expand Down Expand Up @@ -360,7 +360,7 @@ func (s *LocalScanner) distributeJob(job *Job, ctx context.Context, upstream *up
}

if len(assets) == 0 {
return nil, false, nil
return nil, nil
}

// if there is exactly one asset, assure that the --asset-name is used
Expand All @@ -381,146 +381,163 @@ func (s *LocalScanner) distributeJob(job *Job, ctx context.Context, upstream *up
runtimeLabels = runtimeEnv.Labels()
}

justAssets := []*inventory.Asset{}
for _, asset := range assets {
asset.asset.KindString = asset.asset.GetPlatform().Kind
for k, v := range runtimeLabels {
if asset.asset.Labels == nil {
asset.asset.Labels = map[string]string{}
}
asset.asset.Labels[k] = v
}
justAssets = append(justAssets, asset.asset)
}

// sync assets
if upstream != nil && upstream.ApiEndpoint != "" && !upstream.Incognito {
log.Info().Msg("synchronize assets")
client, err := upstream.InitClient()
if err != nil {
return nil, false, err
}

services, err := policy.NewRemoteServices(client.ApiEndpoint, client.Plugins, client.HttpClient)
if err != nil {
return nil, false, err
}

inventory.DeprecatedV8CompatAssets(justAssets)
resp, err := services.SynchronizeAssets(ctx, &policy.SynchronizeAssetsReq{
SpaceMrn: client.SpaceMrn,
List: justAssets,
})
if err != nil {
return nil, false, err
}
log.Debug().Int("assets", len(resp.Details)).Msg("got assets details")
platformAssetMapping := make(map[string]*policy.SynchronizeAssetsRespAssetDetail)
for i := range resp.Details {
log.Debug().Str("platform-mrn", resp.Details[i].PlatformMrn).Str("asset", resp.Details[i].AssetMrn).Msg("asset mapping")
platformAssetMapping[resp.Details[i].PlatformMrn] = resp.Details[i]
}

// attach the asset details to the assets list
for i := range assets {
log.Debug().Str("asset", assets[i].asset.Name).Strs("platform-ids", assets[i].asset.PlatformIds).Msg("update asset")
platformMrn := assets[i].asset.PlatformIds[0]
assets[i].asset.Mrn = platformAssetMapping[platformMrn].AssetMrn
assets[i].asset.Url = platformAssetMapping[platformMrn].Url
}
} else {
// ensure we have non-empty asset MRNs
for i := range assets {
cur := assets[i]
if cur.asset.Mrn == "" {
randID := "//" + policy.POLICY_SERVICE_NAME + "/" + policy.MRN_RESOURCE_ASSET + "/" + ksuid.New().String()
x, err := mrn.NewMRN(randID)
if err != nil {
return nil, false, multierr.Wrap(err, "failed to generate a random asset MRN")
}
cur.asset.Mrn = x.String()
}
}
}

// // if a bundle was provided check that it matches the filter, bundles can also be downloaded
// // later therefore we do not want to stop execution here
// if job.Bundle != nil && job.Bundle.FilterPolicies(job.PolicyFilters) {
// return nil, false, errors.New("all available packs filtered out. nothing to do.")
// }

progressBarElements := map[string]string{}
var multiprogress progress.MultiProgress
orderedKeys := []string{}
for i := range assets {
// this shouldn't happen, but might
// it normally indicates a bug in the provider
if presentAsset, present := progressBarElements[assets[i].asset.PlatformIds[0]]; present {
return nil, false, fmt.Errorf("asset %s and %s have the same platform id %s", presentAsset, assets[i].asset.Name, assets[i].asset.PlatformIds[0])
return nil, fmt.Errorf("asset %s and %s have the same platform id %s", presentAsset, assets[i].asset.Name, assets[i].asset.PlatformIds[0])
}
progressBarElements[assets[i].asset.PlatformIds[0]] = assets[i].asset.Name
orderedKeys = append(orderedKeys, assets[i].asset.PlatformIds[0])
}
var multiprogress progress.MultiProgress

if isatty.IsTerminal(os.Stdout.Fd()) && !s.disableProgressBar && !strings.EqualFold(logger.GetLevel(), "debug") && !strings.EqualFold(logger.GetLevel(), "trace") {
var err error
multiprogress, err = progress.NewMultiProgressBars(progressBarElements, orderedKeys, progress.WithScore())
if err != nil {
return nil, false, multierr.Wrap(err, "failed to create progress bars")
return nil, multierr.Wrap(err, "failed to create progress bars")
}
} else {
// TODO: adjust naming
multiprogress = progress.NoopMultiProgressBars{}
}
scanGroups := sync.WaitGroup{}

scanGroup := sync.WaitGroup{}
scanGroup.Add(1)
finished := false
// start the progress bar
scanGroups.Add(1)
go func() {
defer scanGroup.Done()
for i := range assets {
asset := assets[i].asset
runtime := assets[i].runtime

log.Debug().Interface("asset", asset).Msg("start scan")

// Make sure the context has not been canceled in the meantime. Note that this approach works only for single threaded execution. If we have more than 1 thread calling this function,
// we need to solve this at a different level.
select {
case <-ctx.Done():
log.Warn().Msg("request context has been canceled")
// When we scan concurrently, we need to call Errored(asset.Mrn) status for this asset
multiprogress.Close()
return
default:
defer scanGroups.Done()
multiprogress.Open()

Check failure on line 413 in policy/scan/local_scanner.go

View workflow job for this annotation

GitHub Actions / golangci-lint

Error return value of `multiprogress.Open` is not checked (errcheck)
}()

assetBatches := batch(assets, 10)
for i := range assetBatches {
batch := assetBatches[i]
justAssets := []*inventory.Asset{}
for _, asset := range batch {
asset.asset.KindString = asset.asset.GetPlatform().Kind
for k, v := range runtimeLabels {
if asset.asset.Labels == nil {
asset.asset.Labels = map[string]string{}
}
asset.asset.Labels[k] = v
}
justAssets = append(justAssets, asset.asset)
}

inventory.DeprecatedV8CompatAssets(justAssets)

// sync assets
if upstream != nil && upstream.ApiEndpoint != "" && !upstream.Incognito {
log.Info().Msg("synchronize assets")
client, err := upstream.InitClient()
if err != nil {
return nil, err
}

services, err := policy.NewRemoteServices(client.ApiEndpoint, client.Plugins, client.HttpClient)
if err != nil {
return nil, err
}

p := &progress.MultiProgressAdapter{Key: asset.PlatformIds[0], Multi: multiprogress}
s.RunAssetJob(&AssetJob{
DoRecord: job.DoRecord,
UpstreamConfig: upstream,
Asset: asset,
Bundle: job.Bundle,
Props: job.Props,
PolicyFilters: preprocessPolicyFilters(job.PolicyFilters),
Ctx: ctx,
Reporter: reporter,
ProgressReporter: p,
runtime: runtime,
resp, err := services.SynchronizeAssets(ctx, &policy.SynchronizeAssetsReq{
SpaceMrn: client.SpaceMrn,
List: justAssets,
})
if err != nil {
return nil, err
}
log.Debug().Int("assets", len(resp.Details)).Msg("got assets details")
platformAssetMapping := make(map[string]*policy.SynchronizeAssetsRespAssetDetail)
for i := range resp.Details {
log.Debug().Str("platform-mrn", resp.Details[i].PlatformMrn).Str("asset", resp.Details[i].AssetMrn).Msg("asset mapping")
platformAssetMapping[resp.Details[i].PlatformMrn] = resp.Details[i]
}

// shut down all ephemeral runtimes
runtime.Close()
// attach the asset details to the assets list
for i := range batch {
log.Debug().Str("asset", batch[i].asset.Name).Strs("platform-ids", batch[i].asset.PlatformIds).Msg("update asset")
platformMrn := batch[i].asset.PlatformIds[0]
batch[i].asset.Mrn = platformAssetMapping[platformMrn].AssetMrn
batch[i].asset.Url = platformAssetMapping[platformMrn].Url
}
} else {
// ensure we have non-empty asset MRNs
for i := range batch {
cur := batch[i]
if cur.asset.Mrn == "" {
randID := "//" + policy.POLICY_SERVICE_NAME + "/" + policy.MRN_RESOURCE_ASSET + "/" + ksuid.New().String()
x, err := mrn.NewMRN(randID)
if err != nil {
return nil, multierr.Wrap(err, "failed to generate a random asset MRN")
}
cur.asset.Mrn = x.String()
}
}
}
finished = true
}()

scanGroup.Add(1)
go func() {
defer scanGroup.Done()
multiprogress.Open()
}()
scanGroup.Wait()
return reporter.Reports(), finished, nil
// // if a bundle was provided check that it matches the filter, bundles can also be downloaded
// // later therefore we do not want to stop execution here
// if job.Bundle != nil && job.Bundle.FilterPolicies(job.PolicyFilters) {
// return nil, false, errors.New("all available packs filtered out. nothing to do.")
// }

scanGroups.Add(1)
go func() {
defer scanGroups.Done()
for i := range batch {
asset := batch[i].asset
runtime := batch[i].runtime

log.Debug().Interface("asset", asset).Msg("start scan")

// Make sure the context has not been canceled in the meantime. Note that this approach works only for single threaded execution. If we have more than 1 thread calling this function,
// we need to solve this at a different level.
select {
case <-ctx.Done():
log.Warn().Msg("request context has been canceled")
// When we scan concurrently, we need to call Errored(asset.Mrn) status for this asset
multiprogress.Close()
return
default:
}

p := &progress.MultiProgressAdapter{Key: asset.PlatformIds[0], Multi: multiprogress}
s.RunAssetJob(&AssetJob{
DoRecord: job.DoRecord,
UpstreamConfig: upstream,
Asset: asset,
Bundle: job.Bundle,
Props: job.Props,
PolicyFilters: preprocessPolicyFilters(job.PolicyFilters),
Ctx: ctx,
Reporter: reporter,
ProgressReporter: p,
runtime: runtime,
})

// shut down all ephemeral runtimes
runtime.Close()
}
}()
}
scanGroups.Wait() // wait for all scans to complete
return reporter.Reports(), nil
}

func batch[T any](list []T, batchSize int) [][]T {
var res [][]T
for i := 0; i < len(list); i += batchSize {
end := i + batchSize
if end > len(list) {
end = len(list)
}
res = append(res, list[i:end])
}
return res
}

func (s *LocalScanner) upstreamServices(conf *upstream.UpstreamConfig) *policy.Services {
Expand Down

0 comments on commit fe16ac1

Please sign in to comment.