Skip to content

Commit

Permalink
⚡ fetch github org repositories in parallel (#4970)
Browse files Browse the repository at this point in the history
Introducing an internal go package called `workerpool` that we can use to send parallel
requests when needed.

:zap: For this change, I am making the fetching of repositories for an organization faster.

Tested this code with an **organization that has around 3k repositories**

### Before (~2 Minutes)
```
TRC logger.FuncDur> func=provider.github.repositories took=102803.621667
```
### After (~5 seconds)
```
TRC logger.FuncDur> func=provider.github.repositories took=4567.576542
```

* :zap: fetch org repositories in parallel
* ⚙️  add a collector to the workerpool

This will help us submit as many requests as we want without knowing
about the workers.

* :rotating_light: fix race conditions
---------

Signed-off-by: Salim Afiune Maya <[email protected]>
  • Loading branch information
afiune authored Dec 12, 2024
1 parent ee1123c commit 949f0ce
Show file tree
Hide file tree
Showing 15 changed files with 469 additions and 62 deletions.
55 changes: 55 additions & 0 deletions internal/workerpool/collector.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// Copyright (c) Mondoo, Inc.
// SPDX-License-Identifier: BUSL-1.1

package workerpool

import (
"sync"
"sync/atomic"
)

type collector[R any] struct {
resultsCh <-chan R
results []R
read sync.Mutex

errorsCh <-chan error
errors []error

requestsRead int64
}

func (c *collector[R]) start() {
go func() {
for {
select {
case result := <-c.resultsCh:
c.read.Lock()
c.results = append(c.results, result)
c.read.Unlock()

case err := <-c.errorsCh:
c.read.Lock()
c.errors = append(c.errors, err)
c.read.Unlock()
}

atomic.AddInt64(&c.requestsRead, 1)
}
}()
}
func (c *collector[R]) GetResults() []R {
c.read.Lock()
defer c.read.Unlock()
return c.results
}

func (c *collector[R]) GetErrors() []error {
c.read.Lock()
defer c.read.Unlock()
return c.errors
}

func (c *collector[R]) RequestsRead() int64 {
return atomic.LoadInt64(&c.requestsRead)
}
112 changes: 112 additions & 0 deletions internal/workerpool/pool.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
// Copyright (c) Mondoo, Inc.
// SPDX-License-Identifier: BUSL-1.1

package workerpool

import (
"sync"
"sync/atomic"
"time"

"github.com/cockroachdb/errors"
)

type Task[R any] func() (result R, err error)

// Pool is a generic pool of workers.
type Pool[R any] struct {
queueCh chan Task[R]
resultsCh chan R
errorsCh chan error

requestsSent int64
once sync.Once

workers []*worker[R]
workerCount int

collector[R]
}

// New initializes a new Pool with the provided number of workers. The pool is generic and can
// accept any type of Task that returns the signature `func() (R, error)`.
//
// For example, a Pool[int] will accept Tasks similar to:
//
// task := func() (int, error) {
// return 42, nil
// }
func New[R any](count int) *Pool[R] {
resultsCh := make(chan R)
errorsCh := make(chan error)
return &Pool[R]{
queueCh: make(chan Task[R]),
resultsCh: resultsCh,
errorsCh: errorsCh,
workerCount: count,
collector: collector[R]{resultsCh: resultsCh, errorsCh: errorsCh},
}
}

// Start the pool workers and collector. Make sure call `Close()` to clear the pool.
//
// pool := workerpool.New[int](10)
// pool.Start()
// defer pool.Close()
func (p *Pool[R]) Start() {
p.once.Do(func() {
for i := 0; i < p.workerCount; i++ {
w := worker[R]{id: i, queueCh: p.queueCh, resultsCh: p.resultsCh, errorsCh: p.errorsCh}
w.start()
p.workers = append(p.workers, &w)
}

p.collector.start()
})
}

// Submit sends a task to the workers
func (p *Pool[R]) Submit(t Task[R]) {
p.queueCh <- t
atomic.AddInt64(&p.requestsSent, 1)
}

// GetErrors returns any error from a processed task
func (p *Pool[R]) GetErrors() error {
return errors.Join(p.collector.GetErrors()...)
}

// GetResults returns the tasks results.
//
// It is recommended to call `Wait()` before reading the results.
func (p *Pool[R]) GetResults() []R {
return p.collector.GetResults()
}

// Close waits for workers and collector to process all the requests, and then closes
// the task queue channel. After closing the pool, calling `Submit()` will panic.
func (p *Pool[R]) Close() {
p.Wait()
close(p.queueCh)
}

// Wait waits until all tasks have been processed.
func (p *Pool[R]) Wait() {
ticker := time.NewTicker(100 * time.Millisecond)
for {
if !p.Processing() {
return
}
<-ticker.C
}
}

// PendingRequests returns the number of pending requests.
func (p *Pool[R]) PendingRequests() int64 {
return atomic.LoadInt64(&p.requestsSent) - p.collector.RequestsRead()
}

// Processing return true if tasks are being processed.
func (p *Pool[R]) Processing() bool {
return p.PendingRequests() != 0
}
185 changes: 185 additions & 0 deletions internal/workerpool/pool_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
// Copyright (c) Mondoo, Inc.
// SPDX-License-Identifier: BUSL-1.1

package workerpool_test

import (
"errors"
"testing"
"time"

"math/rand"

"github.com/stretchr/testify/assert"
"go.mondoo.com/cnquery/v11/internal/workerpool"
)

func TestPoolSubmitAndRetrieveResult(t *testing.T) {
pool := workerpool.New[int](2)
pool.Start()
defer pool.Close()

task := func() (int, error) {
return 42, nil
}

// no results
assert.Empty(t, pool.GetResults())

// submit a request
pool.Submit(task)

// wait for the request to process
pool.Wait()

// should have one result
results := pool.GetResults()
if assert.Len(t, results, 1) {
assert.Equal(t, 42, results[0])
}

// no errors
assert.Nil(t, pool.GetErrors())
}

func TestPoolHandleErrors(t *testing.T) {
pool := workerpool.New[int](5)
pool.Start()
defer pool.Close()

// submit a task that will return an error
task := func() (int, error) {
return 0, errors.New("task error")
}
pool.Submit(task)

// Wait for error collector to process
pool.Wait()

err := pool.GetErrors()
if assert.Error(t, err) {
assert.Contains(t, err.Error(), "task error")
}
}

func TestPoolMultipleTasksWithErrors(t *testing.T) {
type test struct {
data int
}
pool := workerpool.New[*test](5)
pool.Start()
defer pool.Close()

tasks := []workerpool.Task[*test]{
func() (*test, error) { return &test{1}, nil },
func() (*test, error) { return &test{2}, nil },
func() (*test, error) {
return nil, errors.New("task error")
},
func() (*test, error) { return &test{3}, nil },
}

for _, task := range tasks {
pool.Submit(task)
}

// Wait for error collector to process
pool.Wait()

results := pool.GetResults()
assert.ElementsMatch(t, []*test{&test{1}, &test{2}, &test{3}}, results)
err := pool.GetErrors()
if assert.Error(t, err) {
assert.Contains(t, err.Error(), "task error")
}
}

func TestPoolHandlesNilTasks(t *testing.T) {
pool := workerpool.New[int](2)
pool.Start()
defer pool.Close()

var nilTask workerpool.Task[int]
pool.Submit(nilTask)

pool.Wait()

err := pool.GetErrors()
assert.NoError(t, err)
}

func TestPoolProcessing(t *testing.T) {
pool := workerpool.New[int](2)
pool.Start()
defer pool.Close()

task := func() (int, error) {
time.Sleep(50 * time.Millisecond)
return 10, nil
}

pool.Submit(task)

// should be processing
assert.True(t, pool.Processing())

// wait
pool.Wait()

// read results
result := pool.GetResults()
assert.Equal(t, []int{10}, result)

// should not longer be processing
assert.False(t, pool.Processing())
}

func TestPoolClosesGracefully(t *testing.T) {
pool := workerpool.New[int](1)
pool.Start()

task := func() (int, error) {
time.Sleep(100 * time.Millisecond)
return 42, nil
}

pool.Submit(task)

pool.Close()

// Ensure no panic occurs and channels are closed
assert.PanicsWithError(t, "send on closed channel", func() {
pool.Submit(task)
})
}

func TestPoolWithManyTasks(t *testing.T) {
// 30k requests with a pool of 100 workers
// should be around 15 seconds
requestCount := 30000
pool := workerpool.New[int](100)
pool.Start()
defer pool.Close()

task := func() (int, error) {
random := rand.Intn(100)
time.Sleep(time.Duration(random) * time.Millisecond)
return random, nil
}

for i := 0; i < requestCount; i++ {
pool.Submit(task)
}

// should be processing
assert.True(t, pool.Processing())

// wait
pool.Wait()

// read results
assert.Equal(t, requestCount, len(pool.GetResults()))

// should not longer be processing
assert.False(t, pool.Processing())
}
30 changes: 30 additions & 0 deletions internal/workerpool/worker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright (c) Mondoo, Inc.
// SPDX-License-Identifier: BUSL-1.1

package workerpool

type worker[R any] struct {
id int
queueCh <-chan Task[R]
resultsCh chan<- R
errorsCh chan<- error
}

func (w *worker[R]) start() {
go func() {
for task := range w.queueCh {
if task == nil {
// let the collector know we processed the request
w.errorsCh <- nil
continue
}

data, err := task()
if err != nil {
w.errorsCh <- err
} else {
w.resultsCh <- data
}
}
}()
}
4 changes: 2 additions & 2 deletions providers-sdk/v1/inventory/inventory.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 949f0ce

Please sign in to comment.