From 6f9ff457794500a9f52c3e39d82846b2500dc47f Mon Sep 17 00:00:00 2001 From: koplas Date: Wed, 19 Jun 2024 15:55:05 +0200 Subject: [PATCH] Refactor downloader to allow usage as library - Allow to specify custom logger - Add callback for downloaded documents - Separate config of downloader CLI and library - Remove `Forwarder` tests that test the handler --- cmd/csaf_downloader/config.go | 253 +++++++++++++++++++++++++++++++ cmd/csaf_downloader/main.go | 179 ++++++++++++++-------- lib/downloader/config.go | 179 ++++------------------ lib/downloader/downloader.go | 178 +++++++++------------- lib/downloader/forwarder.go | 70 +++------ lib/downloader/forwarder_test.go | 150 ++++-------------- 6 files changed, 519 insertions(+), 490 deletions(-) create mode 100644 cmd/csaf_downloader/config.go diff --git a/cmd/csaf_downloader/config.go b/cmd/csaf_downloader/config.go new file mode 100644 index 00000000..929ed608 --- /dev/null +++ b/cmd/csaf_downloader/config.go @@ -0,0 +1,253 @@ +// This file is Free Software under the Apache-2.0 License +// without warranty, see README.md and LICENSES/Apache-2.0.txt for details. +// +// SPDX-License-Identifier: Apache-2.0 +// +// SPDX-FileCopyrightText: 2022 German Federal Office for Information Security (BSI) +// Software-Engineering: 2022 Intevation GmbH + +package main + +import ( + "crypto/tls" + "io" + "log" + "log/slog" + "net/http" + "os" + "path/filepath" + "time" + + "github.com/csaf-poc/csaf_distribution/v3/internal/certs" + "github.com/csaf-poc/csaf_distribution/v3/internal/filter" + "github.com/csaf-poc/csaf_distribution/v3/internal/models" + "github.com/csaf-poc/csaf_distribution/v3/internal/options" + "github.com/csaf-poc/csaf_distribution/v3/lib/downloader" +) + +const ( + defaultWorker = 2 + defaultPreset = "mandatory" + defaultForwardQueue = 5 + defaultValidationMode = downloader.ValidationStrict + defaultLogFile = "downloader.log" + defaultLogLevel = slog.LevelInfo +) + +// configPaths are the potential file locations of the Config file. +var configPaths = []string{ + "~/.config/csaf/downloader.toml", + "~/.csaf_downloader.toml", + "csaf_downloader.toml", +} + +type config struct { + Directory string `short:"d" long:"directory" description:"DIRectory to store the downloaded files in" value-name:"DIR" toml:"directory"` + Insecure bool `long:"insecure" description:"Do not check TLS certificates from provider" toml:"insecure"` + IgnoreSignatureCheck bool `long:"ignore_sigcheck" description:"Ignore signature check results, just warn on mismatch" toml:"ignore_sigcheck"` + ClientCert *string `long:"client_cert" description:"TLS client certificate file (PEM encoded data)" value-name:"CERT-FILE" toml:"client_cert"` + ClientKey *string `long:"client_key" description:"TLS client private key file (PEM encoded data)" value-name:"KEY-FILE" toml:"client_key"` + ClientPassphrase *string `long:"client_passphrase" description:"Optional passphrase for the client cert (limited, experimental, see doc)" value-name:"PASSPHRASE" toml:"client_passphrase"` + Version bool `long:"version" description:"Display version of the binary" toml:"-"` + NoStore bool `long:"no_store" short:"n" description:"Do not store files" toml:"no_store"` + Rate *float64 `long:"rate" short:"r" description:"The average upper limit of https operations per second (defaults to unlimited)" toml:"rate"` + Worker int `long:"worker" short:"w" description:"NUMber of concurrent downloads" value-name:"NUM" toml:"worker"` + Range *models.TimeRange `long:"time_range" short:"t" description:"RANGE of time from which advisories to download" value-name:"RANGE" toml:"time_range"` + Folder string `long:"folder" short:"f" description:"Download into a given subFOLDER" value-name:"FOLDER" toml:"folder"` + IgnorePattern []string `long:"ignore_pattern" short:"i" description:"Do not download files if their URLs match any of the given PATTERNs" value-name:"PATTERN" toml:"ignore_pattern"` + ExtraHeader http.Header `long:"header" short:"H" description:"One or more extra HTTP header fields" toml:"header"` + + EnumeratePMDOnly bool `long:"enumerate_pmd_only" description:"If this flag is set to true, the downloader will only enumerate valid provider metadata files, but not download documents" toml:"enumerate_pmd_only"` + + RemoteValidator string `long:"validator" description:"URL to validate documents remotely" value-name:"URL" toml:"validator"` + RemoteValidatorCache string `long:"validator_cache" description:"FILE to cache remote validations" value-name:"FILE" toml:"validator_cache"` + RemoteValidatorPresets []string `long:"validator_preset" description:"One or more PRESETS to validate remotely" value-name:"PRESETS" toml:"validator_preset"` + + //lint:ignore SA5008 We are using choice twice: strict, unsafe. + ValidationMode downloader.ValidationMode `long:"validation_mode" short:"m" choice:"strict" choice:"unsafe" value-name:"MODE" description:"MODE how strict the validation is" toml:"validation_mode"` + + ForwardURL string `long:"forward_url" description:"URL of HTTP endpoint to forward downloads to" value-name:"URL" toml:"forward_url"` + ForwardHeader http.Header `long:"forward_header" description:"One or more extra HTTP header fields used by forwarding" toml:"forward_header"` + ForwardQueue int `long:"forward_queue" description:"Maximal queue LENGTH before forwarder" value-name:"LENGTH" toml:"forward_queue"` + ForwardInsecure bool `long:"forward_insecure" description:"Do not check TLS certificates from forward endpoint" toml:"forward_insecure"` + + LogFile *string `long:"log_file" description:"FILE to log downloading to" value-name:"FILE" toml:"log_file"` + //lint:ignore SA5008 We are using choice or than once: debug, info, warn, error + LogLevel *options.LogLevel `long:"log_level" description:"LEVEL of logging details" value-name:"LEVEL" choice:"debug" choice:"info" choice:"warn" choice:"error" toml:"log_level"` + + Config string `short:"c" long:"config" description:"Path to config TOML file" value-name:"TOML-FILE" toml:"-"` + + clientCerts []tls.Certificate + ignorePattern filter.PatternMatcher + logger *slog.Logger +} + +// parseArgsConfig parses the command line and if needed a config file. +func parseArgsConfig() ([]string, *config, error) { + var ( + logFile = defaultLogFile + logLevel = &options.LogLevel{Level: defaultLogLevel} + ) + p := options.Parser[config]{ + DefaultConfigLocations: configPaths, + ConfigLocation: func(cfg *config) string { return cfg.Config }, + Usage: "[OPTIONS] domain...", + HasVersion: func(cfg *config) bool { return cfg.Version }, + SetDefaults: func(cfg *config) { + cfg.Worker = defaultWorker + cfg.RemoteValidatorPresets = []string{defaultPreset} + cfg.ValidationMode = defaultValidationMode + cfg.ForwardQueue = defaultForwardQueue + cfg.LogFile = &logFile + cfg.LogLevel = logLevel + }, + // Re-establish default values if not set. + EnsureDefaults: func(cfg *config) { + if cfg.Worker == 0 { + cfg.Worker = defaultWorker + } + if cfg.RemoteValidatorPresets == nil { + cfg.RemoteValidatorPresets = []string{defaultPreset} + } + switch cfg.ValidationMode { + case downloader.ValidationStrict, downloader.ValidationUnsafe: + default: + cfg.ValidationMode = downloader.ValidationStrict + } + if cfg.LogFile == nil { + cfg.LogFile = &logFile + } + if cfg.LogLevel == nil { + cfg.LogLevel = logLevel + } + }, + } + return p.Parse() +} + +// prepareDirectory ensures that the working directory +// exists and is setup properly. +func (cfg *config) prepareDirectory() error { + // If not given use current working directory. + if cfg.Directory == "" { + dir, err := os.Getwd() + if err != nil { + return err + } + cfg.Directory = dir + return nil + } + // Use given directory + if _, err := os.Stat(cfg.Directory); err != nil { + // If it does not exist create it. + if os.IsNotExist(err) { + if err = os.MkdirAll(cfg.Directory, 0755); err != nil { + return err + } + } else { + return err + } + } + return nil +} + +// dropSubSeconds drops all parts below resolution of seconds. +func dropSubSeconds(_ []string, a slog.Attr) slog.Attr { + if a.Key == slog.TimeKey { + t := a.Value.Time() + a.Value = slog.TimeValue(t.Truncate(time.Second)) + } + return a +} + +// prepareLogging sets up the structured logging. +func (cfg *config) prepareLogging() error { + var w io.Writer + if cfg.LogFile == nil || *cfg.LogFile == "" { + log.Println("using STDERR for logging") + w = os.Stderr + } else { + var fname string + // We put the log inside the download folder + // if it is not absolute. + if filepath.IsAbs(*cfg.LogFile) { + fname = *cfg.LogFile + } else { + fname = filepath.Join(cfg.Directory, *cfg.LogFile) + } + f, err := os.OpenFile(fname, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0644) + if err != nil { + return err + } + log.Printf("using %q for logging\n", fname) + w = f + } + ho := slog.HandlerOptions{ + // AddSource: true, + Level: cfg.LogLevel.Level, + ReplaceAttr: dropSubSeconds, + } + handler := slog.NewJSONHandler(w, &ho) + cfg.logger = slog.New(handler) + return nil +} + +// compileIgnorePatterns compiles the configure patterns to be ignored. +func (cfg *config) compileIgnorePatterns() error { + pm, err := filter.NewPatternMatcher(cfg.IgnorePattern) + if err != nil { + return err + } + cfg.ignorePattern = pm + return nil +} + +// prepareCertificates loads the client side certificates used by the HTTP client. +func (cfg *config) prepareCertificates() error { + cert, err := certs.LoadCertificate( + cfg.ClientCert, cfg.ClientKey, cfg.ClientPassphrase) + if err != nil { + return err + } + cfg.clientCerts = cert + return nil +} + +// Prepare prepares internal state of a loaded configuration. +func (cfg *config) GetDownloadConfig() (*downloader.Config, error) { + for _, prepare := range []func(*config) error{ + (*config).prepareDirectory, + (*config).prepareLogging, + (*config).prepareCertificates, + (*config).compileIgnorePatterns, + } { + if err := prepare(cfg); err != nil { + return nil, err + } + } + dCfg := &downloader.Config{ + Insecure: cfg.Insecure, + IgnoreSignatureCheck: cfg.IgnoreSignatureCheck, + ClientCerts: cfg.clientCerts, + ClientKey: cfg.ClientKey, + ClientPassphrase: cfg.ClientPassphrase, + Rate: cfg.Rate, + Worker: cfg.Worker, + Range: cfg.Range, + IgnorePattern: cfg.ignorePattern, + ExtraHeader: cfg.ExtraHeader, + + RemoteValidator: cfg.RemoteValidator, + RemoteValidatorCache: cfg.RemoteValidatorCache, + RemoteValidatorPresets: cfg.RemoteValidatorPresets, + + ValidationMode: cfg.ValidationMode, + + ForwardURL: cfg.ForwardURL, + ForwardHeader: cfg.ForwardHeader, + ForwardQueue: cfg.ForwardQueue, + ForwardInsecure: cfg.ForwardInsecure, + Logger: cfg.logger, + } + return dCfg, nil +} diff --git a/cmd/csaf_downloader/main.go b/cmd/csaf_downloader/main.go index 295b0a35..8011f396 100644 --- a/cmd/csaf_downloader/main.go +++ b/cmd/csaf_downloader/main.go @@ -11,75 +11,40 @@ package main import ( "context" - "github.com/csaf-poc/csaf_distribution/v3/lib/downloader" "log/slog" "os" "os/signal" + "path" + "path/filepath" + "strconv" + "strings" + "sync" "github.com/csaf-poc/csaf_distribution/v3/internal/options" + "github.com/csaf-poc/csaf_distribution/v3/lib/downloader" ) -const ( - defaultWorker = 2 - defaultPreset = "mandatory" - defaultForwardQueue = 5 - defaultValidationMode = downloader.ValidationStrict - defaultLogFile = "downloader.log" - defaultLogLevel = slog.LevelInfo -) +// failedForwardDir is the name of the special sub folder +// where advisories get stored which fail forwarding. +const failedForwardDir = "failed_forward" -// configPaths are the potential file locations of the Config file. -var configPaths = []string{ - "~/.config/csaf/downloader.toml", - "~/.csaf_downloader.toml", - "csaf_downloader.toml", -} +// failedValidationDir is the name of the sub folder +// where advisories are stored that fail validation in +// unsafe mode. +const failedValidationDir = "failed_validation" -// parseArgsConfig parses the command line and if needed a config file. -func parseArgsConfig() ([]string, *downloader.Config, error) { - var ( - logFile = defaultLogFile - logLevel = &options.LogLevel{Level: defaultLogLevel} - ) - p := options.Parser[downloader.Config]{ - DefaultConfigLocations: configPaths, - ConfigLocation: func(cfg *downloader.Config) string { return cfg.Config }, - Usage: "[OPTIONS] domain...", - HasVersion: func(cfg *downloader.Config) bool { return cfg.Version }, - SetDefaults: func(cfg *downloader.Config) { - cfg.Worker = defaultWorker - cfg.RemoteValidatorPresets = []string{defaultPreset} - cfg.ValidationMode = defaultValidationMode - cfg.ForwardQueue = defaultForwardQueue - cfg.LogFile = &logFile - cfg.LogLevel = logLevel - }, - // Re-establish default values if not set. - EnsureDefaults: func(cfg *downloader.Config) { - if cfg.Worker == 0 { - cfg.Worker = defaultWorker - } - if cfg.RemoteValidatorPresets == nil { - cfg.RemoteValidatorPresets = []string{defaultPreset} - } - switch cfg.ValidationMode { - case downloader.ValidationStrict, downloader.ValidationUnsafe: - default: - cfg.ValidationMode = downloader.ValidationStrict - } - if cfg.LogFile == nil { - cfg.LogFile = &logFile - } - if cfg.LogLevel == nil { - cfg.LogLevel = logLevel - } - }, +var mkdirMu sync.Mutex + +func run(cfg *config, domains []string) error { + dCfg, err := cfg.GetDownloadConfig() + if err != nil { + return err } - return p.Parse() -} -func run(cfg *downloader.Config, domains []string) error { - d, err := downloader.NewDownloader(cfg) + dCfg.DownloadHandler = downloadHandler(cfg) + dCfg.FailedForwardHandler = storeFailedAdvisory(cfg) + + d, err := downloader.NewDownloader(dCfg) if err != nil { return err } @@ -91,7 +56,7 @@ func run(cfg *downloader.Config, domains []string) error { defer stop() if cfg.ForwardURL != "" { - f := downloader.NewForwarder(cfg) + f := downloader.NewForwarder(dCfg) go f.Run() defer func() { f.Log() @@ -108,11 +73,103 @@ func run(cfg *downloader.Config, domains []string) error { return d.Run(ctx, domains) } -func main() { +func mkdirAll(path string, perm os.FileMode) error { + mkdirMu.Lock() + defer mkdirMu.Unlock() + return os.MkdirAll(path, perm) +} + +func downloadHandler(cfg *config) func(d downloader.DownloadedDocument) error { + return func(d downloader.DownloadedDocument) error { + if cfg.NoStore { + // Do not write locally. + if d.ValStatus == downloader.ValidValidationStatus { + return nil + } + } + var lastDir string + + // Advisories that failed validation are stored in a special folder. + var newDir string + if d.ValStatus != downloader.ValidValidationStatus { + newDir = path.Join(cfg.Directory, failedValidationDir) + } else { + newDir = cfg.Directory + } + + lower := strings.ToLower(string(d.Label)) + + // Do we have a configured destination folder? + if cfg.Folder != "" { + newDir = path.Join(newDir, cfg.Folder) + } else { + newDir = path.Join(newDir, lower, strconv.Itoa(d.InitialReleaseDate.Year())) + } + + if newDir != lastDir { + if err := mkdirAll(newDir, 0755); err != nil { + return err + } + lastDir = newDir + } + + // Write advisory to file + filePath := filepath.Join(lastDir, d.Filename) + + for _, x := range []struct { + p string + d []byte + }{ + {filePath, d.Data.Bytes()}, + {filePath + ".sha256", d.S256Data}, + {filePath + ".sha512", d.S512Data}, + {filePath + ".asc", d.SignData}, + } { + if x.d != nil { + if err := os.WriteFile(x.p, x.d, 0644); err != nil { + return err + } + } + } + + slog.Info("Written advisory", "path", filePath) + return nil + } +} + +// storeFailedAdvisory stores an advisory in a special folder +// in case the forwarding failed. +func storeFailedAdvisory(cfg *config) func(filename, doc, sha256, sha512 string) error { + return func(filename, doc, sha256, sha512 string) error { + // Create special folder if it does not exist. + dir := filepath.Join(cfg.Directory, failedForwardDir) + if err := os.MkdirAll(dir, 0755); err != nil { + return err + } + // Store parts which are not empty. + for _, x := range []struct { + p string + d string + }{ + {filename, doc}, + {filename + ".sha256", sha256}, + {filename + ".sha512", sha512}, + } { + if len(x.d) != 0 { + path := filepath.Join(dir, x.p) + if err := os.WriteFile(path, []byte(x.d), 0644); err != nil { + return err + } + } + } + return nil + } +} + +func main() { domains, cfg, err := parseArgsConfig() options.ErrorCheck(err) - options.ErrorCheck(cfg.Prepare()) if len(domains) == 0 { slog.Warn("No domains given.") diff --git a/lib/downloader/config.go b/lib/downloader/config.go index 30121ca4..1bd3200d 100644 --- a/lib/downloader/config.go +++ b/lib/downloader/config.go @@ -11,18 +11,11 @@ package downloader import ( "crypto/tls" "fmt" - "io" - "log" "log/slog" "net/http" - "os" - "path/filepath" - "time" - "github.com/csaf-poc/csaf_distribution/v3/internal/certs" "github.com/csaf-poc/csaf_distribution/v3/internal/filter" "github.com/csaf-poc/csaf_distribution/v3/internal/models" - "github.com/csaf-poc/csaf_distribution/v3/internal/options" ) // ValidationMode specifies the strict the validation is. @@ -37,43 +30,33 @@ const ( // Config provides the download configuration. type Config struct { - Directory string `short:"d" long:"directory" description:"DIRectory to store the downloaded files in" value-name:"DIR" toml:"directory"` - Insecure bool `long:"insecure" description:"Do not check TLS certificates from provider" toml:"insecure"` - IgnoreSignatureCheck bool `long:"ignore_sigcheck" description:"Ignore signature check results, just warn on mismatch" toml:"ignore_sigcheck"` - ClientCert *string `long:"client_cert" description:"TLS client certificate file (PEM encoded data)" value-name:"CERT-FILE" toml:"client_cert"` - ClientKey *string `long:"client_key" description:"TLS client private key file (PEM encoded data)" value-name:"KEY-FILE" toml:"client_key"` - ClientPassphrase *string `long:"client_passphrase" description:"Optional passphrase for the client cert (limited, experimental, see doc)" value-name:"PASSPHRASE" toml:"client_passphrase"` - Version bool `long:"version" description:"Display version of the binary" toml:"-"` - NoStore bool `long:"no_store" short:"n" description:"Do not store files" toml:"no_store"` - Rate *float64 `long:"rate" short:"r" description:"The average upper limit of https operations per second (defaults to unlimited)" toml:"rate"` - Worker int `long:"worker" short:"w" description:"NUMber of concurrent downloads" value-name:"NUM" toml:"worker"` - Range *models.TimeRange `long:"time_range" short:"t" description:"RANGE of time from which advisories to download" value-name:"RANGE" toml:"time_range"` - Folder string `long:"folder" short:"f" description:"Download into a given subFOLDER" value-name:"FOLDER" toml:"folder"` - IgnorePattern []string `long:"ignore_pattern" short:"i" description:"Do not download files if their URLs match any of the given PATTERNs" value-name:"PATTERN" toml:"ignore_pattern"` - ExtraHeader http.Header `long:"header" short:"H" description:"One or more extra HTTP header fields" toml:"header"` - - EnumeratePMDOnly bool `long:"enumerate_pmd_only" description:"If this flag is set to true, the downloader will only enumerate valid provider metadata files, but not download documents" toml:"enumerate_pmd_only"` - - RemoteValidator string `long:"validator" description:"URL to validate documents remotely" value-name:"URL" toml:"validator"` - RemoteValidatorCache string `long:"validator_cache" description:"FILE to cache remote validations" value-name:"FILE" toml:"validator_cache"` - RemoteValidatorPresets []string `long:"validator_preset" description:"One or more PRESETS to validate remotely" value-name:"PRESETS" toml:"validator_preset"` - - //lint:ignore SA5008 We are using choice twice: strict, unsafe. - ValidationMode ValidationMode `long:"validation_mode" short:"m" choice:"strict" choice:"unsafe" value-name:"MODE" description:"MODE how strict the validation is" toml:"validation_mode"` - - ForwardURL string `long:"forward_url" description:"URL of HTTP endpoint to forward downloads to" value-name:"URL" toml:"forward_url"` - ForwardHeader http.Header `long:"forward_header" description:"One or more extra HTTP header fields used by forwarding" toml:"forward_header"` - ForwardQueue int `long:"forward_queue" description:"Maximal queue LENGTH before forwarder" value-name:"LENGTH" toml:"forward_queue"` - ForwardInsecure bool `long:"forward_insecure" description:"Do not check TLS certificates from forward endpoint" toml:"forward_insecure"` - - LogFile *string `long:"log_file" description:"FILE to log downloading to" value-name:"FILE" toml:"log_file"` - //lint:ignore SA5008 We are using choice or than once: debug, info, warn, error - LogLevel *options.LogLevel `long:"log_level" description:"LEVEL of logging details" value-name:"LEVEL" choice:"debug" choice:"info" choice:"warn" choice:"error" toml:"log_level"` - - Config string `short:"c" long:"config" description:"Path to config TOML file" value-name:"TOML-FILE" toml:"-"` - - clientCerts []tls.Certificate - ignorePattern filter.PatternMatcher + Insecure bool + IgnoreSignatureCheck bool + ClientCerts []tls.Certificate + ClientKey *string + ClientPassphrase *string + Rate *float64 + Worker int + Range *models.TimeRange + IgnorePattern filter.PatternMatcher + ExtraHeader http.Header + + RemoteValidator string + // CLI only? + RemoteValidatorCache string + RemoteValidatorPresets []string + + ValidationMode ValidationMode + + ForwardURL string + ForwardHeader http.Header + ForwardQueue int + ForwardInsecure bool + + DownloadHandler func(DownloadedDocument) error + FailedForwardHandler func(filename, doc, sha256, sha512 string) error + + Logger *slog.Logger } // UnmarshalText implements [encoding.TextUnmarshaler]. @@ -99,114 +82,10 @@ func (vm *ValidationMode) UnmarshalFlag(value string) error { // ignoreFile returns true if the given URL should not be downloaded. func (cfg *Config) ignoreURL(u string) bool { - return cfg.ignorePattern.Matches(u) + return cfg.IgnorePattern.Matches(u) } // verbose is considered a log level equal or less debug. func (cfg *Config) verbose() bool { - return cfg.LogLevel.Level <= slog.LevelDebug -} - -// prepareDirectory ensures that the working directory -// exists and is setup properly. -func (cfg *Config) prepareDirectory() error { - // If not given use current working directory. - if cfg.Directory == "" { - dir, err := os.Getwd() - if err != nil { - return err - } - cfg.Directory = dir - return nil - } - // Use given directory - if _, err := os.Stat(cfg.Directory); err != nil { - // If it does not exist create it. - if os.IsNotExist(err) { - if err = os.MkdirAll(cfg.Directory, 0755); err != nil { - return err - } - } else { - return err - } - } - return nil -} - -// dropSubSeconds drops all parts below resolution of seconds. -func dropSubSeconds(_ []string, a slog.Attr) slog.Attr { - if a.Key == slog.TimeKey { - t := a.Value.Time() - a.Value = slog.TimeValue(t.Truncate(time.Second)) - } - return a -} - -// prepareLogging sets up the structured logging. -func (cfg *Config) prepareLogging() error { - var w io.Writer - if cfg.LogFile == nil || *cfg.LogFile == "" { - log.Println("using STDERR for logging") - w = os.Stderr - } else { - var fname string - // We put the log inside the download folder - // if it is not absolute. - if filepath.IsAbs(*cfg.LogFile) { - fname = *cfg.LogFile - } else { - fname = filepath.Join(cfg.Directory, *cfg.LogFile) - } - f, err := os.OpenFile(fname, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0644) - if err != nil { - return err - } - log.Printf("using %q for logging\n", fname) - w = f - } - ho := slog.HandlerOptions{ - //AddSource: true, - Level: cfg.LogLevel.Level, - ReplaceAttr: dropSubSeconds, - } - handler := slog.NewJSONHandler(w, &ho) - logger := slog.New(handler) - slog.SetDefault(logger) - return nil -} - -// compileIgnorePatterns compiles the configure patterns to be ignored. -func (cfg *Config) compileIgnorePatterns() error { - pm, err := filter.NewPatternMatcher(cfg.IgnorePattern) - if err != nil { - return err - } - cfg.ignorePattern = pm - return nil -} - -// prepareCertificates loads the client side certificates used by the HTTP client. -func (cfg *Config) prepareCertificates() error { - cert, err := certs.LoadCertificate( - cfg.ClientCert, cfg.ClientKey, cfg.ClientPassphrase) - if err != nil { - return err - } - cfg.clientCerts = cert - return nil -} - -// Prepare prepares internal state of a loaded configuration. -func (cfg *Config) Prepare() error { - for _, prepare := range []func(*Config) error{ - (*Config).prepareDirectory, - (*Config).prepareLogging, - (*Config).prepareCertificates, - (*Config).compileIgnorePatterns, - } { - if err := prepare(cfg); err != nil { - return err - } - } - return nil + return cfg.Logger.Enabled(nil, slog.LevelDebug) } diff --git a/lib/downloader/downloader.go b/lib/downloader/downloader.go index ab3bc9a7..a43118a3 100644 --- a/lib/downloader/downloader.go +++ b/lib/downloader/downloader.go @@ -22,10 +22,7 @@ import ( "log/slog" "net/http" "net/url" - "os" - "path" "path/filepath" - "strconv" "strings" "sync" "time" @@ -49,6 +46,18 @@ type Downloader struct { stats stats } +// DownloadedDocument conatins the document data with additional metadata. +type DownloadedDocument struct { + Data bytes.Buffer + S256Data []byte + S512Data []byte + SignData []byte + InitialReleaseDate time.Time + Filename string + ValStatus ValidationStatus + Label csaf.TLPLabel +} + // failedValidationDir is the name of the sub folder // where advisories are stored that fail validation in // unsafe mode. @@ -56,7 +65,6 @@ const failedValidationDir = "failed_validation" // NewDownloader constructs a new downloader given the configuration. func NewDownloader(cfg *Config) (*Downloader, error) { - var validator csaf.RemoteValidator if cfg.RemoteValidator != "" { @@ -96,23 +104,24 @@ func (d *Downloader) addStats(o *stats) { } // logRedirect logs redirects of the http client. -func logRedirect(req *http.Request, via []*http.Request) error { - vs := make([]string, len(via)) - for i, v := range via { - vs[i] = v.URL.String() - } - slog.Debug("Redirecting", - "to", req.URL.String(), - "via", strings.Join(vs, " -> ")) - return nil +func logRedirect(logger *slog.Logger) func(req *http.Request, via []*http.Request) error { + return func(req *http.Request, via []*http.Request) error { + vs := make([]string, len(via)) + for i, v := range via { + vs[i] = v.URL.String() + } + logger.Debug("Redirecting", + "to", req.URL.String(), + "via", strings.Join(vs, " -> ")) + return nil + } } func (d *Downloader) httpClient() util.Client { - hClient := http.Client{} if d.cfg.verbose() { - hClient.CheckRedirect = logRedirect + hClient.CheckRedirect = logRedirect(d.cfg.Logger) } var tlsConfig tls.Config @@ -120,8 +129,8 @@ func (d *Downloader) httpClient() util.Client { tlsConfig.InsecureSkipVerify = true } - if len(d.cfg.clientCerts) != 0 { - tlsConfig.Certificates = d.cfg.clientCerts + if len(d.cfg.ClientCerts) != 0 { + tlsConfig.Certificates = d.cfg.ClientCerts } hClient.Transport = &http.Transport{ @@ -142,7 +151,7 @@ func (d *Downloader) httpClient() util.Client { if d.cfg.verbose() { client = &util.LoggingClient{ Client: client, - Log: httpLog("downloader"), + Log: httpLog("downloader", d.cfg.Logger), } } @@ -158,9 +167,9 @@ func (d *Downloader) httpClient() util.Client { } // httpLog does structured logging in a [util.LoggingClient]. -func httpLog(who string) func(string, string) { +func httpLog(who string, logger *slog.Logger) func(string, string) { return func(method, url string) { - slog.Debug("http", + logger.Debug("http", "who", who, "method", method, "url", url) @@ -178,7 +187,7 @@ func (d *Downloader) enumerate(domain string) error { for _, pmd := range lpmd { if d.cfg.verbose() { for i := range pmd.Messages { - slog.Debug("Enumerating provider-metadata.json", + d.cfg.Logger.Debug("Enumerating provider-metadata.json", "domain", domain, "message", pmd.Messages[i].Message) } @@ -190,7 +199,7 @@ func (d *Downloader) enumerate(domain string) error { // print the results doc, err := json.MarshalIndent(docs, "", " ") if err != nil { - slog.Error("Couldn't marshal PMD document json") + d.cfg.Logger.Error("Couldn't marshal PMD document json") } fmt.Println(string(doc)) @@ -206,7 +215,7 @@ func (d *Downloader) download(ctx context.Context, domain string) error { if d.cfg.verbose() { for i := range lpmd.Messages { - slog.Debug("Loading provider-metadata.json", + d.cfg.Logger.Debug("Loading provider-metadata.json", "domain", domain, "message", lpmd.Messages[i].Message) } @@ -237,7 +246,7 @@ func (d *Downloader) download(ctx context.Context, domain string) error { // Do we need time range based filtering? if d.cfg.Range != nil { - slog.Debug("Setting up filter to accept advisories within", + d.cfg.Logger.Debug("Setting up filter to accept advisories within", "timerange", d.cfg.Range) afp.AgeAccept = d.cfg.Range.Contains } @@ -252,7 +261,6 @@ func (d *Downloader) downloadFiles( label csaf.TLPLabel, files []csaf.AdvisoryFile, ) error { - var ( advisoryCh = make(chan csaf.AdvisoryFile) errorCh = make(chan error) @@ -301,7 +309,6 @@ func (d *Downloader) loadOpenPGPKeys( doc any, base *url.URL, ) error { - src, err := d.eval.Eval("$.public_openpgp_keys", doc) if err != nil { // no keys. @@ -326,7 +333,7 @@ func (d *Downloader) loadOpenPGPKeys( } up, err := url.Parse(*key.URL) if err != nil { - slog.Warn("Invalid URL", + d.cfg.Logger.Warn("Invalid URL", "url", *key.URL, "error", err) continue @@ -336,14 +343,14 @@ func (d *Downloader) loadOpenPGPKeys( res, err := client.Get(u) if err != nil { - slog.Warn( + d.cfg.Logger.Warn( "Fetching public OpenPGP key failed", "url", u, "error", err) continue } if res.StatusCode != http.StatusOK { - slog.Warn( + d.cfg.Logger.Warn( "Fetching public OpenPGP key failed", "url", u, "status_code", res.StatusCode, @@ -355,9 +362,8 @@ func (d *Downloader) loadOpenPGPKeys( defer res.Body.Close() return crypto.NewKeyFromArmoredReader(res.Body) }() - if err != nil { - slog.Warn( + d.cfg.Logger.Warn( "Reading public OpenPGP key failed", "url", u, "error", err) @@ -365,14 +371,14 @@ func (d *Downloader) loadOpenPGPKeys( } if !strings.EqualFold(ckey.GetFingerprint(), string(key.Fingerprint)) { - slog.Warn( + d.cfg.Logger.Warn( "Fingerprint of public OpenPGP key does not match remotely loaded", "url", u) continue } if d.keys == nil { if keyring, err := crypto.NewKeyRing(ckey); err != nil { - slog.Warn( + d.cfg.Logger.Warn( "Creating store for public OpenPGP key failed", "url", u, "error", err) @@ -389,18 +395,18 @@ func (d *Downloader) loadOpenPGPKeys( // logValidationIssues logs the issues reported by the advisory schema validation. func (d *Downloader) logValidationIssues(url string, errors []string, err error) { if err != nil { - slog.Error("Failed to validate", + d.cfg.Logger.Error("Failed to validate", "url", url, "error", err) return } if len(errors) > 0 { if d.cfg.verbose() { - slog.Error("CSAF file has validation errors", + d.cfg.Logger.Error("CSAF file has validation errors", "url", url, "error", strings.Join(errors, ", ")) } else { - slog.Error("CSAF file has validation errors", + d.cfg.Logger.Error("CSAF file has validation errors", "url", url, "count", len(errors)) } @@ -419,10 +425,8 @@ func (d *Downloader) downloadWorker( var ( client = d.httpClient() data bytes.Buffer - lastDir string initialReleaseDate time.Time dateExtract = util.TimeMatcher(&initialReleaseDate, time.RFC3339) - lower = strings.ToLower(string(label)) stats = stats{} ) @@ -445,14 +449,14 @@ nextAdvisory: u, err := url.Parse(file.URL()) if err != nil { stats.downloadFailed++ - slog.Warn("Ignoring invalid URL", + d.cfg.Logger.Warn("Ignoring invalid URL", "url", file.URL(), "error", err) continue } if d.cfg.ignoreURL(file.URL()) { - slog.Debug("Ignoring URL", "url", file.URL()) + d.cfg.Logger.Debug("Ignoring URL", "url", file.URL()) continue } @@ -460,7 +464,7 @@ nextAdvisory: filename := filepath.Base(u.Path) if !util.ConformingFileName(filename) { stats.filenameFailed++ - slog.Warn("Ignoring none conforming filename", + d.cfg.Logger.Warn("Ignoring none conforming filename", "filename", filename) continue } @@ -468,7 +472,7 @@ nextAdvisory: resp, err := client.Get(file.URL()) if err != nil { stats.downloadFailed++ - slog.Warn("Cannot GET", + d.cfg.Logger.Warn("Cannot GET", "url", file.URL(), "error", err) continue @@ -476,7 +480,7 @@ nextAdvisory: if resp.StatusCode != http.StatusOK { stats.downloadFailed++ - slog.Warn("Cannot load", + d.cfg.Logger.Warn("Cannot load", "url", file.URL(), "status", resp.Status, "status_code", resp.StatusCode) @@ -485,7 +489,7 @@ nextAdvisory: // Warn if we do not get JSON. if ct := resp.Header.Get("Content-Type"); ct != "application/json" { - slog.Warn("Content type is not 'application/json'", + d.cfg.Logger.Warn("Content type is not 'application/json'", "url", file.URL(), "content_type", ct) } @@ -500,7 +504,7 @@ nextAdvisory: // Only hash when we have a remote counter part we can compare it with. if remoteSHA256, s256Data, err = loadHash(client, file.SHA256URL()); err != nil { - slog.Warn("Cannot fetch SHA256", + d.cfg.Logger.Warn("Cannot fetch SHA256", "url", file.SHA256URL(), "error", err) } else { @@ -509,7 +513,7 @@ nextAdvisory: } if remoteSHA512, s512Data, err = loadHash(client, file.SHA512URL()); err != nil { - slog.Warn("Cannot fetch SHA512", + d.cfg.Logger.Warn("Cannot fetch SHA512", "url", file.SHA512URL(), "error", err) } else { @@ -532,7 +536,7 @@ nextAdvisory: return json.NewDecoder(tee).Decode(&doc) }(); err != nil { stats.downloadFailed++ - slog.Warn("Downloading failed", + d.cfg.Logger.Warn("Downloading failed", "url", file.URL(), "error", err) continue @@ -564,7 +568,7 @@ nextAdvisory: var sign *crypto.PGPSignature sign, signData, err = loadSignature(client, file.SignURL()) if err != nil { - slog.Warn("Downloading signature failed", + d.cfg.Logger.Warn("Downloading signature failed", "url", file.SignURL(), "error", err) } @@ -618,7 +622,7 @@ nextAdvisory: } // Run all the validations. - valStatus := notValidatedValidationStatus + valStatus := NotValidatedValidationStatus for _, check := range []func() error{ s256Check, s512Check, @@ -628,14 +632,14 @@ nextAdvisory: remoteValidatorCheck, } { if err := check(); err != nil { - slog.Error("Validation check failed", "error", err) - valStatus.update(invalidValidationStatus) + d.cfg.Logger.Error("Validation check failed", "error", err) + valStatus.update(InvalidValidationStatus) if d.cfg.ValidationMode == ValidationStrict { continue nextAdvisory } } } - valStatus.update(validValidationStatus) + valStatus.update(ValidValidationStatus) // Send to Forwarder if d.Forwarder != nil { @@ -645,15 +649,6 @@ nextAdvisory: string(s256Data), string(s512Data)) } - - if d.cfg.NoStore { - // Do not write locally. - if valStatus == validValidationStatus { - stats.succeeded++ - } - continue - } - if err := d.eval.Extract( `$.document.tracking.initial_release_date`, dateExtract, false, doc, ); err != nil { @@ -663,61 +658,26 @@ nextAdvisory: } initialReleaseDate = initialReleaseDate.UTC() - // Advisories that failed validation are stored in a special folder. - var newDir string - if valStatus != validValidationStatus { - newDir = path.Join(d.cfg.Directory, failedValidationDir) - } else { - newDir = d.cfg.Directory + download := DownloadedDocument{ + Data: data, + S256Data: s256Data, + S512Data: s512Data, + SignData: signData, + InitialReleaseDate: initialReleaseDate, + Filename: filename, + ValStatus: valStatus, + Label: label, } - // Do we have a configured destination folder? - if d.cfg.Folder != "" { - newDir = path.Join(newDir, d.cfg.Folder) + err = d.cfg.DownloadHandler(download) + if err != nil { + errorCh <- err } else { - newDir = path.Join(newDir, lower, strconv.Itoa(initialReleaseDate.Year())) - } - - if newDir != lastDir { - if err := d.mkdirAll(newDir, 0755); err != nil { - errorCh <- err - continue - } - lastDir = newDir + stats.succeeded++ } - - // Write advisory to file - filePath := filepath.Join(lastDir, filename) - - // Write data to disk. - for _, x := range []struct { - p string - d []byte - }{ - {filePath, data.Bytes()}, - {filePath + ".sha256", s256Data}, - {filePath + ".sha512", s512Data}, - {filePath + ".asc", signData}, - } { - if x.d != nil { - if err := os.WriteFile(x.p, x.d, 0644); err != nil { - errorCh <- err - continue nextAdvisory - } - } - } - - stats.succeeded++ - slog.Info("Written advisory", "path", filePath) } } -func (d *Downloader) mkdirAll(path string, perm os.FileMode) error { - d.mkdirMu.Lock() - defer d.mkdirMu.Unlock() - return os.MkdirAll(path, perm) -} - func (d *Downloader) checkSignature(data []byte, sign *crypto.PGPSignature) error { pm := crypto.NewPlainMessage(data) t := crypto.GetUnixTime() diff --git a/lib/downloader/forwarder.go b/lib/downloader/forwarder.go index 78adcaed..9e6de35d 100644 --- a/lib/downloader/forwarder.go +++ b/lib/downloader/forwarder.go @@ -12,10 +12,8 @@ import ( "bytes" "crypto/tls" "io" - "log/slog" "mime/multipart" "net/http" - "os" "path/filepath" "strings" @@ -27,19 +25,22 @@ import ( // where advisories get stored which fail forwarding. const failedForwardDir = "failed_forward" -// validationStatus represents the validation status +// ValidationStatus represents the validation status // known to the HTTP endpoint. -type validationStatus string +type ValidationStatus string const ( - validValidationStatus = validationStatus("valid") - invalidValidationStatus = validationStatus("invalid") - notValidatedValidationStatus = validationStatus("not_validated") + // ValidValidationStatus represents a valid document. + ValidValidationStatus = ValidationStatus("valid") + // InvalidValidationStatus represents an invalid document. + InvalidValidationStatus = ValidationStatus("invalid") + // NotValidatedValidationStatus represents a not validated document. + NotValidatedValidationStatus = ValidationStatus("not_validated") ) -func (vs *validationStatus) update(status validationStatus) { +func (vs *ValidationStatus) update(status ValidationStatus) { // Cannot heal after it fails at least once. - if *vs != invalidValidationStatus { + if *vs != InvalidValidationStatus { *vs = status } } @@ -69,7 +70,7 @@ func NewForwarder(cfg *Config) *Forwarder { // Run runs the Forwarder. Meant to be used in a Go routine. func (f *Forwarder) Run() { - defer slog.Debug("Forwarder done") + defer f.cfg.Logger.Debug("Forwarder done") for cmd := range f.cmds { cmd(f) @@ -84,7 +85,7 @@ func (f *Forwarder) Close() { // Log logs the current statistics. func (f *Forwarder) Log() { f.cmds <- func(f *Forwarder) { - slog.Info("Forward statistics", + f.cfg.Logger.Info("Forward statistics", "succeeded", f.succeeded, "failed", f.failed) } @@ -122,7 +123,7 @@ func (f *Forwarder) httpClient() util.Client { if f.cfg.verbose() { client = &util.LoggingClient{ Client: client, - Log: httpLog("Forwarder"), + Log: httpLog("Forwarder", f.cfg.Logger), } } @@ -139,7 +140,7 @@ func replaceExt(fname, nExt string) string { // buildRequest creates an HTTP request suited to forward the given advisory. func (f *Forwarder) buildRequest( filename, doc string, - status validationStatus, + status ValidationStatus, sha256, sha512 string, ) (*http.Request, error) { body := new(bytes.Buffer) @@ -187,38 +188,11 @@ func (f *Forwarder) buildRequest( return req, nil } -// storeFailedAdvisory stores an advisory in a special folder -// in case the forwarding failed. -func (f *Forwarder) storeFailedAdvisory(filename, doc, sha256, sha512 string) error { - // Create special folder if it does not exist. - dir := filepath.Join(f.cfg.Directory, failedForwardDir) - if err := os.MkdirAll(dir, 0755); err != nil { - return err - } - // Store parts which are not empty. - for _, x := range []struct { - p string - d string - }{ - {filename, doc}, - {filename + ".sha256", sha256}, - {filename + ".sha512", sha512}, - } { - if len(x.d) != 0 { - path := filepath.Join(dir, x.p) - if err := os.WriteFile(path, []byte(x.d), 0644); err != nil { - return err - } - } - } - return nil -} - // storeFailed is a logging wrapper around storeFailedAdvisory. func (f *Forwarder) storeFailed(filename, doc, sha256, sha512 string) { f.failed++ - if err := f.storeFailedAdvisory(filename, doc, sha256, sha512); err != nil { - slog.Error("Storing advisory failed forwarding failed", + if err := f.cfg.FailedForwardHandler(filename, doc, sha256, sha512); err != nil { + f.cfg.Logger.Error("Storing advisory failed forwarding failed", "error", err) } } @@ -241,21 +215,21 @@ func limitedString(r io.Reader, max int) (string, error) { // till the configured queue size is filled. func (f *Forwarder) forward( filename, doc string, - status validationStatus, + status ValidationStatus, sha256, sha512 string, ) { // Run this in the main loop of the Forwarder. f.cmds <- func(f *Forwarder) { req, err := f.buildRequest(filename, doc, status, sha256, sha512) if err != nil { - slog.Error("building forward Request failed", + f.cfg.Logger.Error("building forward Request failed", "error", err) f.storeFailed(filename, doc, sha256, sha512) return } res, err := f.httpClient().Do(req) if err != nil { - slog.Error("sending forward request failed", + f.cfg.Logger.Error("sending forward request failed", "error", err) f.storeFailed(filename, doc, sha256, sha512) return @@ -263,10 +237,10 @@ func (f *Forwarder) forward( if res.StatusCode != http.StatusCreated { defer res.Body.Close() if msg, err := limitedString(res.Body, 512); err != nil { - slog.Error("reading forward result failed", + f.cfg.Logger.Error("reading forward result failed", "error", err) } else { - slog.Error("forwarding failed", + f.cfg.Logger.Error("forwarding failed", "filename", filename, "body", msg, "status_code", res.StatusCode) @@ -274,7 +248,7 @@ func (f *Forwarder) forward( f.storeFailed(filename, doc, sha256, sha512) } else { f.succeeded++ - slog.Debug( + f.cfg.Logger.Debug( "forwarding succeeded", "filename", filename) } diff --git a/lib/downloader/forwarder_test.go b/lib/downloader/forwarder_test.go index 8339d108..37e98d79 100644 --- a/lib/downloader/forwarder_test.go +++ b/lib/downloader/forwarder_test.go @@ -19,26 +19,24 @@ import ( "mime/multipart" "net/http" "os" - "path/filepath" "strings" "testing" - "github.com/csaf-poc/csaf_distribution/v3/internal/options" "github.com/csaf-poc/csaf_distribution/v3/util" ) func TestValidationStatusUpdate(t *testing.T) { - sv := validValidationStatus - sv.update(invalidValidationStatus) - sv.update(validValidationStatus) - if sv != invalidValidationStatus { - t.Fatalf("got %q expected %q", sv, invalidValidationStatus) - } - sv = notValidatedValidationStatus - sv.update(validValidationStatus) - sv.update(notValidatedValidationStatus) - if sv != notValidatedValidationStatus { - t.Fatalf("got %q expected %q", sv, notValidatedValidationStatus) + sv := ValidValidationStatus + sv.update(InvalidValidationStatus) + sv.update(ValidValidationStatus) + if sv != InvalidValidationStatus { + t.Fatalf("got %q expected %q", sv, InvalidValidationStatus) + } + sv = NotValidatedValidationStatus + sv.update(ValidValidationStatus) + sv.update(NotValidatedValidationStatus) + if sv != NotValidatedValidationStatus { + t.Fatalf("got %q expected %q", sv, NotValidatedValidationStatus) } } @@ -51,9 +49,10 @@ func TestForwarderLogStats(t *testing.T) { Level: slog.LevelInfo, }) lg := slog.New(h) - slog.SetDefault(lg) - cfg := &Config{} + cfg := &Config{ + Logger: lg, + } fw := NewForwarder(cfg) fw.failed = 11 fw.succeeded = 13 @@ -100,7 +99,7 @@ func TestForwarderHTTPClient(t *testing.T) { ForwardHeader: http.Header{ "User-Agent": []string{"curl/7.55.1"}, }, - LogLevel: &options.LogLevel{Level: slog.LevelDebug}, + Logger: slog.Default(), } fw := NewForwarder(cfg) if c1, c2 := fw.httpClient(), fw.httpClient(); c1 != c2 { @@ -122,7 +121,6 @@ func TestForwarderReplaceExtension(t *testing.T) { } func TestForwarderBuildRequest(t *testing.T) { - // Good case ... cfg := &Config{ ForwardURL: "https://example.com", @@ -131,10 +129,9 @@ func TestForwarderBuildRequest(t *testing.T) { req, err := fw.buildRequest( "test.json", "{}", - invalidValidationStatus, + InvalidValidationStatus, "256", "512") - if err != nil { t.Fatalf("buildRequest failed: %v", err) } @@ -175,9 +172,9 @@ func TestForwarderBuildRequest(t *testing.T) { } foundAdvisory = true case contains("validation_status"): - if vs := validationStatus(data); vs != invalidValidationStatus { + if vs := ValidationStatus(data); vs != InvalidValidationStatus { t.Fatalf("validation_status: got %q expected %q", - vs, invalidValidationStatus) + vs, InvalidValidationStatus) } foundValidationStatus = true case contains("hash-256"): @@ -209,7 +206,7 @@ func TestForwarderBuildRequest(t *testing.T) { if _, err := fw.buildRequest( "test.json", "{}", - invalidValidationStatus, + InvalidValidationStatus, "256", "512", ); err == nil { @@ -241,101 +238,6 @@ func TestLimitedString(t *testing.T) { } } -func TestStoreFailedAdvisory(t *testing.T) { - dir, err := os.MkdirTemp("", "storeFailedAdvisory") - if err != nil { - t.Fatal(err) - } - defer os.RemoveAll(dir) - - cfg := &Config{Directory: dir} - fw := NewForwarder(cfg) - - badDir := filepath.Join(dir, failedForwardDir) - if err := os.WriteFile(badDir, []byte("test"), 0664); err != nil { - t.Fatal(err) - } - - if err := fw.storeFailedAdvisory("advisory.json", "{}", "256", "512"); err == nil { - t.Fatal("if the destination exists as a file an error should occur") - } - - if err := os.Remove(badDir); err != nil { - t.Fatal(err) - } - - if err := fw.storeFailedAdvisory("advisory.json", "{}", "256", "512"); err != nil { - t.Fatal(err) - } - - sha256Path := filepath.Join(dir, failedForwardDir, "advisory.json.sha256") - - // Write protect advisory. - if err := os.Chmod(sha256Path, 0); err != nil { - t.Fatal(err) - } - - if err := fw.storeFailedAdvisory("advisory.json", "{}", "256", "512"); err == nil { - t.Fatal("expected to fail with an error") - } - - if err := os.Chmod(sha256Path, 0644); err != nil { - t.Fatal(err) - } -} - -func TestStoredFailed(t *testing.T) { - dir, err := os.MkdirTemp("", "storeFailed") - if err != nil { - t.Fatal(err) - } - defer os.RemoveAll(dir) - - orig := slog.Default() - defer slog.SetDefault(orig) - - var buf bytes.Buffer - h := slog.NewJSONHandler(&buf, &slog.HandlerOptions{ - Level: slog.LevelError, - }) - lg := slog.New(h) - slog.SetDefault(lg) - - cfg := &Config{Directory: dir} - fw := NewForwarder(cfg) - - // An empty filename should lead to an error. - fw.storeFailed("", "{}", "256", "512") - - if fw.failed != 1 { - t.Fatalf("got %d expected 1", fw.failed) - } - - type entry struct { - Msg string `json:"msg"` - Level string `json:"level"` - } - - sc := bufio.NewScanner(bytes.NewReader(buf.Bytes())) - found := false - for sc.Scan() { - var e entry - if err := json.Unmarshal(sc.Bytes(), &e); err != nil { - t.Fatalf("JSON parsing log failed: %v", err) - } - if e.Msg == "Storing advisory failed forwarding failed" && e.Level == "ERROR" { - found = true - break - } - } - if err := sc.Err(); err != nil { - t.Fatalf("scanning log failed: %v", err) - } - if !found { - t.Fatal("Cannot error logging statistics in log") - } -} - type fakeClient struct { util.Client state int @@ -383,11 +285,15 @@ func TestForwarderForward(t *testing.T) { // in the other test cases. h := slog.NewJSONHandler(io.Discard, nil) lg := slog.New(h) - slog.SetDefault(lg) + + failedHandler := func(filename, doc, sha256, sha512 string) error { + return nil + } cfg := &Config{ - ForwardURL: "http://example.com", - Directory: dir, + ForwardURL: "http://example.com", + Logger: lg, + FailedForwardHandler: failedHandler, } fw := NewForwarder(cfg) @@ -405,7 +311,7 @@ func TestForwarderForward(t *testing.T) { for i := 0; i <= 3; i++ { fw.forward( "test.json", "{}", - invalidValidationStatus, + InvalidValidationStatus, "256", "512") } @@ -419,7 +325,7 @@ func TestForwarderForward(t *testing.T) { <-wait fw.forward( "test.json", "{}", - invalidValidationStatus, + InvalidValidationStatus, "256", "512")