For regression test, add srs-bench to 3rdparty

pull/2257/head
winlin 4 years ago
parent de87dd427d
commit 876210f6c9

@ -57,10 +57,8 @@ jobs:
- run: |
echo "Build and run SRS with regression config" &&
cd trunk && ./configure && make && ./objs/srs -c conf/regression-test.conf &&
echo "Clone srs-bench for regression test" &&
cd 3rdparty && git clone -b feature/srs --depth=1 https://github.com/ossrs/srs-bench &&
echo "Run srs-bench regression test" &&
cd srs-bench && make && ./objs/srs_test -test.v
cd 3rdparty/srs-bench && make && ./objs/srs_test -test.v
workflows:
version: 2
build_and_test:

@ -0,0 +1,8 @@
.idea
objs
*.ogg
*.ivf
*.h264
.DS_Store
.format.txt

@ -0,0 +1,25 @@
.PHONY: help default clean bench test
default: bench test
clean:
rm -f ./objs/srs_bench ./objs/srs_test
.format.txt: *.go rtc/*.go srs/*.go
gofmt -w .
echo "done" > .format.txt
bench: ./objs/srs_bench
./objs/srs_bench: .format.txt *.go rtc/*.go srs/*.go Makefile
go build -mod=vendor -o objs/srs_bench .
test: ./objs/srs_test
./objs/srs_test: .format.txt *.go rtc/*.go srs/*.go Makefile
go test ./srs -c -o ./objs/srs_test
help:
@echo "Usage: make [bench|test]"
@echo " bench Make the bench to ./objs/srs_bench"
@echo " test Make the test tool to ./objs/srs_test"

@ -0,0 +1,149 @@
# srs-bench
WebRTC benchmark on [pion/webrtc](https://github.com/pion/webrtc) for [SRS](https://github.com/ossrs/srs).
## Usage
编译和使用:
```bash
git clone https://github.com/ossrs/srs-bench.git && git checkout feature/rtc &&
make && ./objs/srs_bench -h
```
## Player for Live
直播播放压测,一个流,很多个播放。
首先推流到SRS
```bash
ffmpeg -re -i doc/source.200kbps.768x320.flv -c copy -f flv -y rtmp://localhost/live/livestream
```
然后启动压测比如100个
```bash
./objs/srs_bench -sr webrtc://localhost/live/livestream -nn 100
```
## Publisher for Live or RTC
直播或会议场景推流压测,一般会推多个流。
首先,推流依赖于录制的文件,请参考[DVR](#dvr)。
然后启动推流压测比如100个流
```bash
./objs/srs_bench -pr webrtc://localhost/live/livestream_%d -sn 100 -sa a.ogg -sv v.h264 -fps 25
```
> 注意帧率是原始视频的帧率由于264中没有这个信息所以需要传递。
## Multipel Player or Publisher for RTC
会议场景的播放压测会多个客户端播放多个流比如3人会议那么就有3个推流每个流有2个播放。
首先启动推流压测比如3个流
```bash
./objs/srs_bench -pr webrtc://localhost/live/livestream_%d -sn 3 -sa a.ogg -sv v.h264 -fps 25
```
然后每个流都启动播放压测比如每个流2个播放
```bash
./objs/srs_bench -sr webrtc://localhost/live/livestream_%d -sn 3 -nn 2
```
> 备注:压测都是基于流,可以任意设计推流和播放的流路数,实现不同的场景。
> 备注URL的变量格式参考Go的`fmt.Sprintf`,比如可以用`webrtc://localhost/live/livestream_%03d`。
## DVR
录制场景,主要是把内容录制下来后,可分析,也可以用于推流。
首先推流到SRS参考[live](#player-for-live)。
```bash
ffmpeg -re -i doc/source.200kbps.768x320.flv -c copy -f flv -y rtmp://localhost/live/livestream
```
然后,启动录制:
```bash
./objs/srs_bench -sr webrtc://localhost/live/livestream -da a.ogg -dv v.h264
```
> 备注:录制下来的`a.ogg`和`v.h264`可以用做推流。
## RTC Plaintext
压测RTC明文播放
首先推流到SRS
```bash
ffmpeg -re -i doc/source.200kbps.768x320.flv -c copy -f flv -y rtmp://localhost/live/livestream
```
然后启动压测指定是明文非加密比如100个
```bash
./objs/srs_bench -sr webrtc://localhost/live/livestream?encrypt=false -nn 100
```
> Note: 可以传递更多参数详细参考SRS支持的参数。
## Regression Test
回归测试需要先启动[SRS](https://github.com/ossrs/srs/issues/307)支持WebRTC推拉流
```bash
eip=$(ifconfig en0 inet| grep 'inet '|awk '{print $2}')
if [[ ! -z $eip ]]; then
docker run -p 1935:1935 -p 8080:8080 -p 1985:1985 -p 8000:8000/udp \
--rm --env CANDIDATE=$(ifconfig en0 inet| grep 'inet '|awk '{print $2}')\
registry.cn-hangzhou.aliyuncs.com/ossrs/srs:v4.0.76 objs/srs -c conf/rtc.conf
fi
```
然后运行回归测试用例,如果只跑一次,可以直接运行:
```bash
go test ./srs -v
```
也可以用make编译出重复使用的二进制
```bash
make test && ./objs/srs_test -test.v
```
可以给回归测试传参数,这样可以测试不同的序列,比如:
```bash
go test ./srs -v -srs-server=127.0.0.1
# Or
make test && ./objs/srs_test -test.v -srs-server=127.0.0.1
```
支持的参数如下:
* `-srs-server`RTC服务器地址。默认值`127.0.0.1`
* `-srs-stream`RTC流地址。默认值`/rtc/regression`
* `-srs-log`,是否开启详细日志。默认值:`false`
* `-srs-timeout`每个Case的超时时间毫秒。默认值`3000`即3秒。
* `-srs-play-pli`播放时PLI的间隔毫秒。默认值`5000`即5秒。
* `-srs-play-ok-packets`,播放时,收到多少个包认为是测试通过,默认值:`10`
* `-srs-publish-audio`,推流时,使用的音频文件。默认值:`avatar.ogg`
* `-srs-publish-video`,推流时,使用的视频文件。默认值:`avatar.h264`
* `-srs-publish-video-fps`推流时视频文件的FPS。默认值`25`
其他不常用参数:
* `-srs-https`是否连接HTTPS-API。默认值`false`即连接HTTP-API。
2021.01, Winlin

Binary file not shown.

@ -0,0 +1,12 @@
module github.com/ossrs/srs-bench
go 1.15
require (
github.com/ossrs/go-oryx-lib v0.0.8
github.com/pion/interceptor v0.0.9
github.com/pion/rtcp v1.2.6
github.com/pion/rtp v1.6.2
github.com/pion/sdp/v3 v3.0.4
github.com/pion/webrtc/v3 v3.0.4
)

@ -0,0 +1,124 @@
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8=
github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA=
github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs=
github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w=
github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0=
github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI=
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/uuid v1.1.2 h1:EVhdT+1Kseyi1/pUmXKaFxYsDNy9RQYkMWRH68J/W7Y=
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=
github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk=
github.com/onsi/ginkgo v1.14.2/go.mod h1:iSB4RoI2tjJc9BBv4NKIKWKya62Rps+oPG/Lv9klQyY=
github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY=
github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo=
github.com/onsi/gomega v1.10.3/go.mod h1:V9xEwhxec5O8UDM77eCW8vLymOMltsqPVYWrpDsH8xc=
github.com/ossrs/go-oryx-lib v0.0.8 h1:k8ml3ZLsjIMoQEdZdWuy8zkU0w/fbJSyHvT/s9NyeCc=
github.com/ossrs/go-oryx-lib v0.0.8/go.mod h1:i2tH4TZBzAw5h+HwGrNOKvP/nmZgSQz0OEnLLdzcT/8=
github.com/pion/datachannel v1.4.21 h1:3ZvhNyfmxsAqltQrApLPQMhSFNA+aT87RqyCq4OXmf0=
github.com/pion/datachannel v1.4.21/go.mod h1:oiNyP4gHx2DIwRzX/MFyH0Rz/Gz05OgBlayAI2hAWjg=
github.com/pion/dtls/v2 v2.0.4 h1:WuUcqi6oYMu/noNTz92QrF1DaFj4eXbhQ6dzaaAwOiI=
github.com/pion/dtls/v2 v2.0.4/go.mod h1:qAkFscX0ZHoI1E07RfYPoRw3manThveu+mlTDdOxoGI=
github.com/pion/ice/v2 v2.0.14 h1:FxXxauyykf89SWAtkQCfnHkno6G8+bhRkNguSh9zU+4=
github.com/pion/ice/v2 v2.0.14/go.mod h1:wqaUbOq5ObDNU5ox1hRsEst0rWfsKuH1zXjQFEWiZwM=
github.com/pion/interceptor v0.0.9 h1:fk5hTdyLO3KURQsf/+RjMpEm4NE3yeTY9Kh97b5BvwA=
github.com/pion/interceptor v0.0.9/go.mod h1:dHgEP5dtxOTf21MObuBAjJeAayPxLUAZjerGH8Xr07c=
github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY=
github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms=
github.com/pion/mdns v0.0.4 h1:O4vvVqr4DGX63vzmO6Fw9vpy3lfztVWHGCQfyw0ZLSY=
github.com/pion/mdns v0.0.4/go.mod h1:R1sL0p50l42S5lJs91oNdUL58nm0QHrhxnSegr++qC0=
github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA=
github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8=
github.com/pion/rtcp v1.2.6 h1:1zvwBbyd0TeEuuWftrd/4d++m+/kZSeiguxU61LFWpo=
github.com/pion/rtcp v1.2.6/go.mod h1:52rMNPWFsjr39z9B9MhnkqhPLoeHTv1aN63o/42bWE0=
github.com/pion/rtp v1.6.2 h1:iGBerLX6JiDjB9NXuaPzHyxHFG9JsIEdgwTC0lp5n/U=
github.com/pion/rtp v1.6.2/go.mod h1:bDb5n+BFZxXx0Ea7E5qe+klMuqiBrP+w8XSjiWtCUko=
github.com/pion/sctp v1.7.10/go.mod h1:EhpTUQu1/lcK3xI+eriS6/96fWetHGCvBi9MSsnaBN0=
github.com/pion/sctp v1.7.11 h1:UCnj7MsobLKLuP/Hh+JMiI/6W5Bs/VF45lWKgHFjSIE=
github.com/pion/sctp v1.7.11/go.mod h1:EhpTUQu1/lcK3xI+eriS6/96fWetHGCvBi9MSsnaBN0=
github.com/pion/sdp/v3 v3.0.4 h1:2Kf+dgrzJflNCSw3TV5v2VLeI0s/qkzy2r5jlR0wzf8=
github.com/pion/sdp/v3 v3.0.4/go.mod h1:bNiSknmJE0HYBprTHXKPQ3+JjacTv5uap92ueJZKsRk=
github.com/pion/srtp/v2 v2.0.1 h1:kgfh65ob3EcnFYA4kUBvU/menCp9u7qaJLXwWgpobzs=
github.com/pion/srtp/v2 v2.0.1/go.mod h1:c8NWHhhkFf/drmHTAblkdu8++lsISEBBdAuiyxgqIsE=
github.com/pion/stun v0.3.5 h1:uLUCBCkQby4S1cf6CGuR9QrVOKcvUwFeemaC865QHDg=
github.com/pion/stun v0.3.5/go.mod h1:gDMim+47EeEtfWogA37n6qXZS88L5V6LqFcf+DZA2UA=
github.com/pion/transport v0.8.10/go.mod h1:tBmha/UCjpum5hqTWhfAEs3CO4/tHSg0MYRhSzR+CZ8=
github.com/pion/transport v0.10.0/go.mod h1:BnHnUipd0rZQyTVB2SBGojFHT9CBt5C5TcsJSQGkvSE=
github.com/pion/transport v0.10.1/go.mod h1:PBis1stIILMiis0PewDw91WJeLJkyIMcEk+DwKOzf4A=
github.com/pion/transport v0.12.0/go.mod h1:N3+vZQD9HlDP5GWkZ85LohxNsDcNgofQmyL6ojX5d8Q=
github.com/pion/transport v0.12.2 h1:WYEjhloRHt1R86LhUKjC5y+P52Y11/QqEUalvtzVoys=
github.com/pion/transport v0.12.2/go.mod h1:N3+vZQD9HlDP5GWkZ85LohxNsDcNgofQmyL6ojX5d8Q=
github.com/pion/turn/v2 v2.0.5 h1:iwMHqDfPEDEOFzwWKT56eFmh6DYC6o/+xnLAEzgISbA=
github.com/pion/turn/v2 v2.0.5/go.mod h1:APg43CFyt/14Uy7heYUOGWdkem/Wu4PhCO/bjyrTqMw=
github.com/pion/udp v0.1.0 h1:uGxQsNyrqG3GLINv36Ff60covYmfrLoxzwnCsIYspXI=
github.com/pion/udp v0.1.0/go.mod h1:BPELIjbwE9PRbd/zxI/KYBnbo7B6+oA6YuEaNE8lths=
github.com/pion/webrtc v1.2.0 h1:3LGGPQEMacwG2hcDfhdvwQPz315gvjZXOfY4vaF4+I4=
github.com/pion/webrtc/v3 v3.0.4 h1:Tiw3H9fpfcwkvaxonB+Gv1DG9tmgYBQaM1vBagDHP40=
github.com/pion/webrtc/v3 v3.0.4/go.mod h1:1TmFSLpPYFTFXFHPtoq9eGP1ASTa9LC6FBh7sUY8cd4=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/sclevine/agouti v3.0.0+incompatible/go.mod h1:b4WX9W9L1sfQKXeJf1mUTLZKJ48R1S7H23Ji7oFO5Bw=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20201016220609-9e8e0b390897 h1:pLI5jrR7OSLijeIDcmRxNmw2api+jEfxLoykJVice/E=
golang.org/x/crypto v0.0.0-20201016220609-9e8e0b390897/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20191126235420-ef20fe5d7933/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20201006153459-a7d1128ccaa0/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20201031054903-ff519b6c9102/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20201201195509-5d6afe98e0b7 h1:3uJsdck53FDIpWwLeAXlia9p4C8j0BO2xZrqzKpL0D8=
golang.org/x/net v0.0.0-20201201195509-5d6afe98e0b7/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200519105757-fe76b779f299/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f h1:+Nyd8tzPX9R7BWHguqsrbFdRx3WQ/1ib8I44HXV5yTA=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0=
google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM=
google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE=
google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

@ -0,0 +1,263 @@
package main
import (
"context"
"flag"
"fmt"
"github.com/ossrs/go-oryx-lib/errors"
"github.com/ossrs/go-oryx-lib/logger"
"github.com/ossrs/srs-bench/srs"
"net"
"net/http"
"os"
"os/signal"
"strings"
"sync"
"syscall"
"time"
)
func main() {
var sr, dumpAudio, dumpVideo string
var pli int
flag.StringVar(&sr, "sr", "", "")
flag.StringVar(&dumpAudio, "da", "", "")
flag.StringVar(&dumpVideo, "dv", "", "")
flag.IntVar(&pli, "pli", 10, "")
var pr, sourceAudio, sourceVideo string
var fps int
flag.StringVar(&pr, "pr", "", "")
flag.StringVar(&sourceAudio, "sa", "", "")
flag.StringVar(&sourceVideo, "sv", "", "")
flag.IntVar(&fps, "fps", 0, "")
var audioLevel, videoTWCC bool
flag.BoolVar(&audioLevel, "al", true, "")
flag.BoolVar(&videoTWCC, "twcc", true, "")
var clients, streams, delay int
flag.IntVar(&clients, "nn", 1, "")
flag.IntVar(&streams, "sn", 1, "")
flag.IntVar(&delay, "delay", 50, "")
var statListen string
flag.StringVar(&statListen, "stat", ":18000", "")
flag.Usage = func() {
fmt.Println(fmt.Sprintf("Usage: %v [Options]", os.Args[0]))
fmt.Println(fmt.Sprintf("Options:"))
fmt.Println(fmt.Sprintf(" -nn The number of clients to simulate. Default: 1"))
fmt.Println(fmt.Sprintf(" -sn The number of streams to simulate. Variable: %%d. Default: 1"))
fmt.Println(fmt.Sprintf(" -delay The start delay in ms for each client or stream to simulate. Default: 50"))
fmt.Println(fmt.Sprintf(" -al [Optional] Whether enable audio-level. Default: true"))
fmt.Println(fmt.Sprintf(" -twcc [Optional] Whether enable vdieo-twcc. Default: true"))
fmt.Println(fmt.Sprintf(" -stat [Optional] The stat server API listen port. Default: :18000"))
fmt.Println(fmt.Sprintf("Player or Subscriber:"))
fmt.Println(fmt.Sprintf(" -sr The url to play/subscribe. If sn exceed 1, auto append variable %%d."))
fmt.Println(fmt.Sprintf(" -da [Optional] The file path to dump audio, ignore if empty."))
fmt.Println(fmt.Sprintf(" -dv [Optional] The file path to dump video, ignore if empty."))
fmt.Println(fmt.Sprintf(" -pli [Optional] PLI request interval in seconds. Default: 10"))
fmt.Println(fmt.Sprintf("Publisher:"))
fmt.Println(fmt.Sprintf(" -pr The url to publish. If sn exceed 1, auto append variable %%d."))
fmt.Println(fmt.Sprintf(" -fps The fps of .h264 source file."))
fmt.Println(fmt.Sprintf(" -sa [Optional] The file path to read audio, ignore if empty."))
fmt.Println(fmt.Sprintf(" -sv [Optional] The file path to read video, ignore if empty."))
fmt.Println(fmt.Sprintf("\n例如1个播放1个推流:"))
fmt.Println(fmt.Sprintf(" %v -sr webrtc://localhost/live/livestream", os.Args[0]))
fmt.Println(fmt.Sprintf(" %v -pr webrtc://localhost/live/livestream -sa a.ogg -sv v.h264 -fps 25", os.Args[0]))
fmt.Println(fmt.Sprintf("\n例如1个流3个播放共3个客户端"))
fmt.Println(fmt.Sprintf(" %v -sr webrtc://localhost/live/livestream -nn 3", os.Args[0]))
fmt.Println(fmt.Sprintf(" %v -pr webrtc://localhost/live/livestream -sa a.ogg -sv v.h264 -fps 25", os.Args[0]))
fmt.Println(fmt.Sprintf("\n例如2个流每个流3个播放共6个客户端"))
fmt.Println(fmt.Sprintf(" %v -sr webrtc://localhost/live/livestream_%%d -sn 2 -nn 3", os.Args[0]))
fmt.Println(fmt.Sprintf(" %v -pr webrtc://localhost/live/livestream_%%d -sn 2 -sa a.ogg -sv v.h264 -fps 25", os.Args[0]))
fmt.Println(fmt.Sprintf("\n例如2个推流"))
fmt.Println(fmt.Sprintf(" %v -pr webrtc://localhost/live/livestream_%%d -sn 2 -sa a.ogg -sv v.h264 -fps 25", os.Args[0]))
fmt.Println(fmt.Sprintf("\n例如1个录制"))
fmt.Println(fmt.Sprintf(" %v -sr webrtc://localhost/live/livestream -da a.ogg -dv v.h264", os.Args[0]))
fmt.Println(fmt.Sprintf("\n例如1个明文播放"))
fmt.Println(fmt.Sprintf(" %v -sr webrtc://localhost/live/livestream?encrypt=false", os.Args[0]))
fmt.Println()
}
flag.Parse()
showHelp := (clients <= 0 || streams <= 0)
if sr == "" && pr == "" {
showHelp = true
}
if pr != "" && (sourceAudio == "" && sourceVideo == "") {
showHelp = true
}
if showHelp {
flag.Usage()
os.Exit(-1)
}
if statListen != "" && !strings.Contains(statListen, ":") {
statListen = ":" + statListen
}
ctx := context.Background()
summaryDesc := fmt.Sprintf("clients=%v, delay=%v, al=%v, twcc=%v, stat=%v", clients, delay, audioLevel, videoTWCC, statListen)
if sr != "" {
summaryDesc = fmt.Sprintf("%v, play(url=%v, da=%v, dv=%v, pli=%v)", summaryDesc, sr, dumpAudio, dumpVideo, pli)
}
if pr != "" {
summaryDesc = fmt.Sprintf("%v, publish(url=%v, sa=%v, sv=%v, fps=%v)",
summaryDesc, pr, sourceAudio, sourceVideo, fps)
}
logger.Tf(ctx, "Start benchmark with %v", summaryDesc)
checkFlag := func() error {
if dumpVideo != "" && !strings.HasSuffix(dumpVideo, ".h264") && !strings.HasSuffix(dumpVideo, ".ivf") {
return errors.Errorf("Should be .ivf or .264, actual %v", dumpVideo)
}
if sourceVideo != "" && !strings.HasSuffix(sourceVideo, ".h264") {
return errors.Errorf("Should be .264, actual %v", sourceVideo)
}
if sourceVideo != "" && strings.HasSuffix(sourceVideo, ".h264") && fps <= 0 {
return errors.Errorf("Video fps should >0, actual %v", fps)
}
return nil
}
if err := checkFlag(); err != nil {
logger.Ef(ctx, "Check faile err %+v", err)
os.Exit(-1)
}
ctx, cancel := context.WithCancel(ctx)
// Process all signals.
go func() {
sigs := make(chan os.Signal, 1)
signal.Notify(sigs, syscall.SIGTERM, syscall.SIGINT)
for sig := range sigs {
logger.Wf(ctx, "Quit for signal %v", sig)
cancel()
}
}()
// Run tasks.
var wg sync.WaitGroup
// Start STAT API server.
wg.Add(1)
go func() {
defer wg.Done()
if statListen == "" {
return
}
var lc net.ListenConfig
ln, err := lc.Listen(ctx, "tcp", statListen)
if err != nil {
logger.Ef(ctx, "stat listen err+%v", err)
cancel()
return
}
mux := http.NewServeMux()
srs.HandleStat(ctx, mux, statListen)
srv := &http.Server{
Handler: mux,
BaseContext: func(listener net.Listener) context.Context {
return ctx
},
}
go func() {
<-ctx.Done()
srv.Shutdown(ctx)
}()
logger.Tf(ctx, "Stat listen at %v", statListen)
if err := srv.Serve(ln); err != nil {
if ctx.Err() == nil {
logger.Ef(ctx, "stat serve err+%v", err)
cancel()
}
return
}
}()
// Start all subscribers or players.
for i := 0; sr != "" && i < streams && ctx.Err() == nil; i++ {
r_auto := sr
if streams > 1 && !strings.Contains(r_auto, "%") {
r_auto += "%d"
}
r2 := r_auto
if strings.Contains(r2, "%") {
r2 = fmt.Sprintf(r2, i)
}
for j := 0; sr != "" && j < clients && ctx.Err() == nil; j++ {
// Dump audio or video only for the first client.
da, dv := dumpAudio, dumpVideo
if i > 0 {
da, dv = "", ""
}
srs.StatRTC.Subscribers.Expect++
srs.StatRTC.Subscribers.Alive++
wg.Add(1)
go func(sr, da, dv string) {
defer wg.Done()
defer func() {
srs.StatRTC.Subscribers.Alive--
}()
if err := srs.StartPlay(ctx, sr, da, dv, audioLevel, videoTWCC, pli); err != nil {
if errors.Cause(err) != context.Canceled {
logger.Wf(ctx, "Run err %+v", err)
}
}
}(r2, da, dv)
time.Sleep(time.Duration(delay) * time.Millisecond)
}
}
// Start all publishers.
for i := 0; pr != "" && i < streams && ctx.Err() == nil; i++ {
r_auto := pr
if streams > 1 && !strings.Contains(r_auto, "%") {
r_auto += "%d"
}
r2 := r_auto
if strings.Contains(r2, "%") {
r2 = fmt.Sprintf(r2, i)
}
srs.StatRTC.Publishers.Expect++
srs.StatRTC.Publishers.Alive++
wg.Add(1)
go func(pr string) {
defer wg.Done()
defer func() {
srs.StatRTC.Publishers.Alive--
}()
if err := srs.StartPublish(ctx, pr, sourceAudio, sourceVideo, fps, audioLevel, videoTWCC); err != nil {
if errors.Cause(err) != context.Canceled {
logger.Wf(ctx, "Run err %+v", err)
}
}
}(r2)
time.Sleep(time.Duration(delay) * time.Millisecond)
}
wg.Wait()
logger.Tf(ctx, "Done")
}

@ -0,0 +1,5 @@
package rtc
const (
rtpOutboundMTU = 1200
)

@ -0,0 +1,27 @@
package rtc
import (
"github.com/pion/rtp"
"github.com/pion/rtp/codecs"
"github.com/pion/webrtc/v3"
"strings"
)
func payloaderForCodec(codec webrtc.RTPCodecCapability) (rtp.Payloader, error) {
switch strings.ToLower(codec.MimeType) {
case strings.ToLower(webrtc.MimeTypeH264):
return &codecs.H264Payloader{}, nil
case strings.ToLower(webrtc.MimeTypeOpus):
return &codecs.OpusPayloader{}, nil
case strings.ToLower(webrtc.MimeTypeVP8):
return &codecs.VP8Payloader{}, nil
case strings.ToLower(webrtc.MimeTypeVP9):
return &codecs.VP9Payloader{}, nil
case strings.ToLower(webrtc.MimeTypeG722):
return &codecs.G722Payloader{}, nil
case strings.ToLower(webrtc.MimeTypePCMU), strings.ToLower(webrtc.MimeTypePCMA):
return &codecs.G711Payloader{}, nil
default:
return nil, webrtc.ErrNoPayloaderForCodec
}
}

@ -0,0 +1,27 @@
package rtc
import (
"github.com/pion/webrtc/v3"
"strings"
)
// Do a fuzzy find for a codec in the list of codecs
// Used for lookup up a codec in an existing list to find a match
func codecParametersFuzzySearch(needle webrtc.RTPCodecParameters, haystack []webrtc.RTPCodecParameters) (webrtc.RTPCodecParameters, error) {
// First attempt to match on MimeType + SDPFmtpLine
for _, c := range haystack {
if strings.EqualFold(c.RTPCodecCapability.MimeType, needle.RTPCodecCapability.MimeType) &&
c.RTPCodecCapability.SDPFmtpLine == needle.RTPCodecCapability.SDPFmtpLine {
return c, nil
}
}
// Fallback to just MimeType
for _, c := range haystack {
if strings.EqualFold(c.RTPCodecCapability.MimeType, needle.RTPCodecCapability.MimeType) {
return c, nil
}
}
return webrtc.RTPCodecParameters{}, webrtc.ErrCodecNotFound
}

@ -0,0 +1,246 @@
package rtc
import (
"github.com/pion/rtp"
"github.com/pion/webrtc/v3"
"github.com/pion/webrtc/v3/pkg/media"
"strings"
"sync"
)
// trackBinding is a single bind for a Track
// Bind can be called multiple times, this stores the
// result for a single bind call so that it can be used when writing
type trackBinding struct {
id string
ssrc webrtc.SSRC
payloadType webrtc.PayloadType
writeStream webrtc.TrackLocalWriter
}
// TrackLocalStaticRTP is a TrackLocal that has a pre-set codec and accepts RTP Packets.
// If you wish to send a media.Sample use TrackLocalStaticSample
type TrackLocalStaticRTP struct {
mu sync.RWMutex
bindings []trackBinding
codec webrtc.RTPCodecCapability
id, streamID string
}
// NewTrackLocalStaticRTP returns a TrackLocalStaticRTP.
func NewTrackLocalStaticRTP(c webrtc.RTPCodecCapability, id, streamID string) (*TrackLocalStaticRTP, error) {
return &TrackLocalStaticRTP{
codec: c,
bindings: []trackBinding{},
id: id,
streamID: streamID,
}, nil
}
// Bind is called by the PeerConnection after negotiation is complete
// This asserts that the code requested is supported by the remote peer.
// If so it setups all the state (SSRC and PayloadType) to have a call
func (s *TrackLocalStaticRTP) Bind(t webrtc.TrackLocalContext) (webrtc.RTPCodecParameters, error) {
s.mu.Lock()
defer s.mu.Unlock()
parameters := webrtc.RTPCodecParameters{RTPCodecCapability: s.codec}
if codec, err := codecParametersFuzzySearch(parameters, t.CodecParameters()); err == nil {
s.bindings = append(s.bindings, trackBinding{
ssrc: t.SSRC(),
payloadType: codec.PayloadType,
writeStream: t.WriteStream(),
id: t.ID(),
})
return codec, nil
}
return webrtc.RTPCodecParameters{}, webrtc.ErrUnsupportedCodec
}
// Unbind implements the teardown logic when the track is no longer needed. This happens
// because a track has been stopped.
func (s *TrackLocalStaticRTP) Unbind(t webrtc.TrackLocalContext) error {
s.mu.Lock()
defer s.mu.Unlock()
for i := range s.bindings {
if s.bindings[i].id == t.ID() {
s.bindings[i] = s.bindings[len(s.bindings)-1]
s.bindings = s.bindings[:len(s.bindings)-1]
return nil
}
}
return webrtc.ErrUnbindFailed
}
// ID is the unique identifier for this Track. This should be unique for the
// stream, but doesn't have to globally unique. A common example would be 'audio' or 'video'
// and StreamID would be 'desktop' or 'webcam'
func (s *TrackLocalStaticRTP) ID() string { return s.id }
// StreamID is the group this track belongs too. This must be unique
func (s *TrackLocalStaticRTP) StreamID() string { return s.streamID }
// Kind controls if this TrackLocal is audio or video
func (s *TrackLocalStaticRTP) Kind() webrtc.RTPCodecType {
switch {
case strings.HasPrefix(s.codec.MimeType, "audio/"):
return webrtc.RTPCodecTypeAudio
case strings.HasPrefix(s.codec.MimeType, "video/"):
return webrtc.RTPCodecTypeVideo
default:
return webrtc.RTPCodecType(0)
}
}
// Codec gets the Codec of the track
func (s *TrackLocalStaticRTP) Codec() webrtc.RTPCodecCapability {
return s.codec
}
// WriteRTP writes a RTP Packet to the TrackLocalStaticRTP
// If one PeerConnection fails the packets will still be sent to
// all PeerConnections. The error message will contain the ID of the failed
// PeerConnections so you can remove them
func (s *TrackLocalStaticRTP) WriteRTP(p *rtp.Packet) error {
s.mu.RLock()
defer s.mu.RUnlock()
writeErrs := []error{}
outboundPacket := *p
for _, b := range s.bindings {
outboundPacket.Header.SSRC = uint32(b.ssrc)
outboundPacket.Header.PayloadType = uint8(b.payloadType)
if _, err := b.writeStream.WriteRTP(&outboundPacket.Header, outboundPacket.Payload); err != nil {
writeErrs = append(writeErrs, err)
}
}
return FlattenErrs(writeErrs)
}
// Write writes a RTP Packet as a buffer to the TrackLocalStaticRTP
// If one PeerConnection fails the packets will still be sent to
// all PeerConnections. The error message will contain the ID of the failed
// PeerConnections so you can remove them
func (s *TrackLocalStaticRTP) Write(b []byte) (n int, err error) {
packet := &rtp.Packet{}
if err = packet.Unmarshal(b); err != nil {
return 0, err
}
return len(b), s.WriteRTP(packet)
}
// TrackLocalStaticSample is a TrackLocal that has a pre-set codec and accepts Samples.
// If you wish to send a RTP Packet use TrackLocalStaticRTP
type TrackLocalStaticSample struct {
packetizer rtp.Packetizer
rtpTrack *TrackLocalStaticRTP
clockRate float64
// Set the callback before write RTP packet.
OnBeforeWritePacket func(rtp *rtp.Packet)
}
// NewTrackLocalStaticSample returns a TrackLocalStaticSample
func NewTrackLocalStaticSample(c webrtc.RTPCodecCapability, id, streamID string) (*TrackLocalStaticSample, error) {
rtpTrack, err := NewTrackLocalStaticRTP(c, id, streamID)
if err != nil {
return nil, err
}
return &TrackLocalStaticSample{
rtpTrack: rtpTrack,
}, nil
}
// ID is the unique identifier for this Track. This should be unique for the
// stream, but doesn't have to globally unique. A common example would be 'audio' or 'video'
// and StreamID would be 'desktop' or 'webcam'
func (s *TrackLocalStaticSample) ID() string { return s.rtpTrack.ID() }
// StreamID is the group this track belongs too. This must be unique
func (s *TrackLocalStaticSample) StreamID() string { return s.rtpTrack.StreamID() }
// Kind controls if this TrackLocal is audio or video
func (s *TrackLocalStaticSample) Kind() webrtc.RTPCodecType { return s.rtpTrack.Kind() }
// Codec gets the Codec of the track
func (s *TrackLocalStaticSample) Codec() webrtc.RTPCodecCapability {
return s.rtpTrack.Codec()
}
// Bind is called by the PeerConnection after negotiation is complete
// This asserts that the code requested is supported by the remote peer.
// If so it setups all the state (SSRC and PayloadType) to have a call
func (s *TrackLocalStaticSample) Bind(t webrtc.TrackLocalContext) (webrtc.RTPCodecParameters, error) {
codec, err := s.rtpTrack.Bind(t)
if err != nil {
return codec, err
}
s.rtpTrack.mu.Lock()
defer s.rtpTrack.mu.Unlock()
// We only need one packetizer
if s.packetizer != nil {
return codec, nil
}
payloader, err := payloaderForCodec(codec.RTPCodecCapability)
if err != nil {
return codec, err
}
s.packetizer = rtp.NewPacketizer(
rtpOutboundMTU,
0, // Value is handled when writing
0, // Value is handled when writing
payloader,
rtp.NewRandomSequencer(),
codec.ClockRate,
)
s.clockRate = float64(codec.RTPCodecCapability.ClockRate)
return codec, nil
}
// Unbind implements the teardown logic when the track is no longer needed. This happens
// because a track has been stopped.
func (s *TrackLocalStaticSample) Unbind(t webrtc.TrackLocalContext) error {
return s.rtpTrack.Unbind(t)
}
// WriteSample writes a Sample to the TrackLocalStaticSample
// If one PeerConnection fails the packets will still be sent to
// all PeerConnections. The error message will contain the ID of the failed
// PeerConnections so you can remove them
func (s *TrackLocalStaticSample) WriteSample(sample media.Sample) error {
s.rtpTrack.mu.RLock()
p := s.packetizer
clockRate := s.clockRate
s.rtpTrack.mu.RUnlock()
if p == nil {
return nil
}
samples := sample.Duration.Seconds() * clockRate
packets := p.(rtp.Packetizer).Packetize(sample.Data, uint32(samples))
writeErrs := []error{}
for _, p := range packets {
if s.OnBeforeWritePacket != nil {
s.OnBeforeWritePacket(p)
}
if err := s.rtpTrack.WriteRTP(p); err != nil {
writeErrs = append(writeErrs, err)
}
}
return FlattenErrs(writeErrs)
}

@ -0,0 +1,10 @@
package rtc
import "fmt"
func FlattenErrs(errors []error) error {
if len(errors) == 0 {
return nil
}
return fmt.Errorf("%v", errors)
}

@ -0,0 +1,9 @@
language: go
go:
- 1.4.3
- 1.5.3
- tip
script:
- go test -v ./...

@ -0,0 +1,10 @@
# How to contribute
We definitely welcome patches and contribution to this project!
### Legal requirements
In order to protect both you and ourselves, you will need to sign the
[Contributor License Agreement](https://cla.developers.google.com/clas).
You may have already signed it for other Google projects.

@ -0,0 +1,9 @@
Paul Borman <borman@google.com>
bmatsuo
shawnps
theory
jboverfelt
dsymonds
cd1
wallclockbuilder
dansouza

@ -0,0 +1,27 @@
Copyright (c) 2009,2014 Google Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

@ -0,0 +1,19 @@
# uuid ![build status](https://travis-ci.org/google/uuid.svg?branch=master)
The uuid package generates and inspects UUIDs based on
[RFC 4122](http://tools.ietf.org/html/rfc4122)
and DCE 1.1: Authentication and Security Services.
This package is based on the github.com/pborman/uuid package (previously named
code.google.com/p/go-uuid). It differs from these earlier packages in that
a UUID is a 16 byte array rather than a byte slice. One loss due to this
change is the ability to represent an invalid UUID (vs a NIL UUID).
###### Install
`go get github.com/google/uuid`
###### Documentation
[![GoDoc](https://godoc.org/github.com/google/uuid?status.svg)](http://godoc.org/github.com/google/uuid)
Full `go doc` style documentation for the package can be viewed online without
installing this package by using the GoDoc site here:
http://pkg.go.dev/github.com/google/uuid

@ -0,0 +1,80 @@
// Copyright 2016 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package uuid
import (
"encoding/binary"
"fmt"
"os"
)
// A Domain represents a Version 2 domain
type Domain byte
// Domain constants for DCE Security (Version 2) UUIDs.
const (
Person = Domain(0)
Group = Domain(1)
Org = Domain(2)
)
// NewDCESecurity returns a DCE Security (Version 2) UUID.
//
// The domain should be one of Person, Group or Org.
// On a POSIX system the id should be the users UID for the Person
// domain and the users GID for the Group. The meaning of id for
// the domain Org or on non-POSIX systems is site defined.
//
// For a given domain/id pair the same token may be returned for up to
// 7 minutes and 10 seconds.
func NewDCESecurity(domain Domain, id uint32) (UUID, error) {
uuid, err := NewUUID()
if err == nil {
uuid[6] = (uuid[6] & 0x0f) | 0x20 // Version 2
uuid[9] = byte(domain)
binary.BigEndian.PutUint32(uuid[0:], id)
}
return uuid, err
}
// NewDCEPerson returns a DCE Security (Version 2) UUID in the person
// domain with the id returned by os.Getuid.
//
// NewDCESecurity(Person, uint32(os.Getuid()))
func NewDCEPerson() (UUID, error) {
return NewDCESecurity(Person, uint32(os.Getuid()))
}
// NewDCEGroup returns a DCE Security (Version 2) UUID in the group
// domain with the id returned by os.Getgid.
//
// NewDCESecurity(Group, uint32(os.Getgid()))
func NewDCEGroup() (UUID, error) {
return NewDCESecurity(Group, uint32(os.Getgid()))
}
// Domain returns the domain for a Version 2 UUID. Domains are only defined
// for Version 2 UUIDs.
func (uuid UUID) Domain() Domain {
return Domain(uuid[9])
}
// ID returns the id for a Version 2 UUID. IDs are only defined for Version 2
// UUIDs.
func (uuid UUID) ID() uint32 {
return binary.BigEndian.Uint32(uuid[0:4])
}
func (d Domain) String() string {
switch d {
case Person:
return "Person"
case Group:
return "Group"
case Org:
return "Org"
}
return fmt.Sprintf("Domain%d", int(d))
}

@ -0,0 +1,12 @@
// Copyright 2016 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package uuid generates and inspects UUIDs.
//
// UUIDs are based on RFC 4122 and DCE 1.1: Authentication and Security
// Services.
//
// A UUID is a 16 byte (128 bit) array. UUIDs may be used as keys to
// maps or compared directly.
package uuid

@ -0,0 +1 @@
module github.com/google/uuid

@ -0,0 +1,53 @@
// Copyright 2016 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package uuid
import (
"crypto/md5"
"crypto/sha1"
"hash"
)
// Well known namespace IDs and UUIDs
var (
NameSpaceDNS = Must(Parse("6ba7b810-9dad-11d1-80b4-00c04fd430c8"))
NameSpaceURL = Must(Parse("6ba7b811-9dad-11d1-80b4-00c04fd430c8"))
NameSpaceOID = Must(Parse("6ba7b812-9dad-11d1-80b4-00c04fd430c8"))
NameSpaceX500 = Must(Parse("6ba7b814-9dad-11d1-80b4-00c04fd430c8"))
Nil UUID // empty UUID, all zeros
)
// NewHash returns a new UUID derived from the hash of space concatenated with
// data generated by h. The hash should be at least 16 byte in length. The
// first 16 bytes of the hash are used to form the UUID. The version of the
// UUID will be the lower 4 bits of version. NewHash is used to implement
// NewMD5 and NewSHA1.
func NewHash(h hash.Hash, space UUID, data []byte, version int) UUID {
h.Reset()
h.Write(space[:])
h.Write(data)
s := h.Sum(nil)
var uuid UUID
copy(uuid[:], s)
uuid[6] = (uuid[6] & 0x0f) | uint8((version&0xf)<<4)
uuid[8] = (uuid[8] & 0x3f) | 0x80 // RFC 4122 variant
return uuid
}
// NewMD5 returns a new MD5 (Version 3) UUID based on the
// supplied name space and data. It is the same as calling:
//
// NewHash(md5.New(), space, data, 3)
func NewMD5(space UUID, data []byte) UUID {
return NewHash(md5.New(), space, data, 3)
}
// NewSHA1 returns a new SHA1 (Version 5) UUID based on the
// supplied name space and data. It is the same as calling:
//
// NewHash(sha1.New(), space, data, 5)
func NewSHA1(space UUID, data []byte) UUID {
return NewHash(sha1.New(), space, data, 5)
}

@ -0,0 +1,38 @@
// Copyright 2016 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package uuid
import "fmt"
// MarshalText implements encoding.TextMarshaler.
func (uuid UUID) MarshalText() ([]byte, error) {
var js [36]byte
encodeHex(js[:], uuid)
return js[:], nil
}
// UnmarshalText implements encoding.TextUnmarshaler.
func (uuid *UUID) UnmarshalText(data []byte) error {
id, err := ParseBytes(data)
if err != nil {
return err
}
*uuid = id
return nil
}
// MarshalBinary implements encoding.BinaryMarshaler.
func (uuid UUID) MarshalBinary() ([]byte, error) {
return uuid[:], nil
}
// UnmarshalBinary implements encoding.BinaryUnmarshaler.
func (uuid *UUID) UnmarshalBinary(data []byte) error {
if len(data) != 16 {
return fmt.Errorf("invalid UUID (got %d bytes)", len(data))
}
copy(uuid[:], data)
return nil
}

@ -0,0 +1,90 @@
// Copyright 2016 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package uuid
import (
"sync"
)
var (
nodeMu sync.Mutex
ifname string // name of interface being used
nodeID [6]byte // hardware for version 1 UUIDs
zeroID [6]byte // nodeID with only 0's
)
// NodeInterface returns the name of the interface from which the NodeID was
// derived. The interface "user" is returned if the NodeID was set by
// SetNodeID.
func NodeInterface() string {
defer nodeMu.Unlock()
nodeMu.Lock()
return ifname
}
// SetNodeInterface selects the hardware address to be used for Version 1 UUIDs.
// If name is "" then the first usable interface found will be used or a random
// Node ID will be generated. If a named interface cannot be found then false
// is returned.
//
// SetNodeInterface never fails when name is "".
func SetNodeInterface(name string) bool {
defer nodeMu.Unlock()
nodeMu.Lock()
return setNodeInterface(name)
}
func setNodeInterface(name string) bool {
iname, addr := getHardwareInterface(name) // null implementation for js
if iname != "" && addr != nil {
ifname = iname
copy(nodeID[:], addr)
return true
}
// We found no interfaces with a valid hardware address. If name
// does not specify a specific interface generate a random Node ID
// (section 4.1.6)
if name == "" {
ifname = "random"
randomBits(nodeID[:])
return true
}
return false
}
// NodeID returns a slice of a copy of the current Node ID, setting the Node ID
// if not already set.
func NodeID() []byte {
defer nodeMu.Unlock()
nodeMu.Lock()
if nodeID == zeroID {
setNodeInterface("")
}
nid := nodeID
return nid[:]
}
// SetNodeID sets the Node ID to be used for Version 1 UUIDs. The first 6 bytes
// of id are used. If id is less than 6 bytes then false is returned and the
// Node ID is not set.
func SetNodeID(id []byte) bool {
if len(id) < 6 {
return false
}
defer nodeMu.Unlock()
nodeMu.Lock()
copy(nodeID[:], id)
ifname = "user"
return true
}
// NodeID returns the 6 byte node id encoded in uuid. It returns nil if uuid is
// not valid. The NodeID is only well defined for version 1 and 2 UUIDs.
func (uuid UUID) NodeID() []byte {
var node [6]byte
copy(node[:], uuid[10:])
return node[:]
}

@ -0,0 +1,12 @@
// Copyright 2017 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build js
package uuid
// getHardwareInterface returns nil values for the JS version of the code.
// This remvoves the "net" dependency, because it is not used in the browser.
// Using the "net" library inflates the size of the transpiled JS code by 673k bytes.
func getHardwareInterface(name string) (string, []byte) { return "", nil }

@ -0,0 +1,33 @@
// Copyright 2017 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !js
package uuid
import "net"
var interfaces []net.Interface // cached list of interfaces
// getHardwareInterface returns the name and hardware address of interface name.
// If name is "" then the name and hardware address of one of the system's
// interfaces is returned. If no interfaces are found (name does not exist or
// there are no interfaces) then "", nil is returned.
//
// Only addresses of at least 6 bytes are returned.
func getHardwareInterface(name string) (string, []byte) {
if interfaces == nil {
var err error
interfaces, err = net.Interfaces()
if err != nil {
return "", nil
}
}
for _, ifs := range interfaces {
if len(ifs.HardwareAddr) >= 6 && (name == "" || name == ifs.Name) {
return ifs.Name, ifs.HardwareAddr
}
}
return "", nil
}

@ -0,0 +1,59 @@
// Copyright 2016 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package uuid
import (
"database/sql/driver"
"fmt"
)
// Scan implements sql.Scanner so UUIDs can be read from databases transparently
// Currently, database types that map to string and []byte are supported. Please
// consult database-specific driver documentation for matching types.
func (uuid *UUID) Scan(src interface{}) error {
switch src := src.(type) {
case nil:
return nil
case string:
// if an empty UUID comes from a table, we return a null UUID
if src == "" {
return nil
}
// see Parse for required string format
u, err := Parse(src)
if err != nil {
return fmt.Errorf("Scan: %v", err)
}
*uuid = u
case []byte:
// if an empty UUID comes from a table, we return a null UUID
if len(src) == 0 {
return nil
}
// assumes a simple slice of bytes if 16 bytes
// otherwise attempts to parse
if len(src) != 16 {
return uuid.Scan(string(src))
}
copy((*uuid)[:], src)
default:
return fmt.Errorf("Scan: unable to scan type %T into UUID", src)
}
return nil
}
// Value implements sql.Valuer so that UUIDs can be written to databases
// transparently. Currently, UUIDs map to strings. Please consult
// database-specific driver documentation for matching types.
func (uuid UUID) Value() (driver.Value, error) {
return uuid.String(), nil
}

@ -0,0 +1,123 @@
// Copyright 2016 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package uuid
import (
"encoding/binary"
"sync"
"time"
)
// A Time represents a time as the number of 100's of nanoseconds since 15 Oct
// 1582.
type Time int64
const (
lillian = 2299160 // Julian day of 15 Oct 1582
unix = 2440587 // Julian day of 1 Jan 1970
epoch = unix - lillian // Days between epochs
g1582 = epoch * 86400 // seconds between epochs
g1582ns100 = g1582 * 10000000 // 100s of a nanoseconds between epochs
)
var (
timeMu sync.Mutex
lasttime uint64 // last time we returned
clockSeq uint16 // clock sequence for this run
timeNow = time.Now // for testing
)
// UnixTime converts t the number of seconds and nanoseconds using the Unix
// epoch of 1 Jan 1970.
func (t Time) UnixTime() (sec, nsec int64) {
sec = int64(t - g1582ns100)
nsec = (sec % 10000000) * 100
sec /= 10000000
return sec, nsec
}
// GetTime returns the current Time (100s of nanoseconds since 15 Oct 1582) and
// clock sequence as well as adjusting the clock sequence as needed. An error
// is returned if the current time cannot be determined.
func GetTime() (Time, uint16, error) {
defer timeMu.Unlock()
timeMu.Lock()
return getTime()
}
func getTime() (Time, uint16, error) {
t := timeNow()
// If we don't have a clock sequence already, set one.
if clockSeq == 0 {
setClockSequence(-1)
}
now := uint64(t.UnixNano()/100) + g1582ns100
// If time has gone backwards with this clock sequence then we
// increment the clock sequence
if now <= lasttime {
clockSeq = ((clockSeq + 1) & 0x3fff) | 0x8000
}
lasttime = now
return Time(now), clockSeq, nil
}
// ClockSequence returns the current clock sequence, generating one if not
// already set. The clock sequence is only used for Version 1 UUIDs.
//
// The uuid package does not use global static storage for the clock sequence or
// the last time a UUID was generated. Unless SetClockSequence is used, a new
// random clock sequence is generated the first time a clock sequence is
// requested by ClockSequence, GetTime, or NewUUID. (section 4.2.1.1)
func ClockSequence() int {
defer timeMu.Unlock()
timeMu.Lock()
return clockSequence()
}
func clockSequence() int {
if clockSeq == 0 {
setClockSequence(-1)
}
return int(clockSeq & 0x3fff)
}
// SetClockSequence sets the clock sequence to the lower 14 bits of seq. Setting to
// -1 causes a new sequence to be generated.
func SetClockSequence(seq int) {
defer timeMu.Unlock()
timeMu.Lock()
setClockSequence(seq)
}
func setClockSequence(seq int) {
if seq == -1 {
var b [2]byte
randomBits(b[:]) // clock sequence
seq = int(b[0])<<8 | int(b[1])
}
oldSeq := clockSeq
clockSeq = uint16(seq&0x3fff) | 0x8000 // Set our variant
if oldSeq != clockSeq {
lasttime = 0
}
}
// Time returns the time in 100s of nanoseconds since 15 Oct 1582 encoded in
// uuid. The time is only defined for version 1 and 2 UUIDs.
func (uuid UUID) Time() Time {
time := int64(binary.BigEndian.Uint32(uuid[0:4]))
time |= int64(binary.BigEndian.Uint16(uuid[4:6])) << 32
time |= int64(binary.BigEndian.Uint16(uuid[6:8])&0xfff) << 48
return Time(time)
}
// ClockSequence returns the clock sequence encoded in uuid.
// The clock sequence is only well defined for version 1 and 2 UUIDs.
func (uuid UUID) ClockSequence() int {
return int(binary.BigEndian.Uint16(uuid[8:10])) & 0x3fff
}

@ -0,0 +1,43 @@
// Copyright 2016 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package uuid
import (
"io"
)
// randomBits completely fills slice b with random data.
func randomBits(b []byte) {
if _, err := io.ReadFull(rander, b); err != nil {
panic(err.Error()) // rand should never fail
}
}
// xvalues returns the value of a byte as a hexadecimal digit or 255.
var xvalues = [256]byte{
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 255, 255, 255, 255, 255, 255,
255, 10, 11, 12, 13, 14, 15, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 10, 11, 12, 13, 14, 15, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
}
// xtob converts hex characters x1 and x2 into a byte.
func xtob(x1, x2 byte) (byte, bool) {
b1 := xvalues[x1]
b2 := xvalues[x2]
return (b1 << 4) | b2, b1 != 255 && b2 != 255
}

@ -0,0 +1,245 @@
// Copyright 2018 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package uuid
import (
"bytes"
"crypto/rand"
"encoding/hex"
"errors"
"fmt"
"io"
"strings"
)
// A UUID is a 128 bit (16 byte) Universal Unique IDentifier as defined in RFC
// 4122.
type UUID [16]byte
// A Version represents a UUID's version.
type Version byte
// A Variant represents a UUID's variant.
type Variant byte
// Constants returned by Variant.
const (
Invalid = Variant(iota) // Invalid UUID
RFC4122 // The variant specified in RFC4122
Reserved // Reserved, NCS backward compatibility.
Microsoft // Reserved, Microsoft Corporation backward compatibility.
Future // Reserved for future definition.
)
var rander = rand.Reader // random function
// Parse decodes s into a UUID or returns an error. Both the standard UUID
// forms of xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx and
// urn:uuid:xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx are decoded as well as the
// Microsoft encoding {xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx} and the raw hex
// encoding: xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx.
func Parse(s string) (UUID, error) {
var uuid UUID
switch len(s) {
// xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx
case 36:
// urn:uuid:xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx
case 36 + 9:
if strings.ToLower(s[:9]) != "urn:uuid:" {
return uuid, fmt.Errorf("invalid urn prefix: %q", s[:9])
}
s = s[9:]
// {xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx}
case 36 + 2:
s = s[1:]
// xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
case 32:
var ok bool
for i := range uuid {
uuid[i], ok = xtob(s[i*2], s[i*2+1])
if !ok {
return uuid, errors.New("invalid UUID format")
}
}
return uuid, nil
default:
return uuid, fmt.Errorf("invalid UUID length: %d", len(s))
}
// s is now at least 36 bytes long
// it must be of the form xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx
if s[8] != '-' || s[13] != '-' || s[18] != '-' || s[23] != '-' {
return uuid, errors.New("invalid UUID format")
}
for i, x := range [16]int{
0, 2, 4, 6,
9, 11,
14, 16,
19, 21,
24, 26, 28, 30, 32, 34} {
v, ok := xtob(s[x], s[x+1])
if !ok {
return uuid, errors.New("invalid UUID format")
}
uuid[i] = v
}
return uuid, nil
}
// ParseBytes is like Parse, except it parses a byte slice instead of a string.
func ParseBytes(b []byte) (UUID, error) {
var uuid UUID
switch len(b) {
case 36: // xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx
case 36 + 9: // urn:uuid:xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx
if !bytes.Equal(bytes.ToLower(b[:9]), []byte("urn:uuid:")) {
return uuid, fmt.Errorf("invalid urn prefix: %q", b[:9])
}
b = b[9:]
case 36 + 2: // {xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx}
b = b[1:]
case 32: // xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
var ok bool
for i := 0; i < 32; i += 2 {
uuid[i/2], ok = xtob(b[i], b[i+1])
if !ok {
return uuid, errors.New("invalid UUID format")
}
}
return uuid, nil
default:
return uuid, fmt.Errorf("invalid UUID length: %d", len(b))
}
// s is now at least 36 bytes long
// it must be of the form xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx
if b[8] != '-' || b[13] != '-' || b[18] != '-' || b[23] != '-' {
return uuid, errors.New("invalid UUID format")
}
for i, x := range [16]int{
0, 2, 4, 6,
9, 11,
14, 16,
19, 21,
24, 26, 28, 30, 32, 34} {
v, ok := xtob(b[x], b[x+1])
if !ok {
return uuid, errors.New("invalid UUID format")
}
uuid[i] = v
}
return uuid, nil
}
// MustParse is like Parse but panics if the string cannot be parsed.
// It simplifies safe initialization of global variables holding compiled UUIDs.
func MustParse(s string) UUID {
uuid, err := Parse(s)
if err != nil {
panic(`uuid: Parse(` + s + `): ` + err.Error())
}
return uuid
}
// FromBytes creates a new UUID from a byte slice. Returns an error if the slice
// does not have a length of 16. The bytes are copied from the slice.
func FromBytes(b []byte) (uuid UUID, err error) {
err = uuid.UnmarshalBinary(b)
return uuid, err
}
// Must returns uuid if err is nil and panics otherwise.
func Must(uuid UUID, err error) UUID {
if err != nil {
panic(err)
}
return uuid
}
// String returns the string form of uuid, xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx
// , or "" if uuid is invalid.
func (uuid UUID) String() string {
var buf [36]byte
encodeHex(buf[:], uuid)
return string(buf[:])
}
// URN returns the RFC 2141 URN form of uuid,
// urn:uuid:xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx, or "" if uuid is invalid.
func (uuid UUID) URN() string {
var buf [36 + 9]byte
copy(buf[:], "urn:uuid:")
encodeHex(buf[9:], uuid)
return string(buf[:])
}
func encodeHex(dst []byte, uuid UUID) {
hex.Encode(dst, uuid[:4])
dst[8] = '-'
hex.Encode(dst[9:13], uuid[4:6])
dst[13] = '-'
hex.Encode(dst[14:18], uuid[6:8])
dst[18] = '-'
hex.Encode(dst[19:23], uuid[8:10])
dst[23] = '-'
hex.Encode(dst[24:], uuid[10:])
}
// Variant returns the variant encoded in uuid.
func (uuid UUID) Variant() Variant {
switch {
case (uuid[8] & 0xc0) == 0x80:
return RFC4122
case (uuid[8] & 0xe0) == 0xc0:
return Microsoft
case (uuid[8] & 0xe0) == 0xe0:
return Future
default:
return Reserved
}
}
// Version returns the version of uuid.
func (uuid UUID) Version() Version {
return Version(uuid[6] >> 4)
}
func (v Version) String() string {
if v > 15 {
return fmt.Sprintf("BAD_VERSION_%d", v)
}
return fmt.Sprintf("VERSION_%d", v)
}
func (v Variant) String() string {
switch v {
case RFC4122:
return "RFC4122"
case Reserved:
return "Reserved"
case Microsoft:
return "Microsoft"
case Future:
return "Future"
case Invalid:
return "Invalid"
}
return fmt.Sprintf("BadVariant%d", int(v))
}
// SetRand sets the random number generator to r, which implements io.Reader.
// If r.Read returns an error when the package requests random data then
// a panic will be issued.
//
// Calling SetRand with nil sets the random number generator to the default
// generator.
func SetRand(r io.Reader) {
if r == nil {
rander = rand.Reader
return
}
rander = r
}

@ -0,0 +1,44 @@
// Copyright 2016 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package uuid
import (
"encoding/binary"
)
// NewUUID returns a Version 1 UUID based on the current NodeID and clock
// sequence, and the current time. If the NodeID has not been set by SetNodeID
// or SetNodeInterface then it will be set automatically. If the NodeID cannot
// be set NewUUID returns nil. If clock sequence has not been set by
// SetClockSequence then it will be set automatically. If GetTime fails to
// return the current NewUUID returns nil and an error.
//
// In most cases, New should be used.
func NewUUID() (UUID, error) {
var uuid UUID
now, seq, err := GetTime()
if err != nil {
return uuid, err
}
timeLow := uint32(now & 0xffffffff)
timeMid := uint16((now >> 32) & 0xffff)
timeHi := uint16((now >> 48) & 0x0fff)
timeHi |= 0x1000 // Version 1
binary.BigEndian.PutUint32(uuid[0:], timeLow)
binary.BigEndian.PutUint16(uuid[4:], timeMid)
binary.BigEndian.PutUint16(uuid[6:], timeHi)
binary.BigEndian.PutUint16(uuid[8:], seq)
nodeMu.Lock()
if nodeID == zeroID {
setNodeInterface("")
}
copy(uuid[10:], nodeID[:])
nodeMu.Unlock()
return uuid, nil
}

@ -0,0 +1,43 @@
// Copyright 2016 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package uuid
import "io"
// New creates a new random UUID or panics. New is equivalent to
// the expression
//
// uuid.Must(uuid.NewRandom())
func New() UUID {
return Must(NewRandom())
}
// NewRandom returns a Random (Version 4) UUID.
//
// The strength of the UUIDs is based on the strength of the crypto/rand
// package.
//
// A note about uniqueness derived from the UUID Wikipedia entry:
//
// Randomly generated UUIDs have 122 random bits. One's annual risk of being
// hit by a meteorite is estimated to be one chance in 17 billion, that
// means the probability is about 0.00000000006 (6 × 1011),
// equivalent to the odds of creating a few tens of trillions of UUIDs in a
// year and having one duplicate.
func NewRandom() (UUID, error) {
return NewRandomFromReader(rander)
}
// NewRandomFromReader returns a UUID based on bytes read from a given io.Reader.
func NewRandomFromReader(r io.Reader) (UUID, error) {
var uuid UUID
_, err := io.ReadFull(r, uuid[:])
if err != nil {
return Nil, err
}
uuid[6] = (uuid[6] & 0x0f) | 0x40 // Version 4
uuid[8] = (uuid[8] & 0x3f) | 0x80 // Variant is 10
return uuid, nil
}

@ -0,0 +1,21 @@
The MIT License (MIT)
Copyright (c) 2013-2017 winlin
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

@ -0,0 +1,23 @@
Copyright (c) 2015, Dave Cheney <dave@cheney.net>
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

@ -0,0 +1,52 @@
# errors [![Travis-CI](https://travis-ci.org/pkg/errors.svg)](https://travis-ci.org/pkg/errors) [![AppVeyor](https://ci.appveyor.com/api/projects/status/b98mptawhudj53ep/branch/master?svg=true)](https://ci.appveyor.com/project/davecheney/errors/branch/master) [![GoDoc](https://godoc.org/github.com/pkg/errors?status.svg)](http://godoc.org/github.com/pkg/errors) [![Report card](https://goreportcard.com/badge/github.com/pkg/errors)](https://goreportcard.com/report/github.com/pkg/errors)
Package errors provides simple error handling primitives.
`go get github.com/pkg/errors`
The traditional error handling idiom in Go is roughly akin to
```go
if err != nil {
return err
}
```
which applied recursively up the call stack results in error reports without context or debugging information. The errors package allows programmers to add context to the failure path in their code in a way that does not destroy the original value of the error.
## Adding context to an error
The errors.Wrap function returns a new error that adds context to the original error. For example
```go
_, err := ioutil.ReadAll(r)
if err != nil {
return errors.Wrap(err, "read failed")
}
```
## Retrieving the cause of an error
Using `errors.Wrap` constructs a stack of errors, adding context to the preceding error. Depending on the nature of the error it may be necessary to reverse the operation of errors.Wrap to retrieve the original error for inspection. Any error value which implements this interface can be inspected by `errors.Cause`.
```go
type causer interface {
Cause() error
}
```
`errors.Cause` will recursively retrieve the topmost error which does not implement `causer`, which is assumed to be the original cause. For example:
```go
switch err := errors.Cause(err).(type) {
case *MyError:
// handle specifically
default:
// unknown error
}
```
[Read the package documentation for more information](https://godoc.org/github.com/pkg/errors).
## Contributing
We welcome pull requests, bug fixes and issue reports. With that said, the bar for adding new symbols to this package is intentionally set high.
Before proposing a change, please discuss your change by raising an issue.
## Licence
BSD-2-Clause

@ -0,0 +1,270 @@
// Package errors provides simple error handling primitives.
//
// The traditional error handling idiom in Go is roughly akin to
//
// if err != nil {
// return err
// }
//
// which applied recursively up the call stack results in error reports
// without context or debugging information. The errors package allows
// programmers to add context to the failure path in their code in a way
// that does not destroy the original value of the error.
//
// Adding context to an error
//
// The errors.Wrap function returns a new error that adds context to the
// original error by recording a stack trace at the point Wrap is called,
// and the supplied message. For example
//
// _, err := ioutil.ReadAll(r)
// if err != nil {
// return errors.Wrap(err, "read failed")
// }
//
// If additional control is required the errors.WithStack and errors.WithMessage
// functions destructure errors.Wrap into its component operations of annotating
// an error with a stack trace and an a message, respectively.
//
// Retrieving the cause of an error
//
// Using errors.Wrap constructs a stack of errors, adding context to the
// preceding error. Depending on the nature of the error it may be necessary
// to reverse the operation of errors.Wrap to retrieve the original error
// for inspection. Any error value which implements this interface
//
// type causer interface {
// Cause() error
// }
//
// can be inspected by errors.Cause. errors.Cause will recursively retrieve
// the topmost error which does not implement causer, which is assumed to be
// the original cause. For example:
//
// switch err := errors.Cause(err).(type) {
// case *MyError:
// // handle specifically
// default:
// // unknown error
// }
//
// causer interface is not exported by this package, but is considered a part
// of stable public API.
//
// Formatted printing of errors
//
// All error values returned from this package implement fmt.Formatter and can
// be formatted by the fmt package. The following verbs are supported
//
// %s print the error. If the error has a Cause it will be
// printed recursively
// %v see %s
// %+v extended format. Each Frame of the error's StackTrace will
// be printed in detail.
//
// Retrieving the stack trace of an error or wrapper
//
// New, Errorf, Wrap, and Wrapf record a stack trace at the point they are
// invoked. This information can be retrieved with the following interface.
//
// type stackTracer interface {
// StackTrace() errors.StackTrace
// }
//
// Where errors.StackTrace is defined as
//
// type StackTrace []Frame
//
// The Frame type represents a call site in the stack trace. Frame supports
// the fmt.Formatter interface that can be used for printing information about
// the stack trace of this error. For example:
//
// if err, ok := err.(stackTracer); ok {
// for _, f := range err.StackTrace() {
// fmt.Printf("%+s:%d", f)
// }
// }
//
// stackTracer interface is not exported by this package, but is considered a part
// of stable public API.
//
// See the documentation for Frame.Format for more details.
// Fork from https://github.com/pkg/errors
package errors
import (
"fmt"
"io"
)
// New returns an error with the supplied message.
// New also records the stack trace at the point it was called.
func New(message string) error {
return &fundamental{
msg: message,
stack: callers(),
}
}
// Errorf formats according to a format specifier and returns the string
// as a value that satisfies error.
// Errorf also records the stack trace at the point it was called.
func Errorf(format string, args ...interface{}) error {
return &fundamental{
msg: fmt.Sprintf(format, args...),
stack: callers(),
}
}
// fundamental is an error that has a message and a stack, but no caller.
type fundamental struct {
msg string
*stack
}
func (f *fundamental) Error() string { return f.msg }
func (f *fundamental) Format(s fmt.State, verb rune) {
switch verb {
case 'v':
if s.Flag('+') {
io.WriteString(s, f.msg)
f.stack.Format(s, verb)
return
}
fallthrough
case 's':
io.WriteString(s, f.msg)
case 'q':
fmt.Fprintf(s, "%q", f.msg)
}
}
// WithStack annotates err with a stack trace at the point WithStack was called.
// If err is nil, WithStack returns nil.
func WithStack(err error) error {
if err == nil {
return nil
}
return &withStack{
err,
callers(),
}
}
type withStack struct {
error
*stack
}
func (w *withStack) Cause() error { return w.error }
func (w *withStack) Format(s fmt.State, verb rune) {
switch verb {
case 'v':
if s.Flag('+') {
fmt.Fprintf(s, "%+v", w.Cause())
w.stack.Format(s, verb)
return
}
fallthrough
case 's':
io.WriteString(s, w.Error())
case 'q':
fmt.Fprintf(s, "%q", w.Error())
}
}
// Wrap returns an error annotating err with a stack trace
// at the point Wrap is called, and the supplied message.
// If err is nil, Wrap returns nil.
func Wrap(err error, message string) error {
if err == nil {
return nil
}
err = &withMessage{
cause: err,
msg: message,
}
return &withStack{
err,
callers(),
}
}
// Wrapf returns an error annotating err with a stack trace
// at the point Wrapf is call, and the format specifier.
// If err is nil, Wrapf returns nil.
func Wrapf(err error, format string, args ...interface{}) error {
if err == nil {
return nil
}
err = &withMessage{
cause: err,
msg: fmt.Sprintf(format, args...),
}
return &withStack{
err,
callers(),
}
}
// WithMessage annotates err with a new message.
// If err is nil, WithMessage returns nil.
func WithMessage(err error, message string) error {
if err == nil {
return nil
}
return &withMessage{
cause: err,
msg: message,
}
}
type withMessage struct {
cause error
msg string
}
func (w *withMessage) Error() string { return w.msg + ": " + w.cause.Error() }
func (w *withMessage) Cause() error { return w.cause }
func (w *withMessage) Format(s fmt.State, verb rune) {
switch verb {
case 'v':
if s.Flag('+') {
fmt.Fprintf(s, "%+v\n", w.Cause())
io.WriteString(s, w.msg)
return
}
fallthrough
case 's', 'q':
io.WriteString(s, w.Error())
}
}
// Cause returns the underlying cause of the error, if possible.
// An error value has a cause if it implements the following
// interface:
//
// type causer interface {
// Cause() error
// }
//
// If the error does not implement Cause, the original error will
// be returned. If the error is nil, nil will be returned without further
// investigation.
func Cause(err error) error {
type causer interface {
Cause() error
}
for err != nil {
cause, ok := err.(causer)
if !ok {
break
}
err = cause.Cause()
}
return err
}

@ -0,0 +1,187 @@
// Fork from https://github.com/pkg/errors
package errors
import (
"fmt"
"io"
"path"
"runtime"
"strings"
)
// Frame represents a program counter inside a stack frame.
type Frame uintptr
// pc returns the program counter for this frame;
// multiple frames may have the same PC value.
func (f Frame) pc() uintptr { return uintptr(f) - 1 }
// file returns the full path to the file that contains the
// function for this Frame's pc.
func (f Frame) file() string {
fn := runtime.FuncForPC(f.pc())
if fn == nil {
return "unknown"
}
file, _ := fn.FileLine(f.pc())
return file
}
// line returns the line number of source code of the
// function for this Frame's pc.
func (f Frame) line() int {
fn := runtime.FuncForPC(f.pc())
if fn == nil {
return 0
}
_, line := fn.FileLine(f.pc())
return line
}
// Format formats the frame according to the fmt.Formatter interface.
//
// %s source file
// %d source line
// %n function name
// %v equivalent to %s:%d
//
// Format accepts flags that alter the printing of some verbs, as follows:
//
// %+s path of source file relative to the compile time GOPATH
// %+v equivalent to %+s:%d
func (f Frame) Format(s fmt.State, verb rune) {
switch verb {
case 's':
switch {
case s.Flag('+'):
pc := f.pc()
fn := runtime.FuncForPC(pc)
if fn == nil {
io.WriteString(s, "unknown")
} else {
file, _ := fn.FileLine(pc)
fmt.Fprintf(s, "%s\n\t%s", fn.Name(), file)
}
default:
io.WriteString(s, path.Base(f.file()))
}
case 'd':
fmt.Fprintf(s, "%d", f.line())
case 'n':
name := runtime.FuncForPC(f.pc()).Name()
io.WriteString(s, funcname(name))
case 'v':
f.Format(s, 's')
io.WriteString(s, ":")
f.Format(s, 'd')
}
}
// StackTrace is stack of Frames from innermost (newest) to outermost (oldest).
type StackTrace []Frame
// Format formats the stack of Frames according to the fmt.Formatter interface.
//
// %s lists source files for each Frame in the stack
// %v lists the source file and line number for each Frame in the stack
//
// Format accepts flags that alter the printing of some verbs, as follows:
//
// %+v Prints filename, function, and line number for each Frame in the stack.
func (st StackTrace) Format(s fmt.State, verb rune) {
switch verb {
case 'v':
switch {
case s.Flag('+'):
for _, f := range st {
fmt.Fprintf(s, "\n%+v", f)
}
case s.Flag('#'):
fmt.Fprintf(s, "%#v", []Frame(st))
default:
fmt.Fprintf(s, "%v", []Frame(st))
}
case 's':
fmt.Fprintf(s, "%s", []Frame(st))
}
}
// stack represents a stack of program counters.
type stack []uintptr
func (s *stack) Format(st fmt.State, verb rune) {
switch verb {
case 'v':
switch {
case st.Flag('+'):
for _, pc := range *s {
f := Frame(pc)
fmt.Fprintf(st, "\n%+v", f)
}
}
}
}
func (s *stack) StackTrace() StackTrace {
f := make([]Frame, len(*s))
for i := 0; i < len(f); i++ {
f[i] = Frame((*s)[i])
}
return f
}
func callers() *stack {
const depth = 32
var pcs [depth]uintptr
n := runtime.Callers(3, pcs[:])
var st stack = pcs[0:n]
return &st
}
// funcname removes the path prefix component of a function's name reported by func.Name().
func funcname(name string) string {
i := strings.LastIndex(name, "/")
name = name[i+1:]
i = strings.Index(name, ".")
return name[i+1:]
}
func trimGOPATH(name, file string) string {
// Here we want to get the source file path relative to the compile time
// GOPATH. As of Go 1.6.x there is no direct way to know the compiled
// GOPATH at runtime, but we can infer the number of path segments in the
// GOPATH. We note that fn.Name() returns the function name qualified by
// the import path, which does not include the GOPATH. Thus we can trim
// segments from the beginning of the file path until the number of path
// separators remaining is one more than the number of path separators in
// the function name. For example, given:
//
// GOPATH /home/user
// file /home/user/src/pkg/sub/file.go
// fn.Name() pkg/sub.Type.Method
//
// We want to produce:
//
// pkg/sub/file.go
//
// From this we can easily see that fn.Name() has one less path separator
// than our desired output. We count separators from the end of the file
// path until it finds two more than in the function name and then move
// one character forward to preserve the initial path segment without a
// leading separator.
const sep = "/"
goal := strings.Count(name, sep) + 2
i := len(file)
for n := 0; n < goal; n++ {
i = strings.LastIndex(file[:i], sep)
if i == -1 {
// not enough separators found, set i so that the slice expression
// below leaves file unmodified
i = -len(sep)
break
}
}
// get back to 0 or trim the leading separator
file = file[i+len(sep):]
return file
}

@ -0,0 +1,86 @@
// The MIT License (MIT)
//
// Copyright (c) 2013-2017 Oryx(ossrs)
//
// Permission is hereby granted, free of charge, to any person obtaining a copy of
// this software and associated documentation files (the "Software"), to deal in
// the Software without restriction, including without limitation the rights to
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
// the Software, and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
// +build go1.7
package logger
import (
"context"
"fmt"
"os"
)
func (v *loggerPlus) Println(ctx Context, a ...interface{}) {
args := v.contextFormat(ctx, a...)
v.doPrintln(args...)
}
func (v *loggerPlus) Printf(ctx Context, format string, a ...interface{}) {
format, args := v.contextFormatf(ctx, format, a...)
v.doPrintf(format, args...)
}
func (v *loggerPlus) contextFormat(ctx Context, a ...interface{}) []interface{} {
if ctx, ok := ctx.(context.Context); ok {
if cid, ok := ctx.Value(cidKey).(int); ok {
return append([]interface{}{fmt.Sprintf("[%v][%v]", os.Getpid(), cid)}, a...)
}
} else {
return v.format(ctx, a...)
}
return a
}
func (v *loggerPlus) contextFormatf(ctx Context, format string, a ...interface{}) (string, []interface{}) {
if ctx, ok := ctx.(context.Context); ok {
if cid, ok := ctx.Value(cidKey).(int); ok {
return "[%v][%v] " + format, append([]interface{}{os.Getpid(), cid}, a...)
}
} else {
return v.formatf(ctx, format, a...)
}
return format, a
}
// User should use context with value to pass the cid.
type key string
var cidKey key = "cid.logger.ossrs.org"
var gCid int = 999
// Create context with value.
func WithContext(ctx context.Context) context.Context {
gCid += 1
return context.WithValue(ctx, cidKey, gCid)
}
// Create context with value from parent, copy the cid from source context.
// @remark Create new cid if source has no cid represent.
func AliasContext(parent context.Context, source context.Context) context.Context {
if source != nil {
if cid, ok := source.Value(cidKey).(int); ok {
return context.WithValue(parent, cidKey, cid)
}
}
return WithContext(parent)
}

@ -0,0 +1,239 @@
// The MIT License (MIT)
//
// Copyright (c) 2013-2017 Oryx(ossrs)
//
// Permission is hereby granted, free of charge, to any person obtaining a copy of
// this software and associated documentation files (the "Software"), to deal in
// the Software without restriction, including without limitation the rights to
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
// the Software, and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
// The oryx logger package provides connection-oriented log service.
// logger.I(ctx, ...)
// logger.T(ctx, ...)
// logger.W(ctx, ...)
// logger.E(ctx, ...)
// Or use format:
// logger.If(ctx, format, ...)
// logger.Tf(ctx, format, ...)
// logger.Wf(ctx, format, ...)
// logger.Ef(ctx, format, ...)
// @remark the Context is optional thus can be nil.
// @remark From 1.7+, the ctx could be context.Context, wrap by logger.WithContext,
// please read ExampleLogger_ContextGO17().
package logger
import (
"fmt"
"io"
"io/ioutil"
"log"
"os"
)
// default level for logger.
const (
logInfoLabel = "[info] "
logTraceLabel = "[trace] "
logWarnLabel = "[warn] "
logErrorLabel = "[error] "
)
// The context for current goroutine.
// It maybe a cidContext or context.Context from GO1.7.
// @remark Use logger.WithContext(ctx) to wrap the context.
type Context interface{}
// The context to get current coroutine cid.
type cidContext interface {
Cid() int
}
// the LOG+ which provides connection-based log.
type loggerPlus struct {
logger *log.Logger
}
func NewLoggerPlus(l *log.Logger) Logger {
return &loggerPlus{logger: l}
}
func (v *loggerPlus) format(ctx Context, a ...interface{}) []interface{} {
if ctx == nil {
return append([]interface{}{fmt.Sprintf("[%v] ", os.Getpid())}, a...)
} else if ctx, ok := ctx.(cidContext); ok {
return append([]interface{}{fmt.Sprintf("[%v][%v] ", os.Getpid(), ctx.Cid())}, a...)
}
return a
}
func (v *loggerPlus) formatf(ctx Context, format string, a ...interface{}) (string, []interface{}) {
if ctx == nil {
return "[%v] " + format, append([]interface{}{os.Getpid()}, a...)
} else if ctx, ok := ctx.(cidContext); ok {
return "[%v][%v] " + format, append([]interface{}{os.Getpid(), ctx.Cid()}, a...)
}
return format, a
}
var colorYellow = "\033[33m"
var colorRed = "\033[31m"
var colorBlack = "\033[0m"
func (v *loggerPlus) doPrintln(args ...interface{}) {
if previousCloser == nil {
if v == Error {
fmt.Fprintf(os.Stdout, colorRed)
v.logger.Println(args...)
fmt.Fprintf(os.Stdout, colorBlack)
} else if v == Warn {
fmt.Fprintf(os.Stdout, colorYellow)
v.logger.Println(args...)
fmt.Fprintf(os.Stdout, colorBlack)
} else {
v.logger.Println(args...)
}
} else {
v.logger.Println(args...)
}
}
func (v *loggerPlus) doPrintf(format string, args ...interface{}) {
if previousCloser == nil {
if v == Error {
fmt.Fprintf(os.Stdout, colorRed)
v.logger.Printf(format, args...)
fmt.Fprintf(os.Stdout, colorBlack)
} else if v == Warn {
fmt.Fprintf(os.Stdout, colorYellow)
v.logger.Printf(format, args...)
fmt.Fprintf(os.Stdout, colorBlack)
} else {
v.logger.Printf(format, args...)
}
} else {
v.logger.Printf(format, args...)
}
}
// Info, the verbose info level, very detail log, the lowest level, to discard.
var Info Logger
// Alias for Info level println.
func I(ctx Context, a ...interface{}) {
Info.Println(ctx, a...)
}
// Printf for Info level log.
func If(ctx Context, format string, a ...interface{}) {
Info.Printf(ctx, format, a...)
}
// Trace, the trace level, something important, the default log level, to stdout.
var Trace Logger
// Alias for Trace level println.
func T(ctx Context, a ...interface{}) {
Trace.Println(ctx, a...)
}
// Printf for Trace level log.
func Tf(ctx Context, format string, a ...interface{}) {
Trace.Printf(ctx, format, a...)
}
// Warn, the warning level, dangerous information, to Stdout.
var Warn Logger
// Alias for Warn level println.
func W(ctx Context, a ...interface{}) {
Warn.Println(ctx, a...)
}
// Printf for Warn level log.
func Wf(ctx Context, format string, a ...interface{}) {
Warn.Printf(ctx, format, a...)
}
// Error, the error level, fatal error things, ot Stdout.
var Error Logger
// Alias for Error level println.
func E(ctx Context, a ...interface{}) {
Error.Println(ctx, a...)
}
// Printf for Error level log.
func Ef(ctx Context, format string, a ...interface{}) {
Error.Printf(ctx, format, a...)
}
// The logger for oryx.
type Logger interface {
// Println for logger plus,
// @param ctx the connection-oriented context,
// or context.Context from GO1.7, or nil to ignore.
Println(ctx Context, a ...interface{})
Printf(ctx Context, format string, a ...interface{})
}
func init() {
Info = NewLoggerPlus(log.New(ioutil.Discard, logInfoLabel, log.Ldate|log.Ltime|log.Lmicroseconds))
Trace = NewLoggerPlus(log.New(os.Stdout, logTraceLabel, log.Ldate|log.Ltime|log.Lmicroseconds))
Warn = NewLoggerPlus(log.New(os.Stderr, logWarnLabel, log.Ldate|log.Ltime|log.Lmicroseconds))
Error = NewLoggerPlus(log.New(os.Stderr, logErrorLabel, log.Ldate|log.Ltime|log.Lmicroseconds))
// init writer and closer.
previousWriter = os.Stdout
previousCloser = nil
}
// Switch the underlayer io.
// @remark user must close previous io for logger never close it.
func Switch(w io.Writer) io.Writer {
// TODO: support level, default to trace here.
Info = NewLoggerPlus(log.New(ioutil.Discard, logInfoLabel, log.Ldate|log.Ltime|log.Lmicroseconds))
Trace = NewLoggerPlus(log.New(w, logTraceLabel, log.Ldate|log.Ltime|log.Lmicroseconds))
Warn = NewLoggerPlus(log.New(w, logWarnLabel, log.Ldate|log.Ltime|log.Lmicroseconds))
Error = NewLoggerPlus(log.New(w, logErrorLabel, log.Ldate|log.Ltime|log.Lmicroseconds))
ow := previousWriter
previousWriter = w
if c, ok := w.(io.Closer); ok {
previousCloser = c
}
return ow
}
// The previous underlayer io for logger.
var previousCloser io.Closer
var previousWriter io.Writer
// The interface io.Closer
// Cleanup the logger, discard any log util switch to fresh writer.
func Close() (err error) {
Info = NewLoggerPlus(log.New(ioutil.Discard, logInfoLabel, log.Ldate|log.Ltime|log.Lmicroseconds))
Trace = NewLoggerPlus(log.New(ioutil.Discard, logTraceLabel, log.Ldate|log.Ltime|log.Lmicroseconds))
Warn = NewLoggerPlus(log.New(ioutil.Discard, logWarnLabel, log.Ldate|log.Ltime|log.Lmicroseconds))
Error = NewLoggerPlus(log.New(ioutil.Discard, logErrorLabel, log.Ldate|log.Ltime|log.Lmicroseconds))
if previousCloser != nil {
err = previousCloser.Close()
previousCloser = nil
}
return
}

@ -0,0 +1,34 @@
// The MIT License (MIT)
//
// Copyright (c) 2013-2017 Oryx(ossrs)
//
// Permission is hereby granted, free of charge, to any person obtaining a copy of
// this software and associated documentation files (the "Software"), to deal in
// the Software without restriction, including without limitation the rights to
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
// the Software, and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
// +build !go1.7
package logger
func (v *loggerPlus) Println(ctx Context, a ...interface{}) {
args := v.format(ctx, a...)
v.doPrintln(args...)
}
func (v *loggerPlus) Printf(ctx Context, format string, a ...interface{}) {
format, args := v.formatf(ctx, format, a...)
v.doPrintf(format, args...)
}

@ -0,0 +1,2 @@
# vim temporary files
*.sw[poe]

@ -0,0 +1,8 @@
linters-settings:
govet:
check-shadowing: true
misspell:
locale: US
run:
skip-dirs-use-default: false

@ -0,0 +1,20 @@
<h1 align="center">
Design
</h1>
### Portable
Pion Data Channels is written in Go and extremely portable. Anywhere Golang runs, Pion Data Channels should work as well! Instead of dealing with complicated
cross-compiling of multiple libraries, you now can run anywhere with one `go build`
### Simple API
The API is based on an io.ReadWriteCloser.
### Readable
If code comes from an RFC we try to make sure everything is commented with a link to the spec.
This makes learning and debugging easier, this library was written to also serve as a guide for others.
### Tested
Every commit is tested via travis-ci Go provides fantastic facilities for testing, and more will be added as time goes on.
### Shared libraries
Every pion product is built using shared libraries, allowing others to review and reuse our libraries.

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2018
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

@ -0,0 +1,45 @@
<h1 align="center">
<br>
Pion Data Channels
<br>
</h1>
<h4 align="center">A Go implementation of WebRTC Data Channels</h4>
<p align="center">
<a href="https://pion.ly"><img src="https://img.shields.io/badge/pion-datachannel-gray.svg?longCache=true&colorB=brightgreen" alt="Pion Data Channels"></a>
<!--<a href="https://sourcegraph.com/github.com/pion/webrtc?badge"><img src="https://sourcegraph.com/github.com/pion/webrtc/-/badge.svg" alt="Sourcegraph Widget"></a>-->
<a href="https://pion.ly/slack"><img src="https://img.shields.io/badge/join-us%20on%20slack-gray.svg?longCache=true&logo=slack&colorB=brightgreen" alt="Slack Widget"></a>
<br>
<a href="https://travis-ci.org/pion/datachannel"><img src="https://travis-ci.org/pion/datachannel.svg?branch=master" alt="Build Status"></a>
<a href="https://pkg.go.dev/github.com/pion/datachannel"><img src="https://godoc.org/github.com/pion/datachannel?status.svg" alt="GoDoc"></a>
<a href="https://codecov.io/gh/pion/datachannel"><img src="https://codecov.io/gh/pion/datachannel/branch/master/graph/badge.svg" alt="Coverage Status"></a>
<a href="https://goreportcard.com/report/github.com/pion/datachannel"><img src="https://goreportcard.com/badge/github.com/pion/datachannel" alt="Go Report Card"></a>
<!--<a href="https://www.codacy.com/app/Sean-Der/webrtc"><img src="https://api.codacy.com/project/badge/Grade/18f4aec384894e6aac0b94effe51961d" alt="Codacy Badge"></a>-->
<a href="LICENSE"><img src="https://img.shields.io/badge/License-MIT-yellow.svg" alt="License: MIT"></a>
</p>
<br>
See [DESIGN.md](DESIGN.md) for an overview of features and future goals.
### Roadmap
The library is used as a part of our WebRTC implementation. Please refer to that [roadmap](https://github.com/pion/webrtc/issues/9) to track our major milestones.
### Community
Pion has an active community on the [Golang Slack](https://invite.slack.golangbridge.org/). Sign up and join the **#pion** channel for discussions and support. You can also use [Pion mailing list](https://groups.google.com/forum/#!forum/pion).
We are always looking to support **your projects**. Please reach out if you have something to build!
If you need commercial support or don't want to use public methods you can contact us at [team@pion.ly](mailto:team@pion.ly)
### Contributing
Check out the **[contributing wiki](https://github.com/pion/webrtc/wiki/Contributing)** to join the group of amazing people making this project possible:
* [John Bradley](https://github.com/kc5nra) - *Original Author*
* [Sean DuBois](https://github.com/Sean-Der) - *Original Author*
* [Michiel De Backker](https://github.com/backkem) - *Public API*
* [Yutaka Takeda](https://github.com/enobufs) - *PR-SCTP*
* [Hugo Arregui](https://github.com/hugoArregui)
* [Atsushi Watanabe](https://github.com/at-wat)
* [Norman Rasmussen](https://github.com/normanr) - *Fix Empty DataChannel messages*
### License
MIT License - see [LICENSE](LICENSE) for full text

@ -0,0 +1,20 @@
#
# DO NOT EDIT THIS FILE
#
# It is automatically copied from https://github.com/pion/.goassets repository.
#
coverage:
status:
project:
default:
# Allow decreasing 2% of total coverage to avoid noise.
threshold: 2%
patch:
default:
target: 70%
only_pulls: true
ignore:
- "examples/*"
- "examples/**/*"

