diff --git a/cmd/config/dest.go b/cmd/config/dest.go index 2b683e9..4e39918 100644 --- a/cmd/config/dest.go +++ b/cmd/config/dest.go @@ -12,46 +12,31 @@ import ( "github.com/tlmiller/disttrust/file" ) -func ToDest(id string, opts json.RawMessage) (dest.Dest, error) { - if id != "file" { - return nil, fmt.Errorf("unknown dest type '%s'", id) - } - uopts := map[string]string{} - err := json.Unmarshal(opts, &uopts) - if err != nil { - return nil, errors.Wrap(err, "parsing dest json") - } - fdest := dest.File{} +type Dest struct { + Dest string `json:"dest"` + DestOptions json.RawMessage `json:"destOpts"` +} - caFile, err := destBuilder(uopts["caFile"], uopts["caFileMode"], - uopts["caFileGid"], uopts["caFileUid"]) - if err != nil { - return nil, errors.Wrap(err, "caFile") - } - fdest.CA = caFile +type DestMapper func(json.RawMessage) (dest.Dest, error) - cFile, err := destBuilder(uopts["certFile"], uopts["certFileMode"], - uopts["certFileGid"], uopts["certFileUid"]) - if err != nil { - return nil, errors.Wrap(err, "certFile") - } - fdest.Certificate = cFile +var ( + destMappings = make(map[string]DestMapper) +) - cbFile, err := destBuilder(uopts["certBundleFile"], uopts["certBundleFileMode"], - uopts["certBundleFileGid"], uopts["certBundleFileUid"]) - if err != nil { - return nil, errors.Wrap(err, "certBundleFile") +func MapDest(id string, mapper DestMapper) error { + if _, exists := destMappings[id]; exists { + return fmt.Errorf("dest mapping already registered for id '%s'", id) } - fdest.CertificateBundle = cbFile + destMappings[id] = mapper + return nil +} - pkfile, err := destBuilder(uopts["privKeyFile"], uopts["privKeyFileMode"], - uopts["privKeyFileGid"], uopts["privKeyFileUid"]) - if err != nil { - return nil, errors.Wrap(err, "privKeyFile") +func ToDest(id string, opts json.RawMessage) (dest.Dest, error) { + mapper, exists := destMappings[id] + if !exists { + return nil, fmt.Errorf("dest mapper does not exist for id '%s'", id) } - fdest.PrivateKey = pkfile - - return &fdest, nil + return mapper(opts) } func destBuilder(path, mode, gid, uid string) (file.File, error) { diff --git a/cmd/config/dest_aggregate_mapper.go b/cmd/config/dest_aggregate_mapper.go new file mode 100644 index 0000000..dcbeb78 --- /dev/null +++ b/cmd/config/dest_aggregate_mapper.go @@ -0,0 +1,37 @@ +package config + +import ( + "encoding/json" + + "github.com/pkg/errors" + + "github.com/tlmiller/disttrust/dest" +) + +func destAggregateMapper(opts json.RawMessage) (dest.Dest, error) { + uopts := []Dest{} + err := json.Unmarshal(opts, &uopts) + if err != nil { + return nil, errors.Wrap(err, "parsing dest aggregate json") + } + + dests := []dest.Dest{} + for _, rawDest := range uopts { + if rawDest.Dest == "" { + return nil, errors.New("aggregate dest missing 'dest' key") + } + dest, err := ToDest(rawDest.Dest, rawDest.DestOptions) + if err != nil { + return nil, errors.Wrap(err, "aggregate dest failed parsing") + } + dests = append(dests, dest) + } + return dest.NewAggregate(dests...), nil +} + +func init() { + err := MapDest("aggregate", destAggregateMapper) + if err != nil { + panic(err) + } +} diff --git a/cmd/config/dest_file_mapper.go b/cmd/config/dest_file_mapper.go new file mode 100644 index 0000000..fa69d81 --- /dev/null +++ b/cmd/config/dest_file_mapper.go @@ -0,0 +1,59 @@ +package config + +import ( + "encoding/json" + + "github.com/pkg/errors" + + "github.com/tlmiller/disttrust/dest" +) + +func destFileMapper(opts json.RawMessage) (dest.Dest, error) { + fileDests := []dest.Dest{} + uopts := map[string]string{} + err := json.Unmarshal(opts, &uopts) + if err != nil { + return nil, errors.Wrap(err, "parsing dest file json") + } + + if uopts["caFile"] != "" { + caFile, err := destBuilder(uopts["caFile"], uopts["caFileMode"], + uopts["caFileGid"], uopts["caFileUid"]) + if err != nil { + return nil, errors.Wrap(err, "caFile") + } + fileDests = append(fileDests, dest.NewTemplateFile(dest.CAFile, caFile)) + } + if uopts["certFile"] != "" { + cFile, err := destBuilder(uopts["certFile"], uopts["certFileMode"], + uopts["certFileGid"], uopts["certFileUid"]) + if err != nil { + return nil, errors.Wrap(err, "certFile") + } + fileDests = append(fileDests, dest.NewTemplateFile(dest.CertificateFile, cFile)) + } + if uopts["certBundleFile"] != "" { + cbFile, err := destBuilder(uopts["certBundleFile"], uopts["certBundleFileMode"], + uopts["certBundleFileGid"], uopts["certBundleFileUid"]) + if err != nil { + return nil, errors.Wrap(err, "certBundleFile") + } + fileDests = append(fileDests, dest.NewTemplateFile(dest.CertificateBundleFile, cbFile)) + } + if uopts["privKeyFile"] != "" { + pkFile, err := destBuilder(uopts["privKeyFile"], uopts["privKeyFileMode"], + uopts["privKeyFileGid"], uopts["privKeyFileUid"]) + if err != nil { + return nil, errors.Wrap(err, "privKeyFile") + } + fileDests = append(fileDests, dest.NewTemplateFile(dest.PrivateKeyFile, pkFile)) + } + return dest.NewAggregate(fileDests...), nil +} + +func init() { + err := MapDest("file", DestMapper(destFileMapper)) + if err != nil { + panic(err) + } +} diff --git a/cmd/config/dest_template_file_mapper.go b/cmd/config/dest_template_file_mapper.go new file mode 100644 index 0000000..6240fc4 --- /dev/null +++ b/cmd/config/dest_template_file_mapper.go @@ -0,0 +1,37 @@ +package config + +import ( + "encoding/json" + + "github.com/pkg/errors" + + "github.com/tlmiller/disttrust/dest" +) + +func destTemplateFileMapper(opts json.RawMessage) (dest.Dest, error) { + uopts := map[string]string{} + err := json.Unmarshal(opts, &uopts) + if err != nil { + return nil, errors.Wrap(err, "parsing dest template json") + } + + if uopts["source"] == "" { + return nil, errors.New("dest template does not have a source template path specificed") + } + if uopts["out"] == "" { + return nil, errors.New("dest template does not have an output path specificed") + } + outFile, err := destBuilder(uopts["out"], uopts["mode"], uopts["gid"], + uopts["uid"]) + if err != nil { + return nil, errors.Wrap(err, "dest template output path") + } + return dest.NewTemplateFile(dest.TemplateFileLoader(uopts["source"]), outFile), nil +} + +func init() { + err := MapDest("template", DestMapper(destTemplateFileMapper)) + if err != nil { + panic(err) + } +} diff --git a/dest/aggregate_dest.go b/dest/aggregate_dest.go new file mode 100644 index 0000000..6960ce5 --- /dev/null +++ b/dest/aggregate_dest.go @@ -0,0 +1,25 @@ +package dest + +import ( + "github.com/tlmiller/disttrust/provider" +) + +type Aggregate struct { + Dests []Dest +} + +func NewAggregate(dests ...Dest) *Aggregate { + return &Aggregate{ + Dests: dests, + } +} + +func (a *Aggregate) Send(res *provider.Response) error { + for _, dest := range a.Dests { + err := dest.Send(res) + if err != nil { + return err + } + } + return nil +} diff --git a/dest/aggregate_dest_test.go b/dest/aggregate_dest_test.go new file mode 100644 index 0000000..67046ed --- /dev/null +++ b/dest/aggregate_dest_test.go @@ -0,0 +1,48 @@ +package dest + +import ( + "errors" + "testing" + + "github.com/tlmiller/disttrust/provider" +) + +type mockDest struct { + counter *int + fail bool +} + +func (m *mockDest) Send(_ *provider.Response) error { + (*m.counter)++ + if m.fail { + return errors.New("I was told to fail") + } + return nil +} + +func TestAggregateCallsAllDests(t *testing.T) { + var counter int + agg := NewAggregate(&mockDest{&counter, false}, &mockDest{&counter, false}, + &mockDest{&counter, false}) + err := agg.Send(nil) + if err != nil { + t.Fatalf("unexpected error for aggregate send: %v", err) + } + if counter != 3 { + t.Fatalf("aggregate did not call all dests, expected 3 got %d", counter) + } +} + +func TestAggregateCallDestFailure(t *testing.T) { + var counter int + agg := NewAggregate(&mockDest{&counter, false}, &mockDest{&counter, false}, + &mockDest{&counter, true}, &mockDest{&counter, false}, + &mockDest{&counter, false}) + err := agg.Send(nil) + if err == nil { + t.Fatal("expected non nill err for aggregate failure send") + } + if counter != 3 { + t.Fatalf("aggregate did not call all dests, expected 3 got %d", counter) + } +} diff --git a/dest/file_dest.go b/dest/file_dest.go index 99179b5..477a22c 100644 --- a/dest/file_dest.go +++ b/dest/file_dest.go @@ -1,74 +1,14 @@ package dest -import ( - "io/ioutil" - "os" +type FileDestType string - "github.com/pkg/errors" - - "github.com/tlmiller/disttrust/file" - "github.com/tlmiller/disttrust/provider" +var ( + CAFile FileDestType = "{{ .CA }}" + CertificateFile FileDestType = "{{ .Certificate }}" + CertificateBundleFile FileDestType = "{{ .Certificate }}\n{{ .CABundle }}" + PrivateKeyFile FileDestType = "{{ .PrivateKey }}" ) -type File struct { - CA file.File - Certificate file.File - CertificateBundle file.File - PrivateKey file.File -} - -func (f *File) Send(res *provider.Response) error { - - if res.CA != "" && f.CA.HasPath() { - err := ioutil.WriteFile(f.CA.Path, []byte(res.CA), f.CA.Mode) - if err != nil { - return errors.Wrap(err, "writing ca file") - } - err = f.CA.Chown() - if err != nil { - return errors.Wrap(err, "chown ca file") - } - } - - if res.Certificate != "" && f.Certificate.HasPath() { - err := ioutil.WriteFile(f.Certificate.Path, []byte(res.Certificate), - f.Certificate.Mode) - if err != nil { - return errors.Wrap(err, "writing certificate file") - } - err = f.Certificate.Chown() - if err != nil { - return errors.Wrap(err, "chown certificate file") - } - } - - if res.CABundle != "" && f.CertificateBundle.HasPath() { - s, err := os.OpenFile(f.CertificateBundle.Path, - os.O_WRONLY|os.O_TRUNC|os.O_CREATE, f.CertificateBundle.Mode) - defer s.Close() - if err != nil { - return errors.Wrap(err, "writing certificate bundle file") - } - _, err = s.WriteString(res.Certificate + "\n" + res.CABundle) - if err != nil { - return errors.Wrap(err, "writing certificate bundle file") - } - err = f.CertificateBundle.Chown() - if err != nil { - return errors.Wrap(err, "chown certificate bundle file") - } - } - - if res.PrivateKey != "" && f.PrivateKey.HasPath() { - err := ioutil.WriteFile(f.PrivateKey.Path, []byte(res.PrivateKey), - f.PrivateKey.Mode) - if err != nil { - return errors.Wrap(err, "writing private key file") - } - err = f.PrivateKey.Chown() - if err != nil { - return errors.Wrap(err, "chown private key file") - } - } - return nil +func (f FileDestType) Load() (string, error) { + return string(f), nil } diff --git a/dest/template_dest.go b/dest/template_dest.go new file mode 100644 index 0000000..1de0f7d --- /dev/null +++ b/dest/template_dest.go @@ -0,0 +1,68 @@ +package dest + +import ( + "io" + "io/ioutil" + "text/template" + + "github.com/pkg/errors" + + "github.com/tlmiller/disttrust/provider" +) + +type TemplateLoader interface { + Load() (string, error) +} + +type TemplateLoaderFunc func() (string, error) + +func TemplateFileLoader(path string) TemplateLoader { + return TemplateLoaderFunc(func() (string, error) { + buf, err := ioutil.ReadFile(path) + if err != nil { + return "", errors.Wrap(err, "reading template file") + } + return string(buf), nil + }) +} + +type TemplateString string + +type Template struct { + Loader TemplateLoader + Dest io.WriteCloser +} + +func (f TemplateLoaderFunc) Load() (string, error) { + return f() +} + +func (s TemplateString) Load() (string, error) { + return string(s), nil +} + +func NewTemplate(loader TemplateLoader, dest io.WriteCloser) *Template { + return &Template{ + Loader: loader, + Dest: dest, + } +} + +func (t *Template) Send(res *provider.Response) error { + defer t.Dest.Close() + tmplBody, err := t.Loader.Load() + if err != nil { + return errors.Wrap(err, "loading template") + } + + tmpl, err := template.New("template").Parse(tmplBody) + if err != nil { + return errors.Wrap(err, "parsing dest template") + } + + err = tmpl.Execute(t.Dest, res) + if err != nil { + return errors.Wrap(err, "writing dest template") + } + return nil +} diff --git a/dest/template_dest_test.go b/dest/template_dest_test.go new file mode 100644 index 0000000..579ea85 --- /dev/null +++ b/dest/template_dest_test.go @@ -0,0 +1,72 @@ +package dest + +import ( + "strings" + "testing" + + "github.com/tlmiller/disttrust/provider" +) + +type testDest struct { + strings.Builder +} + +func (d *testDest) Close() error { + return nil +} + +func (d *testDest) Write(p []byte) (n int, err error) { + return d.Builder.Write(p) +} + +func TestTemplateOutputs(t *testing.T) { + tests := []struct { + loader TemplateLoader + res *provider.Response + expect string + }{ + { + TemplateString("{{ .CA }}"), + &provider.Response{CA: "test-ca"}, + "test-ca", + }, + { + TemplateString("{{ .Certificate }}"), + &provider.Response{Certificate: "test-cert"}, + "test-cert", + }, + { + TemplateString("{{ .CABundle }}"), + &provider.Response{CABundle: "test-bundle"}, + "test-bundle", + }, + { + TemplateString("{{ .PrivateKey}}"), + &provider.Response{PrivateKey: "private-key"}, + "private-key", + }, + { + TemplateString("{{ .Serial}}"), + &provider.Response{Serial: "serial"}, + "serial", + }, + { + TemplateString("{{ .Certificate }}\n{{ .CABundle }}"), + &provider.Response{CABundle: "ca-bundle", Certificate: "certificate"}, + "certificate\nca-bundle", + }, + } + + for _, test := range tests { + var d testDest + tmpl := NewTemplate(test.loader, &d) + err := tmpl.Send(test.res) + if err != nil { + t.Fatalf("unexpected error when using dest template: %v", err) + } + if d.Builder.String() != test.expect { + t.Fatalf("dest template output \"%s\" was not expected \"%s\"", + d.Builder.String(), test.expect) + } + } +} diff --git a/dest/template_file_dest.go b/dest/template_file_dest.go new file mode 100644 index 0000000..8014004 --- /dev/null +++ b/dest/template_file_dest.go @@ -0,0 +1,55 @@ +package dest + +import ( + "os" + + "github.com/tlmiller/disttrust/file" + "github.com/tlmiller/disttrust/provider" +) + +type fileWrapper struct { + Dest file.File + ofile *os.File +} + +type TemplateFile struct { + Loader TemplateLoader + dest *fileWrapper +} + +func (t *fileWrapper) Close() error { + if t.ofile == nil { + return nil + } + err := t.ofile.Close() + if err != nil { + return err + } + err = t.Dest.Chown() + return err +} + +func NewTemplateFile(loader TemplateLoader, dest file.File) *TemplateFile { + return &TemplateFile{ + Loader: loader, + dest: &fileWrapper{ + Dest: dest, + }, + } +} + +func (t *TemplateFile) Send(res *provider.Response) error { + return NewTemplate(t.Loader, t.dest).Send(res) +} + +func (t *fileWrapper) Write(p []byte) (n int, err error) { + if t.ofile == nil { + of, err := os.OpenFile(t.Dest.Path, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, + t.Dest.Mode) + if err != nil { + return 0, err + } + t.ofile = of + } + return t.ofile.Write(p) +}