Skip to content

Commit

Permalink
feat: support sum types on ingress
Browse files Browse the repository at this point in the history
fixes #1388
  • Loading branch information
worstell committed May 8, 2024
1 parent 090442a commit 320e3ed
Show file tree
Hide file tree
Showing 18 changed files with 703 additions and 391 deletions.
2 changes: 1 addition & 1 deletion backend/controller/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -720,7 +720,7 @@ func (s *Service) callWithRequest(
return nil, err
}

err = ingress.ValidateCallBody(req.Msg.Body, verb, sch)
err = ingress.ValidateCallBody(ctx, req.Msg.Body, verb, sch)
if err != nil {
return nil, err
}
Expand Down
95 changes: 75 additions & 20 deletions backend/controller/ingress/alias.go
Original file line number Diff line number Diff line change
@@ -1,30 +1,75 @@
package ingress

import (
"context"
"fmt"

"github.com/TBD54566975/ftl/backend/schema"
)

func transformAliasedFields(sch *schema.Schema, t schema.Type, obj any, aliaser func(obj map[string]any, field *schema.Field) string) error {
func transformAliasedFields(
ctx context.Context,
sch *schema.Schema,
t schema.Type,
obj any,
aliaser func(obj map[string]any, field *schema.Field) string,
) error {
if obj == nil {
return nil
}
switch t := t.(type) {
case *schema.Ref:
data, err := sch.ResolveRefMonomorphised(t)
if err != nil {
return fmt.Errorf("%s: failed to resolve data type: %w", t.Pos, err)
}
m, ok := obj.(map[string]any)
if !ok {
return fmt.Errorf("%s: expected map, got %T", t.Pos, obj)
}
for _, field := range data.Fields {
name := aliaser(m, field)
if err := transformAliasedFields(sch, field.Type, m[name], aliaser); err != nil {
switch decl := sch.ResolveRef(t).(type) {
case *schema.Data:
data, err := sch.ResolveRefMonomorphised(t)
if err != nil {
return fmt.Errorf("%s: failed to resolve data type: %w", t.Pos, err)
}
m, ok := obj.(map[string]any)
if !ok {
return fmt.Errorf("%s: expected map, got %T", t.Pos, obj)
}
for _, field := range data.Fields {
name := aliaser(m, field)
if err := transformAliasedFields(ctx, sch, field.Type, m[name], aliaser); err != nil {
return err
}
}
case *schema.Enum:
if decl.IsValueEnum() {
return nil
}
tr, ok := schema.TypeRegistryFromContext(ctx).Get()
if !ok {
return fmt.Errorf("%s: type registry not found in context, cannot process enum variant %s", t.Pos, t)
}

// type enum
m, ok := obj.(map[string]any)
if !ok {
return fmt.Errorf("%s: expected map, got %T", t.Pos, obj)
}
name, ok := m["name"]
if !ok {
return fmt.Errorf("%s: expected type enum request to have 'name' field", t.Pos)
}
nameStr, ok := name.(string)
if !ok {
return fmt.Errorf("%s: expected 'name' field to be a string, got %T", t.Pos, name)
}
vType, ok := tr.GetVariantType(t.ToRefKey().String(), nameStr).Get()
if !ok {
return fmt.Errorf("%s: unknown variant %s", t.Pos, nameStr)
}
value, ok := m["value"]
if !ok {
return fmt.Errorf("%s: expected type enum request to have 'value' field", t.Pos)
}
if err := transformAliasedFields(ctx, sch, vType, value, aliaser); err != nil {
return err
}
case *schema.Config, *schema.Database, *schema.FSM, *schema.Secret, *schema.Verb:
return fmt.Errorf("%s: unsupported ref type %T", t.Pos, decl)
}

case *schema.Array:
Expand All @@ -33,7 +78,7 @@ func transformAliasedFields(sch *schema.Schema, t schema.Type, obj any, aliaser
return fmt.Errorf("%s: expected array, got %T", t.Pos, obj)
}
for _, elem := range a {
if err := transformAliasedFields(sch, t.Element, elem, aliaser); err != nil {
if err := transformAliasedFields(ctx, sch, t.Element, elem, aliaser); err != nil {
return err
}
}
Expand All @@ -44,25 +89,30 @@ func transformAliasedFields(sch *schema.Schema, t schema.Type, obj any, aliaser
return fmt.Errorf("%s: expected map, got %T", t.Pos, obj)
}
for key, value := range m {
if err := transformAliasedFields(sch, t.Key, key, aliaser); err != nil {
if err := transformAliasedFields(ctx, sch, t.Key, key, aliaser); err != nil {
return err
}
if err := transformAliasedFields(sch, t.Value, value, aliaser); err != nil {
if err := transformAliasedFields(ctx, sch, t.Value, value, aliaser); err != nil {
return err
}
}

case *schema.Optional:
return transformAliasedFields(sch, t.Type, obj, aliaser)
return transformAliasedFields(ctx, sch, t.Type, obj, aliaser)

case *schema.Any, *schema.Bool, *schema.Bytes, *schema.Float, *schema.Int,
*schema.String, *schema.Time, *schema.Unit:
}
return nil
}

func transformFromAliasedFields(ref *schema.Ref, sch *schema.Schema, request map[string]any) (map[string]any, error) {
return request, transformAliasedFields(sch, ref, request, func(obj map[string]any, field *schema.Field) string {
func transformFromAliasedFields(
ctx context.Context,
ref *schema.Ref,
sch *schema.Schema,
request map[string]any,
) (map[string]any, error) {
return request, transformAliasedFields(ctx, sch, ref, request, func(obj map[string]any, field *schema.Field) string {
if jsonAlias, ok := field.Alias(schema.AliasKindJSON).Get(); ok {
if _, ok := obj[field.Name]; !ok && obj[jsonAlias] != nil {
obj[field.Name] = obj[jsonAlias]
Expand All @@ -73,8 +123,13 @@ func transformFromAliasedFields(ref *schema.Ref, sch *schema.Schema, request map
})
}

func transformToAliasedFields(ref *schema.Ref, sch *schema.Schema, request map[string]any) (map[string]any, error) {
return request, transformAliasedFields(sch, ref, request, func(obj map[string]any, field *schema.Field) string {
func transformToAliasedFields(
ctx context.Context,
ref *schema.Ref,
sch *schema.Schema,
request map[string]any,
) (map[string]any, error) {
return request, transformAliasedFields(ctx, sch, ref, request, func(obj map[string]any, field *schema.Field) string {
if jsonAlias, ok := field.Alias(schema.AliasKindJSON).Get(); ok && field.Name != jsonAlias {
obj[jsonAlias] = obj[field.Name]
delete(obj, field.Name)
Expand Down
50 changes: 48 additions & 2 deletions backend/controller/ingress/alias_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ingress

import (
"context"
"testing"

"github.com/alecthomas/assert/v2"
Expand All @@ -11,6 +12,11 @@ import (
func TestTransformFromAliasedFields(t *testing.T) {
schemaText := `
module test {
enum TypeEnum {
A test.Inner
B String
}
data Inner {
waz String +alias json "foo"
}
Expand All @@ -21,12 +27,21 @@ func TestTransformFromAliasedFields(t *testing.T) {
array [test.Inner]
map {String: test.Inner}
optional test.Inner
typeEnum test.TypeEnum
}
}
`
ctx := context.Background()
tr := schema.NewTypeRegistry()
tr.RegisterSumType("test.TypeEnum", map[string]schema.Type{
"A": &schema.Ref{Module: "test", Name: "Inner"},
"B": &schema.String{},
})
ctx = schema.ContextWithTypeRegistry(ctx, tr.ToProto())

sch, err := schema.ParseString("test", schemaText)
assert.NoError(t, err)
actual, err := transformFromAliasedFields(&schema.Ref{Module: "test", Name: "Test"}, sch, map[string]any{
actual, err := transformFromAliasedFields(ctx, &schema.Ref{Module: "test", Name: "Test"}, sch, map[string]any{
"bar": "value",
"inner": map[string]any{
"foo": "value",
Expand All @@ -44,6 +59,10 @@ func TestTransformFromAliasedFields(t *testing.T) {
"optional": map[string]any{
"foo": "value",
},
"typeEnum": map[string]any{
"name": "A",
"value": map[string]any{"foo": "value"},
},
})
expected := map[string]any{
"scalar": "value",
Expand All @@ -63,6 +82,10 @@ func TestTransformFromAliasedFields(t *testing.T) {
"optional": map[string]any{
"waz": "value",
},
"typeEnum": map[string]any{
"name": "A",
"value": map[string]any{"waz": "value"},
},
}
assert.NoError(t, err)
assert.Equal(t, expected, actual)
Expand All @@ -71,6 +94,11 @@ func TestTransformFromAliasedFields(t *testing.T) {
func TestTransformToAliasedFields(t *testing.T) {
schemaText := `
module test {
enum TypeEnum {
A test.Inner
B String
}
data Inner {
waz String +alias json "foo"
}
Expand All @@ -81,12 +109,22 @@ func TestTransformToAliasedFields(t *testing.T) {
array [test.Inner]
map {String: test.Inner}
optional test.Inner
typeEnum test.TypeEnum
}
}
`

ctx := context.Background()
tr := schema.NewTypeRegistry()
tr.RegisterSumType("test.TypeEnum", map[string]schema.Type{
"A": &schema.Ref{Module: "test", Name: "Inner"},
"B": &schema.String{},
})
ctx = schema.ContextWithTypeRegistry(ctx, tr.ToProto())

sch, err := schema.ParseString("test", schemaText)
assert.NoError(t, err)
actual, err := transformToAliasedFields(&schema.Ref{Module: "test", Name: "Test"}, sch, map[string]any{
actual, err := transformToAliasedFields(ctx, &schema.Ref{Module: "test", Name: "Test"}, sch, map[string]any{
"scalar": "value",
"inner": map[string]any{
"waz": "value",
Expand All @@ -104,6 +142,10 @@ func TestTransformToAliasedFields(t *testing.T) {
"optional": map[string]any{
"waz": "value",
},
"typeEnum": map[string]any{
"name": "A",
"value": map[string]any{"waz": "value"},
},
})
expected := map[string]any{
"bar": "value",
Expand All @@ -123,6 +165,10 @@ func TestTransformToAliasedFields(t *testing.T) {
"optional": map[string]any{
"foo": "value",
},
"typeEnum": map[string]any{
"name": "A",
"value": map[string]any{"foo": "value"},
},
}
assert.NoError(t, err)
assert.Equal(t, expected, actual)
Expand Down
2 changes: 1 addition & 1 deletion backend/controller/ingress/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func Handle(
}

var responseHeaders http.Header
responseBody, responseHeaders, err = ResponseForVerb(sch, verb, response)
responseBody, responseHeaders, err = ResponseForVerb(r.Context(), sch, verb, response)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
Expand Down
Loading

0 comments on commit 320e3ed

Please sign in to comment.