@ -0,0 +1,378 @@
// Package datachannel implements WebRTC Data Channels
package datachannel
import (
"fmt"
"io"
"sync/atomic"
"github.com/pion/logging"
"github.com/pion/sctp"
"github.com/pkg/errors"
)
const receiveMTU = 8192
// Reader is an extended io.Reader
// that also returns if the message is text.
type Reader interface {
ReadDataChannel([]byte) (int, bool, error)
}
// Writer is an extended io.Writer
// that also allows indicating if a message is text.
type Writer interface {
WriteDataChannel([]byte, bool) (int, error)
}
// ReadWriteCloser is an extended io.ReadWriteCloser
// that also implements our Reader and Writer.
type ReadWriteCloser interface {
io.Reader
io.Writer
Reader
Writer
io.Closer
}
// DataChannel represents a data channel
type DataChannel struct {
Config
// stats
messagesSent uint32
messagesReceived uint32
bytesSent uint64
bytesReceived uint64
stream *sctp.Stream
log logging.LeveledLogger
}
// Config is used to configure the data channel.
type Config struct {
ChannelType ChannelType
Negotiated bool
Priority uint16
ReliabilityParameter uint32
Label string
Protocol string
LoggerFactory logging.LoggerFactory
}
func newDataChannel(stream *sctp.Stream, config *Config) (*DataChannel, error) {
return &DataChannel{
Config: *config,
stream: stream,
log: config.LoggerFactory.NewLogger("datachannel"),
}, nil
}
// Dial opens a data channels over SCTP
func Dial(a *sctp.Association, id uint16, config *Config) (*DataChannel, error) {
stream, err := a.OpenStream(id, sctp.PayloadTypeWebRTCBinary)
if err != nil {
return nil, err
}
dc, err := Client(stream, config)
if err != nil {
return nil, err
}
return dc, nil
}
// Client opens a data channel over an SCTP stream
func Client(stream *sctp.Stream, config *Config) (*DataChannel, error) {
msg := &channelOpen{
ChannelType: config.ChannelType,
Priority: config.Priority,
ReliabilityParameter: config.ReliabilityParameter,
Label: []byte(config.Label),
Protocol: []byte(config.Protocol),
}
if !config.Negotiated {
rawMsg, err := msg.Marshal()
if err != nil {
return nil, fmt.Errorf("failed to marshal ChannelOpen %v", err)
}
if _, err = stream.WriteSCTP(rawMsg, sctp.PayloadTypeWebRTCDCEP); err != nil {
return nil, fmt.Errorf("failed to send ChannelOpen %v", err)
}
}
return newDataChannel(stream, config)
}
// Accept is used to accept incoming data channels over SCTP
func Accept(a *sctp.Association, config *Config) (*DataChannel, error) {
stream, err := a.AcceptStream()
if err != nil {
return nil, err
}
stream.SetDefaultPayloadType(sctp.PayloadTypeWebRTCBinary)
dc, err := Server(stream, config)
if err != nil {
return nil, err
}
return dc, nil
}
// Server accepts a data channel over an SCTP stream
func Server(stream *sctp.Stream, config *Config) (*DataChannel, error) {
buffer := make([]byte, receiveMTU) // TODO: Can probably be smaller
n, ppi, err := stream.ReadSCTP(buffer)
if err != nil {
return nil, err
}
if ppi != sctp.PayloadTypeWebRTCDCEP {
return nil, fmt.Errorf("unexpected packet type: %s", ppi)
}
openMsg, err := parseExpectDataChannelOpen(buffer[:n])
if err != nil {
return nil, errors.Wrap(err, "failed to parse DataChannelOpen packet")
}
config.ChannelType = openMsg.ChannelType
config.Priority = openMsg.Priority
config.ReliabilityParameter = openMsg.ReliabilityParameter
config.Label = string(openMsg.Label)
config.Protocol = string(openMsg.Protocol)
dataChannel, err := newDataChannel(stream, config)
if err != nil {
return nil, err
}
err = dataChannel.writeDataChannelAck()
if err != nil {
return nil, err
}
err = dataChannel.commitReliabilityParams()
if err != nil {
return nil, err
}
return dataChannel, nil
}
// Read reads a packet of len(p) bytes as binary data
func (c *DataChannel) Read(p []byte) (int, error) {
n, _, err := c.ReadDataChannel(p)
return n, err
}
// ReadDataChannel reads a packet of len(p) bytes
func (c *DataChannel) ReadDataChannel(p []byte) (int, bool, error) {
for {
n, ppi, err := c.stream.ReadSCTP(p)
if err == io.EOF {
// When the peer sees that an incoming stream was
// reset, it also resets its corresponding outgoing stream.
closeErr := c.stream.Close()
if closeErr != nil {
return 0, false, closeErr
}
}
if err != nil {
return 0, false, err
}
var isString bool
switch ppi {
case sctp.PayloadTypeWebRTCDCEP:
err = c.handleDCEP(p[:n])
if err != nil {
c.log.Errorf("Failed to handle DCEP: %s", err.Error())
continue
}
continue
case sctp.PayloadTypeWebRTCString, sctp.PayloadTypeWebRTCStringEmpty:
isString = true
}
switch ppi {
case sctp.PayloadTypeWebRTCBinaryEmpty, sctp.PayloadTypeWebRTCStringEmpty:
n = 0
}
atomic.AddUint32(&c.messagesReceived, 1)
atomic.AddUint64(&c.bytesReceived, uint64(n))
return n, isString, err
}
}
// MessagesSent returns the number of messages sent
func (c *DataChannel) MessagesSent() uint32 {
return atomic.LoadUint32(&c.messagesSent)
}
// MessagesReceived returns the number of messages received
func (c *DataChannel) MessagesReceived() uint32 {
return atomic.LoadUint32(&c.messagesReceived)
}
// BytesSent returns the number of bytes sent
func (c *DataChannel) BytesSent() uint64 {
return atomic.LoadUint64(&c.bytesSent)
}
// BytesReceived returns the number of bytes received
func (c *DataChannel) BytesReceived() uint64 {
return atomic.LoadUint64(&c.bytesReceived)
}
// StreamIdentifier returns the Stream identifier associated to the stream.
func (c *DataChannel) StreamIdentifier() uint16 {
return c.stream.StreamIdentifier()
}
func (c *DataChannel) handleDCEP(data []byte) error {
msg, err := parse(data)
if err != nil {
return errors.Wrap(err, "Failed to parse DataChannel packet")
}
switch msg := msg.(type) {
case *channelOpen:
c.log.Debug("Received DATA_CHANNEL_OPEN")
err = c.writeDataChannelAck()
if err != nil {
return fmt.Errorf("failed to ACK channel open: %v", err)
}
// Note: DATA_CHANNEL_OPEN message is handled inside Server() method.
// Therefore, the message will not reach here.
case *channelAck:
c.log.Debug("Received DATA_CHANNEL_ACK")
err = c.commitReliabilityParams()
if err != nil {
return err
}
// TODO: handle ChannelAck (https://tools.ietf.org/html/draft-ietf-rtcweb-data-protocol-09#section-5.2)
default:
return fmt.Errorf("unhandled DataChannel message %v", msg)
}
return nil
}
// Write writes len(p) bytes from p as binary data
func (c *DataChannel) Write(p []byte) (n int, err error) {
return c.WriteDataChannel(p, false)
}
// WriteDataChannel writes len(p) bytes from p
func (c *DataChannel) WriteDataChannel(p []byte, isString bool) (n int, err error) {
// https://tools.ietf.org/html/draft-ietf-rtcweb-data-channel-12#section-6.6
// SCTP does not support the sending of empty user messages. Therefore,
// if an empty message has to be sent, the appropriate PPID (WebRTC
// String Empty or WebRTC Binary Empty) is used and the SCTP user
// message of one zero byte is sent. When receiving an SCTP user
// message with one of these PPIDs, the receiver MUST ignore the SCTP
// user message and process it as an empty message.
var ppi sctp.PayloadProtocolIdentifier
switch {
case !isString && len(p) > 0:
ppi = sctp.PayloadTypeWebRTCBinary
case !isString && len(p) == 0:
ppi = sctp.PayloadTypeWebRTCBinaryEmpty
case isString && len(p) > 0:
ppi = sctp.PayloadTypeWebRTCString
case isString && len(p) == 0:
ppi = sctp.PayloadTypeWebRTCStringEmpty
}
atomic.AddUint32(&c.messagesSent, 1)
atomic.AddUint64(&c.bytesSent, uint64(len(p)))
if len(p) == 0 {
_, err := c.stream.WriteSCTP([]byte{0}, ppi)
return 0, err
}
return c.stream.WriteSCTP(p, ppi)
}
func (c *DataChannel) writeDataChannelAck() error {
ack := channelAck{}
ackMsg, err := ack.Marshal()
if err != nil {
return fmt.Errorf("failed to marshal ChannelOpen ACK: %v", err)
}
_, err = c.stream.WriteSCTP(ackMsg, sctp.PayloadTypeWebRTCDCEP)
if err != nil {
return fmt.Errorf("failed to send ChannelOpen ACK: %v", err)
}
return err
}
// Close closes the DataChannel and the underlying SCTP stream.
func (c *DataChannel) Close() error {
// https://tools.ietf.org/html/draft-ietf-rtcweb-data-channel-13#section-6.7
// Closing of a data channel MUST be signaled by resetting the
// corresponding outgoing streams [RFC6525]. This means that if one
// side decides to close the data channel, it resets the corresponding
// outgoing stream. When the peer sees that an incoming stream was
// reset, it also resets its corresponding outgoing stream. Once this
// is completed, the data channel is closed. Resetting a stream sets
// the Stream Sequence Numbers (SSNs) of the stream back to 'zero' with
// a corresponding notification to the application layer that the reset
// has been performed. Streams are available for reuse after a reset
// has been performed.
return c.stream.Close()
}
// BufferedAmount returns the number of bytes of data currently queued to be
// sent over this stream.
func (c *DataChannel) BufferedAmount() uint64 {
return c.stream.BufferedAmount()
}
// BufferedAmountLowThreshold returns the number of bytes of buffered outgoing
// data that is considered "low." Defaults to 0.
func (c *DataChannel) BufferedAmountLowThreshold() uint64 {
return c.stream.BufferedAmountLowThreshold()
}
// SetBufferedAmountLowThreshold is used to update the threshold.
// See BufferedAmountLowThreshold().
func (c *DataChannel) SetBufferedAmountLowThreshold(th uint64) {
c.stream.SetBufferedAmountLowThreshold(th)
}
// OnBufferedAmountLow sets the callback handler which would be called when the
// number of bytes of outgoing data buffered is lower than the threshold.
func (c *DataChannel) OnBufferedAmountLow(f func()) {
c.stream.OnBufferedAmountLow(f)
}
func (c *DataChannel) commitReliabilityParams() error {
switch c.Config.ChannelType {
case ChannelTypeReliable:
c.stream.SetReliabilityParams(false, sctp.ReliabilityTypeReliable, c.Config.ReliabilityParameter)
case ChannelTypeReliableUnordered:
c.stream.SetReliabilityParams(true, sctp.ReliabilityTypeReliable, c.Config.ReliabilityParameter)
case ChannelTypePartialReliableRexmit:
c.stream.SetReliabilityParams(false, sctp.ReliabilityTypeRexmit, c.Config.ReliabilityParameter)
case ChannelTypePartialReliableRexmitUnordered:
c.stream.SetReliabilityParams(true, sctp.ReliabilityTypeRexmit, c.Config.ReliabilityParameter)
case ChannelTypePartialReliableTimed:
c.stream.SetReliabilityParams(false, sctp.ReliabilityTypeTimed, c.Config.ReliabilityParameter)
case ChannelTypePartialReliableTimedUnordered:
c.stream.SetReliabilityParams(true, sctp.ReliabilityTypeTimed, c.Config.ReliabilityParameter)
default:
return fmt.Errorf("invalid ChannelType: %v ", c.Config.ChannelType)
}
return nil
}

