Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

⚡ fetch org repositories in parallel #4970

Merged
merged 4 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions internal/workerpool/collector.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// Copyright (c) Mondoo, Inc.
// SPDX-License-Identifier: BUSL-1.1

package workerpool

import (
"sync"
"sync/atomic"
)

type collector[R any] struct {
resultsCh <-chan R
results []R
read sync.Mutex

errorsCh <-chan error
errors []error

requestsRead int64
}

func (c *collector[R]) start() {
go func() {
for {
select {
case result := <-c.resultsCh:
c.read.Lock()
c.results = append(c.results, result)
c.read.Unlock()

case err := <-c.errorsCh:
c.read.Lock()
c.errors = append(c.errors, err)
c.read.Unlock()
}

atomic.AddInt64(&c.requestsRead, 1)
}
}()
}
func (c *collector[R]) GetResults() []R {
c.read.Lock()
defer c.read.Unlock()
return c.results
}

func (c *collector[R]) GetErrors() []error {
c.read.Lock()
defer c.read.Unlock()
return c.errors
}

func (c *collector[R]) RequestsRead() int64 {
return atomic.LoadInt64(&c.requestsRead)
}
112 changes: 112 additions & 0 deletions internal/workerpool/pool.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
// Copyright (c) Mondoo, Inc.
// SPDX-License-Identifier: BUSL-1.1

package workerpool

import (
"sync"
"sync/atomic"
"time"

"github.com/cockroachdb/errors"
)

type Task[R any] func() (result R, err error)

// Pool is a generic pool of workers.
type Pool[R any] struct {
queueCh chan Task[R]
resultsCh chan R
errorsCh chan error

requestsSent int64
once sync.Once

workers []*worker[R]
workerCount int

collector[R]
}

// New initializes a new Pool with the provided number of workers. The pool is generic and can
// accept any type of Task that returns the signature `func() (R, error)`.
//
// For example, a Pool[int] will accept Tasks similar to:
//
// task := func() (int, error) {
// return 42, nil
// }
func New[R any](count int) *Pool[R] {
resultsCh := make(chan R)
errorsCh := make(chan error)
return &Pool[R]{
queueCh: make(chan Task[R]),
resultsCh: resultsCh,
errorsCh: errorsCh,
workerCount: count,
collector: collector[R]{resultsCh: resultsCh, errorsCh: errorsCh},
}
}

// Start the pool workers and collector. Make sure call `Close()` to clear the pool.
//
// pool := workerpool.New[int](10)
// pool.Start()
// defer pool.Close()
func (p *Pool[R]) Start() {
p.once.Do(func() {
for i := 0; i < p.workerCount; i++ {
w := worker[R]{id: i, queueCh: p.queueCh, resultsCh: p.resultsCh, errorsCh: p.errorsCh}
w.start()
p.workers = append(p.workers, &w)
}

p.collector.start()
})
}

// Submit sends a task to the workers
func (p *Pool[R]) Submit(t Task[R]) {
p.queueCh <- t
atomic.AddInt64(&p.requestsSent, 1)
}

// GetErrors returns any error from a processed task
func (p *Pool[R]) GetErrors() error {
return errors.Join(p.collector.GetErrors()...)
}

// GetResults returns the tasks results.
//
// It is recommended to call `Wait()` before reading the results.
func (p *Pool[R]) GetResults() []R {
return p.collector.GetResults()
}

// Close waits for workers and collector to process all the requests, and then closes
// the task queue channel. After closing the pool, calling `Submit()` will panic.
func (p *Pool[R]) Close() {
p.Wait()
close(p.queueCh)
}

// Wait waits until all tasks have been processed.
func (p *Pool[R]) Wait() {
ticker := time.NewTicker(100 * time.Millisecond)
for {
if !p.Processing() {
return
}
<-ticker.C
}
}

// PendingRequests returns the number of pending requests.
func (p *Pool[R]) PendingRequests() int64 {
return atomic.LoadInt64(&p.requestsSent) - p.collector.RequestsRead()
}

// Processing return true if tasks are being processed.
func (p *Pool[R]) Processing() bool {
return p.PendingRequests() != 0
}
185 changes: 185 additions & 0 deletions internal/workerpool/pool_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
// Copyright (c) Mondoo, Inc.
// SPDX-License-Identifier: BUSL-1.1

package workerpool_test

