Skip to content

Commit

Permalink
feat: Handle sum types in backend encoder
Browse files Browse the repository at this point in the history
fixes #1387
  • Loading branch information
worstell committed May 8, 2024
1 parent acdb53a commit 22b416a
Show file tree
Hide file tree
Showing 8 changed files with 245 additions and 95 deletions.
2 changes: 1 addition & 1 deletion backend/controller/ingress/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func TestIngress(t *testing.T) {
req.URL.RawQuery = test.query.Encode()
reqKey := model.NewRequestKey(model.OriginIngress, "test")
ingress.Handle(sch, reqKey, routes, rec, req, func(ctx context.Context, r *connect.Request[ftlv1.CallRequest], requestKey optional.Option[model.RequestKey], requestSource string) (*connect.Response[ftlv1.CallResponse], error) {
body, err := encoding.Marshal(response)
body, err := encoding.Marshal(ctx, response)
assert.NoError(t, err)
return connect.NewResponse(&ftlv1.CallResponse{Response: &ftlv1.CallResponse_Body{Body: body}}), nil
})
Expand Down
5 changes: 3 additions & 2 deletions backend/controller/ingress/request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package ingress

import (
"bytes"
"context"
"net/http"
"net/url"
"reflect"
Expand Down Expand Up @@ -162,7 +163,7 @@ func TestBuildRequestBody(t *testing.T) {
if test.body == nil {
test.body = obj{}
}
body, err := encoding.Marshal(test.body)
body, err := encoding.Marshal(context.Background(), test.body)
assert.NoError(t, err)
requestURL := "http://127.0.0.1" + test.path
if test.query != nil {
Expand All @@ -182,7 +183,7 @@ func TestBuildRequestBody(t *testing.T) {
assert.NoError(t, err)
actualrv := reflect.New(reflect.TypeOf(test.expected))
actual := actualrv.Interface()
err = encoding.Unmarshal(requestBody, actual)
err = encoding.Unmarshal(context.Background(), requestBody, actual)
assert.NoError(t, err)
assert.Equal(t, test.expected, actualrv.Elem().Interface(), assert.OmitEmpty())
})
Expand Down
152 changes: 118 additions & 34 deletions go-runtime/encoding/encoding.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package encoding

import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"fmt"
Expand All @@ -13,6 +14,7 @@ import (
"time"

"github.com/TBD54566975/ftl/backend/schema/strcase"
"github.com/TBD54566975/ftl/go-runtime/ftl/typeregistry"
)

var (
Expand All @@ -21,19 +23,19 @@ var (
)

type OptionMarshaler interface {
Marshal(w *bytes.Buffer, encode func(v reflect.Value, w *bytes.Buffer) error) error
Marshal(ctx context.Context, w *bytes.Buffer, encode func(ctx context.Context, v reflect.Value, w *bytes.Buffer) error) error
}
type OptionUnmarshaler interface {
Unmarshal(d *json.Decoder, isNull bool, decode func(d *json.Decoder, v reflect.Value) error) error
Unmarshal(ctx context.Context, d *json.Decoder, isNull bool, decode func(ctx context.Context, d *json.Decoder, v reflect.Value) error) error
}

func Marshal(v any) ([]byte, error) {
func Marshal(ctx context.Context, v any) ([]byte, error) {
w := &bytes.Buffer{}
err := encodeValue(reflect.ValueOf(v), w)
err := encodeValue(ctx, reflect.ValueOf(v), w)
return w.Bytes(), err
}

func encodeValue(v reflect.Value, w *bytes.Buffer) error {
func encodeValue(ctx context.Context, v reflect.Value, w *bytes.Buffer) error {
if !v.IsValid() {
w.WriteString("null")
return nil
Expand All @@ -56,7 +58,7 @@ func encodeValue(v reflect.Value, w *bytes.Buffer) error {

case t.Implements(optionMarshaler):
enc := v.Interface().(OptionMarshaler) //nolint:forcetypeassert
return enc.Marshal(w, encodeValue)
return enc.Marshal(ctx, w, encodeValue)

//TODO: Remove once we support `omitempty` tag
case t == reflect.TypeFor[json.RawMessage]():
Expand All @@ -70,16 +72,16 @@ func encodeValue(v reflect.Value, w *bytes.Buffer) error {

switch v.Kind() {
case reflect.Struct:
return encodeStruct(v, w)
return encodeStruct(ctx, v, w)

case reflect.Slice:
if v.Type().Elem().Kind() == reflect.Uint8 {
return encodeBytes(v, w)
}
return encodeSlice(v, w)
return encodeSlice(ctx, v, w)

case reflect.Map:
return encodeMap(v, w)
return encodeMap(ctx, v, w)

case reflect.String:
return encodeString(v, w)
Expand All @@ -93,18 +95,29 @@ func encodeValue(v reflect.Value, w *bytes.Buffer) error {
case reflect.Bool:
return encodeBool(v, w)

case reflect.Interface: // any
if t != reflect.TypeOf((*any)(nil)).Elem() {
return fmt.Errorf("the only interface type supported is any, not %s", t)
case reflect.Interface:
if t == reflect.TypeFor[any]() {
return encodeValue(ctx, v.Elem(), w)
}

if tr, ok := typeregistry.FromContext(ctx).Get(); ok {
if vName, ok := tr.GetVariantByType(v.Type(), v.Elem().Type()).Get(); ok {
stv := struct {
Name string
Value any
}{Name: vName, Value: v.Elem().Interface()}
return encodeValue(ctx, reflect.ValueOf(stv), w)
}
}
return encodeValue(v.Elem(), w)

return fmt.Errorf("the only interface types are enums or any, not %s", t)

default:
panic(fmt.Sprintf("unsupported type: %s", v.Type()))
}
}

func encodeStruct(v reflect.Value, w *bytes.Buffer) error {
func encodeStruct(ctx context.Context, v reflect.Value, w *bytes.Buffer) error {
w.WriteRune('{')
afterFirst := false
for i := range v.NumField() {
Expand All @@ -129,7 +142,7 @@ func encodeStruct(v reflect.Value, w *bytes.Buffer) error {
}
afterFirst = true
w.WriteString(`"` + strcase.ToLowerCamel(ft.Name) + `":`)
if err := encodeValue(fv, w); err != nil {
if err := encodeValue(ctx, fv, w); err != nil {
return err
}
}
Expand All @@ -143,21 +156,21 @@ func encodeBytes(v reflect.Value, w *bytes.Buffer) error {
return nil
}

func encodeSlice(v reflect.Value, w *bytes.Buffer) error {
func encodeSlice(ctx context.Context, v reflect.Value, w *bytes.Buffer) error {
w.WriteRune('[')
for i := range v.Len() {
if i > 0 {
w.WriteRune(',')
}
if err := encodeValue(v.Index(i), w); err != nil {
if err := encodeValue(ctx, v.Index(i), w); err != nil {
return err
}
}
w.WriteRune(']')
return nil
}

func encodeMap(v reflect.Value, w *bytes.Buffer) error {
func encodeMap(ctx context.Context, v reflect.Value, w *bytes.Buffer) error {
w.WriteRune('{')
for i, key := range v.MapKeys() {
if i > 0 {
Expand All @@ -166,7 +179,7 @@ func encodeMap(v reflect.Value, w *bytes.Buffer) error {
w.WriteRune('"')
w.WriteString(key.String())
w.WriteString(`":`)
if err := encodeValue(v.MapIndex(key), w); err != nil {
if err := encodeValue(ctx, v.MapIndex(key), w); err != nil {
return err
}
}
Expand Down Expand Up @@ -198,17 +211,17 @@ func encodeString(v reflect.Value, w *bytes.Buffer) error {
return nil
}

func Unmarshal(data []byte, v any) error {
func Unmarshal(ctx context.Context, data []byte, v any) error {
rv := reflect.ValueOf(v)
if rv.Kind() != reflect.Ptr || rv.IsNil() {
return fmt.Errorf("unmarshal expects a non-nil pointer")
}

d := json.NewDecoder(bytes.NewReader(data))
return decodeValue(d, rv.Elem())
return decodeValue(ctx, d, rv.Elem())
}

func decodeValue(d *json.Decoder, v reflect.Value) error {
func decodeValue(ctx context.Context, d *json.Decoder, v reflect.Value) error {
if !v.CanSet() {
return fmt.Errorf("cannot set value: %s", v.Type())
}
Expand All @@ -233,28 +246,34 @@ func decodeValue(d *json.Decoder, v reflect.Value) error {
}
dec := v.Interface().(OptionUnmarshaler) //nolint:forcetypeassert
return handleIfNextTokenIsNull(d, func(d *json.Decoder) error {
return dec.Unmarshal(d, true, decodeValue)
return dec.Unmarshal(ctx, d, true, decodeValue)
}, func(d *json.Decoder) error {
return dec.Unmarshal(d, false, decodeValue)
return dec.Unmarshal(ctx, d, false, decodeValue)
})
}

switch v.Kind() {
case reflect.Struct:
return decodeStruct(d, v)
return decodeStruct(ctx, d, v)

case reflect.Slice:
if v.Type().Elem().Kind() == reflect.Uint8 {
return decodeBytes(d, v)
}
return decodeSlice(d, v)
return decodeSlice(ctx, d, v)

case reflect.Map:
return decodeMap(d, v)
return decodeMap(ctx, d, v)

case reflect.Interface:
if tr, ok := typeregistry.FromContext(ctx).Get(); ok {
if tr.IsSumTypeDiscriminator(v.Type()) {
return decodeSumType(ctx, d, v)
}
}

if v.Type().NumMethod() != 0 {
return fmt.Errorf("the only interface type supported is any, not %s", v.Type())
return fmt.Errorf("the only interface types supported are enums or any, not %s", v.Type())
}
fallthrough

Expand All @@ -263,7 +282,7 @@ func decodeValue(d *json.Decoder, v reflect.Value) error {
}
}

func decodeStruct(d *json.Decoder, v reflect.Value) error {
func decodeStruct(ctx context.Context, d *json.Decoder, v reflect.Value) error {
if err := expectDelim(d, '{'); err != nil {
return err
}
Expand Down Expand Up @@ -291,7 +310,7 @@ func decodeStruct(d *json.Decoder, v reflect.Value) error {
field.Set(reflect.New(field.Type().Elem()))
}
default:
if err := decodeValue(d, field); err != nil {
if err := decodeValue(ctx, d, field); err != nil {
return err
}
}
Expand All @@ -311,14 +330,14 @@ func decodeBytes(d *json.Decoder, v reflect.Value) error {
return nil
}

func decodeSlice(d *json.Decoder, v reflect.Value) error {
func decodeSlice(ctx context.Context, d *json.Decoder, v reflect.Value) error {
if err := expectDelim(d, '['); err != nil {
return err
}

for d.More() {
newElem := reflect.New(v.Type().Elem()).Elem()
if err := decodeValue(d, newElem); err != nil {
if err := decodeValue(ctx, d, newElem); err != nil {
return err
}
v.Set(reflect.Append(v, newElem))
Expand All @@ -328,7 +347,7 @@ func decodeSlice(d *json.Decoder, v reflect.Value) error {
return err
}

func decodeMap(d *json.Decoder, v reflect.Value) error {
func decodeMap(ctx context.Context, d *json.Decoder, v reflect.Value) error {
if err := expectDelim(d, '{'); err != nil {
return err
}
Expand All @@ -345,7 +364,7 @@ func decodeMap(d *json.Decoder, v reflect.Value) error {
}

newElem := reflect.New(valType).Elem()
if err := decodeValue(d, newElem); err != nil {
if err := decodeValue(ctx, d, newElem); err != nil {
return err
}

Expand All @@ -356,6 +375,71 @@ func decodeMap(d *json.Decoder, v reflect.Value) error {
return err
}

func decodeSumType(ctx context.Context, d *json.Decoder, v reflect.Value) error {
tr, ok := typeregistry.FromContext(ctx).Get()
if !ok {
return fmt.Errorf("no type registry found in context")
}

if err := expectDelim(d, '{'); err != nil {
return err
}

var name string
var valueJSON json.RawMessage
for d.More() {
token, err := d.Token()
if err != nil {
return err
}
key, ok := token.(string)
if !ok {
return fmt.Errorf("expected string key, got %T", token)
}

switch key {
case "value":
if err := d.Decode(&valueJSON); err != nil {
return err
}

case "name":
// Decode name as a string
if err := d.Decode(&name); err != nil {
return err
}
default:
return fmt.Errorf(`unexpected key %q in sum type variant, expected "name" or "value"`, key)
}
}

if name == "" {
return fmt.Errorf("no name found for type enum variant")
}

if valueJSON == nil {
return fmt.Errorf("no value found for type enum variant")
}

variantType, ok := tr.GetVariantByName(v.Type(), name).Get()
if !ok {
return fmt.Errorf("no enum variant found by name %s", name)
}

out := reflect.New(variantType)
if err := decodeValue(ctx, json.NewDecoder(bytes.NewReader(valueJSON)), out.Elem()); err != nil {
return err
}
if !out.Type().AssignableTo(v.Type()) {
return fmt.Errorf("cannot assign %s to %s", out.Type(), v.Type())
}
v.Set(out.Elem())

// consume the closing delimiter of the object
_, err := d.Token()
return err
}

func expectDelim(d *json.Decoder, expected json.Delim) error {
token, err := d.Token()
if err != nil {
Expand Down
Loading

0 comments on commit 22b416a

Please sign in to comment.