package worker_pool import ( "errors" "github.com/stretchr/testify/assert" "sync" "testing" "time" ) func TestNewWorkerPool(t *testing.T) { t.Run("default size pool", func(t *testing.T) { pool := NewWorkerPool(0) assert.NotNil(t, pool) }) t.Run("custom size pool", func(t *testing.T) { size := 100 pool := NewWorkerPool(size) assert.NotNil(t, pool) }) } func TestSubmit(t *testing.T) { t.Run("submit task to worker pool", func(t *testing.T) { pool := NewWorkerPool(10) var wg sync.WaitGroup wg.Add(1) err := pool.Submit(func() { defer wg.Done() time.Sleep(100 * time.Millisecond) }) assert.NoError(t, err) wg.Wait() }) } func TestSubmitAndWaitAll(t *testing.T) { t.Run("submit and wait all tasks succeed", func(t *testing.T) { pool := NewWorkerPool(10) tasks := []func() error{ func() error { time.Sleep(100 * time.Millisecond) return nil }, func() error { time.Sleep(50 * time.Millisecond) return nil }, } taskErrCh, submitErr := pool.SubmitAndWaitAll(tasks...) assert.NoError(t, submitErr) close(taskErrCh) for err := range taskErrCh { assert.NoError(t, err) } }) t.Run("submit and wait all tasks with errors", func(t *testing.T) { pool := NewWorkerPool(10) expectedError := errors.New("task error") tasks := []func() error{ func() error { time.Sleep(100 * time.Millisecond) return nil }, func() error { time.Sleep(50 * time.Millisecond) return expectedError }, } taskErrCh, submitErr := pool.SubmitAndWaitAll(tasks...) assert.NoError(t, submitErr) close(taskErrCh) foundError := false for err := range taskErrCh { if err != nil { foundError = true assert.Equal(t, expectedError, err) } } assert.True(t, foundError) }) }