Skip to content

Commit

Permalink
Add parallel steps (#51)
Browse files Browse the repository at this point in the history
  • Loading branch information
amh4r authored Sep 18, 2024
1 parent 546b1d2 commit 3e55d43
Show file tree
Hide file tree
Showing 11 changed files with 544 additions and 20 deletions.
23 changes: 21 additions & 2 deletions .github/workflows/go.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,25 @@ jobs:
uses: actions/setup-go@v3
with:
go-version: '1.21'
- name: Test
run: go test -v -race -count=1
- name: Unit test
run: go test -v -race -count=1 -short
itest:
strategy:
matrix:
os: [ubuntu-latest]
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v3
- name: Set up Go
uses: actions/setup-go@v3
with:
go-version: '1.21'

# Need npx to start the Dev Server
- name: Set up Node.js
uses: actions/setup-node@v3
with:
node-version: '18'

- name: Integration test
run: make itest
10 changes: 7 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
.PHONY: test
test:
go test -test.v
.PHONY: itest
itest:
go test ./tests -v -count=1

.PHONY: utest
utest:
go test -test.v -short

.PHONY: lint
lint:
Expand Down
54 changes: 54 additions & 0 deletions experimental/group/group.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package group

import (
"context"

"github.com/inngest/inngestgo/step"
)

type Result struct {
Error error
Value any
}

func Parallel(
ctx context.Context,
fns ...func(ctx context.Context,
) (any, error)) []Result {
ctx = context.WithValue(ctx, step.ParallelKey, true)

results := []Result{}
isPlanned := false
ch := make(chan struct{}, 1)
var unexpectedPanic any
for _, fn := range fns {
fn := fn
go func(fn func(ctx context.Context) (any, error)) {
defer func() {
if r := recover(); r != nil {
if _, ok := r.(step.ControlHijack); ok {
isPlanned = true
} else {
unexpectedPanic = r
}
}
ch <- struct{}{}
}()

value, err := fn(ctx)
results = append(results, Result{Error: err, Value: value})
}(fn)
<-ch
}

if unexpectedPanic != nil {
// Repanic to let our normal panic recovery handle it
panic(unexpectedPanic)
}

if isPlanned {
panic(step.ControlHijack{})
}

return results
}
18 changes: 16 additions & 2 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -771,8 +771,13 @@ func (h *handler) invoke(w http.ResponseWriter, r *http.Request) error {
}()
}

var stepID *string
if rawStepID := r.URL.Query().Get("stepId"); rawStepID != "" && rawStepID != "step" {
stepID = &rawStepID
}

// Invoke the function, then immediately stop the streaming buffer.
resp, ops, err := invoke(r.Context(), fn, request)
resp, ops, err := invoke(r.Context(), fn, request, stepID)
streamCancel()

