From fe16ac19d1ed121055d8a3bdb5237ab2292452f8 Mon Sep 17 00:00:00 2001 From: Ivan Milchev Date: Tue, 12 Dec 2023 11:10:44 +0200 Subject: [PATCH] batch sync assets Signed-off-by: Ivan Milchev --- policy/scan/local_scanner.go | 259 +++++++++++++++++++---------------- 1 file changed, 138 insertions(+), 121 deletions(-) diff --git a/policy/scan/local_scanner.go b/policy/scan/local_scanner.go index f7cb2adc..920c6ecd 100644 --- a/policy/scan/local_scanner.go +++ b/policy/scan/local_scanner.go @@ -160,7 +160,7 @@ func (s *LocalScanner) Run(ctx context.Context, job *Job) (*ScanResult, error) { return nil, err } - reports, _, err := s.distributeJob(job, ctx, upstream) + reports, err := s.distributeJob(job, ctx, upstream) if err != nil { return nil, err } @@ -186,7 +186,7 @@ func (s *LocalScanner) RunIncognito(ctx context.Context, job *Job) (*ScanResult, return nil, err } - reports, _, err := s.distributeJob(job, ctx, upstream) + reports, err := s.distributeJob(job, ctx, upstream) if err != nil { return nil, err } @@ -280,7 +280,7 @@ func createAssetCandidateList(ctx context.Context, job *Job, upstream *upstream. return assetList, assetCandidates, nil } -func (s *LocalScanner) distributeJob(job *Job, ctx context.Context, upstream *upstream.UpstreamConfig) (*ScanResult, bool, error) { +func (s *LocalScanner) distributeJob(job *Job, ctx context.Context, upstream *upstream.UpstreamConfig) (*ScanResult, error) { // Always shut down the coordinator, to make sure providers are killed defer providers.Coordinator.Shutdown() @@ -294,7 +294,7 @@ func (s *LocalScanner) distributeJob(job *Job, ctx context.Context, upstream *up case ReportType_NONE: reporter = NewNoOpReporter() default: - return nil, false, errors.Errorf("unknown report type: %s", job.ReportType) + return nil, errors.Errorf("unknown report type: %s", job.ReportType) } log.Info().Msgf("discover related assets for %d asset(s)", len(job.Inventory.Spec.Assets)) @@ -302,7 +302,7 @@ func (s *LocalScanner) distributeJob(job *Job, ctx context.Context, upstream *up var assets []*assetWithRuntime assetList, assetCandidates, err := createAssetCandidateList(ctx, job, upstream, s.recording) if err != nil { - return nil, false, err + return nil, err } // For each asset candidate, we initialize a new runtime and connect to it. @@ -327,12 +327,12 @@ func (s *LocalScanner) distributeJob(job *Job, ctx context.Context, upstream *up if candidate.asset.Connections[0].Type == "k8s" { runtime, err = providers.Coordinator.RuntimeFor(candidate.asset, providers.DefaultRuntime()) if err != nil { - return nil, false, err + return nil, err } } else { runtime, err = providers.Coordinator.EphemeralRuntimeFor(candidate.asset) if err != nil { - return nil, false, err + return nil, err } } runtime.UpstreamConfig = upstream @@ -360,7 +360,7 @@ func (s *LocalScanner) distributeJob(job *Job, ctx context.Context, upstream *up } if len(assets) == 0 { - return nil, false, nil + return nil, nil } // if there is exactly one asset, assure that the --asset-name is used @@ -381,146 +381,163 @@ func (s *LocalScanner) distributeJob(job *Job, ctx context.Context, upstream *up runtimeLabels = runtimeEnv.Labels() } - justAssets := []*inventory.Asset{} - for _, asset := range assets { - asset.asset.KindString = asset.asset.GetPlatform().Kind - for k, v := range runtimeLabels { - if asset.asset.Labels == nil { - asset.asset.Labels = map[string]string{} - } - asset.asset.Labels[k] = v - } - justAssets = append(justAssets, asset.asset) - } - - // sync assets - if upstream != nil && upstream.ApiEndpoint != "" && !upstream.Incognito { - log.Info().Msg("synchronize assets") - client, err := upstream.InitClient() - if err != nil { - return nil, false, err - } - - services, err := policy.NewRemoteServices(client.ApiEndpoint, client.Plugins, client.HttpClient) - if err != nil { - return nil, false, err - } - - inventory.DeprecatedV8CompatAssets(justAssets) - resp, err := services.SynchronizeAssets(ctx, &policy.SynchronizeAssetsReq{ - SpaceMrn: client.SpaceMrn, - List: justAssets, - }) - if err != nil { - return nil, false, err - } - log.Debug().Int("assets", len(resp.Details)).Msg("got assets details") - platformAssetMapping := make(map[string]*policy.SynchronizeAssetsRespAssetDetail) - for i := range resp.Details { - log.Debug().Str("platform-mrn", resp.Details[i].PlatformMrn).Str("asset", resp.Details[i].AssetMrn).Msg("asset mapping") - platformAssetMapping[resp.Details[i].PlatformMrn] = resp.Details[i] - } - - // attach the asset details to the assets list - for i := range assets { - log.Debug().Str("asset", assets[i].asset.Name).Strs("platform-ids", assets[i].asset.PlatformIds).Msg("update asset") - platformMrn := assets[i].asset.PlatformIds[0] - assets[i].asset.Mrn = platformAssetMapping[platformMrn].AssetMrn - assets[i].asset.Url = platformAssetMapping[platformMrn].Url - } - } else { - // ensure we have non-empty asset MRNs - for i := range assets { - cur := assets[i] - if cur.asset.Mrn == "" { - randID := "//" + policy.POLICY_SERVICE_NAME + "/" + policy.MRN_RESOURCE_ASSET + "/" + ksuid.New().String() - x, err := mrn.NewMRN(randID) - if err != nil { - return nil, false, multierr.Wrap(err, "failed to generate a random asset MRN") - } - cur.asset.Mrn = x.String() - } - } - } - - // // if a bundle was provided check that it matches the filter, bundles can also be downloaded - // // later therefore we do not want to stop execution here - // if job.Bundle != nil && job.Bundle.FilterPolicies(job.PolicyFilters) { - // return nil, false, errors.New("all available packs filtered out. nothing to do.") - // } - progressBarElements := map[string]string{} + var multiprogress progress.MultiProgress orderedKeys := []string{} for i := range assets { // this shouldn't happen, but might // it normally indicates a bug in the provider if presentAsset, present := progressBarElements[assets[i].asset.PlatformIds[0]]; present { - return nil, false, fmt.Errorf("asset %s and %s have the same platform id %s", presentAsset, assets[i].asset.Name, assets[i].asset.PlatformIds[0]) + return nil, fmt.Errorf("asset %s and %s have the same platform id %s", presentAsset, assets[i].asset.Name, assets[i].asset.PlatformIds[0]) } progressBarElements[assets[i].asset.PlatformIds[0]] = assets[i].asset.Name orderedKeys = append(orderedKeys, assets[i].asset.PlatformIds[0]) } - var multiprogress progress.MultiProgress + if isatty.IsTerminal(os.Stdout.Fd()) && !s.disableProgressBar && !strings.EqualFold(logger.GetLevel(), "debug") && !strings.EqualFold(logger.GetLevel(), "trace") { var err error multiprogress, err = progress.NewMultiProgressBars(progressBarElements, orderedKeys, progress.WithScore()) if err != nil { - return nil, false, multierr.Wrap(err, "failed to create progress bars") + return nil, multierr.Wrap(err, "failed to create progress bars") } } else { // TODO: adjust naming multiprogress = progress.NoopMultiProgressBars{} } + scanGroups := sync.WaitGroup{} - scanGroup := sync.WaitGroup{} - scanGroup.Add(1) - finished := false + // start the progress bar + scanGroups.Add(1) go func() { - defer scanGroup.Done() - for i := range assets { - asset := assets[i].asset - runtime := assets[i].runtime - - log.Debug().Interface("asset", asset).Msg("start scan") - - // Make sure the context has not been canceled in the meantime. Note that this approach works only for single threaded execution. If we have more than 1 thread calling this function, - // we need to solve this at a different level. - select { - case <-ctx.Done(): - log.Warn().Msg("request context has been canceled") - // When we scan concurrently, we need to call Errored(asset.Mrn) status for this asset - multiprogress.Close() - return - default: + defer scanGroups.Done() + multiprogress.Open() + }() + + assetBatches := batch(assets, 10) + for i := range assetBatches { + batch := assetBatches[i] + justAssets := []*inventory.Asset{} + for _, asset := range batch { + asset.asset.KindString = asset.asset.GetPlatform().Kind + for k, v := range runtimeLabels { + if asset.asset.Labels == nil { + asset.asset.Labels = map[string]string{} + } + asset.asset.Labels[k] = v + } + justAssets = append(justAssets, asset.asset) + } + + inventory.DeprecatedV8CompatAssets(justAssets) + + // sync assets + if upstream != nil && upstream.ApiEndpoint != "" && !upstream.Incognito { + log.Info().Msg("synchronize assets") + client, err := upstream.InitClient() + if err != nil { + return nil, err + } + + services, err := policy.NewRemoteServices(client.ApiEndpoint, client.Plugins, client.HttpClient) + if err != nil { + return nil, err } - p := &progress.MultiProgressAdapter{Key: asset.PlatformIds[0], Multi: multiprogress} - s.RunAssetJob(&AssetJob{ - DoRecord: job.DoRecord, - UpstreamConfig: upstream, - Asset: asset, - Bundle: job.Bundle, - Props: job.Props, - PolicyFilters: preprocessPolicyFilters(job.PolicyFilters), - Ctx: ctx, - Reporter: reporter, - ProgressReporter: p, - runtime: runtime, + resp, err := services.SynchronizeAssets(ctx, &policy.SynchronizeAssetsReq{ + SpaceMrn: client.SpaceMrn, + List: justAssets, }) + if err != nil { + return nil, err + } + log.Debug().Int("assets", len(resp.Details)).Msg("got assets details") + platformAssetMapping := make(map[string]*policy.SynchronizeAssetsRespAssetDetail) + for i := range resp.Details { + log.Debug().Str("platform-mrn", resp.Details[i].PlatformMrn).Str("asset", resp.Details[i].AssetMrn).Msg("asset mapping") + platformAssetMapping[resp.Details[i].PlatformMrn] = resp.Details[i] + } - // shut down all ephemeral runtimes - runtime.Close() + // attach the asset details to the assets list + for i := range batch { + log.Debug().Str("asset", batch[i].asset.Name).Strs("platform-ids", batch[i].asset.PlatformIds).Msg("update asset") + platformMrn := batch[i].asset.PlatformIds[0] + batch[i].asset.Mrn = platformAssetMapping[platformMrn].AssetMrn + batch[i].asset.Url = platformAssetMapping[platformMrn].Url + } + } else { + // ensure we have non-empty asset MRNs + for i := range batch { + cur := batch[i] + if cur.asset.Mrn == "" { + randID := "//" + policy.POLICY_SERVICE_NAME + "/" + policy.MRN_RESOURCE_ASSET + "/" + ksuid.New().String() + x, err := mrn.NewMRN(randID) + if err != nil { + return nil, multierr.Wrap(err, "failed to generate a random asset MRN") + } + cur.asset.Mrn = x.String() + } + } } - finished = true - }() - scanGroup.Add(1) - go func() { - defer scanGroup.Done() - multiprogress.Open() - }() - scanGroup.Wait() - return reporter.Reports(), finished, nil + // // if a bundle was provided check that it matches the filter, bundles can also be downloaded + // // later therefore we do not want to stop execution here + // if job.Bundle != nil && job.Bundle.FilterPolicies(job.PolicyFilters) { + // return nil, false, errors.New("all available packs filtered out. nothing to do.") + // } + + scanGroups.Add(1) + go func() { + defer scanGroups.Done() + for i := range batch { + asset := batch[i].asset + runtime := batch[i].runtime + + log.Debug().Interface("asset", asset).Msg("start scan") + + // Make sure the context has not been canceled in the meantime. Note that this approach works only for single threaded execution. If we have more than 1 thread calling this function, + // we need to solve this at a different level. + select { + case <-ctx.Done(): + log.Warn().Msg("request context has been canceled") + // When we scan concurrently, we need to call Errored(asset.Mrn) status for this asset + multiprogress.Close() + return + default: + } + + p := &progress.MultiProgressAdapter{Key: asset.PlatformIds[0], Multi: multiprogress} + s.RunAssetJob(&AssetJob{ + DoRecord: job.DoRecord, + UpstreamConfig: upstream, + Asset: asset, + Bundle: job.Bundle, + Props: job.Props, + PolicyFilters: preprocessPolicyFilters(job.PolicyFilters), + Ctx: ctx, + Reporter: reporter, + ProgressReporter: p, + runtime: runtime, + }) + + // shut down all ephemeral runtimes + runtime.Close() + } + }() + } + scanGroups.Wait() // wait for all scans to complete + return reporter.Reports(), nil +} + +func batch[T any](list []T, batchSize int) [][]T { + var res [][]T + for i := 0; i < len(list); i += batchSize { + end := i + batchSize + if end > len(list) { + end = len(list) + } + res = append(res, list[i:end]) + } + return res } func (s *LocalScanner) upstreamServices(conf *upstream.UpstreamConfig) *policy.Services {