From 0797b3e01cc2bf628a735d5c589586362826ec57 Mon Sep 17 00:00:00 2001 From: Scott Moser Date: Wed, 20 Sep 2023 12:32:10 -0700 Subject: [PATCH] fix: Improvements to yaml.Unmarshal for types.Imports and add some go tests Improvements here: * Drop duplicate code blocks from getImportFromInterface There were 2 identical codeblocks, one with a type convert to map[string]interface{} and one with type convert to map[interface{}]interface{}. I cannot find any reason for the former * Add some golang tests to for unmarshalling. * Type check the values of all string Import fields, rather than just converting them to a string with 'fmt.Sprintf("%v")' Signed-off-by: Scott Moser --- pkg/types/layer.go | 161 ++++++++++++++------------------- pkg/types/layer_import_test.go | 111 +++++++++++++++++++++++ 2 files changed, 179 insertions(+), 93 deletions(-) create mode 100644 pkg/types/layer_import_test.go diff --git a/pkg/types/layer.go b/pkg/types/layer.go index a779a3a3..1c004b07 100644 --- a/pkg/types/layer.go +++ b/pkg/types/layer.go @@ -14,6 +14,8 @@ import ( "github.com/anmitsu/go-shlex" "github.com/pkg/errors" "gopkg.in/yaml.v2" + + "stackerbuild.io/stacker/pkg/lib" ) const ( @@ -486,100 +488,82 @@ func requireImportHash(imports Imports) error { return nil } +// getImportFromInterface - +// an Import (an entry in 'imports'), can be written in yaml as either a string or a map[string]: +// imports: +// - /path/to-file +// - path: /path/f2 +// This function gets a single entry in that list and returns the Import. func getImportFromInterface(v interface{}) (Import, error) { - var hash, dest string - var mode *fs.FileMode - uid := -1 - gid := -1 + mode := -1 + ret := Import{Mode: nil, Uid: lib.UidEmpty, Gid: lib.GidEmpty} - m, ok := v.(map[interface{}]interface{}) + // if it is a simple string, that is the path + s, ok := v.(string) if ok { - // check for nil hash so that we won't end up with "nil" string values - if m["hash"] == nil { - hash = "" - } else { - hash = fmt.Sprintf("%v", m["hash"]) - } - - if m["dest"] != nil { - if !filepath.IsAbs(m["dest"].(string)) { - return Import{}, errors.Errorf("Dest path cannot be relative for: %#v", v) - } + ret.Path = s + return ret, nil + } - dest = fmt.Sprintf("%s", m["dest"]) - } else { - dest = "" - } + m, ok := v.(map[interface{}]interface{}) + if !ok { + return Import{}, errors.Errorf("Didn't find a matching type for: %#v", v) + } - if m["mode"] != nil { - val := fs.FileMode(m["mode"].(int)) - mode = &val + for k, _ := range m { + if _, ok := k.(string); !ok { + return Import{}, errors.Errorf("key '%s' in import is not a string: %#v", k, v) } + } - if _, ok := m["uid"]; ok { - uid = m["uid"].(int) - if uid < 0 { - return Import{}, errors.Errorf("Uid cannot be negative: %v", uid) - } + // if present, these must have string values. + for name, dest := range map[string]*string{"hash": &ret.Hash, "path": &ret.Path, "dest": &ret.Dest} { + val, found := m[name] + if !found { + continue } - - if _, ok := m["gid"]; ok { - gid = m["gid"].(int) - if gid < 0 { - return Import{}, errors.Errorf("Gid cannot be negative: %v", gid) - } + s, ok := val.(string) + if !ok { + return Import{}, errors.Errorf("value for '%s' in import is not a string: %#v", name, v) } - - return Import{Hash: hash, Path: fmt.Sprintf("%v", m["path"]), Dest: dest, Mode: mode, Uid: uid, Gid: gid}, nil + *dest = s } - m2, ok := v.(map[string]interface{}) - if ok { - // check for nil hash so that we won't end up with "nil" string values - if m2["hash"] == nil { - hash = "" - } else { - hash = fmt.Sprintf("%v", m2["hash"]) - } - - if m2["dest"] != nil { - if !filepath.IsAbs(m2["dest"].(string)) { - return Import{}, errors.Errorf("Dest path cannot be relative for: %#v", v) - } - - dest = fmt.Sprintf("%s", m["dest"]) - } else { - dest = "" + // if present, these must have int values + for name, dest := range map[string]*int{"mode": &mode, "uid": &ret.Uid, "gid": &ret.Gid} { + val, found := m[name] + if !found { + continue } - - if m2["mode"] != nil { - val := fs.FileMode(m2["mode"].(int)) - mode = &val + i, ok := val.(int) + if !ok { + return Import{}, errors.Errorf("value for '%s' in import is not an integer: %#v", name, v) } + *dest = i + } - if _, ok := m2["uid"]; ok { - uid = m2["uid"].(int) - if uid < 0 { - return Import{}, errors.Errorf("Uid cannot be negative: %v", uid) - } - } + if ret.Path == "" { + return ret, errors.Errorf("No 'path' entry found in import: %#v", v) + } - if _, ok := m2["gid"]; ok { - gid = m2["gid"].(int) - if gid < 0 { - return Import{}, errors.Errorf("Gid cannot be negative: %v", gid) - } - } + if ret.Dest != "" && !filepath.IsAbs(ret.Dest) { + return Import{}, errors.Errorf("'dest' path cannot be relative for: %#v", v) + } - return Import{Hash: hash, Path: fmt.Sprintf("%v", m2["path"]), Dest: dest, Mode: mode, Uid: uid, Gid: gid}, nil + if mode != -1 { + m := fs.FileMode(mode) + ret.Mode = &m } - // if it's not a map then it's a string - s, ok := v.(string) - if ok { - return Import{Hash: "", Path: fmt.Sprintf("%v", s), Dest: "", Uid: uid, Gid: gid}, nil + // Empty values are -1 + if ret.Uid != lib.UidEmpty && ret.Uid < 0 { + return Import{}, errors.Errorf("'uid' (%d) cannot be negative: %v", ret.Uid, v) } - return Import{}, errors.Errorf("Didn't find a matching type for: %#v", v) + if ret.Gid != lib.GidEmpty && ret.Gid < 0 { + return Import{}, errors.Errorf("'gid' (%d) cannot be negative: %v", ret.Gid, v) + } + + return ret, nil } // Custom UnmarshalYAML from string/map/slice of strings/slice of maps into Imports @@ -590,26 +574,17 @@ func (im *Imports) UnmarshalYAML(unmarshal func(interface{}) error) error { } imports, ok := data.([]interface{}) - if ok { - // imports are a list of either strings or maps - for _, v := range imports { - imp, err := getImportFromInterface(v) - if err != nil { - return err - } - *im = append(*im, imp) - } - } else { - if data != nil { - // import are either string or map - imp, err := getImportFromInterface(data) - if err != nil { - return err - } - *im = append(*im, imp) + if !ok { + return errors.Errorf("'imports' expected an array, found %s: %#v", reflect.TypeOf(data), data) + } + // imports are a list of either strings or maps + for _, v := range imports { + imp, err := getImportFromInterface(v) + if err != nil { + return err } + *im = append(*im, imp) } - return nil } diff --git a/pkg/types/layer_import_test.go b/pkg/types/layer_import_test.go new file mode 100644 index 00000000..6fd54fb0 --- /dev/null +++ b/pkg/types/layer_import_test.go @@ -0,0 +1,111 @@ +package types + +import ( + "io/fs" + "testing" + + "github.com/stretchr/testify/assert" + "gopkg.in/yaml.v2" +) + +func modePtr(mode int) *fs.FileMode { + m := fs.FileMode(mode) + return &m +} + +func TestGetImportFromInterface(t *testing.T) { + assert := assert.New(t) + hash1 := "b5bb9d8014a0f9b1d61e21e796d78dccdf1352f23cd32812f4850b878ae4944c" + tables := []struct { + desc string + val interface{} + expected Import + errstr string + }{ + {desc: "basic string", + val: "/path/to/file", + expected: Import{Path: "/path/to/file", Uid: -1, Gid: -1}}, + {desc: "relative string", + val: "path/to/file", + expected: Import{Path: "path/to/file", Uid: -1, Gid: -1}}, + {desc: "dict no dest", + val: map[interface{}]interface{}{ + "path": "/path/to/file", + "hash": hash1, + }, + expected: Import{Path: "/path/to/file", Dest: "", Hash: hash1, Uid: -1, Gid: -1}}, + {desc: "dest cannot be relative", + val: map[interface{}]interface{}{ + "path": "src1", + "dest": "dest1", + }, + errstr: "cannot be relative", + }, + {desc: "guid cannot be negative", + val: map[interface{}]interface{}{ + "path": "src1", + "uid": -2, + }, + errstr: "cannot be negative", + }, + {desc: "mode present", + val: map[interface{}]interface{}{ + "path": "src1", + "mode": 0755, + }, + expected: Import{Path: "src1", Dest: "", Mode: modePtr(0755), Uid: -1, Gid: -1}}, + } + + var found Import + var err error + for _, t := range tables { + found, err = getImportFromInterface(t.val) + if t.errstr == "" { + assert.NoError(err, t.desc) + assert.Equal(t.expected, found, t.desc) + } else { + assert.ErrorContains(err, t.errstr, t.desc) + } + } +} + +func TestUnmarshalImports(t *testing.T) { + assert := assert.New(t) + type importsContainer struct { + Imports []Import `yaml:"imports"` + } + tables := []struct { + desc string + yblob string + expected Imports + errstr string + }{ + {desc: "imports should not be a string", + yblob: "/path/to/file", + expected: Imports{}, + errstr: "xpected an array"}, + {desc: "imports should not be a dict", + yblob: "path: /path/to/file\ndest: /path/to/dest\n", + expected: Imports{}, + errstr: "xpected an array"}, + {desc: "example valid mixed string and dict", + yblob: "- f1\n- path: f2\n", + expected: Imports{ + Import{Path: "f1", Uid: -1, Gid: -1}, + Import{Path: "f2", Uid: -1, Gid: -1}, + }}, + } + var err error + found := Imports{} + for _, t := range tables { + err = yaml.Unmarshal([]byte(t.yblob), &found) + if t.errstr == "" { + if !assert.NoError(err, t.desc) { + continue + } + assert.Equal(t.expected, found) + } else { + assert.ErrorContains(err, t.errstr, t.desc) + } + } +}