Skip to content

Commit

Permalink
Add ListMerger
Browse files Browse the repository at this point in the history
  • Loading branch information
Jefftree committed May 23, 2023
1 parent df37dd0 commit f2b4b17
Show file tree
Hide file tree
Showing 2 changed files with 243 additions and 8 deletions.
81 changes: 73 additions & 8 deletions pkg/cached/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,10 @@ type merger[K comparable, T, V any] struct {
// function will remerge all the dependencies together everytime. Since
// the list of dependencies is constant, there is no way to save some
// partial merge information either.

// Also note that Golang map iteration is not stable. If the mergeFn
// depends on the order iteration to be stable, it will need to
// implement its own sorting or iteration order.
func NewMerger[K comparable, T, V any](mergeFn func(results map[K]Result[T]) Result[V], caches map[K]Data[T]) Data[V] {
return &merger[K, T, V]{
mergeFn: mergeFn,
Expand Down Expand Up @@ -187,7 +191,71 @@ func (c *merger[K, T, V]) Get() Result[V] {
return c.result
}

type transformerCacheKeyType struct{}
// NewListMerger creates a new merge cache that merges the results of
// other caches in list form. The function only gets called if any of
// the dependency has changed.

// The benefit of ListMerger over the basic Merger is that caches are
// stored in a list and the iteration order will be deterministic. The
// caller is not forced to sort before calling the mergeFn on the
// cache.

// If any of the dependency returned an error before, or any of the
// dependency returned an error this time, or if the mergeFn failed
// before, then the function is reran.

// Note that this assumes there is no "partial" merge, the merge
// function will remerge all the dependencies together everytime. Since
// the list of dependencies is constant, there is no way to save some
// partial merge information either.
type listMerger[T, V any] struct {
mergeFn func([]Result[T]) Result[V]
caches []Data[T]
cacheResults []Result[T]
result Result[V]
}

func NewListMerger[T, V any](mergeFn func(results []Result[T]) Result[V], caches []Data[T]) Data[V] {
return &listMerger[T, V]{
mergeFn: mergeFn,
caches: caches,
}
}
func (c *listMerger[T, V]) prepareResults() []Result[T] {
cacheResults := []Result[T]{}
for _, cache := range c.caches {
cacheResults = append(cacheResults, cache.Get())
}
return cacheResults
}

func (c *listMerger[T, V]) needsRunning(results []Result[T]) bool {
if len(c.cacheResults) == 0 {
return true
}
if c.result.Err != nil {
return true
}
if len(results) != len(c.cacheResults) {
panic(fmt.Errorf("invalid number of results: %v (expected %v)", len(results), len(c.cacheResults)))
}
for i, oldResult := range c.cacheResults {
newResult := results[i]
if newResult.Etag != oldResult.Etag || newResult.Err != nil || oldResult.Err != nil {
return true
}
}
return false
}

func (c *listMerger[T, V]) Get() Result[V] {
cacheResults := c.prepareResults()
if c.needsRunning(cacheResults) {
c.cacheResults = cacheResults
c.result = c.mergeFn(c.cacheResults)
}
return c.result
}

// NewTransformer creates a new cache that transforms the result of
// another cache. The transformFn will only be called if the source
Expand All @@ -198,15 +266,12 @@ type transformerCacheKeyType struct{}
// this time, or if the transformerFn failed before, the function is
// reran.
func NewTransformer[T, V any](transformerFn func(Result[T]) Result[V], source Data[T]) Data[V] {
return NewMerger(func(caches map[transformerCacheKeyType]Result[T]) Result[V] {
cache, ok := caches[transformerCacheKeyType{}]
if len(caches) != 1 || !ok {
return NewListMerger(func(caches []Result[T]) Result[V] {
if len(caches) != 1 {
panic(fmt.Errorf("invalid cache for transformer cache: %v", caches))
}
return transformerFn(cache)
}, map[transformerCacheKeyType]Data[T]{
{}: source,
})
return transformerFn(caches[0])
}, []Data[T]{source})
}

// NewSource creates a new cache that generates some data. This
Expand Down
170 changes: 170 additions & 0 deletions pkg/cached/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -846,3 +846,173 @@ func Example() {
"replaceable": &replaceable,
}))
}

func TestListMerger(t *testing.T) {
source1Count := 0
source1 := cached.NewSource(func() cached.Result[[]byte] {
source1Count += 1
return cached.NewResultOK([]byte("source1"), "source1")
})
source2Count := 0
source2 := cached.NewSource(func() cached.Result[[]byte] {
source2Count += 1
return cached.NewResultOK([]byte("source2"), "source2")
})
mergerCount := 0
merger := cached.NewListMerger(func(results []cached.Result[[]byte]) cached.Result[[]byte] {
mergerCount += 1
d := []string{}
e := []string{}
for _, result := range results {
if result.Err != nil {
return cached.NewResultErr[[]byte](result.Err)
}
d = append(d, string(result.Data))
e = append(e, result.Etag)
}
sort.Strings(d)
sort.Strings(e)
return cached.NewResultOK([]byte("merged "+strings.Join(d, " and ")), "merged "+strings.Join(e, " and "))
}, []cached.Data[[]byte]{
source1, source2,
})
if err := merger.Get().Err; err != nil {
t.Fatalf("unexpected error: %v", err)
}
result := merger.Get()
if result.Err != nil {
t.Fatalf("unexpected error: %v", result.Err)
}
if want := "merged source1 and source2"; string(result.Data) != want {
t.Fatalf("expected data = %v, got %v", want, string(result.Data))
}
if want := "merged source1 and source2"; result.Etag != want {
t.Fatalf("expected etag = %v, got %v", want, result.Etag)
}

if source1Count != 2 {
t.Fatalf("Expected source function called twice, called: %v", source1Count)
}
if source2Count != 2 {
t.Fatalf("Expected source function called twice, called: %v", source2Count)
}
if mergerCount != 1 {
t.Fatalf("Expected merger function called once, called: %v", mergerCount)
}
}

func TestListMergerSourceError(t *testing.T) {
source1Count := 0
source1 := cached.NewSource(func() cached.Result[[]byte] {
source1Count += 1
return cached.NewResultErr[[]byte](errors.New("source1 error"))
})
source2Count := 0
source2 := cached.NewSource(func() cached.Result[[]byte] {
source2Count += 1
return cached.NewResultOK([]byte("source2"), "source2")
})
mergerCount := 0
merger := cached.NewListMerger(func(results []cached.Result[[]byte]) cached.Result[[]byte] {
mergerCount += 1
d := []string{}
e := []string{}
for _, result := range results {
if result.Err != nil {
return cached.NewResultErr[[]byte](result.Err)
}
d = append(d, string(result.Data))
e = append(e, result.Etag)
}
sort.Strings(d)
sort.Strings(e)
return cached.NewResultOK([]byte("merged "+strings.Join(d, " and ")), "merged "+strings.Join(e, " and "))
}, []cached.Data[[]byte]{
source1, source2,
})
if err := merger.Get().Err; err == nil {
t.Fatalf("expected error, none found")
}
if err := merger.Get().Err; err == nil {
t.Fatalf("expected error, none found")
}
if source1Count != 2 {
t.Fatalf("Expected source function called twice, called: %v", source1Count)
}
if source2Count != 2 {
t.Fatalf("Expected source function called twice, called: %v", source2Count)
}
if mergerCount != 2 {
t.Fatalf("Expected merger function called twice, called: %v", mergerCount)
}
}

func TestListMergerAlternateSourceError(t *testing.T) {
source1Count := 0
source1 := cached.NewSource(func() cached.Result[[]byte] {
source1Count += 1
if source1Count%2 == 0 {
return cached.NewResultErr[[]byte](errors.New("source1 error"))
} else {
return cached.NewResultOK([]byte("source1"), "source1")
}
})
source2Count := 0
source2 := cached.NewSource(func() cached.Result[[]byte] {
source2Count += 1
return cached.NewResultOK([]byte("source2"), "source2")
})
mergerCount := 0
merger := cached.NewListMerger(func(results []cached.Result[[]byte]) cached.Result[[]byte] {
mergerCount += 1
d := []string{}
e := []string{}
for _, result := range results {
if result.Err != nil {
return cached.NewResultErr[[]byte](result.Err)
}
d = append(d, string(result.Data))
e = append(e, result.Etag)
}
sort.Strings(d)
sort.Strings(e)
return cached.NewResultOK([]byte("merged "+strings.Join(d, " and ")), "merged "+strings.Join(e, " and "))
}, []cached.Data[[]byte]{
source1, source2,
})
result := merger.Get()
if result.Err != nil {
t.Fatalf("unexpected error: %v", result.Err)
}
if want := "merged source1 and source2"; string(result.Data) != want {
t.Fatalf("expected data = %v, got %v", want, string(result.Data))
}
if want := "merged source1 and source2"; result.Etag != want {
t.Fatalf("expected etag = %v, got %v", want, result.Etag)
}
if err := merger.Get().Err; err == nil {
t.Fatalf("expected error, none found")
}
result = merger.Get()
if result.Err != nil {
t.Fatalf("unexpected error: %v", result.Err)
}
if want := "merged source1 and source2"; string(result.Data) != want {
t.Fatalf("expected data = %v, got %v", want, string(result.Data))
}
if want := "merged source1 and source2"; result.Etag != want {
t.Fatalf("expected etag = %v, got %v", want, result.Etag)
}
if err := merger.Get().Err; err == nil {
t.Fatalf("expected error, none found")
}
if source1Count != 4 {
t.Fatalf("Expected source function called 4x, called: %v", source1Count)
}
if source2Count != 4 {
t.Fatalf("Expected source function called 4x, called: %v", source2Count)
}
if mergerCount != 4 {
t.Fatalf("Expected merger function called 4x, called: %v", mergerCount)
}
}

0 comments on commit f2b4b17

Please sign in to comment.