From 67f80f63fe8c8359e381bc3233234c99ce4f1c41 Mon Sep 17 00:00:00 2001 From: Shaunak Kashyap Date: Thu, 12 Oct 2023 05:45:17 -0700 Subject: [PATCH] Remove context and handle cancellation internally instead --- .../artifact/download/http/downloader.go | 5 +--- .../download/http/progress_reporter.go | 26 +++++++++++++++---- 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/internal/pkg/agent/application/upgrade/artifact/download/http/downloader.go b/internal/pkg/agent/application/upgrade/artifact/download/http/downloader.go index e4ebbf26dc3..0641f9cab68 100644 --- a/internal/pkg/agent/application/upgrade/artifact/download/http/downloader.go +++ b/internal/pkg/agent/application/upgrade/artifact/download/http/downloader.go @@ -207,17 +207,14 @@ func (e *Downloader) downloadFile(ctx context.Context, artifactName, filename, f } loggingObserver := newLoggingProgressObserver(e.log, e.config.HTTPTransportSettings.Timeout) - reportCtx, reportCancel := context.WithCancel(ctx) dp := newDownloadProgressReporter(sourceURI, e.config.HTTPTransportSettings.Timeout, fileSize, loggingObserver) - dp.Report(reportCtx) + dp.Report() _, err = io.Copy(destinationFile, io.TeeReader(resp.Body, dp)) if err != nil { - reportCancel() dp.ReportFailed(err) // return path, file already exists and needs to be cleaned up return fullPath, errors.New(err, "copying fetched package failed", errors.TypeNetwork, errors.M(errors.MetaKeyURI, sourceURI)) } - reportCancel() dp.ReportComplete() return fullPath, nil diff --git a/internal/pkg/agent/application/upgrade/artifact/download/http/progress_reporter.go b/internal/pkg/agent/application/upgrade/artifact/download/http/progress_reporter.go index 3114016d284..9834b3488ae 100644 --- a/internal/pkg/agent/application/upgrade/artifact/download/http/progress_reporter.go +++ b/internal/pkg/agent/application/upgrade/artifact/download/http/progress_reporter.go @@ -5,7 +5,6 @@ package http import ( - "context" "time" "github.com/elastic/elastic-agent-libs/atomic" @@ -21,6 +20,7 @@ type downloadProgressReporter struct { started time.Time progressObservers []progressObserver + done chan struct{} } func newDownloadProgressReporter(sourceURI string, timeout time.Duration, length int, progressObservers ...progressObserver) *downloadProgressReporter { @@ -35,6 +35,7 @@ func newDownloadProgressReporter(sourceURI string, timeout time.Duration, length warnTimeout: time.Duration(float64(timeout) * warningProgressIntervalPercentage), length: float64(length), progressObservers: progressObservers, + done: make(chan struct{}), } } @@ -44,9 +45,10 @@ func (dp *downloadProgressReporter) Write(b []byte) (int, error) { return n, nil } -// Report periodically reports download progress to registered observers. Callers MUST cancel -// the context passed to this method to avoid resource leaks. -func (dp *downloadProgressReporter) Report(ctx context.Context) { +// Report periodically reports download progress to registered observers. Callers MUST call +// either ReportComplete or ReportFailed when they no longer need the downloadProgressReporter +// to avoid resource leaks. +func (dp *downloadProgressReporter) Report() { started := time.Now() dp.started = started sourceURI := dp.sourceURI @@ -63,7 +65,7 @@ func (dp *downloadProgressReporter) Report(ctx context.Context) { defer t.Stop() for { select { - case <-ctx.Done(): + case <-dp.done: return case <-t.C: now := time.Now() @@ -83,6 +85,9 @@ func (dp *downloadProgressReporter) Report(ctx context.Context) { }() } +// ReportComplete reports the completion of a download to registered observers. Callers MUST call +// either ReportComplete or ReportFailed when they no longer need the downloadProgressReporter +// to avoid resource leaks. func (dp *downloadProgressReporter) ReportComplete() { now := time.Now() timePast := now.Sub(dp.started) @@ -92,8 +97,13 @@ func (dp *downloadProgressReporter) ReportComplete() { for _, obs := range dp.progressObservers { obs.ReportCompleted(dp.sourceURI, timePast, bytesPerSecond) } + + dp.close() } +// ReportFailed reports the failure of a download to registered observers. Callers MUST call +// either ReportFailed or ReportComplete when they no longer need the downloadProgressReporter +// to avoid resource leaks. func (dp *downloadProgressReporter) ReportFailed(err error) { now := time.Now() timePast := now.Sub(dp.started) @@ -107,4 +117,10 @@ func (dp *downloadProgressReporter) ReportFailed(err error) { for _, obs := range dp.progressObservers { obs.ReportFailed(dp.sourceURI, timePast, downloaded, dp.length, percentComplete, bytesPerSecond, err) } + + dp.close() +} + +func (dp *downloadProgressReporter) close() { + dp.done <- struct{}{} }