From f0331c8a30369db5805cd218e1d1f5bdfcf39b36 Mon Sep 17 00:00:00 2001 From: "daniel.w" Date: Tue, 20 Aug 2024 23:11:32 +0800 Subject: [PATCH] feat: add worker_pool --- go.work | 1 + worker_pool/go.mod | 15 ++++++ worker_pool/worker_pool.go | 70 ++++++++++++++++++++++++++ worker_pool/worker_pool_test.go | 88 +++++++++++++++++++++++++++++++++ 4 files changed, 174 insertions(+) create mode 100644 worker_pool/go.mod create mode 100644 worker_pool/worker_pool.go create mode 100644 worker_pool/worker_pool_test.go diff --git a/go.work b/go.work index aa5172a..f91f2e2 100644 --- a/go.work +++ b/go.work @@ -2,3 +2,4 @@ go 1.22.3 use ./errors use ./validator +use worker_pool \ No newline at end of file diff --git a/worker_pool/go.mod b/worker_pool/go.mod new file mode 100644 index 0000000..87d2917 --- /dev/null +++ b/worker_pool/go.mod @@ -0,0 +1,15 @@ +module code.30cm.net/digimon/library-go/worker_pool + +go 1.22.3 + +require ( + github.com/panjf2000/ants/v2 v2.10.0 + github.com/stretchr/testify v1.8.2 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + golang.org/x/sync v0.3.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/worker_pool/worker_pool.go b/worker_pool/worker_pool.go new file mode 100644 index 0000000..ce7113b --- /dev/null +++ b/worker_pool/worker_pool.go @@ -0,0 +1,70 @@ +package worker_pool + +import ( + "github.com/panjf2000/ants/v2" + "sync" +) + +const defaultWorkerPoolSize = 2000 + +type WorkerPool interface { + Submit(task func()) error + SubmitAndWaitAll(tasks ...func() error) (taskErr chan error, submitErr error) +} + +type workerPool struct { + p *ants.Pool +} + +func NewWorkerPool(size int) WorkerPool { + if size <= 0 { + size = defaultWorkerPoolSize + } + + p, err := ants.NewPool( + size, + ants.WithDisablePurge(true), + ) + if err != nil { + return &workerPool{p: nil} + } + + return &workerPool{p: p} +} + +func (p *workerPool) Submit(task func()) error { + if p.p == nil { + return ants.Submit(task) + } + + return p.p.Submit(task) +} + +func (p *workerPool) SubmitAndWaitAll(tasks ...func() error) (chan error, error) { + taskErrCh := make(chan error, len(tasks)) + submitErrCh := make(chan error, len(tasks)) + wg := sync.WaitGroup{} + wg.Add(len(tasks)) + + for i := range tasks { + task := tasks[i] + err := p.Submit(func() { + defer wg.Done() + if err := task(); err != nil { + taskErrCh <- err + } + }) + if err != nil { + submitErrCh <- err + wg.Done() + } + } + + wg.Wait() + + if len(submitErrCh) != 0 { + return nil, <-submitErrCh + } + + return taskErrCh, nil +} diff --git a/worker_pool/worker_pool_test.go b/worker_pool/worker_pool_test.go new file mode 100644 index 0000000..0122ab6 --- /dev/null +++ b/worker_pool/worker_pool_test.go @@ -0,0 +1,88 @@ +package worker_pool + +import ( + "errors" + "github.com/stretchr/testify/assert" + "sync" + "testing" + "time" +) + +func TestNewWorkerPool(t *testing.T) { + t.Run("default size pool", func(t *testing.T) { + pool := NewWorkerPool(0) + assert.NotNil(t, pool) + }) + + t.Run("custom size pool", func(t *testing.T) { + size := 100 + pool := NewWorkerPool(size) + assert.NotNil(t, pool) + }) +} + +func TestSubmit(t *testing.T) { + t.Run("submit task to worker pool", func(t *testing.T) { + pool := NewWorkerPool(10) + var wg sync.WaitGroup + wg.Add(1) + + err := pool.Submit(func() { + defer wg.Done() + time.Sleep(100 * time.Millisecond) + }) + assert.NoError(t, err) + + wg.Wait() + }) +} + +func TestSubmitAndWaitAll(t *testing.T) { + t.Run("submit and wait all tasks succeed", func(t *testing.T) { + pool := NewWorkerPool(10) + tasks := []func() error{ + func() error { + time.Sleep(100 * time.Millisecond) + return nil + }, + func() error { + time.Sleep(50 * time.Millisecond) + return nil + }, + } + + taskErrCh, submitErr := pool.SubmitAndWaitAll(tasks...) + assert.NoError(t, submitErr) + close(taskErrCh) + for err := range taskErrCh { + assert.NoError(t, err) + } + }) + + t.Run("submit and wait all tasks with errors", func(t *testing.T) { + pool := NewWorkerPool(10) + expectedError := errors.New("task error") + tasks := []func() error{ + func() error { + time.Sleep(100 * time.Millisecond) + return nil + }, + func() error { + time.Sleep(50 * time.Millisecond) + return expectedError + }, + } + + taskErrCh, submitErr := pool.SubmitAndWaitAll(tasks...) + assert.NoError(t, submitErr) + close(taskErrCh) + foundError := false + for err := range taskErrCh { + if err != nil { + foundError = true + assert.Equal(t, expectedError, err) + } + } + assert.True(t, foundError) + }) +}