diff --git a/pkg/connection/connection.go b/pkg/connection/connection.go index 080dd7b..b1eaffa 100644 --- a/pkg/connection/connection.go +++ b/pkg/connection/connection.go @@ -127,7 +127,7 @@ func New(conn net.Conn, modOptions ...ModOption) Connection { } if c.option.WriteChanSize > 0 { - c.wChan = make(chan wMsg, c.option.WriteBufSize) + c.wChan = make(chan wMsg, c.option.WriteChanSize) c.flushDoneChan = make(chan struct{}, 1) c.exitChan = make(chan struct{}, 1) go c.runWriteLoop() diff --git a/pkg/connection/connection_test.go b/pkg/connection/connection_test.go index abe5be9..0327a89 100644 --- a/pkg/connection/connection_test.go +++ b/pkg/connection/connection_test.go @@ -6,11 +6,16 @@ // // Author: Chef (191201771@qq.com) -package connection +package connection_test import ( + "math/rand" "net" + "sync/atomic" "testing" + "time" + + "github.com/q191201771/naza/pkg/connection" "github.com/q191201771/naza/pkg/assert" "github.com/q191201771/naza/pkg/nazalog" @@ -26,12 +31,12 @@ func TestWriteTimeout(t *testing.T) { assert.Equal(t, nil, err) defer l.Close() go func() { - conn, _ := l.Accept() - defer conn.Close() + srvConn, _ := l.Accept() + defer srvConn.Close() <-ch }() conn, err := net.Dial("tcp", ":10027") - c := New(conn, func(opt *Option) { + c := connection.New(conn, func(opt *connection.Option) { opt.WriteTimeoutMS = 1000 }) assert.Equal(t, nil, err) @@ -45,3 +50,48 @@ func TestWriteTimeout(t *testing.T) { } ch <- struct{}{} } + +func TestWrite(t *testing.T) { + var sentN uint32 + var sentDone uint32 + + rand.Seed(time.Now().Unix()) + l, err := net.Listen("tcp", ":10027") + assert.Equal(t, nil, err) + go func() { + c, err := l.Accept() + srvConn := connection.New(c, func(option *connection.Option) { + option.WriteChanSize = 1024 + //option.WriteBufSize = 256 + option.WriteTimeoutMS = 10000 + }) + assert.Equal(t, nil, err) + for i := 0; i < 10; i++ { + b := make([]byte, rand.Intn(4096)) + n, err := srvConn.Write(b) + if err == nil { + nazalog.Debugf("sent. i=%d, n=%d", i, n) + } + assert.Equal(t, nil, err) + atomic.AddUint32(&sentN, uint32(n)) + } + err = srvConn.Flush() + assert.Equal(t, nil, err) + nazalog.Debugf("total sent:%d", sentN) + atomic.StoreUint32(&sentDone, 1) + }() + + conn, err := net.Dial("tcp", ":10027") + assert.Equal(t, nil, err) + b := make([]byte, 4096) + var readN uint32 + for { + n, err := conn.Read(b) + assert.Equal(t, nil, err) + readN += uint32(n) + nazalog.Debugf("total read:%d", readN) + if atomic.LoadUint32(&sentDone) == 1 && atomic.LoadUint32(&sentN) == readN { + break + } + } +}