Skip to content

Commit

Permalink
Add mocking integration
Browse files Browse the repository at this point in the history
  • Loading branch information
jackkleeman committed Feb 11, 2025
1 parent 9166ce4 commit 8025b4e
Show file tree
Hide file tree
Showing 27 changed files with 2,883 additions and 123 deletions.
16 changes: 16 additions & 0 deletions .mockery.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
with-expecter: true
issue-845-fix: true
resolve-type-alias: false
dir: mocks
outPkg: mocks
packages:
github.com/restatedev/sdk-go/internal/state:
interfaces:
Context:
Client:
Selector:
AwakeableFuture:
DurablePromise:
github.com/restatedev/sdk-go/internal/rand:
interfaces:
Rand:
17 changes: 2 additions & 15 deletions context.go
Original file line number Diff line number Diff line change
@@ -1,30 +1,17 @@
package restate

import (
"context"
"log/slog"

"github.com/restatedev/sdk-go/internal/state"
)

// RunContext is passed to [Run] closures and provides the limited set of Restate operations that are safe to use there.
type RunContext interface {
context.Context

// Log obtains a handle on a slog.Logger which already has some useful fields (invocationID and method)
// By default, this logger will not output messages if the invocation is currently replaying
// The log handler can be set with `.WithLogger()` on the server object
Log() *slog.Logger

// Request gives extra information about the request that started this invocation
Request() *state.Request
}
type RunContext = state.RunContext

// Context is an extension of [RunContext] which is passed to Restate service handlers and enables
// interaction with Restate
type Context interface {
RunContext
inner() *state.Context
inner() state.Context
}

// ObjectSharedContext is an extension of [Context] which is passed to shared-mode Virtual Object handlers,
Expand Down
140 changes: 140 additions & 0 deletions examples/ticketreservation/main_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
package main

import (
"log/slog"
"testing"
"time"

"github.com/google/uuid"
restate "github.com/restatedev/sdk-go"
"github.com/restatedev/sdk-go/mocks"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)

func TestPayment(t *testing.T) {
mockCtx := mocks.NewMockContext(t)
mockRand := mocks.NewMockRand(t)

mockRand.EXPECT().UUID().Return(uuid.UUID([16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}))

mockCtx.EXPECT().Rand().Return(mockRand)
mockCtx.EXPECT().RunAndExpect(t, mockCtx, true, nil)
mockCtx.EXPECT().Log().Return(slog.Default())

resp, err := (&checkout{}).Payment(restate.WithMockContext(mockCtx), PaymentRequest{Tickets: []string{"abc"}})
assert.NoError(t, err)
assert.Equal(t, resp, PaymentResponse{ID: "01020304-0506-0708-090a-0b0c0d0e0f10", Price: 30})
}

func TestReserve(t *testing.T) {
mockCtx := mocks.NewMockContext(t)

mockCtx.EXPECT().GetAndReturn("status", TicketAvailable)
mockCtx.EXPECT().Set("status", TicketReserved)

ok, err := (&ticketService{}).Reserve(restate.WithMockContext(mockCtx), restate.Void{})
assert.NoError(t, err)
assert.True(t, ok)
}

func TestUnreserve(t *testing.T) {
mockCtx := mocks.NewMockContext(t)

mockCtx.EXPECT().Key().Return("foo")
mockCtx.EXPECT().Log().Return(slog.Default())
mockCtx.EXPECT().GetAndReturn("status", TicketAvailable)
mockCtx.EXPECT().Clear("status")

_, err := (&ticketService{}).Unreserve(restate.WithMockContext(mockCtx), restate.Void{})
assert.NoError(t, err)
}

func TestMarkAsSold(t *testing.T) {
mockCtx := mocks.NewMockContext(t)

mockCtx.EXPECT().Key().Return("foo")
mockCtx.EXPECT().Log().Return(slog.Default())
mockCtx.EXPECT().GetAndReturn("status", TicketReserved)
mockCtx.EXPECT().Set("status", TicketSold)

_, err := (&ticketService{}).MarkAsSold(restate.WithMockContext(mockCtx), restate.Void{})
assert.NoError(t, err)
}

func TestStatus(t *testing.T) {
mockCtx := mocks.NewMockContext(t)

mockCtx.EXPECT().Key().Return("foo")
mockCtx.EXPECT().Log().Return(slog.Default())
mockCtx.EXPECT().GetAndReturn("status", TicketReserved)

status, err := (&ticketService{}).Status(restate.WithMockContext(mockCtx), restate.Void{})
assert.NoError(t, err)
assert.Equal(t, status, TicketReserved)
}

func TestAddTicket(t *testing.T) {
mockCtx := mocks.NewMockContext(t)
mockTicketClient := mocks.NewMockClient(t)
mockSessionClient := mocks.NewMockClient(t)

mockCtx.EXPECT().Key().Return("userID")
mockCtx.EXPECT().Object(TicketServiceName, "ticket2", "Reserve").Once().Return(mockTicketClient)
mockTicketClient.EXPECT().RequestAndReturn("userID", true, nil)

mockCtx.EXPECT().GetAndReturn("tickets", []string{"ticket1"})
mockCtx.EXPECT().Set("tickets", []string{"ticket1", "ticket2"})
mockCtx.EXPECT().Object(UserSessionServiceName, "userID", "ExpireTicket").Once().Return(mockSessionClient)
mockSessionClient.EXPECT().Send("ticket2", restate.WithDelay(15*time.Minute))

ok, err := (&userSession{}).AddTicket(restate.WithMockContext(mockCtx), "ticket2")
assert.NoError(t, err)
assert.True(t, ok)
}

