diff --git a/demo/taskpool/main.go b/demo/taskpool/main.go index a428251..d7304ca 100644 --- a/demo/taskpool/main.go +++ b/demo/taskpool/main.go @@ -54,7 +54,7 @@ func taskPool() { //b.StartTimer() wg.Add(taskNum) for i := 0; i < taskNum; i++ { - ps[i%poolNum].Go(func() { + ps[i%poolNum].Go(func(param ...interface{}) { time.Sleep(10 * time.Millisecond) wg.Done() }) @@ -68,8 +68,7 @@ func taskPool() { //nazalog.Debugf("killed, worker num. idle=%d, busy=%d", idle, busy) } nazalog.Debug("< BenchmarkTaskPool") - idle, busy := ps[0].Status() - nazalog.Debugf("killed, worker num. idle=%d, busy=%d", idle, busy) + nazalog.Debugf("killed, worker num. status=%+v", ps[0].GetCurrentStatus()) } func main() { diff --git a/pkg/taskpool/example_test.go b/pkg/taskpool/example_test.go new file mode 100644 index 0000000..8ebe9fb --- /dev/null +++ b/pkg/taskpool/example_test.go @@ -0,0 +1,41 @@ +// Copyright 2020, Chef. All rights reserved. +// https://github.com/q191201771/naza +// +// Use of this source code is governed by a MIT-style license +// that can be found in the License file. +// +// Author: Chef (191201771@qq.com) + +package taskpool_test + +import ( + "fmt" + "sync" + "sync/atomic" + + "github.com/q191201771/naza/pkg/taskpool" +) + +// 并发计算0+1+2+...+1000 +// 演示怎么向协程池中添加带参数的函数任务 +func ExampleNewPool() { + pool, _ := taskpool.NewPool(func(option *taskpool.Option) { + // 限制最大并发数 + option.MaxWorkerNum = 16 + }) + var sum int32 + var wg sync.WaitGroup + n := 1000 + wg.Add(n) + for i := 0; i < n; i++ { + pool.Go(func(param ...interface{}) { + ii := param[0].(int) + atomic.AddInt32(&sum, int32(ii)) + wg.Done() + }, i) + } + wg.Wait() + fmt.Println(sum) + // Output: + // 499500 +} diff --git a/pkg/taskpool/global.go b/pkg/taskpool/global.go index 967694c..2c3478e 100644 --- a/pkg/taskpool/global.go +++ b/pkg/taskpool/global.go @@ -10,8 +10,8 @@ package taskpool var global Pool -func Go(task Task) { - global.Go(task) +func Go(task TaskFn, param ...interface{}) { + global.Go(task, param) } func GetCurrentStatus() Status { diff --git a/pkg/taskpool/interface.go b/pkg/taskpool/interface.go index 188631a..c3f98ee 100644 --- a/pkg/taskpool/interface.go +++ b/pkg/taskpool/interface.go @@ -18,7 +18,7 @@ import ( var ErrTaskPool = errors.New("naza.taskpool: fxxk") -type Task func() +type TaskFn func(param ...interface{}) type Status struct { TotalWorkerNum int // 总协程数量 @@ -29,7 +29,7 @@ type Status struct { type Pool interface { // 向池内放入任务 // 非阻塞函数,不会等待task执行 - Go(task Task) + Go(task TaskFn, param ...interface{}) // 获取当前的状态,注意,只是一个瞬时值 GetCurrentStatus() Status diff --git a/pkg/taskpool/pool.go b/pkg/taskpool/pool.go index 7d0475e..b1602eb 100644 --- a/pkg/taskpool/pool.go +++ b/pkg/taskpool/pool.go @@ -12,13 +12,18 @@ import ( "sync" ) +type taskWrapper struct { + taskFn TaskFn + param []interface{} +} + type pool struct { maxWorkerNum int m sync.Mutex totalWorkerNum int idleWorkerList []*worker - blockTaskList []Task + blockTaskList []taskWrapper } func newPool(option Option) *pool { @@ -31,7 +36,11 @@ func newPool(option Option) *pool { return &p } -func (p *pool) Go(task Task) { +func (p *pool) Go(task TaskFn, param ...interface{}) { + tw := taskWrapper{ + taskFn: task, + param: param, + } var w *worker p.m.Lock() if len(p.idleWorkerList) != 0 { @@ -39,7 +48,7 @@ func (p *pool) Go(task Task) { w = p.idleWorkerList[len(p.idleWorkerList)-1] p.idleWorkerList = p.idleWorkerList[0 : len(p.idleWorkerList)-1] - w.Go(task) + w.Go(tw) } else { // 无空闲worker @@ -47,11 +56,11 @@ func (p *pool) Go(task Task) { (p.maxWorkerNum > 0 && p.totalWorkerNum < p.maxWorkerNum) { // 无最大worker限制,或还未达到限制 - p.newWorkerWithTask(task) + p.newWorkerWithTask(tw) } else { // 已达到限制 - p.blockTaskList = append(p.blockTaskList, task) + p.blockTaskList = append(p.blockTaskList, tw) } } p.m.Unlock() @@ -85,7 +94,7 @@ func (p *pool) newWorker() *worker { return w } -func (p *pool) newWorkerWithTask(task Task) { +func (p *pool) newWorkerWithTask(task taskWrapper) { w := NewWorker(p) w.Start() w.Go(task) diff --git a/pkg/taskpool/taskpool_test.go b/pkg/taskpool/taskpool_test.go index 48e5498..d9e978f 100644 --- a/pkg/taskpool/taskpool_test.go +++ b/pkg/taskpool/taskpool_test.go @@ -10,6 +10,7 @@ package taskpool_test import ( "sync" + "sync/atomic" "testing" "time" @@ -52,7 +53,7 @@ func BenchmarkTaskPool(b *testing.B) { //b.StartTimer() wg.Add(taskNum) for i := 0; i < taskNum; i++ { - p.Go(func() { + p.Go(func(param ...interface{}) { time.Sleep(10 * time.Millisecond) wg.Done() }) @@ -79,7 +80,7 @@ func TestTaskPool(t *testing.T) { wg.Add(n) nazalog.Debug("start.") for i := 0; i < n; i++ { - p.Go(func() { + p.Go(func(param ...interface{}) { time.Sleep(10 * time.Millisecond) wg.Done() }) @@ -93,7 +94,7 @@ func TestTaskPool(t *testing.T) { wg.Add(n) for i := 0; i < n; i++ { - p.Go(func() { + p.Go(func(param ...interface{}) { time.Sleep(10 * time.Millisecond) wg.Done() }) @@ -121,11 +122,14 @@ func TestMaxWorker(t *testing.T) { wg.Add(n) nazalog.Debugf("start.") for i := 0; i < n; i++ { - p.Go(func() { - //atomic.AddInt32(&sum, int32(i)) + p.Go(func(param ...interface{}) { + a := param[0].(int) + b := param[1].(int) + atomic.AddInt32(&sum, int32(a)) + atomic.AddInt32(&sum, int32(b)) time.Sleep(10 * time.Millisecond) wg.Done() - }) + }, i, i) } wg.Wait() nazalog.Debugf("end. sum=%d", sum) @@ -138,7 +142,7 @@ func TestGlobal(t *testing.T) { assert.Equal(t, 0, s.TotalWorkerNum) assert.Equal(t, 0, s.IdleWorkerNum) assert.Equal(t, 0, s.BlockTaskNum) - taskpool.Go(func() { + taskpool.Go(func(param ...interface{}) { }) taskpool.KillIdleWorkers() } diff --git a/pkg/taskpool/worker.go b/pkg/taskpool/worker.go index 7025481..7c9eca7 100644 --- a/pkg/taskpool/worker.go +++ b/pkg/taskpool/worker.go @@ -9,13 +9,13 @@ package taskpool type worker struct { - taskChan chan Task + taskChan chan taskWrapper p *pool } func NewWorker(p *pool) *worker { return &worker{ - taskChan: make(chan Task, 1), + taskChan: make(chan taskWrapper, 1), p: p, } } @@ -27,7 +27,7 @@ func (w *worker) Start() { if !ok { break } - task() + task.taskFn(task.param...) w.p.onIdle(w) } }() @@ -37,6 +37,6 @@ func (w *worker) Stop() { close(w.taskChan) } -func (w *worker) Go(task Task) { - w.taskChan <- task +func (w *worker) Go(t taskWrapper) { + w.taskChan <- t }