Skip to content

Commit

Permalink
Typed string enums (#107)
Browse files Browse the repository at this point in the history
* Added open for typed string enums

* Fix comments

* Address review comments
  • Loading branch information
jaredoconnell authored Nov 15, 2024
1 parent 5ed0055 commit 2b68585
Show file tree
Hide file tree
Showing 7 changed files with 202 additions and 62 deletions.
130 changes: 77 additions & 53 deletions schema/enum.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@ import (
"strings"
)

type enumValue interface {
type serializedEnumValue interface {
int64 | string
}
type enumValue interface {
~int64 | ~string
}

// Enum is an abstract schema for enumerated types.
type Enum[T enumValue] interface {
Expand All @@ -18,84 +21,99 @@ type Enum[T enumValue] interface {
ValidValues() map[T]*DisplayValue
}

type EnumSchema[T enumValue] struct {
type EnumSchema[S serializedEnumValue, T enumValue] struct {
ScalarType
ValidValuesMap map[T]*DisplayValue `json:"values"`
}

func (e EnumSchema[T]) ValidValues() map[T]*DisplayValue {
func (e EnumSchema[S, T]) ValidValues() map[T]*DisplayValue {
return e.ValidValuesMap
}

func (e EnumSchema[T]) ReflectedType() reflect.Type {
func (e EnumSchema[S, T]) ReflectedType() reflect.Type {
var defaultValue T
return reflect.TypeOf(defaultValue)
}

func (e EnumSchema[T]) ValidateCompatibility(typeOrData any) error {
func (e EnumSchema[S, T]) ValidateCompatibility(typeOrData any) error {
// Check if it's a schema type. If it is, verify it. If not, verify it as data.
value := reflect.ValueOf(typeOrData)
if reflect.Indirect(value).Kind() != reflect.Struct {
// Validate as data
return e.Validate(typeOrData)
return e.Validate(typeOrData) // Validate as data
}
field := reflect.Indirect(value).FieldByName("EnumSchema")

if !field.IsValid() {
// Validate as data
return e.Validate(typeOrData)
enumField := reflect.Indirect(value).FieldByName("EnumSchema")
if !enumField.IsValid() {
return e.Validate(typeOrData) // Validate as data
}

// Validate the type of EnumSchema
fieldAsInterface := field.Interface()
schemaType, ok := fieldAsInterface.(EnumSchema[T])
if !ok {
return &ConstraintError{
Message: fmt.Sprintf(
"validation failed for enum. Found type (%T) does not match expected type (%T)",
fieldAsInterface, e),
}
validValuesMapField := enumField.FieldByName("ValidValuesMap")
if !validValuesMapField.IsValid() {
return fmt.Errorf("failed to get values map in enum %T", e)
}

// Validate the valid values
for key, display := range e.ValidValuesMap {
matchingInputDisplay := schemaType.ValidValuesMap[key]
if matchingInputDisplay == nil {
foundValues := reflect.ValueOf(schemaType.ValidValuesMap).MapKeys()
expectedValues := reflect.ValueOf(e.ValidValuesMap).MapKeys()
for _, reflectKey := range validValuesMapField.MapKeys() {
var defaultValue T
defaultType := reflect.TypeOf(defaultValue)
if !reflectKey.CanConvert(defaultType) {
return fmt.Errorf("invalid enum value type %s", reflectKey.Type())
}
keyToCompare := reflectKey.Convert(defaultType).Interface()
// Validate that the key in the data under test is present in the self enum schema.
selfDisplayValue, found := e.ValidValuesMap[keyToCompare.(T)]
if !found {
return &ConstraintError{
Message: fmt.Sprintf("invalid enum values for type '%T' for custom enum of type %T. "+
"Found key %v. Expected values: %v",
e, typeOrData, keyToCompare, e.ValidValuesMap),
}
}
// Validate that the displays are compatible.
otherDisplay := validValuesMapField.MapIndex(reflectKey)
if !otherDisplay.IsValid() {
return fmt.Errorf("failed to get value at key in ValidateCompatibility")
}
otherDisplayValue := otherDisplay.Interface().(*DisplayValue)
switch {
case (selfDisplayValue == nil || selfDisplayValue.Name() == nil) &&
(otherDisplayValue == nil || otherDisplayValue.Name() == nil):
return nil
case otherDisplayValue == nil || otherDisplayValue.Name() == nil:
return &ConstraintError{
Message: fmt.Sprintf("display values for key %s is missing in compared data %T",
keyToCompare, typeOrData),
}
case selfDisplayValue == nil || selfDisplayValue.Name() == nil:
return &ConstraintError{
Message: fmt.Sprintf("invalid enum values for type '%T' for custom enum. Missing key %v (and potentially others). Expected values: %s, Has values: %s",
typeOrData, key, expectedValues, foundValues),
Message: fmt.Sprintf("display values for key %s is missing in the schema for %T, but present"+
" in compared data %T", keyToCompare, e, typeOrData),
}
} else if *display.Name() != *matchingInputDisplay.Name() {
case *selfDisplayValue.Name() != *otherDisplayValue.Name():
return &ConstraintError{
Message: fmt.Sprintf(
"invalid enum value. Mismatched name for key %v. Expected %s, got %s",
key, *display.Name(), *matchingInputDisplay.Name()),
keyToCompare, *selfDisplayValue.Name(), *otherDisplayValue.Name()),
}
}
}
return nil

}

func (e EnumSchema[T]) Validate(d any) error {
data, err := e.asType(d)
func (e EnumSchema[S, T]) Validate(d any) error {
_, data, err := e.asType(d)
if err != nil {
return err
}
return e.ValidateType(data)
}

func (e EnumSchema[T]) Serialize(d any) (any, error) {
data, err := e.asType(d)
func (e EnumSchema[S, T]) Serialize(d any) (any, error) {
serializedData, data, err := e.asType(d)
if err != nil {
return data, err
return serializedData, err
}
return data, e.Validate(data)
return serializedData, e.Validate(data)
}

func (e EnumSchema[T]) ValidateType(data T) error {
func (e EnumSchema[S, T]) ValidateType(data T) error {
for validValue := range e.ValidValuesMap {
if validValue == data {
return nil
Expand All @@ -119,22 +137,28 @@ func (e EnumSchema[T]) ValidateType(data T) error {
}
}

func (e EnumSchema[T]) SerializeType(data T) (any, error) {
func (e EnumSchema[S, T]) SerializeType(data T) (any, error) {
return data, e.Validate(data)
}

func (e EnumSchema[T]) asType(d any) (T, error) {
data, ok := d.(T)
if !ok {
var defaultValue T
tType := reflect.TypeOf(defaultValue)
dValue := reflect.ValueOf(d)
if !dValue.CanConvert(tType) {
return defaultValue, &ConstraintError{
Message: fmt.Sprintf("%T is not a valid data type for an int schema.", d),
}
func (e EnumSchema[S, T]) asType(d any) (S, T, error) {
var serializedDefaultValue S
serializedType := reflect.TypeOf(serializedDefaultValue)
dValue := reflect.ValueOf(d)
var unserializedDefaultValue T
unserializedType := reflect.TypeOf(unserializedDefaultValue)

if !dValue.CanConvert(serializedType) {
return serializedDefaultValue, unserializedDefaultValue, &ConstraintError{
Message: fmt.Sprintf("%T is not a valid data type for an %T schema.", d, serializedDefaultValue),
}
}
if !dValue.CanConvert(unserializedType) {
return serializedDefaultValue, unserializedDefaultValue, &ConstraintError{
Message: fmt.Sprintf("%T is not a valid data type for an %T schema's unserialized type %T", d, e, unserializedType),
}
data = dValue.Convert(tType).Interface().(T)
}
return data, nil
serializedData := dValue.Convert(serializedType).Interface().(S)
unserializedData := dValue.Convert(unserializedType).Interface().(T)
return serializedData, unserializedData, nil
}
6 changes: 3 additions & 3 deletions schema/enum_int.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import "fmt"
// NewIntEnumSchema creates a new enum of integer values.
func NewIntEnumSchema(validValues map[int64]*DisplayValue, units *UnitsDefinition) *IntEnumSchema {
return &IntEnumSchema{
EnumSchema[int64]{
EnumSchema[int64, int64]{
ValidValuesMap: validValues,
},
units,
Expand All @@ -20,8 +20,8 @@ type IntEnum interface {

// IntEnumSchema is an enum type with integer values.
type IntEnumSchema struct {
EnumSchema[int64] `json:",inline"`
IntUnits *UnitsDefinition `json:"units"`
EnumSchema[int64, int64] `json:",inline"`
IntUnits *UnitsDefinition `json:"units"`
}

func (i IntEnumSchema) TypeID() TypeID {
Expand Down
31 changes: 25 additions & 6 deletions schema/enum_string.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,19 @@ import "fmt"
// NewStringEnumSchema creates a new enum of string values.
func NewStringEnumSchema(validValues map[string]*DisplayValue) *StringEnumSchema {
return &StringEnumSchema{
EnumSchema[string]{
TypedStringEnumSchema[string]{
EnumSchema[string, string]{
ValidValuesMap: validValues,
},
},
}
}

// NewTypedStringEnumSchema allows the use of a type with string as an underlying type.
// Useful for external APIs that are being mapped to a schema that use string enums.
func NewTypedStringEnumSchema[T ~string](validValues map[T]*DisplayValue) *TypedStringEnumSchema[T] {
return &TypedStringEnumSchema[T]{
EnumSchema[string, T]{
ValidValuesMap: validValues,
},
}
Expand All @@ -18,15 +30,22 @@ type StringEnum interface {

// StringEnumSchema is an enum type with string values.
type StringEnumSchema struct {
EnumSchema[string] `json:",inline"`
TypedStringEnumSchema[string] `json:",inline"`
}

// TypedStringEnumSchema is an enum type with string values, but with a generic
// element for golang enums that have an underlying string type.
type TypedStringEnumSchema[T ~string] struct {
EnumSchema[string, T] `json:",inline"`
}

func (s StringEnumSchema) TypeID() TypeID {
func (s TypedStringEnumSchema[T]) TypeID() TypeID {
return TypeIDStringEnum
}

func (s StringEnumSchema) Unserialize(data any) (any, error) {
typedData, err := stringInputMapper(data)
func (s TypedStringEnumSchema[T]) Unserialize(data any) (any, error) {
strData, err := stringInputMapper(data)
typedData := T(strData)
if err != nil {
return "", &ConstraintError{
Message: fmt.Sprintf("'%v' (type %T) is not a valid type for a '%T' enum", data, data, typedData),
Expand All @@ -35,7 +54,7 @@ func (s StringEnumSchema) Unserialize(data any) (any, error) {
return typedData, s.Validate(typedData)
}

func (s StringEnumSchema) UnserializeType(data any) (string, error) {
func (s TypedStringEnumSchema[T]) UnserializeType(data any) (string, error) {
unserialized, err := s.Unserialize(data)
if err != nil {
return "", err
Expand Down
28 changes: 28 additions & 0 deletions schema/enum_string_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ func TestStringEnumSerialization(t *testing.T) {
}

func TestStringEnumTypedSerialization(t *testing.T) {
// In this test a typed enum is being passed into Serialize, but
// it is not using a typed enum schema.
type Size string
s := schema.NewStringEnumSchema(map[string]*schema.DisplayValue{
"small": {NameValue: schema.PointerTo("Small")},
Expand All @@ -91,6 +93,22 @@ func TestStringEnumTypedSerialization(t *testing.T) {
assert.Equals(t, serializedData.(string), "small")
}

func TestTypedStringEnum(t *testing.T) {
type TypedString string
s := schema.NewTypedStringEnumSchema[TypedString](map[TypedString]*schema.DisplayValue{
"a": {NameValue: schema.PointerTo("a")},
"b": {NameValue: schema.PointerTo("b")},
})
serializedData, err := s.Serialize(TypedString("a"))
assert.NoError(t, err)
assert.Equals(t, serializedData.(string), "a")

unserialiedData, err := s.Unserialize(serializedData)
assert.NoError(t, err)
assert.InstanceOf[TypedString](t, unserialiedData)
assert.Equals(t, unserialiedData.(TypedString), "a")
}

func TestStringEnumJSONMarshal(t *testing.T) {
typeUnderTest := schema.NewStringEnumSchema(map[string]*schema.DisplayValue{
"small": {NameValue: schema.PointerTo("Small")},
Expand Down Expand Up @@ -142,13 +160,23 @@ func TestStringEnumSchemaCompatibilityValidation(t *testing.T) {
"b": {NameValue: schema.PointerTo("B")},
"c": {NameValue: schema.PointerTo("C")},
})
type TypedString string
s1Typed := schema.NewTypedStringEnumSchema[TypedString](map[TypedString]*schema.DisplayValue{
"a": {NameValue: schema.PointerTo("a")},
"b": {NameValue: schema.PointerTo("b")},
"c": {NameValue: schema.PointerTo("c")},
})

assert.NoError(t, s1.ValidateCompatibility(s1))
assert.NoError(t, s2.ValidateCompatibility(s2))
assert.NoError(t, s1Typed.ValidateCompatibility(s1Typed))
// Mismatched keys
assert.Error(t, s1.ValidateCompatibility(s2))
assert.Error(t, s2.ValidateCompatibility(s1))
// Mismatched names
assert.Error(t, s1.ValidateCompatibility(S1))
assert.Error(t, S1.ValidateCompatibility(s1))
// Different types but same schema.
assert.NoError(t, s1.ValidateCompatibility(s1Typed))
assert.NoError(t, s1Typed.ValidateCompatibility(s1))
}
53 changes: 53 additions & 0 deletions schema/map_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package schema_test

import (
"go.arcalot.io/assert"
"go.flow.arcalot.io/pluginsdk/schema/testdata"
"testing"

"go.flow.arcalot.io/pluginsdk/schema"
Expand Down Expand Up @@ -362,3 +363,55 @@ func TestMap_UnSerialize_Reversible(t *testing.T) {
// test reversiblity
assert.Equals(t, unserialized2, unserialized)
}

// Test type aliased map of type aliased string to a struct.
type TypedStringEnumTestType string

const (
testA TypedStringEnumTestType = "testA"
testB TypedStringEnumTestType = "testB"
)

func TestMap_AliasedTypes(t *testing.T) {
schemaForPrivateFieldStruct := schema.NewStructMappedObjectSchema[testdata.TestStructWithPrivateField](
"structWithPrivateField",
map[string]*schema.PropertySchema{
"field1": schema.NewPropertySchema(
schema.NewStringSchema(nil, nil, nil),
nil,
false,
nil,
nil,
nil,
schema.PointerTo("\"Hello world!\""),
nil,
),
},
)
schemaForEnum := schema.NewTypedStringEnumSchema[TypedStringEnumTestType](map[TypedStringEnumTestType]*schema.DisplayValue{
testA: nil,
testB: nil,
})
mapType := schema.NewMapSchema(
schemaForEnum,
schemaForPrivateFieldStruct,
nil,
nil,
)
serializedInput := map[any]any{
string(testA): map[string]any{
"field1": "test_field_value",
},
}
// Unserialize and validate
unserializedData, err := mapType.Unserialize(serializedInput)
assert.NoError(t, err)
assert.InstanceOf[map[TypedStringEnumTestType]testdata.TestStructWithPrivateField](t, unserializedData)
typedUnserializedData := unserializedData.(map[TypedStringEnumTestType]testdata.TestStructWithPrivateField)
assert.MapContainsKey(t, testA, typedUnserializedData)
assert.Equals(t, typedUnserializedData[testA], testdata.TestStructWithPrivateField{Field1: "test_field_value"})
// Re-serialize and validate
serializedOutput, err := mapType.Serialize(unserializedData)
assert.NoError(t, err)
assert.Equals[map[any]any](t, serializedOutput.(map[any]any), serializedInput)
}
Loading

0 comments on commit 2b68585

Please sign in to comment.