-
Notifications
You must be signed in to change notification settings - Fork 55
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Moritz Sanft <[email protected]>
- Loading branch information
Showing
5 changed files
with
337 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
/* | ||
Copyright (c) Edgeless Systems GmbH | ||
SPDX-License-Identifier: AGPL-3.0-only | ||
*/ | ||
|
||
package validation | ||
|
||
import ( | ||
"fmt" | ||
"regexp" | ||
) | ||
|
||
// Constraint is a constraint on a document or a field of a document. | ||
type Constraint func() (valid bool, err error) | ||
|
||
// MatchRegex is a constraint that if s matches regex. | ||
func MatchRegex(s string, regex string) Constraint { | ||
return func() (valid bool, err error) { | ||
if !regexp.MustCompile(regex).MatchString(s) { | ||
return false, fmt.Errorf("%s must match the pattern %s", s, regex) | ||
} | ||
return true, nil | ||
} | ||
} | ||
|
||
// Equal is a constraint that if s is equal to t. | ||
func Equal[T comparable](s T, t T) Constraint { | ||
return func() (valid bool, err error) { | ||
if s != t { | ||
return false, fmt.Errorf("%v must be equal to %v", s, t) | ||
} | ||
return true, nil | ||
} | ||
} | ||
|
||
// NotEmpty is a constraint that if s is not empty. | ||
func NotEmpty[T comparable](s T) Constraint { | ||
return func() (valid bool, err error) { | ||
var zero T | ||
if s == zero { | ||
return false, fmt.Errorf("%v must not be empty", s) | ||
} | ||
return true, nil | ||
} | ||
} | ||
|
||
// Empty is a constraint that if s is empty. | ||
func Empty[T comparable](s T) Constraint { | ||
return func() (valid bool, err error) { | ||
var zero T | ||
if s != zero { | ||
return false, fmt.Errorf("%v must be empty", s) | ||
} | ||
return true, nil | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
package validation | ||
|
||
import ( | ||
"errors" | ||
"fmt" | ||
"reflect" | ||
"strings" | ||
) | ||
|
||
type ValidationError struct { | ||
Path string | ||
Err error | ||
} | ||
|
||
// NewValidationError creates a new ValidationError. | ||
// | ||
// To find the path to the field that failed validation, it traverses the | ||
// top level struct recursively until it finds a field that matches the | ||
// reference to the field that failed validation. | ||
func NewValidationError(topLevelStruct any, field any, errMsg error) *ValidationError { | ||
path, err := getDocumentPath(topLevelStruct, field) | ||
if err != nil { | ||
panic(fmt.Sprintf("cannot find path to field: %v", err)) | ||
} | ||
|
||
return &ValidationError{ | ||
Path: path, | ||
Err: errMsg, | ||
} | ||
} | ||
|
||
// Error implements the error interface. | ||
func (e *ValidationError) Error() string { | ||
return fmt.Sprintf("validating %s: %s", e.Path, e.Err) | ||
} | ||
|
||
// Unwrap implements the error interface. | ||
func (e *ValidationError) Unwrap() error { | ||
return e.Err | ||
} | ||
|
||
// getDocumentPath finds the JSON / YAML path of field in doc. | ||
func getDocumentPath(doc any, field any) (string, error) { | ||
needleAddr := reflect.ValueOf(field).Elem().UnsafeAddr() | ||
needleType := reflect.TypeOf(field) | ||
|
||
// traverse the top level struct (i.e. the "haystack") until addr (i.e. the "needle") is found | ||
return traverse(doc, needleAddr, needleType, []string{}) | ||
} | ||
|
||
// traverse reverses haystack recursively until it finds a field that matches | ||
// the reference in needle. | ||
// | ||
// If it traverses a level down, it | ||
// appends the name of the struct tag of the field to path. | ||
// | ||
// When a field matches the reference to the given field, it returns the | ||
// path to the field, joined with ".". | ||
func traverse(haystack any, needleAddr uintptr, needleType reflect.Type, path []string) (string, error) { | ||
// recursion anchor: doc is the field we are looking for. | ||
// Join the path and return. Since the first value of a struct has | ||
// the same address as the struct itself, we need to check the type as well. | ||
haystackAddr := reflect.ValueOf(haystack).Elem().UnsafeAddr() | ||
haystackType := reflect.TypeOf(haystack) | ||
if haystackAddr == needleAddr && haystackType == needleType { | ||
return strings.Join(path, "."), nil | ||
} | ||
|
||
haystackVal := reflect.ValueOf(haystack) | ||
kind := reflect.TypeOf(haystack).Kind() | ||
switch kind { | ||
case reflect.Pointer, reflect.UnsafePointer: | ||
// Dereference pointer and continue. | ||
return traverse(haystackVal.Elem(), needleAddr, needleType, path) | ||
case reflect.Struct: | ||
// Traverse all visible struct fields. | ||
for _, field := range reflect.VisibleFields(reflect.TypeOf(haystack)) { | ||
// skip unexported fields | ||
if field.IsExported() { | ||
// When a field is not the needle and cannot be traversed further, | ||
// a errCannotTraverse is returned. Therefore, we only want to handle | ||
// the case where the field is the needle. | ||
if path, err := traverse(field, needleAddr, needleType, appendByStructTag(path, field)); err == nil { | ||
return path, nil | ||
} | ||
} | ||
} | ||
case reflect.Slice, reflect.Array: | ||
// Traverse slice / Array elements | ||
for i := 0; i < haystackVal.Len(); i++ { | ||
// see struct case | ||
if path, err := traverse(haystackVal.Index(i), needleAddr, needleType, append(path, fmt.Sprintf("%d", i))); err == nil { | ||
return path, nil | ||
} | ||
} | ||
case reflect.Map: | ||
// Traverse map elements | ||
for _, key := range haystackVal.MapKeys() { | ||
// see struct case | ||
if path, err := traverse(haystackVal.MapIndex(key), needleAddr, needleType, append(path, key.String())); err == nil { | ||
return path, nil | ||
} | ||
} | ||
} | ||
// Primitive type, but not the value we are looking for. | ||
// Return a | ||
return "", errCannotTraverse | ||
} | ||
|
||
// errCannotTraverse is returned when a field cannot be traversed further. | ||
var errCannotTraverse = errors.New("cannot traverse anymore") | ||
|
||
// appendByStructTag appends the name of the JSON / YAML struct tag of field to path. | ||
// If no struct tag is present, path is returned unchanged. | ||
func appendByStructTag(path []string, field reflect.StructField) []string { | ||
switch { | ||
case field.Tag.Get("json") != "": | ||
return append(path, field.Tag.Get("json")) | ||
case field.Tag.Get("yaml") != "": | ||
return append(path, field.Tag.Get("yaml")) | ||
} | ||
return path | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
package validation | ||
|
||
import ( | ||
"testing" | ||
|
||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
func TestNewValidationErrorSingleField(t *testing.T) { | ||
st := &ErrorTestDoc{ | ||
ExportedField: "abc", | ||
OtherField: 42, | ||
} | ||
|
||
err := NewValidationError(st, &st.OtherField, nil) | ||
require.Error(t, err) | ||
require.Contains(t, err.Error(), "validating otherField: <nil>") | ||
} | ||
|
||
type ErrorTestDoc struct { | ||
ExportedField string `json:"exportedField" yaml:"exportedField"` | ||
OtherField int `json:"otherField" yaml:"otherField"` | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
/* | ||
Copyright (c) Edgeless Systems GmbH | ||
SPDX-License-Identifier: AGPL-3.0-only | ||
*/ | ||
|
||
/* | ||
Package validation provides a unified document validation interface for use within the Constellation CLI. | ||
It validates documents that specify a set of constraints on their content. | ||
*/ | ||
package validation | ||
|
||
import "errors" | ||
|
||
// NewValidator creates a new Validator. | ||
func NewValidator() *Validator { | ||
return &Validator{} | ||
} | ||
|
||
// Validator validates documents. | ||
type Validator struct{} | ||
|
||
// Validatable is implemented by documents that can be validated. | ||
type Validatable interface { | ||
Constraints() []Constraint | ||
} | ||
|
||
// ValidateOptions are the options to use when validating a document. | ||
type ValidateOptions struct { | ||
// FailFast stops validation on the first error. | ||
FailFast bool | ||
} | ||
|
||
// Validate validates a document using the given options. | ||
func (v *Validator) Validate(doc Validatable, opts ValidateOptions) error { | ||
var retErr error | ||
for _, c := range doc.Constraints() { | ||
if valid, err := c(); !valid { | ||
if opts.FailFast { | ||
return err | ||
} | ||
retErr = errors.Join(retErr, err) | ||
} | ||
} | ||
return retErr | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
/* | ||
Copyright (c) Edgeless Systems GmbH | ||
SPDX-License-Identifier: AGPL-3.0-only | ||
*/ | ||
|
||
package validation | ||
|
||
import ( | ||
"fmt" | ||
"testing" | ||
|
||
"github.com/stretchr/testify/assert" | ||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
func TestValidate(t *testing.T) { | ||
testCases := map[string]struct { | ||
doc Validatable | ||
opts ValidateOptions | ||
wantErr bool | ||
errAssertion func(*assert.Assertions, error) bool | ||
}{ | ||
"valid": { | ||
doc: &exampleDoc{ | ||
strField: "abc", | ||
}, | ||
opts: ValidateOptions{}, | ||
}, | ||
"invalid": { | ||
doc: &exampleDoc{ | ||
strField: "def", | ||
}, | ||
wantErr: true, | ||
errAssertion: func(assert *assert.Assertions, err error) bool { | ||
return assert.Contains(err.Error(), "strField must be abc") | ||
}, | ||
opts: ValidateOptions{}, | ||
}, | ||
} | ||
|
||
for name, tc := range testCases { | ||
t.Run(name, func(t *testing.T) { | ||
assert := assert.New(t) | ||
require := require.New(t) | ||
|
||
err := NewValidator().Validate(tc.doc, tc.opts) | ||
if tc.wantErr { | ||
require.Error(err) | ||
if !tc.errAssertion(assert, err) { | ||
t.Fatalf("unexpected error: %v", err) | ||
} | ||
} else { | ||
require.NoError(err) | ||
} | ||
}) | ||
} | ||
|
||
} | ||
|
||
type exampleDoc struct { | ||
strField string | ||
numField int | ||
nested nestedExampleDoc | ||
nestedPtr *nestedExampleDoc | ||
} | ||
|
||
type nestedExampleDoc struct { | ||
strField string | ||
numField int | ||
} | ||
|
||
func (d *exampleDoc) Constraints() []Constraint { | ||
return []Constraint{ | ||
d.strFieldNeedsToBeAbc, | ||
MatchRegex(d.strField, "^[a-z]+$"), | ||
Equal(d.numField, 42), | ||
} | ||
} | ||
|
||
// StrFieldNeedsToBeAbc is an example for a custom constraint. | ||
func (d *exampleDoc) strFieldNeedsToBeAbc() (bool, error) { | ||
if d.strField != "abc" { | ||
return false, fmt.Errorf("%s must be abc", d.strField) | ||
} | ||
return true, nil | ||
} |