Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[8.x](backport #6276) Fix/5163 retry download upgrade verifiers #6378

Merged
merged 2 commits into from
Dec 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Kind can be one of:
# - breaking-change: a change to previously-documented behavior
# - deprecation: functionality that is being removed in a later release
# - bug-fix: fixes a problem in a previous version
# - enhancement: extends functionality but does not break or fix existing behavior
# - feature: new functionality
# - known-issue: problems that we are aware of in a given version
# - security: impacts on the security of a product or a user’s deployment.
# - upgrade: important information for someone upgrading from a prior version
# - other: does not fit into any of the other categories
kind: bug-fix

# Change summary; a 80ish characters long description of the change.
summary: added retries for requesting download verifiers when upgrading the agent

# Long description; in case the summary is not enough to describe the change
# this field accommodate a description without length limits.
# NOTE: This field will be rendered only for breaking-change and known-issue kinds at the moment.
#description:

# Affected component; usually one of "elastic-agent", "fleet-server", "filebeat", "metricbeat", "auditbeat", "all", etc.
component: "elastic-agent"
# PR URL; optional; the PR number that added the changeset.
# If not present is automatically filled by the tooling finding the PR where this changelog fragment has been added.
# NOTE: the tooling supports backports, so it's able to fill the original PR number instead of the backport PR number.
# Please provide it if you are adding a fragment for a different PR.
pr: https://github.com/elastic/elastic-agent/pull/6276
# Issue URL; optional; the GitHub issue related to this changeset (either closes or is part of).
# If not present is automatically filled by the tooling with the issue linked to the PR number.
#issue: https://github.com/owner/repo/1234
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package composed

import (
"context"
goerrors "errors"

"github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade/artifact"
Expand Down Expand Up @@ -39,11 +40,11 @@ func NewVerifier(log *logger.Logger, verifiers ...download.Verifier) *Verifier {
}

// Verify checks the package from configured source.
func (v *Verifier) Verify(a artifact.Artifact, version agtversion.ParsedSemVer, skipDefaultPgp bool, pgpBytes ...string) error {
func (v *Verifier) Verify(ctx context.Context, a artifact.Artifact, version agtversion.ParsedSemVer, skipDefaultPgp bool, pgpBytes ...string) error {
var errs []error

for _, verifier := range v.vv {
e := verifier.Verify(a, version, skipDefaultPgp, pgpBytes...)
e := verifier.Verify(ctx, a, version, skipDefaultPgp, pgpBytes...)
if e == nil {
// Success
return nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package composed

import (
"context"
"errors"
"testing"

Expand All @@ -24,7 +25,7 @@ func (d *ErrorVerifier) Name() string {
return "error"
}

func (d *ErrorVerifier) Verify(artifact.Artifact, agtversion.ParsedSemVer, bool, ...string) error {
func (d *ErrorVerifier) Verify(context.Context, artifact.Artifact, agtversion.ParsedSemVer, bool, ...string) error {
d.called = true
return errors.New("failing")
}
Expand All @@ -39,7 +40,7 @@ func (d *FailVerifier) Name() string {
return "fail"
}

func (d *FailVerifier) Verify(artifact.Artifact, agtversion.ParsedSemVer, bool, ...string) error {
func (d *FailVerifier) Verify(context.Context, artifact.Artifact, agtversion.ParsedSemVer, bool, ...string) error {
d.called = true
return &download.InvalidSignatureError{File: "", Err: errors.New("invalid signature")}
}
Expand All @@ -54,7 +55,7 @@ func (d *SuccVerifier) Name() string {
return "succ"
}

func (d *SuccVerifier) Verify(artifact.Artifact, agtversion.ParsedSemVer, bool, ...string) error {
func (d *SuccVerifier) Verify(context.Context, artifact.Artifact, agtversion.ParsedSemVer, bool, ...string) error {
d.called = true
return nil
}
Expand Down Expand Up @@ -90,7 +91,7 @@ func TestVerifier(t *testing.T) {
testVersion := agtversion.NewParsedSemVer(1, 2, 3, "", "")
for _, tc := range testCases {
d := NewVerifier(log, tc.verifiers[0], tc.verifiers[1], tc.verifiers[2])
err := d.Verify(artifact.Artifact{Name: "a", Cmd: "a", Artifact: "a/a"}, *testVersion, false)
err := d.Verify(context.Background(), artifact.Artifact{Name: "a", Cmd: "a", Artifact: "a/a"}, *testVersion, false)

assert.Equal(t, tc.expectedResult, err == nil)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package fs

import (
"context"
"fmt"
"net/http"
"os"
Expand Down Expand Up @@ -65,7 +66,7 @@ func NewVerifier(log *logger.Logger, config *artifact.Config, pgp []byte) (*Veri

// Verify checks downloaded package on preconfigured
// location against a key stored on elastic.co website.
func (v *Verifier) Verify(a artifact.Artifact, version agtversion.ParsedSemVer, skipDefaultPgp bool, pgpBytes ...string) error {
func (v *Verifier) Verify(ctx context.Context, a artifact.Artifact, version agtversion.ParsedSemVer, skipDefaultPgp bool, pgpBytes ...string) error {
filename, err := artifact.GetArtifactName(a, version, v.config.OS(), v.config.Arch())
if err != nil {
return fmt.Errorf("could not get artifact name: %w", err)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,11 @@ import (

var testVersion = agtversion.NewParsedSemVer(7, 5, 1, "", "")

var (
agentSpec = artifact.Artifact{
Name: "Elastic Agent",
Cmd: "elastic-agent",
Artifact: "beat/elastic-agent"}
)
var agentSpec = artifact.Artifact{
Name: "Elastic Agent",
Cmd: "elastic-agent",
Artifact: "beat/elastic-agent",
}

func TestFetchVerify(t *testing.T) {
// See docs/pgp-sign-verify-artifact.md for how to generate a key, export
Expand All @@ -47,7 +46,8 @@ func TestFetchVerify(t *testing.T) {
targetPath := filepath.Join("testdata", "download")
ctx := context.Background()
a := artifact.Artifact{
Name: "elastic-agent", Cmd: "elastic-agent", Artifact: "beats/elastic-agent"}
Name: "elastic-agent", Cmd: "elastic-agent", Artifact: "beats/elastic-agent",
}
version := agtversion.NewParsedSemVer(8, 0, 0, "", "")

filename := "elastic-agent-8.0.0-darwin-x86_64.tar.gz"
Expand Down Expand Up @@ -80,7 +80,7 @@ func TestFetchVerify(t *testing.T) {
// first download verify should fail:
// download skipped, as invalid package is prepared upfront
// verify fails and cleans download
err = verifier.Verify(a, *version, false)
err = verifier.Verify(ctx, a, *version, false)
var checksumErr *download.ChecksumMismatchError
require.ErrorAs(t, err, &checksumErr)

Expand Down Expand Up @@ -109,7 +109,7 @@ func TestFetchVerify(t *testing.T) {
_, err = os.Stat(ascTargetFilePath)
require.NoError(t, err)

err = verifier.Verify(a, *version, false)
err = verifier.Verify(ctx, a, *version, false)
require.NoError(t, err)

// Bad GPG public key.
Expand All @@ -126,7 +126,7 @@ func TestFetchVerify(t *testing.T) {

// Missing .asc file.
{
err = verifier.Verify(a, *version, false)
err = verifier.Verify(ctx, a, *version, false)
require.Error(t, err)

// Don't delete these files when GPG validation failure.
Expand All @@ -139,7 +139,7 @@ func TestFetchVerify(t *testing.T) {
err = os.WriteFile(targetFilePath+".asc", []byte("bad sig"), 0o600)
require.NoError(t, err)

err = verifier.Verify(a, *version, false)
err = verifier.Verify(ctx, a, *version, false)
var invalidSigErr *download.InvalidSignatureError
assert.ErrorAs(t, err, &invalidSigErr)

Expand All @@ -157,7 +157,8 @@ func prepareFetchVerifyTests(
targetDir,
filename,
targetFilePath,
hashTargetFilePath string) error {
hashTargetFilePath string,
) error {
sourceFilePath := filepath.Join(dropPath, filename)
hashSourceFilePath := filepath.Join(dropPath, filename+".sha512")

Expand Down Expand Up @@ -202,6 +203,7 @@ func TestVerify(t *testing.T) {

for _, tc := range tt {
t.Run(tc.Name, func(t *testing.T) {
ctx := context.Background()
log, obs := loggertest.New("TestVerify")
targetDir := t.TempDir()

Expand All @@ -220,7 +222,7 @@ func TestVerify(t *testing.T) {
pgpKey := prepareTestCase(t, agentSpec, testVersion, config)

testClient := NewDownloader(config)
artifactPath, err := testClient.Download(context.Background(), agentSpec, testVersion)
artifactPath, err := testClient.Download(ctx, agentSpec, testVersion)
require.NoError(t, err, "fs.Downloader could not download artifacts")
_, err = testClient.DownloadAsc(context.Background(), agentSpec, *testVersion)
require.NoError(t, err, "fs.Downloader could not download artifacts .asc file")
Expand All @@ -231,7 +233,7 @@ func TestVerify(t *testing.T) {
testVerifier, err := NewVerifier(log, config, pgpKey)
require.NoError(t, err)

err = testVerifier.Verify(agentSpec, *testVersion, false, tc.RemotePGPUris...)
err = testVerifier.Verify(ctx, agentSpec, *testVersion, false, tc.RemotePGPUris...)
require.NoError(t, err)

// log message informing remote PGP was skipped
Expand All @@ -246,7 +248,6 @@ func TestVerify(t *testing.T) {
// It creates the necessary key to sing the artifact and returns the public key
// to verify the signature.
func prepareTestCase(t *testing.T, a artifact.Artifact, version *agtversion.ParsedSemVer, cfg *artifact.Config) []byte {

filename, err := artifact.GetArtifactName(a, *version, cfg.OperatingSystem, cfg.Architecture)
require.NoErrorf(t, err, "could not get artifact name")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,28 @@ func getTestCases() []testCase {
}
}

func getElasticCoServer(t *testing.T) (*httptest.Server, []byte) {
type extResCode map[string]struct {
resCode int
count int
}

type testDials struct {
extResCode
}

func (td *testDials) withExtResCode(k string, statusCode int, count int) {
td.extResCode[k] = struct {
resCode int
count int
}{statusCode, count}
}

func (td *testDials) reset() {
*td = testDials{extResCode: make(extResCode)}
}

func getElasticCoServer(t *testing.T) (*httptest.Server, []byte, *testDials) {
td := testDials{extResCode: make(extResCode)}
correctValues := map[string]struct{}{
fmt.Sprintf("%s-%s-%s", beatSpec.Cmd, version, "i386.deb"): {},
fmt.Sprintf("%s-%s-%s", beatSpec.Cmd, version, "amd64.deb"): {},
Expand All @@ -81,7 +102,6 @@ func getElasticCoServer(t *testing.T) (*httptest.Server, []byte) {
ext = ".tar.gz"
}
packageName = strings.TrimSuffix(packageName, ext)

switch ext {
case ".sha512":
resp = []byte(fmt.Sprintf("%x %s", hash, packageName))
Expand All @@ -103,11 +123,17 @@ func getElasticCoServer(t *testing.T) (*httptest.Server, []byte) {
return
}

if v, ok := td.extResCode[ext]; ok && v.count != 0 {
w.WriteHeader(v.resCode)
v.count--
td.extResCode[ext] = v
}

_, err := w.Write(resp)
assert.NoErrorf(t, err, "mock elastic.co server: failes writing response")
})

return httptest.NewServer(handler), pub
return httptest.NewServer(handler), pub, &td
}

func getElasticCoClient(server *httptest.Server) http.Client {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func TestDownload(t *testing.T) {
log, _ := logger.New("", false)
timeout := 30 * time.Second
testCases := getTestCases()
server, _ := getElasticCoServer(t)
server, _, _ := getElasticCoServer(t)
elasticClient := getElasticCoClient(server)

config := &artifact.Config{
Expand Down Expand Up @@ -359,7 +359,6 @@ type downloadHttpResponse struct {
}

func TestDownloadVersion(t *testing.T) {

type fields struct {
config *artifact.Config
}
Expand Down Expand Up @@ -485,7 +484,6 @@ func TestDownloadVersion(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {

targetDirPath := t.TempDir()

handleDownload := func(rw http.ResponseWriter, req *http.Request) {
Expand Down Expand Up @@ -527,5 +525,4 @@ func TestDownloadVersion(t *testing.T) {
assert.Equalf(t, filepath.Join(targetDirPath, tt.want), got, "Download(%v, %v)", tt.args.a, tt.args.version)
})
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ func NewVerifier(log *logger.Logger, config *artifact.Config, pgp []byte) (*Veri
httpcommon.WithModRoundtripper(func(rt http.RoundTripper) http.RoundTripper {
return download.WithHeaders(rt, download.Headers)
}),
httpcommon.WithModRoundtripper(func(rt http.RoundTripper) http.RoundTripper {
return WithBackoff(rt, log)
}),
)
if err != nil {
return nil, err
Expand Down Expand Up @@ -88,7 +91,7 @@ func (v *Verifier) Reload(c *artifact.Config) error {

// Verify checks downloaded package on preconfigured
// location against a key stored on elastic.co website.
func (v *Verifier) Verify(a artifact.Artifact, version agtversion.ParsedSemVer, skipDefaultPgp bool, pgpBytes ...string) error {
func (v *Verifier) Verify(ctx context.Context, a artifact.Artifact, version agtversion.ParsedSemVer, skipDefaultPgp bool, pgpBytes ...string) error {
artifactPath, err := artifact.GetArtifactPath(a, version, v.config.OS(), v.config.Arch(), v.config.TargetDirectory)
if err != nil {
return errors.New(err, "retrieving package path")
Expand All @@ -98,7 +101,7 @@ func (v *Verifier) Verify(a artifact.Artifact, version agtversion.ParsedSemVer,
return fmt.Errorf("failed to verify SHA512 hash: %w", err)
}

if err = v.verifyAsc(a, version, skipDefaultPgp, pgpBytes...); err != nil {
if err = v.verifyAsc(ctx, a, version, skipDefaultPgp, pgpBytes...); err != nil {
var invalidSignatureErr *download.InvalidSignatureError
if errors.As(err, &invalidSignatureErr) {
if err := os.Remove(artifactPath); err != nil {
Expand All @@ -116,7 +119,7 @@ func (v *Verifier) Verify(a artifact.Artifact, version agtversion.ParsedSemVer,
return nil
}

func (v *Verifier) verifyAsc(a artifact.Artifact, version agtversion.ParsedSemVer, skipDefaultKey bool, pgpSources ...string) error {
func (v *Verifier) verifyAsc(ctx context.Context, a artifact.Artifact, version agtversion.ParsedSemVer, skipDefaultKey bool, pgpSources ...string) error {
filename, err := artifact.GetArtifactName(a, version, v.config.OS(), v.config.Arch())
if err != nil {
return errors.New(err, "retrieving package name")
Expand All @@ -132,7 +135,7 @@ func (v *Verifier) verifyAsc(a artifact.Artifact, version agtversion.ParsedSemVe
return errors.New(err, "composing URI for fetching asc file", errors.TypeNetwork)
}

ascBytes, err := v.getPublicAsc(ascURI)
ascBytes, err := v.getPublicAsc(ctx, ascURI)
if err != nil {
return errors.New(err, fmt.Sprintf("fetching asc file from %s", ascURI), errors.TypeNetwork, errors.M(errors.MetaKeyURI, ascURI))
}
Expand Down Expand Up @@ -163,8 +166,8 @@ func (v *Verifier) composeURI(filename, artifactName string) (string, error) {
return uri.String(), nil
}

func (v *Verifier) getPublicAsc(sourceURI string) ([]byte, error) {
ctx, cancelFn := context.WithTimeout(context.Background(), 30*time.Second)
func (v *Verifier) getPublicAsc(ctx context.Context, sourceURI string) ([]byte, error) {
ctx, cancelFn := context.WithTimeout(ctx, 30*time.Second)
defer cancelFn()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, sourceURI, nil)
if err != nil {
Expand Down
Loading
Loading