1. stream.msgLen -> stream.header.MsgLen 2. rtmp stream name with url raw query 3.

pull/200/head
q191201771 6 years ago
parent 750e7adce0
commit 89af181710

@ -6,50 +6,53 @@ import (
"github.com/q191201771/lal/pkg/rtmp" "github.com/q191201771/lal/pkg/rtmp"
"github.com/q191201771/nezha/pkg/errors" "github.com/q191201771/nezha/pkg/errors"
"github.com/q191201771/nezha/pkg/log" "github.com/q191201771/nezha/pkg/log"
"io"
"os" "os"
"time" "time"
) )
// 将flv文件通过rtmp协议推送至rtmp服务器 // 将flv文件通过rtmp协议推送至rtmp服务器
// //
// -r 表示当文件推送完毕后,是否循环推送
//
// Usage: // Usage:
// ./bin/flvfile2rtmppush -i /tmp/test.flv -o rtmp://push.xxx.com/live/testttt // ./bin/flvfile2rtmppush -r 1 -i /tmp/test.flv -o rtmp://push.xxx.com/live/testttt
func main() { func main() {
flvFileName, rtmpPushURL := parseFlag() var err error
flvFileName, rtmpPushURL, isRecursive := parseFlag()
ps := rtmp.NewPushSession(5000)
err = ps.Push(rtmpPushURL)
errors.PanicIfErrorOccur(err)
log.Infof("push succ.")
var baseTS int
var prevTS int
for {
var ffr httpflv.FlvFileReader var ffr httpflv.FlvFileReader
err := ffr.Open(flvFileName) err = ffr.Open(flvFileName)
errors.PanicIfErrorOccur(err) errors.PanicIfErrorOccur(err)
defer ffr.Dispose()
log.Infof("open succ.") log.Infof("open succ.")
flvHeader, err := ffr.ReadFlvHeader() flvHeader, err := ffr.ReadFlvHeader()
errors.PanicIfErrorOccur(err) errors.PanicIfErrorOccur(err)
log.Infof("read flv header succ. %v", flvHeader) log.Infof("read flv header succ. %v", flvHeader)
ps := rtmp.NewPushSession(5000)
err = ps.Push(rtmpPushURL)
errors.PanicIfErrorOccur(err)
log.Infof("push succ.")
var prevTS uint32
firstA := true
firstV := true
//var aPrevH *rtmp.Header
//var vPrevH *rtmp.Header
//for i := 0; i < 1000*1000; i++ {
for { for {
tag, err := ffr.ReadTag() tag, err := ffr.ReadTag()
if err == io.EOF {
log.Info("EOF")
break
}
errors.PanicIfErrorOccur(err) errors.PanicIfErrorOccur(err)
//log.Infof("tag: %+v %v", tag.Header, tag.Raw[11:])
//log.Infof("tag: %+v %d", tag.Header, len(tag.Raw))
// TODO chef: 转换代码放入lal某个包中 // TODO chef: 转换代码放入lal某个包中
var h rtmp.Header var h rtmp.Header
h.MsgLen = int(tag.Header.DataSize) //len(tag.Raw)-httpflv.TagHeaderSize h.MsgLen = int(tag.Header.DataSize) //len(tag.Raw)-httpflv.TagHeaderSize
h.Timestamp = int(tag.Header.Timestamp) h.Timestamp = int(tag.Header.Timestamp) + int(baseTS)
h.MsgTypeID = int(tag.Header.T) h.MsgTypeID = int(tag.Header.T)
h.MsgStreamID = rtmp.MSID1 h.MsgStreamID = rtmp.MSID1
switch tag.Header.T { switch tag.Header.T {
@ -61,76 +64,40 @@ func main() {
h.CSID = rtmp.CSIDVideo h.CSID = rtmp.CSIDVideo
} }
// 把第一个音频和视频的时间戳改成0 var diff int
if tag.Header.T == httpflv.TagTypeAudio && firstA { if h.Timestamp >= prevTS {
h.Timestamp = 0 diff = int(h.Timestamp) - prevTS
firstA = false } else {
} h.Timestamp = prevTS
if tag.Header.T == httpflv.TagTypeVideo && firstV {
h.Timestamp = 0
firstV = false
} }
//var chunks []byte
//if tag.Header.T == httpflv.TagTypeVideo {
// chunks = rtmp.Message2Chunks(tag.Raw[11:11+h.MsgLen], &h, aPrevH, rtmp.LocalChunkSize)
// aPrevH = &h
//}
//if tag.Header.T == httpflv.TagTypeVideo {
// chunks = rtmp.Message2Chunks(tag.Raw[11:11+h.MsgLen], &h, vPrevH, rtmp.LocalChunkSize)
// vPrevH = &h
//}
//if tag.Header.T == httpflv.TagTypeVideo {
// chunks = rtmp.Message2Chunks(tag.Raw[11:11+h.MsgLen], &h, nil, rtmp.LocalChunkSize)
//}
chunks := rtmp.Message2Chunks(tag.Raw[11:11+h.MsgLen], &h, rtmp.LocalChunkSize) chunks := rtmp.Message2Chunks(tag.Raw[11:11+h.MsgLen], &h, rtmp.LocalChunkSize)
// 第一个包直接发送 log.Debugf("before send. diff=%d, ts=%d, prevTS=%d", diff, h.Timestamp, prevTS)
if prevTS == 0 { time.Sleep(time.Duration(diff) * time.Millisecond)
err = ps.TmpWrite(chunks) log.Debug("send")
errors.PanicIfErrorOccur(err)
prevTS = tag.Header.Timestamp
continue
}
// 相等或回退了直接发送
if tag.Header.Timestamp <= prevTS {
err = ps.TmpWrite(chunks) err = ps.TmpWrite(chunks)
errors.PanicIfErrorOccur(err) errors.PanicIfErrorOccur(err)
prevTS = tag.Header.Timestamp prevTS = h.Timestamp
continue
} }
if tag.Header.Timestamp > prevTS { baseTS = prevTS + 1
diff := tag.Header.Timestamp - prevTS ffr.Dispose()
// 跳跃超过了30秒直接发送 if !isRecursive {
if diff > 30000 { break
err = ps.TmpWrite(chunks)
errors.PanicIfErrorOccur(err)
prevTS = tag.Header.Timestamp
continue
} }
// 睡眠后发送,睡眠时长为时间戳间隔
time.Sleep(time.Duration(diff) * time.Millisecond)
err = ps.TmpWrite(chunks)
errors.PanicIfErrorOccur(err)
prevTS = tag.Header.Timestamp
continue
}
panic("should not reach here.")
} }
} }
func parseFlag() (string, string) { func parseFlag() (string, string, bool) {
i := flag.String("i", "", "specify flv file") i := flag.String("i", "", "specify flv file")
o := flag.String("o", "", "specify rtmp push url") o := flag.String("o", "", "specify rtmp push url")
r := flag.Bool("r", false, "recursive push if reach end of file")
flag.Parse() flag.Parse()
if *i == "" || *o == "" { if *i == "" || *o == "" {
flag.Usage() flag.Usage()
os.Exit(1) os.Exit(1)
} }
return *i, *o return *i, *o, *r
} }