func TestExpireTicket(t *testing.T) {
mockCtx := mocks.NewMockContext(t)
mockTicketClient := mocks.NewMockClient(t)

mockCtx.EXPECT().GetAndReturn("tickets", []string{"ticket1", "ticket2"})
mockCtx.EXPECT().Set("tickets", []string{"ticket1"})

mockCtx.EXPECT().Object(TicketServiceName, "ticket2", "Unreserve").Once().Return(mockTicketClient)
mockTicketClient.EXPECT().Send(restate.Void{})

_, err := (&userSession{}).ExpireTicket(restate.WithMockContext(mockCtx), "ticket2")
assert.NoError(t, err)
}

func TestCheckout(t *testing.T) {
mockCtx := mocks.NewMockContext(t)

mockCtx.EXPECT().Key().Return("userID")
mockCtx.EXPECT().GetAndReturn("tickets", []string{"ticket1"})
mockCtx.EXPECT().Log().Return(slog.Default())

mockAfter := mocks.NewMockAfterFuture(t)
mockCtx.EXPECT().After(time.Minute).Return(mockAfter)

mockCheckoutClient := mocks.NewMockClient(t)
mockCtx.EXPECT().Object(CheckoutServiceName, "", "Payment").Once().Return(mockCheckoutClient)
mockResponseFuture := mocks.NewMockResponseFuture(t)
mockCheckoutClient.EXPECT().RequestFuture(PaymentRequest{UserID: "userID", Tickets: []string{"ticket1"}}).Return(mockResponseFuture)

mockSelector := mocks.NewMockSelector(t)
mockCtx.EXPECT().Select(mockAfter, mock.AnythingOfType("restate.responseFuture[github.com/restatedev/sdk-go/examples/ticketreservation.PaymentResponse]")).Return(mockSelector)
mockSelector.EXPECT().Select().Return(mockResponseFuture)

mockResponseFuture.EXPECT().ResponseAndReturn(PaymentResponse{ID: "paymentID", Price: 30}, nil)

mockTicketClient := mocks.NewMockClient(t)
mockCtx.EXPECT().Object(TicketServiceName, "ticket1", "MarkAsSold").Once().Return(mockTicketClient)
mockTicketClient.EXPECT().Send(restate.Void{})

mockCtx.EXPECT().Clear("tickets")

ok, err := (&userSession{}).Checkout(restate.WithMockContext(mockCtx), restate.Void{})
assert.NoError(t, err)
assert.True(t, ok)
}
43 changes: 15 additions & 28 deletions facilitators.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
// Rand returns a random source which will give deterministic results for a given invocation
// The source wraps the stdlib rand.Rand but with some extra helper methods
// This source is not safe for use inside .Run()
func Rand(ctx Context) *rand.Rand {
func Rand(ctx Context) rand.Rand {
return ctx.inner().Rand()
}

Expand All @@ -30,14 +30,7 @@ func After(ctx Context, d time.Duration) AfterFuture {

// After is a handle on a Sleep operation which allows you to do other work concurrently
// with the sleep.
type AfterFuture interface {
// Done blocks waiting on the remaining duration of the sleep.
// It is *not* safe to call this in a goroutine - use Context.Select if you want to wait on multiple
// results at once. Can return a terminal error in the case where the invocation was cancelled mid-sleep,
// hence Done() should always be called, even after using Context.Select.
Done() error
futures.Selectable
}
type AfterFuture = state.AfterFuture

// Service gets a Service request client by service and method name
func Service[O any](ctx Context, service string, method string, options ...options.ClientOption) Client[any, O] {
Expand Down Expand Up @@ -85,11 +78,11 @@ type SendClient[I any] interface {
}

type outputClient[O any] struct {
inner *state.Client
inner state.Client
}

func (t outputClient[O]) Request(input any, options ...options.RequestOption) (output O, err error) {
err = t.inner.RequestFuture(input, options...).Response(&output)
err = t.inner.Request(input, &output, options...)
return
}

Expand Down Expand Up @@ -135,11 +128,11 @@ type ResponseFuture[O any] interface {
}

type responseFuture[O any] struct {
state.DecodingResponseFuture
state.ResponseFuture
}

func (t responseFuture[O]) Response() (output O, err error) {
err = t.DecodingResponseFuture.Response(&output)
err = t.ResponseFuture.Response(&output)
return
}

Expand All @@ -162,11 +155,11 @@ type AwakeableFuture[T any] interface {
}

type awakeable[T any] struct {
state.DecodingAwakeable
state.AwakeableFuture
}

func (t awakeable[T]) Result() (output T, err error) {
err = t.DecodingAwakeable.Result(&output)
err = t.AwakeableFuture.Result(&output)
return
}

Expand All @@ -186,26 +179,20 @@ func Select(ctx Context, futs ...futures.Selectable) Selector {
return ctx.inner().Select(futs...)
}

// Selectable is a marker interface for futures that can be selected over with [Select]
type Selectable = futures.Selectable

// Selector is an iterator over a list of blocking Restate operations that are running
// in the background.
type Selector interface {
// Remaining returns whether there are still operations that haven't been returned by Select().
// There will always be exactly the same number of results as there were operations
// given to Context.Select
Remaining() bool
// Select blocks on the next completed operation or returns nil if there are none left
Select() futures.Selectable
}
type Selector = state.Selector

// Run runs the function (fn), storing final results (including terminal errors)
// durably in the journal, or otherwise for transient errors stopping execution
// so Restate can retry the invocation. Replays will produce the same value, so
// all non-deterministic operations (eg, generating a unique ID) *must* happen
// inside Run blocks.
func Run[T any](ctx Context, fn func(ctx RunContext) (T, error), options ...options.RunOption) (output T, err error) {
err = ctx.inner().Run(func(ctx state.RunContext) (any, error) {
err = ctx.inner().Run(func(ctx RunContext) (any, error) {
return fn(ctx)
}, &output, options...)

Expand Down Expand Up @@ -274,19 +261,19 @@ type DurablePromise[T any] interface {
}

type durablePromise[T any] struct {
state.DecodingPromise
state.DurablePromise
}

func (t durablePromise[T]) Result() (output T, err error) {
err = t.DecodingPromise.Result(&output)
err = t.DurablePromise.Result(&output)
return
}

func (t durablePromise[T]) Peek() (output T, err error) {
_, err = t.DecodingPromise.Peek(&output)
_, err = t.DurablePromise.Peek(&output)
return
}

func (t durablePromise[T]) Resolve(value T) (err error) {
return t.DecodingPromise.Resolve(value)
return t.DurablePromise.Resolve(value)
}
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ require (
github.com/go-logr/logr v1.4.2 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/stretchr/objx v0.5.2 // indirect
go.opentelemetry.io/otel/metric v1.28.0 // indirect
go.opentelemetry.io/otel/trace v1.28.0 // indirect
golang.org/x/text v0.14.0 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ github.com/mr-tron/base58 v1.2.0 h1:T/HDJBh4ZCPbU39/+c3rRvE0uKBQlU27+QI8LJ4t64o=
github.com/mr-tron/base58 v1.2.0/go.mod h1:BinMc/sQntlIE1frQmRFPUoPA1Zkr8VRgBdjWI2mNwc=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
go.opentelemetry.io/otel v1.28.0 h1:/SqNcYk+idO0CxKEUOtKQClMK/MimZihKYMruSMViUo=
Expand Down
10 changes: 5 additions & 5 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func NewServiceHandler[I any, O any](fn ServiceHandlerFn[I, O], opts ...options.
}
}

func (h *serviceHandler[I, O]) Call(ctx *state.Context, bytes []byte) ([]byte, error) {
func (h *serviceHandler[I, O]) Call(ctx state.Context, bytes []byte) ([]byte, error) {
var input I
if err := encoding.Unmarshal(h.options.Codec, bytes, &input); err != nil {
return nil, TerminalError(fmt.Errorf("request could not be decoded into handler input type: %w", err), http.StatusBadRequest)
Expand Down Expand Up @@ -133,18 +133,18 @@ func NewObjectSharedHandler[I any, O any](fn ObjectSharedHandlerFn[I, O], opts .
}

type ctxWrapper struct {
*state.Context
state.Context
}

func (o ctxWrapper) inner() *state.Context {
func (o ctxWrapper) inner() state.Context {
return o.Context
}
func (o ctxWrapper) object() {}
func (o ctxWrapper) exclusiveObject() {}
func (o ctxWrapper) workflow() {}
func (o ctxWrapper) runWorkflow() {}

func (h *objectHandler[I, O]) Call(ctx *state.Context, bytes []byte) ([]byte, error) {
func (h *objectHandler[I, O]) Call(ctx state.Context, bytes []byte) ([]byte, error) {
var input I
if err := encoding.Unmarshal(h.options.Codec, bytes, &input); err != nil {
return nil, TerminalError(fmt.Errorf("request could not be decoded into handler input type: %w", err), http.StatusBadRequest)
Expand Down Expand Up @@ -233,7 +233,7 @@ func NewWorkflowSharedHandler[I any, O any](fn WorkflowSharedHandlerFn[I, O], op
}
}

func (h *workflowHandler[I, O]) Call(ctx *state.Context, bytes []byte) ([]byte, error) {
func (h *workflowHandler[I, O]) Call(ctx state.Context, bytes []byte) ([]byte, error) {
var input I
if err := encoding.Unmarshal(h.options.Codec, bytes, &input); err != nil {
return nil, TerminalError(fmt.Errorf("request could not be decoded into handler input type: %w", err), http.StatusBadRequest)
Expand Down
Loading

0 comments on commit 8025b4e

Please sign in to comment.