Skip to content

Commit

Permalink
[wip] validation framework
Browse files Browse the repository at this point in the history
Signed-off-by: Moritz Sanft <[email protected]>
  • Loading branch information
msanft committed Oct 19, 2023
1 parent ee54b71 commit 95a2384
Show file tree
Hide file tree
Showing 5 changed files with 337 additions and 0 deletions.
57 changes: 57 additions & 0 deletions internal/validation/constraints.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/

package validation

import (
"fmt"
"regexp"
)

// Constraint is a constraint on a document or a field of a document.
type Constraint func() (valid bool, err error)

// MatchRegex is a constraint that if s matches regex.
func MatchRegex(s string, regex string) Constraint {
return func() (valid bool, err error) {
if !regexp.MustCompile(regex).MatchString(s) {
return false, fmt.Errorf("%s must match the pattern %s", s, regex)
}
return true, nil
}
}

// Equal is a constraint that if s is equal to t.
func Equal[T comparable](s T, t T) Constraint {
return func() (valid bool, err error) {
if s != t {
return false, fmt.Errorf("%v must be equal to %v", s, t)
}
return true, nil
}
}

// NotEmpty is a constraint that if s is not empty.
func NotEmpty[T comparable](s T) Constraint {
return func() (valid bool, err error) {
var zero T
if s == zero {
return false, fmt.Errorf("%v must not be empty", s)
}
return true, nil
}
}

// Empty is a constraint that if s is empty.
func Empty[T comparable](s T) Constraint {
return func() (valid bool, err error) {
var zero T
if s != zero {
return false, fmt.Errorf("%v must be empty", s)
}
return true, nil
}
}
123 changes: 123 additions & 0 deletions internal/validation/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
package validation

import (
"errors"
"fmt"
"reflect"
"strings"
)

type ValidationError struct {
Path string
Err error
}

// NewValidationError creates a new ValidationError.
//
// To find the path to the field that failed validation, it traverses the
// top level struct recursively until it finds a field that matches the
// reference to the field that failed validation.
func NewValidationError(topLevelStruct any, field any, errMsg error) *ValidationError {
path, err := getDocumentPath(topLevelStruct, field)
if err != nil {
panic(fmt.Sprintf("cannot find path to field: %v", err))
}

return &ValidationError{
Path: path,
Err: errMsg,
}
}

// Error implements the error interface.
func (e *ValidationError) Error() string {
return fmt.Sprintf("validating %s: %s", e.Path, e.Err)
}

// Unwrap implements the error interface.
func (e *ValidationError) Unwrap() error {
return e.Err
}

// getDocumentPath finds the JSON / YAML path of field in doc.
func getDocumentPath(doc any, field any) (string, error) {
needleAddr := reflect.ValueOf(field).Elem().UnsafeAddr()
needleType := reflect.TypeOf(field)

// traverse the top level struct (i.e. the "haystack") until addr (i.e. the "needle") is found
return traverse(doc, needleAddr, needleType, []string{})
}

// traverse reverses haystack recursively until it finds a field that matches
// the reference in needle.
//
// If it traverses a level down, it
// appends the name of the struct tag of the field to path.
//
// When a field matches the reference to the given field, it returns the
// path to the field, joined with ".".
func traverse(haystack any, needleAddr uintptr, needleType reflect.Type, path []string) (string, error) {
// recursion anchor: doc is the field we are looking for.
// Join the path and return. Since the first value of a struct has
// the same address as the struct itself, we need to check the type as well.
haystackAddr := reflect.ValueOf(haystack).Elem().UnsafeAddr()
haystackType := reflect.TypeOf(haystack)
if haystackAddr == needleAddr && haystackType == needleType {
return strings.Join(path, "."), nil
}

haystackVal := reflect.ValueOf(haystack)
kind := reflect.TypeOf(haystack).Kind()
switch kind {
case reflect.Pointer, reflect.UnsafePointer:
// Dereference pointer and continue.
return traverse(haystackVal.Elem(), needleAddr, needleType, path)
case reflect.Struct:
// Traverse all visible struct fields.
for _, field := range reflect.VisibleFields(reflect.TypeOf(haystack)) {
// skip unexported fields
if field.IsExported() {
// When a field is not the needle and cannot be traversed further,
// a errCannotTraverse is returned. Therefore, we only want to handle
// the case where the field is the needle.
if path, err := traverse(field, needleAddr, needleType, appendByStructTag(path, field)); err == nil {
return path, nil
}
}
}
case reflect.Slice, reflect.Array:
// Traverse slice / Array elements
for i := 0; i < haystackVal.Len(); i++ {
// see struct case
if path, err := traverse(haystackVal.Index(i), needleAddr, needleType, append(path, fmt.Sprintf("%d", i))); err == nil {
return path, nil
}
}
case reflect.Map:
// Traverse map elements
for _, key := range haystackVal.MapKeys() {
// see struct case
if path, err := traverse(haystackVal.MapIndex(key), needleAddr, needleType, append(path, key.String())); err == nil {
return path, nil
}
}
}
// Primitive type, but not the value we are looking for.
// Return a
return "", errCannotTraverse
}