// NOTE: When triggering step errors, we should have an OpcodeStepError
Expand Down Expand Up @@ -1025,7 +1030,12 @@ type StreamResponse struct {

// invoke calls a given servable function with the specified input event. The input event must
// be fully typed.
func invoke(ctx context.Context, sf ServableFunction, input *sdkrequest.Request) (any, []state.GeneratorOpcode, error) {
func invoke(
ctx context.Context,
sf ServableFunction,
input *sdkrequest.Request,
stepID *string,
) (any, []state.GeneratorOpcode, error) {
if sf.Func() == nil {
// This should never happen, but as sf.Func returns a nillable type we
// must check that the function exists.
Expand All @@ -1036,6 +1046,10 @@ func invoke(ctx context.Context, sf ServableFunction, input *sdkrequest.Request)
// within a step. This allows us to prevent any execution of future tools after a
// tool has run.
fCtx, cancel := context.WithCancel(context.Background())
if stepID != nil {
fCtx = step.SetTargetStepID(fCtx, *stepID)
}

// This must be a pointer so that it can be mutated from within function tools.
mgr := sdkrequest.NewManager(cancel, input)
fCtx = sdkrequest.SetManager(fCtx, mgr)
Expand Down
11 changes: 6 additions & 5 deletions handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ func TestInvoke(t *testing.T) {
Register(a)

t.Run("it invokes the function with correct types", func(t *testing.T) {
actual, op, err := invoke(ctx, a, createRequest(t, input))
actual, op, err := invoke(ctx, a, createRequest(t, input), nil)
require.NoError(t, err)
require.Nil(t, op)
require.Equal(t, resp, actual)
Expand Down Expand Up @@ -132,7 +132,7 @@ func TestInvoke(t *testing.T) {
Register(a)

t.Run("it invokes the function with correct types", func(t *testing.T) {
actual, op, err := invoke(ctx, a, createBatchRequest(t, input, 5))
actual, op, err := invoke(ctx, a, createBatchRequest(t, input, 5), nil)
require.NoError(t, err)
require.Nil(t, op)
require.Equal(t, resp, actual)
Expand Down Expand Up @@ -167,7 +167,7 @@ func TestInvoke(t *testing.T) {
ctx := context.Background()

t.Run("it invokes the function with correct types", func(t *testing.T) {
actual, op, err := invoke(ctx, a, createRequest(t, input))
actual, op, err := invoke(ctx, a, createRequest(t, input), nil)
require.NoError(t, err)
require.Nil(t, op)
require.Equal(t, resp, actual)
Expand Down Expand Up @@ -205,7 +205,7 @@ func TestInvoke(t *testing.T) {

ctx := context.Background()
t.Run("it invokes the function with correct types", func(t *testing.T) {
actual, op, err := invoke(ctx, a, createRequest(t, input))
actual, op, err := invoke(ctx, a, createRequest(t, input), nil)
require.NoError(t, err)
require.Nil(t, op)
require.Equal(t, resp, actual)
Expand Down Expand Up @@ -242,7 +242,7 @@ func TestInvoke(t *testing.T) {

ctx := context.Background()
t.Run("it invokes the function with correct types", func(t *testing.T) {
actual, op, err := invoke(ctx, a, createRequest(t, input))
actual, op, err := invoke(ctx, a, createRequest(t, input), nil)
require.NoError(t, err)
require.Nil(t, op)
require.Equal(t, resp, actual)
Expand Down Expand Up @@ -282,6 +282,7 @@ func TestInvoke(t *testing.T) {
actual, op, err := invoke(
ctx, a,
createRequest(t, EventA{Name: "my-event"}),
nil,
)
r.Nil(actual)
r.Nil(op)
Expand Down
13 changes: 7 additions & 6 deletions internal/sdkrequest/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@ type Request struct {
// CallCtx represents context for individual function calls. This logs the function ID, the
// specific run ID, and sep information.
type CallCtx struct {
Env string `json:"env"`
FunctionID string `json:"fn_id"`
RunID string `json:"run_id"`
StepID string `json:"step_id"`
Stack CallStack `json:"stack"`
Attempt int `json:"attempt"`
DisableImmediateExecution bool `json:"disable_immediate_execution"`
Env string `json:"env"`
FunctionID string `json:"fn_id"`
RunID string `json:"run_id"`
StepID string `json:"step_id"`
Stack CallStack `json:"stack"`
Attempt int `json:"attempt"`
}

type CallStack struct {
Expand Down
21 changes: 19 additions & 2 deletions step/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@ func Run[T any](
id string,
f func(ctx context.Context) (T, error),
) (T, error) {
targetID := getTargetStepID(ctx)
mgr := preflight(ctx)
op := mgr.NewOp(enums.OpcodeStep, id, nil)
hashedID := op.MustHash()

if val, ok := mgr.Step(op); ok {
// Create a new empty type T in v
Expand Down Expand Up @@ -78,6 +80,21 @@ func Run[T any](
return val, nil
}

if targetID != nil && *targetID != hashedID {
panic(ControlHijack{})
}

planParallel := targetID == nil && isParallel(ctx)
planBeforeRun := targetID == nil && mgr.Request().CallCtx.DisableImmediateExecution
if planParallel || planBeforeRun {
mgr.AppendOp(state.GeneratorOpcode{
ID: hashedID,
Op: enums.OpcodeStepPlanned,
Name: id,
})
panic(ControlHijack{})
}

// We're calling a function, so always cancel the context afterwards so that no
// other tools run.
defer mgr.Cancel()
Expand All @@ -94,7 +111,7 @@ func Run[T any](

// Implement per-step errors.
mgr.AppendOp(state.GeneratorOpcode{
ID: op.MustHash(),
ID: hashedID,
Op: enums.OpcodeStepError,
Name: id,
Error: &state.UserError{
Expand All @@ -112,7 +129,7 @@ func Run[T any](
mgr.SetErr(fmt.Errorf("unable to marshal run respone for '%s': %w", id, err))
}
mgr.AppendOp(state.GeneratorOpcode{
ID: op.MustHash(),
ID: hashedID,
Op: enums.OpcodeStepRun,
Name: id,
Data: byt,
Expand Down
33 changes: 33 additions & 0 deletions step/step.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@ import (

type ControlHijack struct{}

type ctxKey string

const (
targetStepIDKey = ctxKey("stepID")
ParallelKey = ctxKey("parallelKey")
)

var (
// ErrNotInFunction is called when a step tool is executed outside of an Inngest
// function call context.
Expand All @@ -23,6 +30,32 @@ func (errNotInFunction) Error() string {
return "step called without function context"
}

func getTargetStepID(ctx context.Context) *string {
if v := ctx.Value(targetStepIDKey); v != nil {
if c, ok := v.(string); ok {
return &c
}
}
return nil
}

func SetTargetStepID(ctx context.Context, id string) context.Context {
if id == "" || id == "step" {
return ctx
}

return context.WithValue(ctx, targetStepIDKey, id)
}

func isParallel(ctx context.Context) bool {
if v := ctx.Value(ParallelKey); v != nil {
if c, ok := v.(bool); ok {
return c
}
}
return false
}

func preflight(ctx context.Context) sdkrequest.InvocationManager {
if ctx.Err() != nil {
// Another tool has already ran and the context is closed. Return
Expand Down
Loading

0 comments on commit 3e55d43

Please sign in to comment.