|
|
|
|
// Copyright 2020, Chef. All rights reserved.
|
|
|
|
|
// https://github.com/q191201771/naza
|
|
|
|
|
//
|
|
|
|
|
// Use of this source code is governed by a MIT-style license
|
|
|
|
|
// that can be found in the License file.
|
|
|
|
|
//
|
|
|
|
|
// Author: Chef (191201771@qq.com)
|
|
|
|
|
|
|
|
|
|
package ratelimit
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
"errors"
|
|
|
|
|
"fmt"
|
|
|
|
|
"sync"
|
|
|
|
|
"time"
|
|
|
|
|
|
|
|
|
|
"github.com/q191201771/naza/pkg/nazaatomic"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
var ErrTokenNotEnough = errors.New("naza.ratelimit: token not enough")
|
|
|
|
|
|
|
|
|
|
// 令牌桶
|
|
|
|
|
type TokenBucket struct {
|
|
|
|
|
capacity int
|
|
|
|
|
prodTokenInterval time.Duration
|
|
|
|
|
prodTokenNumEveryInterval int
|
|
|
|
|
|
|
|
|
|
disposeFlag nazaatomic.Bool
|
|
|
|
|
|
|
|
|
|
mu sync.Mutex
|
|
|
|
|
available int
|
|
|
|
|
cond *sync.Cond
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// @param capacity: 桶容量大小
|
|
|
|
|
// @param prodTokenIntervalMSec: 生产令牌的时间间隔,单位毫秒
|
|
|
|
|
// @param prodTokenNumEveryInterval: 每次生产多少个令牌
|
|
|
|
|
func NewTokenBucket(capacity int, prodTokenIntervalMSec int, prodTokenNumEveryInterval int) *TokenBucket {
|
|
|
|
|
tb := &TokenBucket{
|
|
|
|
|
capacity: capacity,
|
|
|
|
|
prodTokenInterval: time.Duration(time.Duration(prodTokenIntervalMSec) * time.Millisecond),
|
|
|
|
|
prodTokenNumEveryInterval: prodTokenNumEveryInterval,
|
|
|
|
|
}
|
|
|
|
|
tb.cond = sync.NewCond(&tb.mu)
|
|
|
|
|
tb.asyncProdToken()
|
|
|
|
|
return tb
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (tb *TokenBucket) TryAquire() error {
|
|
|
|
|
return tb.TryAquireWithNum(1)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (tb *TokenBucket) WaitUntilAquire() {
|
|
|
|
|
tb.WaitUntilAquireWithNum(1)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 尝试获取相应数量的令牌,获取成功返回nil,获取失败返回ErrTokenNotEnough
|
|
|
|
|
// 如果获取失败,上层可自由选择多久后重试或丢弃本次任务
|
|
|
|
|
func (tb *TokenBucket) TryAquireWithNum(num int) error {
|
|
|
|
|
tb.checkAquireNum(num)
|
|
|
|
|
|
|
|
|
|
tb.mu.Lock()
|
|
|
|
|
defer tb.mu.Unlock()
|
|
|
|
|
if tb.available >= num {
|
|
|
|
|
tb.available -= num
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return ErrTokenNotEnough
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 阻塞直到获取到相应数量的令牌
|
|
|
|
|
func (tb *TokenBucket) WaitUntilAquireWithNum(num int) {
|
|
|
|
|
tb.checkAquireNum(num)
|
|
|
|
|
|
|
|
|
|
for {
|
|
|
|
|
tb.mu.Lock()
|
|
|
|
|
if tb.available >= num {
|
|
|
|
|
tb.available -= num
|
|
|
|
|
tb.mu.Unlock()
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 等待下次令牌生产时被唤醒
|
|
|
|
|
// wait的内部会将自身添加到事件监听队列中然后释放锁,当接收到事件时,内部会重新获取锁然后返回
|
|
|
|
|
tb.cond.Wait()
|
|
|
|
|
tb.mu.Unlock()
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 销毁令牌桶
|
|
|
|
|
func (tb *TokenBucket) Dispose() {
|
|
|
|
|
tb.disposeFlag.Store(true)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (tb *TokenBucket) asyncProdToken() {
|
|
|
|
|
go func() {
|
|
|
|
|
t := time.NewTicker(tb.prodTokenInterval)
|
|
|
|
|
defer t.Stop()
|
|
|
|
|
for {
|
|
|
|
|
if tb.disposeFlag.Load() {
|
|
|
|
|
break
|
|
|
|
|
}
|
|
|
|
|
select {
|
|
|
|
|
case <-t.C:
|
|
|
|
|
tb.mu.Lock()
|
|
|
|
|
tb.available += tb.prodTokenNumEveryInterval
|
|
|
|
|
if tb.available > tb.capacity {
|
|
|
|
|
tb.available = tb.capacity
|
|
|
|
|
}
|
|
|
|
|
// It is allowed but not required for the caller to hold c.L
|
|
|
|
|
// during the call.
|
|
|
|
|
tb.cond.Broadcast()
|
|
|
|
|
tb.mu.Unlock()
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (tb *TokenBucket) checkAquireNum(num int) {
|
|
|
|
|
if num > tb.capacity {
|
|
|
|
|
panic(fmt.Sprintf("aquire num should not bigger than capacity. num=%d, capacity=%d", num, tb.capacity))
|
|
|
|
|
}
|
|
|
|
|
}
|