From 1cc42f0ec063c2c851cdd1e3ef9b8e1096da9a6f Mon Sep 17 00:00:00 2001 From: "Sascha L. Teichmann" Date: Fri, 29 Sep 2023 09:46:51 +0200 Subject: [PATCH] Downloader: unit test forwarder (#470) * Simplify forward method * Add unit test for validation status * Add unit test for stats logging in forwarder. * Add unit test for http client creation. * Add unit test for replaceExt * Add unit test for buildRequest * Add unit test for limitedString * Add unit test for storeFailedAdvisory * Add unit test for storeFailedAdvisory ... fixed * Add unit test for storeFailed * Add unit test for forward * comment wording --- cmd/csaf_downloader/forwarder.go | 42 +-- cmd/csaf_downloader/forwarder_test.go | 429 ++++++++++++++++++++++++++ 2 files changed, 453 insertions(+), 18 deletions(-) create mode 100644 cmd/csaf_downloader/forwarder_test.go diff --git a/cmd/csaf_downloader/forwarder.go b/cmd/csaf_downloader/forwarder.go index 067ba99e..8ceb0e57 100644 --- a/cmd/csaf_downloader/forwarder.go +++ b/cmd/csaf_downloader/forwarder.go @@ -187,16 +187,10 @@ func (f *forwarder) buildRequest( // storeFailedAdvisory stores an advisory in a special folder // in case the forwarding failed. func (f *forwarder) storeFailedAdvisory(filename, doc, sha256, sha512 string) error { - dir := filepath.Join(f.cfg.Directory, failedForwardDir) // Create special folder if it does not exist. - if _, err := os.Stat(dir); err != nil { - if os.IsNotExist(err) { - if err := os.MkdirAll(dir, 0755); err != nil { - return err - } - } else { - return err - } + dir := filepath.Join(f.cfg.Directory, failedForwardDir) + if err := os.MkdirAll(dir, 0755); err != nil { + return err } // Store parts which are not empty. for _, x := range []struct { @@ -226,6 +220,19 @@ func (f *forwarder) storeFailed(filename, doc, sha256, sha512 string) { } } +// limitedString reads max bytes from reader and returns it as a string. +// Longer strings are indicated by "..." as a suffix. +func limitedString(r io.Reader, max int) (string, error) { + var msg strings.Builder + if _, err := io.Copy(&msg, io.LimitReader(r, int64(max))); err != nil { + return "", err + } + if msg.Len() >= max { + msg.WriteString("...") + } + return msg.String(), nil +} + // forward sends a given document with filename, status and // checksums to the forwarder. This is async to the degree // till the configured queue size is filled. @@ -252,16 +259,15 @@ func (f *forwarder) forward( } if res.StatusCode != http.StatusCreated { defer res.Body.Close() - var msg strings.Builder - io.Copy(&msg, io.LimitReader(res.Body, 512)) - var dots string - if msg.Len() >= 512 { - dots = "..." + if msg, err := limitedString(res.Body, 512); err != nil { + slog.Error("reading forward result failed", + "error", err) + } else { + slog.Error("forwarding failed", + "filename", filename, + "body", msg, + "status_code", res.StatusCode) } - slog.Error("forwarding failed", - "filename", filename, - "body", msg.String()+dots, - "status_code", res.StatusCode) f.storeFailed(filename, doc, sha256, sha512) } else { f.succeeded++ diff --git a/cmd/csaf_downloader/forwarder_test.go b/cmd/csaf_downloader/forwarder_test.go new file mode 100644 index 00000000..c7f8634e --- /dev/null +++ b/cmd/csaf_downloader/forwarder_test.go @@ -0,0 +1,429 @@ +// This file is Free Software under the MIT License +// without warranty, see README.md and LICENSES/MIT.txt for details. +// +// SPDX-License-Identifier: MIT +// +// SPDX-FileCopyrightText: 2023 German Federal Office for Information Security (BSI) +// Software-Engineering: 2023 Intevation GmbH + +package main + +import ( + "bufio" + "bytes" + "encoding/json" + "errors" + "io" + "log/slog" + "mime" + "mime/multipart" + "net/http" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/csaf-poc/csaf_distribution/v3/internal/options" + "github.com/csaf-poc/csaf_distribution/v3/util" +) + +func TestValidationStatusUpdate(t *testing.T) { + sv := validValidationStatus + sv.update(invalidValidationStatus) + sv.update(validValidationStatus) + if sv != invalidValidationStatus { + t.Fatalf("got %q expected %q", sv, invalidValidationStatus) + } + sv = notValidatedValidationStatus + sv.update(validValidationStatus) + sv.update(notValidatedValidationStatus) + if sv != notValidatedValidationStatus { + t.Fatalf("got %q expected %q", sv, notValidatedValidationStatus) + } +} + +func TestForwarderLogStats(t *testing.T) { + orig := slog.Default() + defer slog.SetDefault(orig) + + var buf bytes.Buffer + h := slog.NewJSONHandler(&buf, &slog.HandlerOptions{ + Level: slog.LevelInfo, + }) + lg := slog.New(h) + slog.SetDefault(lg) + + cfg := &config{} + fw := newForwarder(cfg) + fw.failed = 11 + fw.succeeded = 13 + + done := make(chan struct{}) + go func() { + defer close(done) + fw.run() + }() + fw.log() + fw.close() + <-done + + type fwStats struct { + Msg string `json:"msg"` + Succeeded int `json:"succeeded"` + Failed int `json:"failed"` + } + sc := bufio.NewScanner(bytes.NewReader(buf.Bytes())) + found := false + for sc.Scan() { + var fws fwStats + if err := json.Unmarshal(sc.Bytes(), &fws); err != nil { + t.Fatalf("JSON parsing log failed: %v", err) + } + if fws.Msg == "Forward statistics" && + fws.Failed == 11 && + fws.Succeeded == 13 { + found = true + break + } + } + if err := sc.Err(); err != nil { + t.Fatalf("scanning log failed: %v", err) + } + if !found { + t.Fatal("Cannot find forward statistics in log") + } +} + +func TestForwarderHTTPClient(t *testing.T) { + cfg := &config{ + ForwardInsecure: true, + ForwardHeader: http.Header{ + "User-Agent": []string{"curl/7.55.1"}, + }, + LogLevel: &options.LogLevel{Level: slog.LevelDebug}, + } + fw := newForwarder(cfg) + if c1, c2 := fw.httpClient(), fw.httpClient(); c1 != c2 { + t.Fatal("expected to return same client twice") + } +} + +func TestForwarderReplaceExtension(t *testing.T) { + for _, x := range [][2]string{ + {"foo", "foo.ext"}, + {"foo.bar", "foo.ext"}, + {".bar", ".ext"}, + {"", ".ext"}, + } { + if got := replaceExt(x[0], ".ext"); got != x[1] { + t.Fatalf("got %q expected %q", got, x[1]) + } + } +} + +func TestForwarderBuildRequest(t *testing.T) { + + // Good case ... + cfg := &config{ + ForwardURL: "https://example.com", + } + fw := newForwarder(cfg) + + req, err := fw.buildRequest( + "test.json", "{}", + invalidValidationStatus, + "256", + "512") + + if err != nil { + t.Fatalf("buildRequest failed: %v", err) + } + mediaType, params, err := mime.ParseMediaType(req.Header.Get("Content-Type")) + if err != nil { + t.Fatalf("no Content-Type found") + } + if !strings.HasPrefix(mediaType, "multipart/") { + t.Fatalf("media type is not multipart") + } + mr := multipart.NewReader(req.Body, params["boundary"]) + + var foundAdvisory, foundValidationStatus, found256, found512 bool + + for { + p, err := mr.NextPart() + if err == io.EOF { + break + } + if err != nil { + t.Fatalf("parsing multipart failed: %v", err) + } + data, err := io.ReadAll(p) + if err != nil { + t.Fatal(err) + } + cd := p.Header["Content-Disposition"] + if len(cd) == 0 { + continue + } + + switch contains := func(name string) bool { + return strings.Contains(cd[0], `name="`+name+`"`) + }; { + case contains("advisory"): + if a := string(data); a != "{}" { + t.Fatalf("advisory: got %q expected %q", a, "{}") + } + foundAdvisory = true + case contains("validation_status"): + if vs := validationStatus(data); vs != invalidValidationStatus { + t.Fatalf("validation_status: got %q expected %q", + vs, invalidValidationStatus) + } + foundValidationStatus = true + case contains("hash-256"): + if h := string(data); h != "256" { + t.Fatalf("hash-256: got %q expected %q", h, "256") + } + found256 = true + case contains("hash-512"): + if h := string(data); h != "512" { + t.Fatalf("hash-512: got %q expected %q", h, "512") + } + found512 = true + } + } + + switch { + case !foundAdvisory: + t.Fatal("advisory not found") + case !foundValidationStatus: + t.Fatal("validation_status not found") + case !found256: + t.Fatal("hash-256 not found") + case !found512: + t.Fatal("hash-512 not found") + } + + // Bad case ... + cfg.ForwardURL = "%" + + if _, err := fw.buildRequest( + "test.json", "{}", + invalidValidationStatus, + "256", + "512", + ); err == nil { + t.Fatal("bad forward URL should result in an error") + } +} + +type badReader struct{ error } + +func (br *badReader) Read([]byte) (int, error) { return 0, br.error } + +func TestLimitedString(t *testing.T) { + for _, x := range [][2]string{ + {"xx", "xx"}, + {"xxx", "xxx..."}, + {"xxxx", "xxx..."}, + } { + got, err := limitedString(strings.NewReader(x[0]), 3) + if err != nil { + t.Fatal(err) + } + if got != x[1] { + t.Fatalf("got %q expected %q", got, x[1]) + } + } + + if _, err := limitedString(&badReader{error: os.ErrInvalid}, 3); err == nil { + t.Fatal("expected to fail with an error") + } +} + +func TestStoreFailedAdvisory(t *testing.T) { + dir, err := os.MkdirTemp("", "storeFailedAdvisory") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dir) + + cfg := &config{Directory: dir} + fw := newForwarder(cfg) + + badDir := filepath.Join(dir, failedForwardDir) + if err := os.WriteFile(badDir, []byte("test"), 0664); err != nil { + t.Fatal(err) + } + + if err := fw.storeFailedAdvisory("advisory.json", "{}", "256", "512"); err == nil { + t.Fatal("if the destination exists as a file an error should occur") + } + + if err := os.Remove(badDir); err != nil { + t.Fatal(err) + } + + if err := fw.storeFailedAdvisory("advisory.json", "{}", "256", "512"); err != nil { + t.Fatal(err) + } + + sha256Path := filepath.Join(dir, failedForwardDir, "advisory.json.sha256") + + // Write protect advisory. + if err := os.Chmod(sha256Path, 0); err != nil { + t.Fatal(err) + } + + if err := fw.storeFailedAdvisory("advisory.json", "{}", "256", "512"); err == nil { + t.Fatal("expected to fail with an error") + } + + if err := os.Chmod(sha256Path, 0644); err != nil { + t.Fatal(err) + } +} + +func TestStoredFailed(t *testing.T) { + dir, err := os.MkdirTemp("", "storeFailed") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dir) + + orig := slog.Default() + defer slog.SetDefault(orig) + + var buf bytes.Buffer + h := slog.NewJSONHandler(&buf, &slog.HandlerOptions{ + Level: slog.LevelError, + }) + lg := slog.New(h) + slog.SetDefault(lg) + + cfg := &config{Directory: dir} + fw := newForwarder(cfg) + + // An empty filename should lead to an error. + fw.storeFailed("", "{}", "256", "512") + + if fw.failed != 1 { + t.Fatalf("got %d expected 1", fw.failed) + } + + type entry struct { + Msg string `json:"msg"` + Level string `json:"level"` + } + + sc := bufio.NewScanner(bytes.NewReader(buf.Bytes())) + found := false + for sc.Scan() { + var e entry + if err := json.Unmarshal(sc.Bytes(), &e); err != nil { + t.Fatalf("JSON parsing log failed: %v", err) + } + if e.Msg == "Storing advisory failed forwarding failed" && e.Level == "ERROR" { + found = true + break + } + } + if err := sc.Err(); err != nil { + t.Fatalf("scanning log failed: %v", err) + } + if !found { + t.Fatal("Cannot error logging statistics in log") + } +} + +type fakeClient struct { + util.Client + state int +} + +func (fc *fakeClient) Do(*http.Request) (*http.Response, error) { + // The different states simulates different responses from the remote API. + switch fc.state { + case 0: + fc.state = 1 + return &http.Response{ + Status: http.StatusText(http.StatusCreated), + StatusCode: http.StatusCreated, + }, nil + case 1: + fc.state = 2 + return nil, errors.New("does not work") + case 2: + fc.state = 3 + return &http.Response{ + Status: http.StatusText(http.StatusBadRequest), + StatusCode: http.StatusBadRequest, + Body: io.NopCloser(&badReader{error: os.ErrInvalid}), + }, nil + default: + return &http.Response{ + Status: http.StatusText(http.StatusBadRequest), + StatusCode: http.StatusBadRequest, + Body: io.NopCloser(strings.NewReader("This was bad!")), + }, nil + } +} + +func TestForwarderForward(t *testing.T) { + dir, err := os.MkdirTemp("", "forward") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dir) + + orig := slog.Default() + defer slog.SetDefault(orig) + + // We dont care in details here as we captured them + // in the other test cases. + h := slog.NewJSONHandler(io.Discard, nil) + lg := slog.New(h) + slog.SetDefault(lg) + + cfg := &config{ + ForwardURL: "http://example.com", + Directory: dir, + } + fw := newForwarder(cfg) + + // Use the fact that http client is cached. + fw.client = &fakeClient{} + + done := make(chan struct{}) + + go func() { + defer close(done) + fw.run() + }() + + // Iterate through states of http client. + for i := 0; i <= 3; i++ { + fw.forward( + "test.json", "{}", + invalidValidationStatus, + "256", + "512") + } + + // Make buildRequest fail. + wait := make(chan struct{}) + fw.cmds <- func(f *forwarder) { + f.cfg.ForwardURL = "%" + close(wait) + } + <-wait + fw.forward( + "test.json", "{}", + invalidValidationStatus, + "256", + "512") + + fw.close() + + <-done +}