From 2c61f344e648ae570aa87224d5366416022fc484 Mon Sep 17 00:00:00 2001 From: Quinn Klassen Date: Sun, 9 Feb 2025 22:04:15 -0800 Subject: [PATCH 1/2] Fix update.Response getting skipped by the proxy --- cmd/proxygenerator/interceptor.go | 9 +++++++++ proxy/interceptor.go | 28 ++++++++++++++++++++++++++++ proxy/interceptor_test.go | 22 ++++++++++++++++++++-- 3 files changed, 57 insertions(+), 2 deletions(-) diff --git a/cmd/proxygenerator/interceptor.go b/cmd/proxygenerator/interceptor.go index e488d4ae..2185c7a3 100644 --- a/cmd/proxygenerator/interceptor.go +++ b/cmd/proxygenerator/interceptor.go @@ -542,6 +542,12 @@ func generateInterceptor(cfg config) error { } workflowExecutions := types.NewPointer(exportTypes[0]) + updateTypes, err := lookupTypes("go.temporal.io/api/update/v1", []string{"Response"}) + if err != nil { + return err + } + updateResponse := types.NewPointer(updateTypes[0]) + payloadRecords := map[string]*TypeRecord{} failureRecords := map[string]*TypeRecord{} @@ -572,6 +578,9 @@ func generateInterceptor(cfg config) error { walk(payloadTypes, workflowExecutions, &payloadRecords, true) walk(failureTypes, workflowExecutions, &failureRecords, false) + walk(payloadTypes, updateResponse, &payloadRecords, true) + walk(failureTypes, updateResponse, &failureRecords, false) + payloadRecords = pruneRecords(payloadRecords) failureRecords = pruneRecords(failureRecords) diff --git a/proxy/interceptor.go b/proxy/interceptor.go index 0fd2072d..a617dc35 100644 --- a/proxy/interceptor.go +++ b/proxy/interceptor.go @@ -1691,6 +1691,21 @@ func visitPayloads( return err } + case *update.Response: + + if o == nil { + continue + } + + if err := visitPayloads( + ctx, + options, + o, + o.GetOutcome(), + ); err != nil { + return err + } + case []*workflow.CallbackInfo: for _, x := range o { if err := visitPayloads(ctx, options, parent, x); err != nil { @@ -3096,6 +3111,19 @@ func visitFailures(ctx *VisitFailuresContext, options *VisitFailuresOptions, obj return err } + case *update.Response: + if o == nil { + continue + } + ctx.Parent = o + if err := visitFailures( + ctx, + options, + o.GetOutcome(), + ); err != nil { + return err + } + case []*workflow.CallbackInfo: for _, x := range o { if err := visitFailures(ctx, options, x); err != nil { diff --git a/proxy/interceptor_test.go b/proxy/interceptor_test.go index 1c28042d..3347d512 100644 --- a/proxy/interceptor_test.go +++ b/proxy/interceptor_test.go @@ -183,8 +183,14 @@ func TestVisitPayloads_Any(t *testing.T) { Payloads: []*common.Payload{{Data: []byte("orig-val-don't-touch")}}, }}}) require.NoError(t, err) + msg3, err := anypb.New(&update.Response{Outcome: &update.Outcome{Value: &update.Outcome_Success{ + Success: &common.Payloads{ + Payloads: []*common.Payload{{Data: []byte("orig-val")}}, + }, + }}}) + require.NoError(t, err) root := &workflowservice.PollWorkflowTaskQueueResponse{ - Messages: []*protocol.Message{{Body: msg1}, {Body: msg2}}, + Messages: []*protocol.Message{{Body: msg1}, {Body: msg2}, {Body: msg3}}, } // Visit with any recursion enabled and only change orig-val @@ -204,6 +210,9 @@ func TestVisitPayloads_Any(t *testing.T) { update2, err := root.Messages[1].Body.UnmarshalNew() require.NoError(t, err) require.Equal(t, "orig-val-don't-touch", string(update2.(*update.Request).Input.Args.Payloads[0].Data)) + update3, err := root.Messages[2].Body.UnmarshalNew() + require.NoError(t, err) + require.Equal(t, "new-val", string(update3.(*update.Response).GetOutcome().GetSuccess().Payloads[0].Data)) // Do the same test but with a do-nothing visitor and confirm unchanged msg1, err = anypb.New(&update.Request{Input: &update.Input{Args: &common.Payloads{ @@ -214,8 +223,14 @@ func TestVisitPayloads_Any(t *testing.T) { Payloads: []*common.Payload{{Data: []byte("orig-val-don't-touch")}}, }}}) require.NoError(t, err) + msg3, err = anypb.New(&update.Response{Outcome: &update.Outcome{Value: &update.Outcome_Success{ + Success: &common.Payloads{ + Payloads: []*common.Payload{{Data: []byte("orig-val")}}, + }, + }}}) + require.NoError(t, err) root = &workflowservice.PollWorkflowTaskQueueResponse{ - Messages: []*protocol.Message{{Body: msg1}, {Body: msg2}}, + Messages: []*protocol.Message{{Body: msg1}, {Body: msg2}, {Body: msg3}}, } err = VisitPayloads(context.Background(), root, VisitPayloadsOptions{ Visitor: func(ctx *VisitPayloadsContext, p []*common.Payload) ([]*common.Payload, error) { @@ -234,6 +249,9 @@ func TestVisitPayloads_Any(t *testing.T) { update2, err = root.Messages[1].Body.UnmarshalNew() require.NoError(t, err) require.Equal(t, "orig-val-don't-touch", string(update2.(*update.Request).Input.Args.Payloads[0].Data)) + update3, err = root.Messages[2].Body.UnmarshalNew() + require.NoError(t, err) + require.Equal(t, "orig-val", string(update3.(*update.Response).GetOutcome().GetSuccess().Payloads[0].Data)) } func TestVisitFailures(t *testing.T) { From 211431e36bcb298043ca24019af6568b2a08601a Mon Sep 17 00:00:00 2001 From: Quinn Klassen Date: Mon, 10 Feb 2025 08:47:59 -0800 Subject: [PATCH 2/2] Handle any + failures as well --- cmd/proxygenerator/interceptor.go | 59 ++++++++++++++--- proxy/interceptor.go | 103 ++++++++++++++++++++++++++++++ proxy/interceptor_test.go | 38 +++++++++++ 3 files changed, 190 insertions(+), 10 deletions(-) diff --git a/cmd/proxygenerator/interceptor.go b/cmd/proxygenerator/interceptor.go index 2185c7a3..1ae32b2e 100644 --- a/cmd/proxygenerator/interceptor.go +++ b/cmd/proxygenerator/interceptor.go @@ -122,6 +122,9 @@ type VisitFailuresOptions struct { // Context is the same for every call of a visit, callers should not store it. // Visitor is free to mutate the passed failure struct. Visitor func(*VisitFailuresContext, *failure.Failure) (error) + // Will be called for each Any encountered. If not set, the default is to recurse into the Any + // object, unmarshal it, visit, and re-marshal it always (even if there are no changes). + WellKnownAnyVisitor func(*VisitFailuresContext, *anypb.Any) error } // VisitFailures calls the options.Visitor function for every Failure proto within msg. @@ -162,6 +165,25 @@ func NewFailureVisitorInterceptor(options FailureVisitorInterceptorOptions) (grp }, nil } +func (o *VisitFailuresOptions) defaultWellKnownAnyVisitor(ctx *VisitFailuresContext, p *anypb.Any) error { + child, err := p.UnmarshalNew() + if err != nil { + return fmt.Errorf("failed to unmarshal any: %w", err) + } + // We choose to visit and re-marshal always instead of cloning, visiting, + // and checking if anything changed before re-marshaling. It is assumed the + // clone + equality check is not much cheaper than re-marshal. + if err := visitFailures(ctx, o, child); err != nil { + return err + } + // Confirmed this replaces both Any fields on non-error, there is nothing + // left over + if err := p.MarshalFrom(child); err != nil { + return fmt.Errorf("failed to marshal any: %w", err) + } + return nil +} + func (o *VisitPayloadsOptions) defaultWellKnownAnyVisitor(ctx *VisitPayloadsContext, p *anypb.Any) error { child, err := p.UnmarshalNew() if err != nil { @@ -299,6 +321,20 @@ func visitFailures(ctx *VisitFailuresContext, options *VisitFailuresOptions, obj if o == nil { continue } if err := options.Visitor(ctx, o); err != nil { return err } if err := visitFailures(ctx, options, o.GetCause()); err != nil { return err } + case *anypb.Any: + if o == nil { + continue + } + visitor := options.WellKnownAnyVisitor + if visitor == nil { + visitor = options.defaultWellKnownAnyVisitor + } + ctx.Parent = o + err := visitor(ctx, o) + ctx.Parent = nil + if err != nil { + return err + } {{range $type, $record := .FailureTypes}} {{if $record.Slice}} case []{{$type}}: @@ -508,17 +544,19 @@ func generateInterceptor(cfg config) error { if err != nil { return err } - // For the purposes of payloads, we also consider the Any well known type as + + failureTypes, err := lookupTypes("go.temporal.io/api/failure/v1", []string{"Failure"}) + if err != nil { + return err + } + + // For the purposes of payloads and failures, we also consider the Any well known type as // possible if anyTypes, err := lookupTypes("google.golang.org/protobuf/types/known/anypb", []string{"Any"}); err != nil { return err } else { payloadTypes = append(payloadTypes, anyTypes...) - } - - failureTypes, err := lookupTypes("go.temporal.io/api/failure/v1", []string{"Failure"}) - if err != nil { - return err + failureTypes = append(failureTypes, anyTypes...) } // UnimplementedWorkflowServiceServer is auto-generated via our API package @@ -542,11 +580,10 @@ func generateInterceptor(cfg config) error { } workflowExecutions := types.NewPointer(exportTypes[0]) - updateTypes, err := lookupTypes("go.temporal.io/api/update/v1", []string{"Response"}) + updateTypes, err := lookupTypes("go.temporal.io/api/update/v1", []string{"Acceptance", "Rejection", "Response"}) if err != nil { return err } - updateResponse := types.NewPointer(updateTypes[0]) payloadRecords := map[string]*TypeRecord{} failureRecords := map[string]*TypeRecord{} @@ -578,8 +615,10 @@ func generateInterceptor(cfg config) error { walk(payloadTypes, workflowExecutions, &payloadRecords, true) walk(failureTypes, workflowExecutions, &failureRecords, false) - walk(payloadTypes, updateResponse, &payloadRecords, true) - walk(failureTypes, updateResponse, &failureRecords, false) + for _, ut := range updateTypes { + walk(payloadTypes, types.NewPointer(ut), &payloadRecords, true) + walk(failureTypes, types.NewPointer(ut), &failureRecords, false) + } payloadRecords = pruneRecords(payloadRecords) failureRecords = pruneRecords(failureRecords) diff --git a/proxy/interceptor.go b/proxy/interceptor.go index a617dc35..bebd47a9 100644 --- a/proxy/interceptor.go +++ b/proxy/interceptor.go @@ -120,6 +120,9 @@ type VisitFailuresOptions struct { // Context is the same for every call of a visit, callers should not store it. // Visitor is free to mutate the passed failure struct. Visitor func(*VisitFailuresContext, *failure.Failure) error + // Will be called for each Any encountered. If not set, the default is to recurse into the Any + // object, unmarshal it, visit, and re-marshal it always (even if there are no changes). + WellKnownAnyVisitor func(*VisitFailuresContext, *anypb.Any) error } // VisitFailures calls the options.Visitor function for every Failure proto within msg. @@ -160,6 +163,25 @@ func NewFailureVisitorInterceptor(options FailureVisitorInterceptorOptions) (grp }, nil } +func (o *VisitFailuresOptions) defaultWellKnownAnyVisitor(ctx *VisitFailuresContext, p *anypb.Any) error { + child, err := p.UnmarshalNew() + if err != nil { + return fmt.Errorf("failed to unmarshal any: %w", err) + } + // We choose to visit and re-marshal always instead of cloning, visiting, + // and checking if anything changed before re-marshaling. It is assumed the + // clone + equality check is not much cheaper than re-marshal. + if err := visitFailures(ctx, o, child); err != nil { + return err + } + // Confirmed this replaces both Any fields on non-error, there is nothing + // left over + if err := p.MarshalFrom(child); err != nil { + return fmt.Errorf("failed to marshal any: %w", err) + } + return nil +} + func (o *VisitPayloadsOptions) defaultWellKnownAnyVisitor(ctx *VisitPayloadsContext, p *anypb.Any) error { child, err := p.UnmarshalNew() if err != nil { @@ -1644,6 +1666,21 @@ func visitPayloads( o.Summary = no } + case *update.Acceptance: + + if o == nil { + continue + } + + if err := visitPayloads( + ctx, + options, + o, + o.GetAcceptedRequest(), + ); err != nil { + return err + } + case *update.Input: if o == nil { @@ -1676,6 +1713,22 @@ func visitPayloads( return err } + case *update.Rejection: + + if o == nil { + continue + } + + if err := visitPayloads( + ctx, + options, + o, + o.GetFailure(), + o.GetRejectedRequest(), + ); err != nil { + return err + } + case *update.Request: if o == nil { @@ -2755,6 +2808,20 @@ func visitFailures(ctx *VisitFailuresContext, options *VisitFailuresOptions, obj if err := visitFailures(ctx, options, o.GetCause()); err != nil { return err } + case *anypb.Any: + if o == nil { + continue + } + visitor := options.WellKnownAnyVisitor + if visitor == nil { + visitor = options.defaultWellKnownAnyVisitor + } + ctx.Parent = o + err := visitor(ctx, o) + ctx.Parent = nil + if err != nil { + return err + } case []*command.Command: for _, x := range o { @@ -3078,6 +3145,26 @@ func visitFailures(ctx *VisitFailuresContext, options *VisitFailuresOptions, obj return err } + case []*protocol.Message: + for _, x := range o { + if err := visitFailures(ctx, options, x); err != nil { + return err + } + } + + case *protocol.Message: + if o == nil { + continue + } + ctx.Parent = o + if err := visitFailures( + ctx, + options, + o.GetBody(), + ); err != nil { + return err + } + case map[string]*query.WorkflowQueryResult: for _, x := range o { if err := visitFailures(ctx, options, x); err != nil { @@ -3111,6 +3198,19 @@ func visitFailures(ctx *VisitFailuresContext, options *VisitFailuresOptions, obj return err } + case *update.Rejection: + if o == nil { + continue + } + ctx.Parent = o + if err := visitFailures( + ctx, + options, + o.GetFailure(), + ); err != nil { + return err + } + case *update.Response: if o == nil { continue @@ -3328,6 +3428,7 @@ func visitFailures(ctx *VisitFailuresContext, options *VisitFailuresOptions, obj ctx, options, o.GetHistory(), + o.GetMessages(), ); err != nil { return err } @@ -3406,6 +3507,7 @@ func visitFailures(ctx *VisitFailuresContext, options *VisitFailuresOptions, obj ctx, options, o.GetCommands(), + o.GetMessages(), o.GetQueryResults(), ); err != nil { return err @@ -3433,6 +3535,7 @@ func visitFailures(ctx *VisitFailuresContext, options *VisitFailuresOptions, obj ctx, options, o.GetFailure(), + o.GetMessages(), ); err != nil { return err } diff --git a/proxy/interceptor_test.go b/proxy/interceptor_test.go index 3347d512..872f90d2 100644 --- a/proxy/interceptor_test.go +++ b/proxy/interceptor_test.go @@ -292,6 +292,44 @@ func TestVisitFailures(t *testing.T) { require.Equal(2, failureCount) } +func TestVisitFailuresAny(t *testing.T) { + require := require.New(t) + + fail := &failure.Failure{ + Message: "test failure", + } + + msg, err := anypb.New(&update.Response{Outcome: &update.Outcome{Value: &update.Outcome_Failure{ + Failure: fail, + }}}) + require.NoError(err) + + req := &workflowservice.RespondWorkflowTaskCompletedRequest{ + Messages: []*protocol.Message{{Body: msg}}, + } + failureCount := 0 + err = VisitFailures( + context.Background(), + req, + VisitFailuresOptions{ + Visitor: func(vfc *VisitFailuresContext, f *failure.Failure) error { + failureCount += 1 + require.Equal("test failure", f.Message) + f.EncodedAttributes = &common.Payload{Data: []byte("test failure")} + f.Message = "encoded failure" + return nil + }, + }, + ) + require.NoError(err) + require.Equal(1, failureCount) + updateMsg, err := req.GetMessages()[0].GetBody().UnmarshalNew() + require.NoError(err) + require.Equal("encoded failure", updateMsg.(*update.Response).GetOutcome().GetFailure().GetMessage()) + require.Equal("test failure", string(updateMsg.(*update.Response).GetOutcome().GetFailure().EncodedAttributes.Data)) + +} + func TestClientInterceptor(t *testing.T) { require := require.New(t)