From 3013b95ac680510e44488d13852ce7e05e8e4359 Mon Sep 17 00:00:00 2001 From: David Schneider Date: Sat, 27 Jul 2024 19:58:17 +0200 Subject: [PATCH] Adapt tests --- cmd/pcert/cobra.go | 7 +-- cmd/pcert/create2.go | 13 ++-- cmd/pcert/create_test.go | 125 +++++++++++++++++++------------------- cmd/pcert/flags.go | 3 + cmd/pcert/main.go | 2 +- cmd/pcert/request.go | 63 +++++++++++-------- cmd/pcert/request_test.go | 31 +++++----- 7 files changed, 128 insertions(+), 116 deletions(-) diff --git a/cmd/pcert/cobra.go b/cmd/pcert/cobra.go index 9b101f2..9dec5ea 100644 --- a/cmd/pcert/cobra.go +++ b/cmd/pcert/cobra.go @@ -3,20 +3,17 @@ package main import ( "errors" "fmt" - "os" "strings" "github.com/spf13/cobra" "github.com/spf13/pflag" ) -func WithEnv(c *cobra.Command) *cobra.Command { +func WithEnv(c *cobra.Command, args []string, getEnv func(name string) (string, bool)) *cobra.Command { if c.HasParent() { c = c.Root() } - args := os.Args[1:] - var ( cmd *cobra.Command err error @@ -41,7 +38,7 @@ func WithEnv(c *cobra.Command) *cobra.Command { optName := strings.ToUpper(f.Name) optName = strings.ReplaceAll(optName, "-", "_") varName := envVarPrefix + optName - if val, ok := os.LookupEnv(varName); ok { + if val, ok := getEnv(varName); ok { err := f.Value.Set(val) if err != nil { errs = append(errs, fmt.Errorf("invalid environment variable '%s': %w", varName, err)) diff --git a/cmd/pcert/create2.go b/cmd/pcert/create2.go index 0633c72..2f80c36 100644 --- a/cmd/pcert/create2.go +++ b/cmd/pcert/create2.go @@ -16,7 +16,7 @@ import ( type createCommand struct { Out io.Writer - In io.Writer + In io.Reader CertificateOutputLocation string KeyOutputLocation string @@ -41,8 +41,6 @@ func getKeyRelativeToCert(certPath string) string { func newCreate2Cmd() *cobra.Command { createCommand := &createCommand{ - Out: os.Stdout, - In: os.Stdin, CertificateOutputLocation: "", KeyOutputLocation: "", SignCertificateLocation: "", @@ -63,6 +61,8 @@ pcert create tls.crt `, Args: cobra.MaximumNArgs(2), RunE: func(cmd *cobra.Command, args []string) error { + createCommand.In = cmd.InOrStdin() + createCommand.Out = cmd.OutOrStdout() // default key output file relative to certificate if len(args) == 1 && args[0] != "-" { createCommand.CertificateOutputLocation = args[0] @@ -103,7 +103,7 @@ pcert create tls.crt if createCommand.SignCertificateLocation != "" { slog.Info("process signer") if createCommand.SignCertificateLocation == "-" { - stdin, err = io.ReadAll(os.Stdin) + stdin, err = io.ReadAll(createCommand.In) if err != nil { return err } @@ -157,7 +157,10 @@ pcert create tls.crt } if createCommand.CertificateOutputLocation == "" || createCommand.CertificateOutputLocation == "-" { - createCommand.Out.Write(certPEM) + _, err := createCommand.Out.Write(certPEM) + if err != nil { + return err + } } else { err := os.WriteFile(createCommand.CertificateOutputLocation, certPEM, 0664) if err != nil { diff --git a/cmd/pcert/create_test.go b/cmd/pcert/create_test.go index 354c080..24f2a78 100644 --- a/cmd/pcert/create_test.go +++ b/cmd/pcert/create_test.go @@ -1,8 +1,11 @@ package main import ( + "bytes" "crypto/x509" "crypto/x509/pkix" + "fmt" + "io" "os" "testing" "time" @@ -10,56 +13,54 @@ import ( "github.com/dvob/pcert" ) -func runCmd(args []string, env map[string]string) error { - os.Clearenv() - for k, v := range env { - os.Setenv(k, v) - } +func runCmd(args []string, env map[string]string) (io.WriteCloser, *bytes.Buffer, *bytes.Buffer, error) { + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + stdinReader, stdinWriter := io.Pipe() cmd := newRootCmd() cmd.SetArgs(args) - return cmd.Execute() + cmd.SetIn(stdinReader) + cmd.SetOut(stdout) + cmd.SetErr(stderr) + cmd = WithEnv(cmd, args, func(name string) (string, bool) { + if env == nil { + return "", false + } + val, ok := env[name] + return val, ok + }) + + return stdinWriter, stdout, stderr, cmd.Execute() } -func runCreateAndLoad(name string, args []string, env map[string]string) (*x509.Certificate, error) { - defer os.Remove(name + ".crt") - defer os.Remove(name + ".key") - fullArgs := []string{"create", name} - fullArgs = append(fullArgs, args...) - err := runCmd(fullArgs, env) +func runAndLoad(args []string, env map[string]string) (*x509.Certificate, error) { + _, stdout, stderr, err := runCmd(args, env) if err != nil { return nil, err } - cert, err := pcert.Load(name + ".crt") - return cert, err -} + if stderr.Len() != 0 { + return nil, fmt.Errorf("stderr not empty '%s'", stderr.String()) + } -func Test_create(t *testing.T) { - name := "foo1" - cert, err := runCreateAndLoad("foo1", []string{}, nil) + cert, err := pcert.Parse(stdout.Bytes()) if err != nil { - t.Error(err) - return + return nil, fmt.Errorf("could not read certificate from standard output: %s", err) } - if cert.Subject.CommonName != name { - t.Errorf("common name no set correctly: got: %s, want: %s", cert.Subject.CommonName, name) - } + return cert, err } -func Test_create_subject(t *testing.T) { - cn := "myCommonName" - cert, err := runCreateAndLoad("foo2", []string{ - "--subject", - "CN=" + cn, - }, nil) +func Test_create(t *testing.T) { + name := "foo1" + cert, err := runAndLoad([]string{"create", "--subject", "/CN=" + name}, nil) if err != nil { - t.Error(err) + t.Fatal(err) return } - if cert.Subject.CommonName != cn { - t.Errorf("common name no set correctly: got: %s, want: %s", cert.Subject.CommonName, cn) + if cert.Subject.CommonName != name { + t.Fatalf("common name no set correctly: got: %s, want: %s", cert.Subject.CommonName, name) } } @@ -71,7 +72,8 @@ func Test_create_subject_multiple(t *testing.T) { Organization: []string{"Snakeoil Ltd."}, OrganizationalUnit: []string{"Group 1", "Group 2"}, } - cert, err := runCreateAndLoad("subject2", []string{ + cert, err := runAndLoad([]string{ + "create", "--subject", "CN=Bla bla bla/C=CH/L=Bern", "--subject", @@ -80,12 +82,12 @@ func Test_create_subject_multiple(t *testing.T) { "OU=Group 1/OU=Group 2", }, nil) if err != nil { - t.Error(err) + t.Fatal(err) return } if subject.String() != cert.Subject.String() { - t.Errorf("subject no set correctly:\n got: %s\nwant: %s", cert.Subject, subject) + t.Fatalf("subject no set correctly:\n got: %s\nwant: %s", cert.Subject, subject) } } @@ -100,54 +102,57 @@ func Test_create_subject_combined_with_environment(t *testing.T) { Organization: []string{"Snakeoil Ltd."}, OrganizationalUnit: []string{"Group 1", "Group 2"}, } - cert, err := runCreateAndLoad("subject3", []string{ + cert, err := runAndLoad([]string{ + "create", "--subject", "CN=Bla bla bla", "--subject", "OU=Group 1/OU=Group 2", }, env) if err != nil { - t.Error(err) + t.Fatal(err) return } if subject.String() != cert.Subject.String() { - t.Errorf("subject no set correctly:\n got: %s\nwant: %s", cert.Subject, subject) + t.Fatalf("subject no set correctly:\n got: %s\nwant: %s", cert.Subject, subject) } } func Test_create_not_before(t *testing.T) { notBefore := time.Date(2020, 10, 27, 12, 0, 0, 0, time.FixedZone("UTC+1", 60*60)) - cert, err := runCreateAndLoad("foo3", []string{ + cert, err := runAndLoad([]string{ + "create", "--not-before", "2020-10-27T12:00:00+01:00", }, nil) if err != nil { - t.Error(err) + t.Fatal(err) return } if !cert.NotBefore.Equal(notBefore) { - t.Errorf("not before not set correctly: got: %s, want: %s", cert.NotBefore, notBefore) + t.Fatalf("not before not set correctly: got: %s, want: %s", cert.NotBefore, notBefore) } notAfter := notBefore.Add(pcert.DefaultValidityPeriod) if !cert.NotAfter.Equal(notAfter) { - t.Errorf("not after not set correctly: got: %s, want: %s", cert.NotAfter, notAfter) + t.Fatalf("not after not set correctly: got: %s, want: %s", cert.NotAfter, notAfter) } } func Test_create_not_before_and_not_after(t *testing.T) { notBefore := time.Date(2020, 12, 30, 12, 0, 0, 0, time.FixedZone("UTC+1", 60*60)) notAfter := time.Date(2022, 12, 30, 12, 0, 0, 0, time.FixedZone("UTC+1", 60*60)) - cert, err := runCreateAndLoad("foo4", []string{ + cert, err := runAndLoad([]string{ + "create", "--not-before", "2020-12-30T12:00:00+01:00", "--not-after", "2022-12-30T12:00:00+01:00", }, nil) if err != nil { - t.Error(err) + t.Fatal(err) return } @@ -162,13 +167,13 @@ func Test_create_not_before_and_not_after(t *testing.T) { func Test_create_with_expiry(t *testing.T) { now := time.Now().Round(time.Minute) - cert, err := runCreateAndLoad("foo4", []string{ + cert, err := runAndLoad([]string{ + "create", "--expiry", "3y", }, nil) if err != nil { - t.Error(err) - return + t.Fatal(err) } actualNotBefore := cert.NotBefore.Round(time.Minute) @@ -185,14 +190,15 @@ func Test_create_with_expiry(t *testing.T) { func Test_create_not_before_with_expiry(t *testing.T) { notBefore := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) - cert, err := runCreateAndLoad("foo4", []string{ + cert, err := runAndLoad([]string{ + "create", "--not-before", "2020-01-01T00:00:00Z", "--expiry", "90d", }, nil) if err != nil { - t.Error(err) + t.Fatal(err) return } @@ -207,30 +213,23 @@ func Test_create_not_before_with_expiry(t *testing.T) { } func Test_create_output_parameter(t *testing.T) { - name := "foo2" - certFile := "mycert_foo2" - keyFile := "mykey_foo2" - defer os.Remove(certFile) - defer os.Remove(keyFile) - err := runCmd([]string{ + defer os.Remove("tls.crt") + defer os.Remove("tls.key") + _, _, _, err := runCmd([]string{ "create", - name, - "--cert", - certFile, - "--key", - keyFile, + "tls.crt", }, nil) if err != nil { - t.Error(err) + t.Fatal(err) return } - _, err = pcert.Load(certFile) + _, err = pcert.Load("tls.crt") if err != nil { t.Errorf("could not load certificate: %s", err) } - _, err = pcert.LoadKey(keyFile) + _, err = pcert.LoadKey("tls.key") if err != nil { t.Errorf("could not load key: %s", err) } diff --git a/cmd/pcert/flags.go b/cmd/pcert/flags.go index 3125bc7..b7541a9 100644 --- a/cmd/pcert/flags.go +++ b/cmd/pcert/flags.go @@ -69,8 +69,11 @@ func BindCertificateOptionsFlags(fs *pflag.FlagSet, co *pcert.CertificateOptions fs.IPSliceVar(&co.IPAddresses, "ip", []net.IP{}, "IP subject alternative name.") fs.Var(newURISliceValue(&co.URIs), "uri", "URI subject alternative name.") fs.Var(newSignAlgValue(&co.SignatureAlgorithm), "sign-alg", "Signature Algorithm. See 'pcert list' for available algorithms.") + fs.Var(newTimeValue(&co.NotBefore), "not-before", fmt.Sprintf("Not valid before time in RFC3339 format (e.g. '%s').", time.Now().UTC().Format(time.RFC3339))) fs.Var(newTimeValue(&co.NotAfter), "not-after", fmt.Sprintf("Not valid after time in RFC3339 format (e.g. '%s').", time.Now().Add(time.Hour*24*60).UTC().Format(time.RFC3339))) + fs.Var(newDurationValue(&co.Expiry), "expiry", "Validity period of the certificate. If --not-after is set this option has no effect.") + fs.Var(newSubjectValue(&co.Subject), "subject", "Subject in the form '/C=CH/O=My Org/OU=My Team'.") //fs.BoolVar(&co.BasicConstraintsValid, "basic-constraints", cert.BasicConstraintsValid, "Add basic constraints extension.") diff --git a/cmd/pcert/main.go b/cmd/pcert/main.go index 994219e..e318de1 100644 --- a/cmd/pcert/main.go +++ b/cmd/pcert/main.go @@ -20,7 +20,7 @@ var ( ) func main() { - err := WithEnv(newRootCmd()).Execute() + err := WithEnv(newRootCmd(), os.Args[1:], os.LookupEnv).Execute() if err != nil { fmt.Fprintln(os.Stderr, err) os.Exit(1) diff --git a/cmd/pcert/request.go b/cmd/pcert/request.go index bfba4d1..fa37ce9 100644 --- a/cmd/pcert/request.go +++ b/cmd/pcert/request.go @@ -11,30 +11,28 @@ import ( func newRequestCmd() *cobra.Command { var ( - csrFile string - csr = &x509.CertificateRequest{} - key = &key{} + csrOutput string + csr = &x509.CertificateRequest{} + + keyOutput string + keyOpts = pcert.KeyOptions{} ) cmd := &cobra.Command{ - Use: "request ", - Short: "Create a certificate signing request (CSR)", - Long: "Creates a CSR and a coresponding key.", - Args: cobra.ExactArgs(1), + Use: "request [OUTPUT-CSR [OUTPUT-KEY]]", + Short: "Create a certificate signing request (CSR) and key", + Args: cobra.MaximumNArgs(2), RunE: func(cmd *cobra.Command, args []string) error { - name := args[0] - if csr.Subject.CommonName == "" { - csr.Subject.CommonName = name - } - - if csrFile == "" { - csrFile = name + csrFileSuffix + if len(args) == 1 && args[0] != "-" { + csrOutput = args[0] + keyOutput = getKeyRelativeToCert(args[0]) } - if key.path == "" { - key.path = name + keyFileSuffix + if len(args) == 2 { + csrOutput = args[0] + keyOutput = args[1] } - csrDER, privateKey, err := pcert.CreateRequestWithKeyOptions(csr, key.opts) + csrDER, privateKey, err := pcert.CreateRequestWithKeyOptions(csr, keyOpts) if err != nil { return err } @@ -46,23 +44,38 @@ func newRequestCmd() *cobra.Command { csrPEM := pcert.EncodeCSR(csrDER) - err = os.WriteFile(key.path, keyPEM, 0600) - if err != nil { - return fmt.Errorf("failed to write key '%s': %w", key.path, err) + if csrOutput == "" || csrOutput == "-" { + _, err := cmd.OutOrStdout().Write(csrPEM) + if err != nil { + return err + } + } else { + err := os.WriteFile(csrOutput, csrPEM, 0664) + if err != nil { + return fmt.Errorf("failed to write CSR '%s': %w", csrOutput, err) + } } - err = os.WriteFile(csrFile, csrPEM, 0640) - if err != nil { - return fmt.Errorf("failed to write CSR '%s': %w", csrFile, err) + + if keyOutput == "" || keyOutput == "-" { + _, err := cmd.OutOrStdout().Write(keyPEM) + if err != nil { + return err + } + } else { + err = os.WriteFile(keyOutput, keyPEM, 0600) + if err != nil { + return fmt.Errorf("failed to write key '%s': %w", keyOutput, err) + } } return nil }, } - key.bindFlags(cmd) + BindKeyFlags(cmd.Flags(), &keyOpts) + RegisterKeyCompletionFuncs(cmd) BindCertificateRequestFlags(cmd.Flags(), csr) RegisterCertificateRequestCompletionFuncs(cmd) - cmd.Flags().StringVar(&csrFile, "csr", "", "Output file for the CSR. Defaults to .csr") return cmd } diff --git a/cmd/pcert/request_test.go b/cmd/pcert/request_test.go index 388d2c4..58575a4 100644 --- a/cmd/pcert/request_test.go +++ b/cmd/pcert/request_test.go @@ -1,33 +1,30 @@ package main import ( - "crypto/x509" - "os" "testing" "github.com/dvob/pcert" ) -func runRequestAndLoad(name string, args []string, env map[string]string) (*x509.CertificateRequest, error) { - defer os.Remove(name + ".csr") - defer os.Remove(name + ".key") - fullArgs := []string{"request", name} - fullArgs = append(fullArgs, args...) - err := runCmd(fullArgs, env) +func Test_request(t *testing.T) { + name := "foo" + _, stdout, stderr, err := runCmd([]string{ + "request", + "--subject", + "/CN=" + name, + }, nil) if err != nil { - return nil, err + t.Fatal(err) + return } - csr, err := pcert.LoadCSR(name + ".csr") - return csr, err -} + if stderr.Len() != 0 { + t.Fatalf("stderr not empty '%s'", stderr.String()) + } -func Test_request(t *testing.T) { - name := "csr1" - csr, err := runRequestAndLoad(name, []string{}, nil) + csr, err := pcert.ParseCSR(stdout.Bytes()) if err != nil { - t.Error(err) - return + t.Fatal(err) } if csr.Subject.CommonName != name {