You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
lal/pkg/rtsp/client_command_session.go

585 lines
15 KiB
Go

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

// Copyright 2021, Chef. All rights reserved.
// https://github.com/q191201771/lal
//
// Use of this source code is governed by a MIT-style license
// that can be found in the License file.
//
// Author: Chef (191201771@qq.com)
package rtsp
import (
"bufio"
"context"
"fmt"
"net"
"strings"
"sync"
"time"
"github.com/q191201771/naza/pkg/nazaerrors"
"github.com/q191201771/lal/pkg/base"
"github.com/q191201771/lal/pkg/rtprtcp"
"github.com/q191201771/lal/pkg/sdp"
"github.com/q191201771/naza/pkg/connection"
"github.com/q191201771/naza/pkg/nazahttp"
"github.com/q191201771/naza/pkg/nazanet"
)
type ClientCommandSessionType int
const (
readBufSize = 256
writeGetParameterIntervalMs = 10000
)
const (
CcstPullSession ClientCommandSessionType = iota
CcstPushSession
)
type ClientCommandSessionOption struct {
DoTimeoutMs int
OverTcp bool
}
var defaultClientCommandSessionOption = ClientCommandSessionOption{
DoTimeoutMs: 10000,
OverTcp: false,
}
type ClientCommandSessionObserver interface {
OnConnectResult()
// OnDescribeResponse only for PullSession
OnDescribeResponse(sdpCtx sdp.LogicContext)
OnSetupWithConn(uri string, rtpConn, rtcpConn *nazanet.UdpConnection)
OnSetupWithChannel(uri string, rtpChannel, rtcpChannel int)
OnSetupResult()
OnInterleavedPacket(packet []byte, channel int)
}
// ClientCommandSession Push和Pull共用封装了客户端底层信令信令部分。
// 业务方应该使用PushSession和PullSession而不是直接使用ClientCommandSession除非你确定要这么做。
type ClientCommandSession struct {
uniqueKey string
t ClientCommandSessionType
observer ClientCommandSessionObserver
option ClientCommandSessionOption
rawUrl string
urlCtx base.UrlContext
conn connection.Connection
cseq int
methodGetParameterSupported bool
auth Auth
sdpCtx sdp.LogicContext
sessionId string
channel int
disposeOnce sync.Once
}
type ModClientCommandSessionOption func(option *ClientCommandSessionOption)
func NewClientCommandSession(t ClientCommandSessionType, uniqueKey string, observer ClientCommandSessionObserver, modOptions ...ModClientCommandSessionOption) *ClientCommandSession {
option := defaultClientCommandSessionOption
for _, fn := range modOptions {
fn(&option)
}
s := &ClientCommandSession{
t: t,
uniqueKey: uniqueKey,
observer: observer,
option: option,
}
Log.Infof("[%s] lifecycle new rtsp ClientCommandSession. session=%p", uniqueKey, s)
return s
}
// InitWithSdp only for PushSession
func (session *ClientCommandSession) InitWithSdp(sdpCtx sdp.LogicContext) {
session.sdpCtx = sdpCtx
}
func (session *ClientCommandSession) Do(rawUrl string) error {
var (
ctx context.Context
cancel context.CancelFunc
)
if session.option.DoTimeoutMs == 0 {
ctx, cancel = context.WithCancel(context.Background())
} else {
ctx, cancel = context.WithTimeout(context.Background(), time.Duration(session.option.DoTimeoutMs)*time.Millisecond)
}
defer cancel()
return session.doContext(ctx, rawUrl)
}
// ---------------------------------------------------------------------------------------------------------------------
// IClientSessionLifecycle interface
// ---------------------------------------------------------------------------------------------------------------------
// Dispose 文档请参考: IClientSessionLifecycle interface
//
func (session *ClientCommandSession) Dispose() error {
return session.dispose(nil)
}
// WaitChan 文档请参考: IClientSessionLifecycle interface
//
func (session *ClientCommandSession) WaitChan() <-chan error {
return session.conn.Done()
}
// ---------------------------------------------------------------------------------------------------------------------
func (session *ClientCommandSession) WriteInterleavedPacket(packet []byte, channel int) error {
if session.conn == nil {
return base.ErrSessionNotStarted
}
_, err := session.conn.Write(packInterleaved(channel, packet))
return err
}
func (session *ClientCommandSession) RemoteAddr() string {
if session.conn == nil {
return ""
}
return session.conn.RemoteAddr().String()
}
func (session *ClientCommandSession) Url() string {
return session.urlCtx.Url
}
func (session *ClientCommandSession) AppName() string {
return session.urlCtx.PathWithoutLastItem
}
func (session *ClientCommandSession) StreamName() string {
return session.urlCtx.LastItemOfPath
}
func (session *ClientCommandSession) RawQuery() string {
return session.urlCtx.RawQuery
}
func (session *ClientCommandSession) UniqueKey() string {
return session.uniqueKey
}
func (session *ClientCommandSession) doContext(ctx context.Context, rawUrl string) error {
errChan := make(chan error, 1)
go func() {
if err := session.connect(rawUrl); err != nil {
errChan <- err
return
}
if err := session.writeOptions(); err != nil {
errChan <- err
return
}
switch session.t {
case CcstPullSession:
if err := session.writeDescribe(); err != nil {
errChan <- err
return
}
if err := session.writeSetup(); err != nil {
errChan <- err
return
}
session.observer.OnSetupResult()
if err := session.writePlay(); err != nil {
errChan <- err
return
}
case CcstPushSession:
if err := session.writeAnnounce(); err != nil {
errChan <- err
return
}
if err := session.writeSetup(); err != nil {
errChan <- err
return
}
session.observer.OnSetupResult()
if err := session.writeRecord(); err != nil {
errChan <- err
return
}
}
errChan <- nil
}()
select {
case <-ctx.Done():
_ = session.dispose(nil)
return ctx.Err()
case err := <-errChan:
if err != nil {
_ = session.dispose(err)
return err
}
}
go session.runReadLoop()
return nil
}
func (session *ClientCommandSession) runReadLoop() {
var loopErr error
defer func() {
_ = session.dispose(loopErr)
}()
if !session.methodGetParameterSupported {
// TCP模式需要收取数据进行处理
if session.option.OverTcp {
var r = bufio.NewReader(session.conn)
for {
isInterleaved, packet, channel, err := readInterleaved(r)
if err != nil {
loopErr = err
return
}
if isInterleaved {
session.observer.OnInterleavedPacket(packet, int(channel))
}
}
}
// not over tcp
// 接收TCP对端关闭FIN信号
dummy := make([]byte, 1)
_, err := session.conn.Read(dummy)
loopErr = err
return
}
// 对端支持get_parameter需要定时向对端发送get_parameter进行保活
Log.Debugf("[%s] start get_parameter timer.", session.uniqueKey)
var r = bufio.NewReader(session.conn)
t := time.NewTicker(writeGetParameterIntervalMs * time.Millisecond)
defer t.Stop()
if session.option.OverTcp {
for {
select {
case <-t.C:
session.cseq++
if err := session.writeCmd(MethodGetParameter, session.urlCtx.RawUrlWithoutUserInfo, nil, ""); err != nil {
loopErr = err
return
}
default:
// noop
}
isInterleaved, packet, channel, err := readInterleaved(r)
if err != nil {
loopErr = err
return
}
if isInterleaved {
session.observer.OnInterleavedPacket(packet, int(channel))
} else {
if _, err := nazahttp.ReadHttpResponseMessage(r); err != nil {
loopErr = err
return
}
}
}
}
// not over tcp
for {
select {
case <-t.C:
session.cseq++
if _, err := session.writeCmdReadResp(MethodGetParameter, session.urlCtx.RawUrlWithoutUserInfo, nil, ""); err != nil {
loopErr = err
return
}
default:
// noop
}
}
}
func (session *ClientCommandSession) connect(rawUrl string) (err error) {
session.rawUrl = rawUrl
session.urlCtx, err = base.ParseRtspUrl(rawUrl)
if err != nil {
return err
}
Log.Debugf("[%s] > tcp connect.", session.uniqueKey)
// # 建立连接
conn, err := net.Dial("tcp", session.urlCtx.HostWithPort)
if err != nil {
return err
}
session.conn = connection.New(conn, func(option *connection.Option) {
option.ReadBufSize = readBufSize
})
Log.Debugf("[%s] < tcp connect. laddr=%s, raddr=%s", session.uniqueKey, conn.LocalAddr().String(), conn.RemoteAddr().String())
session.observer.OnConnectResult()
return nil
}
func (session *ClientCommandSession) writeOptions() error {
ctx, err := session.writeCmdReadResp(MethodOptions, session.urlCtx.RawUrlWithoutUserInfo, nil, "")
if err != nil {
return err
}
method := ctx.Headers.Get(HeaderPublic)
if method == "" {
return nil
}
if strings.Contains(method, MethodGetParameter) {
session.methodGetParameterSupported = true
}
return nil
}
func (session *ClientCommandSession) writeDescribe() error {
headers := map[string]string{
HeaderAccept: HeaderAcceptApplicationSdp,
}
ctx, err := session.writeCmdReadResp(MethodDescribe, session.urlCtx.RawUrlWithoutUserInfo, headers, "")
if err != nil {
return err
}
sdpCtx, err := sdp.ParseSdp2LogicContext(ctx.Body)
if err != nil {
return err
}
session.sdpCtx = sdpCtx
session.observer.OnDescribeResponse(session.sdpCtx)
return nil
}
func (session *ClientCommandSession) writeAnnounce() error {
headers := map[string]string{
HeaderAccept: HeaderAcceptApplicationSdp,
}
_, err := session.writeCmdReadResp(MethodAnnounce, session.urlCtx.RawUrlWithoutUserInfo, headers, string(session.sdpCtx.RawSdp))
return err
}
func (session *ClientCommandSession) writeSetup() error {
if session.sdpCtx.HasVideoAControl() {
uri := session.sdpCtx.MakeVideoSetupUri(session.urlCtx.RawUrlWithoutUserInfo)
if session.option.OverTcp {
if err := session.writeOneSetupTcp(uri); err != nil {
return err
}
} else {
if err := session.writeOneSetup(uri); err != nil {
return err
}
}
}
// can't else if
if session.sdpCtx.HasAudioAControl() {
uri := session.sdpCtx.MakeAudioSetupUri(session.urlCtx.RawUrlWithoutUserInfo)
if session.option.OverTcp {
if err := session.writeOneSetupTcp(uri); err != nil {
return err
}
} else {
if err := session.writeOneSetup(uri); err != nil {
return err
}
}
}
return nil
}
func (session *ClientCommandSession) writeOneSetup(setupUri string) error {
rtpC, lRtpPort, rtcpC, lRtcpPort, err := availUdpConnPool.Acquire2()
if err != nil {
return err
}
var htv string
switch session.t {
case CcstPushSession:
htv = fmt.Sprintf(HeaderTransportClientRecordTmpl, lRtpPort, lRtcpPort)
case CcstPullSession:
htv = fmt.Sprintf(HeaderTransportClientPlayTmpl, lRtpPort, lRtcpPort)
}
headers := map[string]string{
HeaderTransport: htv,
}
ctx, err := session.writeCmdReadResp(MethodSetup, setupUri, headers, "")
if err != nil {
return err
}
session.sessionId = strings.Split(ctx.Headers.Get(HeaderSession), ";")[0]
rRtpPort, rRtcpPort, err := parseServerPort(ctx.Headers.Get(HeaderTransport))
if err != nil {
return err
}
Log.Debugf("[%s] init conn. lRtpPort=%d, lRtcpPort=%d, rRtpPort=%d, rRtcpPort=%d",
session.uniqueKey, lRtpPort, lRtcpPort, rRtpPort, rRtcpPort)
rtpConn, err := nazanet.NewUdpConnection(func(option *nazanet.UdpConnectionOption) {
option.Conn = rtpC
option.RAddr = net.JoinHostPort(session.urlCtx.Host, fmt.Sprintf("%d", rRtpPort))
option.MaxReadPacketSize = rtprtcp.MaxRtpRtcpPacketSize
})
if err != nil {
return err
}
rtcpConn, err := nazanet.NewUdpConnection(func(option *nazanet.UdpConnectionOption) {
option.Conn = rtcpC
option.RAddr = net.JoinHostPort(session.urlCtx.Host, fmt.Sprintf("%d", rRtcpPort))
option.MaxReadPacketSize = rtprtcp.MaxRtpRtcpPacketSize
})
if err != nil {
return err
}
session.observer.OnSetupWithConn(setupUri, rtpConn, rtcpConn)
return nil
}
func (session *ClientCommandSession) writeOneSetupTcp(setupUri string) error {
rtpChannel := session.channel
rtcpChannel := session.channel + 1
session.channel += 2
var htv string
switch session.t {
case CcstPushSession:
htv = fmt.Sprintf(HeaderTransportClientRecordTcpTmpl, rtpChannel, rtcpChannel)
case CcstPullSession:
htv = fmt.Sprintf(HeaderTransportClientPlayTcpTmpl, rtpChannel, rtcpChannel)
}
headers := map[string]string{
HeaderTransport: htv,
}
ctx, err := session.writeCmdReadResp(MethodSetup, setupUri, headers, "")
if err != nil {
return err
}
session.sessionId = strings.Split(ctx.Headers.Get(HeaderSession), ";")[0]
// TODO chef: 这里没有解析回传的channel id了因为我假定了它和request中的是一致的
session.observer.OnSetupWithChannel(setupUri, rtpChannel, rtcpChannel)
return nil
}
func (session *ClientCommandSession) writePlay() error {
headers := map[string]string{
HeaderRange: HeaderRangeDefault,
}
_, err := session.writeCmdReadResp(MethodPlay, session.urlCtx.RawUrlWithoutUserInfo, headers, "")
return err
}
func (session *ClientCommandSession) writeRecord() error {
headers := map[string]string{
HeaderRange: HeaderRangeDefault,
}
_, err := session.writeCmdReadResp(MethodRecord, session.urlCtx.RawUrlWithoutUserInfo, headers, "")
return err
}
func (session *ClientCommandSession) writeCmd(method, uri string, headers map[string]string, body string) error {
session.cseq++
if headers == nil {
headers = make(map[string]string)
}
headers[HeaderCSeq] = fmt.Sprintf("%d", session.cseq)
headers[HeaderUserAgent] = base.LalRtspPullSessionUa
if body != "" {
headers[HeaderContentLength] = fmt.Sprintf("%d", len(body))
}
// 鉴权时固定用RawUrlWithoutUserInfo
auth := session.auth.MakeAuthorization(method, session.urlCtx.RawUrlWithoutUserInfo)
if auth != "" {
headers[HeaderAuthorization] = auth
}
if session.sessionId != "" {
headers[HeaderSession] = session.sessionId
}
req := PackRequest(method, uri, headers, body)
Log.Debugf("[%s] > write %s.", session.uniqueKey, method)
//Log.Debugf("[%s] > write %s. req=%s", session.uniqueKey, method, req)
_, err := session.conn.Write([]byte(req))
return err
}
// @param headers 可以为nil
// @param body 可以为空
func (session *ClientCommandSession) writeCmdReadResp(method, uri string, headers map[string]string, body string) (ctx nazahttp.HttpRespMsgCtx, err error) {
for i := 0; i < 2; i++ {
if err = session.writeCmd(method, uri, headers, body); err != nil {
return
}
ctx, err = nazahttp.ReadHttpResponseMessage(session.conn)
if err != nil {
return
}
Log.Debugf("[%s] < read response. version=%s, code=%s, reason=%s, headers=%+v, body=%s",
session.uniqueKey, ctx.Version, ctx.StatusCode, ctx.Reason, ctx.Headers, string(ctx.Body))
if ctx.StatusCode != "401" {
return
}
session.auth.FeedWwwAuthenticate(ctx.Headers.Values(HeaderWwwAuthenticate), session.urlCtx.Username, session.urlCtx.Password)
}
// TODO(chef): refactor never reach here
err = nazaerrors.Wrap(base.ErrRtsp)
return
}
func (session *ClientCommandSession) dispose(err error) error {
var retErr error
session.disposeOnce.Do(func() {
Log.Infof("[%s] lifecycle dispose rtsp ClientCommandSession. session=%p, err=%+v", session.uniqueKey, session, err)
if session.conn == nil {
retErr = base.ErrSessionNotStarted
return
}
retErr = session.conn.Close()
})
return retErr
}