@ -6,7 +6,6 @@ import (
"github.com/q191201771/nezha/pkg/errors" "github.com/q191201771/nezha/pkg/errors"
"github.com/q191201771/nezha/pkg/log" "github.com/q191201771/nezha/pkg/log"
"os" "os"
"time"
) )
type Obs struct { type Obs struct {
@ -22,7 +21,8 @@ func main() {
session := rtmp.NewPullSession(obs, 2000) session := rtmp.NewPullSession(obs, 2000)
err := session.Pull(url) err := session.Pull(url)
errors.PanicIfErrorOccur(err) errors.PanicIfErrorOccur(err)
time.Sleep(1 * time.Hour) err := session.WaitLoop()
errors.PanicIfErrorOccur(err)
} }
func parseFlag() string { func parseFlag() string {

@ -90,12 +90,7 @@ func (session *PullSession) Connect(rawURL string) error {
} }
// # 建立连接 // # 建立连接
var conn net.Conn conn, err := net.DialTimeout("tcp", session.addr, time.Duration(session.config.ConnectTimeoutMS)*time.Millisecond)
if session.config.ConnectTimeoutMS == 0 {
conn, err = net.Dial("tcp", session.addr)
} else {
conn, err = net.DialTimeout("tcp", session.addr, time.Duration(session.config.ConnectTimeoutMS)*time.Millisecond)
}
if err != nil { if err != nil {
return err return err
} }

@ -69,11 +69,11 @@ func (c *ChunkComposer) RunLoop(reader io.Reader, cb CompleteMessageCB) error {
// 包头中为绝对时间戳 // 包头中为绝对时间戳
stream.header.Timestamp = int(bele.BEUint24(bootstrap)) stream.header.Timestamp = int(bele.BEUint24(bootstrap))
stream.timestampAbs = stream.header.Timestamp stream.timestampAbs = stream.header.Timestamp
stream.msgLen = int(bele.BEUint24(bootstrap[3:])) stream.header.MsgLen = int(bele.BEUint24(bootstrap[3:]))
stream.header.MsgTypeID = int(bootstrap[6]) stream.header.MsgTypeID = int(bootstrap[6])
stream.header.MsgStreamID = int(bele.LEUint32(bootstrap[7:])) stream.header.MsgStreamID = int(bele.LEUint32(bootstrap[7:]))
stream.msg.reserve(stream.msgLen) stream.msg.reserve(stream.header.MsgLen)
case 1: case 1:
if _, err := io.ReadAtLeast(reader, bootstrap[:7], 7); err != nil { if _, err := io.ReadAtLeast(reader, bootstrap[:7], 7); err != nil {
return err return err
@ -81,10 +81,10 @@ func (c *ChunkComposer) RunLoop(reader io.Reader, cb CompleteMessageCB) error {
// 包头中为相对时间戳 // 包头中为相对时间戳
stream.header.Timestamp = int(bele.BEUint24(bootstrap)) stream.header.Timestamp = int(bele.BEUint24(bootstrap))
stream.timestampAbs += stream.header.Timestamp stream.timestampAbs += stream.header.Timestamp
stream.msgLen = int(bele.BEUint24(bootstrap[3:])) stream.header.MsgLen = int(bele.BEUint24(bootstrap[3:]))
stream.header.MsgTypeID = int(bootstrap[6]) stream.header.MsgTypeID = int(bootstrap[6])
stream.msg.reserve(stream.msgLen) stream.msg.reserve(stream.header.MsgLen)
case 2: case 2:
if _, err := io.ReadAtLeast(reader, bootstrap[:3], 3); err != nil { if _, err := io.ReadAtLeast(reader, bootstrap[:3], 3); err != nil {
return err return err
@ -119,14 +119,13 @@ func (c *ChunkComposer) RunLoop(reader io.Reader, cb CompleteMessageCB) error {
} }
} }
//stream.header.CSID = csid //stream.header.CSID = csid
//stream.header.MsgLen = stream.msgLen
//log.Debugf("CHEFGREPME tag1 fmt:%d header:%+v csid:%d len:%d ts:%d", fmt, stream.header, csid, stream.msgLen, stream.timestampAbs) //log.Debugf("CHEFGREPME tag1 fmt:%d header:%+v csid:%d len:%d ts:%d", fmt, stream.header, csid, stream.msgLen, stream.timestampAbs)
var neededSize int var neededSize int
if stream.msgLen <= c.peerChunkSize { if stream.header.MsgLen <= c.peerChunkSize {
neededSize = stream.msgLen neededSize = stream.header.MsgLen
} else { } else {
neededSize = stream.msgLen - stream.msg.len() neededSize = stream.header.MsgLen - stream.msg.len()
if neededSize > c.peerChunkSize { if neededSize > c.peerChunkSize {
neededSize = c.peerChunkSize neededSize = c.peerChunkSize
} }
@ -138,7 +137,7 @@ func (c *ChunkComposer) RunLoop(reader io.Reader, cb CompleteMessageCB) error {
} }
stream.msg.produced(neededSize) stream.msg.produced(neededSize)
if stream.msg.len() == stream.msgLen { if stream.msg.len() == stream.header.MsgLen {
// 对端设置了chunk size // 对端设置了chunk size
if stream.header.MsgTypeID == typeidSetChunkSize { if stream.header.MsgTypeID == typeidSetChunkSize {
val := int(bele.BEUint32(stream.msg.buf)) val := int(bele.BEUint32(stream.msg.buf))
@ -146,14 +145,13 @@ func (c *ChunkComposer) RunLoop(reader io.Reader, cb CompleteMessageCB) error {
} }
stream.header.CSID = csid stream.header.CSID = csid
stream.header.MsgLen = stream.msgLen //log.Debugf("CHEFGREPME %+v %d %d", stream.header, stream.timestampAbs, stream.header.MsgLen)
//log.Debugf("CHEFGREPME %+v %d %d", stream.header, stream.timestampAbs, stream.msgLen)
if err := cb(stream); err != nil { if err := cb(stream); err != nil {
return err return err
} }
stream.msg.clear() stream.msg.clear()
} }
if stream.msg.len() > stream.msgLen { if stream.msg.len() > stream.header.MsgLen {
panic(0) panic(0)
} }
} }

@ -8,7 +8,7 @@ type PullSession struct {
*ClientSession *ClientSession
} }
func NewPullSession(obs PullSessionObserver, connectTimeout int64) *PullSession { func NewPullSession(obs PullSessionObserver, connectTimeout int) *PullSession {
return &PullSession{ return &PullSession{
ClientSession: NewClientSession(CSTPullSession, obs, connectTimeout), ClientSession: NewClientSession(CSTPullSession, obs, connectTimeout),
} }

@ -4,7 +4,7 @@ type PushSession struct {
*ClientSession *ClientSession
} }
func NewPushSession(connectTimeout int64) *PushSession { func NewPushSession(connectTimeout int) *PushSession {
return &PushSession{ return &PushSession{
ClientSession: NewClientSession(CSTPushSession, nil, connectTimeout), ClientSession: NewClientSession(CSTPushSession, nil, connectTimeout),
} }

@ -19,7 +19,8 @@ type ClientSession struct {
t ClientSessionType t ClientSessionType
obs PullSessionObserver // only for PullSession obs PullSessionObserver // only for PullSession
connectTimeout int64 stageCB StageCB
connectTimeoutMS int
doResultChan chan struct{} doResultChan chan struct{}
errChan chan error errChan chan error
packer *MessagePacker packer *MessagePacker
@ -45,8 +46,19 @@ const (
CSTPushSession CSTPushSession
) )
// set <obs> if <t> equal CSTPullSession type ClientSessionStage int
func NewClientSession(t ClientSessionType, obs PullSessionObserver, connectTimeout int64) *ClientSession {
const (
CSSConnConnectStart ClientSessionStage = iota
CSSConnConnectSucc
)
type StageCB func(stage ClientSessionStage)
// @param t: session的类型只能是推或者拉
// @param obs: 回调结束后buffer会被重复使用
// @param connectTimeoutMS: 建立连接超时,单位毫秒
func NewClientSession(t ClientSessionType, obs PullSessionObserver, connectTimeoutMS int) *ClientSession {
var uk string var uk string
switch t { switch t {
case CSTPullSession: case CSTPullSession:
@ -58,7 +70,7 @@ func NewClientSession(t ClientSessionType, obs PullSessionObserver, connectTimeo
return &ClientSession{ return &ClientSession{
t: t, t: t,
obs: obs, obs: obs,
connectTimeout: connectTimeout, connectTimeoutMS: connectTimeoutMS,
doResultChan: make(chan struct{}), doResultChan: make(chan struct{}),
errChan: make(chan error), errChan: make(chan error),
packer: NewMessagePacker(), packer: NewMessagePacker(),
@ -68,7 +80,7 @@ func NewClientSession(t ClientSessionType, obs PullSessionObserver, connectTimeo
} }
} }
// 阻塞直到收到服务端的 publish start / play start 信令 或者超时 // 阻塞直到收到服务端返回的 publish start / play start 信令 或者超时
func (s *ClientSession) Do(rawURL string) error { func (s *ClientSession) Do(rawURL string) error {
if err := s.parseURL(rawURL); err != nil { if err := s.parseURL(rawURL); err != nil {
return err return err
@ -91,16 +103,13 @@ func (s *ClientSession) Do(rawURL string) error {
s.errChan <- s.runReadLoop() s.errChan <- s.runReadLoop()
}() }()
t := time.NewTimer(time.Duration(s.connectTimeout) * time.Second)
var ret error var ret error
select { select {
case <-s.doResultChan: case <-s.doResultChan:
break break
case <-t.C: case ret = <-s.errChan:
ret = rtmpErr break
} }
t.Stop()
return ret return ret
} }
@ -310,7 +319,8 @@ func (s *ClientSession) parseURL(rawURL string) error {
return rtmpErr return rtmpErr
} }
s.appName = strs[0] s.appName = strs[0]
s.streamName = strs[1] // 有的rtmp服务器会使用url后面的参数比如说用于鉴权这里把它带上
s.streamName = strs[1] + "?" + s.url.RawQuery
log.Debugf("%s %s %s %+v", s.tcURL, s.appName, s.streamName, *s.url) log.Debugf("%s %s %s %+v", s.tcURL, s.appName, s.streamName, *s.url)
return nil return nil
@ -339,7 +349,7 @@ func (s *ClientSession) tcpConnect() error {
} }
var conn net.Conn var conn net.Conn
if conn, err = net.Dial("tcp", addr); err != nil { if conn, err = net.DialTimeout("tcp", addr, time.Duration(s.connectTimeoutMS)*time.Millisecond); err != nil {
return err return err
} }

@ -89,7 +89,7 @@ func (group *Group) DelSubSession(session *ServerSession) {
} }
func (group *Group) Pull(addr string, connectTimeout int64) { func (group *Group) Pull(addr string, connectTimeout int64) {
group.pullSession = NewPullSession(group, connectTimeout) group.pullSession = NewPullSession(group, int(connectTimeout))
defer func() { defer func() {
group.mutex.Lock() group.mutex.Lock()

@ -24,7 +24,6 @@ type StreamMsg struct {
type Stream struct { type Stream struct {
header Header header Header
msgLen int // TODO chef: needed? dup with Header's
timestampAbs int timestampAbs int
msg StreamMsg msg StreamMsg

Loading…
Cancel
Save