diff --git a/backend/controller/ingress/handler_test.go b/backend/controller/ingress/handler_test.go index 72af762c4f..012571d390 100644 --- a/backend/controller/ingress/handler_test.go +++ b/backend/controller/ingress/handler_test.go @@ -100,7 +100,7 @@ func TestIngress(t *testing.T) { req.URL.RawQuery = test.query.Encode() reqKey := model.NewRequestKey(model.OriginIngress, "test") ingress.Handle(sch, reqKey, routes, rec, req, func(ctx context.Context, r *connect.Request[ftlv1.CallRequest], requestKey optional.Option[model.RequestKey], requestSource string) (*connect.Response[ftlv1.CallResponse], error) { - body, err := encoding.Marshal(response) + body, err := encoding.Marshal(ctx, response) assert.NoError(t, err) return connect.NewResponse(&ftlv1.CallResponse{Response: &ftlv1.CallResponse_Body{Body: body}}), nil }) diff --git a/backend/controller/ingress/request_test.go b/backend/controller/ingress/request_test.go index 775534b238..b4ceb6af05 100644 --- a/backend/controller/ingress/request_test.go +++ b/backend/controller/ingress/request_test.go @@ -2,6 +2,7 @@ package ingress import ( "bytes" + "context" "net/http" "net/url" "reflect" @@ -162,7 +163,7 @@ func TestBuildRequestBody(t *testing.T) { if test.body == nil { test.body = obj{} } - body, err := encoding.Marshal(test.body) + body, err := encoding.Marshal(context.Background(), test.body) assert.NoError(t, err) requestURL := "http://127.0.0.1" + test.path if test.query != nil { @@ -182,7 +183,7 @@ func TestBuildRequestBody(t *testing.T) { assert.NoError(t, err) actualrv := reflect.New(reflect.TypeOf(test.expected)) actual := actualrv.Interface() - err = encoding.Unmarshal(requestBody, actual) + err = encoding.Unmarshal(context.Background(), requestBody, actual) assert.NoError(t, err) assert.Equal(t, test.expected, actualrv.Elem().Interface(), assert.OmitEmpty()) }) diff --git a/go-runtime/encoding/encoding.go b/go-runtime/encoding/encoding.go index f427a1d656..59611328f1 100644 --- a/go-runtime/encoding/encoding.go +++ b/go-runtime/encoding/encoding.go @@ -4,6 +4,7 @@ package encoding import ( "bytes" + "context" "encoding/base64" "encoding/json" "fmt" @@ -13,6 +14,7 @@ import ( "time" "github.com/TBD54566975/ftl/backend/schema/strcase" + "github.com/TBD54566975/ftl/go-runtime/ftl/typeregistry" ) var ( @@ -21,19 +23,19 @@ var ( ) type OptionMarshaler interface { - Marshal(w *bytes.Buffer, encode func(v reflect.Value, w *bytes.Buffer) error) error + Marshal(ctx context.Context, w *bytes.Buffer, encode func(ctx context.Context, v reflect.Value, w *bytes.Buffer) error) error } type OptionUnmarshaler interface { - Unmarshal(d *json.Decoder, isNull bool, decode func(d *json.Decoder, v reflect.Value) error) error + Unmarshal(ctx context.Context, d *json.Decoder, isNull bool, decode func(ctx context.Context, d *json.Decoder, v reflect.Value) error) error } -func Marshal(v any) ([]byte, error) { +func Marshal(ctx context.Context, v any) ([]byte, error) { w := &bytes.Buffer{} - err := encodeValue(reflect.ValueOf(v), w) + err := encodeValue(ctx, reflect.ValueOf(v), w) return w.Bytes(), err } -func encodeValue(v reflect.Value, w *bytes.Buffer) error { +func encodeValue(ctx context.Context, v reflect.Value, w *bytes.Buffer) error { if !v.IsValid() { w.WriteString("null") return nil @@ -56,7 +58,7 @@ func encodeValue(v reflect.Value, w *bytes.Buffer) error { case t.Implements(optionMarshaler): enc := v.Interface().(OptionMarshaler) //nolint:forcetypeassert - return enc.Marshal(w, encodeValue) + return enc.Marshal(ctx, w, encodeValue) //TODO: Remove once we support `omitempty` tag case t == reflect.TypeFor[json.RawMessage](): @@ -70,16 +72,16 @@ func encodeValue(v reflect.Value, w *bytes.Buffer) error { switch v.Kind() { case reflect.Struct: - return encodeStruct(v, w) + return encodeStruct(ctx, v, w) case reflect.Slice: if v.Type().Elem().Kind() == reflect.Uint8 { return encodeBytes(v, w) } - return encodeSlice(v, w) + return encodeSlice(ctx, v, w) case reflect.Map: - return encodeMap(v, w) + return encodeMap(ctx, v, w) case reflect.String: return encodeString(v, w) @@ -93,18 +95,29 @@ func encodeValue(v reflect.Value, w *bytes.Buffer) error { case reflect.Bool: return encodeBool(v, w) - case reflect.Interface: // any - if t != reflect.TypeOf((*any)(nil)).Elem() { - return fmt.Errorf("the only interface type supported is any, not %s", t) + case reflect.Interface: + if t == reflect.TypeFor[any]() { + return encodeValue(ctx, v.Elem(), w) + } + + if tr, ok := typeregistry.FromContext(ctx).Get(); ok { + if vName, ok := tr.GetVariantByType(v.Type(), v.Elem().Type()).Get(); ok { + stv := struct { + Name string + Value any + }{Name: vName, Value: v.Elem().Interface()} + return encodeValue(ctx, reflect.ValueOf(stv), w) + } } - return encodeValue(v.Elem(), w) + + return fmt.Errorf("the only interface types are enums or any, not %s", t) default: panic(fmt.Sprintf("unsupported type: %s", v.Type())) } } -func encodeStruct(v reflect.Value, w *bytes.Buffer) error { +func encodeStruct(ctx context.Context, v reflect.Value, w *bytes.Buffer) error { w.WriteRune('{') afterFirst := false for i := range v.NumField() { @@ -129,7 +142,7 @@ func encodeStruct(v reflect.Value, w *bytes.Buffer) error { } afterFirst = true w.WriteString(`"` + strcase.ToLowerCamel(ft.Name) + `":`) - if err := encodeValue(fv, w); err != nil { + if err := encodeValue(ctx, fv, w); err != nil { return err } } @@ -143,13 +156,13 @@ func encodeBytes(v reflect.Value, w *bytes.Buffer) error { return nil } -func encodeSlice(v reflect.Value, w *bytes.Buffer) error { +func encodeSlice(ctx context.Context, v reflect.Value, w *bytes.Buffer) error { w.WriteRune('[') for i := range v.Len() { if i > 0 { w.WriteRune(',') } - if err := encodeValue(v.Index(i), w); err != nil { + if err := encodeValue(ctx, v.Index(i), w); err != nil { return err } } @@ -157,7 +170,7 @@ func encodeSlice(v reflect.Value, w *bytes.Buffer) error { return nil } -func encodeMap(v reflect.Value, w *bytes.Buffer) error { +func encodeMap(ctx context.Context, v reflect.Value, w *bytes.Buffer) error { w.WriteRune('{') for i, key := range v.MapKeys() { if i > 0 { @@ -166,7 +179,7 @@ func encodeMap(v reflect.Value, w *bytes.Buffer) error { w.WriteRune('"') w.WriteString(key.String()) w.WriteString(`":`) - if err := encodeValue(v.MapIndex(key), w); err != nil { + if err := encodeValue(ctx, v.MapIndex(key), w); err != nil { return err } } @@ -198,17 +211,17 @@ func encodeString(v reflect.Value, w *bytes.Buffer) error { return nil } -func Unmarshal(data []byte, v any) error { +func Unmarshal(ctx context.Context, data []byte, v any) error { rv := reflect.ValueOf(v) if rv.Kind() != reflect.Ptr || rv.IsNil() { return fmt.Errorf("unmarshal expects a non-nil pointer") } d := json.NewDecoder(bytes.NewReader(data)) - return decodeValue(d, rv.Elem()) + return decodeValue(ctx, d, rv.Elem()) } -func decodeValue(d *json.Decoder, v reflect.Value) error { +func decodeValue(ctx context.Context, d *json.Decoder, v reflect.Value) error { if !v.CanSet() { return fmt.Errorf("cannot set value: %s", v.Type()) } @@ -233,28 +246,34 @@ func decodeValue(d *json.Decoder, v reflect.Value) error { } dec := v.Interface().(OptionUnmarshaler) //nolint:forcetypeassert return handleIfNextTokenIsNull(d, func(d *json.Decoder) error { - return dec.Unmarshal(d, true, decodeValue) + return dec.Unmarshal(ctx, d, true, decodeValue) }, func(d *json.Decoder) error { - return dec.Unmarshal(d, false, decodeValue) + return dec.Unmarshal(ctx, d, false, decodeValue) }) } switch v.Kind() { case reflect.Struct: - return decodeStruct(d, v) + return decodeStruct(ctx, d, v) case reflect.Slice: if v.Type().Elem().Kind() == reflect.Uint8 { return decodeBytes(d, v) } - return decodeSlice(d, v) + return decodeSlice(ctx, d, v) case reflect.Map: - return decodeMap(d, v) + return decodeMap(ctx, d, v) case reflect.Interface: + if tr, ok := typeregistry.FromContext(ctx).Get(); ok { + if tr.IsSumTypeDiscriminator(v.Type()) { + return decodeSumType(ctx, d, v) + } + } + if v.Type().NumMethod() != 0 { - return fmt.Errorf("the only interface type supported is any, not %s", v.Type()) + return fmt.Errorf("the only interface types supported are enums or any, not %s", v.Type()) } fallthrough @@ -263,7 +282,7 @@ func decodeValue(d *json.Decoder, v reflect.Value) error { } } -func decodeStruct(d *json.Decoder, v reflect.Value) error { +func decodeStruct(ctx context.Context, d *json.Decoder, v reflect.Value) error { if err := expectDelim(d, '{'); err != nil { return err } @@ -291,7 +310,7 @@ func decodeStruct(d *json.Decoder, v reflect.Value) error { field.Set(reflect.New(field.Type().Elem())) } default: - if err := decodeValue(d, field); err != nil { + if err := decodeValue(ctx, d, field); err != nil { return err } } @@ -311,14 +330,14 @@ func decodeBytes(d *json.Decoder, v reflect.Value) error { return nil } -func decodeSlice(d *json.Decoder, v reflect.Value) error { +func decodeSlice(ctx context.Context, d *json.Decoder, v reflect.Value) error { if err := expectDelim(d, '['); err != nil { return err } for d.More() { newElem := reflect.New(v.Type().Elem()).Elem() - if err := decodeValue(d, newElem); err != nil { + if err := decodeValue(ctx, d, newElem); err != nil { return err } v.Set(reflect.Append(v, newElem)) @@ -328,7 +347,7 @@ func decodeSlice(d *json.Decoder, v reflect.Value) error { return err } -func decodeMap(d *json.Decoder, v reflect.Value) error { +func decodeMap(ctx context.Context, d *json.Decoder, v reflect.Value) error { if err := expectDelim(d, '{'); err != nil { return err } @@ -345,7 +364,7 @@ func decodeMap(d *json.Decoder, v reflect.Value) error { } newElem := reflect.New(valType).Elem() - if err := decodeValue(d, newElem); err != nil { + if err := decodeValue(ctx, d, newElem); err != nil { return err } @@ -356,6 +375,71 @@ func decodeMap(d *json.Decoder, v reflect.Value) error { return err } +func decodeSumType(ctx context.Context, d *json.Decoder, v reflect.Value) error { + tr, ok := typeregistry.FromContext(ctx).Get() + if !ok { + return fmt.Errorf("no type registry found in context") + } + + if err := expectDelim(d, '{'); err != nil { + return err + } + + var name string + var valueJSON json.RawMessage + for d.More() { + token, err := d.Token() + if err != nil { + return err + } + key, ok := token.(string) + if !ok { + return fmt.Errorf("expected string key, got %T", token) + } + + switch key { + case "value": + if err := d.Decode(&valueJSON); err != nil { + return err + } + + case "name": + // Decode name as a string + if err := d.Decode(&name); err != nil { + return err + } + default: + return fmt.Errorf("unexpected key %q in sum type", key) + } + } + + if name == "" { + return fmt.Errorf("no name found for type enum variant") + } + + if valueJSON == nil { + return fmt.Errorf("no value found for type enum variant") + } + + variantType, ok := tr.GetVariantByName(v.Type(), name).Get() + if !ok { + return fmt.Errorf("no enum variant found by name %s", name) + } + + out := reflect.New(variantType) + if err := decodeValue(ctx, json.NewDecoder(bytes.NewReader(valueJSON)), out.Elem()); err != nil { + return err + } + if !out.Type().AssignableTo(v.Type()) { + return fmt.Errorf("cannot assign %s to %s", out.Type(), v.Type()) + } + v.Set(out.Elem()) + + // consume the closing delimiter of the object + _, err := d.Token() + return err +} + func expectDelim(d *json.Decoder, expected json.Delim) error { token, err := d.Token() if err != nil { diff --git a/go-runtime/encoding/encoding_test.go b/go-runtime/encoding/encoding_test.go index 2187e354f7..3d1a57cb8b 100644 --- a/go-runtime/encoding/encoding_test.go +++ b/go-runtime/encoding/encoding_test.go @@ -1,6 +1,8 @@ package encoding_test import ( + "context" + "github.com/TBD54566975/ftl/go-runtime/ftl/typeregistry" "reflect" "testing" "time" @@ -11,6 +13,15 @@ import ( "github.com/TBD54566975/ftl/go-runtime/ftl" ) +type discriminator interface { + tag() +} +type variant struct { + Message string +} + +func (variant) tag() {} + func TestMarshal(t *testing.T) { type inner struct { FooBar string @@ -39,11 +50,18 @@ func TestMarshal(t *testing.T) { Unit ftl.Unit }{String: "something", Unit: ftl.Unit{}}, expected: `{"string":"something","unit":{}}`}, {name: "Pointer", input: &struct{ String string }{"foo"}, err: `pointer types are not supported: *struct { String string }`}, + {name: "SumType", input: struct{ D discriminator }{variant{"hello"}}, expected: `{"d":{"name":"Variant","value":{"message":"hello"}}}`}, } + tr := typeregistry.NewTypeRegistry() + tr.RegisterSumType(reflect.TypeFor[discriminator](), map[string]reflect.Type{ + "Variant": reflect.TypeFor[variant](), + }) + ctx := typeregistry.ContextWithTypeRegistry(context.Background(), tr) + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - actual, err := Marshal(tt.input) + actual, err := Marshal(ctx, tt.input) assert.EqualError(t, err, tt.err) assert.Equal(t, tt.expected, string(actual)) }) @@ -85,13 +103,20 @@ func TestUnmarshal(t *testing.T) { Bool bool }{ftl.None[int](), true}}, {name: "Pointer", input: `{"string":"foo"}`, expected: &struct{ String string }{}, err: `pointer types are not supported: *struct { String string }`}, + {name: "SumType", input: `{"d":{"name":"Variant","value":{"message":"hello"}}}`, expected: struct{ D discriminator }{variant{"hello"}}}, } + tr := typeregistry.NewTypeRegistry() + tr.RegisterSumType(reflect.TypeFor[discriminator](), map[string]reflect.Type{ + "Variant": reflect.TypeFor[variant](), + }) + ctx := typeregistry.ContextWithTypeRegistry(context.Background(), tr) + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { eType := reflect.TypeOf(tt.expected) o := reflect.New(eType) - err := Unmarshal([]byte(tt.input), o.Interface()) + err := Unmarshal(ctx, []byte(tt.input), o.Interface()) assert.EqualError(t, err, tt.err) if err == nil { assert.Equal(t, tt.expected, o.Elem().Interface()) @@ -128,16 +153,23 @@ func TestRoundTrip(t *testing.T) { {name: "Aliased", input: struct { TokenID string `json:"token_id"` }{"123"}}, + {name: "SumType", input: struct{ D discriminator }{variant{"hello"}}}, } + tr := typeregistry.NewTypeRegistry() + tr.RegisterSumType(reflect.TypeFor[discriminator](), map[string]reflect.Type{ + "Variant": reflect.TypeFor[variant](), + }) + ctx := typeregistry.ContextWithTypeRegistry(context.Background(), tr) + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - marshaled, err := Marshal(tt.input) + marshaled, err := Marshal(ctx, tt.input) assert.NoError(t, err) eType := reflect.TypeOf(tt.input) o := reflect.New(eType) - err = Unmarshal(marshaled, o.Interface()) + err = Unmarshal(ctx, marshaled, o.Interface()) assert.NoError(t, err) assert.Equal(t, tt.input, o.Elem().Interface()) diff --git a/go-runtime/ftl/call.go b/go-runtime/ftl/call.go index ed8e9c2a25..3d2f1f1210 100644 --- a/go-runtime/ftl/call.go +++ b/go-runtime/ftl/call.go @@ -31,7 +31,7 @@ func call[Req, Resp any](ctx context.Context, callee Ref, req Req, inline Verb[R return resp, fmt.Errorf("%s: overridden verb had invalid response type %T, expected %v", callee, uncheckedResp, reflect.TypeFor[Resp]()) } - reqData, err := encoding.Marshal(req) + reqData, err := encoding.Marshal(ctx, req) if err != nil { return resp, fmt.Errorf("%s: failed to marshal request: %w", callee, err) } @@ -46,7 +46,7 @@ func call[Req, Resp any](ctx context.Context, callee Ref, req Req, inline Verb[R return resp, fmt.Errorf("%s: %s", callee, cresp.Error.Message) case *ftlv1.CallResponse_Body: - err = encoding.Unmarshal(cresp.Body, &resp) + err = encoding.Unmarshal(ctx, cresp.Body, &resp) if err != nil { return resp, fmt.Errorf("%s: failed to decode response: %w", callee, err) } diff --git a/go-runtime/ftl/option.go b/go-runtime/ftl/option.go index b932eb8d9b..184bbad0ec 100644 --- a/go-runtime/ftl/option.go +++ b/go-runtime/ftl/option.go @@ -3,14 +3,13 @@ package ftl import ( "bytes" + "context" "database/sql" "database/sql/driver" "encoding" "encoding/json" "fmt" "reflect" - - ftlencoding "github.com/TBD54566975/ftl/go-runtime/encoding" ) // Stdlib interfaces types implement. @@ -166,25 +165,6 @@ func (o Option[T]) Default(value T) T { return value } -func (o Option[T]) MarshalJSON() ([]byte, error) { - if o.ok { - return ftlencoding.Marshal(o.value) - } - return []byte("null"), nil -} - -func (o *Option[T]) UnmarshalJSON(data []byte) error { - if string(data) == "null" { - o.ok = false - return nil - } - if err := ftlencoding.Unmarshal(data, &o.value); err != nil { - return err - } - o.ok = true - return nil -} - func (o Option[T]) String() string { if o.ok { return fmt.Sprintf("%v", o.value) @@ -199,20 +179,29 @@ func (o Option[T]) GoString() string { return fmt.Sprintf("None[%T]()", o.value) } -func (o Option[T]) Marshal(w *bytes.Buffer, encode func(v reflect.Value, w *bytes.Buffer) error) error { +func (o Option[T]) Marshal( + ctx context.Context, + w *bytes.Buffer, + encode func(ctx context.Context, v reflect.Value, w *bytes.Buffer) error, +) error { if o.ok { - return encode(reflect.ValueOf(&o.value).Elem(), w) + return encode(ctx, reflect.ValueOf(&o.value).Elem(), w) } w.WriteString("null") return nil } -func (o *Option[T]) Unmarshal(d *json.Decoder, isNull bool, decode func(d *json.Decoder, v reflect.Value) error) error { +func (o *Option[T]) Unmarshal( + ctx context.Context, + d *json.Decoder, + isNull bool, + decode func(ctx context.Context, d *json.Decoder, v reflect.Value) error, +) error { if isNull { o.ok = false return nil } - if err := decode(d, reflect.ValueOf(&o.value).Elem()); err != nil { + if err := decode(ctx, d, reflect.ValueOf(&o.value).Elem()); err != nil { return err } o.ok = true diff --git a/go-runtime/ftl/typeregistry/type_registry.go b/go-runtime/ftl/typeregistry/type_registry.go index 2ce7b0094c..58f634fb89 100644 --- a/go-runtime/ftl/typeregistry/type_registry.go +++ b/go-runtime/ftl/typeregistry/type_registry.go @@ -2,8 +2,9 @@ package typeregistry import ( "context" - "fmt" "reflect" + + "github.com/alecthomas/types/optional" ) type contextKeyTypeRegistry struct{} @@ -13,46 +14,88 @@ func ContextWithTypeRegistry(ctx context.Context, r *TypeRegistry) context.Conte return context.WithValue(ctx, contextKeyTypeRegistry{}, r) } -// TypeRegistry is a registry of types that can be instantiated by their qualified name. -// It also records sum types and their variants, for use in encoding and decoding. +// FromContext retrieves the secrets schema.TypeRegistry previously +// added to the context with [ContextWithTypeRegistry]. +func FromContext(ctx context.Context) optional.Option[*TypeRegistry] { + t, ok := ctx.Value(contextKeyTypeRegistry{}).(*TypeRegistry) + if ok { + return optional.Some(t) + } + return optional.None[*TypeRegistry]() +} + +// TypeRegistry is used for dynamic type resolution at runtime. It stores associations between sum type discriminators +// and their variants, for use in encoding and decoding. // -// FTL manages the type registry for you, so you don't need to create one yourself +// FTL manages the type registry for you, so you don't need to create one yourself. type TypeRegistry struct { - // GoTypes associates a type name with a Go type. - GoTypes map[string]reflect.Type - // SumTypes associates a sum type discriminator type name with its variant type names. - SumTypes map[string][]string + sumTypes map[string][]sumTypeVariant +} + +type sumTypeVariant struct { + name string + goType reflect.Type } // NewTypeRegistry creates a new type registry. // The type registry is used to instantiate types by their qualified name at runtime. func NewTypeRegistry() *TypeRegistry { return &TypeRegistry{ - GoTypes: make(map[string]reflect.Type), - SumTypes: make(map[string][]string), - } -} - -// New creates a new instance of the type from the qualified type name. -func (t *TypeRegistry) New(name string) (any, error) { - typ, ok := t.GoTypes[name] - if !ok { - return nil, fmt.Errorf("type %q not registered", name) + sumTypes: make(map[string][]sumTypeVariant), } - return reflect.New(typ).Interface(), nil } // RegisterSumType registers a Go sum type with the type registry. Sum types are represented as enums in the // FTL schema. func (t *TypeRegistry) RegisterSumType(discriminator reflect.Type, variants map[string]reflect.Type) { dFqName := discriminator.PkgPath() + "." + discriminator.Name() - t.GoTypes[dFqName] = discriminator - var values []string + var values []sumTypeVariant for name, v := range variants { - values = append(values, name) - vFqName := v.PkgPath() + "." + v.Name() - t.GoTypes[vFqName] = v + values = append(values, sumTypeVariant{ + name: name, + goType: v, + }) + } + t.sumTypes[dFqName] = values +} + +func (t *TypeRegistry) IsSumTypeDiscriminator(discriminator reflect.Type) bool { + return t.getSumTypeVariants(discriminator).Ok() +} + +func (t *TypeRegistry) GetVariantByName(discriminator reflect.Type, name string) optional.Option[reflect.Type] { + variants, ok := t.getSumTypeVariants(discriminator).Get() + if !ok { + return optional.None[reflect.Type]() + } + for _, v := range variants { + if v.name == name { + return optional.Some(v.goType) + } + } + return optional.None[reflect.Type]() +} + +func (t *TypeRegistry) GetVariantByType(discriminator reflect.Type, variantType reflect.Type) optional.Option[string] { + variants, ok := t.getSumTypeVariants(discriminator).Get() + if !ok { + return optional.None[string]() } - t.SumTypes[dFqName] = values + for _, v := range variants { + if v.goType == variantType { + return optional.Some(v.name) + } + } + return optional.None[string]() +} + +func (t *TypeRegistry) getSumTypeVariants(discriminator reflect.Type) optional.Option[[]sumTypeVariant] { + dFqName := discriminator.PkgPath() + "." + discriminator.Name() + variants, ok := t.sumTypes[dFqName] + if !ok { + return optional.None[[]sumTypeVariant]() + } + + return optional.Some(variants) } diff --git a/go-runtime/server/server.go b/go-runtime/server/server.go index c9c099106d..841404c5bf 100644 --- a/go-runtime/server/server.go +++ b/go-runtime/server/server.go @@ -67,7 +67,7 @@ func handler[Req, Resp any](ref ftl.Ref, verb func(ctx context.Context, req Req) fn: func(ctx context.Context, reqdata []byte) ([]byte, error) { // Decode request. var req Req - err := encoding.Unmarshal(reqdata, &req) + err := encoding.Unmarshal(ctx, reqdata, &req) if err != nil { return nil, fmt.Errorf("invalid request to verb %s: %w", ref, err) } @@ -78,7 +78,7 @@ func handler[Req, Resp any](ref ftl.Ref, verb func(ctx context.Context, req Req) return nil, fmt.Errorf("call to verb %s failed: %w", ref, err) } - respdata, err := encoding.Marshal(resp) + respdata, err := encoding.Marshal(ctx, resp) if err != nil { return nil, err } diff --git a/integration/actions_test.go b/integration/actions_test.go index 18e1acd32e..5bc2076a97 100644 --- a/integration/actions_test.go +++ b/integration/actions_test.go @@ -24,6 +24,7 @@ import ( ftlv1 "github.com/TBD54566975/ftl/backend/protos/xyz/block/ftl/v1" schemapb "github.com/TBD54566975/ftl/backend/protos/xyz/block/ftl/v1/schema" + "github.com/TBD54566975/ftl/go-runtime/encoding" ftlexec "github.com/TBD54566975/ftl/internal/exec" "github.com/TBD54566975/ftl/internal/log" "github.com/TBD54566975/scaffolder" @@ -209,7 +210,7 @@ type obj map[string]any func call(module, verb string, request obj, check func(response obj) error) action { return func(t testing.TB, ic testContext) error { infof("Calling %s.%s", module, verb) - data, err := json.Marshal(request) + data, err := encoding.Marshal(ic, request) if err != nil { return fmt.Errorf("failed to marshal request: %w", err) } @@ -224,7 +225,7 @@ func call(module, verb string, request obj, check func(response obj) error) acti if resp.Msg.GetError() != nil { return fmt.Errorf("verb failed: %s", resp.Msg.GetError().GetMessage()) } - err = json.Unmarshal(resp.Msg.GetBody(), &response) + err = encoding.Unmarshal(ic, resp.Msg.GetBody(), &response) if err != nil { return fmt.Errorf("failed to unmarshal response: %w", err) }