// errCannotTraverse is returned when a field cannot be traversed further.
var errCannotTraverse = errors.New("cannot traverse anymore")

// appendByStructTag appends the name of the JSON / YAML struct tag of field to path.
// If no struct tag is present, path is returned unchanged.
func appendByStructTag(path []string, field reflect.StructField) []string {
switch {
case field.Tag.Get("json") != "":
return append(path, field.Tag.Get("json"))
case field.Tag.Get("yaml") != "":
return append(path, field.Tag.Get("yaml"))
}
return path
}
23 changes: 23 additions & 0 deletions internal/validation/errors_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package validation

import (
"testing"

"github.com/stretchr/testify/require"
)

func TestNewValidationErrorSingleField(t *testing.T) {
st := &ErrorTestDoc{
ExportedField: "abc",
OtherField: 42,
}

err := NewValidationError(st, &st.OtherField, nil)
require.Error(t, err)
require.Contains(t, err.Error(), "validating otherField: <nil>")
}

type ErrorTestDoc struct {
ExportedField string `json:"exportedField" yaml:"exportedField"`
OtherField int `json:"otherField" yaml:"otherField"`
}
47 changes: 47 additions & 0 deletions internal/validation/validation.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/

/*
Package validation provides a unified document validation interface for use within the Constellation CLI.
It validates documents that specify a set of constraints on their content.
*/
package validation

import "errors"

// NewValidator creates a new Validator.
func NewValidator() *Validator {
return &Validator{}
}

// Validator validates documents.
type Validator struct{}

// Validatable is implemented by documents that can be validated.
type Validatable interface {
Constraints() []Constraint
}

// ValidateOptions are the options to use when validating a document.
type ValidateOptions struct {
// FailFast stops validation on the first error.
FailFast bool
}

// Validate validates a document using the given options.
func (v *Validator) Validate(doc Validatable, opts ValidateOptions) error {
var retErr error
for _, c := range doc.Constraints() {
if valid, err := c(); !valid {
if opts.FailFast {
return err
}
retErr = errors.Join(retErr, err)
}
}
return retErr
}
87 changes: 87 additions & 0 deletions internal/validation/validation_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/

package validation

import (
"fmt"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestValidate(t *testing.T) {
testCases := map[string]struct {
doc Validatable
opts ValidateOptions
wantErr bool
errAssertion func(*assert.Assertions, error) bool
}{
"valid": {
doc: &exampleDoc{
strField: "abc",
},
opts: ValidateOptions{},
},
"invalid": {
doc: &exampleDoc{
strField: "def",
},
wantErr: true,
errAssertion: func(assert *assert.Assertions, err error) bool {
return assert.Contains(err.Error(), "strField must be abc")
},
opts: ValidateOptions{},
},
}

for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)

err := NewValidator().Validate(tc.doc, tc.opts)
if tc.wantErr {
require.Error(err)
if !tc.errAssertion(assert, err) {
t.Fatalf("unexpected error: %v", err)
}
} else {
require.NoError(err)
}
})
}

}

type exampleDoc struct {
strField string
numField int
nested nestedExampleDoc
nestedPtr *nestedExampleDoc
}

type nestedExampleDoc struct {
strField string
numField int
}

func (d *exampleDoc) Constraints() []Constraint {
return []Constraint{
d.strFieldNeedsToBeAbc,
MatchRegex(d.strField, "^[a-z]+$"),
Equal(d.numField, 42),
}
}

// StrFieldNeedsToBeAbc is an example for a custom constraint.
func (d *exampleDoc) strFieldNeedsToBeAbc() (bool, error) {
if d.strField != "abc" {
return false, fmt.Errorf("%s must be abc", d.strField)
}
return true, nil
}

0 comments on commit 95a2384

Please sign in to comment.