Skip to content

Commit

Permalink
Adding details progress observer and unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ycombinator committed Oct 13, 2023
1 parent 94da5bb commit 1a8147d
Show file tree
Hide file tree
Showing 10 changed files with 118 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (

"github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade/artifact"
"github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade/artifact/download"
"github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade/details"
"github.com/elastic/elastic-agent/internal/pkg/agent/errors"
"github.com/elastic/elastic-agent/pkg/core/logger"
)
Expand All @@ -44,13 +45,14 @@ const (

// Downloader is a downloader able to fetch artifacts from elastic.co web page.
type Downloader struct {
log *logger.Logger
config *artifact.Config
client http.Client
log *logger.Logger
config *artifact.Config
client http.Client
upgradeDetails *details.Details
}

// NewDownloader creates and configures Elastic Downloader
func NewDownloader(log *logger.Logger, config *artifact.Config) (*Downloader, error) {
func NewDownloader(log *logger.Logger, config *artifact.Config, upgradeDetails *details.Details) (*Downloader, error) {
client, err := config.HTTPTransportSettings.Client(
httpcommon.WithAPMHTTPInstrumentation(),
httpcommon.WithKeepaliveSettings{Disable: false, IdleConnTimeout: 30 * time.Second},
Expand All @@ -60,15 +62,16 @@ func NewDownloader(log *logger.Logger, config *artifact.Config) (*Downloader, er
}

client.Transport = download.WithHeaders(client.Transport, download.Headers)
return NewDownloaderWithClient(log, config, *client), nil
return NewDownloaderWithClient(log, config, *client, upgradeDetails), nil
}

// NewDownloaderWithClient creates Elastic Downloader with specific client used
func NewDownloaderWithClient(log *logger.Logger, config *artifact.Config, client http.Client) *Downloader {
func NewDownloaderWithClient(log *logger.Logger, config *artifact.Config, client http.Client, upgradeDetails *details.Details) *Downloader {
return &Downloader{
log: log,
config: config,
client: client,
log: log,
config: config,
client: client,
upgradeDetails: upgradeDetails,
}
}

Expand Down Expand Up @@ -207,7 +210,8 @@ func (e *Downloader) downloadFile(ctx context.Context, artifactName, filename, f
}

loggingObserver := newLoggingProgressObserver(e.log, e.config.HTTPTransportSettings.Timeout)
dp := newDownloadProgressReporter(sourceURI, e.config.HTTPTransportSettings.Timeout, fileSize, loggingObserver)
detailsObserver := newDetailsProgressObserver(e.upgradeDetails)
dp := newDownloadProgressReporter(sourceURI, e.config.HTTPTransportSettings.Timeout, fileSize, loggingObserver, detailsObserver)
dp.Report(ctx)
_, err = io.Copy(destinationFile, io.TeeReader(resp.Body, dp))
if err != nil {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"go.uber.org/zap/zaptest/observer"

"github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade/artifact"
"github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade/details"
"github.com/elastic/elastic-agent/pkg/core/logger"

"github.com/docker/go-units"
Expand Down Expand Up @@ -63,7 +64,8 @@ func TestDownloadBodyError(t *testing.T) {
}

log, obs := logger.NewTesting("downloader")
testClient := NewDownloaderWithClient(log, config, *client)
upgradeDetails := details.NewDetails("8.12.0", details.StateRequested, "")
testClient := NewDownloaderWithClient(log, config, *client, upgradeDetails)
artifactPath, err := testClient.Download(context.Background(), beatSpec, version)
os.Remove(artifactPath)
if err == nil {
Expand Down Expand Up @@ -119,7 +121,8 @@ func TestDownloadLogProgressWithLength(t *testing.T) {
}

log, obs := logger.NewTesting("downloader")
testClient := NewDownloaderWithClient(log, config, *client)
upgradeDetails := details.NewDetails("8.12.0", details.StateRequested, "")
testClient := NewDownloaderWithClient(log, config, *client, upgradeDetails)
artifactPath, err := testClient.Download(context.Background(), beatSpec, version)
os.Remove(artifactPath)
require.NoError(t, err, "Download should not have errored")
Expand Down Expand Up @@ -201,7 +204,8 @@ func TestDownloadLogProgressWithoutLength(t *testing.T) {
}

log, obs := logger.NewTesting("downloader")
testClient := NewDownloaderWithClient(log, config, *client)
upgradeDetails := details.NewDetails("8.12.0", details.StateRequested, "")
testClient := NewDownloaderWithClient(log, config, *client, upgradeDetails)
artifactPath, err := testClient.Download(context.Background(), beatSpec, version)
os.Remove(artifactPath)
require.NoError(t, err, "Download should not have errored")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/elastic/elastic-agent-libs/transport/httpcommon"

"github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade/artifact"
"github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade/details"
"github.com/elastic/elastic-agent/pkg/core/logger"
)

Expand Down Expand Up @@ -70,7 +71,8 @@ func TestDownload(t *testing.T) {
config.OperatingSystem = testCase.system
config.Architecture = testCase.arch

testClient := NewDownloaderWithClient(log, config, elasticClient)
upgradeDetails := details.NewDetails("8.12.0", details.StateRequested, "")
testClient := NewDownloaderWithClient(log, config, elasticClient, upgradeDetails)
artifactPath, err := testClient.Download(context.Background(), beatSpec, version)
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -111,7 +113,8 @@ func TestVerify(t *testing.T) {
config.OperatingSystem = testCase.system
config.Architecture = testCase.arch

testClient := NewDownloaderWithClient(log, config, elasticClient)
upgradeDetails := details.NewDetails("8.12.0", details.StateRequested, "")
testClient := NewDownloaderWithClient(log, config, elasticClient, upgradeDetails)
artifact, err := testClient.Download(context.Background(), beatSpec, version)
if err != nil {
t.Fatal(err)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

"github.com/docker/go-units"

"github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade/details"
"github.com/elastic/elastic-agent/pkg/core/logger"
)

Expand Down Expand Up @@ -95,3 +96,28 @@ func (lpObs *loggingProgressObserver) ReportFailed(sourceURI string, timePast ti
lpObs.log.Warnf(msg, args...)
}
}

type detailsProgressObserver struct {
upgradeDetails *details.Details
}

func newDetailsProgressObserver(upgradeDetails *details.Details) *detailsProgressObserver {
upgradeDetails.SetState(details.StateDownloading)
return &detailsProgressObserver{
upgradeDetails: upgradeDetails,
}
}

func (dpObs *detailsProgressObserver) Report(sourceURI string, timePast time.Duration, downloadedBytes, totalBytes, percentComplete, downloadRate float64) {
dpObs.upgradeDetails.Metadata.DownloadPercent = percentComplete
dpObs.upgradeDetails.NotifyObservers()
}

func (dpObs *detailsProgressObserver) ReportCompleted(sourceURI string, timePast time.Duration, downloadRate float64) {
dpObs.upgradeDetails.Metadata.DownloadPercent = 1
dpObs.upgradeDetails.NotifyObservers()
}

func (dpObs *detailsProgressObserver) ReportFailed(sourceURI string, timePast time.Duration, downloadedBytes, totalBytes, percentComplete, downloadRate float64, err error) {
dpObs.upgradeDetails.Fail(err)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
// or more contributor license agreements. Licensed under the Elastic License;
// you may not use this file except in compliance with the Elastic License.

package http

import (
"testing"
"time"

"github.com/stretchr/testify/require"

"github.com/docker/go-units"

"github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade/details"
"github.com/elastic/elastic-agent/internal/pkg/agent/errors"
)

func TestDetailsProgressObserver(t *testing.T) {
upgradeDetails := details.NewDetails("8.11.0", details.StateRequested, "")
detailsObs := newDetailsProgressObserver(upgradeDetails)

detailsObs.Report("http://some/uri", 20*time.Second, 400*units.MiB, 500*units.MiB, 0.8, 4455)
require.Equal(t, details.StateDownloading, upgradeDetails.State)
require.Equal(t, 0.8, upgradeDetails.Metadata.DownloadPercent)

detailsObs.ReportCompleted("http://some/uri", 30*time.Second, 3333)
require.Equal(t, details.StateDownloading, upgradeDetails.State)
require.Equal(t, 1.0, upgradeDetails.Metadata.DownloadPercent)

err := errors.New("some download error")
detailsObs.ReportFailed("http://some/uri", 30*time.Second, 450*units.MiB, 500*units.MiB, 0.9, 1122, err)
require.Equal(t, details.StateFailed, upgradeDetails.State)
require.Equal(t, details.StateDownloading, upgradeDetails.Metadata.FailedState)
require.Equal(t, err.Error(), upgradeDetails.Metadata.ErrorMsg)
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@ import (
"github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade/artifact/download/fs"
"github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade/artifact/download/http"
"github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade/artifact/download/snapshot"
"github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade/details"
"github.com/elastic/elastic-agent/internal/pkg/release"
"github.com/elastic/elastic-agent/pkg/core/logger"
)

// NewDownloader creates a downloader which first checks local directory
// and then fallbacks to remote if configured.
func NewDownloader(log *logger.Logger, config *artifact.Config) (download.Downloader, error) {
func NewDownloader(log *logger.Logger, config *artifact.Config, upgradeDetails *details.Details) (download.Downloader, error) {
downloaders := make([]download.Downloader, 0, 3)
downloaders = append(downloaders, fs.NewDownloader(config))

Expand All @@ -26,15 +27,15 @@ func NewDownloader(log *logger.Logger, config *artifact.Config) (download.Downlo
// a snapshot version of fleet, for example.
// try snapshot repo before official
if release.Snapshot() {
snapDownloader, err := snapshot.NewDownloader(log, config, nil)
snapDownloader, err := snapshot.NewDownloader(log, config, nil, upgradeDetails)
if err != nil {
log.Error(err)
} else {
downloaders = append(downloaders, snapDownloader)
}
}

httpDownloader, err := http.NewDownloader(log, config)
httpDownloader, err := http.NewDownloader(log, config, upgradeDetails)
if err != nil {
return nil, err
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade/artifact"
"github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade/artifact/download"
"github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade/artifact/download/http"
"github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade/details"
"github.com/elastic/elastic-agent/internal/pkg/release"
"github.com/elastic/elastic-agent/pkg/core/logger"
agtversion "github.com/elastic/elastic-agent/pkg/version"
Expand All @@ -32,13 +33,13 @@ type Downloader struct {
// We need to pass the versionOverride separately from the config as
// artifact.Config struct is part of agent configuration and a version
// override makes no sense there
func NewDownloader(log *logger.Logger, config *artifact.Config, versionOverride *agtversion.ParsedSemVer) (download.Downloader, error) {
func NewDownloader(log *logger.Logger, config *artifact.Config, versionOverride *agtversion.ParsedSemVer, upgradeDetails *details.Details) (download.Downloader, error) {
cfg, err := snapshotConfig(config, versionOverride)
if err != nil {
return nil, fmt.Errorf("error creating snapshot config: %w", err)
}

httpDownloader, err := http.NewDownloader(log, cfg)
httpDownloader, err := http.NewDownloader(log, cfg, upgradeDetails)
if err != nil {
return nil, fmt.Errorf("failed to create snapshot downloader: %w", err)
}
Expand Down
13 changes: 10 additions & 3 deletions internal/pkg/agent/application/upgrade/details/details.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@ type DetailsMetadata struct {
ErrorMsg string `yaml:"error_msg" json:"error_msg"`
}

func NewDetails(targetVersion string, initialState State, actionID string, metadata DetailsMetadata) *Details {
func NewDetails(targetVersion string, initialState State, actionID string) *Details {
return &Details{
TargetVersion: targetVersion,
State: initialState,
ActionID: actionID,
Metadata: metadata,
Metadata: DetailsMetadata{},
observers: []Observer{},
}
}
Expand All @@ -55,7 +55,14 @@ func (d *Details) Fail(err error) {
d.mu.Lock()
defer d.mu.Unlock()

d.Metadata.FailedState = d.State
// Record the state the upgrade process was in right before it
// failed, but only do this if we haven't already transitioned the
// state to the StateFailed state; otherwise we'll just end up recording
// the state we failed from as StateFailed which is not useful.
if d.State != StateFailed {
d.Metadata.FailedState = d.State
}

d.Metadata.ErrorMsg = err.Error()
d.State = StateFailed
d.notifyObservers()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,23 @@ import (
)

func TestDetailsNew(t *testing.T) {
det := NewDetails("99.999.9999", StateRequested, "test_action_id", DetailsMetadata{})
det := NewDetails("99.999.9999", StateRequested, "test_action_id")
require.Equal(t, StateRequested, det.State)
require.Equal(t, "99.999.9999", det.TargetVersion)
require.Equal(t, "test_action_id", det.ActionID)
require.Equal(t, DetailsMetadata{}, det.Metadata)
}

func TestDetailsSetState(t *testing.T) {
det := NewDetails("99.999.9999", StateRequested, "test_action_id", DetailsMetadata{})
det := NewDetails("99.999.9999", StateRequested, "test_action_id")
require.Equal(t, StateRequested, det.State)

det.SetState(StateDownloading)
require.Equal(t, StateDownloading, det.State)
}

func TestDetailsFail(t *testing.T) {
det := NewDetails("99.999.9999", StateRequested, "test_action_id", DetailsMetadata{})
det := NewDetails("99.999.9999", StateRequested, "test_action_id")
require.Equal(t, StateRequested, det.State)

err := errors.New("test error")
Expand All @@ -39,7 +39,7 @@ func TestDetailsFail(t *testing.T) {
}

func TestDetailsObserver(t *testing.T) {
det := NewDetails("99.999.9999", StateRequested, "test_action_id", DetailsMetadata{})
det := NewDetails("99.999.9999", StateRequested, "test_action_id")
require.Equal(t, StateRequested, det.State)

var observedDetails *Details
Expand Down
17 changes: 9 additions & 8 deletions internal/pkg/agent/application/upgrade/step_download.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ const (
fleetUpgradeFallbackPGPFormat = "/api/agents/upgrades/%d.%d.%d/pgp-public-key"
)

func (u *Upgrader) downloadArtifact(ctx context.Context, version, sourceURI string, details *details.Details, skipVerifyOverride bool, skipDefaultPgp bool, pgpBytes ...string) (_ string, err error) {
func (u *Upgrader) downloadArtifact(ctx context.Context, version, sourceURI string, upgradeDetails *details.Details, skipVerifyOverride bool, skipDefaultPgp bool, pgpBytes ...string) (_ string, err error) {
span, ctx := apm.StartSpan(ctx, "downloadArtifact", "app.internal")
defer func() {
apm.CaptureError(ctx, err).Send()
Expand Down Expand Up @@ -71,7 +71,7 @@ func (u *Upgrader) downloadArtifact(ctx context.Context, version, sourceURI stri
return "", errors.New(err, fmt.Sprintf("failed to create download directory at %s", paths.Downloads()))
}

path, err := u.downloadWithRetries(ctx, newDownloader, parsedVersion, &settings)
path, err := u.downloadWithRetries(ctx, newDownloader, parsedVersion, &settings, upgradeDetails)
if err != nil {
return "", errors.New(err, "failed download of agent binary")
}
Expand Down Expand Up @@ -123,20 +123,20 @@ func (u *Upgrader) appendFallbackPGP(targetVersion string, pgpBytes []string) []
return pgpBytes
}

func newDownloader(version *agtversion.ParsedSemVer, log *logger.Logger, settings *artifact.Config) (download.Downloader, error) {
func newDownloader(version *agtversion.ParsedSemVer, log *logger.Logger, settings *artifact.Config, upgradeDetails *details.Details) (download.Downloader, error) {
if !version.IsSnapshot() {
return localremote.NewDownloader(log, settings)
return localremote.NewDownloader(log, settings, upgradeDetails)
}

// TODO since we know if it's a snapshot or not, shouldn't we add EITHER the snapshot downloader OR the release one ?

// try snapshot repo before official
snapDownloader, err := snapshot.NewDownloader(log, settings, version)
snapDownloader, err := snapshot.NewDownloader(log, settings, version, upgradeDetails)
if err != nil {
return nil, err
}

httpDownloader, err := http.NewDownloader(log, settings)
httpDownloader, err := http.NewDownloader(log, settings, upgradeDetails)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -171,9 +171,10 @@ func newVerifier(version *agtversion.ParsedSemVer, log *logger.Logger, settings

func (u *Upgrader) downloadWithRetries(
ctx context.Context,
downloaderCtor func(*agtversion.ParsedSemVer, *logger.Logger, *artifact.Config) (download.Downloader, error),
downloaderCtor func(*agtversion.ParsedSemVer, *logger.Logger, *artifact.Config, *details.Details) (download.Downloader, error),
version *agtversion.ParsedSemVer,
settings *artifact.Config,
upgradeDetails *details.Details,
) (string, error) {
cancelCtx, cancel := context.WithTimeout(ctx, settings.Timeout)
defer cancel()
Expand All @@ -189,7 +190,7 @@ func (u *Upgrader) downloadWithRetries(
attempt++
u.log.Infof("download attempt %d", attempt)

downloader, err := downloaderCtor(version, u.log, settings)
downloader, err := downloaderCtor(version, u.log, settings, upgradeDetails)
if err != nil {
return fmt.Errorf("unable to create fetcher: %w", err)
}
Expand Down

0 comments on commit 1a8147d

Please sign in to comment.