diff --git a/go.mod b/go.mod index 8edf378..b4b859e 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/0x5a17ed/itkit go 1.20 require ( - github.com/0x5a17ed/coro v1.0.0 + github.com/0x5a17ed/coro v1.1.0 github.com/stretchr/testify v1.9.0 go.uber.org/goleak v1.3.0 golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 diff --git a/go.sum b/go.sum index 24540ad..3dd10c9 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ github.com/0x5a17ed/coro v1.0.0 h1:e7aCb2d+UEdxECfRgW7O82S85qcfagkKeYRmpbOivP4= github.com/0x5a17ed/coro v1.0.0/go.mod h1:63Q/S3kITlZaHEAz43017XiPio29TPqTRWNFY0AiFjo= +github.com/0x5a17ed/coro v1.1.0 h1:20ZVX8/Wk8UE1QMJz839dkk6dXVXv29ApxOfAwcwl5w= +github.com/0x5a17ed/coro v1.1.0/go.mod h1:qBhkDOIugmZNQ1JQkrHb9nqy5/Xgd+22fumabDWPp3w= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= diff --git a/iters/ioit/ioit.go b/iters/ioit/ioit.go index 033dce4..49ab694 100644 --- a/iters/ioit/ioit.go +++ b/iters/ioit/ioit.go @@ -15,23 +15,66 @@ package ioit import ( + "errors" "sync/atomic" - "github.com/0x5a17ed/itkit/iters/genit" + "github.com/0x5a17ed/coro" + "github.com/0x5a17ed/itkit" ) -type GeneratorFn[T any] func(yield func(T)) error +var ( + ErrStopped = errors.New("stopped") +) + +// YieldFn is a function that is called by a generator to send back +// values generated by the same generator. +// +// The return value of YieldFn indicates whenever the generator was +// closed (false) or the generator is allowed to continue (true). +type YieldFn[O any] func(O) bool + +// GeneratorFn is a function that generates values which are sent back +// through the given [YieldFn] yield function. +// +// The cont parameter tells the generator whenever it is allowed to +// continue or not. The argument must be evaluated at the beginning +// of the generator before any call to yield. The generator must +// return before any call to yield if the argument evaluates to false. +// +// See [YieldFn] for the documentation of the yield argument. +type GeneratorFn[O any] func(cont bool, yield func(O) bool) error type Generator[T any] struct { - *genit.Generator[T] + *coro.C[bool, T] + + value T + err atomic.Pointer[error] +} - err atomic.Pointer[error] +// Next fetches the next value produced by the wrapped [GeneratorFn] +// and returns true whenever there is a new value available and false +// otherwise. +func (g *Generator[T]) Next() (ok bool) { + g.value, ok = g.Resume(true) + return } -// Close stops the generator and returns any error returned by +// Value returns the latest value produced by the wrapped [GeneratorFn]. +func (g *Generator[T]) Value() T { return g.value } + +// Iter returns the [Generator] as an [itkit.Iterator] value. +func (g *Generator[T]) Iter() itkit.Iterator[T] { + return g +} + +// Close signals the generator to close down and returns any error returned by // the [GeneratorFn] function. func (g *Generator[T]) Close() error { - g.Generator.Stop() + var ok bool + if g.value, ok = g.Resume(false); ok { + // The generator had its chance. + g.C.Stop() + } return g.Err() } @@ -46,8 +89,25 @@ func (g *Generator[T]) Err() error { func Run[T any](fn GeneratorFn[T]) *Generator[T] { iog := &Generator[T]{} - iog.Generator = genit.Run(func(yield func(T)) { - err := fn(yield) + iog.C = coro.NewSub[bool, T](func(cont bool, yield func(T) bool) { + defer func() { + if errP := iog.err.Load(); errP != nil { + return + } + // No err value stored means fn run into a panic. + r := recover() + + if coro.IsStopped(r) { + // Generator was stopped, pass that value down. + err := ErrStopped + iog.err.Store(&err) + } + + // Propagate the panic value in any case. + panic(r) + }() + + err := fn(cont, yield) iog.err.Store(&err) }) diff --git a/iters/ioit/ioit_test.go b/iters/ioit/ioit_test.go index 6d74a78..ad169b4 100644 --- a/iters/ioit/ioit_test.go +++ b/iters/ioit/ioit_test.go @@ -16,19 +16,99 @@ package ioit_test import ( "io/fs" + "sync" "testing" "github.com/0x5a17ed/itkit/iters/ioit" + "github.com/0x5a17ed/itkit/iters/sliceit" "github.com/stretchr/testify/assert" "go.uber.org/goleak" ) +type Log struct { + mx sync.Mutex + Entries []string +} + +func (l *Log) Add(s string) { + l.mx.Lock() + defer l.mx.Unlock() + + l.Entries = append(l.Entries, s) +} + +func (l *Log) Assert(t *testing.T, s []string) { + assert.Equal(t, s, l.Entries) +} + +func singleGeneratorFactory(l *Log) *ioit.Generator[int] { + l.Add("factory creating generator") + gen := ioit.Run(func(cont bool, yield func(int) bool) error { + l.Add("generator enter") + defer l.Add("generator leave") + + if !cont { + l.Add("generator closing down") + return nil + } + + l.Add("generator yielding") + cont = yield(1) + l.Add("generator yielded") + + if !cont { + l.Add("generator closing down") + return fs.ErrClosed + } + + l.Add("generator returning") + return fs.ErrNotExist + }) + l.Add("factory returning generator") + + return gen +} + +func badGeneratorFactory(l *Log) *ioit.Generator[int] { + l.Add("factory creating generator") + gen := ioit.Run(func(cont bool, yield func(int) bool) error { + l.Add("generator enter") + defer l.Add("generator leave") + + for i := 1; ; i++ { + l.Add("generator yielding") + yield(i) + l.Add("generator yielded") + } + }) + l.Add("factory returning generator") + + return gen +} + +func panicGeneratorFactory(l *Log) *ioit.Generator[int] { + l.Add("factory creating generator") + gen := ioit.Run(func(cont bool, yield func(int) bool) error { + l.Add("generator enter") + defer l.Add("generator leave") + + l.Add("generator panic") + panic(nil) + }) + l.Add("factory returning generator") + + return gen +} + func TestGenerator_Err(t *testing.T) { defer goleak.VerifyNone(t) asserter := assert.New(t) - g := ioit.Run(func(func(any)) error { + g := ioit.Run(func(cont bool, _ func(any) bool) error { + if !cont { + return nil + } return fs.ErrNotExist }) @@ -38,3 +118,135 @@ func TestGenerator_Err(t *testing.T) { asserter.ErrorIs(g.Err(), fs.ErrNotExist) } + +func TestGenerator_Close(t *testing.T) { + t.Run("after completion", func(t *testing.T) { + asserter := assert.New(t) + + var l Log + g := singleGeneratorFactory(&l) + + // Exhaust the iterator. + asserter.Equal([]int{1}, sliceit.To(g.Iter())) + + // Signal the generator to close down. + l.Add("caller closing") + asserter.ErrorIs(g.Close(), fs.ErrNotExist) + + l.Assert(t, []string{ + "factory creating generator", + "factory returning generator", + "generator enter", + "generator yielding", + "generator yielded", + "generator returning", + "generator leave", + "caller closing", + }) + }) + + t.Run("before start", func(t *testing.T) { + asserter := assert.New(t) + + var l Log + g := singleGeneratorFactory(&l) + + // Signal the generator to close down. + l.Add("caller closing") + asserter.ErrorIs(g.Close(), nil) + l.Add("caller closed") + + // Assert the generator has reported no error. + asserter.NoError(g.Err()) + + // assert the state at the end. + l.Assert(t, []string{ + "factory creating generator", + "factory returning generator", + "caller closing", + "generator enter", + "generator closing down", + "generator leave", + "caller closed", + }) + }) + + t.Run("running", func(t *testing.T) { + asserter := assert.New(t) + + var l Log + g := singleGeneratorFactory(&l) + + // Retrieve first value from the generator. + l.Add("caller resuming") + ok := g.Next() + l.Add("caller resumed") + asserter.True(ok) + + // Signal the generator to close down. + l.Add("caller closing") + asserter.ErrorIs(g.Close(), fs.ErrClosed) + l.Add("caller closed") + + // assert the state at the end. + l.Assert(t, []string{ + "factory creating generator", + "factory returning generator", + "caller resuming", + "generator enter", + "generator yielding", + "caller resumed", + "caller closing", + "generator yielded", + "generator closing down", + "generator leave", + "caller closed", + }) + }) + + t.Run("bogus", func(t *testing.T) { + asserter := assert.New(t) + + var l Log + g := badGeneratorFactory(&l) + + // Signal the generator to close down. + l.Add("caller closing") + asserter.ErrorIs(g.Close(), ioit.ErrStopped) + l.Add("caller closed") + + l.Assert(t, []string{ + "factory creating generator", + "factory returning generator", + "caller closing", + "generator enter", + "generator yielding", + "generator leave", + "caller closed", + }) + }) + + t.Run("panic", func(t *testing.T) { + asserter := assert.New(t) + + var l Log + g := panicGeneratorFactory(&l) + + // Signal the generator to close down. + l.Add("caller closing") + asserter.Panics(func() { + _ = g.Close() + }) + l.Add("caller closed") + + l.Assert(t, []string{ + "factory creating generator", + "factory returning generator", + "caller closing", + "generator enter", + "generator panic", + "generator leave", + "caller closed", + }) + }) +}