Skip to content

Commit

Permalink
fix: Improvements to yaml.Unmarshal for types.Imports and add some go…
Browse files Browse the repository at this point in the history
… tests

Improvements here:
 * Drop duplicate code blocks from getImportFromInterface
   There were 2 identical codeblocks, one with a type convert to
   map[string]interface{} and one with type convert to
   map[interface{}]interface{}.  I cannot find any reason for the
   former
 * Add some golang tests to for unmarshalling.
 * Type check the values of all string Import fields, rather than just
   converting them to a string with 'fmt.Sprintf("%v")'

Signed-off-by: Scott Moser <[email protected]>
  • Loading branch information
smoser committed Sep 22, 2023
1 parent db8b6aa commit 6be6698
Show file tree
Hide file tree
Showing 2 changed files with 180 additions and 93 deletions.
162 changes: 69 additions & 93 deletions pkg/types/layer.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ import (
"github.com/anmitsu/go-shlex"
"github.com/pkg/errors"
"gopkg.in/yaml.v2"

"stackerbuild.io/stacker/pkg/lib"
)

const (
Expand Down Expand Up @@ -486,100 +488,83 @@ func requireImportHash(imports Imports) error {
return nil
}

// getImportFromInterface -
//
// an Import (an entry in 'imports'), can be written in yaml as either a string or a map[string]:
// imports:
// - /path/to-file
// - path: /path/f2
// This function gets a single entry in that list and returns the Import.
func getImportFromInterface(v interface{}) (Import, error) {
var hash, dest string
var mode *fs.FileMode
uid := -1
gid := -1
mode := -1
ret := Import{Mode: nil, Uid: lib.UidEmpty, Gid: lib.GidEmpty}

m, ok := v.(map[interface{}]interface{})
// if it is a simple string, that is the path
s, ok := v.(string)
if ok {
// check for nil hash so that we won't end up with "nil" string values
if m["hash"] == nil {
hash = ""
} else {
hash = fmt.Sprintf("%v", m["hash"])
}

if m["dest"] != nil {
if !filepath.IsAbs(m["dest"].(string)) {
return Import{}, errors.Errorf("Dest path cannot be relative for: %#v", v)
}
ret.Path = s
return ret, nil
}

dest = fmt.Sprintf("%s", m["dest"])
} else {
dest = ""
}
m, ok := v.(map[interface{}]interface{})
if !ok {
return Import{}, errors.Errorf("Didn't find a matching type for: %#v", v)
}

if m["mode"] != nil {
val := fs.FileMode(m["mode"].(int))
mode = &val
for k, _ := range m {
if _, ok := k.(string); !ok {
return Import{}, errors.Errorf("key '%s' in import is not a string: %#v", k, v)
}
}

if _, ok := m["uid"]; ok {
uid = m["uid"].(int)
if uid < 0 {
return Import{}, errors.Errorf("Uid cannot be negative: %v", uid)
}
// if present, these must have string values.
for name, dest := range map[string]*string{"hash": &ret.Hash, "path": &ret.Path, "dest": &ret.Dest} {
val, found := m[name]
if !found {
continue
}

if _, ok := m["gid"]; ok {
gid = m["gid"].(int)
if gid < 0 {
return Import{}, errors.Errorf("Gid cannot be negative: %v", gid)
}
s, ok := val.(string)
if !ok {
return Import{}, errors.Errorf("value for '%s' in import is not a string: %#v", name, v)
}

return Import{Hash: hash, Path: fmt.Sprintf("%v", m["path"]), Dest: dest, Mode: mode, Uid: uid, Gid: gid}, nil
*dest = s
}

m2, ok := v.(map[string]interface{})
if ok {
// check for nil hash so that we won't end up with "nil" string values
if m2["hash"] == nil {
hash = ""
} else {
hash = fmt.Sprintf("%v", m2["hash"])
}

if m2["dest"] != nil {
if !filepath.IsAbs(m2["dest"].(string)) {
return Import{}, errors.Errorf("Dest path cannot be relative for: %#v", v)
}

dest = fmt.Sprintf("%s", m["dest"])
} else {
dest = ""
// if present, these must have int values
for name, dest := range map[string]*int{"mode": &mode, "uid": &ret.Uid, "gid": &ret.Gid} {
val, found := m[name]
if !found {
continue
}

if m2["mode"] != nil {
val := fs.FileMode(m2["mode"].(int))
mode = &val
i, ok := val.(int)
if !ok {
return Import{}, errors.Errorf("value for '%s' in import is not an integer: %#v", name, v)
}
*dest = i
}

if _, ok := m2["uid"]; ok {
uid = m2["uid"].(int)
if uid < 0 {
return Import{}, errors.Errorf("Uid cannot be negative: %v", uid)
}
}
if ret.Path == "" {
return ret, errors.Errorf("No 'path' entry found in import: %#v", v)
}

if _, ok := m2["gid"]; ok {
gid = m2["gid"].(int)
if gid < 0 {
return Import{}, errors.Errorf("Gid cannot be negative: %v", gid)
}
}
if ret.Dest != "" && !filepath.IsAbs(ret.Dest) {
return Import{}, errors.Errorf("'dest' path cannot be relative for: %#v", v)
}

return Import{Hash: hash, Path: fmt.Sprintf("%v", m2["path"]), Dest: dest, Mode: mode, Uid: uid, Gid: gid}, nil
if mode != -1 {
m := fs.FileMode(mode)
ret.Mode = &m
}

// if it's not a map then it's a string
s, ok := v.(string)
if ok {
return Import{Hash: "", Path: fmt.Sprintf("%v", s), Dest: "", Uid: uid, Gid: gid}, nil
// Empty values are -1
if ret.Uid != lib.UidEmpty && ret.Uid < 0 {
return Import{}, errors.Errorf("'uid' (%d) cannot be negative: %v", ret.Uid, v)
}
return Import{}, errors.Errorf("Didn't find a matching type for: %#v", v)
if ret.Gid != lib.GidEmpty && ret.Gid < 0 {
return Import{}, errors.Errorf("'gid' (%d) cannot be negative: %v", ret.Gid, v)
}

return ret, nil
}

// Custom UnmarshalYAML from string/map/slice of strings/slice of maps into Imports
Expand All @@ -590,26 +575,17 @@ func (im *Imports) UnmarshalYAML(unmarshal func(interface{}) error) error {
}

imports, ok := data.([]interface{})
if ok {
// imports are a list of either strings or maps
for _, v := range imports {
imp, err := getImportFromInterface(v)
if err != nil {
return err
}
*im = append(*im, imp)
}
} else {
if data != nil {
// import are either string or map
imp, err := getImportFromInterface(data)
if err != nil {
return err
}
*im = append(*im, imp)
if !ok {
return errors.Errorf("'imports' expected an array, found %s: %#v", reflect.TypeOf(data), data)
}
// imports are a list of either strings or maps
for _, v := range imports {
imp, err := getImportFromInterface(v)
if err != nil {
return err
}
*im = append(*im, imp)
}

return nil
}

Expand Down
111 changes: 111 additions & 0 deletions pkg/types/layer_import_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
package types

import (
"io/fs"
"testing"

"github.com/stretchr/testify/assert"
"gopkg.in/yaml.v2"
)

func modePtr(mode int) *fs.FileMode {
m := fs.FileMode(mode)
return &m
}

func TestGetImportFromInterface(t *testing.T) {
assert := assert.New(t)
hash1 := "b5bb9d8014a0f9b1d61e21e796d78dccdf1352f23cd32812f4850b878ae4944c"
tables := []struct {
desc string
val interface{}
expected Import
errstr string
}{
{desc: "basic string",
val: "/path/to/file",
expected: Import{Path: "/path/to/file", Uid: -1, Gid: -1}},
{desc: "relative string",
val: "path/to/file",
expected: Import{Path: "path/to/file", Uid: -1, Gid: -1}},
{desc: "dict no dest",
val: map[interface{}]interface{}{
"path": "/path/to/file",
"hash": hash1,
},
expected: Import{Path: "/path/to/file", Dest: "", Hash: hash1, Uid: -1, Gid: -1}},
{desc: "dest cannot be relative",
val: map[interface{}]interface{}{
"path": "src1",
"dest": "dest1",
},
errstr: "cannot be relative",
},
{desc: "guid cannot be negative",
val: map[interface{}]interface{}{
"path": "src1",
"uid": -2,
},
errstr: "cannot be negative",
},
{desc: "mode present",
val: map[interface{}]interface{}{
"path": "src1",
"mode": 0755,
},
expected: Import{Path: "src1", Dest: "", Mode: modePtr(0755), Uid: -1, Gid: -1}},
}

var found Import
var err error
for _, t := range tables {
found, err = getImportFromInterface(t.val)
if t.errstr == "" {
assert.NoError(err, t.desc)
assert.Equal(t.expected, found, t.desc)
} else {
assert.ErrorContains(err, t.errstr, t.desc)
}
}
}

func TestUnmarshalImports(t *testing.T) {
assert := assert.New(t)
type importsContainer struct {
Imports []Import `yaml:"imports"`
}
tables := []struct {
desc string
yblob string
expected Imports
errstr string
}{
{desc: "imports should not be a string",
yblob: "/path/to/file",
expected: Imports{},
errstr: "xpected an array"},
{desc: "imports should not be a dict",
yblob: "path: /path/to/file\ndest: /path/to/dest\n",
expected: Imports{},
errstr: "xpected an array"},
{desc: "example valid mixed string and dict",
yblob: "- f1\n- path: f2\n",
expected: Imports{
Import{Path: "f1", Uid: -1, Gid: -1},
Import{Path: "f2", Uid: -1, Gid: -1},
}},
}
var err error
found := Imports{}
for _, t := range tables {
err = yaml.Unmarshal([]byte(t.yblob), &found)
if t.errstr == "" {
if !assert.NoError(err, t.desc) {
continue
}
assert.Equal(t.expected, found)
} else {
assert.ErrorContains(err, t.errstr, t.desc)
}
}
}

0 comments on commit 6be6698

Please sign in to comment.