[refactor] rtmp: 使用base.Buffer

pull/114/head
q191201771 3 years ago
parent 746aa9092a
commit 39ccc1b911

@ -15,10 +15,10 @@ import (
)
// TODO(chef): refactor 移入naza中
// TODO(chef): 实现Reader和Writer接口增加Next函数
// TODO(chef): 增加options: growRoundThreshold; 是否做检查
// TODO(chef): 扩容策略函数可由外部传入
const growRoundThreshold = 1048576
const growRoundThreshold = 1048576 // 1MB
// Buffer 先进先出可扩容流式buffer可直接读写内部切片避免拷贝
//
@ -26,7 +26,15 @@ const growRoundThreshold = 1048576
// 读取方式1
// buf := Bytes()
// ... // 读取buf的内容
// Skip()
// Skip(n)
//
// 读取方式2
// buf := Peek(n)
// ...
//
// 读取方式3
// buf := make([]byte, n)
// nn, err := Read(buf)
//
// 写入方式1
// Grow(n)
@ -39,6 +47,9 @@ const growRoundThreshold = 1048576
// ... // 向buf中写入内容
// Flush(n)
//
// 写入方式3
// n, err := Write(buf)
//
type Buffer struct {
core []byte
rpos int
@ -51,9 +62,19 @@ func NewBuffer(initCap int) *Buffer {
}
}
// NewBufferRefBytes
//
// 注意,不拷贝参数`b`的内存块,仅持有
//
func NewBufferRefBytes(b []byte) *Buffer {
return &Buffer{
core: b,
}
}
// ---------------------------------------------------------------------------------------------------------------------
// Bytes Buffer中未读数据
// Bytes Buffer中所有未读数据类似于PeekAll不拷贝
//
func (b *Buffer) Bytes() []byte {
if b.rpos == b.wpos {
@ -62,11 +83,23 @@ func (b *Buffer) Bytes() []byte {
return b.core[b.rpos:b.wpos]
}
// Peek 查看指定长度的未读数据不拷贝类似于Next但是不会修改读取偏移位置
//
func (b *Buffer) Peek(n int) []byte {
if b.rpos == b.wpos {
return nil
}
if b.Len() < n {
return b.Bytes()
}
return b.core[b.rpos : b.rpos+n]
}
// Skip 将前`n`未读数据标记为已读(也即消费完成)
//
func (b *Buffer) Skip(n int) {
if n > b.wpos-b.rpos {
nazalog.Warnf("[%p] Buffer::Skip too large. n=%d, %s", b, n, b.debugString())
nazalog.Warnf("[%p] Buffer::Skip too large. n=%d, %s", b, n, b.DebugString())
b.Reset()
return
}
@ -76,9 +109,10 @@ func (b *Buffer) Skip(n int) {
// ---------------------------------------------------------------------------------------------------------------------
// Grow 确保Buffer中至少有`n`大小的空间可写
// Grow 确保Buffer中至少有`n`大小的空间可写类似于Reserve
//
func (b *Buffer) Grow(n int) {
//nazalog.Debugf("[%p] > Buffer::Grow. n=%d, %s", b, n, b.DebugString())
tail := len(b.core) - b.wpos
if tail >= n {
// 尾部空闲空间足够
@ -87,6 +121,7 @@ func (b *Buffer) Grow(n int) {
if b.rpos+tail >= n {
// 头部加上尾部空闲空间足够,将可读数据移动到头部,回收头部空闲空间
nazalog.Debugf("[%p] Buffer::Grow. move, n=%d, copy=%d", b, n, b.Len())
copy(b.core, b.core[b.rpos:b.wpos])
b.wpos -= b.rpos
b.rpos = 0
@ -101,6 +136,7 @@ func (b *Buffer) Grow(n int) {
needed = roundUpPowerOfTwo(needed)
}
nazalog.Debugf("[%p] Buffer::Grow. realloc, n=%d, copy=%d, cap=(%d, %d)", b, n, b.Len(), b.Cap(), needed)
core := make([]byte, needed, needed)
copy(core, b.core[b.rpos:b.wpos])
b.core = core
@ -119,6 +155,8 @@ func (b *Buffer) WritableBytes() []byte {
// ReserveBytes 返回可写入`n`大小的字节切片,如果空闲空间不够,内部会进行扩容
//
// 注意,返回值空间大小只会为`n`
//
func (b *Buffer) ReserveBytes(n int) []byte {
b.Grow(n)
return b.WritableBytes()[:n]
@ -128,24 +166,17 @@ func (b *Buffer) ReserveBytes(n int) []byte {
//
func (b *Buffer) Flush(n int) {
if len(b.core)-b.wpos < n {
nazalog.Warnf("[%p] Buffer::Flush too large. n=%d, %s", b, n, b.debugString())
nazalog.Warnf("[%p] Buffer::Flush too large. n=%d, %s", b, n, b.DebugString())
b.wpos = len(b.core)
return
}
b.wpos += n
}
// ----- implement io.Writer interface ---------------------------------------------------------------------------------
func (b *Buffer) Write(p []byte) (n int, err error) {
b.Grow(len(p))
copy(b.core[b.wpos:], p)
b.wpos += n
return len(p), nil
}
// ----- implement io.Reader interface ---------------------------------------------------------------------------------
// Read 拷贝,`p`空间由外部申请
//
func (b *Buffer) Read(p []byte) (n int, err error) {
if len(p) == 0 {
return 0, nil
@ -158,13 +189,24 @@ func (b *Buffer) Read(p []byte) (n int, err error) {
return n, nil
}
// ----- implement io.Writer interface ---------------------------------------------------------------------------------
// Write 拷贝
//
func (b *Buffer) Write(p []byte) (n int, err error) {
b.Grow(len(p))
copy(b.core[b.wpos:], p)
b.wpos += n
return len(p), nil
}
// ---------------------------------------------------------------------------------------------------------------------
// Truncate 丢弃可读数据的末尾`n`大小的数据
// Truncate 丢弃可读数据的末尾`n`大小的数据,或者理解为取消写
//
func (b *Buffer) Truncate(n int) {
if b.Len() < n {
nazalog.Warnf("[%p] Buffer::Truncate too large. n=%d, %s", b, n, b.debugString())
nazalog.Warnf("[%p] Buffer::Truncate too large. n=%d, %s", b, n, b.DebugString())
b.Reset()
return
}
@ -197,16 +239,18 @@ func (b *Buffer) Cap() int {
// ---------------------------------------------------------------------------------------------------------------------
func (b *Buffer) DebugString() string {
return fmt.Sprintf("len(core)=%d, rpos=%d, wpos=%d", len(b.core), b.rpos, b.wpos)
}
// ---------------------------------------------------------------------------------------------------------------------
func (b *Buffer) resetIfEmpty() {
if b.rpos == b.wpos {
b.Reset()
}
}
func (b *Buffer) debugString() string {
return fmt.Sprintf("len=%d, rpos=%d, wpos=%d", len(b.core), b.rpos, b.wpos)
}
// TODO(chef): refactor 移入naza中
func roundUpPowerOfTwo(n int) int {
if n <= 2 {

@ -99,7 +99,7 @@ func (c *ChunkComposer) RunLoop(reader io.Reader, cb OnCompleteMessage) error {
stream.header.MsgTypeId = bootstrap[6]
stream.header.MsgStreamId = int(bele.LeUint32(bootstrap[7:]))
stream.msg.reserve(stream.header.MsgLen)
stream.msg.Grow(stream.header.MsgLen)
case 1:
if _, err := io.ReadAtLeast(reader, bootstrap[:7], 7); err != nil {
return err
@ -110,7 +110,7 @@ func (c *ChunkComposer) RunLoop(reader io.Reader, cb OnCompleteMessage) error {
stream.header.MsgLen = bele.BeUint24(bootstrap[3:])
stream.header.MsgTypeId = bootstrap[6]
stream.msg.reserve(stream.header.MsgLen)
stream.msg.Grow(stream.header.MsgLen)
case 2:
if _, err := io.ReadAtLeast(reader, bootstrap[:3], 3); err != nil {
return err
@ -153,24 +153,21 @@ func (c *ChunkComposer) RunLoop(reader io.Reader, cb OnCompleteMessage) error {
if stream.header.MsgLen <= c.peerChunkSize {
neededSize = stream.header.MsgLen
} else {
neededSize = stream.header.MsgLen - stream.msg.len()
neededSize = stream.header.MsgLen - stream.msg.Len()
if neededSize > c.peerChunkSize {
neededSize = c.peerChunkSize
}
}
// 因为上面已经对整个msg的长度reserve过了所以这里就不需要reserve了
//stream.msg.reserve(neededSize)
if _, err := io.ReadAtLeast(reader, stream.msg.buf[stream.msg.e:stream.msg.e+neededSize], int(neededSize)); err != nil {
if _, err := io.ReadFull(reader, stream.msg.buff.ReserveBytes(int(neededSize))); err != nil {
return err
}
stream.msg.produced(neededSize)
stream.msg.Flush(neededSize)
if stream.msg.len() == stream.header.MsgLen {
if stream.msg.Len() == stream.header.MsgLen {
// 对端设置了chunk size
if stream.header.MsgTypeId == base.RtmpTypeIdSetChunkSize {
val := bele.BeUint32(stream.msg.buf)
val := bele.BeUint32(stream.msg.buff.Bytes())
c.SetPeerChunkSize(val)
}
@ -193,21 +190,21 @@ func (c *ChunkComposer) RunLoop(reader io.Reader, cb OnCompleteMessage) error {
}
aggregateStream.header.Csid = stream.header.Csid
for stream.msg.len() != 0 {
for stream.msg.Len() != 0 {
// 读取sub message的头
if stream.msg.len() < 11 {
if stream.msg.Len() < 11 {
return ErrRtmp
}
aggregateStream.header.MsgTypeId = stream.msg.buf[stream.msg.b]
stream.msg.consumed(1)
aggregateStream.header.MsgLen = bele.BeUint24(stream.msg.buf[stream.msg.b:])
stream.msg.consumed(3)
aggregateStream.timestamp = bele.BeUint24(stream.msg.buf[stream.msg.b:])
stream.msg.consumed(3)
aggregateStream.timestamp += uint32(stream.msg.buf[stream.msg.b]) << 24
stream.msg.consumed(1)
aggregateStream.header.MsgStreamId = int(bele.BeUint24(stream.msg.buf[stream.msg.b:]))
stream.msg.consumed(3)
aggregateStream.header.MsgTypeId = stream.msg.buff.Bytes()[0]
stream.msg.Skip(1)
aggregateStream.header.MsgLen = bele.BeUint24(stream.msg.buff.Bytes())
stream.msg.Skip(3)
aggregateStream.timestamp = bele.BeUint24(stream.msg.buff.Bytes())
stream.msg.Skip(3)
aggregateStream.timestamp += uint32(stream.msg.buff.Bytes()[0]) << 24
stream.msg.Skip(1)
aggregateStream.header.MsgStreamId = int(bele.BeUint24(stream.msg.buff.Bytes()))
stream.msg.Skip(3)
// 计算时间戳
if firstSubMessage {
@ -217,13 +214,11 @@ func (c *ChunkComposer) RunLoop(reader io.Reader, cb OnCompleteMessage) error {
aggregateStream.header.TimestampAbs = stream.header.TimestampAbs + aggregateStream.timestamp - baseTimestamp
// message包体
if stream.msg.len() < aggregateStream.header.MsgLen {
if stream.msg.Len() < aggregateStream.header.MsgLen {
return ErrRtmp
}
aggregateStream.msg.buf = stream.msg.buf[stream.msg.b : stream.msg.b+aggregateStream.header.MsgLen]
//aggregateStream.msg.b = 0
aggregateStream.msg.e = aggregateStream.header.MsgLen
stream.msg.consumed(aggregateStream.header.MsgLen)
aggregateStream.msg.buff = base.NewBufferRefBytes(stream.msg.buff.Peek(int(aggregateStream.header.MsgLen)))
stream.msg.Skip(aggregateStream.header.MsgLen)
// sub message回调给上层
if err := cb(aggregateStream); err != nil {
@ -231,21 +226,21 @@ func (c *ChunkComposer) RunLoop(reader io.Reader, cb OnCompleteMessage) error {
}
// 跳过prev size字段
if stream.msg.len() < 4 {
if stream.msg.Len() < 4 {
return ErrRtmp
}
stream.msg.consumed(4)
stream.msg.Skip(4)
}
} else {
if err := cb(stream); err != nil {
return err
}
stream.msg.clear()
stream.msg.Reset()
}
}
if stream.msg.len() > stream.header.MsgLen {
nazalog.Warnf("stream msg len should not greater than len field in header. stream.msg.len=%d, header=%+v", stream.msg.len(), stream.header)
if stream.msg.Len() > stream.header.MsgLen {
nazalog.Warnf("stream msg len should not greater than len field in header. stream.msg.len=%d, header=%+v", stream.msg.Len(), stream.header)
return ErrRtmp
}
}

@ -17,8 +17,6 @@ import (
"sync"
"time"
"github.com/q191201771/naza/pkg/nazastring"
"github.com/q191201771/lal/pkg/base"
"github.com/q191201771/naza/pkg/bele"
@ -362,7 +360,7 @@ func (s *ClientSession) doMsg(stream *Stream) error {
s.debugLogReadUserCtrlMsgCount++
if s.debugLogReadUserCtrlMsgCount <= s.debugLogReadUserCtrlMsgMax {
nazalog.Warnf("[%s] read user control message, ignore. buf=%s",
s.uniqueKey, hex.Dump(nazastring.SubSliceSafety(stream.msg.buf[stream.msg.b:stream.msg.e], 32)))
s.uniqueKey, hex.Dump(stream.msg.buff.Peek(32)))
}
case base.RtmpTypeIdAudio:
fallthrough
@ -376,7 +374,7 @@ func (s *ClientSession) doMsg(stream *Stream) error {
}
func (s *ClientSession) doAck(stream *Stream) error {
seqNum := bele.BeUint32(stream.msg.buf[stream.msg.b:stream.msg.e])
seqNum := bele.BeUint32(stream.msg.buff.Bytes())
nazalog.Infof("[%s] < R Acknowledgement. ignore. sequence number=%d.", s.uniqueKey, seqNum)
return nil
}
@ -509,10 +507,10 @@ func (s *ClientSession) doResultMessage(stream *Stream, tid int) error {
return nil
}
func (s *ClientSession) doProtocolControlMessage(stream *Stream) error {
if stream.msg.len() < 4 {
if stream.msg.Len() < 4 {
return ErrRtmp
}
val := int(bele.BeUint32(stream.msg.buf))
val := int(bele.BeUint32(stream.msg.buff.Bytes()))
switch stream.header.MsgTypeId {
case base.RtmpTypeIdWinAckSize:

@ -73,7 +73,6 @@ func (packer *MessagePacker) ChunkAndWrite(writer io.Writer, csid int, typeid ui
h.MsgStreamId = streamid
h.TimestampAbs = 0
chunks := Message2Chunks(packer.b.Bytes()[12:], &h)
nazalog.Debugf("CHEFERASEME %d %d", packer.b.Len(), len(chunks))
packer.b.Reset()
_, err := writer.Write(chunks)
return err

@ -233,7 +233,7 @@ func (s *ServerSession) doMsg(stream *Stream) error {
}
func (s *ServerSession) doAck(stream *Stream) error {
seqNum := bele.BeUint32(stream.msg.buf[stream.msg.b:stream.msg.e])
seqNum := bele.BeUint32(stream.msg.buff.Bytes())
nazalog.Infof("[%s] < R Acknowledgement. ignore. sequence number=%d.", s.uniqueKey, seqNum)
return nil
}
@ -329,7 +329,7 @@ func (s *ServerSession) doCommandMessage(stream *Stream) error {
func (s *ServerSession) doCommandAmf3Message(stream *Stream) error {
//去除前面的0就是Amf0的数据
stream.msg.consumed(1)
stream.msg.Skip(1)
return s.doCommandMessage(stream)
}

@ -11,22 +11,11 @@ package rtmp
import (
"encoding/hex"
"fmt"
"github.com/q191201771/naza/pkg/nazastring"
"github.com/q191201771/lal/pkg/base"
"github.com/q191201771/naza/pkg/nazalog"
)
const initMsgLen = 4096
// TODO chef: 将这个buffer实现和bytes.Buffer做比较考虑将它放入naza package中
type StreamMsg struct {
buf []byte
b uint32 // 读取起始位置
e uint32 // 读取结束位置,写入起始位置
}
// ----- Stream --------------------------------------------------------------------------------------------------------
type Stream struct {
header base.RtmpHeader
@ -38,103 +27,88 @@ type Stream struct {
func NewStream() *Stream {
return &Stream{
msg: StreamMsg{
buf: make([]byte, initMsgLen),
buff: base.NewBuffer(initMsgLen),
},
}
}
// 序列化成可读字符串,一般用于发生错误时打印日志
func (stream *Stream) toDebugString() string {
// 注意,这里打印的二进制数据的其实位置是从 0 开始,而不是 msg.b 位置
return fmt.Sprintf("header=%+v, b=%d, hex=%s",
stream.header, stream.msg.b, hex.Dump(nazastring.SubSliceSafety(stream.msg.buf[:stream.msg.e], 4096)))
return fmt.Sprintf("header=%+v, b=%s, hex=%s",
stream.header, stream.msg.buff.DebugString(), hex.Dump(stream.msg.buff.Peek(4096)))
}
func (stream *Stream) toAvMsg() base.RtmpMsg {
// TODO chef: 考虑可能出现header中的len和buf的大小不一致的情况
if stream.header.MsgLen != uint32(len(stream.msg.buf[stream.msg.b:stream.msg.e])) {
nazalog.Errorf("toAvMsg. headerMsgLen=%d, bufLen=%d", stream.header.MsgLen, len(stream.msg.buf[stream.msg.b:stream.msg.e]))
if stream.header.MsgLen != uint32(stream.msg.buff.Len()) {
nazalog.Errorf("toAvMsg. headerMsgLen=%d, bufLen=%d", stream.header.MsgLen, stream.msg.buff.Len())
}
return base.RtmpMsg{
Header: stream.header,
Payload: stream.msg.buf[stream.msg.b:stream.msg.e],
Payload: stream.msg.buff.Bytes(),
}
}
// 确保可写空间,如果不够会扩容
func (msg *StreamMsg) reserve(n uint32) {
bufCap := uint32(cap(msg.buf))
nn := bufCap - msg.e // 剩余空闲空间
if nn > n { // 足够
return
}
for nn < n { // 不够,空闲空间翻倍,直到大于需求空间
nn <<= 1
}
nb := make([]byte, bufCap+nn) // 当前容量加扩充容量
copy(nb, msg.buf[msg.b:msg.e]) // 老数据拷贝
msg.buf = nb // 替换
nazalog.Debugf("reserve. newLen=%d(%d, %d), need=(%d -> %d), cap=(%d -> %d)", len(msg.buf), msg.b, msg.e, n, nn, bufCap, cap(msg.buf))
// ----- StreamMsg -----------------------------------------------------------------------------------------------------
type StreamMsg struct {
buff *base.Buffer
}
// 可读长度
func (msg *StreamMsg) len() uint32 {
return msg.e - msg.b
// 确保可写空间,如果不够会扩容
func (msg *StreamMsg) Grow(n uint32) {
msg.buff.Grow(int(n))
}
// 写入数据后调用
func (msg *StreamMsg) produced(n uint32) {
msg.e += n
func (msg *StreamMsg) Len() uint32 {
return uint32(msg.buff.Len())
}
// 读取数据后调用
func (msg *StreamMsg) consumed(n uint32) {
msg.b += n
func (msg *StreamMsg) Flush(n uint32) {
msg.buff.Flush(int(n))
}
// 清空,空闲内存空间保留不释放
func (msg *StreamMsg) clear() {
msg.b = 0
msg.e = 0
func (msg *StreamMsg) Skip(n uint32) {
msg.buff.Skip(int(n))
}
//func (msg *StreamMsg) bytes() []byte {
// return msg.buf[msg.b: msg.e]
//}
func (msg *StreamMsg) Reset() {
msg.buff.Reset()
}
func (msg *StreamMsg) peekStringWithType() (string, error) {
str, _, err := Amf0.ReadString(msg.buf[msg.b:msg.e])
str, _, err := Amf0.ReadString(msg.buff.Bytes())
return str, err
}
func (msg *StreamMsg) readStringWithType() (string, error) {
str, l, err := Amf0.ReadString(msg.buf[msg.b:msg.e])
str, l, err := Amf0.ReadString(msg.buff.Bytes())
if err == nil {
msg.consumed(uint32(l))
msg.Skip(uint32(l))
}
return str, err
}
func (msg *StreamMsg) readNumberWithType() (int, error) {
val, l, err := Amf0.ReadNumber(msg.buf[msg.b:msg.e])
val, l, err := Amf0.ReadNumber(msg.buff.Bytes())
if err == nil {
msg.consumed(uint32(l))
msg.Skip(uint32(l))
}
return int(val), err
}
func (msg *StreamMsg) readObjectWithType() (ObjectPairArray, error) {
opa, l, err := Amf0.ReadObject(msg.buf[msg.b:msg.e])
opa, l, err := Amf0.ReadObject(msg.buff.Bytes())
if err == nil {
msg.consumed(uint32(l))
msg.Skip(uint32(l))
}
return opa, err
}
func (msg *StreamMsg) readNull() error {
l, err := Amf0.ReadNull(msg.buf[msg.b:msg.e])
l, err := Amf0.ReadNull(msg.buff.Bytes())
if err == nil {
msg.consumed(uint32(l))
msg.Skip(uint32(l))
}
return err
}

@ -34,3 +34,7 @@ var (
windowAcknowledgementSize = 5000000
peerBandwidth = 5000000
)
// 接收rtmp数据时msg的初始内存块大小
// 注意,该值只影响性能,不影响功能(大小不够会自动扩容)
const initMsgLen = 4096

Loading…
Cancel
Save