You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
lal/pkg/rtsp/client_command_session.go

562 lines
14 KiB
Go

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

// Copyright 2021, Chef. All rights reserved.
// https://github.com/q191201771/lal
//
// Use of this source code is governed by a MIT-style license
// that can be found in the License file.
//
// Author: Chef (191201771@qq.com)
package rtsp
import (
"bufio"
"context"
"fmt"
"net"
"strings"
"time"
"github.com/q191201771/lal/pkg/base"
"github.com/q191201771/lal/pkg/rtprtcp"
"github.com/q191201771/lal/pkg/sdp"
"github.com/q191201771/naza/pkg/connection"
"github.com/q191201771/naza/pkg/nazahttp"
"github.com/q191201771/naza/pkg/nazalog"
"github.com/q191201771/naza/pkg/nazanet"
)
type ClientCommandSessionType int
const (
readBufSize = 256
writeGetParameterIntervalMs = 10000
)
const (
CcstPullSession ClientCommandSessionType = iota
CcstPushSession
)
type ClientCommandSessionOption struct {
DoTimeoutMs int
OverTcp bool
}
var defaultClientCommandSessionOption = ClientCommandSessionOption{
DoTimeoutMs: 10000,
OverTcp: false,
}
type ClientCommandSessionObserver interface {
OnConnectResult()
// only for PullSession
OnDescribeResponse(rawSdp []byte, sdpLogicCtx sdp.LogicContext)
OnSetupWithConn(uri string, rtpConn, rtcpConn *nazanet.UdpConnection)
OnSetupWithChannel(uri string, rtpChannel, rtcpChannel int)
OnSetupResult()
OnInterleavedPacket(packet []byte, channel int)
}
// Push和Pull共用封装了客户端底层信令信令部分。
// 业务方应该使用PushSession和PullSession而不是直接使用ClientCommandSession除非你确定要这么做。
type ClientCommandSession struct {
uniqueKey string
t ClientCommandSessionType
observer ClientCommandSessionObserver
option ClientCommandSessionOption
rawUrl string
urlCtx base.UrlContext
conn connection.Connection
cseq int
methodGetParameterSupported bool
auth Auth
rawSdp []byte
sdpLogicCtx sdp.LogicContext
sessionId string
channel int
waitChan chan error
}
type ModClientCommandSessionOption func(option *ClientCommandSessionOption)
func NewClientCommandSession(t ClientCommandSessionType, uniqueKey string, observer ClientCommandSessionObserver, modOptions ...ModClientCommandSessionOption) *ClientCommandSession {
option := defaultClientCommandSessionOption
for _, fn := range modOptions {
fn(&option)
}
s := &ClientCommandSession{
t: t,
uniqueKey: uniqueKey,
observer: observer,
option: option,
waitChan: make(chan error, 1),
}
nazalog.Infof("[%s] lifecycle new rtsp ClientCommandSession. session=%p", uniqueKey, s)
return s
}
// only for PushSession
func (session *ClientCommandSession) InitWithSdp(rawSdp []byte, sdpLogicCtx sdp.LogicContext) {
session.rawSdp = rawSdp
session.sdpLogicCtx = sdpLogicCtx
}
func (session *ClientCommandSession) Do(rawUrl string) error {
var (
ctx context.Context
cancel context.CancelFunc
)
if session.option.DoTimeoutMs == 0 {
ctx, cancel = context.WithCancel(context.Background())
} else {
ctx, cancel = context.WithTimeout(context.Background(), time.Duration(session.option.DoTimeoutMs)*time.Millisecond)
}
defer cancel()
return session.doContext(ctx, rawUrl)
}
func (session *ClientCommandSession) WaitChan() <-chan error {
return session.waitChan
}
func (session *ClientCommandSession) Dispose() error {
nazalog.Infof("[%s] lifecycle dispose rtsp ClientCommandSession. session=%p", session.uniqueKey, session)
if session.conn == nil {
return base.ErrSessionNotStarted
}
return session.conn.Close()
}
func (session *ClientCommandSession) WriteInterleavedPacket(packet []byte, channel int) error {
if session.conn == nil {
return base.ErrSessionNotStarted
}
_, err := session.conn.Write(packInterleaved(channel, packet))
return err
}
func (session *ClientCommandSession) RemoteAddr() string {
if session.conn == nil {
return ""
}
return session.conn.RemoteAddr().String()
}
func (session *ClientCommandSession) Url() string {
return session.urlCtx.Url
}
func (session *ClientCommandSession) AppName() string {
return session.urlCtx.PathWithoutLastItem
}
func (session *ClientCommandSession) StreamName() string {
return session.urlCtx.LastItemOfPath
}
func (session *ClientCommandSession) RawQuery() string {
return session.urlCtx.RawQuery
}
func (session *ClientCommandSession) UniqueKey() string {
return session.uniqueKey
}
func (session *ClientCommandSession) doContext(ctx context.Context, rawUrl string) error {
errChan := make(chan error, 1)
go func() {
if err := session.connect(rawUrl); err != nil {
errChan <- err
return
}
if err := session.writeOptions(); err != nil {
errChan <- err
return
}
switch session.t {
case CcstPullSession:
if err := session.writeDescribe(); err != nil {
errChan <- err
return
}
if err := session.writeSetup(); err != nil {
errChan <- err
return
}
session.observer.OnSetupResult()
if err := session.writePlay(); err != nil {
errChan <- err
return
}
case CcstPushSession:
if err := session.writeAnnounce(); err != nil {
errChan <- err
return
}
if err := session.writeSetup(); err != nil {
errChan <- err
return
}
session.observer.OnSetupResult()
if err := session.writeRecord(); err != nil {
errChan <- err
return
}
}
errChan <- nil
}()
select {
case <-ctx.Done():
return ctx.Err()
case err := <-errChan:
if err != nil {
return err
}
}
go session.runReadLoop()
return nil
}
func (session *ClientCommandSession) runReadLoop() {
if !session.methodGetParameterSupported {
// TCP模式需要收取数据进行处理
if session.option.OverTcp {
var r = bufio.NewReader(session.conn)
for {
isInterleaved, packet, channel, err := readInterleaved(r)
if err != nil {
session.waitChan <- err
return
}
if isInterleaved {
session.observer.OnInterleavedPacket(packet, int(channel))
}
}
}
// not over tcp
// 接收TCP对端关闭FIN信号
dummy := make([]byte, 1)
_, err := session.conn.Read(dummy)
session.waitChan <- err
return
}
// 对端支持get_parameter需要定时向对端发送get_parameter进行保活
nazalog.Debugf("[%s] start get_parameter timer.", session.uniqueKey)
var r = bufio.NewReader(session.conn)
t := time.NewTicker(writeGetParameterIntervalMs * time.Millisecond)
defer t.Stop()
if session.option.OverTcp {
for {
select {
case <-t.C:
session.cseq++
if err := session.writeCmd(MethodGetParameter, session.urlCtx.RawUrlWithoutUserInfo, nil, ""); err != nil {
session.waitChan <- err
return
}
default:
// noop
}
isInterleaved, packet, channel, err := readInterleaved(r)
if err != nil {
session.waitChan <- err
return
}
if isInterleaved {
session.observer.OnInterleavedPacket(packet, int(channel))
} else {
if _, err := nazahttp.ReadHttpResponseMessage(r); err != nil {
session.waitChan <- err
return
}
}
}
}
// not over tcp
for {
select {
case <-t.C:
session.cseq++
if _, err := session.writeCmdReadResp(MethodGetParameter, session.urlCtx.RawUrlWithoutUserInfo, nil, ""); err != nil {
session.waitChan <- err
return
}
default:
// noop
}
}
}
func (session *ClientCommandSession) connect(rawUrl string) (err error) {
session.rawUrl = rawUrl
session.urlCtx, err = base.ParseRtspUrl(rawUrl)
if err != nil {
return err
}
nazalog.Debugf("[%s] > tcp connect.", session.uniqueKey)
// # 建立连接
conn, err := net.Dial("tcp", session.urlCtx.HostWithPort)
if err != nil {
return err
}
session.conn = connection.New(conn, func(option *connection.Option) {
option.ReadBufSize = readBufSize
})
nazalog.Debugf("[%s] < tcp connect. laddr=%s, raddr=%s", session.uniqueKey, conn.LocalAddr().String(), conn.RemoteAddr().String())
session.observer.OnConnectResult()
return nil
}
func (session *ClientCommandSession) writeOptions() error {
ctx, err := session.writeCmdReadResp(MethodOptions, session.urlCtx.RawUrlWithoutUserInfo, nil, "")
if err != nil {
return err
}
methods, ok := ctx.Headers[HeaderPublic]
if !ok {
return nil
}
if methods[0] == "" {
return nil
}
if strings.Contains(methods[0], MethodGetParameter) {
session.methodGetParameterSupported = true
}
return nil
}
func (session *ClientCommandSession) writeDescribe() error {
headers := map[string]string{
HeaderAccept: HeaderAcceptApplicationSdp,
}
ctx, err := session.writeCmdReadResp(MethodDescribe, session.urlCtx.RawUrlWithoutUserInfo, headers, "")
if err != nil {
return err
}
sdpLogicCtx, err := sdp.ParseSdp2LogicContext(ctx.Body)
if err != nil {
return err
}
session.rawSdp = ctx.Body
session.sdpLogicCtx = sdpLogicCtx
session.observer.OnDescribeResponse(session.rawSdp, session.sdpLogicCtx)
return nil
}
func (session *ClientCommandSession) writeAnnounce() error {
headers := map[string]string{
HeaderAccept: HeaderAcceptApplicationSdp,
}
_, err := session.writeCmdReadResp(MethodAnnounce, session.urlCtx.RawUrlWithoutUserInfo, headers, string(session.rawSdp))
return err
}
func (session *ClientCommandSession) writeSetup() error {
if session.sdpLogicCtx.HasVideoAControl() {
uri := session.sdpLogicCtx.MakeVideoSetupUri(session.urlCtx.RawUrlWithoutUserInfo)
if session.option.OverTcp {
if err := session.writeOneSetupTcp(uri); err != nil {
return err
}
} else {
if err := session.writeOneSetup(uri); err != nil {
return err
}
}
}
// can't else if
if session.sdpLogicCtx.HasAudioAControl() {
uri := session.sdpLogicCtx.MakeAudioSetupUri(session.urlCtx.RawUrlWithoutUserInfo)
if session.option.OverTcp {
if err := session.writeOneSetupTcp(uri); err != nil {
return err
}
} else {
if err := session.writeOneSetup(uri); err != nil {
return err
}
}
}
return nil
}
func (session *ClientCommandSession) writeOneSetup(setupUri string) error {
rtpC, lRtpPort, rtcpC, lRtcpPort, err := availUdpConnPool.Acquire2()
if err != nil {
return err
}
var htv string
switch session.t {
case CcstPushSession:
htv = fmt.Sprintf(HeaderTransportClientRecordTmpl, lRtpPort, lRtcpPort)
case CcstPullSession:
htv = fmt.Sprintf(HeaderTransportClientPlayTmpl, lRtpPort, lRtcpPort)
}
headers := map[string]string{
HeaderTransport: htv,
}
ctx, err := session.writeCmdReadResp(MethodSetup, setupUri, headers, "")
if err != nil {
return err
}
session.sessionId = strings.Split(ctx.Headers[HeaderSession][0], ";")[0]
rRtpPort, rRtcpPort, err := parseServerPort(ctx.Headers[HeaderTransport][0])
if err != nil {
return err
}
nazalog.Debugf("[%s] init conn. lRtpPort=%d, lRtcpPort=%d, rRtpPort=%d, rRtcpPort=%d",
session.uniqueKey, lRtpPort, lRtcpPort, rRtpPort, rRtcpPort)
rtpConn, err := nazanet.NewUdpConnection(func(option *nazanet.UdpConnectionOption) {
option.Conn = rtpC
option.RAddr = net.JoinHostPort(session.urlCtx.Host, fmt.Sprintf("%d", rRtpPort))
option.MaxReadPacketSize = rtprtcp.MaxRtpRtcpPacketSize
})
if err != nil {
return err
}
rtcpConn, err := nazanet.NewUdpConnection(func(option *nazanet.UdpConnectionOption) {
option.Conn = rtcpC
option.RAddr = net.JoinHostPort(session.urlCtx.Host, fmt.Sprintf("%d", rRtcpPort))
option.MaxReadPacketSize = rtprtcp.MaxRtpRtcpPacketSize
})
if err != nil {
return err
}
session.observer.OnSetupWithConn(setupUri, rtpConn, rtcpConn)
return nil
}
func (session *ClientCommandSession) writeOneSetupTcp(setupUri string) error {
rtpChannel := session.channel
rtcpChannel := session.channel + 1
session.channel += 2
var htv string
switch session.t {
case CcstPushSession:
htv = fmt.Sprintf(HeaderTransportClientRecordTcpTmpl, rtpChannel, rtcpChannel)
case CcstPullSession:
htv = fmt.Sprintf(HeaderTransportClientPlayTcpTmpl, rtpChannel, rtcpChannel)
}
headers := map[string]string{
HeaderTransport: htv,
}
ctx, err := session.writeCmdReadResp(MethodSetup, setupUri, headers, "")
if err != nil {
return err
}
session.sessionId = strings.Split(ctx.Headers[HeaderSession][0], ";")[0]
// TODO chef: 这里没有解析回传的channel id了因为我假定了它和request中的是一致的
session.observer.OnSetupWithChannel(setupUri, rtpChannel, rtcpChannel)
return nil
}
func (session *ClientCommandSession) writePlay() error {
headers := map[string]string{
HeaderRange: HeaderRangeDefault,
}
_, err := session.writeCmdReadResp(MethodPlay, session.urlCtx.RawUrlWithoutUserInfo, headers, "")
return err
}
func (session *ClientCommandSession) writeRecord() error {
headers := map[string]string{
HeaderRange: HeaderRangeDefault,
}
_, err := session.writeCmdReadResp(MethodRecord, session.urlCtx.RawUrlWithoutUserInfo, headers, "")
return err
}
func (session *ClientCommandSession) writeCmd(method, uri string, headers map[string]string, body string) error {
session.cseq++
if headers == nil {
headers = make(map[string]string)
}
headers[HeaderCSeq] = fmt.Sprintf("%d", session.cseq)
headers[HeaderUserAgent] = base.LalRtspPullSessionUa
if body != "" {
headers[HeaderContentLength] = fmt.Sprintf("%d", len(body))
}
// 鉴权时固定用RawUrlWithoutUserInfo
auth := session.auth.MakeAuthorization(method, session.urlCtx.RawUrlWithoutUserInfo)
if auth != "" {
headers[HeaderAuthorization] = auth
}
if session.sessionId != "" {
headers[HeaderSession] = session.sessionId
}
req := PackRequest(method, uri, headers, body)
nazalog.Debugf("[%s] > write %s.", session.uniqueKey, method)
//nazalog.Debugf("[%s] > write %s. req=%s", session.uniqueKey, method, req)
_, err := session.conn.Write([]byte(req))
return err
}
// @param headers 可以为nil
// @param body 可以为空
func (session *ClientCommandSession) writeCmdReadResp(method, uri string, headers map[string]string, body string) (ctx nazahttp.HttpRespMsgCtx, err error) {
for i := 0; i < 2; i++ {
if err = session.writeCmd(method, uri, headers, body); err != nil {
return
}
ctx, err = nazahttp.ReadHttpResponseMessage(session.conn)
if err != nil {
return
}
nazalog.Debugf("[%s] < read response. version=%s, code=%s, reason=%s, headers=%+v, body=%s",
session.uniqueKey, ctx.Version, ctx.StatusCode, ctx.Reason, ctx.Headers, string(ctx.Body))
if ctx.StatusCode != "401" {
return
}
session.auth.FeedWwwAuthenticate(ctx.Headers[HeaderWwwAuthenticate], session.urlCtx.Username, session.urlCtx.Password)
}
err = ErrRtsp
return
}