From 059da41295e0591e65573eb8d311e0c21c2338c3 Mon Sep 17 00:00:00 2001 From: q191201771 <191201771@qq.com> Date: Fri, 27 Sep 2019 14:07:08 +0800 Subject: [PATCH] =?UTF-8?q?1.=20package=20connection:=20Config=20=E4=B8=AD?= =?UTF-8?q?=E5=A2=9E=E5=8A=A0=20WChanSize.=20=E5=A2=9E=E5=8A=A0=20Flush,?= =?UTF-8?q?=20Done,=20ModWriteChanSize=20=E4=B8=89=E4=B8=AA=E6=96=B9?= =?UTF-8?q?=E6=B3=95=202.=20package=20log:=20=E5=A2=9E=E5=8A=A0=20panic=20?= =?UTF-8?q?=E6=96=B9=E6=B3=95=203.=20test.sh=20=E4=B8=AD=E6=B7=BB=E5=8A=A0?= =?UTF-8?q?=20gofmt=E6=A3=80=E6=9F=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/assert/assert.go | 2 +- pkg/assert/assert_test.go | 1 + pkg/connection/connection.go | 227 ++++++++++++++++++++++++++++++++--- pkg/log/global.go | 21 +++- pkg/log/log.go | 51 ++++++-- pkg/log/log_test.go | 67 +++++++++-- test.sh | 27 ++++- 7 files changed, 346 insertions(+), 50 deletions(-) diff --git a/pkg/assert/assert.go b/pkg/assert/assert.go index 6d6f693..bf03e44 100644 --- a/pkg/assert/assert.go +++ b/pkg/assert/assert.go @@ -1,4 +1,4 @@ -// Package assert 提供了单元测试时的断言功能,减少一些模板代码 +// package assert 提供了单元测试时的断言功能,减少一些模板代码 // // 代码参考了 https://github.com/stretchr/testify // diff --git a/pkg/assert/assert_test.go b/pkg/assert/assert_test.go index 7e67020..0ab2017 100644 --- a/pkg/assert/assert_test.go +++ b/pkg/assert/assert_test.go @@ -29,6 +29,7 @@ func TestEqual(t *testing.T) { Equal(t, nil, i) var b []byte Equal(t, nil, b) + //Equal(t, nil, errors.New("mock error")) // 测试isNil Equal(t, true, isNil(nil)) diff --git a/pkg/connection/connection.go b/pkg/connection/connection.go index 4c024f3..e1c3e4a 100644 --- a/pkg/connection/connection.go +++ b/pkg/connection/connection.go @@ -9,8 +9,10 @@ import ( "bufio" "errors" "fmt" + "github.com/q191201771/nezha/pkg/log" "io" "net" + "sync" "time" ) @@ -18,26 +20,63 @@ var connectionErr = errors.New("connection: fxxk") type Connection interface { // 包含 interface net.Conn 的所有方法 + // Read + // Write + // Close + // LocalAddr + // RemoteAddr + // SetDeadline + // SetReadDeadline + // SetWriteDeadline net.Conn ReadAtLeast(buf []byte, min int) (n int, err error) ReadLine() (line []byte, isPrefix bool, err error) + // TODO chef: 这个接口是否不提供 Printf(fmt string, v ...interface{}) (n int, err error) + // 如果使用了 bufio 写缓冲,则将缓冲中的数据发送出去 + // 如果使用了 channel 异步发送,则阻塞等待,直到之前 channel 中的数据全部发送完毕 + // 一般在 Close 前,想要将剩余数据发送完毕时调用 + Flush() error + + // 阻塞直到连接关闭或发生错误 + // @return 返回 nil 则是本端主动调用 Close 关闭 + Done() <-chan error + + // TODO chef: 这几个接口是否不提供 + // Mod类型函数不加锁,需要调用方保证不发生竞态调用 + ModWriteChanSize(n int) ModWriteBufSize(n int) ModReadTimeoutMS(n int) ModWriteTimeoutMS(n int) } type Config struct { - // 如果不为0,则之后每次读/写使用 buffer 缓冲 + // 如果不为0,则之后每次读/写使用 bufio 的缓冲 ReadBufSize int WriteBufSize int // 如果不为0,则之后每次读/写都带超时 ReadTimeoutMS int WriteTimeoutMS int + + // 如果不为0,则写使用 channel 将数据发送到后台协程中发送 + WChanSize int +} + +type wMsgT int + +const ( + _ wMsgT = iota + wMsgTWrite + wMsgTFlush +) + +type wmsg struct { + t wMsgT + b []byte } func New(conn net.Conn, config Config) Connection { @@ -53,20 +92,39 @@ func New(conn net.Conn, config Config) Connection { } else { c.w = conn } + if config.WChanSize > 0 { + c.wChan = make(chan wmsg, config.WChanSize) + c.flushDoneChan = make(chan struct{}, 1) + go c.runWriteLoop() + } + c.doneChan = make(chan error, 1) + c.exitChan = make(chan struct{}, 1) c.config = config return &c } type connection struct { - Conn net.Conn - r io.Reader - w io.Writer - config Config + Conn net.Conn + r io.Reader + w io.Writer + config Config + wChan chan wmsg + flushDoneChan chan struct{} + doneChan chan error + exitChan chan struct{} + closeOnce sync.Once } -// Mod类型函数不加锁 +func (c *connection) ModWriteChanSize(n int) { + if c.config.WChanSize > 0 { + panic(connectionErr) + } + c.config.WChanSize = n + c.wChan = make(chan wmsg, n) + c.flushDoneChan = make(chan struct{}, 1) + go c.runWriteLoop() +} -// 由调用方保证不和写操作并发执行 func (c *connection) ModWriteBufSize(n int) { if c.config.WriteBufSize > 0 { // 如果之前已经设置过写缓冲,直接 panic @@ -93,9 +151,18 @@ func (c *connection) ModWriteTimeoutMS(n int) { func (c *connection) ReadAtLeast(buf []byte, min int) (n int, err error) { if c.config.ReadTimeoutMS > 0 { - _ = c.Conn.SetReadDeadline(time.Now().Add(time.Duration(c.config.ReadTimeoutMS) * time.Millisecond)) + err = c.SetReadDeadline(time.Now().Add(time.Duration(c.config.ReadTimeoutMS) * time.Millisecond)) + if err != nil { + log.Debugf("nezha connection. error=%v, conn=%p", err, c) + return 0, err + } + } + n, err = io.ReadAtLeast(c.r, buf, min) + if err != nil { + log.Debugf("nezha connection. error=%v, conn=%p", err, c) + c.close(err) } - return io.ReadAtLeast(c.r, buf, min) + return n, err } func (c *connection) ReadLine() (line []byte, isPrefix bool, err error) { @@ -105,9 +172,18 @@ func (c *connection) ReadLine() (line []byte, isPrefix bool, err error) { panic(connectionErr) } if c.config.ReadTimeoutMS > 0 { - _ = c.Conn.SetReadDeadline(time.Now().Add(time.Duration(c.config.ReadTimeoutMS) * time.Millisecond)) + err = c.SetReadDeadline(time.Now().Add(time.Duration(c.config.ReadTimeoutMS) * time.Millisecond)) + if err != nil { + log.Debugf("nezha connection. error=%v, conn=%p", err, c) + return nil, false, err + } } - return bufioReader.ReadLine() + line, isPrefix, err = bufioReader.ReadLine() + if err != nil { + log.Debugf("nezha connection. error=%v, conn=%p", err, c) + c.close(err) + } + return line, isPrefix, err } func (c *connection) Printf(format string, v ...interface{}) (n int, err error) { @@ -119,20 +195,117 @@ func (c *connection) Printf(format string, v ...interface{}) (n int, err error) func (c *connection) Read(b []byte) (n int, err error) { if c.config.ReadTimeoutMS > 0 { - _ = c.Conn.SetReadDeadline(time.Now().Add(time.Duration(c.config.ReadTimeoutMS) * time.Millisecond)) + err = c.SetReadDeadline(time.Now().Add(time.Duration(c.config.ReadTimeoutMS) * time.Millisecond)) + if err != nil { + log.Debugf("nezha connection. error=%v, conn=%p", err, c) + return 0, err + } + } + n, err = c.r.Read(b) + if err != nil { + log.Debugf("nezha connection. error=%v, conn=%p", err, c) + c.close(err) } - return c.r.Read(b) + return n, err } func (c *connection) Write(b []byte) (n int, err error) { + if c.config.WChanSize > 0 { + c.wChan <- wmsg{t: wMsgTWrite, b: b} + return len(b), nil + } + return c.write(b) +} + +func (c *connection) write(b []byte) (n int, err error) { if c.config.WriteTimeoutMS > 0 { - _ = c.Conn.SetWriteDeadline(time.Now().Add(time.Duration(c.config.WriteTimeoutMS) * time.Millisecond)) + err = c.SetWriteDeadline(time.Now().Add(time.Duration(c.config.WriteTimeoutMS) * time.Millisecond)) + if err != nil { + log.Debugf("nezha connection. error=%v, conn=%p", err, c) + return 0, err + } } - return c.w.Write(b) + n, err = c.w.Write(b) + if err != nil { + log.Debugf("nezha connection. error=%v, conn=%p", err, c) + c.close(err) + } + return n, err +} + +func (c *connection) runWriteLoop() { + for { + select { + case <-c.exitChan: + log.Debug("exitChan recv, exit write loop.") + return + case msg := <-c.wChan: + switch msg.t { + case wMsgTWrite: + if _, err := c.write(msg.b); err != nil { + log.Debugf("nezha connection. error=%v, conn=%p", err, c) + return + } + case wMsgTFlush: + if err := c.flush(); err != nil { + log.Debugf("nezha connection. error=%v, conn=%p", err, c) + c.flushDoneChan <- struct{}{} + return + } + c.flushDoneChan <- struct{}{} + } + } + } +} + +func (c *connection) Flush() error { + if c.config.WChanSize > 0 { + c.wChan <- wmsg{t: wMsgTFlush} + <-c.flushDoneChan + return nil + } + + return c.flush() +} + +func (c *connection) flush() error { + w, ok := c.w.(*bufio.Writer) + if ok { + if c.config.WriteTimeoutMS > 0 { + err := c.SetWriteDeadline(time.Now().Add(time.Duration(c.config.WriteTimeoutMS) * time.Millisecond)) + if err != nil { + log.Debugf("nezha connection. error=%v, conn=%p", err, c) + return err + } + } + if err := w.Flush(); err != nil { + log.Debugf("nezha connection. error=%v, conn=%p", err, c) + c.close(err) + return err + } + } + return nil } func (c *connection) Close() error { - return c.Conn.Close() + log.Debugf("nezha connection Close. conn=%p", c) + c.close(nil) + return nil +} + +func (c *connection) close(err error) { + log.Debugf("nezha connection close. err=%v, conn=%p", err, c) + c.closeOnce.Do(func() { + if c.config.WChanSize > 0 { + c.exitChan <- struct{}{} + } + c.doneChan <- err + _ = c.Conn.Close() + }) +} + +func (c *connection) Done() <-chan error { + return c.doneChan } func (c *connection) LocalAddr() net.Addr { @@ -144,12 +317,28 @@ func (c *connection) RemoteAddr() net.Addr { } func (c *connection) SetDeadline(t time.Time) error { - return c.Conn.SetDeadline(t) + err := c.Conn.SetDeadline(t) + if err != nil { + log.Debugf("nezha connection. error=%v, conn=%p", err, c) + c.close(err) + } + return err } func (c *connection) SetReadDeadline(t time.Time) error { - return c.Conn.SetReadDeadline(t) + err := c.Conn.SetReadDeadline(t) + if err != nil { + log.Debugf("nezha connection. error=%v, conn=%p", err, c) + c.close(err) + } + return err } + func (c *connection) SetWriteDeadline(t time.Time) error { - return c.Conn.SetWriteDeadline(t) + err := c.Conn.SetWriteDeadline(t) + if err != nil { + log.Debugf("nezha connection. error=%v, conn=%p", err, c) + c.close(err) + } + return err } diff --git a/pkg/log/global.go b/pkg/log/global.go index 795bf50..90856fd 100644 --- a/pkg/log/global.go +++ b/pkg/log/global.go @@ -31,6 +31,10 @@ func Fatalf(format string, v ...interface{}) { global.Out(LevelFatal, 3, fmt.Sprintf(format, v...)) } +func Panicf(format string, v ...interface{}) { + global.Out(LevelPanic, 3, fmt.Sprintf(format, v...)) +} + func Output(level Level, calldepth int, v ...interface{}) { global.Out(level, 3, fmt.Sprint(v...)) } @@ -55,6 +59,10 @@ func Fatal(v ...interface{}) { global.Out(LevelFatal, 3, fmt.Sprint(v...)) } +func Panic(v ...interface{}) { + global.Out(LevelPanic, 3, fmt.Sprint(v...)) +} + func FatalIfErrorNotNil(err error) { if err != nil { global.Out(LevelError, 3, fmt.Sprintf("fatal since error not nil. err=%+v", err)) @@ -62,6 +70,13 @@ func FatalIfErrorNotNil(err error) { } } +func PanicIfErrorNotNil(err error) { + if err != nil { + global.Out(LevelPanic, 3, fmt.Sprintf("fatal since error not nil. err=%+v", err)) + panic(err) + } +} + func Out(level Level, calldepth int, s string) { global.Out(level, calldepth, s) } @@ -75,8 +90,8 @@ func Init(c Config) error { func init() { global, _ = New(Config{ - Level: LevelDebug, - IsToStdout: true, - ShortFileFlag:true, + Level: LevelDebug, + IsToStdout: true, + ShortFileFlag: true, }) } diff --git a/pkg/log/log.go b/pkg/log/log.go index 4aff64f..6660291 100644 --- a/pkg/log/log.go +++ b/pkg/log/log.go @@ -27,14 +27,17 @@ type Logger interface { Warnf(format string, v ...interface{}) Errorf(format string, v ...interface{}) Fatalf(format string, v ...interface{}) // 打印日志并退出程序 + Panicf(format string, v ...interface{}) Debug(v ...interface{}) Info(v ...interface{}) Warn(v ...interface{}) Error(v ...interface{}) Fatal(v ...interface{}) + Panic(v ...interface{}) FatalIfErrorNotNil(err error) + PanicIfErrorNotNil(err error) Outputf(level Level, calldepth int, format string, v ...interface{}) Output(level Level, calldepth int, v ...interface{}) @@ -57,12 +60,13 @@ type Config struct { type Level uint8 const ( - _ = iota + _ Level = iota LevelDebug LevelInfo LevelWarn LevelError LevelFatal + LevelPanic ) func New(c Config) (Logger, error) { @@ -72,7 +76,7 @@ func New(c Config) (Logger, error) { console io.Writer err error ) - if c.Level < LevelDebug || c.Level > LevelFatal { + if c.Level < LevelDebug || c.Level > LevelPanic { return nil, LogErr } if c.Filename != "" { @@ -90,11 +94,11 @@ func New(c Config) (Logger, error) { } l := &logger{ - c: c, - dir: dir, - fp: fp, - console: console, - currRoundTime:time.Now(), + c: c, + dir: dir, + fp: fp, + console: console, + currRoundTime: time.Now(), } return l, nil } @@ -105,12 +109,14 @@ const ( levelWarnString = " WARN " levelErrorString = "ERROR " levelFatalString = "FATAL " + levelPanicString = "PANIC " levelDebugColorString = "\033[22;37mDEBUG\033[0m " levelInfoColorString = "\033[22;36m INFO\033[0m " levelWarnColorString = "\033[22;33m WARN\033[0m " levelErrorColorString = "\033[22;31mERROR\033[0m " - levelFatalColorString = "\033[22;31mFATAL\033[0m " // 颜色和error的一样 + levelFatalColorString = "\033[22;31mFATAL\033[0m " // 颜色和 error 级别一样 + levelPanicColorString = "\033[22;31mPANIC\033[0m " // 颜色和 error 级别一样 ) var ( @@ -120,6 +126,7 @@ var ( LevelWarn: levelWarnString, LevelError: levelErrorString, LevelFatal: levelFatalString, + LevelPanic: levelPanicString, } levelToColorString = map[Level]string{ LevelDebug: levelDebugColorString, @@ -127,6 +134,7 @@ var ( LevelWarn: levelWarnColorString, LevelError: levelErrorColorString, LevelFatal: levelFatalColorString, + LevelPanic: levelPanicColorString, } ) @@ -135,10 +143,10 @@ type logger struct { dir string - m sync.Mutex - fp *os.File - console io.Writer - buf bytes.Buffer + m sync.Mutex + fp *os.File + console io.Writer + buf bytes.Buffer currRoundTime time.Time } @@ -167,6 +175,11 @@ func (l *logger) Fatalf(format string, v ...interface{}) { os.Exit(1) } +func (l *logger) Panicf(format string, v ...interface{}) { + l.Out(LevelPanic, 3, fmt.Sprintf(format, v...)) + panic(fmt.Sprintf(format, v...)) +} + func (l *logger) Output(level Level, calldepth int, v ...interface{}) { l.Out(level, 3, fmt.Sprint(v...)) } @@ -192,6 +205,11 @@ func (l *logger) Fatal(v ...interface{}) { os.Exit(1) } +func (l *logger) Panic(v ...interface{}) { + l.Out(LevelPanic, 3, fmt.Sprint(v...)) + panic(fmt.Sprint(v...)) +} + func (l *logger) FatalIfErrorNotNil(err error) { if err != nil { l.Out(LevelError, 3, fmt.Sprintf("fatal since error not nil. err=%+v", err)) @@ -199,6 +217,13 @@ func (l *logger) FatalIfErrorNotNil(err error) { } } +func (l *logger) PanicIfErrorNotNil(err error) { + if err != nil { + l.Out(LevelPanic, 3, fmt.Sprintf("panic since error not nil. err=%+v", err)) + panic(err) + } +} + func (l *logger) Out(level Level, calldepth int, s string) { if l.c.Level > level { return @@ -221,7 +246,7 @@ func (l *logger) Out(level Level, calldepth int, s string) { if l.c.ShortFileFlag { writeShortFile(&l.buf, calldepth) } - if len(s) == 0 || s[len(s)-1] != '\n' { + if l.buf.Len() == 0 || l.buf.Bytes()[l.buf.Len()-1] != '\n' { l.buf.WriteByte('\n') } diff --git a/pkg/log/log_test.go b/pkg/log/log_test.go index ea29367..ce39f97 100644 --- a/pkg/log/log_test.go +++ b/pkg/log/log_test.go @@ -1,6 +1,8 @@ package log import ( + "encoding/hex" + "errors" "github.com/q191201771/nezha/pkg/assert" originLog "log" "os" @@ -9,13 +11,15 @@ import ( func TestLogger(t *testing.T) { c := Config{ - Level: LevelInfo, - Filename: "/tmp/lallogtest/aaa.log", - IsToStdout: true, + Level: LevelInfo, + Filename: "/tmp/lallogtest/aaa.log", + IsToStdout: true, IsRotateDaily: true, } l, err := New(c) assert.Equal(t, nil, err) + buf := []byte("1234567890987654321") + l.Error(hex.Dump(buf)) l.Debugf("l test msg by Debug%s", "f") l.Infof("l test msg by Info%s", "f") l.Warnf("l test msg by Warn%s", "f") @@ -30,6 +34,8 @@ func TestLogger(t *testing.T) { } func TestGlobal(t *testing.T) { + buf := []byte("1234567890987654321") + Error(hex.Dump(buf)) Debugf("g test msg by Debug%s", "f") Infof("g test msg by Info%s", "f") Warnf("g test msg by Warn%s", "f") @@ -40,9 +46,9 @@ func TestGlobal(t *testing.T) { Error("g test msg by Error") c := Config{ - Level: LevelInfo, - Filename: "/tmp/lallogtest/bbb.log", - IsToStdout: true, + Level: LevelInfo, + Filename: "/tmp/lallogtest/bbb.log", + IsToStdout: true, } err := Init(c) assert.Equal(t, nil, err) @@ -64,7 +70,7 @@ func TestNew(t *testing.T) { l Logger err error ) - l, err = New(Config{Level: LevelFatal + 1}) + l, err = New(Config{Level: LevelPanic + 1}) assert.Equal(t, nil, l) assert.Equal(t, LogErr, err) @@ -79,9 +85,9 @@ func TestNew(t *testing.T) { func TestRotate(t *testing.T) { c := Config{ - Level: LevelInfo, - Filename: "/tmp/lallogtest/ccc.log", - IsToStdout: false, + Level: LevelInfo, + Filename: "/tmp/lallogtest/ccc.log", + IsToStdout: false, IsRotateDaily: true, } err := Init(c) @@ -95,14 +101,49 @@ func TestRotate(t *testing.T) { } } +func withRecover(f func()) { + defer func() { + recover() + }() + f() +} + +func TestPanic(t *testing.T) { + withRecover(func() { + Debug("ddd") + Panic("aaa") + }) + withRecover(func() { + Panicf("%s", "bbb") + }) + withRecover(func() { + PanicIfErrorNotNil(errors.New("mock error")) + }) + withRecover(func() { + l, err := New(Config{Level: LevelDebug, IsToStdout: true}) + assert.Equal(t, nil, err) + l.Panic("aaa") + }) + withRecover(func() { + l, err := New(Config{Level: LevelDebug, IsToStdout: true}) + assert.Equal(t, nil, err) + l.Panicf("%s", "bbb") + }) + withRecover(func() { + l, err := New(Config{Level: LevelDebug, IsToStdout: true}) + assert.Equal(t, nil, err) + l.PanicIfErrorNotNil(errors.New("mock error")) + }) +} + func BenchmarkStdout(b *testing.B) { b.ReportAllocs() c := Config{ - Level: LevelInfo, + Level: LevelInfo, //Filename: "/tmp/lallogtest/ddd.log", - Filename: "/dev/null", + Filename: "/dev/null", //IsToStdout: true, - ShortFileFlag:true, + ShortFileFlag: true, } err := Init(c) assert.Equal(b, nil, err) diff --git a/test.sh b/test.sh index 7c3ff2b..1fd3db9 100755 --- a/test.sh +++ b/test.sh @@ -1,6 +1,31 @@ #!/usr/bin/env bash -set -e +# 在 macos 下运行 gofmt 检查 +uname=$(uname) +if [[ "$uname" == "Darwin" ]]; then + echo "CHEFERASEME run gofmt check..." + gofiles=$(git diff --name-only --diff-filter=ACM | grep '.go$') + if [ ! -z "$gofiles" ]; then + #echo "CHEFERASEME mod gofiles exist:" $gofiles + unformatted=$(gofmt -l $gofiles) + if [ ! -z "$unformatted" ]; then + echo "Go files should be formatted with gofmt. Please run:" + for fn in $unformatted; do + echo " gofmt -w $PWD/$fn" + done + #exit 1 + else + echo "Go files be formatted." + fi + else + echo "CHEFERASEME mod gofiles not exist." + fi +else + echo "CHEFERASEME not run gofmt check..." +fi + +# 跑 go test 生成测试覆盖率 +echo "CHEFERASEME run coverage test..." echo "" > coverage.txt for d in $(go list ./... | grep -v vendor | grep nezha/pkg); do