diff --git a/internal/pkg/agent/application/upgrade/artifact/download/http/downloader.go b/internal/pkg/agent/application/upgrade/artifact/download/http/downloader.go index 0641f9cab68..7be3ae1066f 100644 --- a/internal/pkg/agent/application/upgrade/artifact/download/http/downloader.go +++ b/internal/pkg/agent/application/upgrade/artifact/download/http/downloader.go @@ -208,7 +208,7 @@ func (e *Downloader) downloadFile(ctx context.Context, artifactName, filename, f loggingObserver := newLoggingProgressObserver(e.log, e.config.HTTPTransportSettings.Timeout) dp := newDownloadProgressReporter(sourceURI, e.config.HTTPTransportSettings.Timeout, fileSize, loggingObserver) - dp.Report() + dp.Report(ctx) _, err = io.Copy(destinationFile, io.TeeReader(resp.Body, dp)) if err != nil { dp.ReportFailed(err) diff --git a/internal/pkg/agent/application/upgrade/artifact/download/http/progress_reporter.go b/internal/pkg/agent/application/upgrade/artifact/download/http/progress_reporter.go index f50168c88ca..491646b3ab5 100644 --- a/internal/pkg/agent/application/upgrade/artifact/download/http/progress_reporter.go +++ b/internal/pkg/agent/application/upgrade/artifact/download/http/progress_reporter.go @@ -5,6 +5,7 @@ package http import ( + "context" "time" "github.com/elastic/elastic-agent-libs/atomic" @@ -45,10 +46,10 @@ func (dp *downloadProgressReporter) Write(b []byte) (int, error) { return n, nil } -// Report periodically reports download progress to registered observers. Callers MUST call -// either ReportComplete or ReportFailed when they no longer need the downloadProgressReporter -// to avoid resource leaks. -func (dp *downloadProgressReporter) Report() { +// Report periodically reports download progress to registered observers. Callers MUST either +// cancel the context provided to this method OR call either ReportComplete or ReportFailed when +// they no longer need the downloadProgressReporter to avoid resource leaks. +func (dp *downloadProgressReporter) Report(ctx context.Context) { started := time.Now() dp.started = started sourceURI := dp.sourceURI @@ -65,6 +66,8 @@ func (dp *downloadProgressReporter) Report() { defer t.Stop() for { select { + case <-ctx.Done(): + return case <-dp.done: return case <-t.C: @@ -89,7 +92,7 @@ func (dp *downloadProgressReporter) Report() { // either ReportComplete or ReportFailed when they no longer need the downloadProgressReporter // to avoid resource leaks. func (dp *downloadProgressReporter) ReportComplete() { - defer dp.close() + defer close(dp.done) // If there are no observers to report progress to, there is nothing to do! if len(dp.progressObservers) == 0 { @@ -110,7 +113,7 @@ func (dp *downloadProgressReporter) ReportComplete() { // either ReportFailed or ReportComplete when they no longer need the downloadProgressReporter // to avoid resource leaks. func (dp *downloadProgressReporter) ReportFailed(err error) { - defer dp.close() + defer close(dp.done) // If there are no observers to report progress to, there is nothing to do! if len(dp.progressObservers) == 0 { @@ -130,7 +133,3 @@ func (dp *downloadProgressReporter) ReportFailed(err error) { obs.ReportFailed(dp.sourceURI, timePast, downloaded, dp.length, percentComplete, bytesPerSecond, err) } } - -func (dp *downloadProgressReporter) close() { - dp.done <- struct{}{} -}