@ -0,0 +1,11 @@
module github.com/pion/datachannel
require (
github.com/pion/logging v0.2.2
github.com/pion/sctp v1.7.10
github.com/pion/transport v0.10.1
github.com/pkg/errors v0.9.1
github.com/stretchr/testify v1.6.1
)
go 1.13

@ -0,0 +1,38 @@
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY=
github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms=
github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA=
github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8=
github.com/pion/sctp v1.7.10 h1:o3p3/hZB5Cx12RMGyWmItevJtZ6o2cpuxaw6GOS4x+8=
github.com/pion/sctp v1.7.10/go.mod h1:EhpTUQu1/lcK3xI+eriS6/96fWetHGCvBi9MSsnaBN0=
github.com/pion/transport v0.10.1 h1:2W+yJT+0mOQ160ThZYUx5Zp2skzshiNgxrNE9GUfhJM=
github.com/pion/transport v0.10.1/go.mod h1:PBis1stIILMiis0PewDw91WJeLJkyIMcEk+DwKOzf4A=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20200625001655-4c5254603344 h1:vGXIOMxbNfDTk/aXCmfdLgkrSV+Z2tcbze+pEc3v5W4=
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

@ -0,0 +1,94 @@
package datachannel
import (
"fmt"
"github.com/pkg/errors"
)
// message is a parsed DataChannel message
type message interface {
Marshal() ([]byte, error)
Unmarshal([]byte) error
}
// messageType is the first byte in a DataChannel message that specifies type
type messageType byte
// DataChannel Message Types
const (
dataChannelAck messageType = 0x02
dataChannelOpen messageType = 0x03
)
func (t messageType) String() string {
switch t {
case dataChannelAck:
return "DataChannelAck"
case dataChannelOpen:
return "DataChannelOpen"
default:
return fmt.Sprintf("Unknown MessageType: %d", t)
}
}
// parse accepts raw input and returns a DataChannel message
func parse(raw []byte) (message, error) {
if len(raw) == 0 {
return nil, errors.Errorf("DataChannel message is not long enough to determine type ")
}
var msg message
switch messageType(raw[0]) {
case dataChannelOpen:
msg = &channelOpen{}
case dataChannelAck:
msg = &channelAck{}
default:
return nil, errors.Errorf("Unknown MessageType %v", messageType(raw[0]))
}
if err := msg.Unmarshal(raw); err != nil {
return nil, err
}
return msg, nil
}
// parseExpectDataChannelOpen parses a DataChannelOpen message
// or throws an error
func parseExpectDataChannelOpen(raw []byte) (*channelOpen, error) {
if len(raw) == 0 {
return nil, errors.Errorf("the DataChannel message is not long enough to determine type")
}
if actualTyp := messageType(raw[0]); actualTyp != dataChannelOpen {
return nil, errors.Errorf("expected DataChannelOpen but got %s", actualTyp)
}
msg := &channelOpen{}
if err := msg.Unmarshal(raw); err != nil {
return nil, err
}
return msg, nil
}
// parseExpectDataChannelAck parses a DataChannelAck message
// or throws an error
// func parseExpectDataChannelAck(raw []byte) (*channelAck, error) {
// if len(raw) == 0 {
// return nil, errors.Errorf("the DataChannel message is not long enough to determine type")
// }
//
// if actualTyp := messageType(raw[0]); actualTyp != dataChannelAck {
// return nil, errors.Errorf("expected DataChannelAck but got %s", actualTyp)
// }
//
// msg := &channelAck{}
// if err := msg.Unmarshal(raw); err != nil {
// return nil, err
// }
//
// return msg, nil
// }

