From 0437993a247e10ee879b8452a5fb1bf9785dc1e0 Mon Sep 17 00:00:00 2001 From: q191201771 <191201771@qq.com> Date: Tue, 27 Aug 2019 11:00:52 +0800 Subject: [PATCH] =?UTF-8?q?=E6=8F=90=E4=BA=A4=E4=BF=A1=E6=81=AF=EF=BC=9A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * rtmp音频直接转发,不等待视频 * 新增 /pkg/util/assert 用于替换单元测试中的 stretchr/testify/assert * 补充一些单元测试 --- .gitignore | 3 + app/httpflvpull/httpflvpull.go | 3 +- pkg/httpflv/client_pull_session.go | 6 +- pkg/rtmp/amf0_test.go | 8 +-- pkg/rtmp/chunk_composer.go | 1 - pkg/rtmp/client_session.go | 9 ++- pkg/rtmp/group.go | 39 +++++++------ pkg/rtmp/handshake_test.go | 2 +- pkg/rtmp/rtmp.go | 2 +- pkg/rtmp/server_session.go | 1 - pkg/util/assert/assert.go | 53 +++++++++++++++++ pkg/util/assert/assert_test.go | 26 +++++++++ pkg/util/bele/bele_test.go | 23 +++++++- pkg/util/log/log_test.go | 92 ++++++++++++++++++------------ pkg/util/unique/unique_test.go | 2 +- 15 files changed, 195 insertions(+), 75 deletions(-) create mode 100644 pkg/util/assert/assert.go create mode 100644 pkg/util/assert/assert_test.go diff --git a/.gitignore b/.gitignore index 78cbbc4..cf9979a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +profile.out +coverage.html + /pre-commit.sh /coverage.txt /TODO.md diff --git a/app/httpflvpull/httpflvpull.go b/app/httpflvpull/httpflvpull.go index 2387a3c..fa5fe65 100644 --- a/app/httpflvpull/httpflvpull.go +++ b/app/httpflvpull/httpflvpull.go @@ -19,7 +19,7 @@ func (obs *Obs) ReadFlvHeaderCB(flvHeader []byte) { } func (obs *Obs) ReadFlvTagCB(tag *httpflv.Tag) { - log.Infof("ReadFlvTagCB %+v %d %d", tag.Header, tag.IsAVCKeySeqHeader(), tag.IsAVCKeyNalu()) + log.Infof("ReadFlvTagCB %+v %t %t", tag.Header, tag.IsAVCKeySeqHeader(), tag.IsAVCKeyNalu()) } func main() { @@ -35,7 +35,6 @@ func main() { log.Error(err) } - func parseFlag() string { url := flag.String("i", "", "specify rtmp url") flag.Parse() diff --git a/pkg/httpflv/client_pull_session.go b/pkg/httpflv/client_pull_session.go index af1b23e..920d6c5 100644 --- a/pkg/httpflv/client_pull_session.go +++ b/pkg/httpflv/client_pull_session.go @@ -44,6 +44,8 @@ type PullSessionObserver interface { ReadFlvTagCB(tag *Tag) // after cb, PullSession won't use this tag data } +// @param connectTimeout TCP连接时超时,单位秒,如果为0,则不设置超时 +// @param readTimeout 接收数据超时 func NewPullSession(obs PullSessionObserver, connectTimeout int64, readTimeout int64) *PullSession { uk := unique.GenUniqueKey("FLVPULL") log.Infof("lifecycle new PullSession. [%s]", uk) @@ -55,7 +57,9 @@ func NewPullSession(obs PullSessionObserver, connectTimeout int64, readTimeout i } } -// @param timeout: timeout for connect operate. if 0, then no timeout +// 支持如下两种格式。当然,前提是对端支持 +// http://{domain}/{app_name}/{stream_name}.flv +// http://{ip}/{domain}/{app_name}/{stream_name}.flv func (session *PullSession) Connect(rawURL string) error { session.ConnStat.Start(session.readTimeout, 0) diff --git a/pkg/rtmp/amf0_test.go b/pkg/rtmp/amf0_test.go index b487e17..4cad7c2 100644 --- a/pkg/rtmp/amf0_test.go +++ b/pkg/rtmp/amf0_test.go @@ -2,7 +2,7 @@ package rtmp import ( "bytes" - "github.com/stretchr/testify/assert" + "github.com/q191201771/lal/pkg/util/assert" "strings" "testing" ) @@ -21,7 +21,7 @@ func TestAmf0_WriteNumber_ReadNumber(t *testing.T) { assert.Equal(t, nil, err, "fxxk.") v, l, err := AMF0.ReadNumber(out.Bytes()) assert.Equal(t, item, v, "fxxk.") - assert.Equal(t, l, 9) + assert.Equal(t, l, 9, "fxxk.") assert.Equal(t, nil, err, "fxxk.") } } @@ -39,7 +39,7 @@ func TestAmf0_WriteString_ReadString(t *testing.T) { assert.Equal(t, nil, err, "fxxk.") v, l, err := AMF0.ReadString(out.Bytes()) assert.Equal(t, item, v, "fxxk.") - assert.Equal(t, l, len(item)+3) + assert.Equal(t, l, len(item)+3, "fxxk.") assert.Equal(t, nil, err, "fxxk.") } @@ -49,7 +49,7 @@ func TestAmf0_WriteString_ReadString(t *testing.T) { assert.Equal(t, nil, err, "fxxk.") v, l, err := AMF0.ReadString(out.Bytes()) assert.Equal(t, longStr, v, "fxxk.") - assert.Equal(t, l, len(longStr)+5) + assert.Equal(t, l, len(longStr)+5, "fxxk.") assert.Equal(t, nil, err, "fxxk.") } diff --git a/pkg/rtmp/chunk_composer.go b/pkg/rtmp/chunk_composer.go index 818977b..2f12bd1 100644 --- a/pkg/rtmp/chunk_composer.go +++ b/pkg/rtmp/chunk_composer.go @@ -9,7 +9,6 @@ import ( "io" ) - type ChunkComposer struct { peerChunkSize int csid2stream map[int]*Stream diff --git a/pkg/rtmp/client_session.go b/pkg/rtmp/client_session.go index eafec6f..215adcb 100644 --- a/pkg/rtmp/client_session.go +++ b/pkg/rtmp/client_session.go @@ -12,7 +12,6 @@ import ( "time" ) - // rtmp客户端类型连接的底层实现 // rtmp包的使用者应该优先使用基于ClientSession实现的PushSession和PullSession type ClientSession struct { @@ -32,9 +31,9 @@ type ClientSession struct { hs HandshakeClient peerWinAckSize int - Conn net.Conn - rb *bufio.Reader - wb *bufio.Writer + Conn net.Conn + rb *bufio.Reader + wb *bufio.Writer wChan chan []byte } @@ -64,7 +63,7 @@ func NewClientSession(t ClientSessionType, obs PullSessionObserver, connectTimeo packer: NewMessagePacker(), chunkComposer: NewChunkComposer(), UniqueKey: unique.GenUniqueKey(uk), - wChan: make(chan []byte, wChanSize), + wChan: make(chan []byte, wChanSize), } } diff --git a/pkg/rtmp/group.go b/pkg/rtmp/group.go index ac81b0a..182a534 100644 --- a/pkg/rtmp/group.go +++ b/pkg/rtmp/group.go @@ -138,13 +138,12 @@ func (group *Group) ReadRTMPAVMsgCB(header Header, timestampAbs int, message []b func (group *Group) broadcastRTMP2RTMP(header Header, timestampAbs int, message []byte) { //log.Infof("%+v", header) + // # 1. 设置好头部信息 var currHeader Header currHeader.MsgLen = len(message) currHeader.Timestamp = timestampAbs currHeader.MsgTypeID = header.MsgTypeID currHeader.MsgStreamID = MSID1 - //var prevHeader *Header - switch header.MsgTypeID { case TypeidDataMessageAMF0: currHeader.CSID = CSIDAMF @@ -159,12 +158,14 @@ func (group *Group) broadcastRTMP2RTMP(header Header, timestampAbs int, message var absChunks []byte + // # 2. 广播。遍历所有sub session,决定是否转发 for session := range group.subSessionSet { + // ## 2.1. 一个message广播给多个sub session时,只做一次chunk切割 if absChunks == nil { absChunks = Message2Chunks(message, &currHeader, LocalChunkSize) } - // 是新连接 + // ## 2.2. 如果是新的sub session,发送已缓存的信息 if session.isFresh { // 发送缓存的头部信息 if group.metadata != nil { @@ -176,34 +177,34 @@ func (group *Group) broadcastRTMP2RTMP(header Header, timestampAbs int, message if group.aacSeqHeader != nil { session.AsyncWrite(group.aacSeqHeader) } - session.isFresh = false + } - } else { - // 首次发送,从I帧开始 + // ## 2.3. 判断当前包的类型,以及sub session的状态,决定是否发送并更新sub session的状态 + switch header.MsgTypeID { + case TypeidDataMessageAMF0: + session.AsyncWrite(absChunks) + case TypeidAudio: + session.AsyncWrite(absChunks) + case TypeidVideo: if session.waitKeyNalu { - if header.MsgTypeID == TypeidDataMessageAMF0 { + if message[0] == 0x17 && message[1] == 0x0 { session.AsyncWrite(absChunks) - } else if header.MsgTypeID == TypeidAudio { - if (message[0]>>4) == 0x0a && message[1] == 0x0 { - session.AsyncWrite(absChunks) - } - } else if header.MsgTypeID == TypeidVideo { - if message[0] == 0x17 && message[1] == 0x0 { - session.AsyncWrite(absChunks) - } - if message[0] == 0x17 && message[1] == 0x1 { - session.AsyncWrite(absChunks) - session.waitKeyNalu = false - } + } + if message[0] == 0x17 && message[1] == 0x1 { + session.AsyncWrite(absChunks) + session.waitKeyNalu = false } } else { session.AsyncWrite(absChunks) } + } } + // # 3. 缓存 metadata 和 avc key seq header 和 aac seq header + // 由于可能没有订阅者,所以message可能还没做chunk切割,所以这里要做判断是否做chunk切割 switch header.MsgTypeID { case TypeidDataMessageAMF0: if absChunks == nil { diff --git a/pkg/rtmp/handshake_test.go b/pkg/rtmp/handshake_test.go index 58b5417..f858071 100644 --- a/pkg/rtmp/handshake_test.go +++ b/pkg/rtmp/handshake_test.go @@ -2,7 +2,7 @@ package rtmp import ( "bytes" - "github.com/stretchr/testify/assert" + "github.com/q191201771/lal/pkg/util/assert" "testing" ) diff --git a/pkg/rtmp/rtmp.go b/pkg/rtmp/rtmp.go index e0d037d..0b5d396 100644 --- a/pkg/rtmp/rtmp.go +++ b/pkg/rtmp/rtmp.go @@ -52,7 +52,7 @@ const ( var ( readBufSize = 4096 writeBufSize = 4096 - wChanSize = 1024 + wChanSize = 1024 ) var windowAcknowledgementSize = 5000000 diff --git a/pkg/rtmp/server_session.go b/pkg/rtmp/server_session.go index d5b01cf..14bc055 100644 --- a/pkg/rtmp/server_session.go +++ b/pkg/rtmp/server_session.go @@ -15,7 +15,6 @@ import ( // TODO chef: 没有进化成Pub Sub时的超时释放 - type ServerSessionObserver interface { NewRTMPPubSessionCB(session *ServerSession) // 上层代码应该在这个事件回调中注册音视频数据的监听 NewRTMPSubSessionCB(session *ServerSession) diff --git a/pkg/util/assert/assert.go b/pkg/util/assert/assert.go new file mode 100644 index 0000000..e2b9b4d --- /dev/null +++ b/pkg/util/assert/assert.go @@ -0,0 +1,53 @@ +// Package assert 提供了单元测试时的断言功能 +// +// 代码参考了 https://github.com/stretchr/testify +// +package assert + +import ( + "bytes" + "reflect" +) + +type TestingT interface { + Errorf(format string, args ...interface{}) +} + +func Equal(t TestingT, expected interface{}, actual interface{}, msg string) { + if !equal(expected, actual) { + t.Errorf("%s expected=%+v, actual=%+v", msg, expected, actual) + } + return +} + +func isNil(actual interface{}) bool { + if actual == nil { + return true + } + v := reflect.ValueOf(actual) + k := v.Kind() + if k == reflect.Chan || k == reflect.Map || k == reflect.Ptr || k == reflect.Interface || k == reflect.Slice { + return v.IsNil() + } + return false +} + +func equal(expected, actual interface{}) bool { + if expected == nil { + return isNil(actual) + } + + exp, ok := expected.([]byte) + if !ok { + return reflect.DeepEqual(expected, actual) + } + + act, ok := actual.([]byte) + if !ok { + return false + } + //if exp == nil || act == nil { + // return exp == nil && act == nil + //} + return bytes.Equal(exp, act) +} \ No newline at end of file diff --git a/pkg/util/assert/assert_test.go b/pkg/util/assert/assert_test.go new file mode 100644 index 0000000..eb45d05 --- /dev/null +++ b/pkg/util/assert/assert_test.go @@ -0,0 +1,26 @@ +package assert + +import "testing" +//import aaa "github.com/stretchr/testify/assert" + +func TestEqual(t *testing.T) { + Equal(t, nil, nil, "fxxk.") + Equal(t, 1, 1, "fxxk.") + Equal(t, "aaa", "aaa", "fxxk.") + var ch chan struct{} + Equal(t, nil, ch, "fxxk.") + var m map[string]string + Equal(t, nil, m, "fxxk.") + var p *int + Equal(t, nil, p, "fxxk.") + var i interface{} + Equal(t, nil, i, "fxxk.") + var b []byte + Equal(t, nil, b, "fxxk.") + + Equal(t, true, isNil(nil), "fxxk.") + Equal(t, false, isNil("aaa"), "fxxk.") + Equal(t, false, equal([]byte{}, "aaa"), "fxxk.") + Equal(t, true, equal([]byte{}, []byte{}), "fxxk.") + Equal(t, true, equal([]byte{0, 1, 2}, []byte{0, 1, 2}), "fxxk.") +} diff --git a/pkg/util/bele/bele_test.go b/pkg/util/bele/bele_test.go index 195c41d..60551a6 100644 --- a/pkg/util/bele/bele_test.go +++ b/pkg/util/bele/bele_test.go @@ -3,7 +3,7 @@ package bele import ( "bytes" "encoding/binary" - "github.com/stretchr/testify/assert" + "github.com/q191201771/lal/pkg/util/assert" "testing" ) @@ -59,8 +59,6 @@ func TestBEUint32(t *testing.T) { } } -// TODO chef: test BEFloat64 - func TestBEFloat64(t *testing.T) { vector := []int{ 1, @@ -131,6 +129,25 @@ func TestBEPutUint32(t *testing.T) { } } +func TestLEPutUint32(t *testing.T) { + vector := []struct { + input uint32 + output []byte + }{ + {input: 0, output: []byte{0, 0, 0, 0}}, + {input: 1 * 256 * 256, output: []byte{0, 0, 1, 0}}, + {input: 1 * 256, output: []byte{0, 1, 0, 0}}, + {input: 1, output: []byte{1, 0, 0, 0}}, + {input: 78*256*256*256 + 56*256*256 + 34*256 + 12, output: []byte{12, 34, 56, 78}}, + } + + out := make([]byte, 4) + for i := 0; i < len(vector); i++ { + LEPutUint32(out, vector[i].input) + assert.Equal(t, vector[i].output, out, "fxxk.") + } +} + func TestWriteBEUint24(t *testing.T) { vector := []struct { input uint32 diff --git a/pkg/util/log/log_test.go b/pkg/util/log/log_test.go index c54cda9..d93c8a8 100644 --- a/pkg/util/log/log_test.go +++ b/pkg/util/log/log_test.go @@ -1,21 +1,19 @@ -package log_test +package log import ( - "github.com/q191201771/lal/pkg/util/log" "testing" + "github.com/q191201771/lal/pkg/util/assert" ) func TestLogger(t *testing.T) { - c := log.Config{ - Level: log.LevelInfo, + c := Config{ + Level: LevelInfo, Filename: "/tmp/lallogtest/aaa.log", IsToStdout: true, RotateMByte: 10, } - l, err := log.New(c) - if err != nil { - panic(err) - } + l, err := New(c) + assert.Equal(t, nil, err, "fxxk.") l.Debugf("test msg by Debug%s", "f") l.Infof("test msg by Info%s", "f") l.Warnf("test msg by Warn%s", "f") @@ -27,47 +25,69 @@ func TestLogger(t *testing.T) { } func TestGlobal(t *testing.T) { - log.Debugf("test msg by Debug%s", "f") - log.Infof("test msg by Info%s", "f") - log.Warnf("test msg by Warn%s", "f") - log.Errorf("test msg by Error%s", "f") - log.Debug("test msg by Debug") - log.Info("test msg by Info") - log.Warn("test msg by Warn") - log.Error("test msg by Error") + Debugf("test msg by Debug%s", "f") + Infof("test msg by Info%s", "f") + Warnf("test msg by Warn%s", "f") + Errorf("test msg by Error%s", "f") + Debug("test msg by Debug") + Info("test msg by Info") + Warn("test msg by Warn") + Error("test msg by Error") - c := log.Config{ - Level: log.LevelInfo, + c := Config{ + Level: LevelInfo, Filename: "/tmp/lallogtest/bbb.log", IsToStdout: true, RotateMByte: 10, } - err := log.Init(c) - if err != nil { - panic(err) + err := Init(c) + assert.Equal(t, nil, err, "fxxk.") + Debugf("test msg by Debug%s", "f") + Infof("test msg by Info%s", "f") + Warnf("test msg by Warn%s", "f") + Errorf("test msg by Error%s", "f") + Debug("test msg by Debug") + Info("test msg by Info") + Warn("test msg by Warn") + Error("test msg by Error") + Output(LevelInfo, 3, "test msg by Output") + Outputf(LevelInfo, 3, "test msg by Output%s", "f") +} + +func TestNew(t *testing.T) { + l, err := New(Config{Level:LevelError+1}) + assert.Equal(t, nil, l, "fxxk.") + assert.Equal(t, logErr, err, "fxxk.") +} + +func TestRotate(t *testing.T) { + c := Config{ + Level: LevelInfo, + Filename: "/tmp/lallogtest/ccc.log", + IsToStdout: false, + RotateMByte: 1, + } + err := Init(c) + assert.Equal(t, nil, err, "fxxk.") + b := make([]byte, 1024) + for i := 0; i < 2 * 1024; i++ { + Info(b) + } + for i := 0; i < 2 * 1024; i++ { + Infof("%+v", b) } - log.Debugf("test msg by Debug%s", "f") - log.Infof("test msg by Info%s", "f") - log.Warnf("test msg by Warn%s", "f") - log.Errorf("test msg by Error%s", "f") - log.Debug("test msg by Debug") - log.Info("test msg by Info") - log.Warn("test msg by Warn") - log.Error("test msg by Error") } func BenchmarkStdout(b *testing.B) { - c := log.Config{ - Level: log.LevelInfo, + c := Config{ + Level: LevelInfo, Filename: "/tmp/lallogtest/ccc.log", IsToStdout: true, RotateMByte: 10, } - err := log.Init(c) - if err != nil { - panic(err) - } + err := Init(c) + assert.Equal(b, nil, err, "fxxk.") for i := 0; i < b.N; i++ { - log.Infof("hello %s %d", "world", i) + Infof("hello %s %d", "world", i) } } diff --git a/pkg/util/unique/unique_test.go b/pkg/util/unique/unique_test.go index a72c895..b7509f1 100644 --- a/pkg/util/unique/unique_test.go +++ b/pkg/util/unique/unique_test.go @@ -1,7 +1,7 @@ package unique import ( - "github.com/stretchr/testify/assert" + "github.com/q191201771/lal/pkg/util/assert" "sync" "testing" )