diff --git a/pkg/connection/connection.go b/pkg/connection/connection.go index 6bc2221..35df036 100644 --- a/pkg/connection/connection.go +++ b/pkg/connection/connection.go @@ -21,9 +21,10 @@ import ( "io" "net" "sync" - "sync/atomic" "time" + "github.com/q191201771/naza/pkg/nazaatomic" + "github.com/q191201771/naza/pkg/nazalog" ) @@ -75,6 +76,11 @@ type Stat struct { WroteBytesSum uint64 } +type StatAtomic struct { + ReadBytesSum nazaatomic.Uint64 + WroteBytesSum nazaatomic.Uint64 +} + type WriteChanFullBehavior int const ( @@ -115,7 +121,6 @@ type ModOption func(option *Option) func New(conn net.Conn, modOptions ...ModOption) Connection { c := new(connection) c.doneChan = make(chan error, 1) - c.closedFlag = 0 c.Conn = conn c.option = defaultOption @@ -169,9 +174,9 @@ type connection struct { flushDoneChan chan struct{} exitChan chan struct{} doneChan chan error - closedFlag uint32 + closedFlag nazaatomic.Bool closeOnce sync.Once - stat Stat + stat StatAtomic } func (c *connection) ModWriteChanSize(n int) { @@ -221,7 +226,7 @@ func (c *connection) ReadAtLeast(buf []byte, min int) (n int, err error) { if err != nil { c.close(err) } - atomic.AddUint64(&c.stat.ReadBytesSum, uint64(n)) + c.stat.ReadBytesSum.Add(uint64(n)) return n, err } @@ -243,7 +248,7 @@ func (c *connection) ReadLine() (line []byte, isPrefix bool, err error) { if err != nil { c.close(err) } - atomic.AddUint64(&c.stat.ReadBytesSum, uint64(len(line))) + c.stat.ReadBytesSum.Add(uint64(len(line))) return line, isPrefix, err } @@ -259,12 +264,12 @@ func (c *connection) Read(b []byte) (n int, err error) { if err != nil { c.close(err) } - atomic.AddUint64(&c.stat.ReadBytesSum, uint64(n)) + c.stat.ReadBytesSum.Add(uint64(n)) return n, err } func (c *connection) Write(b []byte) (n int, err error) { - if atomic.LoadUint32(&c.closedFlag) == 1 { + if c.closedFlag.Load() { return 0, ErrClosedAlready } if c.option.WriteChanSize > 0 { @@ -285,7 +290,7 @@ func (c *connection) Write(b []byte) (n int, err error) { } func (c *connection) Flush() error { - if atomic.LoadUint32(&c.closedFlag) == 1 { + if c.closedFlag.Load() { return ErrClosedAlready } if c.option.WriteChanSize > 0 { @@ -340,8 +345,8 @@ func (c *connection) SetWriteDeadline(t time.Time) error { } func (c *connection) GetStat() (s Stat) { - s.ReadBytesSum = atomic.LoadUint64(&c.stat.ReadBytesSum) - s.WroteBytesSum = atomic.LoadUint64(&c.stat.WroteBytesSum) + s.ReadBytesSum = c.stat.ReadBytesSum.Load() + s.WroteBytesSum = c.stat.WroteBytesSum.Load() return } @@ -357,7 +362,7 @@ func (c *connection) write(b []byte) (n int, err error) { if err != nil { c.close(err) } - atomic.AddUint64(&c.stat.WroteBytesSum, uint64(n)) + c.stat.WroteBytesSum.Add(uint64(n)) return n, err } @@ -405,7 +410,7 @@ func (c *connection) flush() error { func (c *connection) close(err error) { nazalog.Debugf("naza connection close. err=%v, conn=%p", err, c) c.closeOnce.Do(func() { - atomic.StoreUint32(&c.closedFlag, 1) + c.closedFlag.Store(true) if c.option.WriteChanSize > 0 { c.exitChan <- struct{}{} } diff --git a/pkg/ratelimit/tokenbucket.go b/pkg/ratelimit/tokenbucket.go index 7eb242f..711ee0f 100644 --- a/pkg/ratelimit/tokenbucket.go +++ b/pkg/ratelimit/tokenbucket.go @@ -12,8 +12,9 @@ import ( "errors" "fmt" "sync" - "sync/atomic" "time" + + "github.com/q191201771/naza/pkg/nazaatomic" ) var ErrTokenNotEnough = errors.New("naza.ratelimit: token not enough") @@ -24,7 +25,7 @@ type TokenBucket struct { prodTokenInterval time.Duration prodTokenNumEveryInterval int - disposeFlag int32 + disposeFlag nazaatomic.Bool mu sync.Mutex available int @@ -89,7 +90,7 @@ func (tb *TokenBucket) WaitUntilAquireWithNum(num int) { // 销毁令牌桶 func (tb *TokenBucket) Dispose() { - atomic.StoreInt32(&tb.disposeFlag, 1) + tb.disposeFlag.Store(true) } func (tb *TokenBucket) asyncProdToken() { @@ -97,7 +98,7 @@ func (tb *TokenBucket) asyncProdToken() { t := time.NewTicker(tb.prodTokenInterval) defer t.Stop() for { - if atomic.LoadInt32(&tb.disposeFlag) == 1 { + if tb.disposeFlag.Load() { break } select { diff --git a/pkg/slicebytepool/slicebytepool.go b/pkg/slicebytepool/slicebytepool.go index a664e07..9b5b916 100644 --- a/pkg/slicebytepool/slicebytepool.go +++ b/pkg/slicebytepool/slicebytepool.go @@ -8,9 +8,7 @@ package slicebytepool -import ( - "sync/atomic" -) +import "github.com/q191201771/naza/pkg/nazaatomic" var ( minSize = 1024 @@ -20,11 +18,18 @@ var ( type sliceBytePool struct { strategy Strategy capToFreeBucket map[int]Bucket - status Status + status statusAtomic +} + +type statusAtomic struct { + getCount nazaatomic.Int64 + putCount nazaatomic.Int64 + hitCount nazaatomic.Int64 + sizeBytes nazaatomic.Int64 } func (bp *sliceBytePool) Get(size int) []byte { - atomic.AddInt64(&bp.status.getCount, 1) + bp.status.getCount.Increment() ss := up2power(size) if ss < minSize { @@ -38,15 +43,15 @@ func (bp *sliceBytePool) Get(size int) []byte { return buf } - atomic.AddInt64(&bp.status.hitCount, 1) - atomic.AddInt64(&bp.status.sizeBytes, int64(-cap(buf))) + bp.status.hitCount.Increment() + bp.status.sizeBytes.Sub(int64(cap(buf))) return buf } func (bp *sliceBytePool) Put(buf []byte) { c := cap(buf) - atomic.AddInt64(&bp.status.putCount, 1) - atomic.AddInt64(&bp.status.sizeBytes, int64(c)) + bp.status.putCount.Increment() + bp.status.sizeBytes.Add(int64(c)) size := down2power(c) if size < minSize { @@ -60,10 +65,10 @@ func (bp *sliceBytePool) Put(buf []byte) { func (bp *sliceBytePool) RetrieveStatus() Status { return Status{ - getCount: atomic.LoadInt64(&bp.status.getCount), - putCount: atomic.LoadInt64(&bp.status.putCount), - hitCount: atomic.LoadInt64(&bp.status.hitCount), - sizeBytes: atomic.LoadInt64(&bp.status.sizeBytes), + getCount: bp.status.getCount.Load(), + putCount: bp.status.putCount.Load(), + hitCount: bp.status.hitCount.Load(), + sizeBytes: bp.status.sizeBytes.Load(), } }