diff --git a/internal/bundle/bundle_ext.yac.go b/internal/bundle/bundle_ext.yac.go new file mode 100644 index 00000000..6e6e70cd --- /dev/null +++ b/internal/bundle/bundle_ext.yac.go @@ -0,0 +1,52 @@ +// Copyright (c) Mondoo, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package bundle + +import "sort" + +// Sorts the queries, policies and queries' variants in the bundle. +func (p *Bundle) SortContents() { + sort.SliceStable(p.Queries, func(i, j int) bool { + if p.Queries[i].Mrn == "" || p.Queries[j].Mrn == "" { + return p.Queries[i].Uid < p.Queries[j].Uid + } + return p.Queries[i].Mrn < p.Queries[j].Mrn + }) + + sort.SliceStable(p.Policies, func(i, j int) bool { + if p.Policies[i].Mrn == "" || p.Policies[j].Mrn == "" { + return p.Policies[i].Uid < p.Policies[j].Uid + } + return p.Policies[i].Mrn < p.Policies[j].Mrn + }) + + for _, q := range p.Queries { + sort.SliceStable(q.Variants, func(i, j int) bool { + if q.Variants[i].Mrn == "" || q.Variants[j].Mrn == "" { + return q.Variants[i].Uid < q.Variants[j].Uid + } + return q.Variants[i].Mrn < q.Variants[j].Mrn + }) + } + for _, pl := range p.Policies { + for _, g := range pl.Groups { + for _, q := range g.Queries { + sort.SliceStable(q.Variants, func(i, j int) bool { + if q.Variants[i].Mrn == "" || q.Variants[j].Mrn == "" { + return q.Variants[i].Uid < q.Variants[j].Uid + } + return q.Variants[i].Mrn < q.Variants[j].Mrn + }) + } + for _, c := range g.Checks { + sort.SliceStable(c.Variants, func(i, j int) bool { + if c.Variants[i].Mrn == "" || c.Variants[j].Mrn == "" { + return c.Variants[i].Uid < c.Variants[j].Uid + } + return c.Variants[i].Mrn < c.Variants[j].Mrn + }) + } + } + } +} diff --git a/internal/bundle/fmt.go b/internal/bundle/fmt.go index 9401940a..a4208357 100644 --- a/internal/bundle/fmt.go +++ b/internal/bundle/fmt.go @@ -74,26 +74,16 @@ func FormatFile(filename string, sort bool) error { if err != nil { return err } - - if sort { - b, err := policy.BundleFromYAML(data) - if err != nil { - return err - } - - b.SortContents() - data, err = b.ToYAML() - if err != nil { - return err - } + b, err := ParseYaml(data) + if err != nil { + return err } - - data, err = FormatBundleData(data) + fmtData, err := FormatBundle(b, sort) if err != nil { return err } - err = os.WriteFile(filename, data, 0o644) + err = os.WriteFile(filename, fmtData, 0o644) if err != nil { return err } @@ -101,13 +91,8 @@ func FormatFile(filename string, sort bool) error { return nil } -// Format formats the .mql.yaml bundle -func FormatBundleData(data []byte) ([]byte, error) { - b, err := ParseYaml(data) - if err != nil { - return nil, err - } - +// Format formats the Bundle +func FormatBundle(b *Bundle, sort bool) ([]byte, error) { // to improve the formatting we need to remove the whitespace at the end of the lines for i := range b.Queries { query := b.Queries[i] @@ -138,5 +123,9 @@ func FormatBundleData(data []byte) ([]byte, error) { } } + if sort { + b.SortContents() + } + return Format(b) } diff --git a/internal/bundle/fmt_test.go b/internal/bundle/fmt_test.go index 7a347a9f..448f14cf 100644 --- a/internal/bundle/fmt_test.go +++ b/internal/bundle/fmt_test.go @@ -8,11 +8,11 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "go.mondoo.com/cnspec/v9/policy" ) func TestBundleFormatter(t *testing.T) { data := ` +# This is a comment policies: - uid: sshd-server-policy authors: @@ -43,10 +43,13 @@ queries: title: Ensure Secure Boot is enabled ` - formatted, err := FormatBundleData([]byte(data)) + b, err := ParseYaml([]byte(data)) + require.NoError(t, err) + formatted, err := FormatBundle(b, false) require.NoError(t, err) - expected := `policies: + expected := `# This is a comment +policies: - uid: sshd-server-policy name: SSH Server Policy version: 1.0.0 @@ -128,12 +131,9 @@ queries: title: Ensure Secure Boot is enabled ` - b, err := policy.BundleFromYAML([]byte(data)) - require.NoError(t, err) - b.SortContents() - byteData, err := b.ToYAML() + b, err := ParseYaml([]byte(data)) require.NoError(t, err) - formatted, err := FormatBundleData(byteData) + formatted, err := FormatBundle(b, true) require.NoError(t, err) expected := `policies: - uid: sshd-server-policy