@ -0,0 +1,22 @@
package datachannel
// channelAck is used to ACK a DataChannel open
type channelAck struct{}
const (
channelOpenAckLength = 4
)
// Marshal returns raw bytes for the given message
func (c *channelAck) Marshal() ([]byte, error) {
raw := make([]byte, channelOpenAckLength)
raw[0] = uint8(dataChannelAck)
return raw, nil
}
// Unmarshal populates the struct with the given raw data
func (c *channelAck) Unmarshal(raw []byte) error {
// Message type already checked in Parse and there is no further data
return nil
}

@ -0,0 +1,123 @@
package datachannel
import (
"encoding/binary"
"github.com/pkg/errors"
)
/*
channelOpen represents a DATA_CHANNEL_OPEN Message
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Message Type | Channel Type | Priority |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Reliability Parameter |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Label Length | Protocol Length |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| |
| Label |
| |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| |
| Protocol |
| |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
*/
type channelOpen struct {
ChannelType ChannelType
Priority uint16
ReliabilityParameter uint32
Label []byte
Protocol []byte
}
const (
channelOpenHeaderLength = 12
)
// ChannelType determines the reliability of the WebRTC DataChannel
type ChannelType byte
// ChannelType enums
const (
// ChannelTypeReliable determines the Data Channel provides a
// reliable in-order bi-directional communication.
ChannelTypeReliable ChannelType = 0x00
// ChannelTypeReliableUnordered determines the Data Channel
// provides a reliable unordered bi-directional communication.
ChannelTypeReliableUnordered ChannelType = 0x80
// ChannelTypePartialReliableRexmit determines the Data Channel
// provides a partially-reliable in-order bi-directional communication.
// User messages will not be retransmitted more times than specified in the Reliability Parameter.
ChannelTypePartialReliableRexmit ChannelType = 0x01
// ChannelTypePartialReliableRexmitUnordered determines
// the Data Channel provides a partial reliable unordered bi-directional communication.
// User messages will not be retransmitted more times than specified in the Reliability Parameter.
ChannelTypePartialReliableRexmitUnordered ChannelType = 0x81
// ChannelTypePartialReliableTimed determines the Data Channel
// provides a partial reliable in-order bi-directional communication.
// User messages might not be transmitted or retransmitted after
// a specified life-time given in milli- seconds in the Reliability Parameter.
// This life-time starts when providing the user message to the protocol stack.
ChannelTypePartialReliableTimed ChannelType = 0x02
// The Data Channel provides a partial reliable unordered bi-directional
// communication. User messages might not be transmitted or retransmitted
// after a specified life-time given in milli- seconds in the Reliability Parameter.
// This life-time starts when providing the user message to the protocol stack.
ChannelTypePartialReliableTimedUnordered ChannelType = 0x82
)
// ChannelPriority enums
const (
ChannelPriorityBelowNormal uint16 = 128
ChannelPriorityNormal uint16 = 256
ChannelPriorityHigh uint16 = 512
ChannelPriorityExtraHigh uint16 = 1024
)
// Marshal returns raw bytes for the given message
func (c *channelOpen) Marshal() ([]byte, error) {
labelLength := len(c.Label)
protocolLength := len(c.Protocol)
totalLen := channelOpenHeaderLength + labelLength + protocolLength
raw := make([]byte, totalLen)
raw[0] = uint8(dataChannelOpen)
raw[1] = byte(c.ChannelType)
binary.BigEndian.PutUint16(raw[2:], c.Priority)
binary.BigEndian.PutUint32(raw[4:], c.ReliabilityParameter)
binary.BigEndian.PutUint16(raw[8:], uint16(labelLength))
binary.BigEndian.PutUint16(raw[10:], uint16(protocolLength))
endLabel := channelOpenHeaderLength + labelLength
copy(raw[channelOpenHeaderLength:endLabel], c.Label)
copy(raw[endLabel:endLabel+protocolLength], c.Protocol)
return raw, nil
}
// Unmarshal populates the struct with the given raw data
func (c *channelOpen) Unmarshal(raw []byte) error {
if len(raw) < channelOpenHeaderLength {
return errors.Errorf("Length of input is not long enough to satisfy header %d", len(raw))
}
c.ChannelType = ChannelType(raw[1])
c.Priority = binary.BigEndian.Uint16(raw[2:])
c.ReliabilityParameter = binary.BigEndian.Uint32(raw[4:])
labelLength := binary.BigEndian.Uint16(raw[8:])
protocolLength := binary.BigEndian.Uint16(raw[10:])
if len(raw) != int(channelOpenHeaderLength+labelLength+protocolLength) {
return errors.Errorf("Label + Protocol length don't match full packet length")
}
c.Label = raw[channelOpenHeaderLength : channelOpenHeaderLength+labelLength]
c.Protocol = raw[channelOpenHeaderLength+labelLength : channelOpenHeaderLength+labelLength+protocolLength]
return nil
}

@ -0,0 +1,15 @@
{
"extends": [
"config:base"
],
"postUpdateOptions": [
"gomodTidy"
],
"commitBody": "Generated by renovateBot",
"packageRules": [
{
"packagePatterns": ["^golang.org/x/"],
"schedule": ["on the first day of the month"]
}
]
}

@ -0,0 +1,21 @@
# http://editorconfig.org/
root = true
[*]
charset = utf-8
insert_final_newline = true
trim_trailing_whitespace = true
end_of_line = lf
[*.go]
indent_style = tab
indent_size = 4
[{*.yml,*.yaml}]
indent_style = space
indent_size = 2
# Makefiles always use tabs for indentation
[Makefile]
indent_style = tab

@ -0,0 +1,89 @@
linters-settings:
govet:
check-shadowing: true
misspell:
locale: US
exhaustive:
default-signifies-exhaustive: true
gomodguard:
blocked:
modules:
- github.com/pkg/errors:
recommendations:
- errors
linters:
enable:
- asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers
- bodyclose # checks whether HTTP response body is closed successfully
- deadcode # Finds unused code
- depguard # Go linter that checks if package imports are in a list of acceptable packages
- dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f())
- dupl # Tool for code clone detection
- errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases
- exhaustive # check exhaustiveness of enum switch statements
- exportloopref # checks for pointers to enclosing loop variables
- gci # Gci control golang package import order and make it always deterministic.
- gochecknoglobals # Checks that no globals are present in Go code
- gochecknoinits # Checks that no init functions are present in Go code
- gocognit # Computes and checks the cognitive complexity of functions
- goconst # Finds repeated strings that could be replaced by a constant
- gocritic # The most opinionated Go source code linter
- godox # Tool for detection of FIXME, TODO and other comment keywords
- goerr113 # Golang linter to check the errors handling expressions
- gofmt # Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification
- gofumpt # Gofumpt checks whether code was gofumpt-ed.
- goheader # Checks is file header matches to pattern
- goimports # Goimports does everything that gofmt does. Additionally it checks unused imports
- golint # Golint differs from gofmt. Gofmt reformats Go source code, whereas golint prints out style mistakes
- gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations.
- goprintffuncname # Checks that printf-like functions are named with `f` at the end
- gosec # Inspects source code for security problems
- gosimple # Linter for Go source code that specializes in simplifying a code
- govet # Vet examines Go source code and reports suspicious constructs, such as Printf calls whose arguments do not align with the format string
- ineffassign # Detects when assignments to existing variables are not used
- misspell # Finds commonly misspelled English words in comments
- nakedret # Finds naked returns in functions greater than a specified function length
- noctx # noctx finds sending http request without context.Context
- scopelint # Scopelint checks for unpinned variables in go programs
- staticcheck # Staticcheck is a go vet on steroids, applying a ton of static analysis checks
- structcheck # Finds unused struct fields
- stylecheck # Stylecheck is a replacement for golint
- typecheck # Like the front-end of a Go compiler, parses and type-checks Go code
- unconvert # Remove unnecessary type conversions
- unparam # Reports unused function parameters
- unused # Checks Go code for unused constants, variables, functions and types
- varcheck # Finds unused global variables and constants
- whitespace # Tool for detection of leading and trailing whitespace
disable:
- funlen # Tool for detection of long functions
- gocyclo # Computes and checks the cyclomatic complexity of functions
- godot # Check if comments end in a period
- gomnd # An analyzer to detect magic numbers.
- lll # Reports long lines
- maligned # Tool to detect Go structs that would take less memory if their fields were sorted
- nestif # Reports deeply nested if statements
- nlreturn # nlreturn checks for a new line before return and branch statements to increase code clarity
- nolintlint # Reports ill-formed or insufficient nolint directives
- prealloc # Finds slice declarations that could potentially be preallocated
- rowserrcheck # checks whether Err of rows is checked successfully
- sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed.
- testpackage # linter that makes you use a separate _test package
- wsl # Whitespace Linter - Forces you to use empty lines!
issues:
exclude-use-default: false
exclude-rules:
# Allow complex tests, better to be self contained
- path: _test\.go
linters:
- gocognit
# Allow complex main function in examples
- path: examples
text: "of func `main` is high"
linters:
- gocognit
run:
skip-dirs-use-default: false

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2018
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

@ -0,0 +1,6 @@
fuzz-build-record-layer: fuzz-prepare
go-fuzz-build -tags gofuzz -func FuzzRecordLayer
fuzz-run-record-layer:
go-fuzz -bin dtls-fuzz.zip -workdir fuzz
fuzz-prepare:
@GO111MODULE=on go mod vendor

