Skip to content

Commit

Permalink
Merge pull request #393 from Jefftree/list-cache
Browse files Browse the repository at this point in the history
Add ListMerger to cached package
  • Loading branch information
k8s-ci-robot authored May 24, 2023
2 parents df37dd0 + b5722cd commit 7828149
Show file tree
Hide file tree
Showing 2 changed files with 233 additions and 36 deletions.
99 changes: 63 additions & 36 deletions pkg/cached/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,14 +110,6 @@ type Data[T any] interface {
Get() Result[T]
}

// T is the source type, V is the destination type.
type merger[K comparable, T, V any] struct {
mergeFn func(map[K]Result[T]) Result[V]
caches map[K]Data[T]
cacheResults map[K]Result[T]
result Result[V]
}

// NewMerger creates a new merge cache, a cache that merges the result
// of other caches. The function only gets called if any of the
// dependency has changed.
Expand All @@ -135,27 +127,71 @@ type merger[K comparable, T, V any] struct {
// function will remerge all the dependencies together everytime. Since
// the list of dependencies is constant, there is no way to save some
// partial merge information either.
//
// Also note that Golang map iteration is not stable. If the mergeFn
// depends on the order iteration to be stable, it will need to
// implement its own sorting or iteration order.
func NewMerger[K comparable, T, V any](mergeFn func(results map[K]Result[T]) Result[V], caches map[K]Data[T]) Data[V] {
return &merger[K, T, V]{
listCaches := make([]Data[T], 0, len(caches))
// maps from index to key
indexes := make(map[int]K, len(caches))
i := 0
for k := range caches {
listCaches = append(listCaches, caches[k])
indexes[i] = k
i++
}

return NewListMerger(func(results []Result[T]) Result[V] {
if len(results) != len(indexes) {
panic(fmt.Errorf("invalid result length %d, expected %d", len(results), len(indexes)))
}
m := make(map[K]Result[T], len(results))
for i := range results {
m[indexes[i]] = results[i]
}
return mergeFn(m)
}, listCaches)
}

type listMerger[T, V any] struct {
mergeFn func([]Result[T]) Result[V]
caches []Data[T]
cacheResults []Result[T]
result Result[V]
}

// NewListMerger creates a new merge cache that merges the results of
// other caches in list form. The function only gets called if any of
// the dependency has changed.
//
// The benefit of ListMerger over the basic Merger is that caches are
// stored in an ordered list so the order of the cache will be
// preserved in the order of the results passed to the mergeFn.
//
// If any of the dependency returned an error before, or any of the
// dependency returned an error this time, or if the mergeFn failed
// before, then the function is reran.
//
// Note that this assumes there is no "partial" merge, the merge
// function will remerge all the dependencies together everytime. Since
// the list of dependencies is constant, there is no way to save some
// partial merge information either.
func NewListMerger[T, V any](mergeFn func(results []Result[T]) Result[V], caches []Data[T]) Data[V] {
return &listMerger[T, V]{
mergeFn: mergeFn,
caches: caches,
}
}

func (c *merger[K, T, V]) prepareResults() map[K]Result[T] {
cacheResults := make(map[K]Result[T], len(c.caches))
for key, cache := range c.caches {
cacheResults[key] = cache.Get()
func (c *listMerger[T, V]) prepareResults() []Result[T] {
cacheResults := make([]Result[T], 0, len(c.caches))
for _, cache := range c.caches {
cacheResults = append(cacheResults, cache.Get())
}
return cacheResults
}

// Rerun if:
// - The last run resulted in an error
// - Any of the dependency previously returned an error
// - Any of the dependency just returned an error
// - Any of the dependency's etag changed
func (c *merger[K, T, V]) needsRunning(results map[K]Result[T]) bool {
func (c *listMerger[T, V]) needsRunning(results []Result[T]) bool {
if c.cacheResults == nil {
return true
}
Expand All @@ -165,20 +201,16 @@ func (c *merger[K, T, V]) needsRunning(results map[K]Result[T]) bool {
if len(results) != len(c.cacheResults) {
panic(fmt.Errorf("invalid number of results: %v (expected %v)", len(results), len(c.cacheResults)))
}
for key, oldResult := range c.cacheResults {
newResult, ok := results[key]
if !ok {
panic(fmt.Errorf("unknown cache entry: %v", key))
}

for i, oldResult := range c.cacheResults {
newResult := results[i]
if newResult.Etag != oldResult.Etag || newResult.Err != nil || oldResult.Err != nil {
return true
}
}
return false
}

func (c *merger[K, T, V]) Get() Result[V] {
func (c *listMerger[T, V]) Get() Result[V] {
cacheResults := c.prepareResults()
if c.needsRunning(cacheResults) {
c.cacheResults = cacheResults
Expand All @@ -187,8 +219,6 @@ func (c *merger[K, T, V]) Get() Result[V] {
return c.result
}

type transformerCacheKeyType struct{}

// NewTransformer creates a new cache that transforms the result of
// another cache. The transformFn will only be called if the source
// cache has updated the output, otherwise, the cached result will be
Expand All @@ -198,15 +228,12 @@ type transformerCacheKeyType struct{}
// this time, or if the transformerFn failed before, the function is
// reran.
func NewTransformer[T, V any](transformerFn func(Result[T]) Result[V], source Data[T]) Data[V] {
return NewMerger(func(caches map[transformerCacheKeyType]Result[T]) Result[V] {
cache, ok := caches[transformerCacheKeyType{}]
if len(caches) != 1 || !ok {
return NewListMerger(func(caches []Result[T]) Result[V] {
if len(caches) != 1 {
panic(fmt.Errorf("invalid cache for transformer cache: %v", caches))
}
return transformerFn(cache)
}, map[transformerCacheKeyType]Data[T]{
{}: source,
})
return transformerFn(caches[0])
}, []Data[T]{source})
}

// NewSource creates a new cache that generates some data. This
Expand Down
170 changes: 170 additions & 0 deletions pkg/cached/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -846,3 +846,173 @@ func Example() {
"replaceable": &replaceable,
}))
}

func TestListMerger(t *testing.T) {
source1Count := 0
source1 := cached.NewSource(func() cached.Result[[]byte] {
source1Count += 1
return cached.NewResultOK([]byte("source1"), "source1")
})
source2Count := 0
source2 := cached.NewSource(func() cached.Result[[]byte] {
source2Count += 1
return cached.NewResultOK([]byte("source2"), "source2")
})
mergerCount := 0
merger := cached.NewListMerger(func(results []cached.Result[[]byte]) cached.Result[[]byte] {
mergerCount += 1
d := []string{}
e := []string{}
for _, result := range results {
if result.Err != nil {
return cached.NewResultErr[[]byte](result.Err)
}
d = append(d, string(result.Data))
e = append(e, result.Etag)
}
sort.Strings(d)
sort.Strings(e)
return cached.NewResultOK([]byte("merged "+strings.Join(d, " and ")), "merged "+strings.Join(e, " and "))
}, []cached.Data[[]byte]{
source1, source2,
})
if err := merger.Get().Err; err != nil {
t.Fatalf("unexpected error: %v", err)
}
result := merger.Get()
if result.Err != nil {
t.Fatalf("unexpected error: %v", result.Err)
}
if want := "merged source1 and source2"; string(result.Data) != want {
t.Fatalf("expected data = %v, got %v", want, string(result.Data))
}
if want := "merged source1 and source2"; result.Etag != want {
t.Fatalf("expected etag = %v, got %v", want, result.Etag)
}

if source1Count != 2 {
t.Fatalf("Expected source function called twice, called: %v", source1Count)
}
if source2Count != 2 {
t.Fatalf("Expected source function called twice, called: %v", source2Count)
}
if mergerCount != 1 {
t.Fatalf("Expected merger function called once, called: %v", mergerCount)
}
}

func TestListMergerSourceError(t *testing.T) {
source1Count := 0
source1 := cached.NewSource(func() cached.Result[[]byte] {
source1Count += 1
return cached.NewResultErr[[]byte](errors.New("source1 error"))
})
source2Count := 0
source2 := cached.NewSource(func() cached.Result[[]byte] {
source2Count += 1
return cached.NewResultOK([]byte("source2"), "source2")
})
mergerCount := 0
merger := cached.NewListMerger(func(results []cached.Result[[]byte]) cached.Result[[]byte] {
mergerCount += 1
d := []string{}
e := []string{}
for _, result := range results {
if result.Err != nil {
return cached.NewResultErr[[]byte](result.Err)
}
d = append(d, string(result.Data))
e = append(e, result.Etag)
}
sort.Strings(d)
sort.Strings(e)
return cached.NewResultOK([]byte("merged "+strings.Join(d, " and ")), "merged "+strings.Join(e, " and "))
}, []cached.Data[[]byte]{
source1, source2,
})
if err := merger.Get().Err; err == nil {
t.Fatalf("expected error, none found")
}
if err := merger.Get().Err; err == nil {
t.Fatalf("expected error, none found")
}
if source1Count != 2 {
t.Fatalf("Expected source function called twice, called: %v", source1Count)
}
if source2Count != 2 {
t.Fatalf("Expected source function called twice, called: %v", source2Count)
}
if mergerCount != 2 {
t.Fatalf("Expected merger function called twice, called: %v", mergerCount)
}
}

func TestListMergerAlternateSourceError(t *testing.T) {
source1Count := 0
source1 := cached.NewSource(func() cached.Result[[]byte] {
source1Count += 1
if source1Count%2 == 0 {
return cached.NewResultErr[[]byte](errors.New("source1 error"))
} else {
return cached.NewResultOK([]byte("source1"), "source1")
}
})
source2Count := 0
source2 := cached.NewSource(func() cached.Result[[]byte] {
source2Count += 1
return cached.NewResultOK([]byte("source2"), "source2")
})
mergerCount := 0
merger := cached.NewListMerger(func(results []cached.Result[[]byte]) cached.Result[[]byte] {
mergerCount += 1
d := []string{}
e := []string{}
for _, result := range results {
if result.Err != nil {
return cached.NewResultErr[[]byte](result.Err)
}
d = append(d, string(result.Data))
e = append(e, result.Etag)
}
sort.Strings(d)
sort.Strings(e)
return cached.NewResultOK([]byte("merged "+strings.Join(d, " and ")), "merged "+strings.Join(e, " and "))
}, []cached.Data[[]byte]{
source1, source2,
})
result := merger.Get()
if result.Err != nil {
t.Fatalf("unexpected error: %v", result.Err)
}
if want := "merged source1 and source2"; string(result.Data) != want {
t.Fatalf("expected data = %v, got %v", want, string(result.Data))
}
if want := "merged source1 and source2"; result.Etag != want {
t.Fatalf("expected etag = %v, got %v", want, result.Etag)
}
if err := merger.Get().Err; err == nil {
t.Fatalf("expected error, none found")
}
result = merger.Get()
if result.Err != nil {
t.Fatalf("unexpected error: %v", result.Err)
}
if want := "merged source1 and source2"; string(result.Data) != want {
t.Fatalf("expected data = %v, got %v", want, string(result.Data))
}
if want := "merged source1 and source2"; result.Etag != want {
t.Fatalf("expected etag = %v, got %v", want, result.Etag)
}
if err := merger.Get().Err; err == nil {
t.Fatalf("expected error, none found")
}
if source1Count != 4 {
t.Fatalf("Expected source function called 4x, called: %v", source1Count)
}
if source2Count != 4 {
t.Fatalf("Expected source function called 4x, called: %v", source2Count)
}
if mergerCount != 4 {
t.Fatalf("Expected merger function called 4x, called: %v", mergerCount)
}
}

0 comments on commit 7828149

Please sign in to comment.