diff --git a/pkg/types/layer.go b/pkg/types/layer.go index 231e3109..232303bd 100644 --- a/pkg/types/layer.go +++ b/pkg/types/layer.go @@ -2,7 +2,6 @@ package types import ( "encoding/json" - "fmt" "io/fs" "os" "path/filepath" @@ -73,40 +72,6 @@ type Bom struct { Packages []Package `yaml:"packages" json:"packages,omitempty"` } -func validateDataAsBind(i interface{}) (map[interface{}]interface{}, error) { - bindMap, ok := i.(map[interface{}]interface{}) - if !ok { - return nil, errors.Errorf("unable to cast into map[interface{}]interface{}: %T", i) - } - - // validations - bindSource, ok := bindMap["Source"] - if !ok { - return nil, errors.Errorf("bind source missing: %v", i) - } - - _, ok = bindSource.(string) - if !ok { - return nil, errors.Errorf("unknown bind source type, expected string: %T", i) - } - - bindDest, ok := bindMap["Dest"] - if !ok { - return nil, errors.Errorf("bind dest missing: %v", i) - } - - _, ok = bindDest.(string) - if !ok { - return nil, errors.Errorf("unknown bind dest type, expected string: %T", i) - } - - if bindSource == "" || bindDest == "" { - return nil, errors.Errorf("empty source or dest: %v", i) - } - - return bindMap, nil -} - func getStringOrStringSlice(data interface{}, xform func(string) ([]string, error)) ([]string, error) { // The user didn't supply run: at all, so let's not do anything. if data == nil { @@ -125,14 +90,6 @@ func getStringOrStringSlice(data interface{}, xform func(string) ([]string, erro switch v := i.(type) { case string: s = v - case interface{}: - bindMap, err := validateDataAsBind(i) - if err != nil { - return nil, err - } - - // validations passed, return as string in form: source -> dest - s = fmt.Sprintf("%s -> %s", bindMap["Source"], bindMap["Dest"]) default: return nil, errors.Errorf("unknown run array type: %T", i) } @@ -205,87 +162,77 @@ func (c *Command) UnmarshalYAML(unmarshal func(interface{}) error) error { return nil } -type Bind struct { - Source string `yaml:"source" json:"source"` - Dest string `yaml:"dest" json:"dest"` +type bindType struct { + Source string `yaml:"source" json:"source,omitempty"` + Dest string `yaml:"dest" json:"dest,omitempty"` } -type Binds []Bind - -func (bs *Bind) MarshalJSON() ([]byte, error) { - var sb strings.Builder +// toBind - copy to Bind type and check. +func (b *bindType) toBind(bs *Bind) error { + if b.Source == "" { + return errors.Errorf("unexpected 'bind': missing required field 'source': %#v", b) + } + bs.Source = b.Source + bs.Dest = b.Dest if bs.Dest == "" { - sb.WriteString(fmt.Sprintf("%q", bs.Source)) - } else { - var sbt strings.Builder - sbt.WriteString(fmt.Sprintf("%s -> %s", bs.Source, bs.Dest)) - sb.WriteString(fmt.Sprintf("%q", sbt.String())) + bs.Dest = bs.Source } - - return []byte(sb.String()), nil + return nil } -func (bs *Binds) UnmarshalJSON(data []byte) error { - var rawBinds []string - - if err := json.Unmarshal(data, &rawBinds); err != nil { - return err - } - - *bs = Binds{} - for _, bind := range rawBinds { - parts := strings.Split(bind, "->") - if len(parts) != 1 && len(parts) != 2 { - return errors.Errorf("invalid bind mount %s", bind) - } - - source := strings.TrimSpace(parts[0]) - target := source - - if len(parts) == 2 { - target = strings.TrimSpace(parts[1]) - } - - *bs = append(*bs, Bind{Source: source, Dest: target}) - } +func (b *bindType) toBindFromString(bind *Bind, asStr string) error { + toks := strings.Fields(asStr) + if len(toks) == 1 { + bind.Source = toks[0] + bind.Dest = toks[0] + return nil + } else if len(toks) == 3 && toks[1] == "->" { + bind.Source = toks[0] + bind.Dest = toks[2] + return nil + } + return errors.Errorf("invalid Bind: %s", string(asStr)) +} - return nil +type Bind struct { + Source string `yaml:"source,omitempty" json:"source,omitempty"` + Dest string `yaml:"dest,omitempty" json:"dest,omitempty"` } -func (bs *Binds) UnmarshalYAML(unmarshal func(interface{}) error) error { - var data interface{} - err := unmarshal(&data) - if err != nil { - return errors.WithStack(err) - } +type Binds []Bind - xform := func(s string) ([]string, error) { - return []string{s}, nil +func (bs *Bind) UnmarshalJSON(data []byte) error { + btype := bindType{} + if err := json.Unmarshal(data, &btype); err == nil { + return btype.toBind(bs) } - rawBinds, err := getStringOrStringSlice(data, xform) + asStr := "" + err := json.Unmarshal(data, &asStr) if err != nil { - return err + return errors.Errorf("invalid Bind: %s", string(data)) } - *bs = Binds{} - for _, bind := range rawBinds { - parts := strings.Split(bind, "->") - if len(parts) != 1 && len(parts) != 2 { - return errors.Errorf("invalid bind mount %s", bind) - } + return btype.toBindFromString(bs, asStr) +} - source := strings.TrimSpace(parts[0]) - target := source +func (bs *Bind) UnmarshalYAML(unmarshal func(interface{}) error) error { + btype := bindType{} + if err := unmarshal(&btype); err == nil { + return btype.toBind(bs) + } - if len(parts) == 2 { - target = strings.TrimSpace(parts[1]) - } + asStr := "" + if err := unmarshal(&asStr); err == nil { + return btype.toBindFromString(bs, asStr) + } - *bs = append(*bs, Bind{Source: source, Dest: target}) + var data interface{} + if err := unmarshal(&data); err != nil { + return errors.Errorf("unexpected error unmarshaling bind yaml: %v", err) } - return nil + return errors.Errorf("unexpected 'bind' data of type: %s: %#v", reflect.TypeOf(data), data) } type Layer struct { diff --git a/pkg/types/layer_bind_test.go b/pkg/types/layer_bind_test.go new file mode 100644 index 00000000..ef39fa8f --- /dev/null +++ b/pkg/types/layer_bind_test.go @@ -0,0 +1,139 @@ +package types + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "gopkg.in/yaml.v2" +) + +func TestUnmarshalYamlAndJSON(t *testing.T) { + assert := assert.New(t) + tables := []struct { + desc string + yblob string + jblob string + expected Binds + errstr string + }{ + {desc: "proper array of source/dest bind allowed", + yblob: "- source: src1\n dest: dest1\n", + jblob: `[{"source": "src1", "dest": "dest1"}]`, + expected: Binds{ + Bind{Source: "src1", Dest: "dest1"}, + }}, + {desc: "array of bind ascii art", + yblob: "- src1 -> dest1\n- src2 -> dest2", + jblob: `["src1 -> dest1", "src2 -> dest2"]`, + expected: Binds{ + Bind{Source: "src1", Dest: "dest1"}, + Bind{Source: "src2", Dest: "dest2"}, + }}, + {desc: "example mixed valid ascii art and dict", + yblob: "- src1 -> dest1\n- source: src2\n dest: dest2\n", + jblob: `["src1 -> dest1", {"source": "src2", "dest": "dest2"}]`, + expected: Binds{ + Bind{Source: "src1", Dest: "dest1"}, + Bind{Source: "src2", Dest: "dest2"}, + }}, + // golang encoding/json is case insensitive + {desc: "capital Source/Dest is not allowed as yaml", + yblob: "- Source: src1\n Dest: dest1\n", + expected: Binds{}, + errstr: "xpected 'bind'"}, + {desc: "source is required", + yblob: "- Dest: dest1\n", + jblob: `[{"Dest": "dest1"}]`, + expected: Binds{}, + errstr: "xpected 'bind'"}, + {desc: "must be an array", + yblob: "source: src1\ndest: dest1\n", + jblob: `{"source": "src1", "dest": "dest1"}`, + expected: Binds{}, + errstr: "unmarshal"}, + } + var err error + found := Binds{} + for _, t := range tables { + err = yaml.Unmarshal([]byte(t.yblob), &found) + if t.errstr == "" { + if !assert.NoError(err, t.desc) { + continue + } + assert.Equal(t.expected, found) + } else { + assert.ErrorContains(err, t.errstr, t.desc) + } + } + + for _, t := range tables { + if t.jblob == "" { + continue + } + err = json.Unmarshal([]byte(t.jblob), &found) + if t.errstr == "" { + if !assert.NoError(err, t.desc) { + continue + } + assert.Equal(t.expected, found) + } else { + assert.ErrorContains(err, t.errstr, t.desc) + } + } +} + +func TestUnmarshalJSON(t *testing.T) { + assert := assert.New(t) + tables := []struct { + desc string + jblob string + expected Binds + errstr string + }{ + {desc: "proper array of source/dest bind allowed", + jblob: `[{"source": "src1", "dest": "dest1"}]`, + expected: Binds{ + Bind{Source: "src1", Dest: "dest1"}, + }}, + /* + {desc: "array of bind ascii art", + jblob: "- src1 -> dest1\n- src2 -> dest2", + expected: Binds{ + Bind{Source: "src1", Dest: "dest1"}, + Bind{Source: "src2", Dest: "dest2"}, + }}, + {desc: "example mixed valid ascii art and dict", + jblob: "- src1 -> dest1\n- source: src2\n dest: dest2\n", + expected: Binds{ + Bind{Source: "src1", Dest: "dest1"}, + Bind{Source: "src2", Dest: "dest2"}, + }}, + {desc: "capital Source/Dest is not allowed", + jblob: "- Source: src1\n Dest: dest1\n", + expected: Binds{}, + errstr: "xpected 'bind'"}, + {desc: "source is required", + jblob: "- Dest: dest1\n", + expected: Binds{}, + errstr: "xpected 'bind'"}, + {desc: "must be an array", + jblob: "source: src1\ndest: dest1\n", + expected: Binds{}, + errstr: "unmarshal"}, + */ + } + var err error + found := Binds{} + for _, t := range tables { + err = json.Unmarshal([]byte(t.jblob), &found) + if t.errstr == "" { + if !assert.NoError(err, t.desc) { + continue + } + assert.Equal(t.expected, found) + } else { + assert.ErrorContains(err, t.errstr, t.desc) + } + } +} diff --git a/test/binds.bats b/test/binds.bats index fbab665f..b34f0a34 100644 --- a/test/binds.bats +++ b/test/binds.bats @@ -43,12 +43,15 @@ bind-test: type: oci url: ${{CENTOS_OCI}} binds: - - Source: ${{bind_path}} - Dest: /root/tree1/foo + - source: ${{bind_path1}} + dest: /root/tree1/foo + - source: ${{bind_path2}} run: | touch /root/tree1/foo/bar + [ -f "${bind_path2}/file1" ] EOF - mkdir -p tree1/foo + mkdir -p tree1/foo tree2/bar + touch tree/bar/file1 # since we are creating directory as # real root and then `touch`-ing a file @@ -56,9 +59,13 @@ EOF # for others chmod +666 tree1/foo - bind_path=$(realpath tree1/foo) + bind_path1=$(realpath tree1/foo) + bind_path2=$(realpath tree2/bar) - out=$(stacker build --substitute bind_path=$bind_path --substitute CENTOS_OCI=$CENTOS_OCI) + out=$(stacker build \ + "--substitute=bind_path1=${bind_path1}" \ + "--substitute=bind_path2=${bind_path2}" \ + --substitute CENTOS_OCI=$CENTOS_OCI) [[ "${out}" =~ ^(.*filesystem bind-test built successfully)$ ]] stat tree1/foo/bar