Skip to content

Commit

Permalink
some improvements to error filtering
Browse files Browse the repository at this point in the history
  • Loading branch information
michalpristas committed Sep 21, 2023
1 parent 6b4f5b6 commit 7169425
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ func (v *Verifier) verifyAsc(fullPath string, skipDefaultPgp bool, pgpSources ..
if len(check) == 0 {
continue
}
raw, err := download.PgpBytesFromSource(v.log, check, v.client)
raw, err := download.PgpBytesFromSource(v.log, check, &v.client)
if err != nil {
return err
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ func (v *Verifier) verifyAsc(a artifact.Artifact, version string, skipDefaultPgp
if len(check) == 0 {
continue
}
raw, err := download.PgpBytesFromSource(v.log, check, v.client)
raw, err := download.PgpBytesFromSource(v.log, check, &v.client)
if err != nil {
return err
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ const (
var (
ErrRemotePGPDownloadFailed = errors.New("Remote PGP download failed")
ErrInvalidLocation = errors.New("Remote PGP location is invalid")
ErrUnknownPGPSource = errors.New("unknown pgp source")
)

// warnLogger is a logger that only needs to implement Warnf, as that is the only functions
Expand Down Expand Up @@ -180,7 +181,7 @@ func VerifyGPGSignature(file string, asciiArmorSignature, publicKey []byte) erro
return nil
}

func PgpBytesFromSource(log warnLogger, source string, client http.Client) ([]byte, error) {
func PgpBytesFromSource(log warnLogger, source string, client HTTPClient) ([]byte, error) {
if strings.HasPrefix(source, PgpSourceRawPrefix) {
return []byte(strings.TrimPrefix(source, PgpSourceRawPrefix)), nil
}
Expand All @@ -189,11 +190,14 @@ func PgpBytesFromSource(log warnLogger, source string, client http.Client) ([]by
pgpBytes, err := fetchPgpFromURI(strings.TrimPrefix(source, PgpSourceURIPrefix), client)
if errors.Is(err, ErrRemotePGPDownloadFailed) || errors.Is(err, ErrInvalidLocation) {
log.Warnf("Skipped remote PGP located at %q because it's unavailable: %v", strings.TrimPrefix(source, PgpSourceURIPrefix), err)
} else if err != nil {
log.Warnf("Failed to fetch remote PGP")
}

return pgpBytes, nil
}

return nil, errors.New("unknown pgp source")
return nil, ErrUnknownPGPSource
}

func CheckValidDownloadUri(rawURI string) error {
Expand All @@ -209,7 +213,7 @@ func CheckValidDownloadUri(rawURI string) error {
return nil
}

func fetchPgpFromURI(uri string, client http.Client) ([]byte, error) {
func fetchPgpFromURI(uri string, client HTTPClient) ([]byte, error) {
if err := CheckValidDownloadUri(uri); err != nil {
return nil, err
}
Expand All @@ -221,7 +225,7 @@ func fetchPgpFromURI(uri string, client http.Client) ([]byte, error) {
if err != nil {
return nil, err
}
resp, err := http.DefaultClient.Do(req)
resp, err := client.Do(req)
if err != nil {
return nil, multierror.Append(err, ErrRemotePGPDownloadFailed)
}
Expand All @@ -233,3 +237,7 @@ func fetchPgpFromURI(uri string, client http.Client) ([]byte, error) {

return ioutil.ReadAll(resp.Body)
}

type HTTPClient interface {
Do(*http.Request) (*http.Response, error)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
// or more contributor license agreements. Licensed under the Elastic License;
// you may not use this file except in compliance with the Elastic License.

package download

import (
"bytes"
"io"
"net/http"
"testing"

Check failure on line 12 in internal/pkg/agent/application/upgrade/artifact/download/verifier_test.go

View workflow job for this annotation

GitHub Actions / lint (macos-latest)

File is not `goimports`-ed with -local github.com/elastic (goimports)
"github.com/elastic/elastic-agent/internal/pkg/agent/errors"
"github.com/elastic/elastic-agent/pkg/core/logger"
"github.com/stretchr/testify/require"
)

func TestPgpBytesFromSource(t *testing.T) {
testCases := []struct {
Name string
Source string
ClientDoErr error
ClientBody []byte
ClientStatus int

ExpectedPGP []byte
ExpectedErr error
ExpectedLogMessage string
}{
{
"successfull call",

Check failure on line 31 in internal/pkg/agent/application/upgrade/artifact/download/verifier_test.go

View workflow job for this annotation

GitHub Actions / lint (macos-latest)

`successfull` is a misspelling of `successful` (misspell)
PgpSourceURIPrefix + "https://location/path",
nil,
[]byte("pgp-body"),
200,
[]byte("pgp-body"),
nil,
"",
},
{
"unknown source call",
"https://location/path",
nil,
[]byte("pgp-body"),
200,
nil,
ErrUnknownPGPSource,
"",
},
{
"invalid location is filtered call",
PgpSourceURIPrefix + "http://location/path",
nil,
[]byte("pgp-body"),
200,
nil,
nil,
"Skipped remote PGP located ",
},
{
"do error is filtered",
PgpSourceURIPrefix + "https://location/path",
errors.New("error"),
[]byte("pgp-body"),
200,
nil,
nil,
"Skipped remote PGP located",
},
{
"invalid status code is filtered out",
PgpSourceURIPrefix + "https://location/path",
nil,
[]byte("pgp-body"),
500,
nil,
nil,
"Failed to fetch remote PGP",
},
{
"invalid status code is filtered out",
PgpSourceURIPrefix + "https://location/path",
nil,
[]byte("pgp-body"),
404,
nil,
nil,
"Failed to fetch remote PGP",
},
}

for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
log, obs := logger.NewTesting(tc.Name)
mockClient := &MockClient{
DoFunc: func(req *http.Request) (*http.Response, error) {
if tc.ClientDoErr != nil {
return nil, tc.ClientDoErr
}

return &http.Response{
StatusCode: tc.ClientStatus,
Body: io.NopCloser(bytes.NewReader(tc.ClientBody)),
}, nil
},
}

resPgp, resErr := PgpBytesFromSource(log, tc.Source, mockClient)
require.Equal(t, tc.ExpectedErr, resErr)
require.Equal(t, tc.ExpectedPGP, resPgp)
if tc.ExpectedLogMessage != "" {
logs := obs.FilterMessageSnippet(tc.ExpectedLogMessage)
require.NotEqual(t, 0, logs.Len())
}

})
}
}

type MockClient struct {
DoFunc func(req *http.Request) (*http.Response, error)
}

func (m *MockClient) Do(req *http.Request) (*http.Response, error) {
return m.DoFunc(req)
}
7 changes: 6 additions & 1 deletion internal/pkg/agent/application/upgrade/step_download.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"context"
"fmt"
"os"
"path"
"strings"
"time"

Expand Down Expand Up @@ -103,7 +104,11 @@ func (u *Upgrader) appendFallbackPGP(targetVersion string, pgpBytes []string) []
// best effort log failure
u.log.Warnf("failed to parse secondary fallback %q: %v", u.fleetServerURI, err)
} else {
secondaryFallback := download.PgpSourceURIPrefix + u.fleetServerURI + fmt.Sprintf(fleetUpgradeFallbackPGPFormat, tpv.Major(), tpv.Minor(), tpv.Patch())
secondaryPath := path.Join(
u.fleetServerURI,
fmt.Sprintf(fleetUpgradeFallbackPGPFormat, tpv.Major(), tpv.Minor(), tpv.Patch()),
)
secondaryFallback := download.PgpSourceURIPrefix + secondaryPath
pgpBytes = append(pgpBytes, secondaryFallback)
}
}
Expand Down

0 comments on commit 7169425

Please sign in to comment.