diff --git a/pkg/base/buffer.go b/pkg/base/buffer.go index 1b2b27d..0518dcc 100644 --- a/pkg/base/buffer.go +++ b/pkg/base/buffer.go @@ -15,10 +15,10 @@ import ( ) // TODO(chef): refactor 移入naza中 -// TODO(chef): 实现Reader和Writer接口;增加Next函数 // TODO(chef): 增加options: growRoundThreshold; 是否做检查 +// TODO(chef): 扩容策略函数可由外部传入 -const growRoundThreshold = 1048576 +const growRoundThreshold = 1048576 // 1MB // Buffer 先进先出可扩容流式buffer,可直接读写内部切片避免拷贝 // @@ -26,7 +26,15 @@ const growRoundThreshold = 1048576 // 读取方式1 // buf := Bytes() // ... // 读取buf的内容 -// Skip() +// Skip(n) +// +// 读取方式2 +// buf := Peek(n) +// ... +// +// 读取方式3 +// buf := make([]byte, n) +// nn, err := Read(buf) // // 写入方式1 // Grow(n) @@ -39,6 +47,9 @@ const growRoundThreshold = 1048576 // ... // 向buf中写入内容 // Flush(n) // +// 写入方式3 +// n, err := Write(buf) +// type Buffer struct { core []byte rpos int @@ -51,9 +62,19 @@ func NewBuffer(initCap int) *Buffer { } } +// NewBufferRefBytes +// +// 注意,不拷贝参数`b`的内存块,仅持有 +// +func NewBufferRefBytes(b []byte) *Buffer { + return &Buffer{ + core: b, + } +} + // --------------------------------------------------------------------------------------------------------------------- -// Bytes Buffer中未读数据 +// Bytes Buffer中所有未读数据,类似于PeekAll,不拷贝 // func (b *Buffer) Bytes() []byte { if b.rpos == b.wpos { @@ -62,11 +83,23 @@ func (b *Buffer) Bytes() []byte { return b.core[b.rpos:b.wpos] } +// Peek 查看指定长度的未读数据,不拷贝,类似于Next,但是不会修改读取偏移位置 +// +func (b *Buffer) Peek(n int) []byte { + if b.rpos == b.wpos { + return nil + } + if b.Len() < n { + return b.Bytes() + } + return b.core[b.rpos : b.rpos+n] +} + // Skip 将前`n`未读数据标记为已读(也即消费完成) // func (b *Buffer) Skip(n int) { if n > b.wpos-b.rpos { - nazalog.Warnf("[%p] Buffer::Skip too large. n=%d, %s", b, n, b.debugString()) + nazalog.Warnf("[%p] Buffer::Skip too large. n=%d, %s", b, n, b.DebugString()) b.Reset() return } @@ -76,9 +109,10 @@ func (b *Buffer) Skip(n int) { // --------------------------------------------------------------------------------------------------------------------- -// Grow 确保Buffer中至少有`n`大小的空间可写 +// Grow 确保Buffer中至少有`n`大小的空间可写,类似于Reserve // func (b *Buffer) Grow(n int) { + //nazalog.Debugf("[%p] > Buffer::Grow. n=%d, %s", b, n, b.DebugString()) tail := len(b.core) - b.wpos if tail >= n { // 尾部空闲空间足够 @@ -87,6 +121,7 @@ func (b *Buffer) Grow(n int) { if b.rpos+tail >= n { // 头部加上尾部空闲空间足够,将可读数据移动到头部,回收头部空闲空间 + nazalog.Debugf("[%p] Buffer::Grow. move, n=%d, copy=%d", b, n, b.Len()) copy(b.core, b.core[b.rpos:b.wpos]) b.wpos -= b.rpos b.rpos = 0 @@ -101,6 +136,7 @@ func (b *Buffer) Grow(n int) { needed = roundUpPowerOfTwo(needed) } + nazalog.Debugf("[%p] Buffer::Grow. realloc, n=%d, copy=%d, cap=(%d, %d)", b, n, b.Len(), b.Cap(), needed) core := make([]byte, needed, needed) copy(core, b.core[b.rpos:b.wpos]) b.core = core @@ -119,6 +155,8 @@ func (b *Buffer) WritableBytes() []byte { // ReserveBytes 返回可写入`n`大小的字节切片,如果空闲空间不够,内部会进行扩容 // +// 注意,返回值空间大小只会为`n`, +// func (b *Buffer) ReserveBytes(n int) []byte { b.Grow(n) return b.WritableBytes()[:n] @@ -128,24 +166,17 @@ func (b *Buffer) ReserveBytes(n int) []byte { // func (b *Buffer) Flush(n int) { if len(b.core)-b.wpos < n { - nazalog.Warnf("[%p] Buffer::Flush too large. n=%d, %s", b, n, b.debugString()) + nazalog.Warnf("[%p] Buffer::Flush too large. n=%d, %s", b, n, b.DebugString()) b.wpos = len(b.core) return } b.wpos += n } -// ----- implement io.Writer interface --------------------------------------------------------------------------------- - -func (b *Buffer) Write(p []byte) (n int, err error) { - b.Grow(len(p)) - copy(b.core[b.wpos:], p) - b.wpos += n - return len(p), nil -} - // ----- implement io.Reader interface --------------------------------------------------------------------------------- +// Read 拷贝,`p`空间由外部申请 +// func (b *Buffer) Read(p []byte) (n int, err error) { if len(p) == 0 { return 0, nil @@ -158,13 +189,24 @@ func (b *Buffer) Read(p []byte) (n int, err error) { return n, nil } +// ----- implement io.Writer interface --------------------------------------------------------------------------------- + +// Write 拷贝 +// +func (b *Buffer) Write(p []byte) (n int, err error) { + b.Grow(len(p)) + copy(b.core[b.wpos:], p) + b.wpos += n + return len(p), nil +} + // --------------------------------------------------------------------------------------------------------------------- -// Truncate 丢弃可读数据的末尾`n`大小的数据 +// Truncate 丢弃可读数据的末尾`n`大小的数据,或者理解为取消写 // func (b *Buffer) Truncate(n int) { if b.Len() < n { - nazalog.Warnf("[%p] Buffer::Truncate too large. n=%d, %s", b, n, b.debugString()) + nazalog.Warnf("[%p] Buffer::Truncate too large. n=%d, %s", b, n, b.DebugString()) b.Reset() return } @@ -197,16 +239,18 @@ func (b *Buffer) Cap() int { // --------------------------------------------------------------------------------------------------------------------- +func (b *Buffer) DebugString() string { + return fmt.Sprintf("len(core)=%d, rpos=%d, wpos=%d", len(b.core), b.rpos, b.wpos) +} + +// --------------------------------------------------------------------------------------------------------------------- + func (b *Buffer) resetIfEmpty() { if b.rpos == b.wpos { b.Reset() } } -func (b *Buffer) debugString() string { - return fmt.Sprintf("len=%d, rpos=%d, wpos=%d", len(b.core), b.rpos, b.wpos) -} - // TODO(chef): refactor 移入naza中 func roundUpPowerOfTwo(n int) int { if n <= 2 { diff --git a/pkg/rtmp/chunk_composer.go b/pkg/rtmp/chunk_composer.go index 14c9dde..e3a365e 100644 --- a/pkg/rtmp/chunk_composer.go +++ b/pkg/rtmp/chunk_composer.go @@ -99,7 +99,7 @@ func (c *ChunkComposer) RunLoop(reader io.Reader, cb OnCompleteMessage) error { stream.header.MsgTypeId = bootstrap[6] stream.header.MsgStreamId = int(bele.LeUint32(bootstrap[7:])) - stream.msg.reserve(stream.header.MsgLen) + stream.msg.Grow(stream.header.MsgLen) case 1: if _, err := io.ReadAtLeast(reader, bootstrap[:7], 7); err != nil { return err @@ -110,7 +110,7 @@ func (c *ChunkComposer) RunLoop(reader io.Reader, cb OnCompleteMessage) error { stream.header.MsgLen = bele.BeUint24(bootstrap[3:]) stream.header.MsgTypeId = bootstrap[6] - stream.msg.reserve(stream.header.MsgLen) + stream.msg.Grow(stream.header.MsgLen) case 2: if _, err := io.ReadAtLeast(reader, bootstrap[:3], 3); err != nil { return err @@ -153,24 +153,21 @@ func (c *ChunkComposer) RunLoop(reader io.Reader, cb OnCompleteMessage) error { if stream.header.MsgLen <= c.peerChunkSize { neededSize = stream.header.MsgLen } else { - neededSize = stream.header.MsgLen - stream.msg.len() + neededSize = stream.header.MsgLen - stream.msg.Len() if neededSize > c.peerChunkSize { neededSize = c.peerChunkSize } } - // 因为上面已经对整个msg的长度reserve过了,所以这里就不需要reserve了 - //stream.msg.reserve(neededSize) - - if _, err := io.ReadAtLeast(reader, stream.msg.buf[stream.msg.e:stream.msg.e+neededSize], int(neededSize)); err != nil { + if _, err := io.ReadFull(reader, stream.msg.buff.ReserveBytes(int(neededSize))); err != nil { return err } - stream.msg.produced(neededSize) + stream.msg.Flush(neededSize) - if stream.msg.len() == stream.header.MsgLen { + if stream.msg.Len() == stream.header.MsgLen { // 对端设置了chunk size if stream.header.MsgTypeId == base.RtmpTypeIdSetChunkSize { - val := bele.BeUint32(stream.msg.buf) + val := bele.BeUint32(stream.msg.buff.Bytes()) c.SetPeerChunkSize(val) } @@ -193,21 +190,21 @@ func (c *ChunkComposer) RunLoop(reader io.Reader, cb OnCompleteMessage) error { } aggregateStream.header.Csid = stream.header.Csid - for stream.msg.len() != 0 { + for stream.msg.Len() != 0 { // 读取sub message的头 - if stream.msg.len() < 11 { + if stream.msg.Len() < 11 { return ErrRtmp } - aggregateStream.header.MsgTypeId = stream.msg.buf[stream.msg.b] - stream.msg.consumed(1) - aggregateStream.header.MsgLen = bele.BeUint24(stream.msg.buf[stream.msg.b:]) - stream.msg.consumed(3) - aggregateStream.timestamp = bele.BeUint24(stream.msg.buf[stream.msg.b:]) - stream.msg.consumed(3) - aggregateStream.timestamp += uint32(stream.msg.buf[stream.msg.b]) << 24 - stream.msg.consumed(1) - aggregateStream.header.MsgStreamId = int(bele.BeUint24(stream.msg.buf[stream.msg.b:])) - stream.msg.consumed(3) + aggregateStream.header.MsgTypeId = stream.msg.buff.Bytes()[0] + stream.msg.Skip(1) + aggregateStream.header.MsgLen = bele.BeUint24(stream.msg.buff.Bytes()) + stream.msg.Skip(3) + aggregateStream.timestamp = bele.BeUint24(stream.msg.buff.Bytes()) + stream.msg.Skip(3) + aggregateStream.timestamp += uint32(stream.msg.buff.Bytes()[0]) << 24 + stream.msg.Skip(1) + aggregateStream.header.MsgStreamId = int(bele.BeUint24(stream.msg.buff.Bytes())) + stream.msg.Skip(3) // 计算时间戳 if firstSubMessage { @@ -217,13 +214,11 @@ func (c *ChunkComposer) RunLoop(reader io.Reader, cb OnCompleteMessage) error { aggregateStream.header.TimestampAbs = stream.header.TimestampAbs + aggregateStream.timestamp - baseTimestamp // message包体 - if stream.msg.len() < aggregateStream.header.MsgLen { + if stream.msg.Len() < aggregateStream.header.MsgLen { return ErrRtmp } - aggregateStream.msg.buf = stream.msg.buf[stream.msg.b : stream.msg.b+aggregateStream.header.MsgLen] - //aggregateStream.msg.b = 0 - aggregateStream.msg.e = aggregateStream.header.MsgLen - stream.msg.consumed(aggregateStream.header.MsgLen) + aggregateStream.msg.buff = base.NewBufferRefBytes(stream.msg.buff.Peek(int(aggregateStream.header.MsgLen))) + stream.msg.Skip(aggregateStream.header.MsgLen) // sub message回调给上层 if err := cb(aggregateStream); err != nil { @@ -231,21 +226,21 @@ func (c *ChunkComposer) RunLoop(reader io.Reader, cb OnCompleteMessage) error { } // 跳过prev size字段 - if stream.msg.len() < 4 { + if stream.msg.Len() < 4 { return ErrRtmp } - stream.msg.consumed(4) + stream.msg.Skip(4) } } else { if err := cb(stream); err != nil { return err } - stream.msg.clear() + stream.msg.Reset() } } - if stream.msg.len() > stream.header.MsgLen { - nazalog.Warnf("stream msg len should not greater than len field in header. stream.msg.len=%d, header=%+v", stream.msg.len(), stream.header) + if stream.msg.Len() > stream.header.MsgLen { + nazalog.Warnf("stream msg len should not greater than len field in header. stream.msg.len=%d, header=%+v", stream.msg.Len(), stream.header) return ErrRtmp } } diff --git a/pkg/rtmp/client_session.go b/pkg/rtmp/client_session.go index ed777d2..89a2a7c 100644 --- a/pkg/rtmp/client_session.go +++ b/pkg/rtmp/client_session.go @@ -17,8 +17,6 @@ import ( "sync" "time" - "github.com/q191201771/naza/pkg/nazastring" - "github.com/q191201771/lal/pkg/base" "github.com/q191201771/naza/pkg/bele" @@ -362,7 +360,7 @@ func (s *ClientSession) doMsg(stream *Stream) error { s.debugLogReadUserCtrlMsgCount++ if s.debugLogReadUserCtrlMsgCount <= s.debugLogReadUserCtrlMsgMax { nazalog.Warnf("[%s] read user control message, ignore. buf=%s", - s.uniqueKey, hex.Dump(nazastring.SubSliceSafety(stream.msg.buf[stream.msg.b:stream.msg.e], 32))) + s.uniqueKey, hex.Dump(stream.msg.buff.Peek(32))) } case base.RtmpTypeIdAudio: fallthrough @@ -376,7 +374,7 @@ func (s *ClientSession) doMsg(stream *Stream) error { } func (s *ClientSession) doAck(stream *Stream) error { - seqNum := bele.BeUint32(stream.msg.buf[stream.msg.b:stream.msg.e]) + seqNum := bele.BeUint32(stream.msg.buff.Bytes()) nazalog.Infof("[%s] < R Acknowledgement. ignore. sequence number=%d.", s.uniqueKey, seqNum) return nil } @@ -509,10 +507,10 @@ func (s *ClientSession) doResultMessage(stream *Stream, tid int) error { return nil } func (s *ClientSession) doProtocolControlMessage(stream *Stream) error { - if stream.msg.len() < 4 { + if stream.msg.Len() < 4 { return ErrRtmp } - val := int(bele.BeUint32(stream.msg.buf)) + val := int(bele.BeUint32(stream.msg.buff.Bytes())) switch stream.header.MsgTypeId { case base.RtmpTypeIdWinAckSize: diff --git a/pkg/rtmp/message_packer.go b/pkg/rtmp/message_packer.go index c2254b2..0fb6b55 100644 --- a/pkg/rtmp/message_packer.go +++ b/pkg/rtmp/message_packer.go @@ -73,7 +73,6 @@ func (packer *MessagePacker) ChunkAndWrite(writer io.Writer, csid int, typeid ui h.MsgStreamId = streamid h.TimestampAbs = 0 chunks := Message2Chunks(packer.b.Bytes()[12:], &h) - nazalog.Debugf("CHEFERASEME %d %d", packer.b.Len(), len(chunks)) packer.b.Reset() _, err := writer.Write(chunks) return err diff --git a/pkg/rtmp/server_session.go b/pkg/rtmp/server_session.go index 03eb470..f74cb65 100644 --- a/pkg/rtmp/server_session.go +++ b/pkg/rtmp/server_session.go @@ -233,7 +233,7 @@ func (s *ServerSession) doMsg(stream *Stream) error { } func (s *ServerSession) doAck(stream *Stream) error { - seqNum := bele.BeUint32(stream.msg.buf[stream.msg.b:stream.msg.e]) + seqNum := bele.BeUint32(stream.msg.buff.Bytes()) nazalog.Infof("[%s] < R Acknowledgement. ignore. sequence number=%d.", s.uniqueKey, seqNum) return nil } @@ -329,7 +329,7 @@ func (s *ServerSession) doCommandMessage(stream *Stream) error { func (s *ServerSession) doCommandAmf3Message(stream *Stream) error { //去除前面的0就是Amf0的数据 - stream.msg.consumed(1) + stream.msg.Skip(1) return s.doCommandMessage(stream) } diff --git a/pkg/rtmp/stream.go b/pkg/rtmp/stream.go index 51b4add..d77ad95 100644 --- a/pkg/rtmp/stream.go +++ b/pkg/rtmp/stream.go @@ -11,22 +11,11 @@ package rtmp import ( "encoding/hex" "fmt" - - "github.com/q191201771/naza/pkg/nazastring" - "github.com/q191201771/lal/pkg/base" - "github.com/q191201771/naza/pkg/nazalog" ) -const initMsgLen = 4096 - -// TODO chef: 将这个buffer实现和bytes.Buffer做比较,考虑将它放入naza package中 -type StreamMsg struct { - buf []byte - b uint32 // 读取起始位置 - e uint32 // 读取结束位置,写入起始位置 -} +// ----- Stream -------------------------------------------------------------------------------------------------------- type Stream struct { header base.RtmpHeader @@ -38,103 +27,88 @@ type Stream struct { func NewStream() *Stream { return &Stream{ msg: StreamMsg{ - buf: make([]byte, initMsgLen), + buff: base.NewBuffer(initMsgLen), }, } } // 序列化成可读字符串,一般用于发生错误时打印日志 func (stream *Stream) toDebugString() string { - // 注意,这里打印的二进制数据的其实位置是从 0 开始,而不是 msg.b 位置 - return fmt.Sprintf("header=%+v, b=%d, hex=%s", - stream.header, stream.msg.b, hex.Dump(nazastring.SubSliceSafety(stream.msg.buf[:stream.msg.e], 4096))) + return fmt.Sprintf("header=%+v, b=%s, hex=%s", + stream.header, stream.msg.buff.DebugString(), hex.Dump(stream.msg.buff.Peek(4096))) } func (stream *Stream) toAvMsg() base.RtmpMsg { // TODO chef: 考虑可能出现header中的len和buf的大小不一致的情况 - if stream.header.MsgLen != uint32(len(stream.msg.buf[stream.msg.b:stream.msg.e])) { - nazalog.Errorf("toAvMsg. headerMsgLen=%d, bufLen=%d", stream.header.MsgLen, len(stream.msg.buf[stream.msg.b:stream.msg.e])) + if stream.header.MsgLen != uint32(stream.msg.buff.Len()) { + nazalog.Errorf("toAvMsg. headerMsgLen=%d, bufLen=%d", stream.header.MsgLen, stream.msg.buff.Len()) } return base.RtmpMsg{ Header: stream.header, - Payload: stream.msg.buf[stream.msg.b:stream.msg.e], + Payload: stream.msg.buff.Bytes(), } } -// 确保可写空间,如果不够会扩容 -func (msg *StreamMsg) reserve(n uint32) { - bufCap := uint32(cap(msg.buf)) - nn := bufCap - msg.e // 剩余空闲空间 - if nn > n { // 足够 - return - } - for nn < n { // 不够,空闲空间翻倍,直到大于需求空间 - nn <<= 1 - } - nb := make([]byte, bufCap+nn) // 当前容量加扩充容量 - copy(nb, msg.buf[msg.b:msg.e]) // 老数据拷贝 - msg.buf = nb // 替换 - nazalog.Debugf("reserve. newLen=%d(%d, %d), need=(%d -> %d), cap=(%d -> %d)", len(msg.buf), msg.b, msg.e, n, nn, bufCap, cap(msg.buf)) +// ----- StreamMsg ----------------------------------------------------------------------------------------------------- + +type StreamMsg struct { + buff *base.Buffer } -// 可读长度 -func (msg *StreamMsg) len() uint32 { - return msg.e - msg.b +// 确保可写空间,如果不够会扩容 +func (msg *StreamMsg) Grow(n uint32) { + msg.buff.Grow(int(n)) } -// 写入数据后调用 -func (msg *StreamMsg) produced(n uint32) { - msg.e += n +func (msg *StreamMsg) Len() uint32 { + return uint32(msg.buff.Len()) } -// 读取数据后调用 -func (msg *StreamMsg) consumed(n uint32) { - msg.b += n +func (msg *StreamMsg) Flush(n uint32) { + msg.buff.Flush(int(n)) } -// 清空,空闲内存空间保留不释放 -func (msg *StreamMsg) clear() { - msg.b = 0 - msg.e = 0 +func (msg *StreamMsg) Skip(n uint32) { + msg.buff.Skip(int(n)) } -//func (msg *StreamMsg) bytes() []byte { -// return msg.buf[msg.b: msg.e] -//} +func (msg *StreamMsg) Reset() { + msg.buff.Reset() +} func (msg *StreamMsg) peekStringWithType() (string, error) { - str, _, err := Amf0.ReadString(msg.buf[msg.b:msg.e]) + str, _, err := Amf0.ReadString(msg.buff.Bytes()) return str, err } func (msg *StreamMsg) readStringWithType() (string, error) { - str, l, err := Amf0.ReadString(msg.buf[msg.b:msg.e]) + str, l, err := Amf0.ReadString(msg.buff.Bytes()) if err == nil { - msg.consumed(uint32(l)) + msg.Skip(uint32(l)) } return str, err } func (msg *StreamMsg) readNumberWithType() (int, error) { - val, l, err := Amf0.ReadNumber(msg.buf[msg.b:msg.e]) + val, l, err := Amf0.ReadNumber(msg.buff.Bytes()) if err == nil { - msg.consumed(uint32(l)) + msg.Skip(uint32(l)) } return int(val), err } func (msg *StreamMsg) readObjectWithType() (ObjectPairArray, error) { - opa, l, err := Amf0.ReadObject(msg.buf[msg.b:msg.e]) + opa, l, err := Amf0.ReadObject(msg.buff.Bytes()) if err == nil { - msg.consumed(uint32(l)) + msg.Skip(uint32(l)) } return opa, err } func (msg *StreamMsg) readNull() error { - l, err := Amf0.ReadNull(msg.buf[msg.b:msg.e]) + l, err := Amf0.ReadNull(msg.buff.Bytes()) if err == nil { - msg.consumed(uint32(l)) + msg.Skip(uint32(l)) } return err } diff --git a/pkg/rtmp/var.go b/pkg/rtmp/var.go index 54ff75d..ca51bc0 100644 --- a/pkg/rtmp/var.go +++ b/pkg/rtmp/var.go @@ -34,3 +34,7 @@ var ( windowAcknowledgementSize = 5000000 peerBandwidth = 5000000 ) + +// 接收rtmp数据时,msg的初始内存块大小 +// 注意,该值只影响性能,不影响功能(大小不够会自动扩容) +const initMsgLen = 4096