Skip to content

Commit

Permalink
create pgptest package and fix broken tests
Browse files Browse the repository at this point in the history
  • Loading branch information
AndersonQ committed Oct 13, 2023
1 parent aa2f22d commit f99cadb
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,32 +52,34 @@ func (e *Downloader) Download(ctx context.Context, a artifact.Artifact, version
}()

// download from source to dest
path, err := e.download(e.config.OS(), a, version)
path, err := e.download(e.config.OS(), a, version, "")
downloadedFiles = append(downloadedFiles, path)
if err != nil {
return "", err
}

hashPath, err := e.downloadHash(e.config.OS(), a, version)
hashPath, err := e.download(e.config.OS(), a, version, ".sha512")
downloadedFiles = append(downloadedFiles, hashPath)
return path, err
}

func (e *Downloader) download(operatingSystem string, a artifact.Artifact, version string) (string, error) {
filename, err := artifact.GetArtifactName(a, version, operatingSystem, e.config.Arch())
if err != nil {
return "", errors.New(err, "generating package name failed")
}

fullPath, err := artifact.GetArtifactPath(a, version, operatingSystem, e.config.Arch(), e.config.TargetDirectory)
// DownloadAsc downloads the package .asc file from configured source.
// It returns absolute path to the downloaded file and a no-nil error if any occurs.
func (e *Downloader) DownloadAsc(_ context.Context, a artifact.Artifact, version string) (string, error) {
path, err := e.download(e.config.OS(), a, version, ".asc")
if err != nil {
return "", errors.New(err, "generating package path failed")
os.Remove(path)
return "", err
}

return e.downloadFile(filename, fullPath)
return path, nil
}

func (e *Downloader) downloadHash(operatingSystem string, a artifact.Artifact, version string) (string, error) {
func (e *Downloader) download(
operatingSystem string,
a artifact.Artifact,
version,
extension string) (string, error) {
filename, err := artifact.GetArtifactName(a, version, operatingSystem, e.config.Arch())
if err != nil {
return "", errors.New(err, "generating package name failed")
Expand All @@ -88,8 +90,10 @@ func (e *Downloader) downloadHash(operatingSystem string, a artifact.Artifact, v
return "", errors.New(err, "generating package path failed")
}

filename = filename + ".sha512"
fullPath = fullPath + ".sha512"
if extension != "" {
filename += extension
fullPath += extension
}

return e.downloadFile(filename, fullPath)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
package fs

import (
"bytes"
"context"
"crypto/sha512"
"fmt"
"io/ioutil"
"os"
"path"
"path/filepath"
Expand All @@ -23,6 +23,7 @@ import (
"github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade/artifact/download"
"github.com/elastic/elastic-agent/internal/pkg/release"
"github.com/elastic/elastic-agent/pkg/core/logger"
"github.com/elastic/elastic-agent/testing/pgptest"
)

const (
Expand Down Expand Up @@ -203,8 +204,7 @@ func TestVerify(t *testing.T) {
for _, tc := range tt {
t.Run(tc.Name, func(t *testing.T) {
log, obs := logger.NewTesting("TestVerify")
targetDir, err := ioutil.TempDir(os.TempDir(), "")
require.NoError(t, err)
targetDir := t.TempDir()

timeout := 30 * time.Second

Expand All @@ -218,23 +218,18 @@ func TestVerify(t *testing.T) {
},
}

err = prepareTestCase(beatSpec, version, config)
require.NoError(t, err)
pgpKey := prepareTestCase(t, beatSpec, version, config)

testClient := NewDownloader(config)
artifact, err := testClient.Download(context.Background(), beatSpec, version)
require.NoError(t, err)

t.Cleanup(func() {
os.Remove(artifact)
os.Remove(artifact + ".sha512")
os.RemoveAll(config.DropPath)
})
artifactPath, err := testClient.Download(context.Background(), beatSpec, version)
require.NoError(t, err, "fs.Downloader could not download artifacts")
_, err = testClient.DownloadAsc(context.Background(), beatSpec, version)
require.NoError(t, err, "fs.Downloader could not download artifacts .asc file")

_, err = os.Stat(artifact)
_, err = os.Stat(artifactPath)
require.NoError(t, err)

testVerifier, err := NewVerifier(log, config, nil)
testVerifier, err := NewVerifier(log, config, pgpKey)
require.NoError(t, err)

err = testVerifier.Verify(beatSpec, version, false, tc.RemotePGPUris...)
Expand All @@ -247,25 +242,40 @@ func TestVerify(t *testing.T) {
}
}

func prepareTestCase(a artifact.Artifact, version string, cfg *artifact.Config) error {
// prepareTestCase prepares the test case by creating an artifact file defined by
// a and version its corresponding checksum, .sha512, and signature, .asc, files.
// It creates the necessary key to sing the artifact and returns the public key
// to verify the signature.
func prepareTestCase(
t *testing.T,
a artifact.Artifact,
version string,
cfg *artifact.Config) []byte {

filename, err := artifact.GetArtifactName(a, version, cfg.OperatingSystem, cfg.Architecture)
if err != nil {
return err
}
require.NoErrorf(t, err, "could not get artifact name")

if err := os.MkdirAll(cfg.DropPath, 0777); err != nil {
return err
}
err = os.MkdirAll(cfg.DropPath, 0777)
require.NoErrorf(t, err, "failed creating directory %q", cfg.DropPath)

filePath := filepath.Join(cfg.DropPath, filename)
filePathSHA := filePath + ".sha512"
filePathASC := filePath + ".asc"

content := []byte("sample content")
err = os.WriteFile(filePath, content, 0644)
require.NoErrorf(t, err, "could not write %q file", filePath)

hash := sha512.Sum512(content)
hashContent := fmt.Sprintf("%x %s", hash, filename)
err = os.WriteFile(filePathSHA, []byte(hashContent), 0644)
require.NoErrorf(t, err, "could not write %q file", filePathSHA)

if err := ioutil.WriteFile(filepath.Join(cfg.DropPath, filename), content, 0644); err != nil {
return err
}
pub, sig := pgptest.Sing(t, bytes.NewReader(content))
err = os.WriteFile(filePathASC, sig, 0644)
require.NoErrorf(t, err, "could not write %q file", filePathASC)

return ioutil.WriteFile(filepath.Join(cfg.DropPath, filename+".sha512"), []byte(hashContent), 0644)
return pub
}

func assertFileExists(t testing.TB, path string) {
Expand Down
37 changes: 37 additions & 0 deletions testing/pgptest/pgp.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package pgptest

import (
"bytes"
"io"
"testing"

"github.com/stretchr/testify/require"
"golang.org/x/crypto/openpgp"
"golang.org/x/crypto/openpgp/armor"
)

// Sing signs data using RSA. It creates the key, sings data and returns the
// ASCII armored public key and detached signature.
func Sing(t *testing.T, data io.Reader) ([]byte, []byte) {
pub := &bytes.Buffer{}
asc := &bytes.Buffer{}

// Create a new key. The openpgp.Entity hold the private and public keys.
entity, err := openpgp.NewEntity("somekey", "", "", nil)

// Create an encoder to serialize the public key.
wPubKey, err := armor.Encode(pub, openpgp.PublicKeyType, nil)
require.NoError(t, err, "could not create PGP ASCII Armor encoder")

// Writes the public key to the io.Writer padded to armor.Encode.
// Use entity.SerializePrivate if you need the private key.
err = entity.Serialize(wPubKey)
require.NoError(t, err, "could not serialize the public key")
// cannot use defer as it needs to be closed before pub.Bytes() is invoked.
wPubKey.Close()

err = openpgp.ArmoredDetachSign(asc, entity, data, nil)
require.NoError(t, err, "failed signing the data")

return pub.Bytes(), asc.Bytes()
}

0 comments on commit f99cadb

Please sign in to comment.