@ -0,0 +1,151 @@
<h1 align="center">
<br>
Pion DTLS
<br>
</h1>
<h4 align="center">A Go implementation of DTLS</h4>
<p align="center">
<a href="https://pion.ly"><img src="https://img.shields.io/badge/pion-dtls-gray.svg?longCache=true&colorB=brightgreen" alt="Pion DTLS"></a>
<a href="https://sourcegraph.com/github.com/pion/dtls"><img src="https://sourcegraph.com/github.com/pion/dtls/-/badge.svg" alt="Sourcegraph Widget"></a>
<a href="https://pion.ly/slack"><img src="https://img.shields.io/badge/join-us%20on%20slack-gray.svg?longCache=true&logo=slack&colorB=brightgreen" alt="Slack Widget"></a>
<br>
<a href="https://travis-ci.org/pion/dtls"><img src="https://travis-ci.org/pion/dtls.svg?branch=master" alt="Build Status"></a>
<a href="https://pkg.go.dev/github.com/pion/dtls"><img src="https://godoc.org/github.com/pion/dtls?status.svg" alt="GoDoc"></a>
<a href="https://codecov.io/gh/pion/dtls"><img src="https://codecov.io/gh/pion/dtls/branch/master/graph/badge.svg" alt="Coverage Status"></a>
<a href="https://goreportcard.com/report/github.com/pion/dtls"><img src="https://goreportcard.com/badge/github.com/pion/dtls" alt="Go Report Card"></a>
<a href="https://www.codacy.com/app/Sean-Der/dtls"><img src="https://api.codacy.com/project/badge/Grade/18f4aec384894e6aac0b94effe51961d" alt="Codacy Badge"></a>
<a href="LICENSE"><img src="https://img.shields.io/badge/License-MIT-yellow.svg" alt="License: MIT"></a>
</p>
<br>
Native [DTLS 1.2][rfc6347] implementation in the Go programming language.
A long term goal is a professional security review, and maye inclusion in stdlib.
[rfc6347]: https://tools.ietf.org/html/rfc6347
### Goals/Progress
This will only be targeting DTLS 1.2, and the most modern/common cipher suites.
We would love contributes that fall under the 'Planned Features' and fixing any bugs!
#### Current features
* DTLS 1.2 Client/Server
* Key Exchange via ECDHE(curve25519, nistp256, nistp384) and PSK
* Packet loss and re-ordering is handled during handshaking
* Key export ([RFC 5705][rfc5705])
* Serialization and Resumption of sessions
* Extended Master Secret extension ([RFC 7627][rfc7627])
[rfc5705]: https://tools.ietf.org/html/rfc5705
[rfc7627]: https://tools.ietf.org/html/rfc7627
#### Supported ciphers
##### ECDHE
* TLS_ECDHE_ECDSA_WITH_AES_128_CCM ([RFC 6655][rfc6655])
* TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 ([RFC 6655][rfc6655])
* TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 ([RFC 5289][rfc5289])
* TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 ([RFC 5289][rfc5289])
* TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA ([RFC 8422][rfc8422])
* TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA ([RFC 8422][rfc8422])
##### PSK
* TLS_PSK_WITH_AES_128_CCM ([RFC 6655][rfc6655])
* TLS_PSK_WITH_AES_128_CCM_8 ([RFC 6655][rfc6655])
* TLS_PSK_WITH_AES_128_GCM_SHA256 ([RFC 5487][rfc5487])
[rfc5289]: https://tools.ietf.org/html/rfc5289
[rfc8422]: https://tools.ietf.org/html/rfc8422
[rfc6655]: https://tools.ietf.org/html/rfc6655
[rfc5487]: https://tools.ietf.org/html/rfc5487
#### Planned Features
* Chacha20Poly1305
#### Excluded Features
* DTLS 1.0
* Renegotiation
* Compression
### Using
This library needs at least Go 1.13, and you should have [Go modules
enabled](https://github.com/golang/go/wiki/Modules).
#### Pion DTLS
For a DTLS 1.2 Server that listens on 127.0.0.1:4444
```sh
go run examples/listen/selfsign/main.go
```
For a DTLS 1.2 Client that connects to 127.0.0.1:4444
```sh
go run examples/dial/selfsign/main.go
```
#### OpenSSL
Pion DTLS can connect to itself and OpenSSL.
```
// Generate a certificate
openssl ecparam -out key.pem -name prime256v1 -genkey
openssl req -new -sha256 -key key.pem -out server.csr
openssl x509 -req -sha256 -days 365 -in server.csr -signkey key.pem -out cert.pem
// Use with examples/dial/selfsign/main.go
openssl s_server -dtls1_2 -cert cert.pem -key key.pem -accept 4444
// Use with examples/listen/selfsign/main.go
openssl s_client -dtls1_2 -connect 127.0.0.1:4444 -debug -cert cert.pem -key key.pem
```
### Using with PSK
Pion DTLS also comes with examples that do key exchange via PSK
#### Pion DTLS
```sh
go run examples/listen/psk/main.go
```
```sh
go run examples/dial/psk/main.go
```
#### OpenSSL
```
// Use with examples/dial/psk/main.go
openssl s_server -dtls1_2 -accept 4444 -nocert -psk abc123 -cipher PSK-AES128-CCM8
// Use with examples/listen/psk/main.go
openssl s_client -dtls1_2 -connect 127.0.0.1:4444 -psk abc123 -cipher PSK-AES128-CCM8
```
### Contributing
Check out the **[contributing wiki](https://github.com/pion/webrtc/wiki/Contributing)** to join the group of amazing people making this project possible:
* [Sean DuBois](https://github.com/Sean-Der) - *Original Author*
* [Michiel De Backker](https://github.com/backkem) - *Public API*
* [Chris Hiszpanski](https://github.com/thinkski) - *Support Signature Algorithms Extension*
* [Iñigo Garcia Olaizola](https://github.com/igolaizola) - *Serialization & resumption, cert verification, E2E*
* [Daniele Sluijters](https://github.com/daenney) - *AES-CCM support*
* [Jin Lei](https://github.com/jinleileiking) - *Logging*
* [Hugo Arregui](https://github.com/hugoArregui)
* [Lander Noterman](https://github.com/LanderN)
* [Aleksandr Razumov](https://github.com/ernado) - *Fuzzing*
* [Ryan Gordon](https://github.com/ryangordon)
* [Stefan Tatschner](https://rumpelsepp.org/contact.html)
* [Hayden James](https://github.com/hjames9)
* [Jozef Kralik](https://github.com/jkralik)
* [Robert Eperjesi](https://github.com/epes)
* [Atsushi Watanabe](https://github.com/at-wat)
* [Julien Salleyron](https://github.com/juliens) - *Server Name Indication*
* [Jeroen de Bruijn](https://github.com/vidavidorra)
* [bjdgyc](https://github.com/bjdgyc)
* [Jeffrey Stoke (Jeff Ctor)](https://github.com/jeffreystoke) - *Fragmentbuffer Fix*
* [Frank Olbricht](https://github.com/folbricht)
* [ZHENK](https://github.com/scorpionknifes)
* [Carson Hoffman](https://github.com/CarsonHoffman)
* [Vadim Filimonov](https://github.com/fffilimonov)
### License
MIT License - see [LICENSE](LICENSE) for full text

@ -0,0 +1,145 @@
package dtls
import "fmt"
type alertLevel byte
const (
alertLevelWarning alertLevel = 1
alertLevelFatal alertLevel = 2
)
func (a alertLevel) String() string {
switch a {
case alertLevelWarning:
return "LevelWarning"
case alertLevelFatal:
return "LevelFatal"
default:
return "Invalid alert level"
}
}
type alertDescription byte
const (
alertCloseNotify alertDescription = 0
alertUnexpectedMessage alertDescription = 10
alertBadRecordMac alertDescription = 20
alertDecryptionFailed alertDescription = 21
alertRecordOverflow alertDescription = 22
alertDecompressionFailure alertDescription = 30
alertHandshakeFailure alertDescription = 40
alertNoCertificate alertDescription = 41
alertBadCertificate alertDescription = 42
alertUnsupportedCertificate alertDescription = 43
alertCertificateRevoked alertDescription = 44
alertCertificateExpired alertDescription = 45
alertCertificateUnknown alertDescription = 46
alertIllegalParameter alertDescription = 47
alertUnknownCA alertDescription = 48
alertAccessDenied alertDescription = 49
alertDecodeError alertDescription = 50
alertDecryptError alertDescription = 51
alertExportRestriction alertDescription = 60
alertProtocolVersion alertDescription = 70
alertInsufficientSecurity alertDescription = 71
alertInternalError alertDescription = 80
alertUserCanceled alertDescription = 90
alertNoRenegotiation alertDescription = 100
alertUnsupportedExtension alertDescription = 110
)
func (a alertDescription) String() string {
switch a {
case alertCloseNotify:
return "CloseNotify"
case alertUnexpectedMessage:
return "UnexpectedMessage"
case alertBadRecordMac:
return "BadRecordMac"
case alertDecryptionFailed:
return "DecryptionFailed"
case alertRecordOverflow:
return "RecordOverflow"
case alertDecompressionFailure:
return "DecompressionFailure"
case alertHandshakeFailure:
return "HandshakeFailure"
case alertNoCertificate:
return "NoCertificate"
case alertBadCertificate:
return "BadCertificate"
case alertUnsupportedCertificate:
return "UnsupportedCertificate"
case alertCertificateRevoked:
return "CertificateRevoked"
case alertCertificateExpired:
return "CertificateExpired"
case alertCertificateUnknown:
return "CertificateUnknown"
case alertIllegalParameter:
return "IllegalParameter"
case alertUnknownCA:
return "UnknownCA"
case alertAccessDenied:
return "AccessDenied"
case alertDecodeError:
return "DecodeError"
case alertDecryptError:
return "DecryptError"
case alertExportRestriction:
return "ExportRestriction"
case alertProtocolVersion:
return "ProtocolVersion"
case alertInsufficientSecurity:
return "InsufficientSecurity"
case alertInternalError:
return "InternalError"
case alertUserCanceled:
return "UserCanceled"
case alertNoRenegotiation:
return "NoRenegotiation"
case alertUnsupportedExtension:
return "UnsupportedExtension"
default:
return "Invalid alert description"
}
}
// One of the content types supported by the TLS record layer is the
// alert type. Alert messages convey the severity of the message
// (warning or fatal) and a description of the alert. Alert messages
// with a level of fatal result in the immediate termination of the
// connection. In this case, other connections corresponding to the
// session may continue, but the session identifier MUST be invalidated,
// preventing the failed session from being used to establish new
// connections. Like other messages, alert messages are encrypted and
// compressed, as specified by the current connection state.
// https://tools.ietf.org/html/rfc5246#section-7.2
type alert struct {
alertLevel alertLevel
alertDescription alertDescription
}
func (a alert) contentType() contentType {
return contentTypeAlert
}
func (a *alert) Marshal() ([]byte, error) {
return []byte{byte(a.alertLevel), byte(a.alertDescription)}, nil
}
func (a *alert) Unmarshal(data []byte) error {
if len(data) != 2 {
return errBufferTooSmall
}
a.alertLevel = alertLevel(data[0])
a.alertDescription = alertDescription(data[1])
return nil
}
func (a *alert) String() string {
return fmt.Sprintf("Alert %s: %s", a.alertLevel, a.alertDescription)
}

@ -0,0 +1,23 @@
package dtls
// Application data messages are carried by the record layer and are
// fragmented, compressed, and encrypted based on the current connection
// state. The messages are treated as transparent data to the record
// layer.
// https://tools.ietf.org/html/rfc5246#section-10
type applicationData struct {
data []byte
}
func (a applicationData) contentType() contentType {
return contentTypeApplicationData
}
func (a *applicationData) Marshal() ([]byte, error) {
return append([]byte{}, a.data...), nil
}
func (a *applicationData) Unmarshal(data []byte) error {
a.data = append([]byte{}, data...)
return nil
}

@ -0,0 +1,67 @@
package dtls
import (
"crypto/tls"
"crypto/x509"
"strings"
)
func (c *handshakeConfig) getCertificate(serverName string) (*tls.Certificate, error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.nameToCertificate == nil {
nameToCertificate := make(map[string]*tls.Certificate)
for i := range c.localCertificates {
cert := &c.localCertificates[i]
x509Cert := cert.Leaf
if x509Cert == nil {
var parseErr error
x509Cert, parseErr = x509.ParseCertificate(cert.Certificate[0])
if parseErr != nil {
continue
}
}
if len(x509Cert.Subject.CommonName) > 0 {
nameToCertificate[strings.ToLower(x509Cert.Subject.CommonName)] = cert
}
for _, san := range x509Cert.DNSNames {
nameToCertificate[strings.ToLower(san)] = cert
}
}
c.nameToCertificate = nameToCertificate
}
if len(c.localCertificates) == 0 {
return nil, errNoCertificates
}
if len(c.localCertificates) == 1 {
// There's only one choice, so no point doing any work.
return &c.localCertificates[0], nil
}
if len(serverName) == 0 {
return &c.localCertificates[0], nil
}
name := strings.TrimRight(strings.ToLower(serverName), ".")
if cert, ok := c.nameToCertificate[name]; ok {
return cert, nil
}
// try replacing labels in the name with wildcards until we get a
// match.
labels := strings.Split(name, ".")
for i := range labels {
labels[i] = "*"
candidate := strings.Join(labels, ".")
if cert, ok := c.nameToCertificate[candidate]; ok {
return cert, nil
}
}
// If nothing matches, return the first certificate.
return &c.localCertificates[0], nil
}

@ -0,0 +1,25 @@
package dtls
// The change cipher spec protocol exists to signal transitions in
// ciphering strategies. The protocol consists of a single message,
// which is encrypted and compressed under the current (not the pending)
// connection state. The message consists of a single byte of value 1.
// https://tools.ietf.org/html/rfc5246#section-7.1
type changeCipherSpec struct {
}
func (c changeCipherSpec) contentType() contentType {
return contentTypeChangeCipherSpec
}
func (c *changeCipherSpec) Marshal() ([]byte, error) {
return []byte{0x01}, nil
}
func (c *changeCipherSpec) Unmarshal(data []byte) error {
if len(data) == 1 && data[0] == 0x01 {
return nil
}
return errInvalidCipherSpec
}

@ -0,0 +1,206 @@
package dtls
import (
"encoding/binary"
"fmt"
"hash"
)
// CipherSuiteID is an ID for our supported CipherSuites
type CipherSuiteID uint16
// Supported Cipher Suites
const (
// AES-128-CCM
TLS_ECDHE_ECDSA_WITH_AES_128_CCM CipherSuiteID = 0xc0ac //nolint:golint,stylecheck
TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 CipherSuiteID = 0xc0ae //nolint:golint,stylecheck
// AES-128-GCM-SHA256
TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 CipherSuiteID = 0xc02b //nolint:golint,stylecheck
TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 CipherSuiteID = 0xc02f //nolint:golint,stylecheck
// AES-256-CBC-SHA
TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA CipherSuiteID = 0xc00a //nolint:golint,stylecheck
TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA CipherSuiteID = 0xc014 //nolint:golint,stylecheck
TLS_PSK_WITH_AES_128_CCM CipherSuiteID = 0xc0a4 //nolint:golint,stylecheck
TLS_PSK_WITH_AES_128_CCM_8 CipherSuiteID = 0xc0a8 //nolint:golint,stylecheck
TLS_PSK_WITH_AES_128_GCM_SHA256 CipherSuiteID = 0x00a8 //nolint:golint,stylecheck
)
var _ = allCipherSuites() // Necessary until this function isn't only used by Go 1.14
func (c CipherSuiteID) String() string {
switch c {
case TLS_ECDHE_ECDSA_WITH_AES_128_CCM:
return "TLS_ECDHE_ECDSA_WITH_AES_128_CCM"
case TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8:
return "TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8"
case TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256:
return "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256"
case TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256:
return "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"
case TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA:
return "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA"
case TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA:
return "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA"
case TLS_PSK_WITH_AES_128_CCM:
return "TLS_PSK_WITH_AES_128_CCM"
case TLS_PSK_WITH_AES_128_CCM_8:
return "TLS_PSK_WITH_AES_128_CCM_8"
case TLS_PSK_WITH_AES_128_GCM_SHA256:
return "TLS_PSK_WITH_AES_128_GCM_SHA256"
default:
return fmt.Sprintf("unknown(%v)", uint16(c))
}
}
type cipherSuite interface {
String() string
ID() CipherSuiteID
certificateType() clientCertificateType
hashFunc() func() hash.Hash
isPSK() bool
isInitialized() bool
// Generate the internal encryption state
init(masterSecret, clientRandom, serverRandom []byte, isClient bool) error
encrypt(pkt *recordLayer, raw []byte) ([]byte, error)
decrypt(in []byte) ([]byte, error)
}
// CipherSuiteName provides the same functionality as tls.CipherSuiteName
// that appeared first in Go 1.14.
//
// Our implementation differs slightly in that it takes in a CiperSuiteID,
// like the rest of our library, instead of a uint16 like crypto/tls.
func CipherSuiteName(id CipherSuiteID) string {
suite := cipherSuiteForID(id)
if suite != nil {
return suite.String()
}
return fmt.Sprintf("0x%04X", uint16(id))
}
// Taken from https://www.iana.org/assignments/tls-parameters/tls-parameters.xml
// A cipherSuite is a specific combination of key agreement, cipher and MAC
// function.
func cipherSuiteForID(id CipherSuiteID) cipherSuite {
switch id {
case TLS_ECDHE_ECDSA_WITH_AES_128_CCM:
return newCipherSuiteTLSEcdheEcdsaWithAes128Ccm()
case TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8:
return newCipherSuiteTLSEcdheEcdsaWithAes128Ccm8()
case TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256:
return &cipherSuiteTLSEcdheEcdsaWithAes128GcmSha256{}
case TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256:
return &cipherSuiteTLSEcdheRsaWithAes128GcmSha256{}
case TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA:
return &cipherSuiteTLSEcdheEcdsaWithAes256CbcSha{}
case TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA:
return &cipherSuiteTLSEcdheRsaWithAes256CbcSha{}
case TLS_PSK_WITH_AES_128_CCM:
return newCipherSuiteTLSPskWithAes128Ccm()
case TLS_PSK_WITH_AES_128_CCM_8:
return newCipherSuiteTLSPskWithAes128Ccm8()
case TLS_PSK_WITH_AES_128_GCM_SHA256:
return &cipherSuiteTLSPskWithAes128GcmSha256{}
}
return nil
}
// CipherSuites we support in order of preference
func defaultCipherSuites() []cipherSuite {
return []cipherSuite{
&cipherSuiteTLSEcdheEcdsaWithAes128GcmSha256{},
&cipherSuiteTLSEcdheRsaWithAes128GcmSha256{},
&cipherSuiteTLSEcdheEcdsaWithAes256CbcSha{},
&cipherSuiteTLSEcdheRsaWithAes256CbcSha{},
}
}
func allCipherSuites() []cipherSuite {
return []cipherSuite{
newCipherSuiteTLSEcdheEcdsaWithAes128Ccm(),
newCipherSuiteTLSEcdheEcdsaWithAes128Ccm8(),
&cipherSuiteTLSEcdheEcdsaWithAes128GcmSha256{},
&cipherSuiteTLSEcdheRsaWithAes128GcmSha256{},
&cipherSuiteTLSEcdheEcdsaWithAes256CbcSha{},
&cipherSuiteTLSEcdheRsaWithAes256CbcSha{},
newCipherSuiteTLSPskWithAes128Ccm(),
newCipherSuiteTLSPskWithAes128Ccm8(),
&cipherSuiteTLSPskWithAes128GcmSha256{},
}
}
func decodeCipherSuites(buf []byte) ([]cipherSuite, error) {
if len(buf) < 2 {
return nil, errDTLSPacketInvalidLength
}
cipherSuitesCount := int(binary.BigEndian.Uint16(buf[0:])) / 2
rtrn := []cipherSuite{}
for i := 0; i < cipherSuitesCount; i++ {
if len(buf) < (i*2 + 4) {
return nil, errBufferTooSmall
}
id := CipherSuiteID(binary.BigEndian.Uint16(buf[(i*2)+2:]))
if c := cipherSuiteForID(id); c != nil {
rtrn = append(rtrn, c)
}
}
return rtrn, nil
}
func encodeCipherSuites(cipherSuites []cipherSuite) []byte {
out := []byte{0x00, 0x00}
binary.BigEndian.PutUint16(out[len(out)-2:], uint16(len(cipherSuites)*2))
for _, c := range cipherSuites {
out = append(out, []byte{0x00, 0x00}...)
binary.BigEndian.PutUint16(out[len(out)-2:], uint16(c.ID()))
}
return out
}
func parseCipherSuites(userSelectedSuites []CipherSuiteID, excludePSK, excludeNonPSK bool) ([]cipherSuite, error) {
cipherSuitesForIDs := func(ids []CipherSuiteID) ([]cipherSuite, error) {
cipherSuites := []cipherSuite{}
for _, id := range ids {
c := cipherSuiteForID(id)
if c == nil {
return nil, &invalidCipherSuite{id}
}
cipherSuites = append(cipherSuites, c)
}
return cipherSuites, nil
}
var (
cipherSuites []cipherSuite
err error
i int
)
if len(userSelectedSuites) != 0 {
cipherSuites, err = cipherSuitesForIDs(userSelectedSuites)
if err != nil {
return nil, err
}
} else {
cipherSuites = defaultCipherSuites()
}
for _, c := range cipherSuites {
if excludePSK && c.isPSK() || excludeNonPSK && !c.isPSK() {
continue
}
cipherSuites[i] = c
i++
}
cipherSuites = cipherSuites[:i]
if len(cipherSuites) == 0 {
return nil, errNoAvailableCipherSuites
}
return cipherSuites, nil
}

@ -0,0 +1,93 @@
package dtls
import (
"crypto/sha256"
"errors"
"fmt"
"hash"
"sync/atomic"
)
type cipherSuiteAes128Ccm struct {
ccm atomic.Value // *cryptoCCM
clientCertificateType clientCertificateType
id CipherSuiteID
psk bool
cryptoCCMTagLen cryptoCCMTagLen
}
func newCipherSuiteAes128Ccm(clientCertificateType clientCertificateType, id CipherSuiteID, psk bool, cryptoCCMTagLen cryptoCCMTagLen) *cipherSuiteAes128Ccm {
return &cipherSuiteAes128Ccm{
clientCertificateType: clientCertificateType,
id: id,
psk: psk,
cryptoCCMTagLen: cryptoCCMTagLen,
}
}
func (c *cipherSuiteAes128Ccm) certificateType() clientCertificateType {
return c.clientCertificateType
}
func (c *cipherSuiteAes128Ccm) ID() CipherSuiteID {
return c.id
}
func (c *cipherSuiteAes128Ccm) String() string {
return c.id.String()
}
func (c *cipherSuiteAes128Ccm) hashFunc() func() hash.Hash {
return sha256.New
}
func (c *cipherSuiteAes128Ccm) isPSK() bool {
return c.psk
}
func (c *cipherSuiteAes128Ccm) isInitialized() bool {
return c.ccm.Load() != nil
}
func (c *cipherSuiteAes128Ccm) init(masterSecret, clientRandom, serverRandom []byte, isClient bool) error {
const (
prfMacLen = 0
prfKeyLen = 16
prfIvLen = 4
)
keys, err := prfEncryptionKeys(masterSecret, clientRandom, serverRandom, prfMacLen, prfKeyLen, prfIvLen, c.hashFunc())
if err != nil {
return err
}
var ccm *cryptoCCM
if isClient {
ccm, err = newCryptoCCM(c.cryptoCCMTagLen, keys.clientWriteKey, keys.clientWriteIV, keys.serverWriteKey, keys.serverWriteIV)
} else {
ccm, err = newCryptoCCM(c.cryptoCCMTagLen, keys.serverWriteKey, keys.serverWriteIV, keys.clientWriteKey, keys.clientWriteIV)
}
c.ccm.Store(ccm)
return err
}
var errCipherSuiteNotInit = errors.New("CipherSuite has not been initialized")
func (c *cipherSuiteAes128Ccm) encrypt(pkt *recordLayer, raw []byte) ([]byte, error) {
ccm := c.ccm.Load()
if ccm == nil { // !c.isInitialized()
return nil, fmt.Errorf("%w, unable to encrypt", errCipherSuiteNotInit)
}
return ccm.(*cryptoCCM).encrypt(pkt, raw)
}
func (c *cipherSuiteAes128Ccm) decrypt(raw []byte) ([]byte, error) {
ccm := c.ccm.Load()
if ccm == nil { // !c.isInitialized()
return nil, fmt.Errorf("%w, unable to decrypt", errCipherSuiteNotInit)
}
return ccm.(*cryptoCCM).decrypt(raw)
}

@ -0,0 +1,36 @@
// +build go1.14
package dtls
import (
"crypto/tls"
)
// Convert from our cipherSuite interface to a tls.CipherSuite struct
func toTLSCipherSuite(c cipherSuite) *tls.CipherSuite {
return &tls.CipherSuite{
ID: uint16(c.ID()),
Name: c.String(),
SupportedVersions: []uint16{VersionDTLS12},
Insecure: false,
}
}
// CipherSuites returns a list of cipher suites currently implemented by this
// package, excluding those with security issues, which are returned by
// InsecureCipherSuites.
func CipherSuites() []*tls.CipherSuite {
suites := allCipherSuites()
res := make([]*tls.CipherSuite, len(suites))
for i, c := range suites {
res[i] = toTLSCipherSuite(c)
}
return res
}
// InsecureCipherSuites returns a list of cipher suites currently implemented by
// this package and which have security issues.
func InsecureCipherSuites() []*tls.CipherSuite {
var res []*tls.CipherSuite
return res
}

@ -0,0 +1,5 @@
package dtls
func newCipherSuiteTLSEcdheEcdsaWithAes128Ccm() *cipherSuiteAes128Ccm {
return newCipherSuiteAes128Ccm(clientCertificateTypeECDSASign, TLS_ECDHE_ECDSA_WITH_AES_128_CCM, false, cryptoCCMTagLength)
}

@ -0,0 +1,5 @@
package dtls
func newCipherSuiteTLSEcdheEcdsaWithAes128Ccm8() *cipherSuiteAes128Ccm {
return newCipherSuiteAes128Ccm(clientCertificateTypeECDSASign, TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8, false, cryptoCCM8TagLength)
}

@ -0,0 +1,77 @@
package dtls
import (
"crypto/sha256"
"fmt"
"hash"
"sync/atomic"
)
type cipherSuiteTLSEcdheEcdsaWithAes128GcmSha256 struct {
gcm atomic.Value // *cryptoGCM
}
func (c *cipherSuiteTLSEcdheEcdsaWithAes128GcmSha256) certificateType() clientCertificateType {
return clientCertificateTypeECDSASign
}
func (c *cipherSuiteTLSEcdheEcdsaWithAes128GcmSha256) ID() CipherSuiteID {
return TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256
}
func (c *cipherSuiteTLSEcdheEcdsaWithAes128GcmSha256) String() string {
return "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256"
}
func (c *cipherSuiteTLSEcdheEcdsaWithAes128GcmSha256) hashFunc() func() hash.Hash {
return sha256.New
}
func (c *cipherSuiteTLSEcdheEcdsaWithAes128GcmSha256) isPSK() bool {
return false
}
func (c *cipherSuiteTLSEcdheEcdsaWithAes128GcmSha256) isInitialized() bool {
return c.gcm.Load() != nil
}
func (c *cipherSuiteTLSEcdheEcdsaWithAes128GcmSha256) init(masterSecret, clientRandom, serverRandom []byte, isClient bool) error {
const (
prfMacLen = 0
prfKeyLen = 16
prfIvLen = 4
)
keys, err := prfEncryptionKeys(masterSecret, clientRandom, serverRandom, prfMacLen, prfKeyLen, prfIvLen, c.hashFunc())
if err != nil {
return err
}
var gcm *cryptoGCM
if isClient {
gcm, err = newCryptoGCM(keys.clientWriteKey, keys.clientWriteIV, keys.serverWriteKey, keys.serverWriteIV)
} else {
gcm, err = newCryptoGCM(keys.serverWriteKey, keys.serverWriteIV, keys.clientWriteKey, keys.clientWriteIV)
}
c.gcm.Store(gcm)
return err
}
func (c *cipherSuiteTLSEcdheEcdsaWithAes128GcmSha256) encrypt(pkt *recordLayer, raw []byte) ([]byte, error) {
gcm := c.gcm.Load()
if gcm == nil { // !c.isInitialized()
return nil, fmt.Errorf("%w, unable to encrypt", errCipherSuiteNotInit)
}
return gcm.(*cryptoGCM).encrypt(pkt, raw)
}
func (c *cipherSuiteTLSEcdheEcdsaWithAes128GcmSha256) decrypt(raw []byte) ([]byte, error) {
gcm := c.gcm.Load()
if gcm == nil { // !c.isInitialized()
return nil, fmt.Errorf("%w, unable to decrypt", errCipherSuiteNotInit)
}
return gcm.(*cryptoGCM).decrypt(raw)
}

@ -0,0 +1,83 @@
package dtls
import (
"crypto/sha256"
"fmt"
"hash"
"sync/atomic"
)
type cipherSuiteTLSEcdheEcdsaWithAes256CbcSha struct {
cbc atomic.Value // *cryptoCBC
}
func (c *cipherSuiteTLSEcdheEcdsaWithAes256CbcSha) certificateType() clientCertificateType {
return clientCertificateTypeECDSASign
}
func (c *cipherSuiteTLSEcdheEcdsaWithAes256CbcSha) ID() CipherSuiteID {
return TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA
}
func (c *cipherSuiteTLSEcdheEcdsaWithAes256CbcSha) String() string {
return "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA"
}
func (c *cipherSuiteTLSEcdheEcdsaWithAes256CbcSha) hashFunc() func() hash.Hash {
return sha256.New
}
func (c *cipherSuiteTLSEcdheEcdsaWithAes256CbcSha) isPSK() bool {
return false
}
func (c *cipherSuiteTLSEcdheEcdsaWithAes256CbcSha) isInitialized() bool {
return c.cbc.Load() != nil
}
func (c *cipherSuiteTLSEcdheEcdsaWithAes256CbcSha) init(masterSecret, clientRandom, serverRandom []byte, isClient bool) error {
const (
prfMacLen = 20
prfKeyLen = 32
prfIvLen = 16
)
keys, err := prfEncryptionKeys(masterSecret, clientRandom, serverRandom, prfMacLen, prfKeyLen, prfIvLen, c.hashFunc())
if err != nil {
return err
}
var cbc *cryptoCBC
if isClient {
cbc, err = newCryptoCBC(
keys.clientWriteKey, keys.clientWriteIV, keys.clientMACKey,
keys.serverWriteKey, keys.serverWriteIV, keys.serverMACKey,
)
} else {
cbc, err = newCryptoCBC(
keys.serverWriteKey, keys.serverWriteIV, keys.serverMACKey,
keys.clientWriteKey, keys.clientWriteIV, keys.clientMACKey,
)
}
c.cbc.Store(cbc)
return err
}
func (c *cipherSuiteTLSEcdheEcdsaWithAes256CbcSha) encrypt(pkt *recordLayer, raw []byte) ([]byte, error) {
cbc := c.cbc.Load()
if cbc == nil { // !c.isInitialized()
return nil, fmt.Errorf("%w, unable to encrypt", errCipherSuiteNotInit)
}
return cbc.(*cryptoCBC).encrypt(pkt, raw)
}
func (c *cipherSuiteTLSEcdheEcdsaWithAes256CbcSha) decrypt(raw []byte) ([]byte, error) {
cbc := c.cbc.Load()
if cbc == nil { // !c.isInitialized()
return nil, fmt.Errorf("%w, unable to decrypt", errCipherSuiteNotInit)
}
return cbc.(*cryptoCBC).decrypt(raw)
}

@ -0,0 +1,17 @@
package dtls
type cipherSuiteTLSEcdheRsaWithAes128GcmSha256 struct {
cipherSuiteTLSEcdheEcdsaWithAes128GcmSha256
}
func (c *cipherSuiteTLSEcdheRsaWithAes128GcmSha256) certificateType() clientCertificateType {
return clientCertificateTypeRSASign
}
func (c *cipherSuiteTLSEcdheRsaWithAes128GcmSha256) ID() CipherSuiteID {
return TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256
}
func (c *cipherSuiteTLSEcdheRsaWithAes128GcmSha256) String() string {
return "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"
}

@ -0,0 +1,17 @@
package dtls
type cipherSuiteTLSEcdheRsaWithAes256CbcSha struct {
cipherSuiteTLSEcdheEcdsaWithAes256CbcSha
}
func (c *cipherSuiteTLSEcdheRsaWithAes256CbcSha) certificateType() clientCertificateType {
return clientCertificateTypeRSASign
}
func (c *cipherSuiteTLSEcdheRsaWithAes256CbcSha) ID() CipherSuiteID {
return TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA
}
func (c *cipherSuiteTLSEcdheRsaWithAes256CbcSha) String() string {
return "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA"
}

@ -0,0 +1,5 @@
package dtls
func newCipherSuiteTLSPskWithAes128Ccm() *cipherSuiteAes128Ccm {
return newCipherSuiteAes128Ccm(clientCertificateType(0), TLS_PSK_WITH_AES_128_CCM, true, cryptoCCMTagLength)
}

@ -0,0 +1,5 @@
package dtls
func newCipherSuiteTLSPskWithAes128Ccm8() *cipherSuiteAes128Ccm {
return newCipherSuiteAes128Ccm(clientCertificateType(0), TLS_PSK_WITH_AES_128_CCM_8, true, cryptoCCM8TagLength)
}

@ -0,0 +1,21 @@
package dtls
type cipherSuiteTLSPskWithAes128GcmSha256 struct {
cipherSuiteTLSEcdheEcdsaWithAes128GcmSha256
}
func (c *cipherSuiteTLSPskWithAes128GcmSha256) certificateType() clientCertificateType {
return clientCertificateType(0)
}
func (c *cipherSuiteTLSPskWithAes128GcmSha256) ID() CipherSuiteID {
return TLS_PSK_WITH_AES_128_GCM_SHA256
}
func (c *cipherSuiteTLSPskWithAes128GcmSha256) String() string {
return "TLS_PSK_WITH_AES_128_GCM_SHA256"
}
func (c *cipherSuiteTLSPskWithAes128GcmSha256) isPSK() bool {
return true
}

@ -0,0 +1,16 @@
package dtls
// https://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml#tls-parameters-10
type clientCertificateType byte
const (
clientCertificateTypeRSASign clientCertificateType = 1
clientCertificateTypeECDSASign clientCertificateType = 64
)
func clientCertificateTypes() map[clientCertificateType]bool {
return map[clientCertificateType]bool{
clientCertificateTypeRSASign: true,
clientCertificateTypeECDSASign: true,
}
}

@ -0,0 +1,20 @@
#
# DO NOT EDIT THIS FILE
#
# It is automatically copied from https://github.com/pion/.goassets repository.
#
coverage:
status:
project:
default:
# Allow decreasing 2% of total coverage to avoid noise.
threshold: 2%
patch:
default:
target: 70%
only_pulls: true
ignore:
- "examples/*"
- "examples/**/*"

@ -0,0 +1,49 @@
package dtls
type compressionMethodID byte
const (
compressionMethodNull compressionMethodID = 0
)
type compressionMethod struct {
id compressionMethodID
}
func compressionMethods() map[compressionMethodID]*compressionMethod {
return map[compressionMethodID]*compressionMethod{
compressionMethodNull: {id: compressionMethodNull},
}
}
func defaultCompressionMethods() []*compressionMethod {
return []*compressionMethod{
compressionMethods()[compressionMethodNull],
}
}
func decodeCompressionMethods(buf []byte) ([]*compressionMethod, error) {
if len(buf) < 1 {
return nil, errDTLSPacketInvalidLength
}
compressionMethodsCount := int(buf[0])
c := []*compressionMethod{}
for i := 0; i < compressionMethodsCount; i++ {
if len(buf) <= i+1 {
return nil, errBufferTooSmall
}
id := compressionMethodID(buf[i+1])
if compressionMethod, ok := compressionMethods()[id]; ok {
c = append(c, compressionMethod)
}
}
return c, nil
}
func encodeCompressionMethods(c []*compressionMethod) []byte {
out := []byte{byte(len(c))}
for i := len(c); i > 0; i-- {
out = append(out, byte(c[i-1].id))
}
return out
}

@ -0,0 +1,179 @@
package dtls
import (
"context"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/tls"
"crypto/x509"
"time"
"github.com/pion/logging"
)
// Config is used to configure a DTLS client or server.
// After a Config is passed to a DTLS function it must not be modified.
type Config struct {
// Certificates contains certificate chain to present to the other side of the connection.
// Server MUST set this if PSK is non-nil
// client SHOULD sets this so CertificateRequests can be handled if PSK is non-nil
Certificates []tls.Certificate
// CipherSuites is a list of supported cipher suites.
// If CipherSuites is nil, a default list is used
CipherSuites []CipherSuiteID
// SignatureSchemes contains the signature and hash schemes that the peer requests to verify.
SignatureSchemes []tls.SignatureScheme
// SRTPProtectionProfiles are the supported protection profiles
// Clients will send this via use_srtp and assert that the server properly responds
// Servers will assert that clients send one of these profiles and will respond as needed
SRTPProtectionProfiles []SRTPProtectionProfile
// ClientAuth determines the server's policy for
// TLS Client Authentication. The default is NoClientCert.
ClientAuth ClientAuthType
// RequireExtendedMasterSecret determines if the "Extended Master Secret" extension
// should be disabled, requested, or required (default requested).
ExtendedMasterSecret ExtendedMasterSecretType
// FlightInterval controls how often we send outbound handshake messages
// defaults to time.Second
FlightInterval time.Duration
// PSK sets the pre-shared key used by this DTLS connection
// If PSK is non-nil only PSK CipherSuites will be used
PSK PSKCallback
PSKIdentityHint []byte
// InsecureSkipVerify controls whether a client verifies the
// server's certificate chain and host name.
// If InsecureSkipVerify is true, TLS accepts any certificate
// presented by the server and any host name in that certificate.
// In this mode, TLS is susceptible to man-in-the-middle attacks.
// This should be used only for testing.
InsecureSkipVerify bool
// InsecureHashes allows the use of hashing algorithms that are known
// to be vulnerable.
InsecureHashes bool
// VerifyPeerCertificate, if not nil, is called after normal
// certificate verification by either a client or server. It
// receives the certificate provided by the peer and also a flag
// that tells if normal verification has succeedded. If it returns a
// non-nil error, the handshake is aborted and that error results.
//
// If normal verification fails then the handshake will abort before
// considering this callback. If normal verification is disabled by
// setting InsecureSkipVerify, or (for a server) when ClientAuth is
// RequestClientCert or RequireAnyClientCert, then this callback will
// be considered but the verifiedChains will always be nil.
VerifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error
// RootCAs defines the set of root certificate authorities
// that one peer uses when verifying the other peer's certificates.
// If RootCAs is nil, TLS uses the host's root CA set.
RootCAs *x509.CertPool
// ClientCAs defines the set of root certificate authorities
// that servers use if required to verify a client certificate
// by the policy in ClientAuth.
ClientCAs *x509.CertPool
// ServerName is used to verify the hostname on the returned
// certificates unless InsecureSkipVerify is given.
ServerName string
LoggerFactory logging.LoggerFactory
// ConnectContextMaker is a function to make a context used in Dial(),
// Client(), Server(), and Accept(). If nil, the default ConnectContextMaker
// is used. It can be implemented as following.
//
// func ConnectContextMaker() (context.Context, func()) {
// return context.WithTimeout(context.Background(), 30*time.Second)
// }
ConnectContextMaker func() (context.Context, func())
// MTU is the length at which handshake messages will be fragmented to
// fit within the maximum transmission unit (default is 1200 bytes)
MTU int
// ReplayProtectionWindow is the size of the replay attack protection window.
// Duplication of the sequence number is checked in this window size.
// Packet with sequence number older than this value compared to the latest
// accepted packet will be discarded. (default is 64)
ReplayProtectionWindow int
}
func defaultConnectContextMaker() (context.Context, func()) {
return context.WithTimeout(context.Background(), 30*time.Second)
}
func (c *Config) connectContextMaker() (context.Context, func()) {
if c.ConnectContextMaker == nil {
return defaultConnectContextMaker()
}
return c.ConnectContextMaker()
}
const defaultMTU = 1200 // bytes
// PSKCallback is called once we have the remote's PSKIdentityHint.
// If the remote provided none it will be nil
type PSKCallback func([]byte) ([]byte, error)
// ClientAuthType declares the policy the server will follow for
// TLS Client Authentication.
type ClientAuthType int
// ClientAuthType enums
const (
NoClientCert ClientAuthType = iota
RequestClientCert
RequireAnyClientCert
VerifyClientCertIfGiven
RequireAndVerifyClientCert
)
// ExtendedMasterSecretType declares the policy the client and server
// will follow for the Extended Master Secret extension
type ExtendedMasterSecretType int
// ExtendedMasterSecretType enums
const (
RequestExtendedMasterSecret ExtendedMasterSecretType = iota
RequireExtendedMasterSecret
DisableExtendedMasterSecret
)
func validateConfig(config *Config) error {
switch {
case config == nil:
return errNoConfigProvided
case len(config.Certificates) > 0 && config.PSK != nil:
return errPSKAndCertificate
case config.PSKIdentityHint != nil && config.PSK == nil:
return errIdentityNoPSK
}
for _, cert := range config.Certificates {
if cert.Certificate == nil {
return errInvalidCertificate
}
if cert.PrivateKey != nil {
switch cert.PrivateKey.(type) {
case ed25519.PrivateKey:
case *ecdsa.PrivateKey:
default:
return errInvalidPrivateKey
}
}
}
_, err := parseCipherSuites(config.CipherSuites, config.PSK == nil, config.PSK != nil)
return err
}

@ -0,0 +1,978 @@
package dtls
import (
"context"
"errors"
"fmt"
"io"
"net"
"sync"
"sync/atomic"
"time"
"github.com/pion/dtls/v2/internal/closer"
"github.com/pion/dtls/v2/internal/net/connctx"
"github.com/pion/logging"
"github.com/pion/transport/deadline"
"github.com/pion/transport/replaydetector"
)
const (
initialTickerInterval = time.Second
cookieLength = 20
defaultNamedCurve = namedCurveX25519
inboundBufferSize = 8192
// Default replay protection window is specified by RFC 6347 Section 4.1.2.6
defaultReplayProtectionWindow = 64
)
var (
errApplicationDataEpochZero = errors.New("ApplicationData with epoch of 0")
errUnhandledContextType = errors.New("unhandled contentType")
)
func invalidKeyingLabels() map[string]bool {
return map[string]bool{
"client finished": true,
"server finished": true,
"master secret": true,
"key expansion": true,
}
}
// Conn represents a DTLS connection
type Conn struct {
lock sync.RWMutex // Internal lock (must not be public)
nextConn connctx.ConnCtx // Embedded Conn, typically a udpconn we read/write from
fragmentBuffer *fragmentBuffer // out-of-order and missing fragment handling
handshakeCache *handshakeCache // caching of handshake messages for verifyData generation
decrypted chan interface{} // Decrypted Application Data or error, pull by calling `Read`
state State // Internal state
maximumTransmissionUnit int
handshakeCompletedSuccessfully atomic.Value
encryptedPackets [][]byte
connectionClosedByUser bool
closeLock sync.Mutex
closed *closer.Closer
handshakeLoopsFinished sync.WaitGroup
readDeadline *deadline.Deadline
writeDeadline *deadline.Deadline
log logging.LeveledLogger
reading chan struct{}
handshakeRecv chan chan struct{}
cancelHandshaker func()
cancelHandshakeReader func()
fsm *handshakeFSM
replayProtectionWindow uint
}
func createConn(ctx context.Context, nextConn net.Conn, config *Config, isClient bool, initialState *State) (*Conn, error) {
err := validateConfig(config)
if err != nil {
return nil, err
}
if nextConn == nil {
return nil, errNilNextConn
}
cipherSuites, err := parseCipherSuites(config.CipherSuites, config.PSK == nil, config.PSK != nil)
if err != nil {
return nil, err
}
signatureSchemes, err := parseSignatureSchemes(config.SignatureSchemes, config.InsecureHashes)
if err != nil {
return nil, err
}
workerInterval := initialTickerInterval
if config.FlightInterval != 0 {
workerInterval = config.FlightInterval
}
loggerFactory := config.LoggerFactory
if loggerFactory == nil {
loggerFactory = logging.NewDefaultLoggerFactory()
}
logger := loggerFactory.NewLogger("dtls")
mtu := config.MTU
if mtu <= 0 {
mtu = defaultMTU
}
replayProtectionWindow := config.ReplayProtectionWindow
if replayProtectionWindow <= 0 {
replayProtectionWindow = defaultReplayProtectionWindow
}
c := &Conn{
nextConn: connctx.New(nextConn),
fragmentBuffer: newFragmentBuffer(),
handshakeCache: newHandshakeCache(),
maximumTransmissionUnit: mtu,
decrypted: make(chan interface{}, 1),
log: logger,
readDeadline: deadline.New(),
writeDeadline: deadline.New(),
reading: make(chan struct{}, 1),
handshakeRecv: make(chan chan struct{}),
closed: closer.NewCloser(),
cancelHandshaker: func() {},
replayProtectionWindow: uint(replayProtectionWindow),
state: State{
isClient: isClient,
},
}
c.setRemoteEpoch(0)
c.setLocalEpoch(0)
serverName := config.ServerName
// Use host from conn address when serverName is not provided
if isClient && serverName == "" && nextConn.RemoteAddr() != nil {
remoteAddr := nextConn.RemoteAddr().String()
var host string
host, _, err = net.SplitHostPort(remoteAddr)
if err != nil {
serverName = remoteAddr
} else {
serverName = host
}
}
hsCfg := &handshakeConfig{
localPSKCallback: config.PSK,
localPSKIdentityHint: config.PSKIdentityHint,
localCipherSuites: cipherSuites,
localSignatureSchemes: signatureSchemes,
extendedMasterSecret: config.ExtendedMasterSecret,
localSRTPProtectionProfiles: config.SRTPProtectionProfiles,
serverName: serverName,
clientAuth: config.ClientAuth,
localCertificates: config.Certificates,
insecureSkipVerify: config.InsecureSkipVerify,
verifyPeerCertificate: config.VerifyPeerCertificate,
rootCAs: config.RootCAs,
clientCAs: config.ClientCAs,
retransmitInterval: workerInterval,
log: logger,
initialEpoch: 0,
}
var initialFlight flightVal
var initialFSMState handshakeState
if initialState != nil {
if c.state.isClient {
initialFlight = flight5
} else {
initialFlight = flight6
}
initialFSMState = handshakeFinished
c.state = *initialState
} else {
if c.state.isClient {
initialFlight = flight1
} else {
initialFlight = flight0
}
initialFSMState = handshakePreparing
}
// Do handshake
if err := c.handshake(ctx, hsCfg, initialFlight, initialFSMState); err != nil {
return nil, err
}
c.log.Trace("Handshake Completed")
return c, nil
}
// Dial connects to the given network address and establishes a DTLS connection on top.
// Connection handshake will timeout using ConnectContextMaker in the Config.
// If you want to specify the timeout duration, use DialWithContext() instead.
func Dial(network string, raddr *net.UDPAddr, config *Config) (*Conn, error) {
ctx, cancel := config.connectContextMaker()
defer cancel()
return DialWithContext(ctx, network, raddr, config)
}
// Client establishes a DTLS connection over an existing connection.
// Connection handshake will timeout using ConnectContextMaker in the Config.
// If you want to specify the timeout duration, use ClientWithContext() instead.
func Client(conn net.Conn, config *Config) (*Conn, error) {
ctx, cancel := config.connectContextMaker()
defer cancel()
return ClientWithContext(ctx, conn, config)
}
// Server listens for incoming DTLS connections.
// Connection handshake will timeout using ConnectContextMaker in the Config.
// If you want to specify the timeout duration, use ServerWithContext() instead.
func Server(conn net.Conn, config *Config) (*Conn, error) {
ctx, cancel := config.connectContextMaker()
defer cancel()
return ServerWithContext(ctx, conn, config)
}
// DialWithContext connects to the given network address and establishes a DTLS connection on top.
func DialWithContext(ctx context.Context, network string, raddr *net.UDPAddr, config *Config) (*Conn, error) {
pConn, err := net.DialUDP(network, nil, raddr)
if err != nil {
return nil, err
}
return ClientWithContext(ctx, pConn, config)
}
// ClientWithContext establishes a DTLS connection over an existing connection.
func ClientWithContext(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) {
switch {
case config == nil:
return nil, errNoConfigProvided
case config.PSK != nil && config.PSKIdentityHint == nil:
return nil, errPSKAndIdentityMustBeSetForClient
}
return createConn(ctx, conn, config, true, nil)
}
// ServerWithContext listens for incoming DTLS connections.
func ServerWithContext(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) {
switch {
case config == nil:
return nil, errNoConfigProvided
case config.PSK == nil && len(config.Certificates) == 0:
return nil, errServerMustHaveCertificate
}
return createConn(ctx, conn, config, false, nil)
}
// Read reads data from the connection.
func (c *Conn) Read(p []byte) (n int, err error) {
if !c.isHandshakeCompletedSuccessfully() {
return 0, errHandshakeInProgress
}
select {
case <-c.readDeadline.Done():
return 0, errDeadlineExceeded
default:
}
for {
select {
case <-c.readDeadline.Done():
return 0, errDeadlineExceeded
case out, ok := <-c.decrypted:
if !ok {
return 0, io.EOF
}
switch val := out.(type) {
case ([]byte):
if len(p) < len(val) {
return 0, errBufferTooSmall
}
copy(p, val)
return len(val), nil
case (error):
return 0, val
}
}
}
}
// Write writes len(p) bytes from p to the DTLS connection
func (c *Conn) Write(p []byte) (int, error) {
if c.isConnectionClosed() {
return 0, ErrConnClosed
}
select {
case <-c.writeDeadline.Done():
return 0, errDeadlineExceeded
default:
}
if !c.isHandshakeCompletedSuccessfully() {
return 0, errHandshakeInProgress
}
return len(p), c.writePackets(c.writeDeadline, []*packet{
{
record: &recordLayer{
recordLayerHeader: recordLayerHeader{
epoch: c.getLocalEpoch(),
protocolVersion: protocolVersion1_2,
},
content: &applicationData{
data: p,
},
},
shouldEncrypt: true,
},
})
}
// Close closes the connection.
func (c *Conn) Close() error {
err := c.close(true)
c.handshakeLoopsFinished.Wait()
return err
}
// ConnectionState returns basic DTLS details about the connection.
// Note that this replaced the `Export` function of v1.
func (c *Conn) ConnectionState() State {
c.lock.RLock()
defer c.lock.RUnlock()
return *c.state.clone()
}
// SelectedSRTPProtectionProfile returns the selected SRTPProtectionProfile
func (c *Conn) SelectedSRTPProtectionProfile() (SRTPProtectionProfile, bool) {
c.lock.RLock()
defer c.lock.RUnlock()
if c.state.srtpProtectionProfile == 0 {
return 0, false
}
return c.state.srtpProtectionProfile, true
}
func (c *Conn) writePackets(ctx context.Context, pkts []*packet) error {
c.lock.Lock()
defer c.lock.Unlock()
var rawPackets [][]byte
for _, p := range pkts {
if h, ok := p.record.content.(*handshake); ok {
handshakeRaw, err := p.record.Marshal()
if err != nil {
return err
}
c.log.Tracef("[handshake:%v] -> %s (epoch: %d, seq: %d)",
srvCliStr(c.state.isClient), h.handshakeHeader.handshakeType.String(),
p.record.recordLayerHeader.epoch, h.handshakeHeader.messageSequence)
c.handshakeCache.push(handshakeRaw[recordLayerHeaderSize:], p.record.recordLayerHeader.epoch, h.handshakeHeader.messageSequence, h.handshakeHeader.handshakeType, c.state.isClient)
rawHandshakePackets, err := c.processHandshakePacket(p, h)
if err != nil {
return err
}
rawPackets = append(rawPackets, rawHandshakePackets...)
} else {
rawPacket, err := c.processPacket(p)
if err != nil {
return err
}
rawPackets = append(rawPackets, rawPacket)
}
}
if len(rawPackets) == 0 {
return nil
}
compactedRawPackets := c.compactRawPackets(rawPackets)
for _, compactedRawPackets := range compactedRawPackets {
if _, err := c.nextConn.Write(ctx, compactedRawPackets); err != nil {
return netError(err)
}
}
return nil
}
func (c *Conn) compactRawPackets(rawPackets [][]byte) [][]byte {
combinedRawPackets := make([][]byte, 0)
currentCombinedRawPacket := make([]byte, 0)
for _, rawPacket := range rawPackets {
if len(currentCombinedRawPacket) > 0 && len(currentCombinedRawPacket)+len(rawPacket) >= c.maximumTransmissionUnit {
combinedRawPackets = append(combinedRawPackets, currentCombinedRawPacket)
currentCombinedRawPacket = []byte{}
}
currentCombinedRawPacket = append(currentCombinedRawPacket, rawPacket...)
}
combinedRawPackets = append(combinedRawPackets, currentCombinedRawPacket)
return combinedRawPackets
}
func (c *Conn) processPacket(p *packet) ([]byte, error) {
epoch := p.record.recordLayerHeader.epoch
for len(c.state.localSequenceNumber) <= int(epoch) {
c.state.localSequenceNumber = append(c.state.localSequenceNumber, uint64(0))
}
seq := atomic.AddUint64(&c.state.localSequenceNumber[epoch], 1) - 1
if seq > maxSequenceNumber {
// RFC 6347 Section 4.1.0
// The implementation must either abandon an association or rehandshake
// prior to allowing the sequence number to wrap.
return nil, errSequenceNumberOverflow
}
p.record.recordLayerHeader.sequenceNumber = seq
rawPacket, err := p.record.Marshal()
if err != nil {
return nil, err
}
if p.shouldEncrypt {
var err error
rawPacket, err = c.state.cipherSuite.encrypt(p.record, rawPacket)
if err != nil {
return nil, err
}
}
return rawPacket, nil
}
func (c *Conn) processHandshakePacket(p *packet, h *handshake) ([][]byte, error) {
rawPackets := make([][]byte, 0)
handshakeFragments, err := c.fragmentHandshake(h)
if err != nil {
return nil, err
}
epoch := p.record.recordLayerHeader.epoch
for len(c.state.localSequenceNumber) <= int(epoch) {
c.state.localSequenceNumber = append(c.state.localSequenceNumber, uint64(0))
}
for _, handshakeFragment := range handshakeFragments {
seq := atomic.AddUint64(&c.state.localSequenceNumber[epoch], 1) - 1
if seq > maxSequenceNumber {
return nil, errSequenceNumberOverflow
}
recordLayerHeader := &recordLayerHeader{
protocolVersion: p.record.recordLayerHeader.protocolVersion,
contentType: p.record.recordLayerHeader.contentType,
contentLen: uint16(len(handshakeFragment)),
epoch: p.record.recordLayerHeader.epoch,
sequenceNumber: seq,
}
recordLayerHeaderBytes, err := recordLayerHeader.Marshal()
if err != nil {
return nil, err
}
p.record.recordLayerHeader = *recordLayerHeader
rawPacket := append(recordLayerHeaderBytes, handshakeFragment...)
if p.shouldEncrypt {
var err error
rawPacket, err = c.state.cipherSuite.encrypt(p.record, rawPacket)
if err != nil {
return nil, err
}
}
rawPackets = append(rawPackets, rawPacket)
}
return rawPackets, nil
}
func (c *Conn) fragmentHandshake(h *handshake) ([][]byte, error) {
content, err := h.handshakeMessage.Marshal()
if err != nil {
return nil, err
}
fragmentedHandshakes := make([][]byte, 0)
contentFragments := splitBytes(content, c.maximumTransmissionUnit)
if len(contentFragments) == 0 {
contentFragments = [][]byte{
{},
}
}
offset := 0
for _, contentFragment := range contentFragments {
contentFragmentLen := len(contentFragment)
handshakeHeaderFragment := &handshakeHeader{
handshakeType: h.handshakeHeader.handshakeType,
length: h.handshakeHeader.length,
messageSequence: h.handshakeHeader.messageSequence,
fragmentOffset: uint32(offset),
fragmentLength: uint32(contentFragmentLen),
}
offset += contentFragmentLen
handshakeHeaderFragmentRaw, err := handshakeHeaderFragment.Marshal()
if err != nil {
return nil, err
}
fragmentedHandshake := append(handshakeHeaderFragmentRaw, contentFragment...)
fragmentedHandshakes = append(fragmentedHandshakes, fragmentedHandshake)
}
return fragmentedHandshakes, nil
}
var poolReadBuffer = sync.Pool{ //nolint:gochecknoglobals
New: func() interface{} {
b := make([]byte, inboundBufferSize)
return &b
},
}
func (c *Conn) readAndBuffer(ctx context.Context) error {
bufptr := poolReadBuffer.Get().(*[]byte)
defer poolReadBuffer.Put(bufptr)
b := *bufptr
i, err := c.nextConn.Read(ctx, b)
if err != nil {
return netError(err)
}
pkts, err := unpackDatagram(b[:i])
if err != nil {
return err
}
var hasHandshake bool
for _, p := range pkts {
hs, alert, err := c.handleIncomingPacket(p, true)
if alert != nil {
if alertErr := c.notify(ctx, alert.alertLevel, alert.alertDescription); alertErr != nil {
if err == nil {
err = alertErr
}
}
}
if hs {
hasHandshake = true
}
switch e := err.(type) {
case nil:
case *errAlert:
if e.IsFatalOrCloseNotify() {
return e
}
default:
return e
}
}
if hasHandshake {
done := make(chan struct{})
select {
case c.handshakeRecv <- done:
// If the other party may retransmit the flight,
// we should respond even if it not a new message.
<-done
case <-c.fsm.Done():
}
}
return nil
}
func (c *Conn) handleQueuedPackets(ctx context.Context) error {
pkts := c.encryptedPackets
c.encryptedPackets = nil
for _, p := range pkts {
_, alert, err := c.handleIncomingPacket(p, false) // don't re-enqueue
if alert != nil {
if alertErr := c.notify(ctx, alert.alertLevel, alert.alertDescription); alertErr != nil {
if err == nil {
err = alertErr
}
}
}
switch e := err.(type) {
case nil:
case *errAlert:
if e.IsFatalOrCloseNotify() {
return e
}
default:
return e
}
}
return nil
}
func (c *Conn) handleIncomingPacket(buf []byte, enqueue bool) (bool, *alert, error) { //nolint:gocognit
h := &recordLayerHeader{}
if err := h.Unmarshal(buf); err != nil {
// Decode error must be silently discarded
// [RFC6347 Section-4.1.2.7]
c.log.Debugf("discarded broken packet: %v", err)
return false, nil, nil
}
// Validate epoch
remoteEpoch := c.getRemoteEpoch()
if h.epoch > remoteEpoch {
if h.epoch > remoteEpoch+1 {
c.log.Debugf("discarded future packet (epoch: %d, seq: %d)",
h.epoch, h.sequenceNumber,
)
return false, nil, nil
}
if enqueue {
c.log.Debug("received packet of next epoch, queuing packet")
c.encryptedPackets = append(c.encryptedPackets, buf)
}
return false, nil, nil
}
// Anti-replay protection
for len(c.state.replayDetector) <= int(h.epoch) {
c.state.replayDetector = append(c.state.replayDetector,
replaydetector.New(c.replayProtectionWindow, maxSequenceNumber),
)
}
markPacketAsValid, ok := c.state.replayDetector[int(h.epoch)].Check(h.sequenceNumber)
if !ok {
c.log.Debugf("discarded duplicated packet (epoch: %d, seq: %d)",
h.epoch, h.sequenceNumber,
)
return false, nil, nil
}
// Decrypt
if h.epoch != 0 {
if c.state.cipherSuite == nil || !c.state.cipherSuite.isInitialized() {
if enqueue {
c.encryptedPackets = append(c.encryptedPackets, buf)
c.log.Debug("handshake not finished, queuing packet")
}
return false, nil, nil
}
var err error
buf, err = c.state.cipherSuite.decrypt(buf)
if err != nil {
c.log.Debugf("%s: decrypt failed: %s", srvCliStr(c.state.isClient), err)
return false, nil, nil
}
}
isHandshake, err := c.fragmentBuffer.push(append([]byte{}, buf...))
if err != nil {
// Decode error must be silently discarded
// [RFC6347 Section-4.1.2.7]
c.log.Debugf("defragment failed: %s", err)
return false, nil, nil
} else if isHandshake {
markPacketAsValid()
for out, epoch := c.fragmentBuffer.pop(); out != nil; out, epoch = c.fragmentBuffer.pop() {
rawHandshake := &handshake{}
if err := rawHandshake.Unmarshal(out); err != nil {
c.log.Debugf("%s: handshake parse failed: %s", srvCliStr(c.state.isClient), err)
continue
}
_ = c.handshakeCache.push(out, epoch, rawHandshake.handshakeHeader.messageSequence, rawHandshake.handshakeHeader.handshakeType, !c.state.isClient)
}
return true, nil, nil
}
r := &recordLayer{}
if err := r.Unmarshal(buf); err != nil {
return false, &alert{alertLevelFatal, alertDecodeError}, err
}
switch content := r.content.(type) {
case *alert:
c.log.Tracef("%s: <- %s", srvCliStr(c.state.isClient), content.String())
var a *alert
if content.alertDescription == alertCloseNotify {
// Respond with a close_notify [RFC5246 Section 7.2.1]
a = &alert{alertLevelWarning, alertCloseNotify}
}
markPacketAsValid()
return false, a, &errAlert{content}
case *changeCipherSpec:
if c.state.cipherSuite == nil || !c.state.cipherSuite.isInitialized() {
if enqueue {
c.encryptedPackets = append(c.encryptedPackets, buf)
c.log.Debugf("CipherSuite not initialized, queuing packet")
}
return false, nil, nil
}
newRemoteEpoch := h.epoch + 1
c.log.Tracef("%s: <- ChangeCipherSpec (epoch: %d)", srvCliStr(c.state.isClient), newRemoteEpoch)
if c.getRemoteEpoch()+1 == newRemoteEpoch {
c.setRemoteEpoch(newRemoteEpoch)
markPacketAsValid()
}
case *applicationData:
if h.epoch == 0 {
return false, &alert{alertLevelFatal, alertUnexpectedMessage}, errApplicationDataEpochZero
}
markPacketAsValid()
select {
case c.decrypted <- content.data:
case <-c.closed.Done():
}
default:
return false, &alert{alertLevelFatal, alertUnexpectedMessage}, fmt.Errorf("%w: %d", errUnhandledContextType, content.contentType())
}
return false, nil, nil
}
func (c *Conn) recvHandshake() <-chan chan struct{} {
return c.handshakeRecv
}
func (c *Conn) notify(ctx context.Context, level alertLevel, desc alertDescription) error {
return c.writePackets(ctx, []*packet{
{
record: &recordLayer{
recordLayerHeader: recordLayerHeader{
epoch: c.getLocalEpoch(),
protocolVersion: protocolVersion1_2,
},
content: &alert{
alertLevel: level,
alertDescription: desc,
},
},
shouldEncrypt: c.isHandshakeCompletedSuccessfully(),
},
})
}
func (c *Conn) setHandshakeCompletedSuccessfully() {
c.handshakeCompletedSuccessfully.Store(struct{ bool }{true})
}
func (c *Conn) isHandshakeCompletedSuccessfully() bool {
boolean, _ := c.handshakeCompletedSuccessfully.Load().(struct{ bool })
return boolean.bool
}
func (c *Conn) handshake(ctx context.Context, cfg *handshakeConfig, initialFlight flightVal, initialState handshakeState) error { //nolint:gocognit
c.fsm = newHandshakeFSM(&c.state, c.handshakeCache, cfg, initialFlight)
done := make(chan struct{})
ctxRead, cancelRead := context.WithCancel(context.Background())
c.cancelHandshakeReader = cancelRead
cfg.onFlightState = func(f flightVal, s handshakeState) {
if s == handshakeFinished && !c.isHandshakeCompletedSuccessfully() {
c.setHandshakeCompletedSuccessfully()
close(done)
}
}
ctxHs, cancel := context.WithCancel(context.Background())
c.cancelHandshaker = cancel
firstErr := make(chan error, 1)
c.handshakeLoopsFinished.Add(2)
// Handshake routine should be live until close.
// The other party may request retransmission of the last flight to cope with packet drop.
go func() {
defer c.handshakeLoopsFinished.Done()
err := c.fsm.Run(ctxHs, c, initialState)
if !errors.Is(err, context.Canceled) {
select {
case firstErr <- err:
default:
}
}
}()
go func() {
defer func() {
// Escaping read loop.
// It's safe to close decrypted channnel now.
close(c.decrypted)
// Force stop handshaker when the underlying connection is closed.
cancel()
}()
defer c.handshakeLoopsFinished.Done()
for {
if err := c.readAndBuffer(ctxRead); err != nil {
switch e := err.(type) {
case *errAlert:
if !e.IsFatalOrCloseNotify() {
if c.isHandshakeCompletedSuccessfully() {
// Pass the error to Read()
select {
case c.decrypted <- err:
case <-c.closed.Done():
}
}
continue // non-fatal alert must not stop read loop
}
case error:
switch err {
case context.DeadlineExceeded, context.Canceled, io.EOF:
default:
if c.isHandshakeCompletedSuccessfully() {
// Keep read loop and pass the read error to Read()
select {
case c.decrypted <- err:
case <-c.closed.Done():
}
continue // non-fatal alert must not stop read loop
}
}
}
select {
case firstErr <- err:
default:
}
if e, ok := err.(*errAlert); ok {
if e.IsFatalOrCloseNotify() {
_ = c.close(false)
}
}
return
}
}
}()
select {
case err := <-firstErr:
cancelRead()
cancel()
return c.translateHandshakeCtxError(err)
case <-ctx.Done():
cancelRead()
cancel()
return c.translateHandshakeCtxError(ctx.Err())
case <-done:
return nil
}
}
func (c *Conn) translateHandshakeCtxError(err error) error {
if err == nil {
return nil
}
if errors.Is(err, context.Canceled) && c.isHandshakeCompletedSuccessfully() {
return nil
}
return &HandshakeError{err}
}
func (c *Conn) close(byUser bool) error {
c.cancelHandshaker()
c.cancelHandshakeReader()
if c.isHandshakeCompletedSuccessfully() && byUser {
// Discard error from notify() to return non-error on the first user call of Close()
// even if the underlying connection is already closed.
_ = c.notify(context.Background(), alertLevelWarning, alertCloseNotify)
}
c.closeLock.Lock()
// Don't return ErrConnClosed at the first time of the call from user.
closedByUser := c.connectionClosedByUser
if byUser {
c.connectionClosedByUser = true
}
c.closed.Close()
c.closeLock.Unlock()
if closedByUser {
return ErrConnClosed
}
return c.nextConn.Close()
}
func (c *Conn) isConnectionClosed() bool {
select {
case <-c.closed.Done():
return true
default:
return false
}
}
func (c *Conn) setLocalEpoch(epoch uint16) {
c.state.localEpoch.Store(epoch)
}
func (c *Conn) getLocalEpoch() uint16 {
return c.state.localEpoch.Load().(uint16)
}
func (c *Conn) setRemoteEpoch(epoch uint16) {
c.state.remoteEpoch.Store(epoch)
}
func (c *Conn) getRemoteEpoch() uint16 {
return c.state.remoteEpoch.Load().(uint16)
}
// LocalAddr implements net.Conn.LocalAddr
func (c *Conn) LocalAddr() net.Addr {
return c.nextConn.LocalAddr()
}
// RemoteAddr implements net.Conn.RemoteAddr
func (c *Conn) RemoteAddr() net.Addr {
return c.nextConn.RemoteAddr()
}
// SetDeadline implements net.Conn.SetDeadline
func (c *Conn) SetDeadline(t time.Time) error {
c.readDeadline.Set(t)
return c.SetWriteDeadline(t)
}
// SetReadDeadline implements net.Conn.SetReadDeadline
func (c *Conn) SetReadDeadline(t time.Time) error {
c.readDeadline.Set(t)
// Read deadline is fully managed by this layer.
// Don't set read deadline to underlying connection.
return nil
}
// SetWriteDeadline implements net.Conn.SetWriteDeadline
func (c *Conn) SetWriteDeadline(t time.Time) error {
c.writeDeadline.Set(t)
// Write deadline is also fully managed by this layer.
return nil
}

@ -0,0 +1,17 @@
package dtls
// https://tools.ietf.org/html/rfc4346#section-6.2.1
type contentType uint8
const (
contentTypeChangeCipherSpec contentType = 20
contentTypeAlert contentType = 21
contentTypeHandshake contentType = 22
contentTypeApplicationData contentType = 23
)
type content interface {
contentType() contentType
Marshal() ([]byte, error)
Unmarshal(data []byte) error
}

@ -0,0 +1,232 @@
package dtls
import (
"crypto"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
"encoding/asn1"
"encoding/binary"
"math/big"
"time"
)
type ecdsaSignature struct {
R, S *big.Int
}
func valueKeyMessage(clientRandom, serverRandom, publicKey []byte, namedCurve namedCurve) []byte {
serverECDHParams := make([]byte, 4)
serverECDHParams[0] = 3 // named curve
binary.BigEndian.PutUint16(serverECDHParams[1:], uint16(namedCurve))
serverECDHParams[3] = byte(len(publicKey))
plaintext := []byte{}
plaintext = append(plaintext, clientRandom...)
plaintext = append(plaintext, serverRandom...)
plaintext = append(plaintext, serverECDHParams...)
plaintext = append(plaintext, publicKey...)
return plaintext
}
// If the client provided a "signature_algorithms" extension, then all
// certificates provided by the server MUST be signed by a
// hash/signature algorithm pair that appears in that extension
//
// https://tools.ietf.org/html/rfc5246#section-7.4.2
func generateKeySignature(clientRandom, serverRandom, publicKey []byte, namedCurve namedCurve, privateKey crypto.PrivateKey, hashAlgorithm hashAlgorithm) ([]byte, error) {
msg := valueKeyMessage(clientRandom, serverRandom, publicKey, namedCurve)
switch p := privateKey.(type) {
case ed25519.PrivateKey:
// https://crypto.stackexchange.com/a/55483
return p.Sign(rand.Reader, msg, crypto.Hash(0))
case *ecdsa.PrivateKey:
hashed := hashAlgorithm.digest(msg)
return p.Sign(rand.Reader, hashed, hashAlgorithm.cryptoHash())
case *rsa.PrivateKey:
hashed := hashAlgorithm.digest(msg)
return p.Sign(rand.Reader, hashed, hashAlgorithm.cryptoHash())
}
return nil, errKeySignatureGenerateUnimplemented
}
func verifyKeySignature(message, remoteKeySignature []byte, hashAlgorithm hashAlgorithm, rawCertificates [][]byte) error { //nolint:dupl
if len(rawCertificates) == 0 {
return errLengthMismatch
}
certificate, err := x509.ParseCertificate(rawCertificates[0])
if err != nil {
return err
}
switch p := certificate.PublicKey.(type) {
case ed25519.PublicKey:
if ok := ed25519.Verify(p, message, remoteKeySignature); !ok {
return errKeySignatureMismatch
}
return nil
case *ecdsa.PublicKey:
ecdsaSig := &ecdsaSignature{}
if _, err := asn1.Unmarshal(remoteKeySignature, ecdsaSig); err != nil {
return err
}
if ecdsaSig.R.Sign() <= 0 || ecdsaSig.S.Sign() <= 0 {
return errInvalidECDSASignature
}
hashed := hashAlgorithm.digest(message)
if !ecdsa.Verify(p, hashed, ecdsaSig.R, ecdsaSig.S) {
return errKeySignatureMismatch
}
return nil
case *rsa.PublicKey:
switch certificate.SignatureAlgorithm {
case x509.SHA1WithRSA, x509.SHA256WithRSA, x509.SHA384WithRSA, x509.SHA512WithRSA:
hashed := hashAlgorithm.digest(message)
return rsa.VerifyPKCS1v15(p, hashAlgorithm.cryptoHash(), hashed, remoteKeySignature)
default:
return errKeySignatureVerifyUnimplemented
}
}
return errKeySignatureVerifyUnimplemented
}
// If the server has sent a CertificateRequest message, the client MUST send the Certificate
// message. The ClientKeyExchange message is now sent, and the content
// of that message will depend on the public key algorithm selected
// between the ClientHello and the ServerHello. If the client has sent
// a certificate with signing ability, a digitally-signed
// CertificateVerify message is sent to explicitly verify possession of
// the private key in the certificate.
// https://tools.ietf.org/html/rfc5246#section-7.3
func generateCertificateVerify(handshakeBodies []byte, privateKey crypto.PrivateKey, hashAlgorithm hashAlgorithm) ([]byte, error) {
h := sha256.New()
if _, err := h.Write(handshakeBodies); err != nil {
return nil, err
}
hashed := h.Sum(nil)
switch p := privateKey.(type) {
case ed25519.PrivateKey:
// https://crypto.stackexchange.com/a/55483
return p.Sign(rand.Reader, hashed, crypto.Hash(0))
case *ecdsa.PrivateKey:
return p.Sign(rand.Reader, hashed, hashAlgorithm.cryptoHash())
case *rsa.PrivateKey:
return p.Sign(rand.Reader, hashed, hashAlgorithm.cryptoHash())
}
return nil, errInvalidSignatureAlgorithm
}
func verifyCertificateVerify(handshakeBodies []byte, hashAlgorithm hashAlgorithm, remoteKeySignature []byte, rawCertificates [][]byte) error { //nolint:dupl
if len(rawCertificates) == 0 {
return errLengthMismatch
}
certificate, err := x509.ParseCertificate(rawCertificates[0])
if err != nil {
return err
}
switch p := certificate.PublicKey.(type) {
case ed25519.PublicKey:
if ok := ed25519.Verify(p, handshakeBodies, remoteKeySignature); !ok {
return errKeySignatureMismatch
}
return nil
case *ecdsa.PublicKey:
ecdsaSig := &ecdsaSignature{}
if _, err := asn1.Unmarshal(remoteKeySignature, ecdsaSig); err != nil {
return err
}
if ecdsaSig.R.Sign() <= 0 || ecdsaSig.S.Sign() <= 0 {
return errInvalidECDSASignature
}
hash := hashAlgorithm.digest(handshakeBodies)
if !ecdsa.Verify(p, hash, ecdsaSig.R, ecdsaSig.S) {
return errKeySignatureMismatch
}
return nil
case *rsa.PublicKey:
switch certificate.SignatureAlgorithm {
case x509.SHA1WithRSA, x509.SHA256WithRSA, x509.SHA384WithRSA, x509.SHA512WithRSA:
hash := hashAlgorithm.digest(handshakeBodies)
return rsa.VerifyPKCS1v15(p, hashAlgorithm.cryptoHash(), hash, remoteKeySignature)
default:
return errKeySignatureVerifyUnimplemented
}
}
return errKeySignatureVerifyUnimplemented
}
func loadCerts(rawCertificates [][]byte) ([]*x509.Certificate, error) {
if len(rawCertificates) == 0 {
return nil, errLengthMismatch
}
certs := make([]*x509.Certificate, 0, len(rawCertificates))
for _, rawCert := range rawCertificates {
cert, err := x509.ParseCertificate(rawCert)
if err != nil {
return nil, err
}
certs = append(certs, cert)
}
return certs, nil
}
func verifyClientCert(rawCertificates [][]byte, roots *x509.CertPool) (chains [][]*x509.Certificate, err error) {
certificate, err := loadCerts(rawCertificates)
if err != nil {
return nil, err
}
intermediateCAPool := x509.NewCertPool()
for _, cert := range certificate[1:] {
intermediateCAPool.AddCert(cert)
}
opts := x509.VerifyOptions{
Roots: roots,
CurrentTime: time.Now(),
Intermediates: intermediateCAPool,
KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
}
return certificate[0].Verify(opts)
}
func verifyServerCert(rawCertificates [][]byte, roots *x509.CertPool, serverName string) (chains [][]*x509.Certificate, err error) {
certificate, err := loadCerts(rawCertificates)
if err != nil {
return nil, err
}
intermediateCAPool := x509.NewCertPool()
for _, cert := range certificate[1:] {
intermediateCAPool.AddCert(cert)
}
opts := x509.VerifyOptions{
Roots: roots,
CurrentTime: time.Now(),
DNSName: serverName,
Intermediates: intermediateCAPool,
}
return certificate[0].Verify(opts)
}
func generateAEADAdditionalData(h *recordLayerHeader, payloadLen int) []byte {
var additionalData [13]byte
// SequenceNumber MUST be set first
// we only want uint48, clobbering an extra 2 (using uint64, Golang doesn't have uint48)
binary.BigEndian.PutUint64(additionalData[:], h.sequenceNumber)
binary.BigEndian.PutUint16(additionalData[:], h.epoch)
additionalData[8] = byte(h.contentType)
additionalData[9] = h.protocolVersion.major
additionalData[10] = h.protocolVersion.minor
binary.BigEndian.PutUint16(additionalData[len(additionalData)-2:], uint16(payloadLen))
return additionalData[:]
}

@ -0,0 +1,133 @@
package dtls
import ( //nolint:gci
"crypto/aes"
"crypto/cipher"
"crypto/hmac"
"crypto/rand"
"crypto/sha1" //nolint:gosec
"encoding/binary"
)
// block ciphers using cipher block chaining.
type cbcMode interface {
cipher.BlockMode
SetIV([]byte)
}
// State needed to handle encrypted input/output
type cryptoCBC struct {
writeCBC, readCBC cbcMode
writeMac, readMac []byte
}
// Currently hardcoded to be SHA1 only
var cryptoCBCMacFunc = sha1.New //nolint:gochecknoglobals
func newCryptoCBC(localKey, localWriteIV, localMac, remoteKey, remoteWriteIV, remoteMac []byte) (*cryptoCBC, error) {
writeBlock, err := aes.NewCipher(localKey)
if err != nil {
return nil, err
}
readBlock, err := aes.NewCipher(remoteKey)
if err != nil {
return nil, err
}
return &cryptoCBC{
writeCBC: cipher.NewCBCEncrypter(writeBlock, localWriteIV).(cbcMode),
writeMac: localMac,
readCBC: cipher.NewCBCDecrypter(readBlock, remoteWriteIV).(cbcMode),
readMac: remoteMac,
}, nil
}
func (c *cryptoCBC) encrypt(pkt *recordLayer, raw []byte) ([]byte, error) {
payload := raw[recordLayerHeaderSize:]
raw = raw[:recordLayerHeaderSize]
blockSize := c.writeCBC.BlockSize()
// Generate + Append MAC
h := pkt.recordLayerHeader
MAC, err := prfMac(h.epoch, h.sequenceNumber, h.contentType, h.protocolVersion, payload, c.writeMac)
if err != nil {
return nil, err
}
payload = append(payload, MAC...)
// Generate + Append padding
padding := make([]byte, blockSize-len(payload)%blockSize)
paddingLen := len(padding)
for i := 0; i < paddingLen; i++ {
padding[i] = byte(paddingLen - 1)
}
payload = append(payload, padding...)
// Generate IV
iv := make([]byte, blockSize)
if _, err := rand.Read(iv); err != nil {
return nil, err
}
// Set IV + Encrypt + Prepend IV
c.writeCBC.SetIV(iv)
c.writeCBC.CryptBlocks(payload, payload)
payload = append(iv, payload...)
// Prepend unencrypte header with encrypted payload
raw = append(raw, payload...)
// Update recordLayer size to include IV+MAC+Padding
binary.BigEndian.PutUint16(raw[recordLayerHeaderSize-2:], uint16(len(raw)-recordLayerHeaderSize))
return raw, nil
}
func (c *cryptoCBC) decrypt(in []byte) ([]byte, error) {
body := in[recordLayerHeaderSize:]
blockSize := c.readCBC.BlockSize()
mac := cryptoCBCMacFunc()
var h recordLayerHeader
err := h.Unmarshal(in)
switch {
case err != nil:
return nil, err
case h.contentType == contentTypeChangeCipherSpec:
// Nothing to encrypt with ChangeCipherSpec
return in, nil
case len(body)%blockSize != 0 || len(body) < blockSize+max(mac.Size()+1, blockSize):
return nil, errNotEnoughRoomForNonce
}
// Set + remove per record IV
c.readCBC.SetIV(body[:blockSize])
body = body[blockSize:]
// Decrypt
c.readCBC.CryptBlocks(body, body)
// Padding+MAC needs to be checked in constant time
// Otherwise we reveal information about the level of correctness
paddingLen, paddingGood := examinePadding(body)
macSize := mac.Size()
if len(body) < macSize {
return nil, errInvalidMAC
}
dataEnd := len(body) - macSize - paddingLen
expectedMAC := body[dataEnd : dataEnd+macSize]
actualMAC, err := prfMac(h.epoch, h.sequenceNumber, h.contentType, h.protocolVersion, body[:dataEnd], c.readMac)
// Compute Local MAC and compare
if paddingGood != 255 || err != nil || !hmac.Equal(actualMAC, expectedMAC) {
return nil, errInvalidMAC
}
return append(in[:recordLayerHeaderSize], body[:dataEnd]...), nil
}

@ -0,0 +1,100 @@
package dtls
import (
"crypto/aes"
"crypto/rand"
"encoding/binary"
"errors"
"fmt"
"github.com/pion/dtls/v2/pkg/crypto/ccm"
)
var errDecryptPacket = errors.New("decryptPacket")
type cryptoCCMTagLen int
const (
cryptoCCM8TagLength cryptoCCMTagLen = 8
cryptoCCMTagLength cryptoCCMTagLen = 16
cryptoCCMNonceLength = 12
)
// State needed to handle encrypted input/output
type cryptoCCM struct {
localCCM, remoteCCM ccm.CCM
localWriteIV, remoteWriteIV []byte
tagLen cryptoCCMTagLen
}
func newCryptoCCM(tagLen cryptoCCMTagLen, localKey, localWriteIV, remoteKey, remoteWriteIV []byte) (*cryptoCCM, error) {
localBlock, err := aes.NewCipher(localKey)
if err != nil {
return nil, err
}
localCCM, err := ccm.NewCCM(localBlock, int(tagLen), cryptoCCMNonceLength)
if err != nil {
return nil, err
}
remoteBlock, err := aes.NewCipher(remoteKey)
if err != nil {
return nil, err
}
remoteCCM, err := ccm.NewCCM(remoteBlock, int(tagLen), cryptoCCMNonceLength)
if err != nil {
return nil, err
}
return &cryptoCCM{
localCCM: localCCM,
localWriteIV: localWriteIV,
remoteCCM: remoteCCM,
remoteWriteIV: remoteWriteIV,
tagLen: tagLen,
}, nil
}
func (c *cryptoCCM) encrypt(pkt *recordLayer, raw []byte) ([]byte, error) {
payload := raw[recordLayerHeaderSize:]
raw = raw[:recordLayerHeaderSize]
nonce := append(append([]byte{}, c.localWriteIV[:4]...), make([]byte, 8)...)
if _, err := rand.Read(nonce[4:]); err != nil {
return nil, err
}
additionalData := generateAEADAdditionalData(&pkt.recordLayerHeader, len(payload))
encryptedPayload := c.localCCM.Seal(nil, nonce, payload, additionalData)
encryptedPayload = append(nonce[4:], encryptedPayload...)
raw = append(raw, encryptedPayload...)
// Update recordLayer size to include explicit nonce
binary.BigEndian.PutUint16(raw[recordLayerHeaderSize-2:], uint16(len(raw)-recordLayerHeaderSize))
return raw, nil
}
func (c *cryptoCCM) decrypt(in []byte) ([]byte, error) {
var h recordLayerHeader
err := h.Unmarshal(in)
switch {
case err != nil:
return nil, err
case h.contentType == contentTypeChangeCipherSpec:
// Nothing to encrypt with ChangeCipherSpec
return in, nil
case len(in) <= (8 + recordLayerHeaderSize):
return nil, errNotEnoughRoomForNonce
}
nonce := append(append([]byte{}, c.remoteWriteIV[:4]...), in[recordLayerHeaderSize:recordLayerHeaderSize+8]...)
out := in[recordLayerHeaderSize+8:]
additionalData := generateAEADAdditionalData(&h, len(out)-int(c.tagLen))
out, err = c.remoteCCM.Open(out[:0], nonce, out, additionalData)
if err != nil {
return nil, fmt.Errorf("%w: %v", errDecryptPacket, err)
}
return append(in[:recordLayerHeaderSize], out...), nil
}

@ -0,0 +1,94 @@
package dtls
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/binary"
"fmt"
)
const (
cryptoGCMTagLength = 16
cryptoGCMNonceLength = 12
)
// State needed to handle encrypted input/output
type cryptoGCM struct {
localGCM, remoteGCM cipher.AEAD
localWriteIV, remoteWriteIV []byte
}
func newCryptoGCM(localKey, localWriteIV, remoteKey, remoteWriteIV []byte) (*cryptoGCM, error) {
localBlock, err := aes.NewCipher(localKey)
if err != nil {
return nil, err
}
localGCM, err := cipher.NewGCM(localBlock)
if err != nil {
return nil, err
}
remoteBlock, err := aes.NewCipher(remoteKey)
if err != nil {
return nil, err
}
remoteGCM, err := cipher.NewGCM(remoteBlock)
if err != nil {
return nil, err
}
return &cryptoGCM{
localGCM: localGCM,
localWriteIV: localWriteIV,
remoteGCM: remoteGCM,
remoteWriteIV: remoteWriteIV,
}, nil
}
func (c *cryptoGCM) encrypt(pkt *recordLayer, raw []byte) ([]byte, error) {
payload := raw[recordLayerHeaderSize:]
raw = raw[:recordLayerHeaderSize]
nonce := make([]byte, cryptoGCMNonceLength)
copy(nonce, c.localWriteIV[:4])
if _, err := rand.Read(nonce[4:]); err != nil {
return nil, err
}
additionalData := generateAEADAdditionalData(&pkt.recordLayerHeader, len(payload))
encryptedPayload := c.localGCM.Seal(nil, nonce, payload, additionalData)
r := make([]byte, len(raw)+len(nonce[4:])+len(encryptedPayload))
copy(r, raw)
copy(r[len(raw):], nonce[4:])
copy(r[len(raw)+len(nonce[4:]):], encryptedPayload)
// Update recordLayer size to include explicit nonce
binary.BigEndian.PutUint16(r[recordLayerHeaderSize-2:], uint16(len(r)-recordLayerHeaderSize))
return r, nil
}
func (c *cryptoGCM) decrypt(in []byte) ([]byte, error) {
var h recordLayerHeader
err := h.Unmarshal(in)
switch {
case err != nil:
return nil, err
case h.contentType == contentTypeChangeCipherSpec:
// Nothing to encrypt with ChangeCipherSpec
return in, nil
case len(in) <= (8 + recordLayerHeaderSize):
return nil, errNotEnoughRoomForNonce
}
nonce := make([]byte, 0, cryptoGCMNonceLength)
nonce = append(append(nonce, c.remoteWriteIV[:4]...), in[recordLayerHeaderSize:recordLayerHeaderSize+8]...)
out := in[recordLayerHeaderSize+8:]
additionalData := generateAEADAdditionalData(&h, len(out)-cryptoGCMTagLength)
out, err = c.remoteGCM.Open(out[:0], nonce, out, additionalData)
if err != nil {
return nil, fmt.Errorf("%w: %v", errDecryptPacket, err)
}
return append(in[:recordLayerHeaderSize], out...), nil
}

@ -0,0 +1,14 @@
package dtls
// https://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml#tls-parameters-10
type ellipticCurveType byte
const (
ellipticCurveTypeNamedCurve ellipticCurveType = 0x03
)
func ellipticCurveTypes() map[ellipticCurveType]bool {
return map[ellipticCurveType]bool{
ellipticCurveTypeNamedCurve: true,
}
}

@ -0,0 +1,2 @@
// Package dtls implements Datagram Transport Layer Security (DTLS) 1.2
package dtls

@ -0,0 +1,229 @@
package dtls
import (
"context"
"errors"
"fmt"
"io"
"net"
"os"
"golang.org/x/xerrors"
)
// Typed errors
var (
ErrConnClosed = &FatalError{errors.New("conn is closed")} //nolint:goerr113
errDeadlineExceeded = &TimeoutError{xerrors.Errorf("read/write timeout: %w", context.DeadlineExceeded)}
errBufferTooSmall = &TemporaryError{errors.New("buffer is too small")} //nolint:goerr113
errContextUnsupported = &TemporaryError{errors.New("context is not supported for ExportKeyingMaterial")} //nolint:goerr113
errDTLSPacketInvalidLength = &TemporaryError{errors.New("packet is too short")} //nolint:goerr113
errHandshakeInProgress = &TemporaryError{errors.New("handshake is in progress")} //nolint:goerr113
errInvalidContentType = &TemporaryError{errors.New("invalid content type")} //nolint:goerr113
errInvalidMAC = &TemporaryError{errors.New("invalid mac")} //nolint:goerr113
errInvalidPacketLength = &TemporaryError{errors.New("packet length and declared length do not match")} //nolint:goerr113
errReservedExportKeyingMaterial = &TemporaryError{errors.New("ExportKeyingMaterial can not be used with a reserved label")} //nolint:goerr113
errCertificateVerifyNoCertificate = &FatalError{errors.New("client sent certificate verify but we have no certificate to verify")} //nolint:goerr113
errCipherSuiteNoIntersection = &FatalError{errors.New("client+server do not support any shared cipher suites")} //nolint:goerr113
errCipherSuiteUnset = &FatalError{errors.New("server hello can not be created without a cipher suite")} //nolint:goerr113
errClientCertificateNotVerified = &FatalError{errors.New("client sent certificate but did not verify it")} //nolint:goerr113
errClientCertificateRequired = &FatalError{errors.New("server required client verification, but got none")} //nolint:goerr113
errClientNoMatchingSRTPProfile = &FatalError{errors.New("server responded with SRTP Profile we do not support")} //nolint:goerr113
errClientRequiredButNoServerEMS = &FatalError{errors.New("client required Extended Master Secret extension, but server does not support it")} //nolint:goerr113
errCompressionMethodUnset = &FatalError{errors.New("server hello can not be created without a compression method")} //nolint:goerr113
errCookieMismatch = &FatalError{errors.New("client+server cookie does not match")} //nolint:goerr113
errCookieTooLong = &FatalError{errors.New("cookie must not be longer then 255 bytes")} //nolint:goerr113
errIdentityNoPSK = &FatalError{errors.New("PSK Identity Hint provided but PSK is nil")} //nolint:goerr113
errInvalidCertificate = &FatalError{errors.New("no certificate provided")} //nolint:goerr113
errInvalidCipherSpec = &FatalError{errors.New("cipher spec invalid")} //nolint:goerr113
errInvalidCipherSuite = &FatalError{errors.New("invalid or unknown cipher suite")} //nolint:goerr113
errInvalidClientKeyExchange = &FatalError{errors.New("unable to determine if ClientKeyExchange is a public key or PSK Identity")} //nolint:goerr113
errInvalidCompressionMethod = &FatalError{errors.New("invalid or unknown compression method")} //nolint:goerr113
errInvalidECDSASignature = &FatalError{errors.New("ECDSA signature contained zero or negative values")} //nolint:goerr113
errInvalidEllipticCurveType = &FatalError{errors.New("invalid or unknown elliptic curve type")} //nolint:goerr113
errInvalidExtensionType = &FatalError{errors.New("invalid extension type")} //nolint:goerr113
errInvalidHashAlgorithm = &FatalError{errors.New("invalid hash algorithm")} //nolint:goerr113
errInvalidNamedCurve = &FatalError{errors.New("invalid named curve")} //nolint:goerr113
errInvalidPrivateKey = &FatalError{errors.New("invalid private key type")} //nolint:goerr113
errInvalidSNIFormat = &FatalError{errors.New("invalid server name format")} //nolint:goerr113
errInvalidSignatureAlgorithm = &FatalError{errors.New("invalid signature algorithm")} //nolint:goerr113
errKeySignatureMismatch = &FatalError{errors.New("expected and actual key signature do not match")} //nolint:goerr113
errNilNextConn = &FatalError{errors.New("Conn can not be created with a nil nextConn")} //nolint:goerr113
errNoAvailableCipherSuites = &FatalError{errors.New("connection can not be created, no CipherSuites satisfy this Config")} //nolint:goerr113
errNoAvailableSignatureSchemes = &FatalError{errors.New("connection can not be created, no SignatureScheme satisfy this Config")} //nolint:goerr113
errNoCertificates = &FatalError{errors.New("no certificates configured")} //nolint:goerr113
errNoConfigProvided = &FatalError{errors.New("no config provided")} //nolint:goerr113
errNoSupportedEllipticCurves = &FatalError{errors.New("client requested zero or more elliptic curves that are not supported by the server")} //nolint:goerr113
errUnsupportedProtocolVersion = &FatalError{errors.New("unsupported protocol version")} //nolint:goerr113
errPSKAndCertificate = &FatalError{errors.New("Certificate and PSK provided")} //nolint:stylecheck
errPSKAndIdentityMustBeSetForClient = &FatalError{errors.New("PSK and PSK Identity Hint must both be set for client")} //nolint:goerr113
errRequestedButNoSRTPExtension = &FatalError{errors.New("SRTP support was requested but server did not respond with use_srtp extension")} //nolint:goerr113
errServerMustHaveCertificate = &FatalError{errors.New("Certificate is mandatory for server")} //nolint:stylecheck
errServerNoMatchingSRTPProfile = &FatalError{errors.New("client requested SRTP but we have no matching profiles")} //nolint:goerr113
errServerRequiredButNoClientEMS = &FatalError{errors.New("server requires the Extended Master Secret extension, but the client does not support it")} //nolint:goerr113
errVerifyDataMismatch = &FatalError{errors.New("expected and actual verify data does not match")} //nolint:goerr113
errHandshakeMessageUnset = &InternalError{errors.New("handshake message unset, unable to marshal")} //nolint:goerr113
errInvalidFlight = &InternalError{errors.New("invalid flight number")} //nolint:goerr113
errKeySignatureGenerateUnimplemented = &InternalError{errors.New("unable to generate key signature, unimplemented")} //nolint:goerr113
errKeySignatureVerifyUnimplemented = &InternalError{errors.New("unable to verify key signature, unimplemented")} //nolint:goerr113
errLengthMismatch = &InternalError{errors.New("data length and declared length do not match")} //nolint:goerr113
errNotEnoughRoomForNonce = &InternalError{errors.New("buffer not long enough to contain nonce")} //nolint:goerr113
errNotImplemented = &InternalError{errors.New("feature has not been implemented yet")} //nolint:goerr113
errSequenceNumberOverflow = &InternalError{errors.New("sequence number overflow")} //nolint:goerr113
errUnableToMarshalFragmented = &InternalError{errors.New("unable to marshal fragmented handshakes")} //nolint:goerr113
)
// FatalError indicates that the DTLS connection is no longer available.
// It is mainly caused by wrong configuration of server or client.
type FatalError struct {
Err error
}
// InternalError indicates and internal error caused by the implementation, and the DTLS connection is no longer available.
// It is mainly caused by bugs or tried to use unimplemented features.
type InternalError struct {
Err error
}
// TemporaryError indicates that the DTLS connection is still available, but the request was failed temporary.
type TemporaryError struct {
Err error
}
// TimeoutError indicates that the request was timed out.
type TimeoutError struct {
Err error
}
// HandshakeError indicates that the handshake failed.
type HandshakeError struct {
Err error
}
// invalidCipherSuite indicates an attempt at using an unsupported cipher suite.
type invalidCipherSuite struct {
id CipherSuiteID
}
func (e *invalidCipherSuite) Error() string {
return fmt.Sprintf("CipherSuite with id(%d) is not valid", e.id)
}
func (e *invalidCipherSuite) Is(err error) bool {
if other, ok := err.(*invalidCipherSuite); ok {
return e.id == other.id
}
return false
}
// Timeout implements net.Error.Timeout()
func (*FatalError) Timeout() bool { return false }
// Temporary implements net.Error.Temporary()
func (*FatalError) Temporary() bool { return false }
// Unwrap implements Go1.13 error unwrapper.
func (e *FatalError) Unwrap() error { return e.Err }
func (e *FatalError) Error() string { return fmt.Sprintf("dtls fatal: %v", e.Err) }
// Timeout implements net.Error.Timeout()
func (*InternalError) Timeout() bool { return false }
// Temporary implements net.Error.Temporary()
func (*InternalError) Temporary() bool { return false }
// Unwrap implements Go1.13 error unwrapper.
func (e *InternalError) Unwrap() error { return e.Err }
func (e *InternalError) Error() string { return fmt.Sprintf("dtls internal: %v", e.Err) }
// Timeout implements net.Error.Timeout()
func (*TemporaryError) Timeout() bool { return false }
// Temporary implements net.Error.Temporary()
func (*TemporaryError) Temporary() bool { return true }
// Unwrap implements Go1.13 error unwrapper.
func (e *TemporaryError) Unwrap() error { return e.Err }
func (e *TemporaryError) Error() string { return fmt.Sprintf("dtls temporary: %v", e.Err) }
// Timeout implements net.Error.Timeout()
func (*TimeoutError) Timeout() bool { return true }
// Temporary implements net.Error.Temporary()
func (*TimeoutError) Temporary() bool { return true }
// Unwrap implements Go1.13 error unwrapper.
func (e *TimeoutError) Unwrap() error { return e.Err }
func (e *TimeoutError) Error() string { return fmt.Sprintf("dtls timeout: %v", e.Err) }
// Timeout implements net.Error.Timeout()
func (e *HandshakeError) Timeout() bool {
if netErr, ok := e.Err.(net.Error); ok {
return netErr.Timeout()
}
return false
}
// Temporary implements net.Error.Temporary()
func (e *HandshakeError) Temporary() bool {
if netErr, ok := e.Err.(net.Error); ok {
return netErr.Temporary()
}
return false
}
// Unwrap implements Go1.13 error unwrapper.
func (e *HandshakeError) Unwrap() error { return e.Err }
func (e *HandshakeError) Error() string { return fmt.Sprintf("handshake error: %v", e.Err) }
// errAlert wraps DTLS alert notification as an error
type errAlert struct {
*alert
}
func (e *errAlert) Error() string {
return fmt.Sprintf("alert: %s", e.alert.String())
}
func (e *errAlert) IsFatalOrCloseNotify() bool {
return e.alertLevel == alertLevelFatal || e.alertDescription == alertCloseNotify
}
func (e *errAlert) Is(err error) bool {
if other, ok := err.(*errAlert); ok {
return e.alertLevel == other.alertLevel && e.alertDescription == other.alertDescription
}
return false
}
// netError translates an error from underlying Conn to corresponding net.Error.
func netError(err error) error {
switch err {
case io.EOF, context.Canceled, context.DeadlineExceeded:
// Return io.EOF and context errors as is.
return err
}
switch e := err.(type) {
case (*net.OpError):
if se, ok := e.Err.(*os.SyscallError); ok {
if se.Timeout() {
return &TimeoutError{err}
}
if isOpErrorTemporary(se) {
return &TemporaryError{err}
}
}
case (net.Error):
return err
}
return &FatalError{err}
}

@ -0,0 +1,25 @@
// +build aix darwin dragonfly freebsd linux nacl nacljs netbsd openbsd solaris windows
// For systems having syscall.Errno.
// Update build targets by following command:
// $ grep -R ECONN $(go env GOROOT)/src/syscall/zerrors_*.go \
// | tr "." "_" | cut -d"_" -f"2" | sort | uniq
package dtls
import (
"os"
"syscall"
)
func isOpErrorTemporary(err *os.SyscallError) bool {
if ne, ok := err.Err.(syscall.Errno); ok {
switch ne {
case syscall.ECONNREFUSED:
return true
default:
return false
}
}
return false
}

@ -0,0 +1,14 @@
// +build !aix,!darwin,!dragonfly,!freebsd,!linux,!nacl,!nacljs,!netbsd,!openbsd,!solaris,!windows
// For systems without syscall.Errno.
// Build targets must be inverse of errors_errno.go
package dtls
import (
"os"
)
func isOpErrorTemporary(err *os.SyscallError) bool {
return false
}

@ -0,0 +1,88 @@
package dtls
import (
"encoding/binary"
)
// https://www.iana.org/assignments/tls-extensiontype-values/tls-extensiontype-values.xhtml
type extensionValue uint16
const (
extensionServerNameValue extensionValue = 0
extensionSupportedEllipticCurvesValue extensionValue = 10
extensionSupportedPointFormatsValue extensionValue = 11
extensionSupportedSignatureAlgorithmsValue extensionValue = 13
extensionUseSRTPValue extensionValue = 14
extensionUseExtendedMasterSecretValue extensionValue = 23
extensionRenegotiationInfoValue extensionValue = 65281
)
type extension interface {
Marshal() ([]byte, error)
Unmarshal(data []byte) error
extensionValue() extensionValue
}
func decodeExtensions(buf []byte) ([]extension, error) {
if len(buf) < 2 {
return nil, errBufferTooSmall
}
declaredLen := binary.BigEndian.Uint16(buf)
if len(buf)-2 != int(declaredLen) {
return nil, errLengthMismatch
}
extensions := []extension{}
unmarshalAndAppend := func(data []byte, e extension) error {
err := e.Unmarshal(data)
if err != nil {
return err
}
extensions = append(extensions, e)
return nil
}
for offset := 2; offset < len(buf); {
if len(buf) < (offset + 2) {
return nil, errBufferTooSmall
}
var err error
switch extensionValue(binary.BigEndian.Uint16(buf[offset:])) {
case extensionServerNameValue:
err = unmarshalAndAppend(buf[offset:], &extensionServerName{})
case extensionSupportedEllipticCurvesValue:
err = unmarshalAndAppend(buf[offset:], &extensionSupportedEllipticCurves{})
case extensionUseSRTPValue:
err = unmarshalAndAppend(buf[offset:], &extensionUseSRTP{})
case extensionUseExtendedMasterSecretValue:
err = unmarshalAndAppend(buf[offset:], &extensionUseExtendedMasterSecret{})
case extensionRenegotiationInfoValue:
err = unmarshalAndAppend(buf[offset:], &extensionRenegotiationInfo{})
default:
}
if err != nil {
return nil, err
}
if len(buf) < (offset + 4) {
return nil, errBufferTooSmall
}
extensionLength := binary.BigEndian.Uint16(buf[offset+2:])
offset += (4 + int(extensionLength))
}
return extensions, nil
}
func encodeExtensions(e []extension) ([]byte, error) {
extensions := []byte{}
for _, e := range e {
raw, err := e.Marshal()
if err != nil {
return nil, err
}
extensions = append(extensions, raw...)
}
out := []byte{0x00, 0x00}
binary.BigEndian.PutUint16(out, uint16(len(extensions)))
return append(out, extensions...), nil
}

@ -0,0 +1,37 @@
package dtls
import "encoding/binary"
const (
extensionRenegotiationInfoHeaderSize = 5
)
// https://tools.ietf.org/html/rfc5746
type extensionRenegotiationInfo struct {
renegotiatedConnection uint8
}
func (e extensionRenegotiationInfo) extensionValue() extensionValue {
return extensionRenegotiationInfoValue
}
func (e *extensionRenegotiationInfo) Marshal() ([]byte, error) {
out := make([]byte, extensionRenegotiationInfoHeaderSize)
binary.BigEndian.PutUint16(out, uint16(e.extensionValue()))
binary.BigEndian.PutUint16(out[2:], uint16(1)) // length
out[4] = e.renegotiatedConnection
return out, nil
}
func (e *extensionRenegotiationInfo) Unmarshal(data []byte) error {
if len(data) < extensionRenegotiationInfoHeaderSize {
return errBufferTooSmall
} else if extensionValue(binary.BigEndian.Uint16(data)) != e.extensionValue() {
return errInvalidExtensionType
}
e.renegotiatedConnection = data[4]
return nil
}

@ -0,0 +1,70 @@
package dtls
import (
"strings"
"golang.org/x/crypto/cryptobyte"
)
const extensionServerNameTypeDNSHostName = 0
type extensionServerName struct {
serverName string
}
func (e extensionServerName) extensionValue() extensionValue {
return extensionServerNameValue
}
func (e *extensionServerName) Marshal() ([]byte, error) {
var b cryptobyte.Builder
b.AddUint16(uint16(e.extensionValue()))
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddUint8(extensionServerNameTypeDNSHostName)
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes([]byte(e.serverName))
})
})
})
return b.Bytes()
}
func (e *extensionServerName) Unmarshal(data []byte) error {
s := cryptobyte.String(data)
var extension uint16
s.ReadUint16(&extension)
if extensionValue(extension) != e.extensionValue() {
return errInvalidExtensionType
}
var extData cryptobyte.String
s.ReadUint16LengthPrefixed(&extData)
var nameList cryptobyte.String
if !extData.ReadUint16LengthPrefixed(&nameList) || nameList.Empty() {
return errInvalidSNIFormat
}
for !nameList.Empty() {
var nameType uint8
var serverName cryptobyte.String
if !nameList.ReadUint8(&nameType) ||
!nameList.ReadUint16LengthPrefixed(&serverName) ||
serverName.Empty() {
return errInvalidSNIFormat
}
if nameType != extensionServerNameTypeDNSHostName {
continue
}
if len(e.serverName) != 0 {
// Multiple names of the same name_type are prohibited.
return errInvalidSNIFormat
}
e.serverName = string(serverName)
// An SNI value may not include a trailing dot.
if strings.HasSuffix(e.serverName, ".") {
return errInvalidSNIFormat
}
}
return nil
}

@ -0,0 +1,54 @@
package dtls
import (
"encoding/binary"
)
const (
extensionSupportedGroupsHeaderSize = 6
)
// https://tools.ietf.org/html/rfc8422#section-5.1.1
type extensionSupportedEllipticCurves struct {
ellipticCurves []namedCurve
}
func (e extensionSupportedEllipticCurves) extensionValue() extensionValue {
return extensionSupportedEllipticCurvesValue
}
func (e *extensionSupportedEllipticCurves) Marshal() ([]byte, error) {
out := make([]byte, extensionSupportedGroupsHeaderSize)
binary.BigEndian.PutUint16(out, uint16(e.extensionValue()))
binary.BigEndian.PutUint16(out[2:], uint16(2+(len(e.ellipticCurves)*2)))
binary.BigEndian.PutUint16(out[4:], uint16(len(e.ellipticCurves)*2))
for _, v := range e.ellipticCurves {
out = append(out, []byte{0x00, 0x00}...)
binary.BigEndian.PutUint16(out[len(out)-2:], uint16(v))
}
return out, nil
}
func (e *extensionSupportedEllipticCurves) Unmarshal(data []byte) error {
if len(data) <= extensionSupportedGroupsHeaderSize {
return errBufferTooSmall
} else if extensionValue(binary.BigEndian.Uint16(data)) != e.extensionValue() {
return errInvalidExtensionType
}
groupCount := int(binary.BigEndian.Uint16(data[4:]) / 2)
if extensionSupportedGroupsHeaderSize+(groupCount*2) > len(data) {
return errLengthMismatch
}
for i := 0; i < groupCount; i++ {
supportedGroupID := namedCurve(binary.BigEndian.Uint16(data[(extensionSupportedGroupsHeaderSize + (i * 2)):]))
if _, ok := namedCurves()[supportedGroupID]; ok {
e.ellipticCurves = append(e.ellipticCurves, supportedGroupID)
}
}
return nil
}

@ -0,0 +1,56 @@
package dtls
import "encoding/binary"
const (
extensionSupportedPointFormatsSize = 5
)
type ellipticCurvePointFormat byte
const ellipticCurvePointFormatUncompressed ellipticCurvePointFormat = 0
// https://tools.ietf.org/html/rfc4492#section-5.1.2
type extensionSupportedPointFormats struct {
pointFormats []ellipticCurvePointFormat
}
func (e extensionSupportedPointFormats) extensionValue() extensionValue {
return extensionSupportedPointFormatsValue
}
func (e *extensionSupportedPointFormats) Marshal() ([]byte, error) {
out := make([]byte, extensionSupportedPointFormatsSize)
binary.BigEndian.PutUint16(out, uint16(e.extensionValue()))
binary.BigEndian.PutUint16(out[2:], uint16(1+(len(e.pointFormats))))
out[4] = byte(len(e.pointFormats))
for _, v := range e.pointFormats {
out = append(out, byte(v))
}
return out, nil
}
func (e *extensionSupportedPointFormats) Unmarshal(data []byte) error {
if len(data) <= extensionSupportedPointFormatsSize {
return errBufferTooSmall
} else if extensionValue(binary.BigEndian.Uint16(data)) != e.extensionValue() {
return errInvalidExtensionType
}
pointFormatCount := int(binary.BigEndian.Uint16(data[4:]))
if extensionSupportedGroupsHeaderSize+(pointFormatCount) > len(data) {
return errLengthMismatch
}
for i := 0; i < pointFormatCount; i++ {
p := ellipticCurvePointFormat(data[extensionSupportedPointFormatsSize+i])
switch p {
case ellipticCurvePointFormatUncompressed:
e.pointFormats = append(e.pointFormats, p)
default:
}
}
return nil
}

@ -0,0 +1,60 @@
package dtls
import (
"encoding/binary"
)
const (
extensionSupportedSignatureAlgorithmsHeaderSize = 6
)
// https://tools.ietf.org/html/rfc5246#section-7.4.1.4.1
type extensionSupportedSignatureAlgorithms struct {
signatureHashAlgorithms []signatureHashAlgorithm
}
func (e extensionSupportedSignatureAlgorithms) extensionValue() extensionValue {
return extensionSupportedSignatureAlgorithmsValue
}
func (e *extensionSupportedSignatureAlgorithms) Marshal() ([]byte, error) {
out := make([]byte, extensionSupportedSignatureAlgorithmsHeaderSize)
binary.BigEndian.PutUint16(out, uint16(e.extensionValue()))
binary.BigEndian.PutUint16(out[2:], uint16(2+(len(e.signatureHashAlgorithms)*2)))
binary.BigEndian.PutUint16(out[4:], uint16(len(e.signatureHashAlgorithms)*2))
for _, v := range e.signatureHashAlgorithms {
out = append(out, []byte{0x00, 0x00}...)
out[len(out)-2] = byte(v.hash)
out[len(out)-1] = byte(v.signature)
}
return out, nil
}
func (e *extensionSupportedSignatureAlgorithms) Unmarshal(data []byte) error {
if len(data) <= extensionSupportedSignatureAlgorithmsHeaderSize {
return errBufferTooSmall
} else if extensionValue(binary.BigEndian.Uint16(data)) != e.extensionValue() {
return errInvalidExtensionType
}
algorithmCount := int(binary.BigEndian.Uint16(data[4:]) / 2)
if extensionSupportedSignatureAlgorithmsHeaderSize+(algorithmCount*2) > len(data) {
return errLengthMismatch
}
for i := 0; i < algorithmCount; i++ {
supportedHashAlgorithm := hashAlgorithm(data[extensionSupportedSignatureAlgorithmsHeaderSize+(i*2)])
supportedSignatureAlgorithm := signatureAlgorithm(data[extensionSupportedSignatureAlgorithmsHeaderSize+(i*2)+1])
if _, ok := hashAlgorithms()[supportedHashAlgorithm]; ok {
if _, ok := signatureAlgorithms()[supportedSignatureAlgorithm]; ok {
e.signatureHashAlgorithms = append(e.signatureHashAlgorithms, signatureHashAlgorithm{
supportedHashAlgorithm,
supportedSignatureAlgorithm,
})
}
}
}
return nil
}

@ -0,0 +1,40 @@
package dtls
import "encoding/binary"
const (
extensionUseExtendedMasterSecretHeaderSize = 4
)
// https://tools.ietf.org/html/rfc8422
type extensionUseExtendedMasterSecret struct {
supported bool
}
func (e extensionUseExtendedMasterSecret) extensionValue() extensionValue {
return extensionUseExtendedMasterSecretValue
}
func (e *extensionUseExtendedMasterSecret) Marshal() ([]byte, error) {
if !e.supported {
return []byte{}, nil
}
out := make([]byte, extensionUseExtendedMasterSecretHeaderSize)
binary.BigEndian.PutUint16(out, uint16(e.extensionValue()))
binary.BigEndian.PutUint16(out[2:], uint16(0)) // length
return out, nil
}
func (e *extensionUseExtendedMasterSecret) Unmarshal(data []byte) error {
if len(data) < extensionUseExtendedMasterSecretHeaderSize {
return errBufferTooSmall
} else if extensionValue(binary.BigEndian.Uint16(data)) != e.extensionValue() {
return errInvalidExtensionType
}
e.supported = true
return nil
}

@ -0,0 +1,53 @@
package dtls
import "encoding/binary"
const (
extensionUseSRTPHeaderSize = 6
)
// https://tools.ietf.org/html/rfc8422
type extensionUseSRTP struct {
protectionProfiles []SRTPProtectionProfile
}
func (e extensionUseSRTP) extensionValue() extensionValue {
return extensionUseSRTPValue
}
func (e *extensionUseSRTP) Marshal() ([]byte, error) {
out := make([]byte, extensionUseSRTPHeaderSize)
binary.BigEndian.PutUint16(out, uint16(e.extensionValue()))
binary.BigEndian.PutUint16(out[2:], uint16(2+(len(e.protectionProfiles)*2)+ /* MKI Length */ 1))
binary.BigEndian.PutUint16(out[4:], uint16(len(e.protectionProfiles)*2))
for _, v := range e.protectionProfiles {
out = append(out, []byte{0x00, 0x00}...)
binary.BigEndian.PutUint16(out[len(out)-2:], uint16(v))
}
out = append(out, 0x00) /* MKI Length */
return out, nil
}
func (e *extensionUseSRTP) Unmarshal(data []byte) error {
if len(data) <= extensionUseSRTPHeaderSize {
return errBufferTooSmall
} else if extensionValue(binary.BigEndian.Uint16(data)) != e.extensionValue() {
return errInvalidExtensionType
}
profileCount := int(binary.BigEndian.Uint16(data[4:]) / 2)
if extensionSupportedGroupsHeaderSize+(profileCount*2) > len(data) {
return errLengthMismatch
}
for i := 0; i < profileCount; i++ {
supportedProfile := SRTPProtectionProfile(binary.BigEndian.Uint16(data[(extensionUseSRTPHeaderSize + (i * 2)):]))
if _, ok := srtpProtectionProfiles()[supportedProfile]; ok {
e.protectionProfiles = append(e.protectionProfiles, supportedProfile)
}
}
return nil
}

@ -0,0 +1,75 @@
package dtls
/*
DTLS messages are grouped into a series of message flights, according
to the diagrams below. Although each flight of messages may consist
of a number of messages, they should be viewed as monolithic for the
purpose of timeout and retransmission.
https://tools.ietf.org/html/rfc4347#section-4.2.4
Client Server
------ ------
Waiting Flight 0
ClientHello --------> Flight 1
<------- HelloVerifyRequest Flight 2
ClientHello --------> Flight 3
ServerHello \
Certificate* \
ServerKeyExchange* Flight 4
CertificateRequest* /
<-------- ServerHelloDone /
Certificate* \
ClientKeyExchange \
CertificateVerify* Flight 5
[ChangeCipherSpec] /
Finished --------> /
[ChangeCipherSpec] \ Flight 6
<-------- Finished /
*/
type flightVal uint8
const (
flight0 flightVal = iota + 1
flight1
flight2
flight3
flight4
flight5
flight6
)
func (f flightVal) String() string {
switch f {
case flight0:
return "Flight 0"
case flight1:
return "Flight 1"
case flight2:
return "Flight 2"
case flight3:
return "Flight 3"
case flight4:
return "Flight 4"
case flight5:
return "Flight 5"
case flight6:
return "Flight 6"
default:
return "Invalid Flight"
}
}
func (f flightVal) isLastSendFlight() bool {
return f == flight6
}
func (f flightVal) isLastRecvFlight() bool {
return f == flight5
}

@ -0,0 +1,89 @@
package dtls
import (
"context"
"crypto/rand"
)
func flight0Parse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert, error) {
seq, msgs, ok := cache.fullPullMap(0,
handshakeCachePullRule{handshakeTypeClientHello, cfg.initialEpoch, true, false},
)
if !ok {
// No valid message received. Keep reading
return 0, nil, nil
}
state.handshakeRecvSequence = seq
var clientHello *handshakeMessageClientHello
// Validate type
if clientHello, ok = msgs[handshakeTypeClientHello].(*handshakeMessageClientHello); !ok {
return 0, &alert{alertLevelFatal, alertInternalError}, nil
}
if !clientHello.version.Equal(protocolVersion1_2) {
return 0, &alert{alertLevelFatal, alertProtocolVersion}, errUnsupportedProtocolVersion
}
state.remoteRandom = clientHello.random
if state.cipherSuite, ok = findMatchingCipherSuite(clientHello.cipherSuites, cfg.localCipherSuites); !ok {
return 0, &alert{alertLevelFatal, alertInsufficientSecurity}, errCipherSuiteNoIntersection
}
for _, extension := range clientHello.extensions {
switch e := extension.(type) {
case *extensionSupportedEllipticCurves:
if len(e.ellipticCurves) == 0 {
return 0, &alert{alertLevelFatal, alertInsufficientSecurity}, errNoSupportedEllipticCurves
}
state.namedCurve = e.ellipticCurves[0]
case *extensionUseSRTP:
profile, ok := findMatchingSRTPProfile(e.protectionProfiles, cfg.localSRTPProtectionProfiles)
if !ok {
return 0, &alert{alertLevelFatal, alertInsufficientSecurity}, errServerNoMatchingSRTPProfile
}
state.srtpProtectionProfile = profile
case *extensionUseExtendedMasterSecret:
if cfg.extendedMasterSecret != DisableExtendedMasterSecret {
state.extendedMasterSecret = true
}
case *extensionServerName:
state.serverName = e.serverName // remote server name
}
}
if cfg.extendedMasterSecret == RequireExtendedMasterSecret && !state.extendedMasterSecret {
return 0, &alert{alertLevelFatal, alertInsufficientSecurity}, errServerRequiredButNoClientEMS
}
if state.localKeypair == nil {
var err error
state.localKeypair, err = generateKeypair(state.namedCurve)
if err != nil {
return 0, &alert{alertLevelFatal, alertIllegalParameter}, err
}
}
return flight2, nil, nil
}
func flight0Generate(c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert, error) {
// Initialize
state.cookie = make([]byte, cookieLength)
if _, err := rand.Read(state.cookie); err != nil {
return nil, nil, err
}
var zeroEpoch uint16
state.localEpoch.Store(zeroEpoch)
state.remoteEpoch.Store(zeroEpoch)
state.namedCurve = defaultNamedCurve
if err := state.localRandom.populate(); err != nil {
return nil, nil, err
}
return nil, nil, nil
}

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save