Skip to content
This repository has been archived by the owner on Jan 22, 2025. It is now read-only.

Commit

Permalink
feat(cloudformation): add an option to override parameters (#36)
Browse files Browse the repository at this point in the history
* feat(cloudformation): add an option to override parameters

* chore: update comments

* refactor: improve error handling

* refactor: use strict mode when parsing a format like CF
  • Loading branch information
nikpivkin authored Oct 26, 2023
1 parent 00033a7 commit 97ac24e
Show file tree
Hide file tree
Showing 7 changed files with 497 additions and 14 deletions.
8 changes: 8 additions & 0 deletions pkg/scanners/cloudformation/parser/file_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,11 @@ func (t *FileContext) Metadata() defsecTypes.Metadata {

return defsecTypes.NewMetadata(rng, NewCFReference("Template", rng).String())
}

func (t *FileContext) OverrideParameters(params map[string]any) {
for key := range t.Parameters {
if val, ok := params[key]; ok {
t.Parameters[key].UpdateDefault(val)
}
}
}
61 changes: 61 additions & 0 deletions pkg/scanners/cloudformation/parser/file_context_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package parser

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestFileContext_OverrideParameters(t *testing.T) {
tests := []struct {
name string
ctx FileContext
arg map[string]any
expected map[string]*Parameter
}{
{
name: "happy",
ctx: FileContext{
Parameters: map[string]*Parameter{
"BucketName": {
inner: parameterInner{
Type: "String",
Default: "test",
},
},
"QueueName": {
inner: parameterInner{
Type: "String",
},
},
},
},
arg: map[string]any{
"BucketName": "test2",
"QueueName": "test",
"SomeKey": "some_value",
},
expected: map[string]*Parameter{
"BucketName": {
inner: parameterInner{
Type: "String",
Default: "test2",
},
},
"QueueName": {
inner: parameterInner{
Type: "String",
Default: "test",
},
},
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.ctx.OverrideParameters(tt.arg)
assert.Equal(t, tt.expected, tt.ctx.Parameters)
})
}
}
74 changes: 72 additions & 2 deletions pkg/scanners/cloudformation/parser/parameter.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
package parser

import (
"bytes"
"encoding/json"
"fmt"
"strconv"

"github.com/aquasecurity/trivy-iac/pkg/scanners/cloudformation/cftypes"
"strings"

"github.com/liamg/jfather"
"gopkg.in/yaml.v3"

"github.com/aquasecurity/trivy-iac/pkg/scanners/cloudformation/cftypes"
)

type Parameter struct {
Expand Down Expand Up @@ -57,3 +61,69 @@ func (p *Parameter) UpdateDefault(inVal interface{}) {
p.inner.Default = passedVal
}
}

type Parameters map[string]any

func (p *Parameters) Merge(other Parameters) {
for k, v := range other {
(*p)[k] = v
}
}

func (p *Parameters) UnmarshalJSON(data []byte) error {
(*p) = make(Parameters)

if len(data) == 0 {
return nil
}

switch {
case data[0] == '{' && data[len(data)-1] == '}': // object
// CodePipeline like format
var params struct {
Params map[string]any `json:"Parameters"`
}

if err := json.Unmarshal(data, &params); err != nil {
return err
}

(*p) = params.Params
case data[0] == '[' && data[len(data)-1] == ']': // array
{
// Original format
var params []string

if err := json.Unmarshal(data, &params); err == nil {
for _, param := range params {
parts := strings.Split(param, "=")
if len(parts) != 2 {
return fmt.Errorf("invalid key-value parameter: %q", param)
}
(*p)[parts[0]] = parts[1]
}
return nil
}

// CloudFormation like format
var cfparams []struct {
ParameterKey string `json:"ParameterKey"`
ParameterValue string `json:"ParameterValue"`
}

d := json.NewDecoder(bytes.NewReader(data))
d.DisallowUnknownFields()
if err := d.Decode(&cfparams); err != nil {
return err
}

for _, param := range cfparams {
(*p)[param.ParameterKey] = param.ParameterValue
}
}
default:
return fmt.Errorf("unsupported parameters format")
}

return nil
}
89 changes: 89 additions & 0 deletions pkg/scanners/cloudformation/parser/parameters_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package parser

import (
"encoding/json"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestParameters_UnmarshalJSON(t *testing.T) {
tests := []struct {
name string
source string
expected Parameters
wantErr bool
}{
{
name: "original format",
source: `[
"Key1=Value1",
"Key2=Value2"
]`,
expected: map[string]any{
"Key1": "Value1",
"Key2": "Value2",
},
},
{
name: "CloudFormation like format",
source: `[
{
"ParameterKey": "Key1",
"ParameterValue": "Value1"
},
{
"ParameterKey": "Key2",
"ParameterValue": "Value2"
}
]`,
expected: map[string]any{
"Key1": "Value1",
"Key2": "Value2",
},
},
{
name: "CloudFormation like format, with unknown fields",
source: `[
{
"ParameterKey": "Key1",
"ParameterValue": "Value1"
},
{
"ParameterKey": "Key2",
"ParameterValue": "Value2",
"UsePreviousValue": true
}
]`,
wantErr: true,
},
{
name: "CodePipeline like format",
source: `{
"Parameters": {
"Key1": "Value1",
"Key2": "Value2"
}
}`,
expected: map[string]any{
"Key1": "Value1",
"Key2": "Value2",
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var params Parameters

err := json.Unmarshal([]byte(tt.source), &params)
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, tt.expected, params)
})
}
}
Loading

0 comments on commit 97ac24e

Please sign in to comment.