SquashSRS4: Regine DTLS and add regression tests. 4.0.84

pull/2252/head
winlin 4 years ago
parent dc93836489
commit e74810230a

@ -186,6 +186,7 @@ Other documents:
## V4 changes ## V4 changes
* v4.0, 2021-03-09, DTLS: Fix ARQ bug, use openssl timeout. 4.0.84
* v4.0, 2021-03-08, DTLS: Fix dead loop by duplicated Alert message. 4.0.83 * v4.0, 2021-03-08, DTLS: Fix dead loop by duplicated Alert message. 4.0.83
* v4.0, 2021-03-08, Fix bug when client DTLS is passive. 4.0.82 * v4.0, 2021-03-08, Fix bug when client DTLS is passive. 4.0.82
* v4.0, 2021-03-03, Fix [#2106][bug #2106], [#2011][bug #2011], RTMP/AAC transcode to Opus bug. 4.0.81 * v4.0, 2021-03-03, Fix [#2106][bug #2106], [#2011][bug #2011], RTMP/AAC transcode to Opus bug. 4.0.81

2
trunk/.gitignore vendored

@ -34,7 +34,7 @@
/research/speex/ /research/speex/
/test/ /test/
.DS_Store .DS_Store
srs ./srs
*.dSYM/ *.dSYM/
*.gcov *.gcov
*.ts *.ts

@ -1,17 +0,0 @@
#!/bin/bash
# check exists.
if [[ -f /usr/local/bin/ccache ]]; then
echo "ccache is ok";
exit 0;
fi
# check sudoer.
sudo echo "ok" > /dev/null 2>&1;
ret=$?; if [[ 0 -ne ${ret} ]]; then echo "you must be sudoer"; exit 1; fi
unzip ccache-3.1.9.zip && cd ccache-3.1.9 && ./configure && make
ret=$?; if [[ $ret -ne 0 ]]; then echo "build ccache failed."; exit $ret; fi
sudo cp ccache /usr/local/bin && sudo ln -s ccache /usr/local/bin/gcc && sudo ln -s ccache /usr/local/bin/g++ && sudo ln -s ccache /usr/local/bin/cc && sudo ln -s ccache /usr/local/bin/c++
ret=$?; if [[ $ret -ne 0 ]]; then echo "install ccache failed."; exit $ret; fi

Binary file not shown.

@ -1,11 +0,0 @@
ccache是samba组织提供的加速编译过程的工具
使用虚拟机编译可以考虑用这个工具,让编译过程飞快。
链接:
http://ccache.samba.org/
http://samba.org/ftp/ccache/ccache-3.1.9.tar.xz
http://ccache.samba.org/manual.html
安装方法:
bash build_ccache.sh
注意要求以sudoer执行要修改文件。

@ -0,0 +1,20 @@
The MIT License (MIT)
Copyright (c) 2021 srs-bench(ossrs)
Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
the Software without restriction, including without limitation the rights to
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
the Software, and to permit persons to whom the Software is furnished to do so,
subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

@ -5,18 +5,18 @@ default: bench test
clean: clean:
rm -f ./objs/srs_bench ./objs/srs_test rm -f ./objs/srs_bench ./objs/srs_test
.format.txt: *.go rtc/*.go srs/*.go .format.txt: *.go srs/*.go vnet/*.go
gofmt -w . gofmt -w .
echo "done" > .format.txt echo "done" > .format.txt
bench: ./objs/srs_bench bench: ./objs/srs_bench
./objs/srs_bench: .format.txt *.go rtc/*.go srs/*.go Makefile ./objs/srs_bench: .format.txt *.go srs/*.go vnet/*.go Makefile
go build -mod=vendor -o objs/srs_bench . go build -mod=vendor -o objs/srs_bench .
test: ./objs/srs_test test: ./objs/srs_test
./objs/srs_test: .format.txt *.go rtc/*.go srs/*.go Makefile ./objs/srs_test: .format.txt *.go srs/*.go vnet/*.go Makefile
go test ./srs -mod=vendor -c -o ./objs/srs_test go test ./srs -mod=vendor -c -o ./objs/srs_test
help: help:

@ -102,8 +102,7 @@ ffmpeg -re -i doc/source.200kbps.768x320.flv -c copy -f flv -y rtmp://localhost/
回归测试需要先启动[SRS](https://github.com/ossrs/srs/issues/307)支持WebRTC推拉流 回归测试需要先启动[SRS](https://github.com/ossrs/srs/issues/307)支持WebRTC推拉流
```bash ```bash
eip=$(ifconfig en0 inet| grep 'inet '|awk '{print $2}') if [[ ! -z $(ifconfig en0 inet| grep 'inet '|awk '{print $2}') ]]; then
if [[ ! -z $eip ]]; then
docker run -p 1935:1935 -p 8080:8080 -p 1985:1985 -p 8000:8000/udp \ docker run -p 1935:1935 -p 8080:8080 -p 1985:1985 -p 8000:8000/udp \
--rm --env CANDIDATE=$(ifconfig en0 inet| grep 'inet '|awk '{print $2}')\ --rm --env CANDIDATE=$(ifconfig en0 inet| grep 'inet '|awk '{print $2}')\
registry.cn-hangzhou.aliyuncs.com/ossrs/srs:v4.0.76 objs/srs -c conf/rtc.conf registry.cn-hangzhou.aliyuncs.com/ossrs/srs:v4.0.76 objs/srs -c conf/rtc.conf
@ -119,7 +118,20 @@ go test ./srs -mod=vendor -v
也可以用make编译出重复使用的二进制 也可以用make编译出重复使用的二进制
```bash ```bash
make test && ./objs/srs_test -test.v make && ./objs/srs_test -test.v
```
> Note: 注意由于pion不支持`DTLS 1.0`所以SFU必须要支持`DTLS 1.2`才行。
运行结果如下:
```bash
$ make && ./objs/srs_test -test.v
=== RUN TestRTCServerVersion
--- PASS: TestRTCServerVersion (0.00s)
=== RUN TestRTCServerPublishPlay
--- PASS: TestRTCServerPublishPlay (1.28s)
PASS
``` ```
可以给回归测试传参数,这样可以测试不同的序列,比如: 可以给回归测试传参数,这样可以测试不同的序列,比如:
@ -127,23 +139,43 @@ make test && ./objs/srs_test -test.v
```bash ```bash
go test ./srs -mod=vendor -v -srs-server=127.0.0.1 go test ./srs -mod=vendor -v -srs-server=127.0.0.1
# Or # Or
make test && ./objs/srs_test -test.v -srs-server=127.0.0.1 make && ./objs/srs_test -test.v -srs-server=127.0.0.1
``` ```
支持的参数如下: 支持的参数如下:
* `-srs-server`RTC服务器地址。默认值`127.0.0.1` * `-srs-server`RTC服务器地址。默认值`127.0.0.1`
* `-srs-stream`RTC流地址。默认值`/rtc/regression` * `-srs-stream`RTC流地址。默认值`/rtc/regression`
* `-srs-log`,是否开启详细日志。默认值:`false`
* `-srs-timeout`每个Case的超时时间毫秒。默认值`3000`即3秒。 * `-srs-timeout`每个Case的超时时间毫秒。默认值`3000`即3秒。
* `-srs-play-pli`播放时PLI的间隔毫秒。默认值`5000`即5秒。
* `-srs-play-ok-packets`,播放时,收到多少个包认为是测试通过,默认值:`10`
* `-srs-publish-audio`,推流时,使用的音频文件。默认值:`avatar.ogg` * `-srs-publish-audio`,推流时,使用的音频文件。默认值:`avatar.ogg`
* `-srs-publish-video`,推流时,使用的视频文件。默认值:`avatar.h264` * `-srs-publish-video`,推流时,使用的视频文件。默认值:`avatar.h264`
* `-srs-publish-video-fps`推流时视频文件的FPS。默认值`25` * `-srs-publish-video-fps`推流时视频文件的FPS。默认值`25`
* `-srs-vnet-client-ip`,设置[pion/vnet](https://github.com/ossrs/srs-bench/blob/feature/rtc/vnet/example_test.go)客户端的虚拟IP不能和服务器IP冲突。默认值`192.168.168.168`
其他不常用参数: 其他不常用参数:
* `-srs-log`,是否开启详细日志。默认值:`false`
* `-srs-play-ok-packets`,播放时,收到多少个包认为是测试通过,默认值:`10`
* `-srs-publish-ok-packets`,推流时,发送多少个包认为时测试通过,默认值:`10`
* `-srs-https`是否连接HTTPS-API。默认值`false`即连接HTTP-API。 * `-srs-https`是否连接HTTPS-API。默认值`false`即连接HTTP-API。
* `-srs-play-pli`播放时PLI的间隔毫秒。默认值`5000`即5秒。
* `-srs-dtls-drop-packets`DTLS丢包测试丢了多少个包算成功默认值`5`
## GCOVR
本机生成覆盖率时,我们使用工具[gcovr](https://gcovr.com/en/stable/guide.html)。
在macOS上安装gcovr
```bash
pip3 install gcovr
```
在CentOS上安装gcovr
```bash
yum install -y python2-pip &&
pip install lxml && pip install gcovr
```
2021.01, Winlin 2021.01, Winlin

Binary file not shown.

@ -1,3 +1,23 @@
// The MIT License (MIT)
//
// Copyright (c) 2021 srs-bench(ossrs)
//
// Permission is hereby granted, free of charge, to any person obtaining a copy of
// this software and associated documentation files (the "Software"), to deal in
// the Software without restriction, including without limitation the rights to
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
// the Software, and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
package main package main
import ( import (
@ -42,7 +62,7 @@ func main() {
flag.IntVar(&delay, "delay", 50, "") flag.IntVar(&delay, "delay", 50, "")
var statListen string var statListen string
flag.StringVar(&statListen, "stat", ":18000", "") flag.StringVar(&statListen, "stat", "", "")
flag.Usage = func() { flag.Usage = func() {
fmt.Println(fmt.Sprintf("Usage: %v [Options]", os.Args[0])) fmt.Println(fmt.Sprintf("Usage: %v [Options]", os.Args[0]))
@ -52,7 +72,7 @@ func main() {
fmt.Println(fmt.Sprintf(" -delay The start delay in ms for each client or stream to simulate. Default: 50")) fmt.Println(fmt.Sprintf(" -delay The start delay in ms for each client or stream to simulate. Default: 50"))
fmt.Println(fmt.Sprintf(" -al [Optional] Whether enable audio-level. Default: true")) fmt.Println(fmt.Sprintf(" -al [Optional] Whether enable audio-level. Default: true"))
fmt.Println(fmt.Sprintf(" -twcc [Optional] Whether enable vdieo-twcc. Default: true")) fmt.Println(fmt.Sprintf(" -twcc [Optional] Whether enable vdieo-twcc. Default: true"))
fmt.Println(fmt.Sprintf(" -stat [Optional] The stat server API listen port. Default: :18000")) fmt.Println(fmt.Sprintf(" -stat [Optional] The stat server API listen port."))
fmt.Println(fmt.Sprintf("Player or Subscriber:")) fmt.Println(fmt.Sprintf("Player or Subscriber:"))
fmt.Println(fmt.Sprintf(" -sr The url to play/subscribe. If sn exceed 1, auto append variable %%d.")) fmt.Println(fmt.Sprintf(" -sr The url to play/subscribe. If sn exceed 1, auto append variable %%d."))
fmt.Println(fmt.Sprintf(" -da [Optional] The file path to dump audio, ignore if empty.")) fmt.Println(fmt.Sprintf(" -da [Optional] The file path to dump audio, ignore if empty."))

@ -1,5 +0,0 @@
package rtc
const (
rtpOutboundMTU = 1200
)

@ -1,27 +0,0 @@
package rtc
import (
"github.com/pion/rtp"
"github.com/pion/rtp/codecs"
"github.com/pion/webrtc/v3"
"strings"
)
func payloaderForCodec(codec webrtc.RTPCodecCapability) (rtp.Payloader, error) {
switch strings.ToLower(codec.MimeType) {
case strings.ToLower(webrtc.MimeTypeH264):
return &codecs.H264Payloader{}, nil
case strings.ToLower(webrtc.MimeTypeOpus):
return &codecs.OpusPayloader{}, nil
case strings.ToLower(webrtc.MimeTypeVP8):
return &codecs.VP8Payloader{}, nil
case strings.ToLower(webrtc.MimeTypeVP9):
return &codecs.VP9Payloader{}, nil
case strings.ToLower(webrtc.MimeTypeG722):
return &codecs.G722Payloader{}, nil
case strings.ToLower(webrtc.MimeTypePCMU), strings.ToLower(webrtc.MimeTypePCMA):
return &codecs.G711Payloader{}, nil
default:
return nil, webrtc.ErrNoPayloaderForCodec
}
}

@ -1,27 +0,0 @@
package rtc
import (
"github.com/pion/webrtc/v3"
"strings"
)
// Do a fuzzy find for a codec in the list of codecs
// Used for lookup up a codec in an existing list to find a match
func codecParametersFuzzySearch(needle webrtc.RTPCodecParameters, haystack []webrtc.RTPCodecParameters) (webrtc.RTPCodecParameters, error) {
// First attempt to match on MimeType + SDPFmtpLine
for _, c := range haystack {
if strings.EqualFold(c.RTPCodecCapability.MimeType, needle.RTPCodecCapability.MimeType) &&
c.RTPCodecCapability.SDPFmtpLine == needle.RTPCodecCapability.SDPFmtpLine {
return c, nil
}
}
// Fallback to just MimeType
for _, c := range haystack {
if strings.EqualFold(c.RTPCodecCapability.MimeType, needle.RTPCodecCapability.MimeType) {
return c, nil
}
}
return webrtc.RTPCodecParameters{}, webrtc.ErrCodecNotFound
}

@ -1,246 +0,0 @@
package rtc
import (
"github.com/pion/rtp"
"github.com/pion/webrtc/v3"
"github.com/pion/webrtc/v3/pkg/media"
"strings"
"sync"
)
// trackBinding is a single bind for a Track
// Bind can be called multiple times, this stores the
// result for a single bind call so that it can be used when writing
type trackBinding struct {
id string
ssrc webrtc.SSRC
payloadType webrtc.PayloadType
writeStream webrtc.TrackLocalWriter
}
// TrackLocalStaticRTP is a TrackLocal that has a pre-set codec and accepts RTP Packets.
// If you wish to send a media.Sample use TrackLocalStaticSample
type TrackLocalStaticRTP struct {
mu sync.RWMutex
bindings []trackBinding
codec webrtc.RTPCodecCapability
id, streamID string
}
// NewTrackLocalStaticRTP returns a TrackLocalStaticRTP.
func NewTrackLocalStaticRTP(c webrtc.RTPCodecCapability, id, streamID string) (*TrackLocalStaticRTP, error) {
return &TrackLocalStaticRTP{
codec: c,
bindings: []trackBinding{},
id: id,
streamID: streamID,
}, nil
}
// Bind is called by the PeerConnection after negotiation is complete
// This asserts that the code requested is supported by the remote peer.
// If so it setups all the state (SSRC and PayloadType) to have a call
func (s *TrackLocalStaticRTP) Bind(t webrtc.TrackLocalContext) (webrtc.RTPCodecParameters, error) {
s.mu.Lock()
defer s.mu.Unlock()
parameters := webrtc.RTPCodecParameters{RTPCodecCapability: s.codec}
if codec, err := codecParametersFuzzySearch(parameters, t.CodecParameters()); err == nil {
s.bindings = append(s.bindings, trackBinding{
ssrc: t.SSRC(),
payloadType: codec.PayloadType,
writeStream: t.WriteStream(),
id: t.ID(),
})
return codec, nil
}
return webrtc.RTPCodecParameters{}, webrtc.ErrUnsupportedCodec
}
// Unbind implements the teardown logic when the track is no longer needed. This happens
// because a track has been stopped.
func (s *TrackLocalStaticRTP) Unbind(t webrtc.TrackLocalContext) error {
s.mu.Lock()
defer s.mu.Unlock()
for i := range s.bindings {
if s.bindings[i].id == t.ID() {
s.bindings[i] = s.bindings[len(s.bindings)-1]
s.bindings = s.bindings[:len(s.bindings)-1]
return nil
}
}
return webrtc.ErrUnbindFailed
}
// ID is the unique identifier for this Track. This should be unique for the
// stream, but doesn't have to globally unique. A common example would be 'audio' or 'video'
// and StreamID would be 'desktop' or 'webcam'
func (s *TrackLocalStaticRTP) ID() string { return s.id }
// StreamID is the group this track belongs too. This must be unique
func (s *TrackLocalStaticRTP) StreamID() string { return s.streamID }
// Kind controls if this TrackLocal is audio or video
func (s *TrackLocalStaticRTP) Kind() webrtc.RTPCodecType {
switch {
case strings.HasPrefix(s.codec.MimeType, "audio/"):
return webrtc.RTPCodecTypeAudio
case strings.HasPrefix(s.codec.MimeType, "video/"):
return webrtc.RTPCodecTypeVideo
default:
return webrtc.RTPCodecType(0)
}
}
// Codec gets the Codec of the track
func (s *TrackLocalStaticRTP) Codec() webrtc.RTPCodecCapability {
return s.codec
}
// WriteRTP writes a RTP Packet to the TrackLocalStaticRTP
// If one PeerConnection fails the packets will still be sent to
// all PeerConnections. The error message will contain the ID of the failed
// PeerConnections so you can remove them
func (s *TrackLocalStaticRTP) WriteRTP(p *rtp.Packet) error {
s.mu.RLock()
defer s.mu.RUnlock()
writeErrs := []error{}
outboundPacket := *p
for _, b := range s.bindings {
outboundPacket.Header.SSRC = uint32(b.ssrc)
outboundPacket.Header.PayloadType = uint8(b.payloadType)
if _, err := b.writeStream.WriteRTP(&outboundPacket.Header, outboundPacket.Payload); err != nil {
writeErrs = append(writeErrs, err)
}
}
return FlattenErrs(writeErrs)
}
// Write writes a RTP Packet as a buffer to the TrackLocalStaticRTP
// If one PeerConnection fails the packets will still be sent to
// all PeerConnections. The error message will contain the ID of the failed
// PeerConnections so you can remove them
func (s *TrackLocalStaticRTP) Write(b []byte) (n int, err error) {
packet := &rtp.Packet{}
if err = packet.Unmarshal(b); err != nil {
return 0, err
}
return len(b), s.WriteRTP(packet)
}
// TrackLocalStaticSample is a TrackLocal that has a pre-set codec and accepts Samples.
// If you wish to send a RTP Packet use TrackLocalStaticRTP
type TrackLocalStaticSample struct {
packetizer rtp.Packetizer
rtpTrack *TrackLocalStaticRTP
clockRate float64
// Set the callback before write RTP packet.
OnBeforeWritePacket func(rtp *rtp.Packet)
}
// NewTrackLocalStaticSample returns a TrackLocalStaticSample
func NewTrackLocalStaticSample(c webrtc.RTPCodecCapability, id, streamID string) (*TrackLocalStaticSample, error) {
rtpTrack, err := NewTrackLocalStaticRTP(c, id, streamID)
if err != nil {
return nil, err
}
return &TrackLocalStaticSample{
rtpTrack: rtpTrack,
}, nil
}
// ID is the unique identifier for this Track. This should be unique for the
// stream, but doesn't have to globally unique. A common example would be 'audio' or 'video'
// and StreamID would be 'desktop' or 'webcam'
func (s *TrackLocalStaticSample) ID() string { return s.rtpTrack.ID() }
// StreamID is the group this track belongs too. This must be unique
func (s *TrackLocalStaticSample) StreamID() string { return s.rtpTrack.StreamID() }
// Kind controls if this TrackLocal is audio or video
func (s *TrackLocalStaticSample) Kind() webrtc.RTPCodecType { return s.rtpTrack.Kind() }
// Codec gets the Codec of the track
func (s *TrackLocalStaticSample) Codec() webrtc.RTPCodecCapability {
return s.rtpTrack.Codec()
}
// Bind is called by the PeerConnection after negotiation is complete
// This asserts that the code requested is supported by the remote peer.
// If so it setups all the state (SSRC and PayloadType) to have a call
func (s *TrackLocalStaticSample) Bind(t webrtc.TrackLocalContext) (webrtc.RTPCodecParameters, error) {
codec, err := s.rtpTrack.Bind(t)
if err != nil {
return codec, err
}
s.rtpTrack.mu.Lock()
defer s.rtpTrack.mu.Unlock()
// We only need one packetizer
if s.packetizer != nil {
return codec, nil
}
payloader, err := payloaderForCodec(codec.RTPCodecCapability)
if err != nil {
return codec, err
}
s.packetizer = rtp.NewPacketizer(
rtpOutboundMTU,
0, // Value is handled when writing
0, // Value is handled when writing
payloader,
rtp.NewRandomSequencer(),
codec.ClockRate,
)
s.clockRate = float64(codec.RTPCodecCapability.ClockRate)
return codec, nil
}
// Unbind implements the teardown logic when the track is no longer needed. This happens
// because a track has been stopped.
func (s *TrackLocalStaticSample) Unbind(t webrtc.TrackLocalContext) error {
return s.rtpTrack.Unbind(t)
}
// WriteSample writes a Sample to the TrackLocalStaticSample
// If one PeerConnection fails the packets will still be sent to
// all PeerConnections. The error message will contain the ID of the failed
// PeerConnections so you can remove them
func (s *TrackLocalStaticSample) WriteSample(sample media.Sample) error {
s.rtpTrack.mu.RLock()
p := s.packetizer
clockRate := s.clockRate
s.rtpTrack.mu.RUnlock()
if p == nil {
return nil
}
samples := sample.Duration.Seconds() * clockRate
packets := p.(rtp.Packetizer).Packetize(sample.Data, uint32(samples))
writeErrs := []error{}
for _, p := range packets {
if s.OnBeforeWritePacket != nil {
s.OnBeforeWritePacket(p)
}
if err := s.rtpTrack.WriteRTP(p); err != nil {
writeErrs = append(writeErrs, err)
}
}
return FlattenErrs(writeErrs)
}

@ -1,10 +0,0 @@
package rtc
import "fmt"
func FlattenErrs(errors []error) error {
if len(errors) == 0 {
return nil
}
return fmt.Errorf("%v", errors)
}

@ -0,0 +1,285 @@
// The MIT License (MIT)
//
// Copyright (c) 2021 srs-bench(ossrs)
//
// Permission is hereby granted, free of charge, to any person obtaining a copy of
// this software and associated documentation files (the "Software"), to deal in
// the Software without restriction, including without limitation the rights to
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
// the Software, and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
package srs
import (
"context"
"github.com/ossrs/go-oryx-lib/errors"
"github.com/ossrs/go-oryx-lib/logger"
"github.com/pion/interceptor"
"github.com/pion/rtp"
"github.com/pion/sdp/v3"
"github.com/pion/webrtc/v3"
"github.com/pion/webrtc/v3/pkg/media"
"github.com/pion/webrtc/v3/pkg/media/h264reader"
"github.com/pion/webrtc/v3/pkg/media/oggreader"
"io"
"os"
"strings"
"time"
)
type videoIngester struct {
sourceVideo string
fps int
markerInterceptor *RTPInterceptor
sVideoTrack *webrtc.TrackLocalStaticSample
sVideoSender *webrtc.RTPSender
}
func NewVideoIngester(sourceVideo string) *videoIngester {
return &videoIngester{markerInterceptor: &RTPInterceptor{}, sourceVideo: sourceVideo}
}
func (v *videoIngester) Close() error {
if v.sVideoSender != nil {
v.sVideoSender.Stop()
v.sVideoSender = nil
}
return nil
}
func (v *videoIngester) AddTrack(pc *webrtc.PeerConnection, fps int) error {
v.fps = fps
mimeType, trackID := "video/H264", "video"
if strings.HasSuffix(v.sourceVideo, ".ivf") {
mimeType = "video/VP8"
}
var err error
v.sVideoTrack, err = webrtc.NewTrackLocalStaticSample(
webrtc.RTPCodecCapability{MimeType: mimeType, ClockRate: 90000}, trackID, "pion",
)
if err != nil {
return errors.Wrapf(err, "Create video track")
}
v.sVideoSender, err = pc.AddTrack(v.sVideoTrack)
if err != nil {
return errors.Wrapf(err, "Add video track")
}
return err
}
func (v *videoIngester) Ingest(ctx context.Context) error {
source, sender, track, fps := v.sourceVideo, v.sVideoSender, v.sVideoTrack, v.fps
f, err := os.Open(source)
if err != nil {
return errors.Wrapf(err, "Open file %v", source)
}
defer f.Close()
// TODO: FIXME: Support ivf for vp8.
h264, err := h264reader.NewReader(f)
if err != nil {
return errors.Wrapf(err, "Open h264 %v", source)
}
enc := sender.GetParameters().Encodings[0]
codec := sender.GetParameters().Codecs[0]
headers := sender.GetParameters().HeaderExtensions
logger.Tf(ctx, "Video %v, tbn=%v, fps=%v, ssrc=%v, pt=%v, header=%v",
codec.MimeType, codec.ClockRate, fps, enc.SSRC, codec.PayloadType, headers)
clock := newWallClock()
sampleDuration := time.Duration(uint64(time.Millisecond) * 1000 / uint64(fps))
for ctx.Err() == nil {
var sps, pps *h264reader.NAL
var oFrames []*h264reader.NAL
for ctx.Err() == nil {
frame, err := h264.NextNAL()
if err == io.EOF {
return io.EOF
}
if err != nil {
return errors.Wrapf(err, "Read h264")
}
oFrames = append(oFrames, frame)
logger.If(ctx, "NALU %v PictureOrderCount=%v, ForbiddenZeroBit=%v, RefIdc=%v, %v bytes",
frame.UnitType.String(), frame.PictureOrderCount, frame.ForbiddenZeroBit, frame.RefIdc, len(frame.Data))
if frame.UnitType == h264reader.NalUnitTypeSPS {
sps = frame
} else if frame.UnitType == h264reader.NalUnitTypePPS {
pps = frame
} else {
break
}
}
var frames []*h264reader.NAL
// Package SPS/PPS to STAP-A
if sps != nil && pps != nil {
stapA := packageAsSTAPA(sps, pps)
frames = append(frames, stapA)
}
// Append other original frames.
for _, frame := range oFrames {
if frame.UnitType != h264reader.NalUnitTypeSPS && frame.UnitType != h264reader.NalUnitTypePPS {
frames = append(frames, frame)
}
}
// Covert frames to sample(buffers).
for i, frame := range frames {
sample := media.Sample{Data: frame.Data, Duration: sampleDuration}
// Use the sample timestamp for frames.
if i != len(frames)-1 {
sample.Duration = 0
}
// For STAP-A, set marker to false, to make Chrome happy.
if ri := v.markerInterceptor; ri.rtpWriter == nil {
ri.rtpWriter = func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) {
// TODO: Should we decode to check whether SPS/PPS?
if len(payload) > 0 && payload[0]&0x1f == 24 {
header.Marker = false // 24, STAP-A
}
return ri.nextRTPWriter.Write(header, payload, attributes)
}
}
if err = track.WriteSample(sample); err != nil {
return errors.Wrapf(err, "Write sample")
}
}
if d := clock.Tick(sampleDuration); d > 0 {
time.Sleep(d)
}
}
return ctx.Err()
}
type audioIngester struct {
sourceAudio string
audioLevelInterceptor *RTPInterceptor
sAudioTrack *webrtc.TrackLocalStaticSample
sAudioSender *webrtc.RTPSender
}
func NewAudioIngester(sourceAudio string) *audioIngester {
return &audioIngester{audioLevelInterceptor: &RTPInterceptor{}, sourceAudio: sourceAudio}
}
func (v *audioIngester) Close() error {
if v.sAudioSender != nil {
v.sAudioSender.Stop()
v.sAudioSender = nil
}
return nil
}
func (v *audioIngester) AddTrack(pc *webrtc.PeerConnection) error {
var err error
mimeType, trackID := "audio/opus", "audio"
v.sAudioTrack, err = webrtc.NewTrackLocalStaticSample(
webrtc.RTPCodecCapability{MimeType: mimeType, ClockRate: 48000, Channels: 2}, trackID, "pion",
)
if err != nil {
return errors.Wrapf(err, "Create audio track")
}
v.sAudioSender, err = pc.AddTrack(v.sAudioTrack)
if err != nil {
return errors.Wrapf(err, "Add audio track")
}
return nil
}
func (v *audioIngester) Ingest(ctx context.Context) error {
source, sender, track := v.sourceAudio, v.sAudioSender, v.sAudioTrack
f, err := os.Open(source)
if err != nil {
return errors.Wrapf(err, "Open file %v", source)
}
defer f.Close()
ogg, _, err := oggreader.NewWith(f)
if err != nil {
return errors.Wrapf(err, "Open ogg %v", source)
}
enc := sender.GetParameters().Encodings[0]
codec := sender.GetParameters().Codecs[0]
headers := sender.GetParameters().HeaderExtensions
logger.Tf(ctx, "Audio %v, tbn=%v, channels=%v, ssrc=%v, pt=%v, header=%v",
codec.MimeType, codec.ClockRate, codec.Channels, enc.SSRC, codec.PayloadType, headers)
// Whether should encode the audio-level in RTP header.
var audioLevel *webrtc.RTPHeaderExtensionParameter
for _, h := range headers {
if h.URI == sdp.AudioLevelURI {
audioLevel = &h
}
}
clock := newWallClock()
var lastGranule uint64
for ctx.Err() == nil {
pageData, pageHeader, err := ogg.ParseNextPage()
if err == io.EOF {
return io.EOF
}
if err != nil {
return errors.Wrapf(err, "Read ogg")
}
// The amount of samples is the difference between the last and current timestamp
sampleCount := uint64(pageHeader.GranulePosition - lastGranule)
lastGranule = pageHeader.GranulePosition
sampleDuration := time.Duration(uint64(time.Millisecond) * 1000 * sampleCount / uint64(codec.ClockRate))
// For audio-level, set the extensions if negotiated.
if ri := v.audioLevelInterceptor; ri.rtpWriter == nil {
ri.rtpWriter = func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) {
if audioLevel != nil {
audioLevelPayload, err := new(rtp.AudioLevelExtension).Marshal()
if err != nil {
return 0, err
}
header.SetExtension(uint8(audioLevel.ID), audioLevelPayload)
}
return ri.nextRTPWriter.Write(header, payload, attributes)
}
}
if err = track.WriteSample(media.Sample{Data: pageData, Duration: sampleDuration}); err != nil {
return errors.Wrapf(err, "Write sample")
}
if d := clock.Tick(sampleDuration); d > 0 {
time.Sleep(d)
}
}
return ctx.Err()
}

@ -0,0 +1,175 @@
// The MIT License (MIT)
//
// Copyright (c) 2021 srs-bench(ossrs)
//
// Permission is hereby granted, free of charge, to any person obtaining a copy of
// this software and associated documentation files (the "Software"), to deal in
// the Software without restriction, including without limitation the rights to
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
// the Software, and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
package srs
import (
"github.com/pion/interceptor"
"github.com/pion/rtcp"
"github.com/pion/rtp"
)
type RTPInterceptorOptionFunc func(i *RTPInterceptor)
// Common RTP packet interceptor for benchmark.
// @remark Should never merge with RTCPInterceptor, because they has the same Write interface.
type RTPInterceptor struct {
localInfo *interceptor.StreamInfo
remoteInfo *interceptor.StreamInfo
// If rtpReader is nil, use the default next one to read.
rtpReader interceptor.RTPReaderFunc
nextRTPReader interceptor.RTPReader
// If rtpWriter is nil, use the default next one to write.
rtpWriter interceptor.RTPWriterFunc
nextRTPWriter interceptor.RTPWriter
BypassInterceptor
}
func NewRTPInterceptor(options ...RTPInterceptorOptionFunc) *RTPInterceptor {
v := &RTPInterceptor{}
for _, opt := range options {
opt(v)
}
return v
}
func (v *RTPInterceptor) BindLocalStream(info *interceptor.StreamInfo, writer interceptor.RTPWriter) interceptor.RTPWriter {
if v.localInfo != nil {
return writer // Only handle one stream.
}
v.localInfo = info
v.nextRTPWriter = writer
return v // Handle all RTP
}
func (v *RTPInterceptor) Write(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) {
if v.rtpWriter != nil {
return v.rtpWriter(header, payload, attributes)
}
return v.nextRTPWriter.Write(header, payload, attributes)
}
func (v *RTPInterceptor) UnbindLocalStream(info *interceptor.StreamInfo) {
if v.localInfo == nil || v.localInfo.ID != info.ID {
return
}
v.localInfo = nil // Reset the interceptor.
}
func (v *RTPInterceptor) BindRemoteStream(info *interceptor.StreamInfo, reader interceptor.RTPReader) interceptor.RTPReader {
if v.remoteInfo != nil {
return reader // Only handle one stream.
}
v.nextRTPReader = reader
return v // Handle all RTP
}
func (v *RTPInterceptor) Read(b []byte, a interceptor.Attributes) (int, interceptor.Attributes, error) {
if v.rtpReader != nil {
return v.rtpReader(b, a)
}
return v.nextRTPReader.Read(b, a)
}
func (v *RTPInterceptor) UnbindRemoteStream(info *interceptor.StreamInfo) {
if v.remoteInfo == nil || v.remoteInfo.ID != info.ID {
return
}
v.remoteInfo = nil
}
type RTCPInterceptorOptionFunc func(i *RTCPInterceptor)
// Common RTCP packet interceptor for benchmark.
// @remark Should never merge with RTPInterceptor, because they has the same Write interface.
type RTCPInterceptor struct {
// If rtcpReader is nil, use the default next one to read.
rtcpReader interceptor.RTCPReaderFunc
nextRTCPReader interceptor.RTCPReader
// If rtcpWriter is nil, use the default next one to write.
rtcpWriter interceptor.RTCPWriterFunc
nextRTCPWriter interceptor.RTCPWriter
BypassInterceptor
}
func NewRTCPInterceptor(options ...RTCPInterceptorOptionFunc) *RTCPInterceptor {
v := &RTCPInterceptor{}
for _, opt := range options {
opt(v)
}
return v
}
func (v *RTCPInterceptor) BindRTCPReader(reader interceptor.RTCPReader) interceptor.RTCPReader {
v.nextRTCPReader = reader
return v // Handle all RTCP
}
func (v *RTCPInterceptor) Read(b []byte, a interceptor.Attributes) (int, interceptor.Attributes, error) {
if v.rtcpReader != nil {
return v.rtcpReader(b, a)
}
return v.nextRTCPReader.Read(b, a)
}
func (v *RTCPInterceptor) BindRTCPWriter(writer interceptor.RTCPWriter) interceptor.RTCPWriter {
v.nextRTCPWriter = writer
return v // Handle all RTCP
}
func (v *RTCPInterceptor) Write(pkts []rtcp.Packet, attributes interceptor.Attributes) (int, error) {
if v.rtcpWriter != nil {
return v.rtcpWriter(pkts, attributes)
}
return v.nextRTCPWriter.Write(pkts, attributes)
}
// Do nothing.
type BypassInterceptor struct {
interceptor.Interceptor
}
func (v *BypassInterceptor) BindRTCPReader(reader interceptor.RTCPReader) interceptor.RTCPReader {
return reader
}
func (v *BypassInterceptor) BindRTCPWriter(writer interceptor.RTCPWriter) interceptor.RTCPWriter {
return writer
}
func (v *BypassInterceptor) BindLocalStream(info *interceptor.StreamInfo, writer interceptor.RTPWriter) interceptor.RTPWriter {
return writer
}
func (v *BypassInterceptor) UnbindLocalStream(info *interceptor.StreamInfo) {
}
func (v *BypassInterceptor) BindRemoteStream(info *interceptor.StreamInfo, reader interceptor.RTPReader) interceptor.RTPReader {
return reader
}
func (v *BypassInterceptor) UnbindRemoteStream(info *interceptor.StreamInfo) {
}
func (v *BypassInterceptor) Close() error {
return nil
}

@ -1,3 +1,23 @@
// The MIT License (MIT)
//
// Copyright (c) 2021 srs-bench(ossrs)
//
// Permission is hereby granted, free of charge, to any person obtaining a copy of
// this software and associated documentation files (the "Software"), to deal in
// the Software without restriction, including without limitation the rights to
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
// the Software, and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
package srs package srs
import ( import (
@ -65,7 +85,14 @@ func StartPlay(ctx context.Context, r, dumpAudio, dumpVideo string, enableAudioL
if err != nil { if err != nil {
return errors.Wrapf(err, "Create PC") return errors.Wrapf(err, "Create PC")
} }
defer pc.Close()
var receivers []*webrtc.RTPReceiver
defer func() {
pc.Close()
for _, receiver := range receivers {
receiver.Stop()
}
}()
pc.AddTransceiverFromKind(webrtc.RTPCodecTypeAudio, webrtc.RTPTransceiverInit{ pc.AddTransceiverFromKind(webrtc.RTPCodecTypeAudio, webrtc.RTPTransceiverInit{
Direction: webrtc.RTPTransceiverDirectionRecvonly, Direction: webrtc.RTPTransceiverDirectionRecvonly,
@ -132,6 +159,8 @@ func StartPlay(ctx context.Context, r, dumpAudio, dumpVideo string, enableAudioL
} }
}() }()
receivers = append(receivers, receiver)
codec := track.Codec() codec := track.Codec()
trackDesc := fmt.Sprintf("channels=%v", codec.Channels) trackDesc := fmt.Sprintf("channels=%v", codec.Channels)

@ -1,20 +1,33 @@
// The MIT License (MIT)
//
// Copyright (c) 2021 srs-bench(ossrs)
//
// Permission is hereby granted, free of charge, to any person obtaining a copy of
// this software and associated documentation files (the "Software"), to deal in
// the Software without restriction, including without limitation the rights to
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
// the Software, and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
package srs package srs
import ( import (
"context" "context"
"github.com/ossrs/go-oryx-lib/errors" "github.com/ossrs/go-oryx-lib/errors"
"github.com/ossrs/go-oryx-lib/logger" "github.com/ossrs/go-oryx-lib/logger"
"github.com/ossrs/srs-bench/rtc"
"github.com/pion/interceptor" "github.com/pion/interceptor"
"github.com/pion/rtp"
"github.com/pion/sdp/v3" "github.com/pion/sdp/v3"
"github.com/pion/webrtc/v3" "github.com/pion/webrtc/v3"
"github.com/pion/webrtc/v3/pkg/media"
"github.com/pion/webrtc/v3/pkg/media/h264reader"
"github.com/pion/webrtc/v3/pkg/media/oggreader"
"io" "io"
"os"
"strings"
"sync" "sync"
"time" "time"
) )
@ -26,7 +39,12 @@ func StartPublish(ctx context.Context, r, sourceAudio, sourceVideo string, fps i
logger.Tf(ctx, "Start publish url=%v, audio=%v, video=%v, fps=%v, audio-level=%v, twcc=%v", logger.Tf(ctx, "Start publish url=%v, audio=%v, video=%v, fps=%v, audio-level=%v, twcc=%v",
r, sourceAudio, sourceVideo, fps, enableAudioLevel, enableTWCC) r, sourceAudio, sourceVideo, fps, enableAudioLevel, enableTWCC)
// For audio-level. // Filter for SPS/PPS marker.
var aIngester *audioIngester
var vIngester *videoIngester
// For audio-level and sps/pps marker.
// TODO: FIXME: Should share with player.
webrtcNewPeerConnection := func(configuration webrtc.Configuration) (*webrtc.PeerConnection, error) { webrtcNewPeerConnection := func(configuration webrtc.Configuration) (*webrtc.PeerConnection, error) {
m := &webrtc.MediaEngine{} m := &webrtc.MediaEngine{}
if err := m.RegisterDefaultCodecs(); err != nil { if err := m.RegisterDefaultCodecs(); err != nil {
@ -53,12 +71,21 @@ func StartPublish(ctx context.Context, r, sourceAudio, sourceVideo string, fps i
} }
} }
i := &interceptor.Registry{} registry := &interceptor.Registry{}
if err := webrtc.RegisterDefaultInterceptors(m, i); err != nil { if err := webrtc.RegisterDefaultInterceptors(m, registry); err != nil {
return nil, err return nil, err
} }
api := webrtc.NewAPI(webrtc.WithMediaEngine(m), webrtc.WithInterceptorRegistry(i)) if sourceAudio != "" {
aIngester = NewAudioIngester(sourceAudio)
registry.Add(aIngester.audioLevelInterceptor)
}
if sourceVideo != "" {
vIngester = NewVideoIngester(sourceVideo)
registry.Add(vIngester.markerInterceptor)
}
api := webrtc.NewAPI(webrtc.WithMediaEngine(m), webrtc.WithInterceptorRegistry(registry))
return api.NewPeerConnection(configuration) return api.NewPeerConnection(configuration)
} }
@ -66,46 +93,30 @@ func StartPublish(ctx context.Context, r, sourceAudio, sourceVideo string, fps i
if err != nil { if err != nil {
return errors.Wrapf(err, "Create PC") return errors.Wrapf(err, "Create PC")
} }
defer pc.Close()
var sVideoTrack *rtc.TrackLocalStaticSample
var sVideoSender *webrtc.RTPSender
if sourceVideo != "" {
mimeType, trackID := "video/H264", "video"
if strings.HasSuffix(sourceVideo, ".ivf") {
mimeType = "video/VP8"
}
sVideoTrack, err = rtc.NewTrackLocalStaticSample( doClose := func() {
webrtc.RTPCodecCapability{MimeType: mimeType, ClockRate: 90000}, trackID, "pion", if pc != nil {
) pc.Close()
if err != nil {
return errors.Wrapf(err, "Create video track")
} }
if vIngester != nil {
sVideoSender, err = pc.AddTrack(sVideoTrack) vIngester.Close()
if err != nil { }
return errors.Wrapf(err, "Add video track") if aIngester != nil {
aIngester.Close()
} }
sVideoSender.Stop()
} }
defer doClose()
var sAudioTrack *rtc.TrackLocalStaticSample if vIngester != nil {
var sAudioSender *webrtc.RTPSender if err := vIngester.AddTrack(pc, fps); err != nil {
if sourceAudio != "" { return errors.Wrapf(err, "Add track")
mimeType, trackID := "audio/opus", "audio"
sAudioTrack, err = rtc.NewTrackLocalStaticSample(
webrtc.RTPCodecCapability{MimeType: mimeType, ClockRate: 48000, Channels: 2}, trackID, "pion",
)
if err != nil {
return errors.Wrapf(err, "Create audio track")
} }
}
sAudioSender, err = pc.AddTrack(sAudioTrack) if aIngester != nil {
if err != nil { if err := aIngester.AddTrack(pc); err != nil {
return errors.Wrapf(err, "Add audio track") return errors.Wrapf(err, "Add track")
} }
defer sAudioSender.Stop()
} }
offer, err := pc.CreateOffer(nil) offer, err := pc.CreateOffer(nil)
@ -139,9 +150,11 @@ func StartPublish(ctx context.Context, r, sourceAudio, sourceVideo string, fps i
logger.Tf(ctx, "Signaling state %v", state) logger.Tf(ctx, "Signaling state %v", state)
}) })
sAudioSender.Transport().OnStateChange(func(state webrtc.DTLSTransportState) { if aIngester != nil {
logger.Tf(ctx, "DTLS state %v", state) aIngester.sAudioSender.Transport().OnStateChange(func(state webrtc.DTLSTransportState) {
}) logger.Tf(ctx, "DTLS state %v", state)
})
}
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
pcDone, pcDoneCancel := context.WithCancel(context.Background()) pcDone, pcDoneCancel := context.WithCancel(context.Background())
@ -165,11 +178,18 @@ func StartPublish(ctx context.Context, r, sourceAudio, sourceVideo string, fps i
// Wait for event from context or tracks. // Wait for event from context or tracks.
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
<-ctx.Done()
doClose() // Interrupt the RTCP read.
}()
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
if sAudioSender == nil { if aIngester == nil {
return return
} }
@ -181,7 +201,7 @@ func StartPublish(ctx context.Context, r, sourceAudio, sourceVideo string, fps i
buf := make([]byte, 1500) buf := make([]byte, 1500)
for ctx.Err() == nil { for ctx.Err() == nil {
if _, _, err := sAudioSender.Read(buf); err != nil { if _, _, err := aIngester.sAudioSender.Read(buf); err != nil {
return return
} }
} }
@ -191,7 +211,7 @@ func StartPublish(ctx context.Context, r, sourceAudio, sourceVideo string, fps i
go func() { go func() {
defer wg.Done() defer wg.Done()
if sAudioTrack == nil { if aIngester == nil {
return return
} }
@ -201,8 +221,9 @@ func StartPublish(ctx context.Context, r, sourceAudio, sourceVideo string, fps i
logger.Tf(ctx, "PC(ICE+DTLS+SRTP) done, start ingest audio %v", sourceAudio) logger.Tf(ctx, "PC(ICE+DTLS+SRTP) done, start ingest audio %v", sourceAudio)
} }
// Read audio and send out.
for ctx.Err() == nil { for ctx.Err() == nil {
if err := readAudioTrackFromDisk(ctx, sourceAudio, sAudioSender, sAudioTrack); err != nil { if err := aIngester.Ingest(ctx); err != nil {
if errors.Cause(err) == io.EOF { if errors.Cause(err) == io.EOF {
logger.Tf(ctx, "EOF, restart ingest audio %v", sourceAudio) logger.Tf(ctx, "EOF, restart ingest audio %v", sourceAudio)
continue continue
@ -216,7 +237,7 @@ func StartPublish(ctx context.Context, r, sourceAudio, sourceVideo string, fps i
go func() { go func() {
defer wg.Done() defer wg.Done()
if sVideoSender == nil { if vIngester == nil {
return return
} }
@ -228,7 +249,7 @@ func StartPublish(ctx context.Context, r, sourceAudio, sourceVideo string, fps i
buf := make([]byte, 1500) buf := make([]byte, 1500)
for ctx.Err() == nil { for ctx.Err() == nil {
if _, _, err := sVideoSender.Read(buf); err != nil { if _, _, err := vIngester.sVideoSender.Read(buf); err != nil {
return return
} }
} }
@ -238,7 +259,7 @@ func StartPublish(ctx context.Context, r, sourceAudio, sourceVideo string, fps i
go func() { go func() {
defer wg.Done() defer wg.Done()
if sVideoTrack == nil { if vIngester == nil {
return return
} }
@ -249,7 +270,7 @@ func StartPublish(ctx context.Context, r, sourceAudio, sourceVideo string, fps i
} }
for ctx.Err() == nil { for ctx.Err() == nil {
if err := readVideoTrackFromDisk(ctx, sourceVideo, sVideoSender, fps, sVideoTrack); err != nil { if err := vIngester.Ingest(ctx); err != nil {
if errors.Cause(err) == io.EOF { if errors.Cause(err) == io.EOF {
logger.Tf(ctx, "EOF, restart ingest video %v", sourceVideo) logger.Tf(ctx, "EOF, restart ingest video %v", sourceVideo)
continue continue
@ -276,154 +297,3 @@ func StartPublish(ctx context.Context, r, sourceAudio, sourceVideo string, fps i
wg.Wait() wg.Wait()
return nil return nil
} }
func readAudioTrackFromDisk(ctx context.Context, source string, sender *webrtc.RTPSender, track *rtc.TrackLocalStaticSample) error {
f, err := os.Open(source)
if err != nil {
return errors.Wrapf(err, "Open file %v", source)
}
defer f.Close()
ogg, _, err := oggreader.NewWith(f)
if err != nil {
return errors.Wrapf(err, "Open ogg %v", source)
}
enc := sender.GetParameters().Encodings[0]
codec := sender.GetParameters().Codecs[0]
headers := sender.GetParameters().HeaderExtensions
logger.Tf(ctx, "Audio %v, tbn=%v, channels=%v, ssrc=%v, pt=%v, header=%v",
codec.MimeType, codec.ClockRate, codec.Channels, enc.SSRC, codec.PayloadType, headers)
// Whether should encode the audio-level in RTP header.
var audioLevel *webrtc.RTPHeaderExtensionParameter
for _, h := range headers {
if h.URI == sdp.AudioLevelURI {
audioLevel = &h
}
}
clock := newWallClock()
var lastGranule uint64
for ctx.Err() == nil {
pageData, pageHeader, err := ogg.ParseNextPage()
if err == io.EOF {
return nil
}
if err != nil {
return errors.Wrapf(err, "Read ogg")
}
// The amount of samples is the difference between the last and current timestamp
sampleCount := uint64(pageHeader.GranulePosition - lastGranule)
lastGranule = pageHeader.GranulePosition
sampleDuration := time.Duration(uint64(time.Millisecond) * 1000 * sampleCount / uint64(codec.ClockRate))
// For audio-level, set the extensions if negotiated.
track.OnBeforeWritePacket = func(p *rtp.Packet) {
if audioLevel != nil {
if b, err := new(rtp.AudioLevelExtension).Marshal(); err == nil {
p.SetExtension(uint8(audioLevel.ID), b)
}
}
}
if err = track.WriteSample(media.Sample{Data: pageData, Duration: sampleDuration}); err != nil {
return errors.Wrapf(err, "Write sample")
}
if d := clock.Tick(sampleDuration); d > 0 {
time.Sleep(d)
}
}
return nil
}
func readVideoTrackFromDisk(ctx context.Context, source string, sender *webrtc.RTPSender, fps int, track *rtc.TrackLocalStaticSample) error {
f, err := os.Open(source)
if err != nil {
return errors.Wrapf(err, "Open file %v", source)
}
defer f.Close()
// TODO: FIXME: Support ivf for vp8.
h264, err := h264reader.NewReader(f)
if err != nil {
return errors.Wrapf(err, "Open h264 %v", source)
}
enc := sender.GetParameters().Encodings[0]
codec := sender.GetParameters().Codecs[0]
headers := sender.GetParameters().HeaderExtensions
logger.Tf(ctx, "Video %v, tbn=%v, fps=%v, ssrc=%v, pt=%v, header=%v",
codec.MimeType, codec.ClockRate, fps, enc.SSRC, codec.PayloadType, headers)
clock := newWallClock()
sampleDuration := time.Duration(uint64(time.Millisecond) * 1000 / uint64(fps))
for ctx.Err() == nil {
var sps, pps *h264reader.NAL
var oFrames []*h264reader.NAL
for ctx.Err() == nil {
frame, err := h264.NextNAL()
if err == io.EOF {
return nil
}
if err != nil {
return errors.Wrapf(err, "Read h264")
}
oFrames = append(oFrames, frame)
logger.If(ctx, "NALU %v PictureOrderCount=%v, ForbiddenZeroBit=%v, RefIdc=%v, %v bytes",
frame.UnitType.String(), frame.PictureOrderCount, frame.ForbiddenZeroBit, frame.RefIdc, len(frame.Data))
if frame.UnitType == h264reader.NalUnitTypeSPS {
sps = frame
} else if frame.UnitType == h264reader.NalUnitTypePPS {
pps = frame
} else {
break
}
}
var frames []*h264reader.NAL
// Package SPS/PPS to STAP-A
if sps != nil && pps != nil {
stapA := packageAsSTAPA(sps, pps)
frames = append(frames, stapA)
}
// Append other original frames.
for _, frame := range oFrames {
if frame.UnitType != h264reader.NalUnitTypeSPS && frame.UnitType != h264reader.NalUnitTypePPS {
frames = append(frames, frame)
}
}
// Covert frames to sample(buffers).
for i, frame := range frames {
sample := media.Sample{Data: frame.Data, Duration: sampleDuration}
// Use the sample timestamp for frames.
if i != len(frames)-1 {
sample.Duration = 0
}
// For STAP-A, set marker to false, to make Chrome happy.
track.OnBeforeWritePacket = func(p *rtp.Packet) {
if i < len(frames)-1 {
p.Header.Marker = false
}
}
if err = track.WriteSample(sample); err != nil {
return errors.Wrapf(err, "Write sample")
}
}
if d := clock.Tick(sampleDuration); d > 0 {
time.Sleep(d)
}
}
return nil
}

File diff suppressed because it is too large Load Diff

@ -1,3 +1,23 @@
// The MIT License (MIT)
//
// Copyright (c) 2021 srs-bench(ossrs)
//
// Permission is hereby granted, free of charge, to any person obtaining a copy of
// this software and associated documentation files (the "Software"), to deal in
// the Software without restriction, including without limitation the rights to
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
// the Software, and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
package srs package srs
import ( import (

@ -1,3 +1,23 @@
// The MIT License (MIT)
//
// Copyright (c) 2021 srs-bench(ossrs)
//
// Permission is hereby granted, free of charge, to any person obtaining a copy of
// this software and associated documentation files (the "Software"), to deal in
// the Software without restriction, including without limitation the rights to
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
// the Software, and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
package srs package srs
import ( import (
@ -7,10 +27,14 @@ import (
"fmt" "fmt"
"github.com/ossrs/go-oryx-lib/errors" "github.com/ossrs/go-oryx-lib/errors"
"github.com/ossrs/go-oryx-lib/logger" "github.com/ossrs/go-oryx-lib/logger"
"github.com/pion/transport/vnet"
"github.com/pion/webrtc/v3"
"github.com/pion/webrtc/v3/pkg/media/h264reader" "github.com/pion/webrtc/v3/pkg/media/h264reader"
"io/ioutil" "io/ioutil"
"net"
"net/http" "net/http"
"net/url" "net/url"
"strconv"
"strings" "strings"
"time" "time"
) )
@ -140,3 +164,305 @@ func (v *wallClock) Tick(d time.Duration) time.Duration {
} }
return 0 return 0
} }
// Set to active, as DTLS client, to start ClientHello.
func testUtilSetupActive(s *webrtc.SessionDescription) error {
if strings.Contains(s.SDP, "setup:passive") {
return errors.New("set to active")
}
s.SDP = strings.ReplaceAll(s.SDP, "setup:actpass", "setup:active")
return nil
}
// Set to passive, as DTLS client, to start ClientHello.
func testUtilSetupPassive(s *webrtc.SessionDescription) error {
if strings.Contains(s.SDP, "setup:active") {
return errors.New("set to passive")
}
s.SDP = strings.ReplaceAll(s.SDP, "setup:actpass", "setup:passive")
return nil
}
// Parse address from SDP.
// candidate:0 1 udp 2130706431 192.168.3.8 8000 typ host generation 0
func parseAddressOfCandidate(answerSDP string) (*net.UDPAddr, error) {
answer := webrtc.SessionDescription{Type: webrtc.SDPTypeAnswer, SDP: answerSDP}
answerObject, err := answer.Unmarshal()
if err != nil {
return nil, errors.Wrapf(err, "unmarshal answer %v", answerSDP)
}
if len(answerObject.MediaDescriptions) == 0 {
return nil, errors.New("no media")
}
candidate, ok := answerObject.MediaDescriptions[0].Attribute("candidate")
if !ok {
return nil, errors.New("no candidate")
}
// candidate:0 1 udp 2130706431 192.168.3.8 8000 typ host generation 0
attrs := strings.Split(candidate, " ")
if len(attrs) <= 6 {
return nil, errors.Errorf("no address in %v", candidate)
}
// Parse ip and port from answer.
ip := attrs[4]
port, err := strconv.Atoi(attrs[5])
if err != nil {
return nil, errors.Wrapf(err, "invalid port %v", candidate)
}
address := fmt.Sprintf("%v:%v", ip, port)
addr, err := net.ResolveUDPAddr("udp4", address)
if err != nil {
return nil, errors.Wrapf(err, "parse %v", address)
}
return addr, nil
}
// Filter the test error, ignore context.Canceled
func filterTestError(errs ...error) error {
var filteredErrors []error
for _, err := range errs {
if err == nil || errors.Cause(err) == context.Canceled {
continue
}
filteredErrors = append(filteredErrors, err)
}
if len(filteredErrors) == 0 {
return nil
}
if len(filteredErrors) == 1 {
return filteredErrors[0]
}
var descs []string
for i, err := range filteredErrors[1:] {
descs = append(descs, fmt.Sprintf("err #%d, %+v", i, err))
}
return errors.Wrapf(filteredErrors[0], "with %v", strings.Join(descs, ","))
}
// For STUN packet, 0x00 is binding request, 0x01 is binding success response.
// @see srs_is_stun of https://github.com/ossrs/srs
func srsIsStun(b []byte) bool {
return len(b) > 0 && (b[0] == 0 || b[0] == 1)
}
// change_cipher_spec(20), alert(21), handshake(22), application_data(23)
// @see https://tools.ietf.org/html/rfc2246#section-6.2.1
// @see srs_is_dtls of https://github.com/ossrs/srs
func srsIsDTLS(b []byte) bool {
return (len(b) >= 13 && (b[0] > 19 && b[0] < 64))
}
// For RTP or RTCP, the V=2 which is in the high 2bits, 0xC0 (1100 0000)
// @see srs_is_rtp_or_rtcp of https://github.com/ossrs/srs
func srsIsRTPOrRTCP(b []byte) bool {
return (len(b) >= 12 && (b[0]&0xC0) == 0x80)
}
// For RTCP, PT is [128, 223] (or without marker [0, 95]).
// Literally, RTCP starts from 64 not 0, so PT is [192, 223] (or without marker [64, 95]).
// @note For RTP, the PT is [96, 127], or [224, 255] with marker.
// @see srs_is_rtcp of https://github.com/ossrs/srs
func srsIsRTCP(b []byte) bool {
return (len(b) >= 12) && (b[0]&0x80) != 0 && (b[1] >= 192 && b[1] <= 223)
}
type ChunkType int
const (
ChunkTypeICE ChunkType = iota + 1
ChunkTypeDTLS
ChunkTypeRTP
ChunkTypeRTCP
)
func (v ChunkType) String() string {
switch v {
case ChunkTypeICE:
return "ICE"
case ChunkTypeDTLS:
return "DTLS"
case ChunkTypeRTP:
return "RTP"
case ChunkTypeRTCP:
return "RTCP"
default:
return "Unknown"
}
}
type DTLSContentType int
const (
DTLSContentTypeHandshake DTLSContentType = 22
DTLSContentTypeChangeCipherSpec DTLSContentType = 20
DTLSContentTypeAlert DTLSContentType = 21
)
func (v DTLSContentType) String() string {
switch v {
case DTLSContentTypeHandshake:
return "Handshake"
case DTLSContentTypeChangeCipherSpec:
return "ChangeCipherSpec"
default:
return "Unknown"
}
}
type DTLSHandshakeType int
const (
DTLSHandshakeTypeClientHello DTLSHandshakeType = 1
DTLSHandshakeTypeServerHello DTLSHandshakeType = 2
DTLSHandshakeTypeCertificate DTLSHandshakeType = 11
DTLSHandshakeTypeServerKeyExchange DTLSHandshakeType = 12
DTLSHandshakeTypeCertificateRequest DTLSHandshakeType = 13
DTLSHandshakeTypeServerDone DTLSHandshakeType = 14
DTLSHandshakeTypeCertificateVerify DTLSHandshakeType = 15
DTLSHandshakeTypeClientKeyExchange DTLSHandshakeType = 16
DTLSHandshakeTypeFinished DTLSHandshakeType = 20
)
func (v DTLSHandshakeType) String() string {
switch v {
case DTLSHandshakeTypeClientHello:
return "ClientHello"
case DTLSHandshakeTypeServerHello:
return "ServerHello"
case DTLSHandshakeTypeCertificate:
return "Certificate"
case DTLSHandshakeTypeServerKeyExchange:
return "ServerKeyExchange"
case DTLSHandshakeTypeCertificateRequest:
return "CertificateRequest"
case DTLSHandshakeTypeServerDone:
return "ServerDone"
case DTLSHandshakeTypeCertificateVerify:
return "CertificateVerify"
case DTLSHandshakeTypeClientKeyExchange:
return "ClientKeyExchange"
case DTLSHandshakeTypeFinished:
return "Finished"
default:
return "Unknown"
}
}
type ChunkMessageType struct {
chunk ChunkType
content DTLSContentType
handshake DTLSHandshakeType
}
func (v *ChunkMessageType) String() string {
if v.chunk == ChunkTypeDTLS {
return fmt.Sprintf("%v-%v-%v", v.chunk, v.content, v.handshake)
}
return fmt.Sprintf("%v", v.chunk)
}
func NewChunkMessageType(c vnet.Chunk) (*ChunkMessageType, bool) {
b := c.UserData()
if len(b) == 0 {
return nil, false
}
v := &ChunkMessageType{}
if srsIsRTPOrRTCP(b) {
if srsIsRTCP(b) {
v.chunk = ChunkTypeRTCP
} else {
v.chunk = ChunkTypeRTP
}
return v, true
}
if srsIsStun(b) {
v.chunk = ChunkTypeICE
return v, true
}
if !srsIsDTLS(b) {
return nil, false
}
v.chunk, v.content = ChunkTypeDTLS, DTLSContentType(b[0])
if v.content != DTLSContentTypeHandshake {
return v, true
}
if len(b) < 14 {
return v, false
}
v.handshake = DTLSHandshakeType(b[13])
return v, true
}
func (v *ChunkMessageType) IsHandshake() bool {
return v.chunk == ChunkTypeDTLS && v.content == DTLSContentTypeHandshake
}
func (v *ChunkMessageType) IsClientHello() bool {
return v.chunk == ChunkTypeDTLS && v.content == DTLSContentTypeHandshake && v.handshake == DTLSHandshakeTypeClientHello
}
func (v *ChunkMessageType) IsServerHello() bool {
return v.chunk == ChunkTypeDTLS && v.content == DTLSContentTypeHandshake && v.handshake == DTLSHandshakeTypeServerHello
}
func (v *ChunkMessageType) IsCertificate() bool {
return v.chunk == ChunkTypeDTLS && v.content == DTLSContentTypeHandshake && v.handshake == DTLSHandshakeTypeCertificate
}
func (v *ChunkMessageType) IsChangeCipherSpec() bool {
return v.chunk == ChunkTypeDTLS && v.content == DTLSContentTypeChangeCipherSpec
}
type DTLSRecord struct {
ContentType DTLSContentType
Version uint16
Epoch uint16
SequenceNumber uint64
Length uint16
Data []byte
}
func NewDTLSRecord(b []byte) (*DTLSRecord, error) {
v := &DTLSRecord{}
return v, v.Unmarshal(b)
}
func (v *DTLSRecord) String() string {
return fmt.Sprintf("epoch=%v, sequence=%v", v.Epoch, v.SequenceNumber)
}
func (v *DTLSRecord) Equals(p *DTLSRecord) bool {
return v.Epoch == p.Epoch && v.SequenceNumber == p.SequenceNumber
}
func (v *DTLSRecord) Unmarshal(b []byte) error {
if len(b) < 13 {
return errors.Errorf("requires 13B only %v", len(b))
}
v.ContentType = DTLSContentType(uint8(b[0]))
v.Version = uint16(b[1])<<8 | uint16(b[2])
v.Epoch = uint16(b[3])<<8 | uint16(b[4])
v.SequenceNumber = uint64(b[5])<<40 | uint64(b[6])<<32 | uint64(b[7])<<24 | uint64(b[8])<<16 | uint64(b[9])<<8 | uint64(b[10])
v.Length = uint16(b[11])<<8 | uint16(b[12])
v.Data = b[13:]
return nil
}

@ -0,0 +1,723 @@
// The MIT License (MIT)
//
// Copyright (c) 2021 srs-bench(ossrs)
//
// Permission is hereby granted, free of charge, to any person obtaining a copy of
// this software and associated documentation files (the "Software"), to deal in
// the Software without restriction, including without limitation the rights to
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
// the Software, and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
package srs
import (
"context"
"encoding/json"
"flag"
"fmt"
"github.com/ossrs/go-oryx-lib/errors"
"github.com/ossrs/go-oryx-lib/logger"
vnet_proxy "github.com/ossrs/srs-bench/vnet"
"github.com/pion/interceptor"
"github.com/pion/logging"
"github.com/pion/rtcp"
"github.com/pion/transport/vnet"
"github.com/pion/webrtc/v3"
"io"
"io/ioutil"
"net/http"
"os"
"path"
"strings"
"sync"
"testing"
"time"
)
var srsSchema = "http"
var srsHttps = flag.Bool("srs-https", false, "Whther connect to HTTPS-API")
var srsServer = flag.String("srs-server", "127.0.0.1", "The RTC server to connect to")
var srsStream = flag.String("srs-stream", "/rtc/regression", "The RTC stream to play")
var srsLog = flag.Bool("srs-log", false, "Whether enable the detail log")
var srsTimeout = flag.Int("srs-timeout", 5000, "For each case, the timeout in ms")
var srsPlayPLI = flag.Int("srs-play-pli", 5000, "The PLI interval in seconds for player.")
var srsPlayOKPackets = flag.Int("srs-play-ok-packets", 10, "If got N packets, it's ok, or fail")
var srsPublishOKPackets = flag.Int("srs-publish-ok-packets", 10, "If send N packets, it's ok, or fail")
var srsPublishAudio = flag.String("srs-publish-audio", "avatar.ogg", "The audio file for publisher.")
var srsPublishVideo = flag.String("srs-publish-video", "avatar.h264", "The video file for publisher.")
var srsPublishVideoFps = flag.Int("srs-publish-video-fps", 25, "The video fps for publisher.")
var srsVnetClientIP = flag.String("srs-vnet-client-ip", "192.168.168.168", "The client ip in pion/vnet.")
var srsDTLSDropPackets = flag.Int("srs-dtls-drop-packets", 5, "If dropped N packets, it's ok, or fail")
func prepareTest() error {
var err error
// Should parse it first.
flag.Parse()
// The stream should starts with /, for example, /rtc/regression
if !strings.HasPrefix(*srsStream, "/") {
*srsStream = "/" + *srsStream
}
// Generate srs protocol from whether use HTTPS.
if *srsHttps {
srsSchema = "https"
}
// Check file.
tryOpenFile := func(filename string) (string, error) {
if filename == "" {
return filename, nil
}
f, err := os.Open(filename)
if err != nil {
nfilename := path.Join("../", filename)
f2, err := os.Open(nfilename)
if err != nil {
return filename, errors.Wrapf(err, "No video file at %v or %v", filename, nfilename)
}
defer f2.Close()
return nfilename, nil
}
defer f.Close()
return filename, nil
}
if *srsPublishVideo, err = tryOpenFile(*srsPublishVideo); err != nil {
return err
}
if *srsPublishAudio, err = tryOpenFile(*srsPublishAudio); err != nil {
return err
}
return nil
}
func TestMain(m *testing.M) {
if err := prepareTest(); err != nil {
logger.Ef(nil, "Prepare test fail, err %+v", err)
os.Exit(-1)
}
// Disable the logger during all tests.
if *srsLog == false {
olw := logger.Switch(ioutil.Discard)
defer func() {
logger.Switch(olw)
}()
}
os.Exit(m.Run())
}
type TestWebRTCAPIOptionFunc func(api *TestWebRTCAPI)
type TestWebRTCAPI struct {
// The options to setup the api.
options []TestWebRTCAPIOptionFunc
// The api and settings.
api *webrtc.API
mediaEngine *webrtc.MediaEngine
registry *interceptor.Registry
settingEngine *webrtc.SettingEngine
// The vnet router, can be shared by different apis, but we do not share it.
router *vnet.Router
// The network for api.
network *vnet.Net
// The vnet UDP proxy bind to the router.
proxy *vnet_proxy.UDPProxy
}
func NewTestWebRTCAPI(options ...TestWebRTCAPIOptionFunc) (*TestWebRTCAPI, error) {
v := &TestWebRTCAPI{}
v.mediaEngine = &webrtc.MediaEngine{}
if err := v.mediaEngine.RegisterDefaultCodecs(); err != nil {
return nil, err
}
v.registry = &interceptor.Registry{}
if err := webrtc.RegisterDefaultInterceptors(v.mediaEngine, v.registry); err != nil {
return nil, err
}
for _, setup := range options {
setup(v)
}
v.settingEngine = &webrtc.SettingEngine{}
return v, nil
}
func (v *TestWebRTCAPI) Close() error {
if v.proxy != nil {
v.proxy.Close()
v.proxy = nil
}
if v.router != nil {
v.router.Stop()
v.router = nil
}
return nil
}
func (v *TestWebRTCAPI) Setup(vnetClientIP string, options ...TestWebRTCAPIOptionFunc) error {
// Setting engine for https://github.com/pion/transport/tree/master/vnet
setupVnet := func(vnetClientIP string) (err error) {
// We create a private router for a api, however, it's possible to share the
// same router between apis.
if v.router, err = vnet.NewRouter(&vnet.RouterConfig{
CIDR: "0.0.0.0/0", // Accept all ip, no sub router.
LoggerFactory: logging.NewDefaultLoggerFactory(),
}); err != nil {
return errors.Wrapf(err, "create router for api")
}
// Each api should bind to a network, however, it's possible to share it
// for different apis.
v.network = vnet.NewNet(&vnet.NetConfig{
StaticIP: vnetClientIP,
})
if err = v.router.AddNet(v.network); err != nil {
return errors.Wrapf(err, "create network for api")
}
v.settingEngine.SetVNet(v.network)
// Create a proxy bind to the router.
if v.proxy, err = vnet_proxy.NewProxy(v.router); err != nil {
return errors.Wrapf(err, "create proxy for router")
}
return v.router.Start()
}
if err := setupVnet(vnetClientIP); err != nil {
return err
}
for _, setup := range options {
setup(v)
}
for _, setup := range v.options {
setup(v)
}
v.api = webrtc.NewAPI(
webrtc.WithMediaEngine(v.mediaEngine),
webrtc.WithInterceptorRegistry(v.registry),
webrtc.WithSettingEngine(*v.settingEngine),
)
return nil
}
func (v *TestWebRTCAPI) NewPeerConnection(configuration webrtc.Configuration) (*webrtc.PeerConnection, error) {
return v.api.NewPeerConnection(configuration)
}
type TestPlayerOptionFunc func(p *TestPlayer)
type TestPlayer struct {
pc *webrtc.PeerConnection
receivers []*webrtc.RTPReceiver
// root api object
api *TestWebRTCAPI
// Optional suffix for stream url.
streamSuffix string
}
func NewTestPlayer(api *TestWebRTCAPI, options ...TestPlayerOptionFunc) *TestPlayer {
v := &TestPlayer{api: api}
for _, opt := range options {
opt(v)
}
return v
}
func (v *TestPlayer) Close() error {
if v.pc != nil {
v.pc.Close()
v.pc = nil
}
for _, receiver := range v.receivers {
receiver.Stop()
}
v.receivers = nil
return nil
}
func (v *TestPlayer) Run(ctx context.Context, cancel context.CancelFunc) error {
r := fmt.Sprintf("%v://%v%v", srsSchema, *srsServer, *srsStream)
if v.streamSuffix != "" {
r = fmt.Sprintf("%v-%v", r, v.streamSuffix)
}
pli := time.Duration(*srsPlayPLI) * time.Millisecond
logger.Tf(ctx, "Start play url=%v", r)
pc, err := v.api.NewPeerConnection(webrtc.Configuration{})
if err != nil {
return errors.Wrapf(err, "Create PC")
}
v.pc = pc
pc.AddTransceiverFromKind(webrtc.RTPCodecTypeAudio, webrtc.RTPTransceiverInit{
Direction: webrtc.RTPTransceiverDirectionRecvonly,
})
pc.AddTransceiverFromKind(webrtc.RTPCodecTypeVideo, webrtc.RTPTransceiverInit{
Direction: webrtc.RTPTransceiverDirectionRecvonly,
})
offer, err := pc.CreateOffer(nil)
if err != nil {
return errors.Wrapf(err, "Create Offer")
}
if err := pc.SetLocalDescription(offer); err != nil {
return errors.Wrapf(err, "Set offer %v", offer)
}
answer, err := apiRtcRequest(ctx, "/rtc/v1/play", r, offer.SDP)
if err != nil {
return errors.Wrapf(err, "Api request offer=%v", offer.SDP)
}
// Start a proxy for real server and vnet.
if address, err := parseAddressOfCandidate(answer); err != nil {
return errors.Wrapf(err, "parse address of %v", answer)
} else if err := v.api.proxy.Proxy(v.api.network, address); err != nil {
return errors.Wrapf(err, "proxy %v to %v", v.api.network, address)
}
if err := pc.SetRemoteDescription(webrtc.SessionDescription{
Type: webrtc.SDPTypeAnswer, SDP: answer,
}); err != nil {
return errors.Wrapf(err, "Set answer %v", answer)
}
handleTrack := func(ctx context.Context, track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) error {
// Send a PLI on an interval so that the publisher is pushing a keyframe
go func() {
if track.Kind() == webrtc.RTPCodecTypeAudio {
return
}
for {
select {
case <-ctx.Done():
return
case <-time.After(pli):
_ = pc.WriteRTCP([]rtcp.Packet{&rtcp.PictureLossIndication{
MediaSSRC: uint32(track.SSRC()),
}})
}
}
}()
v.receivers = append(v.receivers, receiver)
for ctx.Err() == nil {
_, _, err := track.ReadRTP()
if err != nil {
return errors.Wrapf(err, "Read RTP")
}
}
return nil
}
pc.OnTrack(func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) {
err = handleTrack(ctx, track, receiver)
if err != nil {
codec := track.Codec()
err = errors.Wrapf(err, "Handle track %v, pt=%v", codec.MimeType, codec.PayloadType)
cancel()
}
})
pc.OnICEConnectionStateChange(func(state webrtc.ICEConnectionState) {
if state == webrtc.ICEConnectionStateFailed || state == webrtc.ICEConnectionStateClosed {
err = errors.Errorf("Close for ICE state %v", state)
cancel()
}
})
<-ctx.Done()
return err
}
type TestPublisherOptionFunc func(p *TestPublisher)
type TestPublisher struct {
onOffer func(s *webrtc.SessionDescription) error
onAnswer func(s *webrtc.SessionDescription) error
iceReadyCancel context.CancelFunc
// internal objects
aIngester *audioIngester
vIngester *videoIngester
pc *webrtc.PeerConnection
// root api object
api *TestWebRTCAPI
// Optional suffix for stream url.
streamSuffix string
}
func NewTestPublisher(api *TestWebRTCAPI, options ...TestPublisherOptionFunc) *TestPublisher {
sourceVideo, sourceAudio := *srsPublishVideo, *srsPublishAudio
v := &TestPublisher{api: api}
for _, opt := range options {
opt(v)
}
// Create ingesters.
if sourceAudio != "" {
v.aIngester = NewAudioIngester(sourceAudio)
}
if sourceVideo != "" {
v.vIngester = NewVideoIngester(sourceVideo)
}
// Setup the interceptors for packets.
api.options = append(api.options, func(api *TestWebRTCAPI) {
// Filter for RTCP packets.
rtcpInterceptor := &RTCPInterceptor{}
rtcpInterceptor.rtcpReader = func(buf []byte, attributes interceptor.Attributes) (int, interceptor.Attributes, error) {
return rtcpInterceptor.nextRTCPReader.Read(buf, attributes)
}
rtcpInterceptor.rtcpWriter = func(pkts []rtcp.Packet, attributes interceptor.Attributes) (int, error) {
return rtcpInterceptor.nextRTCPWriter.Write(pkts, attributes)
}
api.registry.Add(rtcpInterceptor)
// Filter for ingesters.
if sourceAudio != "" {
api.registry.Add(v.aIngester.audioLevelInterceptor)
}
if sourceVideo != "" {
api.registry.Add(v.vIngester.markerInterceptor)
}
})
return v
}
func (v *TestPublisher) Close() error {
if v.vIngester != nil {
v.vIngester.Close()
}
if v.aIngester != nil {
v.aIngester.Close()
}
if v.pc != nil {
v.pc.Close()
}
return nil
}
func (v *TestPublisher) SetStreamSuffix(suffix string) *TestPublisher {
v.streamSuffix = suffix
return v
}
func (v *TestPublisher) Run(ctx context.Context, cancel context.CancelFunc) error {
r := fmt.Sprintf("%v://%v%v", srsSchema, *srsServer, *srsStream)
if v.streamSuffix != "" {
r = fmt.Sprintf("%v-%v", r, v.streamSuffix)
}
sourceVideo, sourceAudio, fps := *srsPublishVideo, *srsPublishAudio, *srsPublishVideoFps
logger.Tf(ctx, "Start publish url=%v, audio=%v, video=%v, fps=%v",
r, sourceAudio, sourceVideo, fps)
pc, err := v.api.NewPeerConnection(webrtc.Configuration{})
if err != nil {
return errors.Wrapf(err, "Create PC")
}
v.pc = pc
if v.vIngester != nil {
if err := v.vIngester.AddTrack(pc, fps); err != nil {
return errors.Wrapf(err, "Add track")
}
defer v.vIngester.Close()
}
if v.aIngester != nil {
if err := v.aIngester.AddTrack(pc); err != nil {
return errors.Wrapf(err, "Add track")
}
defer v.aIngester.Close()
}
offer, err := pc.CreateOffer(nil)
if err != nil {
return errors.Wrapf(err, "Create Offer")
}
if err := pc.SetLocalDescription(offer); err != nil {
return errors.Wrapf(err, "Set offer %v", offer)
}
if v.onOffer != nil {
if err := v.onOffer(&offer); err != nil {
return errors.Wrapf(err, "sdp %v %v", offer.Type, offer.SDP)
}
}
answerSDP, err := apiRtcRequest(ctx, "/rtc/v1/publish", r, offer.SDP)
if err != nil {
return errors.Wrapf(err, "Api request offer=%v", offer.SDP)
}
// Start a proxy for real server and vnet.
if address, err := parseAddressOfCandidate(answerSDP); err != nil {
return errors.Wrapf(err, "parse address of %v", answerSDP)
} else if err := v.api.proxy.Proxy(v.api.network, address); err != nil {
return errors.Wrapf(err, "proxy %v to %v", v.api.network, address)
}
answer := &webrtc.SessionDescription{
Type: webrtc.SDPTypeAnswer, SDP: answerSDP,
}
if v.onAnswer != nil {
if err := v.onAnswer(answer); err != nil {
return errors.Wrapf(err, "on answerSDP")
}
}
if err := pc.SetRemoteDescription(*answer); err != nil {
return errors.Wrapf(err, "Set answerSDP %v", answerSDP)
}
logger.Tf(ctx, "State signaling=%v, ice=%v, conn=%v", pc.SignalingState(), pc.ICEConnectionState(), pc.ConnectionState())
// ICE state management.
pc.OnICEGatheringStateChange(func(state webrtc.ICEGathererState) {
logger.Tf(ctx, "ICE gather state %v", state)
})
pc.OnICECandidate(func(candidate *webrtc.ICECandidate) {
logger.Tf(ctx, "ICE candidate %v %v:%v", candidate.Protocol, candidate.Address, candidate.Port)
})
pc.OnICEConnectionStateChange(func(state webrtc.ICEConnectionState) {
logger.Tf(ctx, "ICE state %v", state)
})
pc.OnSignalingStateChange(func(state webrtc.SignalingState) {
logger.Tf(ctx, "Signaling state %v", state)
})
if v.aIngester != nil {
v.aIngester.sAudioSender.Transport().OnStateChange(func(state webrtc.DTLSTransportState) {
logger.Tf(ctx, "DTLS state %v", state)
})
}
pcDone, pcDoneCancel := context.WithCancel(context.Background())
pc.OnConnectionStateChange(func(state webrtc.PeerConnectionState) {
logger.Tf(ctx, "PC state %v", state)
if state == webrtc.PeerConnectionStateConnected {
pcDoneCancel()
if v.iceReadyCancel != nil {
v.iceReadyCancel()
}
}
if state == webrtc.PeerConnectionStateFailed || state == webrtc.PeerConnectionStateClosed {
err = errors.Errorf("Close for PC state %v", state)
cancel()
}
})
// Wait for event from context or tracks.
var wg sync.WaitGroup
var finalErr error
wg.Add(1)
go func() {
defer wg.Done()
defer logger.Tf(ctx, "ingest notify done")
<-ctx.Done()
if v.aIngester != nil && v.aIngester.sAudioSender != nil {
v.aIngester.sAudioSender.Stop()
}
if v.vIngester != nil && v.vIngester.sVideoSender != nil {
v.vIngester.sVideoSender.Stop()
}
}()
wg.Add(1)
go func() {
defer wg.Done()
defer cancel()
if v.aIngester == nil {
return
}
select {
case <-ctx.Done():
return
case <-pcDone.Done():
}
wg.Add(1)
go func() {
defer wg.Done()
defer logger.Tf(ctx, "aingester sender read done")
buf := make([]byte, 1500)
for ctx.Err() == nil {
if _, _, err := v.aIngester.sAudioSender.Read(buf); err != nil {
return
}
}
}()
for {
if err := v.aIngester.Ingest(ctx); err != nil {
if err == io.EOF {
logger.Tf(ctx, "aingester retry for %v", err)
continue
}
if err != context.Canceled {
finalErr = errors.Wrapf(err, "audio")
}
logger.Tf(ctx, "aingester err=%v, final=%v", err, finalErr)
return
}
}
}()
wg.Add(1)
go func() {
defer wg.Done()
defer cancel()
if v.vIngester == nil {
return
}
select {
case <-ctx.Done():
return
case <-pcDone.Done():
logger.Tf(ctx, "PC(ICE+DTLS+SRTP) done, start ingest video %v", sourceVideo)
}
wg.Add(1)
go func() {
defer wg.Done()
defer logger.Tf(ctx, "vingester sender read done")
buf := make([]byte, 1500)
for ctx.Err() == nil {
// The Read() might block in r.rtcpInterceptor.Read(b, a),
// so that the Stop() can not stop it.
if _, _, err := v.vIngester.sVideoSender.Read(buf); err != nil {
return
}
}
}()
for {
if err := v.vIngester.Ingest(ctx); err != nil {
if err == io.EOF {
logger.Tf(ctx, "vingester retry for %v", err)
continue
}
if err != context.Canceled {
finalErr = errors.Wrapf(err, "video")
}
logger.Tf(ctx, "vingester err=%v, final=%v", err, finalErr)
return
}
}
}()
wg.Wait()
logger.Tf(ctx, "ingester done ctx=%v, final=%v", ctx.Err(), finalErr)
if finalErr != nil {
return finalErr
}
return ctx.Err()
}
func TestRTCServerVersion(t *testing.T) {
api := fmt.Sprintf("http://%v:1985/api/v1/versions", *srsServer)
req, err := http.NewRequest("POST", api, nil)
if err != nil {
t.Errorf("Request %v", api)
return
}
res, err := http.DefaultClient.Do(req)
if err != nil {
t.Errorf("Do request %v", api)
return
}
b, err := ioutil.ReadAll(res.Body)
if err != nil {
t.Errorf("Read body of %v", api)
return
}
obj := struct {
Code int `json:"code"`
Server string `json:"server"`
Data struct {
Major int `json:"major"`
Minor int `json:"minor"`
Revision int `json:"revision"`
Version string `json:"version"`
} `json:"data"`
}{}
if err := json.Unmarshal(b, &obj); err != nil {
t.Errorf("Parse %v", string(b))
return
}
if obj.Code != 0 {
t.Errorf("Server err code=%v, server=%v", obj.Code, obj.Server)
return
}
if obj.Data.Major == 0 && obj.Data.Minor == 0 {
t.Errorf("Invalid version %v", obj.Data)
return
}
}

@ -0,0 +1,278 @@
// The MIT License (MIT)
//
// Copyright (c) 2021 srs-bench(ossrs)
//
// Permission is hereby granted, free of charge, to any person obtaining a copy of
// this software and associated documentation files (the "Software"), to deal in
// the Software without restriction, including without limitation the rights to
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
// the Software, and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
package vnet_test
import (
"net"
vnet_proxy "github.com/ossrs/srs-bench/vnet"
"github.com/pion/logging"
"github.com/pion/transport/vnet"
)
// Proxy many vnet endpoint to one real server endpoint.
// For example:
// vnet(10.0.0.11:5787) => proxy => 192.168.1.10:8000
// vnet(10.0.0.11:5788) => proxy => 192.168.1.10:8000
// vnet(10.0.0.11:5789) => proxy => 192.168.1.10:8000
func ExampleUDPProxyManyToOne() { // nolint:govet
var clientNetwork *vnet.Net
var serverAddr *net.UDPAddr
if addr, err := net.ResolveUDPAddr("udp4", "192.168.1.10:8000"); err != nil {
// handle error
} else {
serverAddr = addr
}
// Setup the network and proxy.
if true {
// Create vnet WAN with one endpoint, please read from
// https://github.com/pion/transport/tree/master/vnet#example-wan-with-one-endpoint-vnet
router, err := vnet.NewRouter(&vnet.RouterConfig{
CIDR: "0.0.0.0/0",
LoggerFactory: logging.NewDefaultLoggerFactory(),
})
if err != nil {
// handle error
}
// Create a network and add to router, for example, for client.
clientNetwork = vnet.NewNet(&vnet.NetConfig{
StaticIP: "10.0.0.11",
})
if err = router.AddNet(clientNetwork); err != nil {
// handle error
}
// Start the router.
if err = router.Start(); err != nil {
// handle error
}
defer router.Stop() // nolint:errcheck
// Create a proxy, bind to the router.
proxy, err := vnet_proxy.NewProxy(router)
if err != nil {
// handle error
}
defer proxy.Close() // nolint:errcheck
// Start to proxy some addresses, clientNetwork is a hit for proxy,
// that the client in vnet is from this network.
if err := proxy.Proxy(clientNetwork, serverAddr); err != nil {
// handle error
}
}
// Now, all packets from client, will be proxy to real server, vice versa.
client0, err := clientNetwork.ListenPacket("udp4", "10.0.0.11:5787")
if err != nil {
// handle error
}
_, _ = client0.WriteTo([]byte("Hello"), serverAddr)
client1, err := clientNetwork.ListenPacket("udp4", "10.0.0.11:5788")
if err != nil {
// handle error
}
_, _ = client1.WriteTo([]byte("Hello"), serverAddr)
client2, err := clientNetwork.ListenPacket("udp4", "10.0.0.11:5789")
if err != nil {
// handle error
}
_, _ = client2.WriteTo([]byte("Hello"), serverAddr)
}
// Proxy many vnet endpoint to one real server endpoint.
// For example:
// vnet(10.0.0.11:5787) => proxy => 192.168.1.10:8000
// vnet(10.0.0.11:5788) => proxy => 192.168.1.10:8000
func ExampleUDPProxyMultileTimes() { // nolint:govet
var clientNetwork *vnet.Net
var serverAddr *net.UDPAddr
if addr, err := net.ResolveUDPAddr("udp4", "192.168.1.10:8000"); err != nil {
// handle error
} else {
serverAddr = addr
}
// Setup the network and proxy.
var proxy *vnet_proxy.UDPProxy
if true {
// Create vnet WAN with one endpoint, please read from
// https://github.com/pion/transport/tree/master/vnet#example-wan-with-one-endpoint-vnet
router, err := vnet.NewRouter(&vnet.RouterConfig{
CIDR: "0.0.0.0/0",
LoggerFactory: logging.NewDefaultLoggerFactory(),
})
if err != nil {
// handle error
}
// Create a network and add to router, for example, for client.
clientNetwork = vnet.NewNet(&vnet.NetConfig{
StaticIP: "10.0.0.11",
})
if err = router.AddNet(clientNetwork); err != nil {
// handle error
}
// Start the router.
if err = router.Start(); err != nil {
// handle error
}
defer router.Stop() // nolint:errcheck
// Create a proxy, bind to the router.
proxy, err = vnet_proxy.NewProxy(router)
if err != nil {
// handle error
}
defer proxy.Close() // nolint:errcheck
}
if true {
// Start to proxy some addresses, clientNetwork is a hit for proxy,
// that the client in vnet is from this network.
if err := proxy.Proxy(clientNetwork, serverAddr); err != nil {
// handle error
}
// Now, all packets from client, will be proxy to real server, vice versa.
client0, err := clientNetwork.ListenPacket("udp4", "10.0.0.11:5787")
if err != nil {
// handle error
}
_, _ = client0.WriteTo([]byte("Hello"), serverAddr)
}
if true {
// It's ok to proxy multiple times, for example, the publisher and player
// might need to proxy when got answer.
if err := proxy.Proxy(clientNetwork, serverAddr); err != nil {
// handle error
}
client1, err := clientNetwork.ListenPacket("udp4", "10.0.0.11:5788")
if err != nil {
// handle error
}
_, _ = client1.WriteTo([]byte("Hello"), serverAddr)
}
}
// Proxy one vnet endpoint to one real server endpoint.
// For example:
// vnet(10.0.0.11:5787) => proxy0 => 192.168.1.10:8000
// vnet(10.0.0.11:5788) => proxy1 => 192.168.1.10:8001
// vnet(10.0.0.11:5789) => proxy2 => 192.168.1.10:8002
func ExampleUDPProxyOneToOne() { // nolint:govet
var clientNetwork *vnet.Net
var serverAddr0 *net.UDPAddr
if addr, err := net.ResolveUDPAddr("udp4", "192.168.1.10:8000"); err != nil {
// handle error
} else {
serverAddr0 = addr
}
var serverAddr1 *net.UDPAddr
if addr, err := net.ResolveUDPAddr("udp4", "192.168.1.10:8001"); err != nil {
// handle error
} else {
serverAddr1 = addr
}
var serverAddr2 *net.UDPAddr
if addr, err := net.ResolveUDPAddr("udp4", "192.168.1.10:8002"); err != nil {
// handle error
} else {
serverAddr2 = addr
}
// Setup the network and proxy.
if true {
// Create vnet WAN with one endpoint, please read from
// https://github.com/pion/transport/tree/master/vnet#example-wan-with-one-endpoint-vnet
router, err := vnet.NewRouter(&vnet.RouterConfig{
CIDR: "0.0.0.0/0",
LoggerFactory: logging.NewDefaultLoggerFactory(),
})
if err != nil {
// handle error
}
// Create a network and add to router, for example, for client.
clientNetwork = vnet.NewNet(&vnet.NetConfig{
StaticIP: "10.0.0.11",
})
if err = router.AddNet(clientNetwork); err != nil {
// handle error
}
// Start the router.
if err = router.Start(); err != nil {
// handle error
}
defer router.Stop() // nolint:errcheck
// Create a proxy, bind to the router.
proxy, err := vnet_proxy.NewProxy(router)
if err != nil {
// handle error
}
defer proxy.Close() // nolint:errcheck
// Start to proxy some addresses, clientNetwork is a hit for proxy,
// that the client in vnet is from this network.
if err := proxy.Proxy(clientNetwork, serverAddr0); err != nil {
// handle error
}
if err := proxy.Proxy(clientNetwork, serverAddr1); err != nil {
// handle error
}
if err := proxy.Proxy(clientNetwork, serverAddr2); err != nil {
// handle error
}
}
// Now, all packets from client, will be proxy to real server, vice versa.
client0, err := clientNetwork.ListenPacket("udp4", "10.0.0.11:5787")
if err != nil {
// handle error
}
_, _ = client0.WriteTo([]byte("Hello"), serverAddr0)
client1, err := clientNetwork.ListenPacket("udp4", "10.0.0.11:5788")
if err != nil {
// handle error
}
_, _ = client1.WriteTo([]byte("Hello"), serverAddr1)
client2, err := clientNetwork.ListenPacket("udp4", "10.0.0.11:5789")
if err != nil {
// handle error
}
_, _ = client2.WriteTo([]byte("Hello"), serverAddr2)
}

@ -0,0 +1,222 @@
// The MIT License (MIT)
//
// Copyright (c) 2021 srs-bench(ossrs)
//
// Permission is hereby granted, free of charge, to any person obtaining a copy of
// this software and associated documentation files (the "Software"), to deal in
// the Software without restriction, including without limitation the rights to
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
// the Software, and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
package vnet
import (
"net"
"sync"
"time"
"github.com/pion/transport/vnet"
)
// A UDP proxy between real server(net.UDPConn) and vnet.UDPConn.
//
// High level design:
// ..............................................
// : Virtual Network (vnet) :
// : :
// +-------+ * 1 +----+ +--------+ :
// | :App |------------>|:Net|--o<-----|:Router | .............................
// +-------+ +----+ | | : UDPProxy :
// : | | +----+ +---------+ +---------+ +--------+
// : | |--->o--|:Net|-->o-| vnet. |-->o-| net. |--->-| :Real |
// : | | +----+ | UDPConn | | UDPConn | | Server |
// : | | : +---------+ +---------+ +--------+
// : | | ............................:
// : +--------+ :
// ...............................................
//
// The whole big picture:
// ......................................
// : Virtual Network (vnet) :
// : :
// +-------+ * 1 +----+ +--------+ :
// | :App |------------>|:Net|--o<-----|:Router | .............................
// +-------+ +----+ | | : UDPProxy :
// +-----------+ * 1 +----+ | | +----+ +---------+ +---------+ +--------+
// |:STUNServer|-------->|:Net|--o<-----| |--->o--|:Net|-->o-| vnet. |-->o-| net. |--->-| :Real |
// +-----------+ +----+ | | +----+ | UDPConn | | UDPConn | | Server |
// +-----------+ * 1 +----+ | | : +---------+ +---------+ +--------+
// |:TURNServer|-------->|:Net|--o<-----| | ............................:
// +-----------+ +----+ [1] | | :
// : 1 | | 1 <<has>> :
// : +---<>| |<>----+ [2] :
// : | +--------+ | :
// To form | *| v 0..1 :
// a subnet tree | o [3] +-----+ :
// : | ^ |:NAT | :
// : | | +-----+ :
// : +-------+ :
// ......................................
type UDPProxy struct {
// The router bind to.
router *vnet.Router
// Each vnet source, bind to a real socket to server.
// key is real server addr, which is net.Addr
// value is *aUDPProxyWorker
workers sync.Map
// For each endpoint, we never know when to start and stop proxy,
// so we stop the endpoint when timeout.
timeout time.Duration
// For utest, to mock the target real server.
// Optional, use the address of received client packet.
mockRealServerAddr *net.UDPAddr
}
// NewProxy create a proxy, the router for this proxy belongs/bind to. If need to proxy for
// please create a new proxy for each router. For all addresses we proxy, we will create a
// vnet.Net in this router and proxy all packets.
func NewProxy(router *vnet.Router) (*UDPProxy, error) {
v := &UDPProxy{router: router, timeout: 2 * time.Minute}
return v, nil
}
// Close the proxy, stop all workers.
func (v *UDPProxy) Close() error {
// nolint:godox // TODO: FIXME: Do cleanup.
return nil
}
// Proxy starts a worker for server, ignore if already started.
func (v *UDPProxy) Proxy(client *vnet.Net, server *net.UDPAddr) error {
// Note that even if the worker exists, it's also ok to create a same worker,
// because the router will use the last one, and the real server will see a address
// change event after we switch to the next worker.
if _, ok := v.workers.Load(server.String()); ok {
// nolint:godox // TODO: Need to restart the stopped worker?
return nil
}
// Not exists, create a new one.
worker := &aUDPProxyWorker{
router: v.router, mockRealServerAddr: v.mockRealServerAddr,
}
v.workers.Store(server.String(), worker)
return worker.Proxy(client, server)
}
// A proxy worker for a specified proxy server.
type aUDPProxyWorker struct {
router *vnet.Router
mockRealServerAddr *net.UDPAddr
// Each vnet source, bind to a real socket to server.
// key is vnet client addr, which is net.Addr
// value is *net.UDPConn
endpoints sync.Map
}
func (v *aUDPProxyWorker) Proxy(client *vnet.Net, serverAddr *net.UDPAddr) error { // nolint:gocognit
// Create vnet for real server by serverAddr.
nw := vnet.NewNet(&vnet.NetConfig{
StaticIP: serverAddr.IP.String(),
})
if err := v.router.AddNet(nw); err != nil {
return err
}
// We must create a "same" vnet.UDPConn as the net.UDPConn,
// which has the same ip:port, to copy packets between them.
vnetSocket, err := nw.ListenUDP("udp4", serverAddr)
if err != nil {
return err
}
// Start a proxy goroutine.
var findEndpointBy func(addr net.Addr) (*net.UDPConn, error)
// nolint:godox // TODO: FIXME: Do cleanup.
go func() {
buf := make([]byte, 1500)
for {
n, addr, err := vnetSocket.ReadFrom(buf)
if err != nil {
return
}
if n <= 0 || addr == nil {
continue // Drop packet
}
realSocket, err := findEndpointBy(addr)
if err != nil {
continue // Drop packet.
}
if _, err := realSocket.Write(buf[:n]); err != nil {
return
}
}
}()
// Got new vnet client, start a new endpoint.
findEndpointBy = func(addr net.Addr) (*net.UDPConn, error) {
// Exists binding.
if value, ok := v.endpoints.Load(addr.String()); ok {
// Exists endpoint, reuse it.
return value.(*net.UDPConn), nil
}
// The real server we proxy to, for utest to mock it.
realAddr := serverAddr
if v.mockRealServerAddr != nil {
realAddr = v.mockRealServerAddr
}
// Got new vnet client, create new endpoint.
realSocket, err := net.DialUDP("udp4", nil, realAddr)
if err != nil {
return nil, err
}
// Bind address.
v.endpoints.Store(addr.String(), realSocket)
// Got packet from real serverAddr, we should proxy it to vnet.
// nolint:godox // TODO: FIXME: Do cleanup.
go func(vnetClientAddr net.Addr) {
buf := make([]byte, 1500)
for {
n, _, err := realSocket.ReadFrom(buf)
if err != nil {
return
}
if n <= 0 {
continue // Drop packet
}
if _, err := vnetSocket.WriteTo(buf[:n], vnetClientAddr); err != nil {
return
}
}
}(addr)
return realSocket, nil
}
return nil
}

@ -0,0 +1,61 @@
// The MIT License (MIT)
//
// Copyright (c) 2021 srs-bench(ossrs)
//
// Permission is hereby granted, free of charge, to any person obtaining a copy of
// this software and associated documentation files (the "Software"), to deal in
// the Software without restriction, including without limitation the rights to
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
// the Software, and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
package vnet
import (
"net"
)
func (v *UDPProxy) Deliver(sourceAddr, destAddr net.Addr, b []byte) (nn int, err error) {
v.workers.Range(func(key, value interface{}) bool {
if nn, err := value.(*aUDPProxyWorker).Deliver(sourceAddr, destAddr, b); err != nil {
return false // Fail, abort.
} else if nn == len(b) {
return false // Done.
}
return true // Deliver by next worker.
})
return
}
func (v *aUDPProxyWorker) Deliver(sourceAddr, destAddr net.Addr, b []byte) (nn int, err error) {
addr, ok := sourceAddr.(*net.UDPAddr)
if !ok {
return 0, nil
}
// TODO: Support deliver packet from real server to vnet.
// If packet is from vent, proxy to real server.
var realSocket *net.UDPConn
if value, ok := v.endpoints.Load(addr.String()); !ok {
return 0, nil
} else {
realSocket = value.(*net.UDPConn)
}
// Send to real server.
if _, err := realSocket.Write(b); err != nil {
return 0, err
}
return len(b), nil
}

@ -0,0 +1,184 @@
// The MIT License (MIT)
//
// Copyright (c) 2021 srs-bench(ossrs)
//
// Permission is hereby granted, free of charge, to any person obtaining a copy of
// this software and associated documentation files (the "Software"), to deal in
// the Software without restriction, including without limitation the rights to
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
// the Software, and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
package vnet
import (
"context"
"fmt"
"github.com/pion/logging"
"github.com/pion/transport/vnet"
"net"
"sync"
"testing"
"time"
)
// vnet client:
// 10.0.0.11:5787
// proxy to real server:
// 192.168.1.10:8000
func TestUDPProxyDirectDeliver(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
var r0, r1, r2 error
defer func() {
if r0 != nil || r1 != nil || r2 != nil {
t.Errorf("fail for ctx=%v, r0=%v, r1=%v, r2=%v", ctx.Err(), r0, r1, r2)
}
}()
var wg sync.WaitGroup
defer wg.Wait()
// Timeout, fail
wg.Add(1)
go func() {
defer wg.Done()
defer cancel()
select {
case <-ctx.Done():
case <-time.After(time.Duration(*testTimeout) * time.Millisecond):
r2 = fmt.Errorf("timeout")
}
}()
// For utest, we always proxy vnet packets to the random port we listen to.
mockServer := NewMockUDPEchoServer()
wg.Add(1)
go func() {
defer wg.Done()
defer cancel()
if err := mockServer.doMockUDPServer(ctx); err != nil {
r0 = err
}
}()
// Create a vent and proxy.
wg.Add(1)
go func() {
defer wg.Done()
defer cancel()
// When real server is ready, start the vnet test.
select {
case <-ctx.Done():
return
case <-mockServer.realServerReady.Done():
}
doVnetProxy := func() error {
router, err := vnet.NewRouter(&vnet.RouterConfig{
CIDR: "0.0.0.0/0",
LoggerFactory: logging.NewDefaultLoggerFactory(),
})
if err != nil {
return err
}
clientNetwork := vnet.NewNet(&vnet.NetConfig{
StaticIP: "10.0.0.11",
})
if err = router.AddNet(clientNetwork); err != nil {
return err
}
if err := router.Start(); err != nil {
return err
}
defer router.Stop()
proxy, err := NewProxy(router)
if err != nil {
return err
}
defer proxy.Close()
// For utest, mock the target real server.
proxy.mockRealServerAddr = mockServer.realServerAddr
// The real server address to proxy to.
// Note that for utest, we will proxy to a local address.
serverAddr, err := net.ResolveUDPAddr("udp4", "192.168.1.10:8000")
if err != nil {
return err
}
if err := proxy.Proxy(clientNetwork, serverAddr); err != nil {
return err
}
// Now, all packets from client, will be proxy to real server, vice versa.
client, err := clientNetwork.ListenPacket("udp4", "10.0.0.11:5787")
if err != nil {
return err
}
// When system quit, interrupt client.
selfKill, selfKillCancel := context.WithCancel(context.Background())
go func() {
<-ctx.Done()
selfKillCancel()
client.Close()
}()
// Write by vnet client.
if _, err := client.WriteTo([]byte("Hello"), serverAddr); err != nil {
return err
}
buf := make([]byte, 1500)
if n, addr, err := client.ReadFrom(buf); err != nil {
if selfKill.Err() == context.Canceled {
return nil
}
return err
} else if n != 5 || addr == nil {
return fmt.Errorf("n=%v, addr=%v", n, addr)
} else if string(buf[:n]) != "Hello" {
return fmt.Errorf("data %v", buf[:n])
}
// Directly write, simulate the ARQ packet.
// We should got the echo packet also.
if _, err := proxy.Deliver(client.LocalAddr(), serverAddr, []byte("Hello")); err != nil {
return err
}
if n, addr, err := client.ReadFrom(buf); err != nil {
if selfKill.Err() == context.Canceled {
return nil
}
return err
} else if n != 5 || addr == nil {
return fmt.Errorf("n=%v, addr=%v", n, addr)
} else if string(buf[:n]) != "Hello" {
return fmt.Errorf("data %v", buf[:n])
}
return err
}
if err := doVnetProxy(); err != nil {
r1 = err
}
}()
}

@ -0,0 +1,615 @@
// The MIT License (MIT)
//
// Copyright (c) 2021 srs-bench(ossrs)
//
// Permission is hereby granted, free of charge, to any person obtaining a copy of
// this software and associated documentation files (the "Software"), to deal in
// the Software without restriction, including without limitation the rights to
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
// the Software, and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
package vnet
import (
"context"
"errors"
"flag"
"fmt"
"net"
"os"
"sync"
"testing"
"time"
"github.com/pion/logging"
"github.com/pion/transport/vnet"
)
type MockUDPEchoServer struct {
realServerAddr *net.UDPAddr
realServerReady context.Context
realServerReadyCancel context.CancelFunc
}
func NewMockUDPEchoServer() *MockUDPEchoServer {
v := &MockUDPEchoServer{}
v.realServerReady, v.realServerReadyCancel = context.WithCancel(context.Background())
return v
}
func (v *MockUDPEchoServer) doMockUDPServer(ctx context.Context) error {
// Listen to a random port.
laddr, err := net.ResolveUDPAddr("udp4", "127.0.0.1:0")
if err != nil {
return err
}
conn, err := net.ListenUDP("udp4", laddr)
if err != nil {
return err
}
v.realServerAddr = conn.LocalAddr().(*net.UDPAddr)
v.realServerReadyCancel()
// When system quit, interrupt client.
selfKill, selfKillCancel := context.WithCancel(context.Background())
go func() {
<-ctx.Done()
selfKillCancel()
_ = conn.Close()
}()
// Note that if they has the same ID, the address should not changed.
addrs := make(map[string]net.Addr)
// Start an echo UDP server.
buf := make([]byte, 1500)
for ctx.Err() == nil {
n, addr, err := conn.ReadFrom(buf)
if err != nil {
if errors.Is(selfKill.Err(), context.Canceled) {
return nil
}
return err
} else if n == 0 || addr == nil {
return fmt.Errorf("n=%v, addr=%v", n, addr) // nolint:goerr113
} else if nn, err := conn.WriteTo(buf[:n], addr); err != nil {
return err
} else if nn != n {
return fmt.Errorf("nn=%v, n=%v", nn, n) // nolint:goerr113
}
// Check the address, shold not change, use content as ID.
clientID := string(buf[:n])
if oldAddr, ok := addrs[clientID]; ok && oldAddr.String() != addr.String() {
return fmt.Errorf("address change %v to %v", oldAddr.String(), addr.String()) // nolint:goerr113
}
addrs[clientID] = addr
}
return nil
}
var testTimeout = flag.Int("timeout", 5000, "For each case, the timeout in ms") // nolint:gochecknoglobals
func TestMain(m *testing.M) {
flag.Parse()
os.Exit(m.Run())
}
// vnet client:
// 10.0.0.11:5787
// proxy to real server:
// 192.168.1.10:8000
func TestUDPProxyOne2One(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
var r0, r1, r2 error
defer func() {
if r0 != nil || r1 != nil || r2 != nil {
t.Errorf("fail for ctx=%v, r0=%v, r1=%v, r2=%v", ctx.Err(), r0, r1, r2)
}
}()
var wg sync.WaitGroup
defer wg.Wait()
// Timeout, fail
wg.Add(1)
go func() {
defer wg.Done()
defer cancel()
select {
case <-ctx.Done():
case <-time.After(time.Duration(*testTimeout) * time.Millisecond):
r2 = fmt.Errorf("timeout") // nolint:goerr113
}
}()
// For utest, we always proxy vnet packets to the random port we listen to.
mockServer := NewMockUDPEchoServer()
wg.Add(1)
go func() {
defer wg.Done()
defer cancel()
if err := mockServer.doMockUDPServer(ctx); err != nil {
r0 = err
}
}()
// Create a vent and proxy.
wg.Add(1)
go func() {
defer wg.Done()
defer cancel()
// When real server is ready, start the vnet test.
select {
case <-ctx.Done():
return
case <-mockServer.realServerReady.Done():
}
doVnetProxy := func() error {
router, err := vnet.NewRouter(&vnet.RouterConfig{
CIDR: "0.0.0.0/0",
LoggerFactory: logging.NewDefaultLoggerFactory(),
})
if err != nil {
return err
}
clientNetwork := vnet.NewNet(&vnet.NetConfig{
StaticIP: "10.0.0.11",
})
if err = router.AddNet(clientNetwork); err != nil {
return err
}
if err = router.Start(); err != nil {
return err
}
defer router.Stop() // nolint:errcheck
proxy, err := NewProxy(router)
if err != nil {
return err
}
defer proxy.Close() // nolint:errcheck
// For utest, mock the target real server.
proxy.mockRealServerAddr = mockServer.realServerAddr
// The real server address to proxy to.
// Note that for utest, we will proxy to a local address.
serverAddr, err := net.ResolveUDPAddr("udp4", "192.168.1.10:8000")
if err != nil {
return err
}
if err = proxy.Proxy(clientNetwork, serverAddr); err != nil {
return err
}
// Now, all packets from client, will be proxy to real server, vice versa.
client, err := clientNetwork.ListenPacket("udp4", "10.0.0.11:5787")
if err != nil {
return err
}
// When system quit, interrupt client.
selfKill, selfKillCancel := context.WithCancel(context.Background())
go func() {
<-ctx.Done()
selfKillCancel()
_ = client.Close() // nolint:errcheck
}()
for i := 0; i < 10; i++ {
if _, err = client.WriteTo([]byte("Hello"), serverAddr); err != nil {
return err
}
var n int
var addr net.Addr
buf := make([]byte, 1500)
if n, addr, err = client.ReadFrom(buf); err != nil { // nolint:gocritic
if errors.Is(selfKill.Err(), context.Canceled) {
return nil
}
return err
} else if n != 5 || addr == nil {
return fmt.Errorf("n=%v, addr=%v", n, addr) // nolint:goerr113
} else if string(buf[:n]) != "Hello" {
return fmt.Errorf("data %v", buf[:n]) // nolint:goerr113
}
// Wait for awhile for each UDP packet, to simulate real network.
select {
case <-ctx.Done():
return nil
case <-time.After(30 * time.Millisecond):
}
}
return err
}
if err := doVnetProxy(); err != nil {
r1 = err
}
}()
}
// vnet client:
// 10.0.0.11:5787
// 10.0.0.11:5788
// proxy to real server:
// 192.168.1.10:8000
func TestUDPProxyTwo2One(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
var r0, r1, r2, r3 error
defer func() {
if r0 != nil || r1 != nil || r2 != nil || r3 != nil {
t.Errorf("fail for ctx=%v, r0=%v, r1=%v, r2=%v, r3=%v", ctx.Err(), r0, r1, r2, r3)
}
}()
var wg sync.WaitGroup
defer wg.Wait()
// Timeout, fail
wg.Add(1)
go func() {
defer wg.Done()
defer cancel()
select {
case <-ctx.Done():
case <-time.After(time.Duration(*testTimeout) * time.Millisecond):
r2 = fmt.Errorf("timeout") // nolint:goerr113
}
}()
// For utest, we always proxy vnet packets to the random port we listen to.
mockServer := NewMockUDPEchoServer()
wg.Add(1)
go func() {
defer wg.Done()
defer cancel()
if err := mockServer.doMockUDPServer(ctx); err != nil {
r0 = err
}
}()
// Create a vent and proxy.
wg.Add(1)
go func() {
defer wg.Done()
defer cancel()
// When real server is ready, start the vnet test.
select {
case <-ctx.Done():
return
case <-mockServer.realServerReady.Done():
}
doVnetProxy := func() error {
router, err := vnet.NewRouter(&vnet.RouterConfig{
CIDR: "0.0.0.0/0",
LoggerFactory: logging.NewDefaultLoggerFactory(),
})
if err != nil {
return err
}
clientNetwork := vnet.NewNet(&vnet.NetConfig{
StaticIP: "10.0.0.11",
})
if err = router.AddNet(clientNetwork); err != nil {
return err
}
if err = router.Start(); err != nil {
return err
}
defer router.Stop() // nolint:errcheck
proxy, err := NewProxy(router)
if err != nil {
return err
}
defer proxy.Close() // nolint:errcheck
// For utest, mock the target real server.
proxy.mockRealServerAddr = mockServer.realServerAddr
// The real server address to proxy to.
// Note that for utest, we will proxy to a local address.
serverAddr, err := net.ResolveUDPAddr("udp4", "192.168.1.10:8000")
if err != nil {
return err
}
if err = proxy.Proxy(clientNetwork, serverAddr); err != nil {
return err
}
handClient := func(address, echoData string) error {
// Now, all packets from client, will be proxy to real server, vice versa.
client, err := clientNetwork.ListenPacket("udp4", address) // nolint:govet
if err != nil {
return err
}
// When system quit, interrupt client.
selfKill, selfKillCancel := context.WithCancel(context.Background())
go func() {
<-ctx.Done()
selfKillCancel()
_ = client.Close()
}()
for i := 0; i < 10; i++ {
if _, err := client.WriteTo([]byte(echoData), serverAddr); err != nil { // nolint:govet
return err
}
var n int
var addr net.Addr
buf := make([]byte, 1400)
if n, addr, err = client.ReadFrom(buf); err != nil { // nolint:gocritic
if errors.Is(selfKill.Err(), context.Canceled) {
return nil
}
return err
} else if n != len(echoData) || addr == nil {
return fmt.Errorf("n=%v, addr=%v", n, addr) // nolint:goerr113
} else if string(buf[:n]) != echoData {
return fmt.Errorf("check data %v", buf[:n]) // nolint:goerr113
}
// Wait for awhile for each UDP packet, to simulate real network.
select {
case <-ctx.Done():
return nil
case <-time.After(30 * time.Millisecond):
}
}
return nil
}
client0, client0Cancel := context.WithCancel(context.Background())
go func() {
defer client0Cancel()
address := "10.0.0.11:5787"
if err := handClient(address, "Hello"); err != nil { // nolint:govet
r3 = fmt.Errorf("client %v err %v", address, err) // nolint:goerr113
}
}()
client1, client1Cancel := context.WithCancel(context.Background())
go func() {
defer client1Cancel()
address := "10.0.0.11:5788"
if err := handClient(address, "World"); err != nil { // nolint:govet
r3 = fmt.Errorf("client %v err %v", address, err) // nolint:goerr113
}
}()
select {
case <-ctx.Done():
case <-client0.Done():
case <-client1.Done():
}
return err
}
if err := doVnetProxy(); err != nil {
r1 = err
}
}()
}
// vnet client:
// 10.0.0.11:5787
// proxy to real server:
// 192.168.1.10:8000
//
// vnet client:
// 10.0.0.11:5788
// proxy to real server:
// 192.168.1.10:8000
func TestUDPProxyProxyTwice(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
var r0, r1, r2, r3 error
defer func() {
if r0 != nil || r1 != nil || r2 != nil || r3 != nil {
t.Errorf("fail for ctx=%v, r0=%v, r1=%v, r2=%v, r3=%v", ctx.Err(), r0, r1, r2, r3)
}
}()
var wg sync.WaitGroup
defer wg.Wait()
// Timeout, fail
wg.Add(1)
go func() {
defer wg.Done()
defer cancel()
select {
case <-ctx.Done():
case <-time.After(time.Duration(*testTimeout) * time.Millisecond):
r2 = fmt.Errorf("timeout") // nolint:goerr113
}
}()
// For utest, we always proxy vnet packets to the random port we listen to.
mockServer := NewMockUDPEchoServer()
wg.Add(1)
go func() {
defer wg.Done()
defer cancel()
if err := mockServer.doMockUDPServer(ctx); err != nil {
r0 = err
}
}()
// Create a vent and proxy.
wg.Add(1)
go func() {
defer wg.Done()
defer cancel()
// When real server is ready, start the vnet test.
select {
case <-ctx.Done():
return
case <-mockServer.realServerReady.Done():
}
doVnetProxy := func() error {
router, err := vnet.NewRouter(&vnet.RouterConfig{
CIDR: "0.0.0.0/0",
LoggerFactory: logging.NewDefaultLoggerFactory(),
})
if err != nil {
return err
}
clientNetwork := vnet.NewNet(&vnet.NetConfig{
StaticIP: "10.0.0.11",
})
if err = router.AddNet(clientNetwork); err != nil {
return err
}
if err = router.Start(); err != nil {
return err
}
defer router.Stop() // nolint:errcheck
proxy, err := NewProxy(router)
if err != nil {
return err
}
defer proxy.Close() // nolint:errcheck
// For utest, mock the target real server.
proxy.mockRealServerAddr = mockServer.realServerAddr
// The real server address to proxy to.
// Note that for utest, we will proxy to a local address.
serverAddr, err := net.ResolveUDPAddr("udp4", "192.168.1.10:8000")
if err != nil {
return err
}
handClient := func(address, echoData string) error {
// We proxy multiple times, for example, in publisher and player, both call
// the proxy when got answer.
if err := proxy.Proxy(clientNetwork, serverAddr); err != nil { // nolint:govet
return err
}
// Now, all packets from client, will be proxy to real server, vice versa.
client, err := clientNetwork.ListenPacket("udp4", address) // nolint:govet
if err != nil {
return err
}
// When system quit, interrupt client.
selfKill, selfKillCancel := context.WithCancel(context.Background())
go func() {
<-ctx.Done()
selfKillCancel()
_ = client.Close() // nolint:errcheck
}()
for i := 0; i < 10; i++ {
if _, err = client.WriteTo([]byte(echoData), serverAddr); err != nil {
return err
}
buf := make([]byte, 1500)
if n, addr, err := client.ReadFrom(buf); err != nil { // nolint:gocritic,govet
if errors.Is(selfKill.Err(), context.Canceled) {
return nil
}
return err
} else if n != len(echoData) || addr == nil {
return fmt.Errorf("n=%v, addr=%v", n, addr) // nolint:goerr113
} else if string(buf[:n]) != echoData {
return fmt.Errorf("verify data %v", buf[:n]) // nolint:goerr113
}
// Wait for awhile for each UDP packet, to simulate real network.
select {
case <-ctx.Done():
return nil
case <-time.After(30 * time.Millisecond):
}
}
return nil
}
client0, client0Cancel := context.WithCancel(context.Background())
go func() {
defer client0Cancel()
address := "10.0.0.11:5787"
if err = handClient(address, "Hello"); err != nil {
r3 = fmt.Errorf("client %v err %v", address, err) // nolint:goerr113
}
}()
client1, client1Cancel := context.WithCancel(context.Background())
go func() {
defer client1Cancel()
// Slower than client0, 60ms.
// To simulate the real player or publisher, might not start at the same time.
select {
case <-ctx.Done():
return
case <-time.After(150 * time.Millisecond):
}
address := "10.0.0.11:5788"
if err = handClient(address, "World"); err != nil {
r3 = fmt.Errorf("client %v err %v", address, err) // nolint:goerr113
}
}()
select {
case <-ctx.Done():
case <-client0.Done():
case <-client1.Done():
}
return err
}
if err := doVnetProxy(); err != nil {
r1 = err
}
}()
}

@ -35,6 +35,7 @@ using namespace std;
#include <srs_app_utility.hpp> #include <srs_app_utility.hpp>
#include <srs_kernel_rtc_rtp.hpp> #include <srs_kernel_rtc_rtp.hpp>
#include <srs_app_log.hpp> #include <srs_app_log.hpp>
#include <srs_kernel_utility.hpp>
#include <srtp2/srtp.h> #include <srtp2/srtp.h>
#include <openssl/ssl.h> #include <openssl/ssl.h>
@ -43,6 +44,35 @@ using namespace std;
// Defined in HTTP/HTTPS client. // Defined in HTTP/HTTPS client.
extern int srs_verify_callback(int preverify_ok, X509_STORE_CTX *ctx); extern int srs_verify_callback(int preverify_ok, X509_STORE_CTX *ctx);
// Setup the openssl timeout for DTLS packet.
// @see https://www.openssl.org/docs/man1.1.1/man3/DTLS_set_timer_cb.html
//
// Use step timeout for ARQ, [50, 100, 200, 400, 800, 1600, 3200, 6400, 12800, 25600, 51200] in ms,
// then total timeout is sum([50, 100, 200, 400, 800, 1600, 3200, 6400, 12800, 25600, 51200]) = 102350ms.
//
// @remark The connection might be closed for timeout in about 30s by default, which stop the DTLS ARQ.
unsigned int dtls_timer_cb(SSL* dtls, unsigned int previous_us)
{
SrsDtlsImpl* dtls_impl = (SrsDtlsImpl*)SSL_get_ex_data(dtls, 0);
srs_assert(dtls_impl);
// Double the timeout. Note that it can be 0.
unsigned int timeout_us = previous_us * 2;
// If previous_us is 0, for example, the HelloVerifyRequest, we should response it ASAP.
// When got ServerHello, we should reset the timer.
if (previous_us == 0 || dtls_impl->should_reset_timer()) {
timeout_us = 50 * 1000; // in us
}
// Never exceed the max timeout.
timeout_us = srs_min(timeout_us, 30 * 1000 * 1000); // in us
srs_info("DTLS: ARQ timer cb timeout=%ums, previous=%ums", timeout_us/1000, previous_us/1000);
return timeout_us;
}
// Print the information of SSL, DTLS alert as such. // Print the information of SSL, DTLS alert as such.
void ssl_on_info(const SSL* dtls, int where, int ret) void ssl_on_info(const SSL* dtls, int where, int ret)
{ {
@ -377,8 +407,6 @@ SrsDtlsImpl::SrsDtlsImpl(ISrsDtlsCallback* callback)
callback_ = callback; callback_ = callback;
handshake_done_for_us = false; handshake_done_for_us = false;
last_outgoing_packet_cache = new uint8_t[kRtpPacketSize];
nn_last_outgoing_packet = 0;
nn_arq_packets = 0; nn_arq_packets = 0;
version_ = SrsDtlsVersionAuto; version_ = SrsDtlsVersionAuto;
@ -401,8 +429,6 @@ SrsDtlsImpl::~SrsDtlsImpl()
SSL_free(dtls); SSL_free(dtls);
dtls = NULL; dtls = NULL;
} }
srs_freepa(last_outgoing_packet_cache);
} }
srs_error_t SrsDtlsImpl::initialize(std::string version, std::string role) srs_error_t SrsDtlsImpl::initialize(std::string version, std::string role)
@ -431,6 +457,19 @@ srs_error_t SrsDtlsImpl::initialize(std::string version, std::string role)
SSL_set_options(dtls, SSL_OP_NO_QUERY_MTU); SSL_set_options(dtls, SSL_OP_NO_QUERY_MTU);
SSL_set_mtu(dtls, kRtpPacketSize); SSL_set_mtu(dtls, kRtpPacketSize);
// @see https://linux.die.net/man/3/openssl_version_number
// MM NN FF PP S
// 0x1010102fL = 0x1 01 01 02 fL // 1.1.1b release
// MM(major) = 0x1 // 1.*
// NN(minor) = 0x01 // 1.1.*
// FF(fix) = 0x01 // 1.1.1*
// PP(patch) = 'a' + 0x02 - 1 = 'b' // 1.1.1b *
// S(status) = 0xf = release // 1.1.1b release
// @note Status 0 for development, 1 to e for betas 1 to 14, and f for release.
#if OPENSSL_VERSION_NUMBER >= 0x1010102fL // 1.1.1b
DTLS_set_timer_cb(dtls, dtls_timer_cb);
#endif
if ((bio_in = BIO_new(BIO_s_mem())) == NULL) { if ((bio_in = BIO_new(BIO_s_mem())) == NULL) {
return srs_error_new(ERROR_OpenSslBIONew, "BIO_new in"); return srs_error_new(ERROR_OpenSslBIONew, "BIO_new in");
} }
@ -461,6 +500,12 @@ srs_error_t SrsDtlsImpl::do_on_dtls(char* data, int nb_data)
{ {
srs_error_t err = srs_success; srs_error_t err = srs_success;
// When already done, only for us, we still got message from client,
// it might be our response is lost, or application data.
if (handshake_done_for_us) {
srs_info("DTLS: After done, got %d bytes", nb_data);
}
int r0 = 0; int r0 = 0;
// TODO: FIXME: Why reset it before writing? // TODO: FIXME: Why reset it before writing?
if ((r0 = BIO_reset(bio_in)) != 1) { if ((r0 = BIO_reset(bio_in)) != 1) {
@ -471,7 +516,7 @@ srs_error_t SrsDtlsImpl::do_on_dtls(char* data, int nb_data)
} }
// Trace the detail of DTLS packet. // Trace the detail of DTLS packet.
state_trace((uint8_t*)data, nb_data, true, r0, SSL_ERROR_NONE, false, false); state_trace((uint8_t*)data, nb_data, true, r0, SSL_ERROR_NONE, false);
if ((r0 = BIO_write(bio_in, data, nb_data)) <= 0) { if ((r0 = BIO_write(bio_in, data, nb_data)) <= 0) {
// TODO: 0 or -1 maybe block, use BIO_should_retry to check. // TODO: 0 or -1 maybe block, use BIO_should_retry to check.
@ -502,6 +547,18 @@ srs_error_t SrsDtlsImpl::do_on_dtls(char* data, int nb_data)
if (r1 != SSL_ERROR_WANT_READ && r1 != SSL_ERROR_WANT_WRITE) { if (r1 != SSL_ERROR_WANT_READ && r1 != SSL_ERROR_WANT_WRITE) {
break; break;
} }
// We got data in memory, which can not read by SSL_read, generally, it's handshake data.
uint8_t* data = NULL;
int size = BIO_get_mem_data(bio_out, (char**)&data);
// Logging when got SSL original data.
state_trace((uint8_t*)data, size, false, r0, r1, false);
if (size > 0 && (err = callback_->write_dtls_data(data, size)) != srs_success) {
return srs_error_wrap(err, "dtls send size=%u, data=[%s]", size,
srs_string_dumps_hex((char*)data, size, 32).c_str());
}
continue; continue;
} }
@ -521,6 +578,12 @@ srs_error_t SrsDtlsImpl::do_handshake()
{ {
srs_error_t err = srs_success; srs_error_t err = srs_success;
// Done for use, ignore handshake packets. If need to ARQ the handshake packets,
// we should use SSL_read to handle it.
if (handshake_done_for_us) {
return err;
}
// Do handshake and get the result. // Do handshake and get the result.
int r0 = SSL_do_handshake(dtls); int r0 = SSL_do_handshake(dtls);
int r1 = SSL_get_error(dtls, r0); int r1 = SSL_get_error(dtls, r0);
@ -537,18 +600,10 @@ srs_error_t SrsDtlsImpl::do_handshake()
// The data to send out to peer. // The data to send out to peer.
uint8_t* data = NULL; uint8_t* data = NULL;
int size = BIO_get_mem_data(bio_out, &data); int size = BIO_get_mem_data(bio_out, (char**)&data);
// Callback when got SSL original data. // Logging when got SSL original data.
bool cache = false; state_trace((uint8_t*)data, size, false, r0, r1, false);
on_ssl_out_data(data, size, cache);
state_trace((uint8_t*)data, size, false, r0, r1, cache, false);
// Update the packet cache.
if (size > 0 && data != last_outgoing_packet_cache && size < kRtpPacketSize) {
memcpy(last_outgoing_packet_cache, data, size);
nn_last_outgoing_packet = size;
}
// Callback for the final output data, before send-out. // Callback for the final output data, before send-out.
if ((err = on_final_out_data(data, size)) != srs_success) { if ((err = on_final_out_data(data, size)) != srs_success) {
@ -569,7 +624,7 @@ srs_error_t SrsDtlsImpl::do_handshake()
return err; return err;
} }
void SrsDtlsImpl::state_trace(uint8_t* data, int length, bool incoming, int r0, int r1, bool cache, bool arq) void SrsDtlsImpl::state_trace(uint8_t* data, int length, bool incoming, int r0, int r1, bool arq)
{ {
// change_cipher_spec(20), alert(21), handshake(22), application_data(23) // change_cipher_spec(20), alert(21), handshake(22), application_data(23)
// @see https://tools.ietf.org/html/rfc2246#section-6.2.1 // @see https://tools.ietf.org/html/rfc2246#section-6.2.1
@ -588,8 +643,8 @@ void SrsDtlsImpl::state_trace(uint8_t* data, int length, bool incoming, int r0,
handshake_type = (uint8_t)data[13]; handshake_type = (uint8_t)data[13];
} }
srs_trace("DTLS: %s %s, done=%u, cache=%u, arq=%u/%u, r0=%d, r1=%d, len=%u, cnt=%u, size=%u, hs=%u", srs_trace("DTLS: State %s %s, done=%u, arq=%u/%u, r0=%d, r1=%d, len=%u, cnt=%u, size=%u, hs=%u",
(is_dtls_client()? "Active":"Passive"), (incoming? "RECV":"SEND"), handshake_done_for_us, cache, arq, (is_dtls_client()? "Active":"Passive"), (incoming? "RECV":"SEND"), handshake_done_for_us, arq,
nn_arq_packets, r0, r1, length, content_type, size, handshake_type); nn_arq_packets, r0, r1, length, content_type, size, handshake_type);
} }
@ -640,15 +695,9 @@ SrsDtlsClientImpl::SrsDtlsClientImpl(ISrsDtlsCallback* callback) : SrsDtlsImpl(c
trd = NULL; trd = NULL;
state_ = SrsDtlsStateInit; state_ = SrsDtlsStateInit;
// The first wait and base interval for ARQ. // the max dtls retry num is 12 in openssl.
arq_interval = 10 * SRS_UTIME_MILLISECONDS; arq_max_retry = 12 * 2; // Max ARQ limit shared for ClientHello and Certificate.
reset_timer_ = true;
// Use step timeout for ARQ, the total timeout is sum(arq_to_ratios)*arq_interval.
// for example, if arq_interval is 10ms, arq_to_ratios is [3, 6, 9, 15, 20, 40, 80, 160],
// then total timeout is sum([3, 6, 9, 15, 20, 40, 80, 160]) * 10ms = 3330ms.
int ratios[] = {3, 6, 9, 15, 20, 40, 80, 160};
srs_assert(sizeof(arq_to_ratios) == sizeof(ratios));
memcpy(arq_to_ratios, ratios, sizeof(ratios));
} }
SrsDtlsClientImpl::~SrsDtlsClientImpl() SrsDtlsClientImpl::~SrsDtlsClientImpl()
@ -672,60 +721,47 @@ srs_error_t SrsDtlsClientImpl::initialize(std::string version, std::string role)
} }
srs_error_t SrsDtlsClientImpl::start_active_handshake() srs_error_t SrsDtlsClientImpl::start_active_handshake()
{
return do_handshake();
}
srs_error_t SrsDtlsClientImpl::on_dtls(char* data, int nb_data)
{ {
srs_error_t err = srs_success; srs_error_t err = srs_success;
// When got packet, stop the ARQ if server in the first ARQ state SrsDtlsStateServerHello. if ((err = do_handshake()) != srs_success) {
// @note But for ARQ state, we should never stop the ARQ, for example, we are in the second ARQ sate return srs_error_wrap(err, "start handshake");
// SrsDtlsStateServerDone, but we got previous late wrong packet ServeHello, which is not the expect
// packet SessionNewTicket, we should never stop the ARQ thread.
if (state_ == SrsDtlsStateServerHello) {
stop_arq();
} }
if ((err = SrsDtlsImpl::on_dtls(data, nb_data)) != srs_success) { if ((err = start_arq()) != srs_success) {
return err; return srs_error_wrap(err, "start arq");
} }
return err; return err;
} }
void SrsDtlsClientImpl::on_ssl_out_data(uint8_t*& data, int& size, bool& cached) bool SrsDtlsClientImpl::should_reset_timer()
{ {
// DTLS client use ARQ thread to send cached packet. bool v = reset_timer_;
cached = false; reset_timer_ = false;
return v;
} }
// Note that only handshake sending packets drives the state, neither ARQ nor the
// final-packets(after handshake done) drives it.
srs_error_t SrsDtlsClientImpl::on_final_out_data(uint8_t* data, int size) srs_error_t SrsDtlsClientImpl::on_final_out_data(uint8_t* data, int size)
{ {
srs_error_t err = srs_success; srs_error_t err = srs_success;
// Driven ARQ and state for DTLS client.
// If we are sending client hello, change from init to new state. // If we are sending client hello, change from init to new state.
if (state_ == SrsDtlsStateInit && size > 14 && data[13] == 1) { if (state_ == SrsDtlsStateInit && size > 14 && data[0] == 22 && data[13] == 1) {
state_ = SrsDtlsStateClientHello; state_ = SrsDtlsStateClientHello;
} return err;
// If we are sending certificate, change from SrsDtlsStateServerHello to new state.
if (state_ == SrsDtlsStateServerHello && size > 14 && data[13] == 11) {
state_ = SrsDtlsStateClientCertificate;
} }
// Try to start the ARQ for client. // If we are sending certificate, change from SrsDtlsStateClientHello to new state.
if ((state_ == SrsDtlsStateClientHello || state_ == SrsDtlsStateClientCertificate)) { if (state_ == SrsDtlsStateClientHello && size > 14 && data[0] == 22 && data[13] == 11) {
if (state_ == SrsDtlsStateClientHello) { state_ = SrsDtlsStateClientCertificate;
state_ = SrsDtlsStateServerHello;
} else if (state_ == SrsDtlsStateClientCertificate) {
state_ = SrsDtlsStateServerDone;
}
if ((err = start_arq()) != srs_success) { // When we send out the certificate, we should reset the timer.
return srs_error_wrap(err, "start arq"); reset_timer_ = true;
} srs_info("DTLS: Reset the timer for ServerHello");
return err;
} }
return err; return err;
@ -735,8 +771,15 @@ srs_error_t SrsDtlsClientImpl::on_handshake_done()
{ {
srs_error_t err = srs_success; srs_error_t err = srs_success;
// When handshake done, stop the ARQ. // Ignore if done.
if (state_ == SrsDtlsStateClientDone) {
return err;
}
// Change to done state.
state_ = SrsDtlsStateClientDone; state_ = SrsDtlsStateClientDone;
// When handshake done, stop the ARQ.
stop_arq(); stop_arq();
// Notify connection the DTLS is done. // Notify connection the DTLS is done.
@ -756,8 +799,6 @@ srs_error_t SrsDtlsClientImpl::start_arq()
{ {
srs_error_t err = srs_success; srs_error_t err = srs_success;
srs_info("start arq, state=%u", state_);
// Dispose the previous ARQ thread. // Dispose the previous ARQ thread.
srs_freep(trd); srs_freep(trd);
trd = new SrsSTCoroutine("dtls", this, _srs_context->get_id()); trd = new SrsSTCoroutine("dtls", this, _srs_context->get_id());
@ -772,20 +813,24 @@ srs_error_t SrsDtlsClientImpl::start_arq()
void SrsDtlsClientImpl::stop_arq() void SrsDtlsClientImpl::stop_arq()
{ {
srs_info("stop arq, state=%u", state_);
srs_freep(trd); srs_freep(trd);
srs_info("stop arq, done");
} }
srs_error_t SrsDtlsClientImpl::cycle() srs_error_t SrsDtlsClientImpl::cycle()
{ {
srs_error_t err = srs_success; srs_error_t err = srs_success;
// Limit the max retry for ARQ. // Limit the max retry for ARQ, to avoid infinite loop.
for (int i = 0; i < (int)(sizeof(arq_to_ratios) / sizeof(int)); i++) { // Note that we set the timeout to [50, 100, 200, 400, 800, 1600, 3200, 6400, 12800, 25600, 51200] in ms,
srs_utime_t arq_to = arq_interval * arq_to_ratios[i]; // but the actual timeout is limit to 1s:
srs_usleep(arq_to); // 50ms, 100ms, 200ms, 400ms, 800ms, (1000ms,600ms), (200ms,1000ms,1000ms,1000ms),
// (400ms,1000ms,1000ms,1000ms,1000ms,1000ms,1000ms), ...
// So when the max ARQ limit to 12 times, the max loop is about 103.
// @remark We change the max sleep to 100ms, so we limit about (103*10)/2=500.
const int max_loop = 512;
int arq_count = 0;
for (int i = 0; arq_count < arq_max_retry && i < max_loop; i++) {
// We ignore any error for ARQ thread. // We ignore any error for ARQ thread.
if ((err = trd->pull()) != srs_success) { if ((err = trd->pull()) != srs_success) {
srs_freep(err); srs_freep(err);
@ -798,27 +843,62 @@ srs_error_t SrsDtlsClientImpl::cycle()
} }
// For DTLS client ARQ, the state should be specified. // For DTLS client ARQ, the state should be specified.
if (state_ != SrsDtlsStateServerHello && state_ != SrsDtlsStateServerDone) { if (state_ != SrsDtlsStateClientHello && state_ != SrsDtlsStateClientCertificate) {
return err; return err;
} }
// Try to retransmit the packet. // If there is a timeout in progress, it sets *out to the time remaining
uint8_t* data = last_outgoing_packet_cache; // and returns one. Otherwise, it returns zero.
int size = nn_last_outgoing_packet; int r0 = 0; timeval to = {0};
if ((r0 = DTLSv1_get_timeout(dtls, &to)) == 0) {
// No timeout, for example?, wait for a default 50ms.
srs_usleep(50 * SRS_UTIME_MILLISECONDS);
continue;
}
srs_utime_t timeout = to.tv_sec + to.tv_usec;
// There is timeout to wait, so we should wait, because there is no packet in openssl.
if (timeout > 0) {
// Never wait too long, because we might need to retransmit other messages.
// For example, we have transmit 2 ClientHello as [50ms, 100ms] then we sleep(200ms),
// during this we reset the openssl timer to 50ms and need to retransmit Certificate,
// we still need to wait 200ms not 50ms.
timeout = srs_min(100 * SRS_UTIME_MILLISECONDS, timeout);
timeout = srs_max(50 * SRS_UTIME_MILLISECONDS, timeout);
srs_usleep(timeout);
continue;
}
if (size) { // The timeout is 0, so there must be a ARQ packet to transmit in openssl.
// Trace the detail of DTLS packet. r0 = BIO_reset(bio_out); int r1 = SSL_get_error(dtls, r0);
state_trace((uint8_t*)data, size, false, 1, SSL_ERROR_NONE, true, true); if (r0 != 1) {
nn_arq_packets++; return srs_error_new(ERROR_OpenSslBIOReset, "BIO_reset r0=%d, r1=%d", r0, r1);
}
if ((err = callback_->write_dtls_data(data, size)) != srs_success) { // DTLSv1_handle_timeout is called when a DTLS handshake timeout expires. If no timeout
return srs_error_wrap(err, "dtls send size=%u, data=[%s]", size, // had expired, it returns 0. Otherwise, it retransmits the previous flight of handshake
srs_string_dumps_hex((char*)data, size, 32).c_str()); // messages and returns 1. If too many timeouts had expired without progress or an error
} // occurs, it returns -1.
r0 = DTLSv1_handle_timeout(dtls); r1 = SSL_get_error(dtls, r0);
if (r0 == 0) {
continue; // No timeout had expired.
} }
if (r0 != 1) {
return srs_error_new(ERROR_RTC_DTLS, "ARQ r0=%d, r1=%d", r0, r1);
}
// The data to send out to peer.
uint8_t* data = NULL;
int size = BIO_get_mem_data(bio_out, (char**)&data);
arq_count++;
nn_arq_packets++;
state_trace((uint8_t*)data, size, false, r0, r1, true);
srs_info("arq cycle, done=%u, state=%u, retry=%d, interval=%dms, to=%dms, size=%d, nn=%d", handshake_done_for_us, if (size > 0 && (err = callback_->write_dtls_data(data, size)) != srs_success) {
state_, i, srsu2msi(arq_interval), srsu2msi(arq_to), size, nn_arq_packets); return srs_error_wrap(err, "dtls send size=%u, data=[%s]", size,
srs_string_dumps_hex((char*)data, size, 32).c_str());
}
} }
return err; return err;
@ -848,23 +928,19 @@ srs_error_t SrsDtlsServerImpl::initialize(std::string version, std::string role)
srs_error_t SrsDtlsServerImpl::start_active_handshake() srs_error_t SrsDtlsServerImpl::start_active_handshake()
{ {
// For DTLS server, we do nothing, because DTLS client drive it.
return srs_success; return srs_success;
} }
void SrsDtlsServerImpl::on_ssl_out_data(uint8_t*& data, int& size, bool& cached) bool SrsDtlsServerImpl::should_reset_timer()
{ {
// If outgoing packet is empty, we use the last cache. // For DTLS server, we never use timer for ARQ, because DTLS client drive it.
// @remark Only for DTLS server, because DTLS client use ARQ thread to send cached packet. return false;
if (size <= 0 && nn_last_outgoing_packet) {
size = nn_last_outgoing_packet;
data = last_outgoing_packet_cache;
nn_arq_packets++;
cached = true;
}
} }
srs_error_t SrsDtlsServerImpl::on_final_out_data(uint8_t* data, int size) srs_error_t SrsDtlsServerImpl::on_final_out_data(uint8_t* data, int size)
{ {
// No ARQ, driven by DTLS client packets.
return srs_success; return srs_success;
} }

@ -121,9 +121,6 @@ protected:
// Whether the handshake is done, for us only. // Whether the handshake is done, for us only.
// @remark For us only, means peer maybe not done, we also need to handle the DTLS packet. // @remark For us only, means peer maybe not done, we also need to handle the DTLS packet.
bool handshake_done_for_us; bool handshake_done_for_us;
// DTLS packet cache, only last out-going packet.
uint8_t* last_outgoing_packet_cache;
int nn_last_outgoing_packet;
// The stat for ARQ packets. // The stat for ARQ packets.
int nn_arq_packets; int nn_arq_packets;
public: public:
@ -132,16 +129,16 @@ public:
public: public:
virtual srs_error_t initialize(std::string version, std::string role); virtual srs_error_t initialize(std::string version, std::string role);
virtual srs_error_t start_active_handshake() = 0; virtual srs_error_t start_active_handshake() = 0;
virtual bool should_reset_timer() = 0;
virtual srs_error_t on_dtls(char* data, int nb_data); virtual srs_error_t on_dtls(char* data, int nb_data);
protected: protected:
srs_error_t do_on_dtls(char* data, int nb_data); srs_error_t do_on_dtls(char* data, int nb_data);
srs_error_t do_handshake(); srs_error_t do_handshake();
void state_trace(uint8_t* data, int length, bool incoming, int r0, int r1, bool cache, bool arq); void state_trace(uint8_t* data, int length, bool incoming, int r0, int r1, bool arq);
public: public:
srs_error_t get_srtp_key(std::string& recv_key, std::string& send_key); srs_error_t get_srtp_key(std::string& recv_key, std::string& send_key);
void callback_by_ssl(std::string type, std::string desc); void callback_by_ssl(std::string type, std::string desc);
protected: protected:
virtual void on_ssl_out_data(uint8_t*& data, int& size, bool& cached) = 0;
virtual srs_error_t on_final_out_data(uint8_t* data, int size) = 0; virtual srs_error_t on_final_out_data(uint8_t* data, int size) = 0;
virtual srs_error_t on_handshake_done() = 0; virtual srs_error_t on_handshake_done() = 0;
virtual bool is_dtls_client() = 0; virtual bool is_dtls_client() = 0;
@ -155,18 +152,19 @@ private:
SrsCoroutine* trd; SrsCoroutine* trd;
// The DTLS-client state to drive the ARQ thread. // The DTLS-client state to drive the ARQ thread.
SrsDtlsState state_; SrsDtlsState state_;
// The timeout for ARQ. // The max ARQ retry.
srs_utime_t arq_interval; int arq_max_retry;
int arq_to_ratios[8]; // Should we reset the timer?
// It's true when init, or in state ServerHello.
bool reset_timer_;
public: public:
SrsDtlsClientImpl(ISrsDtlsCallback* callback); SrsDtlsClientImpl(ISrsDtlsCallback* callback);
virtual ~SrsDtlsClientImpl(); virtual ~SrsDtlsClientImpl();
public: public:
virtual srs_error_t initialize(std::string version, std::string role); virtual srs_error_t initialize(std::string version, std::string role);
virtual srs_error_t start_active_handshake(); virtual srs_error_t start_active_handshake();
virtual srs_error_t on_dtls(char* data, int nb_data); virtual bool should_reset_timer();
protected: protected:
virtual void on_ssl_out_data(uint8_t*& data, int& size, bool& cached);
virtual srs_error_t on_final_out_data(uint8_t* data, int size); virtual srs_error_t on_final_out_data(uint8_t* data, int size);
virtual srs_error_t on_handshake_done(); virtual srs_error_t on_handshake_done();
virtual bool is_dtls_client(); virtual bool is_dtls_client();
@ -185,8 +183,8 @@ public:
public: public:
virtual srs_error_t initialize(std::string version, std::string role); virtual srs_error_t initialize(std::string version, std::string role);
virtual srs_error_t start_active_handshake(); virtual srs_error_t start_active_handshake();
virtual bool should_reset_timer();
protected: protected:
virtual void on_ssl_out_data(uint8_t*& data, int& size, bool& cached);
virtual srs_error_t on_final_out_data(uint8_t* data, int size); virtual srs_error_t on_final_out_data(uint8_t* data, int size);
virtual srs_error_t on_handshake_done(); virtual srs_error_t on_handshake_done();
virtual bool is_dtls_client(); virtual bool is_dtls_client();

@ -24,6 +24,6 @@
#ifndef SRS_CORE_VERSION4_HPP #ifndef SRS_CORE_VERSION4_HPP
#define SRS_CORE_VERSION4_HPP #define SRS_CORE_VERSION4_HPP
#define SRS_VERSION4_REVISION 83 #define SRS_VERSION4_REVISION 84
#endif #endif

@ -871,9 +871,12 @@ srs_error_t MockDtlsCallback::cycle()
} }
// Wait for mock io to done, try to switch to coroutine many times. // Wait for mock io to done, try to switch to coroutine many times.
void mock_wait_dtls_io_done(int count = 100, int interval = 0) void mock_wait_dtls_io_done(SrsDtlsImpl* client_impl, int count = 100, int interval = 0)
{ {
for (int i = 0; i < count; i++) { for (int i = 0; i < count; i++) {
if (client_impl) {
dynamic_cast<SrsDtlsClientImpl*>(client_impl)->reset_timer_ = true;
}
srs_usleep(interval * SRS_UTIME_MILLISECONDS); srs_usleep(interval * SRS_UTIME_MILLISECONDS);
} }
} }
@ -895,138 +898,6 @@ public:
} }
}; };
VOID TEST(KernelRTCTest, DTLSARQLimitTest)
{
srs_error_t err = srs_success;
// ClientHello lost, client retransmit the ClientHello.
if (true) {
MockDtlsCallback cio; SrsDtls client(&cio);
MockDtlsCallback sio; SrsDtls server(&sio);
MockBridgeDtlsIO b0(&cio, &server, NULL); MockBridgeDtlsIO b1(&sio, &client, NULL);
HELPER_EXPECT_SUCCESS(client.initialize("active", "dtls1.0"));
HELPER_EXPECT_SUCCESS(server.initialize("passive", "dtls1.0"));
// Use very short interval for utest.
dynamic_cast<SrsDtlsClientImpl*>(client.impl)->arq_interval = 1 * SRS_UTIME_MILLISECONDS;
HELPER_ARRAY_INIT(dynamic_cast<SrsDtlsClientImpl*>(client.impl)->arq_to_ratios, 8, 1);
// Lost 10 packets, total packets should be 9(max to 9).
// Note that only one server hello.
cio.nn_client_hello_lost = 10;
HELPER_EXPECT_SUCCESS(client.start_active_handshake());
mock_wait_dtls_io_done(10, 3);
EXPECT_TRUE(sio.r0 == srs_success);
EXPECT_TRUE(cio.r0 == srs_success);
EXPECT_FALSE(cio.done);
EXPECT_FALSE(sio.done);
EXPECT_EQ(9, cio.nn_client_hello);
EXPECT_EQ(0, sio.nn_server_hello);
EXPECT_EQ(0, cio.nn_certificate);
EXPECT_EQ(0, sio.nn_new_session);
EXPECT_EQ(0, sio.nn_change_cipher);
}
// Certificate lost, client retransmit the Certificate.
if (true) {
MockDtlsCallback cio; SrsDtls client(&cio);
MockDtlsCallback sio; SrsDtls server(&sio);
MockBridgeDtlsIO b0(&cio, &server, NULL); MockBridgeDtlsIO b1(&sio, &client, NULL);
HELPER_EXPECT_SUCCESS(client.initialize("active", "dtls1.0"));
HELPER_EXPECT_SUCCESS(server.initialize("passive", "dtls1.0"));
// Use very short interval for utest.
dynamic_cast<SrsDtlsClientImpl*>(client.impl)->arq_interval = 1 * SRS_UTIME_MILLISECONDS;
HELPER_ARRAY_INIT(dynamic_cast<SrsDtlsClientImpl*>(client.impl)->arq_to_ratios, 8, 1);
// Lost 10 packets, total packets should be 9(max to 9).
// Note that only one server NewSessionTicket.
cio.nn_certificate_lost = 10;
HELPER_EXPECT_SUCCESS(client.start_active_handshake());
mock_wait_dtls_io_done(10, 3);
EXPECT_TRUE(sio.r0 == srs_success);
EXPECT_TRUE(cio.r0 == srs_success);
EXPECT_FALSE(cio.done);
EXPECT_FALSE(sio.done);
EXPECT_EQ(1, cio.nn_client_hello);
EXPECT_EQ(1, sio.nn_server_hello);
EXPECT_EQ(9, cio.nn_certificate);
EXPECT_EQ(0, sio.nn_new_session);
EXPECT_EQ(0, sio.nn_change_cipher);
}
// ServerHello lost, client retransmit the ClientHello.
if (true) {
MockDtlsCallback cio; SrsDtls client(&cio);
MockDtlsCallback sio; SrsDtls server(&sio);
MockBridgeDtlsIO b0(&cio, &server, NULL); MockBridgeDtlsIO b1(&sio, &client, NULL);
HELPER_EXPECT_SUCCESS(client.initialize("active", "dtls1.0"));
HELPER_EXPECT_SUCCESS(server.initialize("passive", "dtls1.0"));
// Use very short interval for utest.
dynamic_cast<SrsDtlsClientImpl*>(client.impl)->arq_interval = 1 * SRS_UTIME_MILLISECONDS;
HELPER_ARRAY_INIT(dynamic_cast<SrsDtlsClientImpl*>(client.impl)->arq_to_ratios, 8, 1);
// Lost 10 packets, total packets should be 9(max to 9).
sio.nn_server_hello_lost = 10;
HELPER_EXPECT_SUCCESS(client.start_active_handshake());
mock_wait_dtls_io_done(10, 3);
EXPECT_TRUE(sio.r0 == srs_success);
EXPECT_TRUE(cio.r0 == srs_success);
EXPECT_FALSE(cio.done);
EXPECT_FALSE(sio.done);
EXPECT_EQ(9, cio.nn_client_hello);
EXPECT_EQ(9, sio.nn_server_hello);
EXPECT_EQ(0, cio.nn_certificate);
EXPECT_EQ(0, sio.nn_new_session);
EXPECT_EQ(0, sio.nn_change_cipher);
}
// NewSessionTicket lost, client retransmit the Certificate.
if (true) {
MockDtlsCallback cio; SrsDtls client(&cio);
MockDtlsCallback sio; SrsDtls server(&sio);
MockBridgeDtlsIO b0(&cio, &server, NULL); MockBridgeDtlsIO b1(&sio, &client, NULL);
HELPER_EXPECT_SUCCESS(client.initialize("active", "dtls1.0"));
HELPER_EXPECT_SUCCESS(server.initialize("passive", "dtls1.0"));
// Use very short interval for utest.
dynamic_cast<SrsDtlsClientImpl*>(client.impl)->arq_interval = 1 * SRS_UTIME_MILLISECONDS;
HELPER_ARRAY_INIT(dynamic_cast<SrsDtlsClientImpl*>(client.impl)->arq_to_ratios, 8, 1);
// Lost 10 packets, total packets should be 9(max to 9).
sio.nn_new_session_lost = 10;
HELPER_EXPECT_SUCCESS(client.start_active_handshake());
mock_wait_dtls_io_done(10, 3);
EXPECT_TRUE(sio.r0 == srs_success);
EXPECT_TRUE(cio.r0 == srs_success);
// Although the packet is lost, but it's done for server, and not done for client.
EXPECT_FALSE(cio.done);
EXPECT_TRUE(sio.done);
EXPECT_EQ(1, cio.nn_client_hello);
EXPECT_EQ(1, sio.nn_server_hello);
EXPECT_EQ(9, cio.nn_certificate);
EXPECT_EQ(9, sio.nn_new_session);
EXPECT_EQ(0, sio.nn_change_cipher);
}
}
VOID TEST(KernelRTCTest, DTLSClientARQTest) VOID TEST(KernelRTCTest, DTLSClientARQTest)
{ {
srs_error_t err = srs_success; srs_error_t err = srs_success;
@ -1040,7 +911,7 @@ VOID TEST(KernelRTCTest, DTLSClientARQTest)
HELPER_EXPECT_SUCCESS(server.initialize("passive", "dtls1.0")); HELPER_EXPECT_SUCCESS(server.initialize("passive", "dtls1.0"));
HELPER_EXPECT_SUCCESS(client.start_active_handshake()); HELPER_EXPECT_SUCCESS(client.start_active_handshake());
mock_wait_dtls_io_done(30, 1); mock_wait_dtls_io_done(client.impl, 15, 5);
EXPECT_TRUE(sio.r0 == srs_success); EXPECT_TRUE(sio.r0 == srs_success);
EXPECT_TRUE(cio.r0 == srs_success); EXPECT_TRUE(cio.r0 == srs_success);
@ -1050,8 +921,8 @@ VOID TEST(KernelRTCTest, DTLSClientARQTest)
EXPECT_EQ(1, cio.nn_client_hello); EXPECT_EQ(1, cio.nn_client_hello);
EXPECT_EQ(1, sio.nn_server_hello); EXPECT_EQ(1, sio.nn_server_hello);
EXPECT_TRUE(1 <= cio.nn_certificate); EXPECT_EQ(1, cio.nn_certificate);
EXPECT_TRUE(1 <= sio.nn_new_session); EXPECT_EQ(1, sio.nn_new_session);
EXPECT_EQ(0, sio.nn_change_cipher); EXPECT_EQ(0, sio.nn_change_cipher);
} }
@ -1063,16 +934,12 @@ VOID TEST(KernelRTCTest, DTLSClientARQTest)
HELPER_EXPECT_SUCCESS(client.initialize("active", "dtls1.0")); HELPER_EXPECT_SUCCESS(client.initialize("active", "dtls1.0"));
HELPER_EXPECT_SUCCESS(server.initialize("passive", "dtls1.0")); HELPER_EXPECT_SUCCESS(server.initialize("passive", "dtls1.0"));
// Use very short interval for utest.
dynamic_cast<SrsDtlsClientImpl*>(client.impl)->arq_interval = 1 * SRS_UTIME_MILLISECONDS;
HELPER_ARRAY_INIT(dynamic_cast<SrsDtlsClientImpl*>(client.impl)->arq_to_ratios, 8, 1);
// Lost 2 packets, total packets should be 3. // Lost 2 packets, total packets should be 3.
// Note that only one server hello. // Note that only one server hello.
cio.nn_client_hello_lost = 2; cio.nn_client_hello_lost = 1;
HELPER_EXPECT_SUCCESS(client.start_active_handshake()); HELPER_EXPECT_SUCCESS(client.start_active_handshake());
mock_wait_dtls_io_done(10, 3); mock_wait_dtls_io_done(client.impl, 15, 5);
EXPECT_TRUE(sio.r0 == srs_success); EXPECT_TRUE(sio.r0 == srs_success);
EXPECT_TRUE(cio.r0 == srs_success); EXPECT_TRUE(cio.r0 == srs_success);
@ -1080,10 +947,10 @@ VOID TEST(KernelRTCTest, DTLSClientARQTest)
EXPECT_TRUE(cio.done); EXPECT_TRUE(cio.done);
EXPECT_TRUE(sio.done); EXPECT_TRUE(sio.done);
EXPECT_TRUE(3 <= cio.nn_client_hello); EXPECT_EQ(2, cio.nn_client_hello);
EXPECT_TRUE(1 <= sio.nn_server_hello); EXPECT_EQ(1, sio.nn_server_hello);
EXPECT_TRUE(1 <= cio.nn_certificate); EXPECT_EQ(1, cio.nn_certificate);
EXPECT_TRUE(1 <= sio.nn_new_session); EXPECT_EQ(1, sio.nn_new_session);
EXPECT_EQ(0, sio.nn_change_cipher); EXPECT_EQ(0, sio.nn_change_cipher);
} }
@ -1095,16 +962,12 @@ VOID TEST(KernelRTCTest, DTLSClientARQTest)
HELPER_EXPECT_SUCCESS(client.initialize("active", "dtls1.0")); HELPER_EXPECT_SUCCESS(client.initialize("active", "dtls1.0"));
HELPER_EXPECT_SUCCESS(server.initialize("passive", "dtls1.0")); HELPER_EXPECT_SUCCESS(server.initialize("passive", "dtls1.0"));
// Use very short interval for utest.
dynamic_cast<SrsDtlsClientImpl*>(client.impl)->arq_interval = 1 * SRS_UTIME_MILLISECONDS;
HELPER_ARRAY_INIT(dynamic_cast<SrsDtlsClientImpl*>(client.impl)->arq_to_ratios, 8, 1);
// Lost 2 packets, total packets should be 3. // Lost 2 packets, total packets should be 3.
// Note that only one server NewSessionTicket. // Note that only one server NewSessionTicket.
cio.nn_certificate_lost = 2; cio.nn_certificate_lost = 1;
HELPER_EXPECT_SUCCESS(client.start_active_handshake()); HELPER_EXPECT_SUCCESS(client.start_active_handshake());
mock_wait_dtls_io_done(10, 3); mock_wait_dtls_io_done(client.impl, 15, 5);
EXPECT_TRUE(sio.r0 == srs_success); EXPECT_TRUE(sio.r0 == srs_success);
EXPECT_TRUE(cio.r0 == srs_success); EXPECT_TRUE(cio.r0 == srs_success);
@ -1113,9 +976,9 @@ VOID TEST(KernelRTCTest, DTLSClientARQTest)
EXPECT_TRUE(sio.done); EXPECT_TRUE(sio.done);
EXPECT_EQ(1, cio.nn_client_hello); EXPECT_EQ(1, cio.nn_client_hello);
EXPECT_EQ(1, sio.nn_server_hello); EXPECT_EQ(2, sio.nn_server_hello);
EXPECT_TRUE(3 <= cio.nn_certificate); EXPECT_EQ(2, cio.nn_certificate);
EXPECT_TRUE(1 <= sio.nn_new_session); EXPECT_EQ(0, sio.nn_new_session);
EXPECT_EQ(0, sio.nn_change_cipher); EXPECT_EQ(0, sio.nn_change_cipher);
} }
} }
@ -1133,7 +996,7 @@ VOID TEST(KernelRTCTest, DTLSServerARQTest)
HELPER_EXPECT_SUCCESS(server.initialize("passive", "dtls1.0")); HELPER_EXPECT_SUCCESS(server.initialize("passive", "dtls1.0"));
HELPER_EXPECT_SUCCESS(client.start_active_handshake()); HELPER_EXPECT_SUCCESS(client.start_active_handshake());
mock_wait_dtls_io_done(30, 1); mock_wait_dtls_io_done(client.impl, 15, 5);
EXPECT_TRUE(sio.r0 == srs_success); EXPECT_TRUE(sio.r0 == srs_success);
EXPECT_TRUE(cio.r0 == srs_success); EXPECT_TRUE(cio.r0 == srs_success);
@ -1143,8 +1006,8 @@ VOID TEST(KernelRTCTest, DTLSServerARQTest)
EXPECT_EQ(1, cio.nn_client_hello); EXPECT_EQ(1, cio.nn_client_hello);
EXPECT_EQ(1, sio.nn_server_hello); EXPECT_EQ(1, sio.nn_server_hello);
EXPECT_TRUE(1 <= cio.nn_certificate); EXPECT_EQ(1, cio.nn_certificate);
EXPECT_TRUE(1 <= sio.nn_new_session); EXPECT_EQ(1, sio.nn_new_session);
EXPECT_EQ(0, sio.nn_change_cipher); EXPECT_EQ(0, sio.nn_change_cipher);
} }
@ -1156,15 +1019,11 @@ VOID TEST(KernelRTCTest, DTLSServerARQTest)
HELPER_EXPECT_SUCCESS(client.initialize("active", "dtls1.0")); HELPER_EXPECT_SUCCESS(client.initialize("active", "dtls1.0"));
HELPER_EXPECT_SUCCESS(server.initialize("passive", "dtls1.0")); HELPER_EXPECT_SUCCESS(server.initialize("passive", "dtls1.0"));
// Use very short interval for utest.
dynamic_cast<SrsDtlsClientImpl*>(client.impl)->arq_interval = 1 * SRS_UTIME_MILLISECONDS;
HELPER_ARRAY_INIT(dynamic_cast<SrsDtlsClientImpl*>(client.impl)->arq_to_ratios, 8, 1);
// Lost 2 packets, total packets should be 3. // Lost 2 packets, total packets should be 3.
sio.nn_server_hello_lost = 2; sio.nn_server_hello_lost = 1;
HELPER_EXPECT_SUCCESS(client.start_active_handshake()); HELPER_EXPECT_SUCCESS(client.start_active_handshake());
mock_wait_dtls_io_done(10, 3); mock_wait_dtls_io_done(client.impl, 15, 5);
EXPECT_TRUE(sio.r0 == srs_success); EXPECT_TRUE(sio.r0 == srs_success);
EXPECT_TRUE(cio.r0 == srs_success); EXPECT_TRUE(cio.r0 == srs_success);
@ -1172,10 +1031,10 @@ VOID TEST(KernelRTCTest, DTLSServerARQTest)
EXPECT_TRUE(cio.done); EXPECT_TRUE(cio.done);
EXPECT_TRUE(sio.done); EXPECT_TRUE(sio.done);
EXPECT_EQ(3, cio.nn_client_hello); EXPECT_EQ(2, cio.nn_client_hello);
EXPECT_EQ(3, sio.nn_server_hello); EXPECT_EQ(2, sio.nn_server_hello);
EXPECT_TRUE(1 <= cio.nn_certificate); EXPECT_EQ(1, cio.nn_certificate);
EXPECT_TRUE(1 <= sio.nn_new_session); EXPECT_EQ(1, sio.nn_new_session);
EXPECT_EQ(0, sio.nn_change_cipher); EXPECT_EQ(0, sio.nn_change_cipher);
} }
@ -1187,15 +1046,11 @@ VOID TEST(KernelRTCTest, DTLSServerARQTest)
HELPER_EXPECT_SUCCESS(client.initialize("active", "dtls1.0")); HELPER_EXPECT_SUCCESS(client.initialize("active", "dtls1.0"));
HELPER_EXPECT_SUCCESS(server.initialize("passive", "dtls1.0")); HELPER_EXPECT_SUCCESS(server.initialize("passive", "dtls1.0"));
// Use very short interval for utest.
dynamic_cast<SrsDtlsClientImpl*>(client.impl)->arq_interval = 1 * SRS_UTIME_MILLISECONDS;
HELPER_ARRAY_INIT(dynamic_cast<SrsDtlsClientImpl*>(client.impl)->arq_to_ratios, 8, 1);
// Lost 2 packets, total packets should be 3. // Lost 2 packets, total packets should be 3.
sio.nn_new_session_lost = 2; sio.nn_new_session_lost = 1;
HELPER_EXPECT_SUCCESS(client.start_active_handshake()); HELPER_EXPECT_SUCCESS(client.start_active_handshake());
mock_wait_dtls_io_done(10, 3); mock_wait_dtls_io_done(client.impl, 15, 5);
EXPECT_TRUE(sio.r0 == srs_success); EXPECT_TRUE(sio.r0 == srs_success);
EXPECT_TRUE(cio.r0 == srs_success); EXPECT_TRUE(cio.r0 == srs_success);
@ -1205,8 +1060,8 @@ VOID TEST(KernelRTCTest, DTLSServerARQTest)
EXPECT_EQ(1, cio.nn_client_hello); EXPECT_EQ(1, cio.nn_client_hello);
EXPECT_EQ(1, sio.nn_server_hello); EXPECT_EQ(1, sio.nn_server_hello);
EXPECT_EQ(3, cio.nn_certificate); EXPECT_EQ(2, cio.nn_certificate);
EXPECT_EQ(3, sio.nn_new_session); EXPECT_EQ(2, sio.nn_new_session);
EXPECT_EQ(0, sio.nn_change_cipher); EXPECT_EQ(0, sio.nn_change_cipher);
} }
} }
@ -1250,10 +1105,10 @@ VOID TEST(KernelRTCTest, DTLSClientFlowTest)
{4, "auto", "dtls1.0", true, true, false, false}, {4, "auto", "dtls1.0", true, true, false, false},
// OK, Client: DTLS auto(v1.0 or v1.2), Server: DTLS v1.0 // OK, Client: DTLS auto(v1.0 or v1.2), Server: DTLS v1.0
{5, "auto", "dtls1.2", true, true, false, false}, {5, "auto", "dtls1.2", true, true, false, false},
// Fail, Client: DTLS v1.0, Server: DTLS v1.2 // OK?, Client: DTLS v1.0, Server: DTLS v1.2
{6, "dtls1.0", "dtls1.2", false, false, false, true}, {6, "dtls1.0", "dtls1.2", true, true, false, false},
// Fail, Client: DTLS v1.2, Server: DTLS v1.0 // OK?, Client: DTLS v1.2, Server: DTLS v1.0
{7, "dtls1.2", "dtls1.0", false, false, true, false}, {7, "dtls1.2", "dtls1.0", true, true, false, false},
}; };
for (int i = 0; i < (int)(sizeof(cases) / sizeof(DTLSFlowCase)); i++) { for (int i = 0; i < (int)(sizeof(cases) / sizeof(DTLSFlowCase)); i++) {
@ -1266,14 +1121,14 @@ VOID TEST(KernelRTCTest, DTLSClientFlowTest)
HELPER_EXPECT_SUCCESS(server.initialize("passive", c.ServerVersion)) << c; HELPER_EXPECT_SUCCESS(server.initialize("passive", c.ServerVersion)) << c;
HELPER_EXPECT_SUCCESS(client.start_active_handshake()) << c; HELPER_EXPECT_SUCCESS(client.start_active_handshake()) << c;
mock_wait_dtls_io_done(); mock_wait_dtls_io_done(client.impl, 15, 5);
// Note that the cio error is generated from server, vice versa. // Note that the cio error is generated from server, vice versa.
EXPECT_EQ(c.ClientError, sio.r0 != srs_success) << c;
EXPECT_EQ(c.ServerError, cio.r0 != srs_success) << c;
EXPECT_EQ(c.ClientDone, cio.done) << c; EXPECT_EQ(c.ClientDone, cio.done) << c;
EXPECT_EQ(c.ServerDone, sio.done) << c; EXPECT_EQ(c.ServerDone, sio.done) << c;
EXPECT_EQ(c.ClientError, sio.r0 != srs_success) << c;
EXPECT_EQ(c.ServerError, cio.r0 != srs_success) << c;
} }
} }
@ -1294,10 +1149,10 @@ VOID TEST(KernelRTCTest, DTLSServerFlowTest)
{4, "auto", "dtls1.0", true, true, false, false}, {4, "auto", "dtls1.0", true, true, false, false},
// OK, Client: DTLS auto(v1.0 or v1.2), Server: DTLS v1.0 // OK, Client: DTLS auto(v1.0 or v1.2), Server: DTLS v1.0
{5, "auto", "dtls1.2", true, true, false, false}, {5, "auto", "dtls1.2", true, true, false, false},
// Fail, Client: DTLS v1.0, Server: DTLS v1.2 // OK?, Client: DTLS v1.0, Server: DTLS v1.2
{6, "dtls1.0", "dtls1.2", false, false, false, true}, {6, "dtls1.0", "dtls1.2", true, true, false, false},
// Fail, Client: DTLS v1.2, Server: DTLS v1.0 // OK?, Client: DTLS v1.2, Server: DTLS v1.0
{7, "dtls1.2", "dtls1.0", false, false, true, false}, {7, "dtls1.2", "dtls1.0", true, true, false, false},
}; };
for (int i = 0; i < (int)(sizeof(cases) / sizeof(DTLSFlowCase)); i++) { for (int i = 0; i < (int)(sizeof(cases) / sizeof(DTLSFlowCase)); i++) {
@ -1310,14 +1165,14 @@ VOID TEST(KernelRTCTest, DTLSServerFlowTest)
HELPER_EXPECT_SUCCESS(server.initialize("passive", c.ServerVersion)) << c; HELPER_EXPECT_SUCCESS(server.initialize("passive", c.ServerVersion)) << c;
HELPER_EXPECT_SUCCESS(client.start_active_handshake()) << c; HELPER_EXPECT_SUCCESS(client.start_active_handshake()) << c;
mock_wait_dtls_io_done(); mock_wait_dtls_io_done(NULL, 15, 5);
// Note that the cio error is generated from server, vice versa. // Note that the cio error is generated from server, vice versa.
EXPECT_EQ(c.ClientError, sio.r0 != srs_success) << c;
EXPECT_EQ(c.ServerError, cio.r0 != srs_success) << c;
EXPECT_EQ(c.ClientDone, cio.done) << c; EXPECT_EQ(c.ClientDone, cio.done) << c;
EXPECT_EQ(c.ServerDone, sio.done) << c; EXPECT_EQ(c.ServerDone, sio.done) << c;
EXPECT_EQ(c.ClientError, sio.r0 != srs_success) << c;
EXPECT_EQ(c.ServerError, cio.r0 != srs_success) << c;
} }
} }

Loading…
Cancel
Save