diff --git a/app/demo/analyseflv/analyseflv.go b/app/demo/analyseflv/analyseflv.go index 0f2a33d..db6cbd1 100644 --- a/app/demo/analyseflv/analyseflv.go +++ b/app/demo/analyseflv/analyseflv.go @@ -199,45 +199,36 @@ func analysisVideoTag(tag httpflv.Tag) { buf.WriteString(" [HEVC SeqHeader] ") } } else { - body := tag.Raw[11:] + body := tag.Raw[httpflv.TagHeaderSize+5 : len(tag.Raw)-httpflv.PrevTagSizeFieldSize] + nals, err := avc.IterateNALUAVCC(body) + nazalog.Assert(nil, err) - i := 5 - for i != int(tag.Header.DataSize) { - if i+4 > int(tag.Header.DataSize) { - nazalog.Errorf("invalid nalu size. i=%d, tag size=%d", i, int(tag.Header.DataSize)) - break - } - naluLen := bele.BEUint32(body[i:]) - if i+int(naluLen) > int(tag.Header.DataSize) { - nazalog.Errorf("invalid nalu size. i=%d, naluLen=%d, tag size=%d", i, naluLen, int(tag.Header.DataSize)) - break - } + for _, nal := range nals { switch t { case typeAVC: - if avc.ParseNALUType(body[i+4]) == avc.NALUTypeIDRSlice { + if avc.ParseNALUType(nal[0]) == avc.NALUTypeIDRSlice { if prevIDRTS != int64(-1) { diffIDRTS = int64(tag.Header.Timestamp) - prevIDRTS } prevIDRTS = int64(tag.Header.Timestamp) } - if avc.ParseNALUType(body[i+4]) == avc.NALUTypeSEI { - delay := SEIDelayMS(body[i+4 : i+4+int(naluLen)]) + if avc.ParseNALUType(nal[0]) == avc.NALUTypeSEI { + delay := SEIDelayMS(nal) if delay != -1 { buf.WriteString(fmt.Sprintf("delay: %dms", delay)) } } - sliceTypeReadable, _ := avc.ParseSliceTypeReadable(body[i+4:]) - buf.WriteString(fmt.Sprintf(" [%s(%s)(%d)] ", avc.ParseNALUTypeReadable(body[i+4]), sliceTypeReadable, naluLen)) + sliceTypeReadable, _ := avc.ParseSliceTypeReadable(nal) + buf.WriteString(fmt.Sprintf(" [%s(%s)(%d)] ", avc.ParseNALUTypeReadable(nal[0]), sliceTypeReadable, len(nal))) case typeHEVC: - if hevc.ParseNALUType(body[i+4]) == hevc.NALUTypeSEI { - delay := SEIDelayMS(body[i+4 : i+4+int(naluLen)]) + if hevc.ParseNALUType(nal[0]) == hevc.NALUTypeSEI { + delay := SEIDelayMS(nal) if delay != -1 { buf.WriteString(fmt.Sprintf("delay: %dms", delay)) } } - buf.WriteString(fmt.Sprintf(" [%s(%d)] ", hevc.ParseNALUTypeReadable(body[i+4]), body[i+4])) + buf.WriteString(fmt.Sprintf(" [%s(%d)] ", hevc.ParseNALUTypeReadable(nal[0]), nal[0])) } - i = i + 4 + int(naluLen) } } if analysisVideoTagFlag { diff --git a/pkg/avc/avc.go b/pkg/avc/avc.go index 6d3c602..dc11540 100644 --- a/pkg/avc/avc.go +++ b/pkg/avc/avc.go @@ -36,13 +36,16 @@ var ( AUDNALU = []byte{0x00, 0x00, 0x00, 0x01, 0x09, 0xf0} ) +// H.264-AVC-ISO_IEC_14496-15.pdf +// Table 1 - NAL unit types in elementary streams var NALUTypeMapping = map[uint8]string{ - 1: "SLICE", - 5: "IDR", - 6: "SEI", - 7: "SPS", - 8: "PPS", - 9: "AUD", + 1: "SLICE", + 5: "IDR", + 6: "SEI", + 7: "SPS", + 8: "PPS", + 9: "AUD", + 12: "FD", } var SliceTypeMapping = map[uint8]string{ @@ -64,7 +67,8 @@ const ( NALUTypeSEI uint8 = 6 NALUTypeSPS uint8 = 7 NALUTypePPS uint8 = 8 - NALUTypeAUD uint8 = 9 // Access Unit Delimiter + NALUTypeAUD uint8 = 9 // Access Unit Delimiter + NALUTypeFD uint8 = 12 // Filler Data ) const ( @@ -194,7 +198,11 @@ func ParseSliceTypeReadable(nalu []byte) (string, error) { } // AVCC Seq Header -> AnnexB -// 注意,返回的内存块为独立的内存块,不依赖指向传输参数内存块 +// +// @param payload: rtmp message的payload部分或者flv tag的payload部分 +// 注意,包含了头部2字节类型以及3字节的cts +// +// @return 注意,返回的内存块为独立的内存块,不依赖指向传输参数内存块 // func SPSPPSSeqHeader2AnnexB(payload []byte) ([]byte, error) { sps, pps, err := ParseSPSPPSFromSeqHeader(payload) @@ -211,8 +219,8 @@ func SPSPPSSeqHeader2AnnexB(payload []byte) ([]byte, error) { // 从AVCC格式的Seq Header中得到SPS和PPS内存块 // -// @param rtmp message的payload部分或者flv tag的payload部分 -// 注意,包含了头部2字节类型以及3字节的cts +// @param payload: rtmp message的payload部分或者flv tag的payload部分 +// 注意,包含了头部2字节类型以及3字节的cts // // @return 注意,返回的sps,pps内存块指向的是传入参数内存块的内存 // @@ -310,8 +318,8 @@ func BuildSeqHeaderFromSPSPPS(sps, pps []byte) ([]byte, error) { // AVCC -> AnnexB // -// @param rtmp message的payload部分或者flv tag的payload部分 -// 注意,包含了头部2字节类型以及3字节的cts +// @param payload: rtmp message的payload部分或者flv tag的payload部分 +// 注意,包含了头部2字节类型以及3字节的cts // func CaptureAVCC2AnnexB(w io.Writer, payload []byte) error { // sps pps @@ -335,3 +343,112 @@ func CaptureAVCC2AnnexB(w io.Writer, payload []byte) error { } return nil } + +// 遍历直到找到第一个nalu start code的位置 +// +// @param start: 从`nalu`的start位置开始查找 +// +// @return pos: start code的起始位置(包含start code自身) +// length: start code的长度,可能是3或者4 +// 注意,如果找不到start code,则返回-1, -1 +// +func IterateNALUStartCode(nalu []byte, start int) (pos, length int) { + if nalu == nil || start >= len(nalu) { + return -1, -1 + } + count := 0 + for i := range nalu[start:] { + switch nalu[start+i] { + case 0: + count++ + case 1: + if count >= 2 { + return start + i - count, count + 1 + } + count = 0 + default: + count = 0 + } + } + return -1, -1 +} + +// 遍历AnnexB格式,去掉start code,获取nal包,正常情况下可能为1个或多个,异常情况下可能一个也没有 +// +// 具体见单元测试 +// +func IterateNALUAnnexB(nals []byte) (nalList [][]byte, err error) { + if nals == nil { + err = ErrAVC + return + } + prePos, preLength := IterateNALUStartCode(nals, 0) + if prePos == -1 { + nalList = append(nalList, nals) + err = ErrAVC + return + } + + for { + pos, length := IterateNALUStartCode(nals, prePos+preLength) + start := prePos + preLength + if pos == -1 { + if start < len(nals) { + nalList = append(nalList, nals[start:]) + } else { + err = ErrAVC + } + return + } + if start < pos { + nalList = append(nalList, nals[start:pos]) + } else { + err = ErrAVC + } + + prePos = pos + preLength = length + } +} + +// 遍历AVCC格式,去掉4字节长度,获取nal包,正常情况下可能返回1个或多个,异常情况下可能一个也没有 +// +// 具体见单元测试 +// +func IterateNALUAVCC(nals []byte) (nalList [][]byte, err error) { + if nals == nil { + err = ErrAVC + return + } + pos := 0 + for { + if len(nals[pos:]) < 4 { + err = ErrAVC + return + } + length := int(bele.BEUint32(nals[pos:])) + pos += 4 + if pos == len(nals) { + err = ErrAVC + return + } + epos := pos + length + if epos < len(nals) { + // 非最后一个 + nalList = append(nalList, nals[pos:epos]) + pos += length + } else if epos == len(nals) { + // 最后一个 + nalList = append(nalList, nals[pos:epos]) + return + } else { + nalList = append(nalList, nals[pos:]) + err = ErrAVC + return + } + } +} + +// TODO(chef) +// func NALUAVCC2AnnexB +// func NALUAnnexB2AVCC diff --git a/pkg/avc/avc_test.go b/pkg/avc/avc_test.go index ca15259..a2411e0 100644 --- a/pkg/avc/avc_test.go +++ b/pkg/avc/avc_test.go @@ -10,6 +10,7 @@ package avc_test import ( "bytes" + "fmt" "testing" "github.com/q191201771/naza/pkg/nazabits" @@ -153,3 +154,145 @@ func TestParsePPS_Case2(t *testing.T) { assert.Equal(t, uint32(1280), ctx.Width) assert.Equal(t, uint32(960), ctx.Height) } + +func TestIterateStartCode(t *testing.T) { + golden := []struct { + nalu []byte + pos int + length int + }{ + { + nalu: []byte{0, 0, 1}, + pos: 0, + length: 3, + }, + { + nalu: []byte{0, 0, 0, 1}, + pos: 0, + length: 4, + }, + { + nalu: []byte{0xa, 0, 0, 0, 1}, + pos: 1, + length: 4, + }, + { + nalu: []byte{0, 1}, + pos: -1, + length: -1, + }, + { + nalu: []byte{0xa, 0xb}, + pos: -1, + length: -1, + }, + } + + for _, v := range golden { + pos, length := avc.IterateNALUStartCode(v.nalu, 0) + assert.Equal(t, v.pos, pos) + assert.Equal(t, v.length, length) + } +} + +func TestIterateNALUAnnexB(t *testing.T) { + golden := []struct { + nals []byte + nalList [][]byte + err error + }{ + { + nals: []byte{0, 0, 1, 0xa, 0xb}, + nalList: [][]byte{ + {0xa, 0xb}, + }, + err: nil, + }, + { + nals: []byte{0, 0, 0, 1, 0xa, 0xb, 0, 0, 0, 1, 0xc, 0xd}, + nalList: [][]byte{ + {0xa, 0xb}, + {0xc, 0xd}, + }, + err: nil, + }, + { + nals: []byte{0xa, 0xb}, + nalList: [][]byte{ + {0xa, 0xb}, + }, + err: avc.ErrAVC, + }, + { + nals: []byte{0, 0, 1}, + nalList: nil, + err: avc.ErrAVC, + }, + { + nals: []byte{0, 0, 1, 0, 0, 1}, + nalList: nil, + err: avc.ErrAVC, + }, + { + nals: nil, + nalList: nil, + err: avc.ErrAVC, + }, + } + for _, v := range golden { + nalList, err := avc.IterateNALUAnnexB(v.nals) + assert.Equal(t, v.nalList, nalList) + assert.Equal(t, v.err, err, fmt.Sprintf("%+v", v)) + } +} + +func TestIterateNALUAVCC(t *testing.T) { + golden := []struct { + nals []byte + nalList [][]byte + err error + }{ + { + nals: []byte{0, 0, 0, 1, 0xa}, // 正常,1个 + nalList: [][]byte{ + {0xa}, + }, + err: nil, + }, + { + nals: []byte{0, 0, 0, 1, 0xa, 0, 0, 0, 2, 0xa, 0xb}, // 正常,2个 + nalList: [][]byte{ + {0xa}, + {0xa, 0xb}, + }, + err: nil, + }, + { + nals: []byte{0, 0}, // length不全 + nalList: nil, + err: avc.ErrAVC, + }, + { + nals: nil, + nalList: nil, + err: avc.ErrAVC, + }, + { + nals: []byte{0, 0, 0, 1}, // 只有length + nalList: nil, + err: avc.ErrAVC, + }, + { + nals: []byte{0, 0, 0, 2, 0xa}, // 包体数据不全 + nalList: [][]byte{ + {0xa}, + }, + err: avc.ErrAVC, + }, + } + for _, v := range golden { + nalList, err := avc.IterateNALUAVCC(v.nals) + assert.Equal(t, v.nalList, nalList) + assert.Equal(t, v.err, err) + } +} diff --git a/pkg/hls/streamer.go b/pkg/hls/streamer.go index 8fdb8be..61ce815 100644 --- a/pkg/hls/streamer.go +++ b/pkg/hls/streamer.go @@ -9,6 +9,8 @@ package hls import ( + "encoding/hex" + "github.com/q191201771/lal/pkg/aac" "github.com/q191201771/lal/pkg/avc" "github.com/q191201771/lal/pkg/base" @@ -16,6 +18,7 @@ import ( "github.com/q191201771/lal/pkg/mpegts" "github.com/q191201771/naza/pkg/bele" "github.com/q191201771/naza/pkg/nazalog" + "github.com/q191201771/naza/pkg/nazastring" ) type StreamerObserver interface { @@ -124,25 +127,19 @@ func (s *Streamer) feedVideo(msg base.RTMPMsg) { // 优化这块buffer out := s.videoOut[0:0] - // tag中可能有多个NALU,逐个获取 - for i := 5; i != len(msg.Payload); { - if i+4 > len(msg.Payload) { - nazalog.Errorf("[%s] slice len not enough. i=%d, len=%d", s.UniqueKey, i, len(msg.Payload)) - return - } - nalBytes := int(bele.BEUint32(msg.Payload[i:])) - i += 4 - if i+nalBytes > len(msg.Payload) { - nazalog.Errorf("[%s] slice len not enough. i=%d, payload len=%d, nalBytes=%d", s.UniqueKey, i, len(msg.Payload), nalBytes) - return - } - + // msg中可能有多个NALU,逐个获取 + nals, err := avc.IterateNALUAVCC(msg.Payload[5:]) + if err != nil { + nazalog.Errorf("[%s] iterate nalu failed. err=%+v, payload=%s", err, s.UniqueKey, hex.Dump(nazastring.SubSliceSafety(msg.Payload, 32))) + return + } + for _, nal := range nals { var nalType uint8 switch codecID { case base.RTMPCodecIDAVC: - nalType = avc.ParseNALUType(msg.Payload[i]) + nalType = avc.ParseNALUType(nal[0]) case base.RTMPCodecIDHEVC: - nalType = hevc.ParseNALUType(msg.Payload[i]) + nalType = hevc.ParseNALUType(nal[0]) } //nazalog.Debugf("[%s] naltype=%d, len=%d(%d), cts=%d, key=%t.", s.UniqueKey, nalType, nalBytes, len(msg.Payload), cts, msg.IsVideoKeyNALU()) @@ -152,7 +149,6 @@ func (s *Streamer) feedVideo(msg base.RTMPMsg) { // aud有自己的写入逻辑 if (codecID == base.RTMPCodecIDAVC && (nalType == avc.NALUTypeSPS || nalType == avc.NALUTypePPS || nalType == avc.NALUTypeAUD)) || (codecID == base.RTMPCodecIDHEVC && (nalType == hevc.NALUTypeVPS || nalType == hevc.NALUTypeSPS || nalType == hevc.NALUTypePPS || nalType == hevc.NALUTypeAUD)) { - i += nalBytes continue } @@ -209,9 +205,7 @@ func (s *Streamer) feedVideo(msg base.RTMPMsg) { out = append(out, avc.NALUStartCode3...) } - out = append(out, msg.Payload[i:i+nalBytes]...) - - i += nalBytes + out = append(out, nal...) } dts := uint64(msg.Header.TimestampAbs) * 90 diff --git a/pkg/remux/avpacket2flv.go b/pkg/remux/avpacket2flv.go index 284ae78..70abfd9 100644 --- a/pkg/remux/avpacket2flv.go +++ b/pkg/remux/avpacket2flv.go @@ -146,13 +146,15 @@ func AVPacket2FLVTag(pkt base.AVPacket) (tag httpflv.Tag, err error) { tag.Raw[9] = 0 tag.Raw[10] = 0 - // TODO chef: 这段代码应该放在更合适的地方,或者在AVPacket中标识是否包含关键帧 - for i := 0; i != len(pkt.Payload); { - naluSize := int(bele.BEUint32(pkt.Payload[i:])) - + var nals [][]byte + nals, err = avc.IterateNALUAVCC(pkt.Payload) + if err != nil { + return + } + for _, nal := range nals { switch pkt.PayloadType { case base.AVPacketPTAVC: - t := avc.ParseNALUType(pkt.Payload[i+4]) + t := avc.ParseNALUType(nal[0]) if t == avc.NALUTypeIDRSlice { tag.Raw[httpflv.TagHeaderSize] = base.RTMPAVCKeyFrame } else { @@ -160,7 +162,7 @@ func AVPacket2FLVTag(pkt base.AVPacket) (tag httpflv.Tag, err error) { } tag.Raw[httpflv.TagHeaderSize+1] = base.RTMPAVCPacketTypeNALU case base.AVPacketPTHEVC: - t := hevc.ParseNALUType(pkt.Payload[i+4]) + t := hevc.ParseNALUType(nal[0]) if t == hevc.NALUTypeSliceIDR || t == hevc.NALUTypeSliceIDRNLP { tag.Raw[httpflv.TagHeaderSize] = base.RTMPHEVCKeyFrame } else { @@ -168,8 +170,6 @@ func AVPacket2FLVTag(pkt base.AVPacket) (tag httpflv.Tag, err error) { } tag.Raw[httpflv.TagHeaderSize+1] = base.RTMPHEVCPacketTypeNALU } - - i += 4 + naluSize } tag.Raw[httpflv.TagHeaderSize+2] = 0x0 // cts diff --git a/pkg/remux/avpacket2rtmp.go b/pkg/remux/avpacket2rtmp.go index f641ac6..0e01e4c 100644 --- a/pkg/remux/avpacket2rtmp.go +++ b/pkg/remux/avpacket2rtmp.go @@ -14,7 +14,6 @@ import ( "github.com/q191201771/lal/pkg/base" "github.com/q191201771/lal/pkg/hevc" "github.com/q191201771/lal/pkg/rtmp" - "github.com/q191201771/naza/pkg/bele" ) // @return 返回的内存块为新申请的独立内存块 @@ -139,13 +138,15 @@ func AVPacket2RTMPMsg(pkt base.AVPacket) (msg base.RTMPMsg, err error) { msg.Payload = make([]byte, msg.Header.MsgLen) - // TODO chef: 这段代码应该放在更合适的地方,或者在AVPacket中标识是否包含关键帧 - for i := 0; i != len(pkt.Payload); { - naluSize := int(bele.BEUint32(pkt.Payload[i:])) - + var nals [][]byte + nals, err = avc.IterateNALUAVCC(pkt.Payload) + if err != nil { + return + } + for _, nal := range nals { switch pkt.PayloadType { case base.AVPacketPTAVC: - t := avc.ParseNALUType(pkt.Payload[i+4]) + t := avc.ParseNALUType(nal[0]) if t == avc.NALUTypeIDRSlice { msg.Payload[0] = base.RTMPAVCKeyFrame } else { @@ -153,7 +154,7 @@ func AVPacket2RTMPMsg(pkt base.AVPacket) (msg base.RTMPMsg, err error) { } msg.Payload[1] = base.RTMPAVCPacketTypeNALU case base.AVPacketPTHEVC: - t := hevc.ParseNALUType(pkt.Payload[i+4]) + t := hevc.ParseNALUType(nal[0]) if t == hevc.NALUTypeSliceIDR || t == hevc.NALUTypeSliceIDRNLP { msg.Payload[0] = base.RTMPHEVCKeyFrame } else { @@ -161,8 +162,6 @@ func AVPacket2RTMPMsg(pkt base.AVPacket) (msg base.RTMPMsg, err error) { } msg.Payload[1] = base.RTMPHEVCPacketTypeNALU } - - i += 4 + naluSize } msg.Payload[2] = 0x0 // cts