import (
"errors"
"testing"
"time"

"math/rand"

"github.com/stretchr/testify/assert"
"go.mondoo.com/cnquery/v11/internal/workerpool"
)

func TestPoolSubmitAndRetrieveResult(t *testing.T) {
pool := workerpool.New[int](2)
pool.Start()
defer pool.Close()

task := func() (int, error) {
return 42, nil
}

// no results
assert.Empty(t, pool.GetResults())

// submit a request
pool.Submit(task)

// wait for the request to process
pool.Wait()

// should have one result
results := pool.GetResults()
if assert.Len(t, results, 1) {
assert.Equal(t, 42, results[0])
}

// no errors
assert.Nil(t, pool.GetErrors())
}

func TestPoolHandleErrors(t *testing.T) {
pool := workerpool.New[int](5)
pool.Start()
defer pool.Close()

// submit a task that will return an error
task := func() (int, error) {
return 0, errors.New("task error")
}
pool.Submit(task)

// Wait for error collector to process
pool.Wait()

err := pool.GetErrors()
if assert.Error(t, err) {
assert.Contains(t, err.Error(), "task error")
}
}

func TestPoolMultipleTasksWithErrors(t *testing.T) {
type test struct {
data int
}
pool := workerpool.New[*test](5)
pool.Start()
defer pool.Close()

tasks := []workerpool.Task[*test]{
func() (*test, error) { return &test{1}, nil },
func() (*test, error) { return &test{2}, nil },
func() (*test, error) {
return nil, errors.New("task error")
},
func() (*test, error) { return &test{3}, nil },
}

for _, task := range tasks {
pool.Submit(task)
}

// Wait for error collector to process
pool.Wait()

results := pool.GetResults()
assert.ElementsMatch(t, []*test{&test{1}, &test{2}, &test{3}}, results)
err := pool.GetErrors()
if assert.Error(t, err) {
assert.Contains(t, err.Error(), "task error")
}
}

func TestPoolHandlesNilTasks(t *testing.T) {
pool := workerpool.New[int](2)
pool.Start()
defer pool.Close()

var nilTask workerpool.Task[int]
pool.Submit(nilTask)

pool.Wait()

err := pool.GetErrors()
assert.NoError(t, err)
}

func TestPoolProcessing(t *testing.T) {
pool := workerpool.New[int](2)
pool.Start()
defer pool.Close()

task := func() (int, error) {
time.Sleep(50 * time.Millisecond)
return 10, nil
}

pool.Submit(task)

// should be processing
assert.True(t, pool.Processing())

// wait
pool.Wait()

// read results
result := pool.GetResults()
assert.Equal(t, []int{10}, result)

// should not longer be processing
assert.False(t, pool.Processing())
}

func TestPoolClosesGracefully(t *testing.T) {
pool := workerpool.New[int](1)
pool.Start()

task := func() (int, error) {
time.Sleep(100 * time.Millisecond)
return 42, nil
}

pool.Submit(task)

pool.Close()

// Ensure no panic occurs and channels are closed
assert.PanicsWithError(t, "send on closed channel", func() {
pool.Submit(task)
})
}

func TestPoolWithManyTasks(t *testing.T) {
// 30k requests with a pool of 100 workers
// should be around 15 seconds
requestCount := 30000
pool := workerpool.New[int](100)
pool.Start()
defer pool.Close()

task := func() (int, error) {
random := rand.Intn(100)
time.Sleep(time.Duration(random) * time.Millisecond)
return random, nil
}

for i := 0; i < requestCount; i++ {
pool.Submit(task)
}

// should be processing
assert.True(t, pool.Processing())

// wait
pool.Wait()

// read results
assert.Equal(t, requestCount, len(pool.GetResults()))

// should not longer be processing
assert.False(t, pool.Processing())
}
30 changes: 30 additions & 0 deletions internal/workerpool/worker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright (c) Mondoo, Inc.
// SPDX-License-Identifier: BUSL-1.1

package workerpool

type worker[R any] struct {
id int
queueCh <-chan Task[R]
resultsCh chan<- R
errorsCh chan<- error
}

func (w *worker[R]) start() {
go func() {
for task := range w.queueCh {
if task == nil {
// let the collector know we processed the request
w.errorsCh <- nil
continue
}

data, err := task()
if err != nil {
w.errorsCh <- err
} else {
w.resultsCh <- data
}
}
}()
}
4 changes: 2 additions & 2 deletions providers-sdk/v1/inventory/inventory.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading