You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
naza/pkg/connection/connection.go

359 lines
7.7 KiB
Go

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

// package connection
//
// 对 net.Conn 接口的二次封装,目的有两个:
// 1. 在流媒体传输这种特定的长连接场景下提供更方便、高性能的接口
// 2. 便于后续将 TCPConn 替换成其他传输协议
package connection
import (
"bufio"
"errors"
"fmt"
"github.com/q191201771/nezha/pkg/log"
"io"
"net"
"sync"
"time"
)
var connectionErr = errors.New("connection: fxxk")
type Connection interface {
// 包含 interface net.Conn 的所有方法
// Read
// Write
// Close
// LocalAddr
// RemoteAddr
// SetDeadline
// SetReadDeadline
// SetWriteDeadline
net.Conn
ReadAtLeast(buf []byte, min int) (n int, err error)
ReadLine() (line []byte, isPrefix bool, err error)
// TODO chef: 这个接口是否不提供
Printf(fmt string, v ...interface{}) (n int, err error)
// 如果使用了 bufio 写缓冲,则将缓冲中的数据发送出去
// 如果使用了 channel 异步发送,则阻塞等待,直到之前 channel 中的数据全部发送完毕
Flush() error
// 阻塞直到连接主动或被动关闭
// @return 返回 nil 则是本端主动调用 Close 关闭
Done() <- chan error
// TODO chef: 这几个接口是否不提供
ModWriteChanSize(n int)
ModWriteBufSize(n int)
ModReadTimeoutMS(n int)
ModWriteTimeoutMS(n int)
}
type Config struct {
// 如果不为0则之后每次读/写使用 bufio 的缓冲
ReadBufSize int
WriteBufSize int
// 如果不为0则之后每次读/写都带超时
ReadTimeoutMS int
WriteTimeoutMS int
// 如果不过0则写使用 channel 将数据发送到后台协程中发送
WChanSize int
}
type wMsgT int
const (
_ wMsgT = iota
wMsgTWrite
wMsgTFlush
wMsgTClose // TODO chef: 没有使用
)
type wmsg struct {
t wMsgT
b []byte
}
func New(conn net.Conn, config Config) Connection {
var c connection
c.Conn = conn
if config.ReadBufSize > 0 {
c.r = bufio.NewReaderSize(conn, config.ReadBufSize)
} else {
c.r = conn
}
if config.WriteBufSize > 0 {
c.w = bufio.NewWriterSize(conn, config.WriteBufSize)
} else {
c.w = conn
}
if config.WChanSize > 0 {
c.wChan = make(chan wmsg, config.WChanSize)
c.flushDoneChan = make(chan struct{}, 1)
go c.runWriteLoop()
}
c.doneChan = make(chan error, 1)
c.exitChan = make(chan struct{}, 1)
c.config = config
return &c
}
type connection struct {
Conn net.Conn
r io.Reader
w io.Writer
config Config
wChan chan wmsg
flushDoneChan chan struct{}
doneChan chan error
exitChan chan struct{}
closeOnce sync.Once
}
// Mod类型函数不加锁
// 由调用方保证不和写操作并发执行
func (c *connection) ModWriteChanSize(n int) {
if c.config.WChanSize > 0 {
panic(connectionErr)
}
c.config.WChanSize = n
c.wChan = make(chan wmsg, n)
c.flushDoneChan = make(chan struct{}, 1)
go c.runWriteLoop()
}
func (c *connection) ModWriteBufSize(n int) {
if c.config.WriteBufSize > 0 {
// 如果之前已经设置过写缓冲,直接 panic
// 这里改成 flush 后替换成新缓冲也行,暂时没这个必要
panic(connectionErr)
}
c.config.WriteBufSize = n
c.w = bufio.NewWriterSize(c.Conn, n)
}
func (c *connection) ModReadTimeoutMS(n int) {
if c.config.ReadTimeoutMS > 0 {
panic(connectionErr)
}
c.config.ReadTimeoutMS = n
}
func (c *connection) ModWriteTimeoutMS(n int) {
if c.config.WriteTimeoutMS > 0 {
panic(connectionErr)
}
c.config.WriteTimeoutMS = n
}
func (c *connection) ReadAtLeast(buf []byte, min int) (n int, err error) {
if c.config.ReadTimeoutMS > 0 {
err = c.SetReadDeadline(time.Now().Add(time.Duration(c.config.ReadTimeoutMS) * time.Millisecond))
if err != nil {
log.Debugf("error=%v", err)
return 0, err
}
}
n, err = io.ReadAtLeast(c.r, buf, min)
if err != nil {
c.doneChan <- err
log.Debugf("error=%v", err)
}
return n, err
}
func (c *connection) ReadLine() (line []byte, isPrefix bool, err error) {
bufioReader, ok := c.r.(*bufio.Reader)
if !ok {
// 目前只有使用了 bufio.Reader 时才能执行 ReadLine 操作
panic(connectionErr)
}
if c.config.ReadTimeoutMS > 0 {
err = c.SetReadDeadline(time.Now().Add(time.Duration(c.config.ReadTimeoutMS) * time.Millisecond))
if err != nil {
log.Debugf("error=%v", err)
return nil, false, err
}
}
line, isPrefix, err = bufioReader.ReadLine()
if err != nil {
c.doneChan <- err
log.Debugf("error=%v", err)
}
return line, isPrefix, err
}
func (c *connection) Printf(format string, v ...interface{}) (n int, err error) {
if c.config.WriteTimeoutMS > 0 {
_ = c.Conn.SetWriteDeadline(time.Now().Add(time.Duration(c.config.WriteTimeoutMS) * time.Millisecond))
}
return fmt.Fprintf(c.Conn, format, v...)
}
func (c *connection) Read(b []byte) (n int, err error) {
if c.config.ReadTimeoutMS > 0 {
err = c.SetReadDeadline(time.Now().Add(time.Duration(c.config.ReadTimeoutMS) * time.Millisecond))
if err != nil {
log.Debugf("error=%v", err)
return 0, err
}
}
n, err = c.r.Read(b)
if err != nil {
c.doneChan <- err
log.Debugf("error=%v", err)
}
return n, err
}
func (c *connection) Write(b []byte) (n int, err error) {
if c.config.WChanSize > 0 {
c.wChan <- wmsg{t:wMsgTWrite, b:b}
return len(b), nil
}
return c.write(b)
}
func (c *connection) write(b []byte) (n int, err error) {
if c.config.WriteTimeoutMS > 0 {
err = c.SetWriteDeadline(time.Now().Add(time.Duration(c.config.WriteTimeoutMS) * time.Millisecond))
if err != nil {
log.Debugf("error=%v", err)
return 0, err
}
}
n, err = c.w.Write(b)
if err != nil {
c.doneChan <- err
log.Debugf("error=%v", err)
}
return n, err
}
func (c *connection) runWriteLoop() {
for {
select {
case <- c.exitChan:
log.Debug("exitChan recv, exit write loop.")
return
case msg := <- c.wChan:
switch msg.t {
case wMsgTWrite:
if _, err := c.write(msg.b); err != nil {
log.Debugf("error=%v", err)
return
}
case wMsgTFlush:
if err := c.flush(); err != nil {
log.Debugf("error=%v", err)
c.flushDoneChan <- struct{}{}
return
}
c.flushDoneChan <- struct{}{}
case wMsgTClose:
// TODO chef: 是否需要
}
}
}
}
func (c *connection) Flush() error {
if c.config.WChanSize > 0 {
c.wChan <- wmsg{t:wMsgTFlush}
<- c.flushDoneChan
return nil
}
return c.flush()
}
func (c *connection) flush() error {
w, ok := c.w.(*bufio.Writer)
if ok {
if c.config.WriteTimeoutMS > 0 {
err := c.SetWriteDeadline(time.Now().Add(time.Duration(c.config.WriteTimeoutMS) * time.Millisecond))
if err != nil {
log.Debugf("error=%v", err)
return err
}
}
if err := w.Flush(); err != nil {
c.doneChan <- err
log.Debugf("error=%v", err)
return err
}
}
return nil
}
func (c *connection) Close() error {
log.Debugf("Close.")
c.close(nil)
return nil
}
func (c *connection) close(err error) {
log.Debugf("close. err=%v", err)
c.closeOnce.Do(func() {
if c.config.WChanSize > 0 {
c.exitChan <- struct{}{}
}
if err != nil {
c.doneChan <- err
}
_ = c.Conn.Close()
})
}
func (c *connection) Done() <- chan error {
err := <- c.doneChan
log.Debugf("Done. err=%v", err)
if err != nil {
c.close(err)
}
ch := make(chan error, 1)
ch <- err
return ch
}
func (c *connection) LocalAddr() net.Addr {
return c.Conn.LocalAddr()
}
func (c *connection) RemoteAddr() net.Addr {
return c.Conn.RemoteAddr()
}
func (c *connection) SetDeadline(t time.Time) error {
err := c.Conn.SetDeadline(t)
if err != nil {
c.doneChan <- err
log.Debugf("error=%v", err)
}
return err
}
func (c *connection) SetReadDeadline(t time.Time) error {
err := c.Conn.SetReadDeadline(t)
if err != nil {
c.doneChan <- err
log.Debugf("error=%v", err)
}
return err
}
func (c *connection) SetWriteDeadline(t time.Time) error {
err := c.Conn.SetWriteDeadline(t)
if err != nil {
c.doneChan <- err
log.Debugf("error=%v", err)
}
return err
}