Proxy: Support proxy server for SRS. v7.0.16 (#4158)

Please note that the proxy server is a new architecture or the next
version of the Origin Cluster, which allows the publication of multiple
streams. The SRS origin cluster consists of a group of origin servers
designed to handle a large number of streams.

```text
                         +-----------------------+
                     +---+ SRS Proxy(Deployment) +------+---------------------+
+-----------------+  |   +-----------+-----------+      +                     +
| LB(K8s Service) +--+               +(Redis/MESH)      + SRS Origin Servers  +
+-----------------+  |   +-----------+-----------+      +    (Deployment)     +
                     +---+ SRS Proxy(Deployment) +------+---------------------+
                         +-----------------------+
```

The new origin cluster is designed as a collection of proxy servers. For
more information, see [Discussion
#3634](https://github.com/ossrs/srs/discussions/3634). If you prefer to
use the old origin cluster, please switch to a version before SRS 6.0.

A proxy server can be used for a set of origin servers, which are
isolated and dedicated origin servers. The main improvement in the new
architecture is to store the state for origin servers in the proxy
server, rather than using MESH to communicate between origin servers.
With a proxy server, you can deploy origin servers as stateless servers,
such as in a Kubernetes (K8s) deployment.

Now that the proxy server is a stateful server, it uses Redis to store
the states. For faster development, we use Go to develop the proxy
server, instead of C/C++. Therefore, the proxy server itself is also
stateless, with all states stored in the Redis server or cluster. This
makes the new origin cluster architecture very powerful and robust.

The proxy server is also an architecture designed to solve multiple
process bottlenecks. You can run hundreds of SRS origin servers with one
proxy server on the same machine. This solution can utilize multi-core
machines, such as servers with 128 CPUs. Thus, we can keep SRS
single-threaded and very simple. See
https://github.com/ossrs/srs/discussions/3665#discussioncomment-6474441
for details.

```text
                                       +--------------------+
                               +-------+ SRS Origin Server  +
                               +       +--------------------+
                               +
+-----------------------+      +       +--------------------+
+ SRS Proxy(Deployment) +------+-------+ SRS Origin Server  +
+-----------------------+      +       +--------------------+
                               +
                               +       +--------------------+
                               +-------+ SRS Origin Server  +
                                       +--------------------+
```

Keep in mind that the proxy server for the Origin Cluster is designed to
handle many streams. To address the issue of many viewers, we will
enhance the Edge Cluster to support more protocols.

```text
+------------------+                                               +--------------------+
+ SRS Edge Server  +--+                                    +-------+ SRS Origin Server  +
+------------------+  +                                    +       +--------------------+
                      +                                    +
+------------------+  +     +-----------------------+      +       +--------------------+
+ SRS Edge Server  +--+-----+ SRS Proxy(Deployment) +------+-------+ SRS Origin Server  +
+------------------+  +     +-----------------------+      +       +--------------------+
                      +                                    +
+------------------+  +                                    +       +--------------------+
+ SRS Edge Server  +--+                                    +-------+ SRS Origin Server  +
+------------------+                                               +--------------------+
```

With the new Origin Cluster and Edge Cluster, you have a media system
capable of supporting a large number of streams and viewers. For
example, you can publish 10,000 streams, each with 100,000 viewers.

---------

Co-authored-by: Jacob Su <suzp1984@gmail.com>
pull/4185/head
Winlin 5 months ago committed by GitHub
parent b475d552aa
commit 2e4014ae1c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

4
proxy/.gitignore vendored

@ -0,0 +1,4 @@
.idea
srs-proxy
.env
.go-formarted

@ -0,0 +1,23 @@
.PHONY: all build test fmt clean run
all: build
build: fmt ./srs-proxy
./srs-proxy: *.go
go build -o srs-proxy .
test:
go test ./...
fmt: ./.go-formarted
./.go-formarted: *.go
touch .go-formarted
go fmt ./...
clean:
rm -f srs-proxy .go-formarted
run: fmt
go run .

@ -0,0 +1,272 @@
// Copyright (c) 2024 Winlin
//
// SPDX-License-Identifier: MIT
package main
import (
"context"
"fmt"
"net/http"
"os"
"strings"
"sync"
"time"
"srs-proxy/errors"
"srs-proxy/logger"
)
// srsHTTPAPIServer is the proxy for SRS HTTP API, to proxy the WebRTC HTTP API like WHIP and WHEP,
// to proxy other HTTP API of SRS like the streams and clients, etc.
type srsHTTPAPIServer struct {
// The underlayer HTTP server.
server *http.Server
// The WebRTC server.
rtc *srsWebRTCServer
// The gracefully quit timeout, wait server to quit.
gracefulQuitTimeout time.Duration
// The wait group for all goroutines.
wg sync.WaitGroup
}
func NewSRSHTTPAPIServer(opts ...func(*srsHTTPAPIServer)) *srsHTTPAPIServer {
v := &srsHTTPAPIServer{}
for _, opt := range opts {
opt(v)
}
return v
}
func (v *srsHTTPAPIServer) Close() error {
ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout)
defer cancel()
v.server.Shutdown(ctx)
v.wg.Wait()
return nil
}
func (v *srsHTTPAPIServer) Run(ctx context.Context) error {
// Parse address to listen.
addr := envHttpAPI()
if !strings.Contains(addr, ":") {
addr = ":" + addr
}
// Create server and handler.
mux := http.NewServeMux()
v.server = &http.Server{Addr: addr, Handler: mux}
logger.Df(ctx, "HTTP API server listen at %v", addr)
// Shutdown the server gracefully when quiting.
go func() {
ctxParent := ctx
<-ctxParent.Done()
ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout)
defer cancel()
v.server.Shutdown(ctx)
}()
// The basic version handler, also can be used as health check API.
logger.Df(ctx, "Handle /api/v1/versions by %v", addr)
mux.HandleFunc("/api/v1/versions", func(w http.ResponseWriter, r *http.Request) {
apiResponse(ctx, w, r, map[string]string{
"signature": Signature(),
"version": Version(),
})
})
// The WebRTC WHIP API handler.
logger.Df(ctx, "Handle /rtc/v1/whip/ by %v", addr)
mux.HandleFunc("/rtc/v1/whip/", func(w http.ResponseWriter, r *http.Request) {
if err := v.rtc.HandleApiForWHIP(ctx, w, r); err != nil {
apiError(ctx, w, r, err)
}
})
// The WebRTC WHEP API handler.
logger.Df(ctx, "Handle /rtc/v1/whep/ by %v", addr)
mux.HandleFunc("/rtc/v1/whep/", func(w http.ResponseWriter, r *http.Request) {
if err := v.rtc.HandleApiForWHEP(ctx, w, r); err != nil {
apiError(ctx, w, r, err)
}
})
// Run HTTP API server.
v.wg.Add(1)
go func() {
defer v.wg.Done()
err := v.server.ListenAndServe()
if err != nil {
if ctx.Err() != context.Canceled {
// TODO: If HTTP API server closed unexpectedly, we should notice the main loop to quit.
logger.Wf(ctx, "HTTP API accept err %+v", err)
} else {
logger.Df(ctx, "HTTP API server done")
}
}
}()
return nil
}
// systemAPI is the system HTTP API of the proxy server, for SRS media server to register the service
// to proxy server. It also provides some other system APIs like the status of proxy server, like exporter
// for Prometheus metrics.
type systemAPI struct {
// The underlayer HTTP server.
server *http.Server
// The gracefully quit timeout, wait server to quit.
gracefulQuitTimeout time.Duration
// The wait group for all goroutines.
wg sync.WaitGroup
}
func NewSystemAPI(opts ...func(*systemAPI)) *systemAPI {
v := &systemAPI{}
for _, opt := range opts {
opt(v)
}
return v
}
func (v *systemAPI) Close() error {
ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout)
defer cancel()
v.server.Shutdown(ctx)
v.wg.Wait()
return nil
}
func (v *systemAPI) Run(ctx context.Context) error {
// Parse address to listen.
addr := envSystemAPI()
if !strings.Contains(addr, ":") {
addr = ":" + addr
}
// Create server and handler.
mux := http.NewServeMux()
v.server = &http.Server{Addr: addr, Handler: mux}
logger.Df(ctx, "System API server listen at %v", addr)
// Shutdown the server gracefully when quiting.
go func() {
ctxParent := ctx
<-ctxParent.Done()
ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout)
defer cancel()
v.server.Shutdown(ctx)
}()
// The basic version handler, also can be used as health check API.
logger.Df(ctx, "Handle /api/v1/versions by %v", addr)
mux.HandleFunc("/api/v1/versions", func(w http.ResponseWriter, r *http.Request) {
apiResponse(ctx, w, r, map[string]string{
"signature": Signature(),
"version": Version(),
})
})
// The register service for SRS media servers.
logger.Df(ctx, "Handle /api/v1/srs/register by %v", addr)
mux.HandleFunc("/api/v1/srs/register", func(w http.ResponseWriter, r *http.Request) {
if err := func() error {
var deviceID, ip, serverID, serviceID, pid string
var rtmp, stream, api, srt, rtc []string
if err := ParseBody(r.Body, &struct {
// The IP of SRS, mandatory.
IP *string `json:"ip"`
// The server id of SRS, store in file, may not change, mandatory.
ServerID *string `json:"server"`
// The service id of SRS, always change when restarted, mandatory.
ServiceID *string `json:"service"`
// The process id of SRS, always change when restarted, mandatory.
PID *string `json:"pid"`
// The RTMP listen endpoints, mandatory.
RTMP *[]string `json:"rtmp"`
// The HTTP Stream listen endpoints, optional.
HTTP *[]string `json:"http"`
// The API listen endpoints, optional.
API *[]string `json:"api"`
// The SRT listen endpoints, optional.
SRT *[]string `json:"srt"`
// The RTC listen endpoints, optional.
RTC *[]string `json:"rtc"`
// The device id of SRS, optional.
DeviceID *string `json:"device_id"`
}{
IP: &ip, DeviceID: &deviceID,
ServerID: &serverID, ServiceID: &serviceID, PID: &pid,
RTMP: &rtmp, HTTP: &stream, API: &api, SRT: &srt, RTC: &rtc,
}); err != nil {
return errors.Wrapf(err, "parse body")
}
if ip == "" {
return errors.Errorf("empty ip")
}
if serverID == "" {
return errors.Errorf("empty server")
}
if serviceID == "" {
return errors.Errorf("empty service")
}
if pid == "" {
return errors.Errorf("empty pid")
}
if len(rtmp) == 0 {
return errors.Errorf("empty rtmp")
}
server := NewSRSServer(func(srs *SRSServer) {
srs.IP, srs.DeviceID = ip, deviceID
srs.ServerID, srs.ServiceID, srs.PID = serverID, serviceID, pid
srs.RTMP, srs.HTTP, srs.API = rtmp, stream, api
srs.SRT, srs.RTC = srt, rtc
srs.UpdatedAt = time.Now()
})
if err := srsLoadBalancer.Update(ctx, server); err != nil {
return errors.Wrapf(err, "update SRS server %+v", server)
}
logger.Df(ctx, "Register SRS media server, %+v", server)
return nil
}(); err != nil {
apiError(ctx, w, r, err)
}
type Response struct {
Code int `json:"code"`
PID string `json:"pid"`
}
apiResponse(ctx, w, r, &Response{
Code: 0, PID: fmt.Sprintf("%v", os.Getpid()),
})
})
// Run System API server.
v.wg.Add(1)
go func() {
defer v.wg.Done()
err := v.server.ListenAndServe()
if err != nil {
if ctx.Err() != context.Canceled {
// TODO: If System API server closed unexpectedly, we should notice the main loop to quit.
logger.Wf(ctx, "System API accept err %+v", err)
} else {
logger.Df(ctx, "System API server done")
}
}
}()
return nil
}

@ -0,0 +1,20 @@
// Copyright (c) 2024 Winlin
//
// SPDX-License-Identifier: MIT
package main
import (
"context"
"net/http"
"srs-proxy/logger"
)
func handleGoPprof(ctx context.Context) {
if addr := envGoPprof(); addr != "" {
go func() {
logger.Df(ctx, "Start Go pprof at %v", addr)
http.ListenAndServe(addr, nil)
}()
}
}

@ -0,0 +1,197 @@
// Copyright (c) 2024 Winlin
//
// SPDX-License-Identifier: MIT
package main
import (
"context"
"os"
"path"
"github.com/joho/godotenv"
"srs-proxy/errors"
"srs-proxy/logger"
)
// loadEnvFile loads the environment variables from file. Note that we only use .env file.
func loadEnvFile(ctx context.Context) error {
if workDir, err := os.Getwd(); err != nil {
return errors.Wrapf(err, "getpwd")
} else {
envFile := path.Join(workDir, ".env")
if _, err := os.Stat(envFile); err == nil {
if err := godotenv.Load(envFile); err != nil {
return errors.Wrapf(err, "load %v", envFile)
}
}
}
return nil
}
// buildDefaultEnvironmentVariables setups the default environment variables.
func buildDefaultEnvironmentVariables(ctx context.Context) {
// Whether enable the Go pprof.
setEnvDefault("GO_PPROF", "")
// Force shutdown timeout.
setEnvDefault("PROXY_FORCE_QUIT_TIMEOUT", "30s")
// Graceful quit timeout.
setEnvDefault("PROXY_GRACE_QUIT_TIMEOUT", "20s")
// The HTTP API server.
setEnvDefault("PROXY_HTTP_API", "11985")
// The HTTP web server.
setEnvDefault("PROXY_HTTP_SERVER", "18080")
// The RTMP media server.
setEnvDefault("PROXY_RTMP_SERVER", "11935")
// The WebRTC media server, via UDP protocol.
setEnvDefault("PROXY_WEBRTC_SERVER", "18000")
// The SRT media server, via UDP protocol.
setEnvDefault("PROXY_SRT_SERVER", "20080")
// The API server of proxy itself.
setEnvDefault("PROXY_SYSTEM_API", "12025")
// The static directory for web server.
setEnvDefault("PROXY_STATIC_FILES", "../trunk/research")
// The load balancer, use redis or memory.
setEnvDefault("PROXY_LOAD_BALANCER_TYPE", "memory")
// The redis server host.
setEnvDefault("PROXY_REDIS_HOST", "127.0.0.1")
// The redis server port.
setEnvDefault("PROXY_REDIS_PORT", "6379")
// The redis server password.
setEnvDefault("PROXY_REDIS_PASSWORD", "")
// The redis server db.
setEnvDefault("PROXY_REDIS_DB", "0")
// Whether enable the default backend server, for debugging.
setEnvDefault("PROXY_DEFAULT_BACKEND_ENABLED", "off")
// Default backend server IP, for debugging.
setEnvDefault("PROXY_DEFAULT_BACKEND_IP", "127.0.0.1")
// Default backend server port, for debugging.
setEnvDefault("PROXY_DEFAULT_BACKEND_RTMP", "1935")
// Default backend api port, for debugging.
setEnvDefault("PROXY_DEFAULT_BACKEND_API", "1985")
// Default backend udp rtc port, for debugging.
setEnvDefault("PROXY_DEFAULT_BACKEND_RTC", "8000")
// Default backend udp srt port, for debugging.
setEnvDefault("PROXY_DEFAULT_BACKEND_SRT", "10080")
logger.Df(ctx, "load .env as GO_PPROF=%v, "+
"PROXY_FORCE_QUIT_TIMEOUT=%v, PROXY_GRACE_QUIT_TIMEOUT=%v, "+
"PROXY_HTTP_API=%v, PROXY_HTTP_SERVER=%v, PROXY_RTMP_SERVER=%v, "+
"PROXY_WEBRTC_SERVER=%v, PROXY_SRT_SERVER=%v, "+
"PROXY_SYSTEM_API=%v, PROXY_STATIC_FILES=%v, PROXY_DEFAULT_BACKEND_ENABLED=%v, "+
"PROXY_DEFAULT_BACKEND_IP=%v, PROXY_DEFAULT_BACKEND_RTMP=%v, "+
"PROXY_DEFAULT_BACKEND_HTTP=%v, PROXY_DEFAULT_BACKEND_API=%v, "+
"PROXY_DEFAULT_BACKEND_RTC=%v, PROXY_DEFAULT_BACKEND_SRT=%v, "+
"PROXY_LOAD_BALANCER_TYPE=%v, PROXY_REDIS_HOST=%v, PROXY_REDIS_PORT=%v, "+
"PROXY_REDIS_PASSWORD=%v, PROXY_REDIS_DB=%v",
envGoPprof(),
envForceQuitTimeout(), envGraceQuitTimeout(),
envHttpAPI(), envHttpServer(), envRtmpServer(),
envWebRTCServer(), envSRTServer(),
envSystemAPI(), envStaticFiles(), envDefaultBackendEnabled(),
envDefaultBackendIP(), envDefaultBackendRTMP(),
envDefaultBackendHttp(), envDefaultBackendAPI(),
envDefaultBackendRTC(), envDefaultBackendSRT(),
envLoadBalancerType(), envRedisHost(), envRedisPort(),
envRedisPassword(), envRedisDB(),
)
}
func envStaticFiles() string {
return os.Getenv("PROXY_STATIC_FILES")
}
func envDefaultBackendSRT() string {
return os.Getenv("PROXY_DEFAULT_BACKEND_SRT")
}
func envDefaultBackendRTC() string {
return os.Getenv("PROXY_DEFAULT_BACKEND_RTC")
}
func envDefaultBackendAPI() string {
return os.Getenv("PROXY_DEFAULT_BACKEND_API")
}
func envSRTServer() string {
return os.Getenv("PROXY_SRT_SERVER")
}
func envWebRTCServer() string {
return os.Getenv("PROXY_WEBRTC_SERVER")
}
func envDefaultBackendHttp() string {
return os.Getenv("PROXY_DEFAULT_BACKEND_HTTP")
}
func envRedisDB() string {
return os.Getenv("PROXY_REDIS_DB")
}
func envRedisPassword() string {
return os.Getenv("PROXY_REDIS_PASSWORD")
}
func envRedisPort() string {
return os.Getenv("PROXY_REDIS_PORT")
}
func envRedisHost() string {
return os.Getenv("PROXY_REDIS_HOST")
}
func envLoadBalancerType() string {
return os.Getenv("PROXY_LOAD_BALANCER_TYPE")
}
func envDefaultBackendRTMP() string {
return os.Getenv("PROXY_DEFAULT_BACKEND_RTMP")
}
func envDefaultBackendIP() string {
return os.Getenv("PROXY_DEFAULT_BACKEND_IP")
}
func envDefaultBackendEnabled() string {
return os.Getenv("PROXY_DEFAULT_BACKEND_ENABLED")
}
func envGraceQuitTimeout() string {
return os.Getenv("PROXY_GRACE_QUIT_TIMEOUT")
}
func envForceQuitTimeout() string {
return os.Getenv("PROXY_FORCE_QUIT_TIMEOUT")
}
func envGoPprof() string {
return os.Getenv("GO_PPROF")
}
func envSystemAPI() string {
return os.Getenv("PROXY_SYSTEM_API")
}
func envRtmpServer() string {
return os.Getenv("PROXY_RTMP_SERVER")
}
func envHttpServer() string {
return os.Getenv("PROXY_HTTP_SERVER")
}
func envHttpAPI() string {
return os.Getenv("PROXY_HTTP_API")
}
// setEnvDefault set env key=value if not set.
func setEnvDefault(key, value string) {
if os.Getenv(key) == "" {
os.Setenv(key, value)
}
}

@ -0,0 +1,270 @@
// Package errors provides simple error handling primitives.
//
// The traditional error handling idiom in Go is roughly akin to
//
// if err != nil {
// return err
// }
//
// which applied recursively up the call stack results in error reports
// without context or debugging information. The errors package allows
// programmers to add context to the failure path in their code in a way
// that does not destroy the original value of the error.
//
// Adding context to an error
//
// The errors.Wrap function returns a new error that adds context to the
// original error by recording a stack trace at the point Wrap is called,
// and the supplied message. For example
//
// _, err := ioutil.ReadAll(r)
// if err != nil {
// return errors.Wrap(err, "read failed")
// }
//
// If additional control is required the errors.WithStack and errors.WithMessage
// functions destructure errors.Wrap into its component operations of annotating
// an error with a stack trace and an a message, respectively.
//
// Retrieving the cause of an error
//
// Using errors.Wrap constructs a stack of errors, adding context to the
// preceding error. Depending on the nature of the error it may be necessary
// to reverse the operation of errors.Wrap to retrieve the original error
// for inspection. Any error value which implements this interface
//
// type causer interface {
// Cause() error
// }
//
// can be inspected by errors.Cause. errors.Cause will recursively retrieve
// the topmost error which does not implement causer, which is assumed to be
// the original cause. For example:
//
// switch err := errors.Cause(err).(type) {
// case *MyError:
// // handle specifically
// default:
// // unknown error
// }
//
// causer interface is not exported by this package, but is considered a part
// of stable public API.
//
// Formatted printing of errors
//
// All error values returned from this package implement fmt.Formatter and can
// be formatted by the fmt package. The following verbs are supported
//
// %s print the error. If the error has a Cause it will be
// printed recursively
// %v see %s
// %+v extended format. Each Frame of the error's StackTrace will
// be printed in detail.
//
// Retrieving the stack trace of an error or wrapper
//
// New, Errorf, Wrap, and Wrapf record a stack trace at the point they are
// invoked. This information can be retrieved with the following interface.
//
// type stackTracer interface {
// StackTrace() errors.StackTrace
// }
//
// Where errors.StackTrace is defined as
//
// type StackTrace []Frame
//
// The Frame type represents a call site in the stack trace. Frame supports
// the fmt.Formatter interface that can be used for printing information about
// the stack trace of this error. For example:
//
// if err, ok := err.(stackTracer); ok {
// for _, f := range err.StackTrace() {
// fmt.Printf("%+s:%d", f)
// }
// }
//
// stackTracer interface is not exported by this package, but is considered a part
// of stable public API.
//
// See the documentation for Frame.Format for more details.
// Fork from https://github.com/pkg/errors
package errors
import (
"fmt"
"io"
)
// New returns an error with the supplied message.
// New also records the stack trace at the point it was called.
func New(message string) error {
return &fundamental{
msg: message,
stack: callers(),
}
}
// Errorf formats according to a format specifier and returns the string
// as a value that satisfies error.
// Errorf also records the stack trace at the point it was called.
func Errorf(format string, args ...interface{}) error {
return &fundamental{
msg: fmt.Sprintf(format, args...),
stack: callers(),
}
}
// fundamental is an error that has a message and a stack, but no caller.
type fundamental struct {
msg string
*stack
}
func (f *fundamental) Error() string { return f.msg }
func (f *fundamental) Format(s fmt.State, verb rune) {
switch verb {
case 'v':
if s.Flag('+') {
io.WriteString(s, f.msg)
f.stack.Format(s, verb)
return
}
fallthrough
case 's':
io.WriteString(s, f.msg)
case 'q':
fmt.Fprintf(s, "%q", f.msg)
}
}
// WithStack annotates err with a stack trace at the point WithStack was called.
// If err is nil, WithStack returns nil.
func WithStack(err error) error {
if err == nil {
return nil
}
return &withStack{
err,
callers(),
}
}
type withStack struct {
error
*stack
}
func (w *withStack) Cause() error { return w.error }
func (w *withStack) Format(s fmt.State, verb rune) {
switch verb {
case 'v':
if s.Flag('+') {
fmt.Fprintf(s, "%+v", w.Cause())
w.stack.Format(s, verb)
return
}
fallthrough
case 's':
io.WriteString(s, w.Error())
case 'q':
fmt.Fprintf(s, "%q", w.Error())
}
}
// Wrap returns an error annotating err with a stack trace
// at the point Wrap is called, and the supplied message.
// If err is nil, Wrap returns nil.
func Wrap(err error, message string) error {
if err == nil {
return nil
}
err = &withMessage{
cause: err,
msg: message,
}
return &withStack{
err,
callers(),
}
}
// Wrapf returns an error annotating err with a stack trace
// at the point Wrapf is call, and the format specifier.
// If err is nil, Wrapf returns nil.
func Wrapf(err error, format string, args ...interface{}) error {
if err == nil {
return nil
}
err = &withMessage{
cause: err,
msg: fmt.Sprintf(format, args...),
}
return &withStack{
err,
callers(),
}
}
// WithMessage annotates err with a new message.
// If err is nil, WithMessage returns nil.
func WithMessage(err error, message string) error {
if err == nil {
return nil
}
return &withMessage{
cause: err,
msg: message,
}
}
type withMessage struct {
cause error
msg string
}
func (w *withMessage) Error() string { return w.msg + ": " + w.cause.Error() }
func (w *withMessage) Cause() error { return w.cause }
func (w *withMessage) Format(s fmt.State, verb rune) {
switch verb {
case 'v':
if s.Flag('+') {
fmt.Fprintf(s, "%+v\n", w.Cause())
io.WriteString(s, w.msg)
return
}
fallthrough
case 's', 'q':
io.WriteString(s, w.Error())
}
}
// Cause returns the underlying cause of the error, if possible.
// An error value has a cause if it implements the following
// interface:
//
// type causer interface {
// Cause() error
// }
//
// If the error does not implement Cause, the original error will
// be returned. If the error is nil, nil will be returned without further
// investigation.
func Cause(err error) error {
type causer interface {
Cause() error
}
for err != nil {
cause, ok := err.(causer)
if !ok {
break
}
err = cause.Cause()
}
return err
}

@ -0,0 +1,187 @@
// Fork from https://github.com/pkg/errors
package errors
import (
"fmt"
"io"
"path"
"runtime"
"strings"
)
// Frame represents a program counter inside a stack frame.
type Frame uintptr
// pc returns the program counter for this frame;
// multiple frames may have the same PC value.
func (f Frame) pc() uintptr { return uintptr(f) - 1 }
// file returns the full path to the file that contains the
// function for this Frame's pc.
func (f Frame) file() string {
fn := runtime.FuncForPC(f.pc())
if fn == nil {
return "unknown"
}
file, _ := fn.FileLine(f.pc())
return file
}
// line returns the line number of source code of the
// function for this Frame's pc.
func (f Frame) line() int {
fn := runtime.FuncForPC(f.pc())
if fn == nil {
return 0
}
_, line := fn.FileLine(f.pc())
return line
}
// Format formats the frame according to the fmt.Formatter interface.
//
// %s source file
// %d source line
// %n function name
// %v equivalent to %s:%d
//
// Format accepts flags that alter the printing of some verbs, as follows:
//
// %+s path of source file relative to the compile time GOPATH
// %+v equivalent to %+s:%d
func (f Frame) Format(s fmt.State, verb rune) {
switch verb {
case 's':
switch {
case s.Flag('+'):
pc := f.pc()
fn := runtime.FuncForPC(pc)
if fn == nil {
io.WriteString(s, "unknown")
} else {
file, _ := fn.FileLine(pc)
fmt.Fprintf(s, "%s\n\t%s", fn.Name(), file)
}
default:
io.WriteString(s, path.Base(f.file()))
}
case 'd':
fmt.Fprintf(s, "%d", f.line())
case 'n':
name := runtime.FuncForPC(f.pc()).Name()
io.WriteString(s, funcname(name))
case 'v':
f.Format(s, 's')
io.WriteString(s, ":")
f.Format(s, 'd')
}
}
// StackTrace is stack of Frames from innermost (newest) to outermost (oldest).
type StackTrace []Frame
// Format formats the stack of Frames according to the fmt.Formatter interface.
//
// %s lists source files for each Frame in the stack
// %v lists the source file and line number for each Frame in the stack
//
// Format accepts flags that alter the printing of some verbs, as follows:
//
// %+v Prints filename, function, and line number for each Frame in the stack.
func (st StackTrace) Format(s fmt.State, verb rune) {
switch verb {
case 'v':
switch {
case s.Flag('+'):
for _, f := range st {
fmt.Fprintf(s, "\n%+v", f)
}
case s.Flag('#'):
fmt.Fprintf(s, "%#v", []Frame(st))
default:
fmt.Fprintf(s, "%v", []Frame(st))
}
case 's':
fmt.Fprintf(s, "%s", []Frame(st))
}
}
// stack represents a stack of program counters.
type stack []uintptr
func (s *stack) Format(st fmt.State, verb rune) {
switch verb {
case 'v':
switch {
case st.Flag('+'):
for _, pc := range *s {
f := Frame(pc)
fmt.Fprintf(st, "\n%+v", f)
}
}
}
}
func (s *stack) StackTrace() StackTrace {
f := make([]Frame, len(*s))
for i := 0; i < len(f); i++ {
f[i] = Frame((*s)[i])
}
return f
}
func callers() *stack {
const depth = 32
var pcs [depth]uintptr
n := runtime.Callers(3, pcs[:])
var st stack = pcs[0:n]
return &st
}
// funcname removes the path prefix component of a function's name reported by func.Name().
func funcname(name string) string {
i := strings.LastIndex(name, "/")
name = name[i+1:]
i = strings.Index(name, ".")
return name[i+1:]
}
func trimGOPATH(name, file string) string {
// Here we want to get the source file path relative to the compile time
// GOPATH. As of Go 1.6.x there is no direct way to know the compiled
// GOPATH at runtime, but we can infer the number of path segments in the
// GOPATH. We note that fn.Name() returns the function name qualified by
// the import path, which does not include the GOPATH. Thus we can trim
// segments from the beginning of the file path until the number of path
// separators remaining is one more than the number of path separators in
// the function name. For example, given:
//
// GOPATH /home/user
// file /home/user/src/pkg/sub/file.go
// fn.Name() pkg/sub.Type.Method
//
// We want to produce:
//
// pkg/sub/file.go
//
// From this we can easily see that fn.Name() has one less path separator
// than our desired output. We count separators from the end of the file
// path until it finds two more than in the function name and then move
// one character forward to preserve the initial path segment without a
// leading separator.
const sep = "/"
goal := strings.Count(name, sep) + 2
i := len(file)
for n := 0; n < goal; n++ {
i = strings.LastIndex(file[:i], sep)
if i == -1 {
// not enough separators found, set i so that the slice expression
// below leaves file unmodified
i = -len(sep)
break
}
}
// get back to 0 or trim the leading separator
file = file[i+len(sep):]
return file
}

@ -0,0 +1,13 @@
module srs-proxy
go 1.18
require (
github.com/go-redis/redis/v8 v8.11.5
github.com/joho/godotenv v1.5.1
)
require (
github.com/cespare/xxhash/v2 v2.1.2 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
)

@ -0,0 +1,17 @@
github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE=
github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE=
golang.org/x/net v0.0.0-20210428140749-89ef3d95e781 h1:DzZ89McO9/gWPsQXS/FVKAlG02ZjaQ6AlZRBimEYOd0=
golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e h1:fLOSk5Q00efkSvAm+4xcoXD+RRmLmmulPn5I3Y9F2EM=
golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=

@ -0,0 +1,419 @@
// Copyright (c) 2024 Winlin
//
// SPDX-License-Identifier: MIT
package main
import (
"context"
"fmt"
"io"
"io/ioutil"
"net/http"
"os"
"strconv"
"strings"
stdSync "sync"
"time"
"srs-proxy/errors"
"srs-proxy/logger"
)
// srsHTTPStreamServer is the proxy server for SRS HTTP stream server, for HTTP-FLV, HTTP-TS,
// HLS, etc. The proxy server will figure out which SRS origin server to proxy to, then proxy
// the request to the origin server.
type srsHTTPStreamServer struct {
// The underlayer HTTP server.
server *http.Server
// The gracefully quit timeout, wait server to quit.
gracefulQuitTimeout time.Duration
// The wait group for all goroutines.
wg stdSync.WaitGroup
}
func NewSRSHTTPStreamServer(opts ...func(*srsHTTPStreamServer)) *srsHTTPStreamServer {
v := &srsHTTPStreamServer{}
for _, opt := range opts {
opt(v)
}
return v
}
func (v *srsHTTPStreamServer) Close() error {
ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout)
defer cancel()
v.server.Shutdown(ctx)
v.wg.Wait()
return nil
}
func (v *srsHTTPStreamServer) Run(ctx context.Context) error {
// Parse address to listen.
addr := envHttpServer()
if !strings.Contains(addr, ":") {
addr = ":" + addr
}
// Create server and handler.
mux := http.NewServeMux()
v.server = &http.Server{Addr: addr, Handler: mux}
logger.Df(ctx, "HTTP Stream server listen at %v", addr)
// Shutdown the server gracefully when quiting.
go func() {
ctxParent := ctx
<-ctxParent.Done()
ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout)
defer cancel()
v.server.Shutdown(ctx)
}()
// The basic version handler, also can be used as health check API.
logger.Df(ctx, "Handle /api/v1/versions by %v", addr)
mux.HandleFunc("/api/v1/versions", func(w http.ResponseWriter, r *http.Request) {
type Response struct {
Code int `json:"code"`
PID string `json:"pid"`
Data struct {
Major int `json:"major"`
Minor int `json:"minor"`
Revision int `json:"revision"`
Version string `json:"version"`
} `json:"data"`
}
res := Response{Code: 0, PID: fmt.Sprintf("%v", os.Getpid())}
res.Data.Major = VersionMajor()
res.Data.Minor = VersionMinor()
res.Data.Revision = VersionRevision()
res.Data.Version = Version()
apiResponse(ctx, w, r, &res)
})
// The static web server, for the web pages.
var staticServer http.Handler
if staticFiles := envStaticFiles(); staticFiles != "" {
if _, err := os.Stat(staticFiles); err != nil {
return errors.Wrapf(err, "invalid static files %v", staticFiles)
}
staticServer = http.FileServer(http.Dir(staticFiles))
logger.Df(ctx, "Handle static files at %v", staticFiles)
}
// The default handler, for both static web server and streaming server.
logger.Df(ctx, "Handle / by %v", addr)
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
// For HLS streaming, we will proxy the request to the streaming server.
if strings.HasSuffix(r.URL.Path, ".m3u8") {
unifiedURL, fullURL := convertURLToStreamURL(r)
streamURL, err := buildStreamURL(unifiedURL)
if err != nil {
http.Error(w, fmt.Sprintf("build stream url by %v from %v", unifiedURL, fullURL), http.StatusBadRequest)
return
}
stream, _ := srsLoadBalancer.LoadOrStoreHLS(ctx, streamURL, NewHLSPlayStream(func(s *HLSPlayStream) {
s.SRSProxyBackendHLSID = logger.GenerateContextID()
s.StreamURL, s.FullURL = streamURL, fullURL
}))
stream.Initialize(ctx).ServeHTTP(w, r)
return
}
// For HTTP streaming, we will proxy the request to the streaming server.
if strings.HasSuffix(r.URL.Path, ".flv") ||
strings.HasSuffix(r.URL.Path, ".ts") {
// If SPBHID is specified, it must be a HLS stream client.
if srsProxyBackendID := r.URL.Query().Get("spbhid"); srsProxyBackendID != "" {
if stream, err := srsLoadBalancer.LoadHLSBySPBHID(ctx, srsProxyBackendID); err != nil {
http.Error(w, fmt.Sprintf("load stream by spbhid %v", srsProxyBackendID), http.StatusBadRequest)
} else {
stream.Initialize(ctx).ServeHTTP(w, r)
}
return
}
// Use HTTP pseudo streaming to proxy the request.
NewHTTPFlvTsConnection(func(c *HTTPFlvTsConnection) {
c.ctx = ctx
}).ServeHTTP(w, r)
return
}
// Serve by static server.
if staticServer != nil {
staticServer.ServeHTTP(w, r)
return
}
http.NotFound(w, r)
})
// Run HTTP server.
v.wg.Add(1)
go func() {
defer v.wg.Done()
err := v.server.ListenAndServe()
if err != nil {
if ctx.Err() != context.Canceled {
// TODO: If HTTP Stream server closed unexpectedly, we should notice the main loop to quit.
logger.Wf(ctx, "HTTP Stream accept err %+v", err)
} else {
logger.Df(ctx, "HTTP Stream server done")
}
}
}()
return nil
}
// HTTPFlvTsConnection is an HTTP pseudo streaming connection, such as an HTTP-FLV or HTTP-TS
// connection. There is no state need to be sync between proxy servers.
//
// When we got an HTTP FLV or TS request, we will parse the stream URL from the HTTP request,
// then proxy to the corresponding backend server. All state is in the HTTP request, so this
// connection is stateless.
type HTTPFlvTsConnection struct {
// The context for HTTP streaming.
ctx context.Context
}
func NewHTTPFlvTsConnection(opts ...func(*HTTPFlvTsConnection)) *HTTPFlvTsConnection {
v := &HTTPFlvTsConnection{}
for _, opt := range opts {
opt(v)
}
return v
}
func (v *HTTPFlvTsConnection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
defer r.Body.Close()
ctx := logger.WithContext(v.ctx)
if err := v.serve(ctx, w, r); err != nil {
apiError(ctx, w, r, err)
} else {
logger.Df(ctx, "HTTP client done")
}
}
func (v *HTTPFlvTsConnection) serve(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
// Always allow CORS for all requests.
if ok := apiCORS(ctx, w, r); ok {
return nil
}
// Build the stream URL in vhost/app/stream schema.
unifiedURL, fullURL := convertURLToStreamURL(r)
logger.Df(ctx, "Got HTTP client from %v for %v", r.RemoteAddr, fullURL)
streamURL, err := buildStreamURL(unifiedURL)
if err != nil {
return errors.Wrapf(err, "build stream url %v", unifiedURL)
}
// Pick a backend SRS server to proxy the RTMP stream.
backend, err := srsLoadBalancer.Pick(ctx, streamURL)
if err != nil {
return errors.Wrapf(err, "pick backend for %v", streamURL)
}
if err = v.serveByBackend(ctx, w, r, backend); err != nil {
return errors.Wrapf(err, "serve %v with %v by backend %+v", fullURL, streamURL, backend)
}
return nil
}
func (v *HTTPFlvTsConnection) serveByBackend(ctx context.Context, w http.ResponseWriter, r *http.Request, backend *SRSServer) error {
// Parse HTTP port from backend.
if len(backend.HTTP) == 0 {
return errors.Errorf("no http stream server")
}
var httpPort int
if iv, err := strconv.ParseInt(backend.HTTP[0], 10, 64); err != nil {
return errors.Wrapf(err, "parse http port %v", backend.HTTP[0])
} else {
httpPort = int(iv)
}
// Connect to backend SRS server via HTTP client.
backendURL := fmt.Sprintf("http://%v:%v%s", backend.IP, httpPort, r.URL.Path)
req, err := http.NewRequestWithContext(ctx, r.Method, backendURL, nil)
if err != nil {
return errors.Wrapf(err, "create request to %v", backendURL)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return errors.Wrapf(err, "do request to %v", backendURL)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return errors.Errorf("proxy stream to %v failed, status=%v", backendURL, resp.Status)
}
// Copy all headers from backend to client.
w.WriteHeader(resp.StatusCode)
for k, v := range resp.Header {
for _, vv := range v {
w.Header().Add(k, vv)
}
}
logger.Df(ctx, "HTTP start streaming")
// Proxy the stream from backend to client.
if _, err := io.Copy(w, resp.Body); err != nil {
return errors.Wrapf(err, "copy stream to client, backend=%v", backendURL)
}
return nil
}
// HLSPlayStream is an HLS stream proxy, which represents the stream level object. This means multiple HLS
// clients will share this object, and they do not use the same ctx among proxy servers.
//
// Unlike the HTTP FLV or TS connection, HLS client may request the m3u8 or ts via different HTTP connections.
// Especially for requesting ts, we need to identify the stream URl or backend server for it. So we create
// the spbhid which can be seen as the hash of stream URL or backend server. The spbhid enable us to convert
// to the stream URL and then query the backend server to serve it.
type HLSPlayStream struct {
// The context for HLS streaming.
ctx context.Context
// The spbhid, used to identify the backend server.
SRSProxyBackendHLSID string `json:"spbhid"`
// The stream URL in vhost/app/stream schema.
StreamURL string `json:"stream_url"`
// The full request URL for HLS streaming
FullURL string `json:"full_url"`
}
func NewHLSPlayStream(opts ...func(*HLSPlayStream)) *HLSPlayStream {
v := &HLSPlayStream{}
for _, opt := range opts {
opt(v)
}
return v
}
func (v *HLSPlayStream) Initialize(ctx context.Context) *HLSPlayStream {
if v.ctx == nil {
v.ctx = logger.WithContext(ctx)
}
return v
}
func (v *HLSPlayStream) ServeHTTP(w http.ResponseWriter, r *http.Request) {
defer r.Body.Close()
if err := v.serve(v.ctx, w, r); err != nil {
apiError(v.ctx, w, r, err)
} else {
logger.Df(v.ctx, "HLS client %v for %v with %v done",
v.SRSProxyBackendHLSID, v.StreamURL, r.URL.Path)
}
}
func (v *HLSPlayStream) serve(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
ctx, streamURL, fullURL := v.ctx, v.StreamURL, v.FullURL
// Always allow CORS for all requests.
if ok := apiCORS(ctx, w, r); ok {
return nil
}
// Pick a backend SRS server to proxy the RTMP stream.
backend, err := srsLoadBalancer.Pick(ctx, streamURL)
if err != nil {
return errors.Wrapf(err, "pick backend for %v", streamURL)
}
if err = v.serveByBackend(ctx, w, r, backend); err != nil {
return errors.Wrapf(err, "serve %v with %v by backend %+v", fullURL, streamURL, backend)
}
return nil
}
func (v *HLSPlayStream) serveByBackend(ctx context.Context, w http.ResponseWriter, r *http.Request, backend *SRSServer) error {
// Parse HTTP port from backend.
if len(backend.HTTP) == 0 {
return errors.Errorf("no rtmp server")
}
var httpPort int
if iv, err := strconv.ParseInt(backend.HTTP[0], 10, 64); err != nil {
return errors.Wrapf(err, "parse http port %v", backend.HTTP[0])
} else {
httpPort = int(iv)
}
// Connect to backend SRS server via HTTP client.
backendURL := fmt.Sprintf("http://%v:%v%s", backend.IP, httpPort, r.URL.Path)
if r.URL.RawQuery != "" {
backendURL += "?" + r.URL.RawQuery
}
req, err := http.NewRequestWithContext(ctx, r.Method, backendURL, nil)
if err != nil {
return errors.Wrapf(err, "create request to %v", backendURL)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return errors.Errorf("do request to %v EOF", backendURL)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return errors.Errorf("proxy stream to %v failed, status=%v", backendURL, resp.Status)
}
// Copy all headers from backend to client.
w.WriteHeader(resp.StatusCode)
for k, v := range resp.Header {
for _, vv := range v {
w.Header().Add(k, vv)
}
}
// For TS file, directly copy it.
if !strings.HasSuffix(r.URL.Path, ".m3u8") {
if _, err := io.Copy(w, resp.Body); err != nil {
return errors.Wrapf(err, "copy stream to client, backend=%v", backendURL)
}
return nil
}
// Read all content of m3u8, append the stream ID to ts URL. Note that we only append stream ID to ts
// URL, to identify the stream to specified backend server. The spbhid is the SRS Proxy Backend HLS ID.
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
return errors.Wrapf(err, "read stream from %v", backendURL)
}
m3u8 := string(b)
if strings.Contains(m3u8, ".ts?") {
m3u8 = strings.ReplaceAll(m3u8, ".ts?", fmt.Sprintf(".ts?spbhid=%v&&", v.SRSProxyBackendHLSID))
} else {
m3u8 = strings.ReplaceAll(m3u8, ".ts", fmt.Sprintf(".ts?spbhid=%v", v.SRSProxyBackendHLSID))
}
if _, err := io.Copy(w, strings.NewReader(m3u8)); err != nil {
return errors.Wrapf(err, "proxy m3u8 client to %v", backendURL)
}
return nil
}

@ -0,0 +1,43 @@
// Copyright (c) 2024 Winlin
//
// SPDX-License-Identifier: MIT
package logger
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
)
type key string
var cidKey key = "cid.proxy.ossrs.org"
// generateContextID generates a random context id in string.
func GenerateContextID() string {
randomBytes := make([]byte, 32)
_, _ = rand.Read(randomBytes)
hash := sha256.Sum256(randomBytes)
hashString := hex.EncodeToString(hash[:])
cid := hashString[:7]
return cid
}
// WithContext creates a new context with cid, which will be used for log.
func WithContext(ctx context.Context) context.Context {
return WithContextID(ctx, GenerateContextID())
}
// WithContextID creates a new context with cid, which will be used for log.
func WithContextID(ctx context.Context, cid string) context.Context {
return context.WithValue(ctx, cidKey, cid)
}
// ContextID returns the cid in context, or empty string if not set.
func ContextID(ctx context.Context) string {
if cid, ok := ctx.Value(cidKey).(string); ok {
return cid
}
return ""
}

@ -0,0 +1,87 @@
// Copyright (c) 2024 Winlin
//
// SPDX-License-Identifier: MIT
package logger
import (
"context"
"io/ioutil"
stdLog "log"
"os"
)
type logger interface {
Printf(ctx context.Context, format string, v ...any)
}
type loggerPlus struct {
logger *stdLog.Logger
level string
}
func newLoggerPlus(opts ...func(*loggerPlus)) *loggerPlus {
v := &loggerPlus{}
for _, opt := range opts {
opt(v)
}
return v
}
func (v *loggerPlus) Printf(ctx context.Context, f string, a ...interface{}) {
format, args := f, a
if cid := ContextID(ctx); cid != "" {
format, args = "[%v][%v][%v] "+format, append([]interface{}{v.level, os.Getpid(), cid}, a...)
}
v.logger.Printf(format, args...)
}
var verboseLogger logger
func Vf(ctx context.Context, format string, a ...interface{}) {
verboseLogger.Printf(ctx, format, a...)
}
var debugLogger logger
func Df(ctx context.Context, format string, a ...interface{}) {
debugLogger.Printf(ctx, format, a...)
}
var warnLogger logger
func Wf(ctx context.Context, format string, a ...interface{}) {
warnLogger.Printf(ctx, format, a...)
}
var errorLogger logger
func Ef(ctx context.Context, format string, a ...interface{}) {
errorLogger.Printf(ctx, format, a...)
}
const (
logVerboseLabel = "verb"
logDebugLabel = "debug"
logWarnLabel = "warn"
logErrorLabel = "error"
)
func init() {
verboseLogger = newLoggerPlus(func(logger *loggerPlus) {
logger.logger = stdLog.New(ioutil.Discard, "", stdLog.Ldate|stdLog.Ltime|stdLog.Lmicroseconds)
logger.level = logVerboseLabel
})
debugLogger = newLoggerPlus(func(logger *loggerPlus) {
logger.logger = stdLog.New(os.Stdout, "", stdLog.Ldate|stdLog.Ltime|stdLog.Lmicroseconds)
logger.level = logDebugLabel
})
warnLogger = newLoggerPlus(func(logger *loggerPlus) {
logger.logger = stdLog.New(os.Stderr, "", stdLog.Ldate|stdLog.Ltime|stdLog.Lmicroseconds)
logger.level = logWarnLabel
})
errorLogger = newLoggerPlus(func(logger *loggerPlus) {
logger.logger = stdLog.New(os.Stderr, "", stdLog.Ldate|stdLog.Ltime|stdLog.Lmicroseconds)
logger.level = logErrorLabel
})
}

@ -0,0 +1,121 @@
// Copyright (c) 2024 Winlin
//
// SPDX-License-Identifier: MIT
package main
import (
"context"
"os"
"srs-proxy/errors"
"srs-proxy/logger"
)
func main() {
ctx := logger.WithContext(context.Background())
logger.Df(ctx, "%v/%v started", Signature(), Version())
// Install signals.
ctx, cancel := context.WithCancel(ctx)
installSignals(ctx, cancel)
// Start the main loop, ignore the user cancel error.
err := doMain(ctx)
if err != nil && ctx.Err() != context.Canceled {
logger.Ef(ctx, "main: %+v", err)
os.Exit(-1)
}
logger.Df(ctx, "%v done", Signature())
}
func doMain(ctx context.Context) error {
// Setup the environment variables.
if err := loadEnvFile(ctx); err != nil {
return errors.Wrapf(err, "load env")
}
buildDefaultEnvironmentVariables(ctx)
// When cancelled, the program is forced to exit due to a timeout. Normally, this doesn't occur
// because the main thread exits after the context is cancelled. However, sometimes the main thread
// may be blocked for some reason, so a forced exit is necessary to ensure the program terminates.
if err := installForceQuit(ctx); err != nil {
return errors.Wrapf(err, "install force quit")
}
// Start the Go pprof if enabled.
handleGoPprof(ctx)
// Initialize SRS load balancers.
switch lbType := envLoadBalancerType(); lbType {
case "memory":
srsLoadBalancer = NewMemoryLoadBalancer()
case "redis":
srsLoadBalancer = NewRedisLoadBalancer()
default:
return errors.Errorf("invalid load balancer %v", lbType)
}
if err := srsLoadBalancer.Initialize(ctx); err != nil {
return errors.Wrapf(err, "initialize srs load balancer")
}
// Parse the gracefully quit timeout.
gracefulQuitTimeout, err := parseGracefullyQuitTimeout()
if err != nil {
return errors.Wrapf(err, "parse gracefully quit timeout")
}
// Start the RTMP server.
srsRTMPServer := NewSRSRTMPServer()
defer srsRTMPServer.Close()
if err := srsRTMPServer.Run(ctx); err != nil {
return errors.Wrapf(err, "rtmp server")
}
// Start the WebRTC server.
srsWebRTCServer := NewSRSWebRTCServer()
defer srsWebRTCServer.Close()
if err := srsWebRTCServer.Run(ctx); err != nil {
return errors.Wrapf(err, "rtc server")
}
// Start the HTTP API server.
srsHTTPAPIServer := NewSRSHTTPAPIServer(func(server *srsHTTPAPIServer) {
server.gracefulQuitTimeout, server.rtc = gracefulQuitTimeout, srsWebRTCServer
})
defer srsHTTPAPIServer.Close()
if err := srsHTTPAPIServer.Run(ctx); err != nil {
return errors.Wrapf(err, "http api server")
}
// Start the SRT server.
srsSRTServer := NewSRSSRTServer()
defer srsSRTServer.Close()
if err := srsSRTServer.Run(ctx); err != nil {
return errors.Wrapf(err, "srt server")
}
// Start the System API server.
systemAPI := NewSystemAPI(func(server *systemAPI) {
server.gracefulQuitTimeout = gracefulQuitTimeout
})
defer systemAPI.Close()
if err := systemAPI.Run(ctx); err != nil {
return errors.Wrapf(err, "system api server")
}
// Start the HTTP web server.
srsHTTPStreamServer := NewSRSHTTPStreamServer(func(server *srsHTTPStreamServer) {
server.gracefulQuitTimeout = gracefulQuitTimeout
})
defer srsHTTPStreamServer.Close()
if err := srsHTTPStreamServer.Run(ctx); err != nil {
return errors.Wrapf(err, "http server")
}
// Wait for the main loop to quit.
<-ctx.Done()
return nil
}

@ -0,0 +1,515 @@
// Copyright (c) 2024 Winlin
//
// SPDX-License-Identifier: MIT
package main
import (
"context"
"encoding/binary"
"fmt"
"io/ioutil"
"net"
"net/http"
"strconv"
"strings"
stdSync "sync"
"srs-proxy/errors"
"srs-proxy/logger"
"srs-proxy/sync"
)
// srsWebRTCServer is the proxy for SRS WebRTC server via WHIP or WHEP protocol. It will figure out
// which backend server to proxy to. It will also replace the UDP port to the proxy server's in the
// SDP answer.
type srsWebRTCServer struct {
// The UDP listener for WebRTC server.
listener *net.UDPConn
// Fast cache for the username to identify the connection.
// The key is username, the value is the UDP address.
usernames sync.Map[string, *RTCConnection]
// Fast cache for the udp address to identify the connection.
// The key is UDP address, the value is the username.
// TODO: Support fast earch by uint64 address.
addresses sync.Map[string, *RTCConnection]
// The wait group for server.
wg stdSync.WaitGroup
}
func NewSRSWebRTCServer(opts ...func(*srsWebRTCServer)) *srsWebRTCServer {
v := &srsWebRTCServer{}
for _, opt := range opts {
opt(v)
}
return v
}
func (v *srsWebRTCServer) Close() error {
if v.listener != nil {
_ = v.listener.Close()
}
v.wg.Wait()
return nil
}
func (v *srsWebRTCServer) HandleApiForWHIP(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
defer r.Body.Close()
ctx = logger.WithContext(ctx)
// Always allow CORS for all requests.
if ok := apiCORS(ctx, w, r); ok {
return nil
}
// Read remote SDP offer from body.
remoteSDPOffer, err := ioutil.ReadAll(r.Body)
if err != nil {
return errors.Wrapf(err, "read remote sdp offer")
}
// Build the stream URL in vhost/app/stream schema.
unifiedURL, fullURL := convertURLToStreamURL(r)
logger.Df(ctx, "Got WebRTC WHIP from %v with %vB offer for %v", r.RemoteAddr, len(remoteSDPOffer), fullURL)
streamURL, err := buildStreamURL(unifiedURL)
if err != nil {
return errors.Wrapf(err, "build stream url %v", unifiedURL)
}
// Pick a backend SRS server to proxy the RTMP stream.
backend, err := srsLoadBalancer.Pick(ctx, streamURL)
if err != nil {
return errors.Wrapf(err, "pick backend for %v", streamURL)
}
if err = v.proxyApiToBackend(ctx, w, r, backend, string(remoteSDPOffer), streamURL); err != nil {
return errors.Wrapf(err, "serve %v with %v by backend %+v", fullURL, streamURL, backend)
}
return nil
}
func (v *srsWebRTCServer) HandleApiForWHEP(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
defer r.Body.Close()
ctx = logger.WithContext(ctx)
// Always allow CORS for all requests.
if ok := apiCORS(ctx, w, r); ok {
return nil
}
// Read remote SDP offer from body.
remoteSDPOffer, err := ioutil.ReadAll(r.Body)
if err != nil {
return errors.Wrapf(err, "read remote sdp offer")
}
// Build the stream URL in vhost/app/stream schema.
unifiedURL, fullURL := convertURLToStreamURL(r)
logger.Df(ctx, "Got WebRTC WHEP from %v with %vB offer for %v", r.RemoteAddr, len(remoteSDPOffer), fullURL)
streamURL, err := buildStreamURL(unifiedURL)
if err != nil {
return errors.Wrapf(err, "build stream url %v", unifiedURL)
}
// Pick a backend SRS server to proxy the RTMP stream.
backend, err := srsLoadBalancer.Pick(ctx, streamURL)
if err != nil {
return errors.Wrapf(err, "pick backend for %v", streamURL)
}
if err = v.proxyApiToBackend(ctx, w, r, backend, string(remoteSDPOffer), streamURL); err != nil {
return errors.Wrapf(err, "serve %v with %v by backend %+v", fullURL, streamURL, backend)
}
return nil
}
func (v *srsWebRTCServer) proxyApiToBackend(
ctx context.Context, w http.ResponseWriter, r *http.Request, backend *SRSServer,
remoteSDPOffer string, streamURL string,
) error {
// Parse HTTP port from backend.
if len(backend.API) == 0 {
return errors.Errorf("no http api server")
}
var apiPort int
if iv, err := strconv.ParseInt(backend.API[0], 10, 64); err != nil {
return errors.Wrapf(err, "parse http port %v", backend.API[0])
} else {
apiPort = int(iv)
}
// Connect to backend SRS server via HTTP client.
backendURL := fmt.Sprintf("http://%v:%v%s", backend.IP, apiPort, r.URL.Path)
if r.URL.RawQuery != "" {
backendURL += "?" + r.URL.RawQuery
}
req, err := http.NewRequestWithContext(ctx, r.Method, backendURL, strings.NewReader(remoteSDPOffer))
if err != nil {
return errors.Wrapf(err, "create request to %v", backendURL)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return errors.Errorf("do request to %v EOF", backendURL)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
return errors.Errorf("proxy api to %v failed, status=%v", backendURL, resp.Status)
}
// Copy all headers from backend to client.
w.WriteHeader(resp.StatusCode)
for k, v := range resp.Header {
for _, vv := range v {
w.Header().Add(k, vv)
}
}
// Parse the local SDP answer from backend.
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
return errors.Wrapf(err, "read stream from %v", backendURL)
}
// Replace the WebRTC UDP port in answer.
localSDPAnswer := string(b)
for _, endpoint := range backend.RTC {
_, _, port, err := parseListenEndpoint(endpoint)
if err != nil {
return errors.Wrapf(err, "parse endpoint %v", endpoint)
}
from := fmt.Sprintf(" %v typ host", port)
to := fmt.Sprintf(" %v typ host", envWebRTCServer())
localSDPAnswer = strings.Replace(localSDPAnswer, from, to, -1)
}
// Fetch the ice-ufrag and ice-pwd from local SDP answer.
remoteICEUfrag, remoteICEPwd, err := parseIceUfragPwd(remoteSDPOffer)
if err != nil {
return errors.Wrapf(err, "parse remote sdp offer")
}
localICEUfrag, localICEPwd, err := parseIceUfragPwd(localSDPAnswer)
if err != nil {
return errors.Wrapf(err, "parse local sdp answer")
}
// Save the new WebRTC connection to LB.
icePair := &RTCICEPair{
RemoteICEUfrag: remoteICEUfrag, RemoteICEPwd: remoteICEPwd,
LocalICEUfrag: localICEUfrag, LocalICEPwd: localICEPwd,
}
if err := srsLoadBalancer.StoreWebRTC(ctx, streamURL, NewRTCConnection(func(c *RTCConnection) {
c.StreamURL, c.Ufrag = streamURL, icePair.Ufrag()
c.Initialize(ctx, v.listener)
// Cache the connection for fast search by username.
v.usernames.Store(c.Ufrag, c)
})); err != nil {
return errors.Wrapf(err, "load or store webrtc %v", streamURL)
}
// Response client with local answer.
if _, err = w.Write([]byte(localSDPAnswer)); err != nil {
return errors.Wrapf(err, "write local sdp answer %v", localSDPAnswer)
}
logger.Df(ctx, "Create WebRTC connection with local answer %vB with ice-ufrag=%v, ice-pwd=%vB",
len(localSDPAnswer), localICEUfrag, len(localICEPwd))
return nil
}
func (v *srsWebRTCServer) Run(ctx context.Context) error {
// Parse address to listen.
endpoint := envWebRTCServer()
if !strings.Contains(endpoint, ":") {
endpoint = fmt.Sprintf(":%v", endpoint)
}
saddr, err := net.ResolveUDPAddr("udp", endpoint)
if err != nil {
return errors.Wrapf(err, "resolve udp addr %v", endpoint)
}
listener, err := net.ListenUDP("udp", saddr)
if err != nil {
return errors.Wrapf(err, "listen udp %v", saddr)
}
v.listener = listener
logger.Df(ctx, "WebRTC server listen at %v", saddr)
// Consume all messages from UDP media transport.
v.wg.Add(1)
go func() {
defer v.wg.Done()
for ctx.Err() == nil {
buf := make([]byte, 4096)
n, caddr, err := listener.ReadFromUDP(buf)
if err != nil {
// TODO: If WebRTC server closed unexpectedly, we should notice the main loop to quit.
logger.Wf(ctx, "read from udp failed, err=%+v", err)
continue
}
if err := v.handleClientUDP(ctx, caddr, buf[:n]); err != nil {
logger.Wf(ctx, "handle udp %vB failed, addr=%v, err=%+v", n, caddr, err)
}
}
}()
return nil
}
func (v *srsWebRTCServer) handleClientUDP(ctx context.Context, addr *net.UDPAddr, data []byte) error {
var connection *RTCConnection
// If STUN binding request, parse the ufrag and identify the connection.
if err := func() error {
if rtcIsRTPOrRTCP(data) || !rtcIsSTUN(data) {
return nil
}
var pkt RTCStunPacket
if err := pkt.UnmarshalBinary(data); err != nil {
return errors.Wrapf(err, "unmarshal stun packet")
}
// Search the connection in fast cache.
if s, ok := v.usernames.Load(pkt.Username); ok {
connection = s
return nil
}
// Load connection by username.
if s, err := srsLoadBalancer.LoadWebRTCByUfrag(ctx, pkt.Username); err != nil {
return errors.Wrapf(err, "load webrtc by ufrag %v", pkt.Username)
} else {
connection = s.Initialize(ctx, v.listener)
logger.Df(ctx, "Create WebRTC connection by ufrag=%v, stream=%v", pkt.Username, connection.StreamURL)
}
// Cache connection for fast search.
if connection != nil {
v.usernames.Store(pkt.Username, connection)
}
return nil
}(); err != nil {
return err
}
// Search the connection by addr.
if s, ok := v.addresses.Load(addr.String()); ok {
connection = s
} else if connection != nil {
// Cache the address for fast search.
v.addresses.Store(addr.String(), connection)
}
// If connection is not found, ignore the packet.
if connection == nil {
// TODO: Should logging the dropped packet, only logging the first one for each address.
return nil
}
// Proxy the packet to backend.
if err := connection.HandlePacket(addr, data); err != nil {
return errors.Wrapf(err, "proxy %vB for %v", len(data), connection.StreamURL)
}
return nil
}
// RTCConnection is a WebRTC connection proxy, for both WHIP and WHEP. It represents a WebRTC
// connection, identify by the ufrag in sdp offer/answer and ICE binding request.
//
// It's not like RTMP or HTTP FLV/TS proxy connection, which are stateless and all state is
// in the client request. The RTCConnection is stateful, and need to sync the ufrag between
// proxy servers.
//
// The media transport is UDP, which is also a special thing for WebRTC. So if the client switch
// to another UDP address, it may connect to another WebRTC proxy, then we should discover the
// RTCConnection by the ufrag from the ICE binding request.
type RTCConnection struct {
// The stream context for WebRTC streaming.
ctx context.Context
// The stream URL in vhost/app/stream schema.
StreamURL string `json:"stream_url"`
// The ufrag for this WebRTC connection.
Ufrag string `json:"ufrag"`
// The UDP connection proxy to backend.
backendUDP *net.UDPConn
// The client UDP address. Note that it may change.
clientUDP *net.UDPAddr
// The listener UDP connection, used to send messages to client.
listenerUDP *net.UDPConn
}
func NewRTCConnection(opts ...func(*RTCConnection)) *RTCConnection {
v := &RTCConnection{}
for _, opt := range opts {
opt(v)
}
return v
}
func (v *RTCConnection) Initialize(ctx context.Context, listener *net.UDPConn) *RTCConnection {
if v.ctx == nil {
v.ctx = logger.WithContext(ctx)
}
if listener != nil {
v.listenerUDP = listener
}
return v
}
func (v *RTCConnection) HandlePacket(addr *net.UDPAddr, data []byte) error {
ctx := v.ctx
// Update the current UDP address.
v.clientUDP = addr
// Start the UDP proxy to backend.
if err := v.connectBackend(ctx); err != nil {
return errors.Wrapf(err, "connect backend for %v", v.StreamURL)
}
// Proxy client message to backend.
if v.backendUDP == nil {
return nil
}
// Proxy all messages from backend to client.
go func() {
for ctx.Err() == nil {
buf := make([]byte, 4096)
n, _, err := v.backendUDP.ReadFromUDP(buf)
if err != nil {
// TODO: If backend server closed unexpectedly, we should notice the stream to quit.
logger.Wf(ctx, "read from backend failed, err=%v", err)
break
}
if _, err = v.listenerUDP.WriteToUDP(buf[:n], v.clientUDP); err != nil {
// TODO: If backend server closed unexpectedly, we should notice the stream to quit.
logger.Wf(ctx, "write to client failed, err=%v", err)
break
}
}
}()
if _, err := v.backendUDP.Write(data); err != nil {
return errors.Wrapf(err, "write to backend %v", v.StreamURL)
}
return nil
}
func (v *RTCConnection) connectBackend(ctx context.Context) error {
if v.backendUDP != nil {
return nil
}
// Pick a backend SRS server to proxy the RTC stream.
backend, err := srsLoadBalancer.Pick(ctx, v.StreamURL)
if err != nil {
return errors.Wrapf(err, "pick backend")
}
// Parse UDP port from backend.
if len(backend.RTC) == 0 {
return errors.Errorf("no udp server")
}
_, _, udpPort, err := parseListenEndpoint(backend.RTC[0])
if err != nil {
return errors.Wrapf(err, "parse udp port %v of %v for %v", backend.RTC[0], backend, v.StreamURL)
}
// Connect to backend SRS server via UDP client.
// TODO: FIXME: Support close the connection when timeout or DTLS alert.
backendAddr := net.UDPAddr{IP: net.ParseIP(backend.IP), Port: int(udpPort)}
if backendUDP, err := net.DialUDP("udp", nil, &backendAddr); err != nil {
return errors.Wrapf(err, "dial udp to %v", backendAddr)
} else {
v.backendUDP = backendUDP
}
return nil
}
type RTCICEPair struct {
// The remote ufrag, used for ICE username and session id.
RemoteICEUfrag string `json:"remote_ufrag"`
// The remote pwd, used for ICE password.
RemoteICEPwd string `json:"remote_pwd"`
// The local ufrag, used for ICE username and session id.
LocalICEUfrag string `json:"local_ufrag"`
// The local pwd, used for ICE password.
LocalICEPwd string `json:"local_pwd"`
}
// Generate the ICE ufrag for the WebRTC streaming, format is remote-ufrag:local-ufrag.
func (v *RTCICEPair) Ufrag() string {
return fmt.Sprintf("%v:%v", v.LocalICEUfrag, v.RemoteICEUfrag)
}
type RTCStunPacket struct {
// The stun message type.
MessageType uint16
// The stun username, or ufrag.
Username string
}
func (v *RTCStunPacket) UnmarshalBinary(data []byte) error {
if len(data) < 20 {
return errors.Errorf("stun packet too short %v", len(data))
}
p := data
v.MessageType = binary.BigEndian.Uint16(p)
messageLen := binary.BigEndian.Uint16(p[2:])
//magicCookie := p[:8]
//transactionID := p[:20]
p = p[20:]
if len(p) != int(messageLen) {
return errors.Errorf("stun packet length invalid %v != %v", len(data), messageLen)
}
for len(p) > 0 {
typ := binary.BigEndian.Uint16(p)
length := binary.BigEndian.Uint16(p[2:])
p = p[4:]
if len(p) < int(length) {
return errors.Errorf("stun attribute length invalid %v < %v", len(p), length)
}
value := p[:length]
p = p[length:]
if length%4 != 0 {
p = p[4-length%4:]
}
switch typ {
case 0x0006:
v.Username = string(value)
}
}
return nil
}

@ -0,0 +1,655 @@
// Copyright (c) 2024 Winlin
//
// SPDX-License-Identifier: MIT
package main
import (
"context"
"fmt"
"math/rand"
"net"
"strconv"
"strings"
"sync"
"time"
"srs-proxy/errors"
"srs-proxy/logger"
"srs-proxy/rtmp"
)
// srsRTMPServer is the proxy for SRS RTMP server, to proxy the RTMP stream to backend SRS
// server. It will figure out the backend server to proxy to. Unlike the edge server, it will
// not cache the stream, but just proxy the stream to backend.
type srsRTMPServer struct {
// The TCP listener for RTMP server.
listener *net.TCPListener
// The random number generator.
rd *rand.Rand
// The wait group for all goroutines.
wg sync.WaitGroup
}
func NewSRSRTMPServer(opts ...func(*srsRTMPServer)) *srsRTMPServer {
v := &srsRTMPServer{
rd: rand.New(rand.NewSource(time.Now().UnixNano())),
}
for _, opt := range opts {
opt(v)
}
return v
}
func (v *srsRTMPServer) Close() error {
if v.listener != nil {
v.listener.Close()
}
v.wg.Wait()
return nil
}
func (v *srsRTMPServer) Run(ctx context.Context) error {
endpoint := envRtmpServer()
if !strings.Contains(endpoint, ":") {
endpoint = ":" + endpoint
}
addr, err := net.ResolveTCPAddr("tcp", endpoint)
if err != nil {
return errors.Wrapf(err, "resolve rtmp addr %v", endpoint)
}
listener, err := net.ListenTCP("tcp", addr)
if err != nil {
return errors.Wrapf(err, "listen rtmp addr %v", addr)
}
v.listener = listener
logger.Df(ctx, "RTMP server listen at %v", addr)
v.wg.Add(1)
go func() {
defer v.wg.Done()
for {
conn, err := v.listener.AcceptTCP()
if err != nil {
if ctx.Err() != context.Canceled {
// TODO: If RTMP server closed unexpectedly, we should notice the main loop to quit.
logger.Wf(ctx, "RTMP server accept err %+v", err)
} else {
logger.Df(ctx, "RTMP server done")
}
return
}
v.wg.Add(1)
go func(ctx context.Context, conn *net.TCPConn) {
defer v.wg.Done()
defer conn.Close()
handleErr := func(err error) {
if isPeerClosedError(err) {
logger.Df(ctx, "RTMP peer is closed")
} else {
logger.Wf(ctx, "RTMP serve err %+v", err)
}
}
rc := NewRTMPConnection(func(client *RTMPConnection) {
client.rd = v.rd
})
if err := rc.serve(ctx, conn); err != nil {
handleErr(err)
} else {
logger.Df(ctx, "RTMP client done")
}
}(logger.WithContext(ctx), conn)
}
}()
return nil
}
// RTMPConnection is an RTMP streaming connection. There is no state need to be sync between
// proxy servers.
//
// When we got an RTMP request, we will parse the stream URL from the RTMP publish or play request,
// then proxy to the corresponding backend server. All state is in the RTMP request, so this
// connection is stateless.
type RTMPConnection struct {
// The random number generator.
rd *rand.Rand
}
func NewRTMPConnection(opts ...func(*RTMPConnection)) *RTMPConnection {
v := &RTMPConnection{}
for _, opt := range opts {
opt(v)
}
return v
}
func (v *RTMPConnection) serve(ctx context.Context, conn *net.TCPConn) error {
logger.Df(ctx, "Got RTMP client from %v", conn.RemoteAddr())
// If any goroutine quit, cancel another one.
parentCtx := ctx
ctx, cancel := context.WithCancel(ctx)
defer cancel()
var backend *RTMPClientToBackend
if true {
go func() {
<-ctx.Done()
conn.Close()
if backend != nil {
backend.Close()
}
}()
}
// Simple handshake with client.
hs := rtmp.NewHandshake(v.rd)
if _, err := hs.ReadC0S0(conn); err != nil {
return errors.Wrapf(err, "read c0")
}
if _, err := hs.ReadC1S1(conn); err != nil {
return errors.Wrapf(err, "read c1")
}
if err := hs.WriteC0S0(conn); err != nil {
return errors.Wrapf(err, "write s1")
}
if err := hs.WriteC1S1(conn); err != nil {
return errors.Wrapf(err, "write s1")
}
if err := hs.WriteC2S2(conn, hs.C1S1()); err != nil {
return errors.Wrapf(err, "write s2")
}
if _, err := hs.ReadC2S2(conn); err != nil {
return errors.Wrapf(err, "read c2")
}
client := rtmp.NewProtocol(conn)
logger.Df(ctx, "RTMP simple handshake done")
// Expect RTMP connect command with tcUrl.
var connectReq *rtmp.ConnectAppPacket
if _, err := rtmp.ExpectPacket(ctx, client, &connectReq); err != nil {
return errors.Wrapf(err, "expect connect req")
}
if true {
ack := rtmp.NewWindowAcknowledgementSize()
ack.AckSize = 2500000
if err := client.WritePacket(ctx, ack, 0); err != nil {
return errors.Wrapf(err, "write set ack size")
}
}
if true {
chunk := rtmp.NewSetChunkSize()
chunk.ChunkSize = 128
if err := client.WritePacket(ctx, chunk, 0); err != nil {
return errors.Wrapf(err, "write set chunk size")
}
}
connectRes := rtmp.NewConnectAppResPacket(connectReq.TransactionID)
connectRes.CommandObject.Set("fmsVer", rtmp.NewAmf0String("FMS/3,5,3,888"))
connectRes.CommandObject.Set("capabilities", rtmp.NewAmf0Number(127))
connectRes.CommandObject.Set("mode", rtmp.NewAmf0Number(1))
connectRes.Args.Set("level", rtmp.NewAmf0String("status"))
connectRes.Args.Set("code", rtmp.NewAmf0String("NetConnection.Connect.Success"))
connectRes.Args.Set("description", rtmp.NewAmf0String("Connection succeeded"))
connectRes.Args.Set("objectEncoding", rtmp.NewAmf0Number(0))
connectResData := rtmp.NewAmf0EcmaArray()
connectResData.Set("version", rtmp.NewAmf0String("3,5,3,888"))
connectResData.Set("srs_version", rtmp.NewAmf0String(Version()))
connectResData.Set("srs_id", rtmp.NewAmf0String(logger.ContextID(ctx)))
connectRes.Args.Set("data", connectResData)
if err := client.WritePacket(ctx, connectRes, 0); err != nil {
return errors.Wrapf(err, "write connect res")
}
tcUrl := connectReq.TcUrl()
logger.Df(ctx, "RTMP connect app %v", tcUrl)
// Expect RTMP command to identify the client, a publisher or viewer.
var currentStreamID, nextStreamID int
var streamName string
var clientType RTMPClientType
for clientType == "" {
var identifyReq rtmp.Packet
if _, err := rtmp.ExpectPacket(ctx, client, &identifyReq); err != nil {
return errors.Wrapf(err, "expect identify req")
}
var response rtmp.Packet
switch pkt := identifyReq.(type) {
case *rtmp.CallPacket:
if pkt.CommandName == "createStream" {
identifyRes := rtmp.NewCreateStreamResPacket(pkt.TransactionID)
response = identifyRes
nextStreamID = 1
identifyRes.StreamID = *rtmp.NewAmf0Number(float64(nextStreamID))
} else if pkt.CommandName == "getStreamLength" {
// Ignore and do not reply these packets.
} else {
// For releaseStream, FCPublish, etc.
identifyRes := rtmp.NewCallPacket()
response = identifyRes
identifyRes.TransactionID = pkt.TransactionID
identifyRes.CommandName = "_result"
identifyRes.CommandObject = rtmp.NewAmf0Null()
identifyRes.Args = rtmp.NewAmf0Undefined()
}
case *rtmp.PublishPacket:
streamName = string(pkt.StreamName)
clientType = RTMPClientTypePublisher
identifyRes := rtmp.NewCallPacket()
response = identifyRes
identifyRes.CommandName = "onFCPublish"
identifyRes.CommandObject = rtmp.NewAmf0Null()
data := rtmp.NewAmf0Object()
data.Set("code", rtmp.NewAmf0String("NetStream.Publish.Start"))
data.Set("description", rtmp.NewAmf0String("Started publishing stream."))
identifyRes.Args = data
case *rtmp.PlayPacket:
streamName = string(pkt.StreamName)
clientType = RTMPClientTypeViewer
identifyRes := rtmp.NewCallPacket()
response = identifyRes
identifyRes.CommandName = "onStatus"
identifyRes.CommandObject = rtmp.NewAmf0Null()
data := rtmp.NewAmf0Object()
data.Set("level", rtmp.NewAmf0String("status"))
data.Set("code", rtmp.NewAmf0String("NetStream.Play.Reset"))
data.Set("description", rtmp.NewAmf0String("Playing and resetting stream."))
data.Set("details", rtmp.NewAmf0String("stream"))
data.Set("clientid", rtmp.NewAmf0String("ASAICiss"))
identifyRes.Args = data
}
if response != nil {
if err := client.WritePacket(ctx, response, currentStreamID); err != nil {
return errors.Wrapf(err, "write identify res for req=%v, stream=%v",
identifyReq, currentStreamID)
}
}
// Update the stream ID for next request.
currentStreamID = nextStreamID
}
logger.Df(ctx, "RTMP identify tcUrl=%v, stream=%v, id=%v, type=%v",
tcUrl, streamName, currentStreamID, clientType)
// Find a backend SRS server to proxy the RTMP stream.
backend = NewRTMPClientToBackend(func(client *RTMPClientToBackend) {
client.rd, client.typ = v.rd, clientType
})
defer backend.Close()
if err := backend.Connect(ctx, tcUrl, streamName); err != nil {
return errors.Wrapf(err, "connect backend, tcUrl=%v, stream=%v", tcUrl, streamName)
}
// Start the streaming.
if clientType == RTMPClientTypePublisher {
identifyRes := rtmp.NewCallPacket()
identifyRes.CommandName = "onStatus"
identifyRes.CommandObject = rtmp.NewAmf0Null()
data := rtmp.NewAmf0Object()
data.Set("level", rtmp.NewAmf0String("status"))
data.Set("code", rtmp.NewAmf0String("NetStream.Publish.Start"))
data.Set("description", rtmp.NewAmf0String("Started publishing stream."))
data.Set("clientid", rtmp.NewAmf0String("ASAICiss"))
identifyRes.Args = data
if err := client.WritePacket(ctx, identifyRes, currentStreamID); err != nil {
return errors.Wrapf(err, "start publish")
}
} else if clientType == RTMPClientTypeViewer {
identifyRes := rtmp.NewCallPacket()
identifyRes.CommandName = "onStatus"
identifyRes.CommandObject = rtmp.NewAmf0Null()
data := rtmp.NewAmf0Object()
data.Set("level", rtmp.NewAmf0String("status"))
data.Set("code", rtmp.NewAmf0String("NetStream.Play.Start"))
data.Set("description", rtmp.NewAmf0String("Started playing stream."))
data.Set("details", rtmp.NewAmf0String("stream"))
data.Set("clientid", rtmp.NewAmf0String("ASAICiss"))
identifyRes.Args = data
if err := client.WritePacket(ctx, identifyRes, currentStreamID); err != nil {
return errors.Wrapf(err, "start play")
}
}
logger.Df(ctx, "RTMP start streaming")
// For all proxy goroutines.
var wg sync.WaitGroup
defer wg.Wait()
// Proxy all message from backend to client.
wg.Add(1)
var r0 error
go func() {
defer wg.Done()
defer cancel()
r0 = func() error {
for {
m, err := backend.client.ReadMessage(ctx)
if err != nil {
return errors.Wrapf(err, "read message")
}
//logger.Df(ctx, "client<- %v %v %vB", m.MessageType, m.Timestamp, len(m.Payload))
// TODO: Update the stream ID if not the same.
if err := client.WriteMessage(ctx, m); err != nil {
return errors.Wrapf(err, "write message")
}
}
}()
}()
// Proxy all messages from client to backend.
wg.Add(1)
var r1 error
go func() {
defer wg.Done()
defer cancel()
r1 = func() error {
for {
m, err := client.ReadMessage(ctx)
if err != nil {
return errors.Wrapf(err, "read message")
}
//logger.Df(ctx, "client-> %v %v %vB", m.MessageType, m.Timestamp, len(m.Payload))
// TODO: Update the stream ID if not the same.
if err := backend.client.WriteMessage(ctx, m); err != nil {
return errors.Wrapf(err, "write message")
}
}
}()
}()
// Wait until all goroutine quit.
wg.Wait()
// Reset the error if caused by another goroutine.
if r0 != nil {
return errors.Wrapf(r0, "proxy backend->client")
}
if r1 != nil {
return errors.Wrapf(r1, "proxy client->backend")
}
return parentCtx.Err()
}
type RTMPClientType string
const (
RTMPClientTypePublisher RTMPClientType = "publisher"
RTMPClientTypeViewer RTMPClientType = "viewer"
)
// RTMPClientToBackend is a RTMP client to proxy the RTMP stream to backend.
type RTMPClientToBackend struct {
// The random number generator.
rd *rand.Rand
// The underlayer tcp client.
tcpConn *net.TCPConn
// The RTMP protocol client.
client *rtmp.Protocol
// The stream type.
typ RTMPClientType
}
func NewRTMPClientToBackend(opts ...func(*RTMPClientToBackend)) *RTMPClientToBackend {
v := &RTMPClientToBackend{}
for _, opt := range opts {
opt(v)
}
return v
}
func (v *RTMPClientToBackend) Close() error {
if v.tcpConn != nil {
v.tcpConn.Close()
}
return nil
}
func (v *RTMPClientToBackend) Connect(ctx context.Context, tcUrl, streamName string) error {
// Build the stream URL in vhost/app/stream schema.
streamURL, err := buildStreamURL(fmt.Sprintf("%v/%v", tcUrl, streamName))
if err != nil {
return errors.Wrapf(err, "build stream url %v/%v", tcUrl, streamName)
}
// Pick a backend SRS server to proxy the RTMP stream.
backend, err := srsLoadBalancer.Pick(ctx, streamURL)
if err != nil {
return errors.Wrapf(err, "pick backend for %v", streamURL)
}
// Parse RTMP port from backend.
if len(backend.RTMP) == 0 {
return errors.Errorf("no rtmp server %+v for %v", backend, streamURL)
}
var rtmpPort int
if iv, err := strconv.ParseInt(backend.RTMP[0], 10, 64); err != nil {
return errors.Wrapf(err, "parse backend %+v rtmp port %v", backend, backend.RTMP[0])
} else {
rtmpPort = int(iv)
}
// Connect to backend SRS server via TCP client.
addr := &net.TCPAddr{IP: net.ParseIP(backend.IP), Port: rtmpPort}
c, err := net.DialTCP("tcp", nil, addr)
if err != nil {
return errors.Wrapf(err, "dial backend addr=%v, srs=%v", addr, backend)
}
v.tcpConn = c
hs := rtmp.NewHandshake(v.rd)
client := rtmp.NewProtocol(c)
v.client = client
// Simple RTMP handshake with server.
if err := hs.WriteC0S0(c); err != nil {
return errors.Wrapf(err, "write c0")
}
if err := hs.WriteC1S1(c); err != nil {
return errors.Wrapf(err, "write c1")
}
if _, err = hs.ReadC0S0(c); err != nil {
return errors.Wrapf(err, "read s0")
}
if _, err := hs.ReadC1S1(c); err != nil {
return errors.Wrapf(err, "read s1")
}
if _, err = hs.ReadC2S2(c); err != nil {
return errors.Wrapf(err, "read c2")
}
logger.Df(ctx, "backend simple handshake done, server=%v", addr)
if err := hs.WriteC2S2(c, hs.C1S1()); err != nil {
return errors.Wrapf(err, "write c2")
}
// Connect RTMP app on tcUrl with server.
if true {
connectApp := rtmp.NewConnectAppPacket()
connectApp.CommandObject.Set("tcUrl", rtmp.NewAmf0String(tcUrl))
if err := client.WritePacket(ctx, connectApp, 1); err != nil {
return errors.Wrapf(err, "write connect app")
}
}
if true {
var connectAppRes *rtmp.ConnectAppResPacket
if _, err := rtmp.ExpectPacket(ctx, client, &connectAppRes); err != nil {
return errors.Wrapf(err, "expect connect app res")
}
logger.Df(ctx, "backend connect RTMP app, tcUrl=%v, id=%v", tcUrl, connectAppRes.SrsID())
}
// Play or view RTMP stream with server.
if v.typ == RTMPClientTypeViewer {
return v.play(ctx, client, streamName)
}
// Publish RTMP stream with server.
return v.publish(ctx, client, streamName)
}
func (v *RTMPClientToBackend) publish(ctx context.Context, client *rtmp.Protocol, streamName string) error {
if true {
identifyReq := rtmp.NewCallPacket()
identifyReq.CommandName = "releaseStream"
identifyReq.TransactionID = 2
identifyReq.CommandObject = rtmp.NewAmf0Null()
identifyReq.Args = rtmp.NewAmf0String(streamName)
if err := client.WritePacket(ctx, identifyReq, 0); err != nil {
return errors.Wrapf(err, "releaseStream")
}
}
for {
var identifyRes *rtmp.CallPacket
if _, err := rtmp.ExpectPacket(ctx, client, &identifyRes); err != nil {
return errors.Wrapf(err, "expect releaseStream res")
}
if identifyRes.CommandName == "_result" {
break
}
}
if true {
identifyReq := rtmp.NewCallPacket()
identifyReq.CommandName = "FCPublish"
identifyReq.TransactionID = 3
identifyReq.CommandObject = rtmp.NewAmf0Null()
identifyReq.Args = rtmp.NewAmf0String(streamName)
if err := client.WritePacket(ctx, identifyReq, 0); err != nil {
return errors.Wrapf(err, "FCPublish")
}
}
for {
var identifyRes *rtmp.CallPacket
if _, err := rtmp.ExpectPacket(ctx, client, &identifyRes); err != nil {
return errors.Wrapf(err, "expect FCPublish res")
}
if identifyRes.CommandName == "_result" {
break
}
}
var currentStreamID int
if true {
createStream := rtmp.NewCreateStreamPacket()
createStream.TransactionID = 4
createStream.CommandObject = rtmp.NewAmf0Null()
if err := client.WritePacket(ctx, createStream, 0); err != nil {
return errors.Wrapf(err, "createStream")
}
}
for {
var identifyRes *rtmp.CreateStreamResPacket
if _, err := rtmp.ExpectPacket(ctx, client, &identifyRes); err != nil {
return errors.Wrapf(err, "expect createStream res")
}
if sid := identifyRes.StreamID; sid != 0 {
currentStreamID = int(sid)
break
}
}
if true {
publishStream := rtmp.NewPublishPacket()
publishStream.TransactionID = 5
publishStream.CommandObject = rtmp.NewAmf0Null()
publishStream.StreamName = *rtmp.NewAmf0String(streamName)
publishStream.StreamType = *rtmp.NewAmf0String("live")
if err := client.WritePacket(ctx, publishStream, currentStreamID); err != nil {
return errors.Wrapf(err, "publish")
}
}
for {
var identifyRes *rtmp.CallPacket
if _, err := rtmp.ExpectPacket(ctx, client, &identifyRes); err != nil {
return errors.Wrapf(err, "expect publish res")
}
// Ignore onFCPublish, expect onStatus(NetStream.Publish.Start).
if identifyRes.CommandName == "onStatus" {
if data := rtmp.NewAmf0Converter(identifyRes.Args).ToObject(); data == nil {
return errors.Errorf("onStatus args not object")
} else if code := rtmp.NewAmf0Converter(data.Get("code")).ToString(); code == nil {
return errors.Errorf("onStatus code not string")
} else if *code != "NetStream.Publish.Start" {
return errors.Errorf("onStatus code=%v not NetStream.Publish.Start", *code)
}
break
}
}
logger.Df(ctx, "backend publish stream=%v, sid=%v", streamName, currentStreamID)
return nil
}
func (v *RTMPClientToBackend) play(ctx context.Context, client *rtmp.Protocol, streamName string) error {
var currentStreamID int
if true {
createStream := rtmp.NewCreateStreamPacket()
createStream.TransactionID = 4
createStream.CommandObject = rtmp.NewAmf0Null()
if err := client.WritePacket(ctx, createStream, 0); err != nil {
return errors.Wrapf(err, "createStream")
}
}
for {
var identifyRes *rtmp.CreateStreamResPacket
if _, err := rtmp.ExpectPacket(ctx, client, &identifyRes); err != nil {
return errors.Wrapf(err, "expect createStream res")
}
if sid := identifyRes.StreamID; sid != 0 {
currentStreamID = int(sid)
break
}
}
playStream := rtmp.NewPlayPacket()
playStream.StreamName = *rtmp.NewAmf0String(streamName)
if err := client.WritePacket(ctx, playStream, currentStreamID); err != nil {
return errors.Wrapf(err, "play")
}
for {
var identifyRes *rtmp.CallPacket
if _, err := rtmp.ExpectPacket(ctx, client, &identifyRes); err != nil {
return errors.Wrapf(err, "expect releaseStream res")
}
if identifyRes.CommandName == "onStatus" && identifyRes.ArgsCode() == "NetStream.Play.Start" {
break
}
}
return nil
}

@ -0,0 +1,771 @@
// Copyright (c) 2024 Winlin
//
// SPDX-License-Identifier: MIT
package rtmp
import (
"bytes"
"encoding"
"encoding/binary"
"fmt"
"math"
"sync"
"srs-proxy/errors"
)
// Please read @doc amf0_spec_121207.pdf, @page 4, @section 2.1 Types Overview
type amf0Marker uint8
const (
amf0MarkerNumber amf0Marker = iota // 0
amf0MarkerBoolean // 1
amf0MarkerString // 2
amf0MarkerObject // 3
amf0MarkerMovieClip // 4
amf0MarkerNull // 5
amf0MarkerUndefined // 6
amf0MarkerReference // 7
amf0MarkerEcmaArray // 8
amf0MarkerObjectEnd // 9
amf0MarkerStrictArray // 10
amf0MarkerDate // 11
amf0MarkerLongString // 12
amf0MarkerUnsupported // 13
amf0MarkerRecordSet // 14
amf0MarkerXmlDocument // 15
amf0MarkerTypedObject // 16
amf0MarkerAvmPlusObject // 17
amf0MarkerForbidden amf0Marker = 0xff
)
func (v amf0Marker) String() string {
switch v {
case amf0MarkerNumber:
return "Amf0Number"
case amf0MarkerBoolean:
return "amf0Boolean"
case amf0MarkerString:
return "Amf0String"
case amf0MarkerObject:
return "Amf0Object"
case amf0MarkerNull:
return "Null"
case amf0MarkerUndefined:
return "Undefined"
case amf0MarkerReference:
return "Reference"
case amf0MarkerEcmaArray:
return "EcmaArray"
case amf0MarkerObjectEnd:
return "ObjectEnd"
case amf0MarkerStrictArray:
return "StrictArray"
case amf0MarkerDate:
return "Date"
case amf0MarkerLongString:
return "LongString"
case amf0MarkerUnsupported:
return "Unsupported"
case amf0MarkerXmlDocument:
return "XmlDocument"
case amf0MarkerTypedObject:
return "TypedObject"
case amf0MarkerAvmPlusObject:
return "AvmPlusObject"
case amf0MarkerMovieClip:
return "MovieClip"
case amf0MarkerRecordSet:
return "RecordSet"
default:
return "Forbidden"
}
}
// For utest to mock it.
type amf0Buffer interface {
Bytes() []byte
WriteByte(c byte) error
Write(p []byte) (n int, err error)
}
var createBuffer = func() amf0Buffer {
return &bytes.Buffer{}
}
// All AMF0 things.
type amf0Any interface {
// Binary marshaler and unmarshaler.
encoding.BinaryUnmarshaler
encoding.BinaryMarshaler
// Get the size of bytes to marshal this object.
Size() int
// Get the Marker of any AMF0 stuff.
amf0Marker() amf0Marker
}
type amf0Converter struct {
from amf0Any
}
func NewAmf0Converter(from amf0Any) *amf0Converter {
return &amf0Converter{from: from}
}
func (v *amf0Converter) ToNumber() *amf0Number {
return amf0AnyTo[*amf0Number](v.from)
}
func (v *amf0Converter) ToBoolean() *amf0Boolean {
return amf0AnyTo[*amf0Boolean](v.from)
}
func (v *amf0Converter) ToString() *amf0String {
return amf0AnyTo[*amf0String](v.from)
}
func (v *amf0Converter) ToObject() *amf0Object {
return amf0AnyTo[*amf0Object](v.from)
}
func (v *amf0Converter) ToNull() *amf0Null {
return amf0AnyTo[*amf0Null](v.from)
}
func (v *amf0Converter) ToUndefined() *amf0Undefined {
return amf0AnyTo[*amf0Undefined](v.from)
}
func (v *amf0Converter) ToEcmaArray() *amf0EcmaArray {
return amf0AnyTo[*amf0EcmaArray](v.from)
}
func (v *amf0Converter) ToStrictArray() *amf0StrictArray {
return amf0AnyTo[*amf0StrictArray](v.from)
}
// Convert any to specified object.
func amf0AnyTo[T amf0Any](a amf0Any) T {
var to T
if a != nil {
if v, ok := a.(T); ok {
return v
}
}
return to
}
// Discovery the amf0 object from the bytes b.
func Amf0Discovery(p []byte) (a amf0Any, err error) {
if len(p) < 1 {
return nil, errors.Errorf("require 1 bytes only %v", len(p))
}
m := amf0Marker(p[0])
switch m {
case amf0MarkerNumber:
return NewAmf0Number(0), nil
case amf0MarkerBoolean:
return NewAmf0Boolean(false), nil
case amf0MarkerString:
return NewAmf0String(""), nil
case amf0MarkerObject:
return NewAmf0Object(), nil
case amf0MarkerNull:
return NewAmf0Null(), nil
case amf0MarkerUndefined:
return NewAmf0Undefined(), nil
case amf0MarkerReference:
case amf0MarkerEcmaArray:
return NewAmf0EcmaArray(), nil
case amf0MarkerObjectEnd:
return &amf0ObjectEOF{}, nil
case amf0MarkerStrictArray:
return NewAmf0StrictArray(), nil
case amf0MarkerDate, amf0MarkerLongString, amf0MarkerUnsupported, amf0MarkerXmlDocument,
amf0MarkerTypedObject, amf0MarkerAvmPlusObject, amf0MarkerForbidden, amf0MarkerMovieClip,
amf0MarkerRecordSet:
return nil, errors.Errorf("Marker %v is not supported", m)
}
return nil, errors.Errorf("Marker %v is invalid", m)
}
// The UTF8 string, please read @doc amf0_spec_121207.pdf, @page 3, @section 1.3.1 Strings and UTF-8
type amf0UTF8 string
func (v *amf0UTF8) Size() int {
return 2 + len(string(*v))
}
func (v *amf0UTF8) UnmarshalBinary(data []byte) (err error) {
var p []byte
if p = data; len(p) < 2 {
return errors.Errorf("require 2 bytes only %v", len(p))
}
size := uint16(p[0])<<8 | uint16(p[1])
if p = data[2:]; len(p) < int(size) {
return errors.Errorf("require %v bytes only %v", int(size), len(p))
}
*v = amf0UTF8(string(p[:size]))
return
}
func (v *amf0UTF8) MarshalBinary() (data []byte, err error) {
data = make([]byte, v.Size())
size := uint16(len(string(*v)))
data[0] = byte(size >> 8)
data[1] = byte(size)
if size > 0 {
copy(data[2:], []byte(*v))
}
return
}
// The number object, please read @doc amf0_spec_121207.pdf, @page 5, @section 2.2 Number Type
type amf0Number float64
func NewAmf0Number(f float64) *amf0Number {
v := amf0Number(f)
return &v
}
func (v *amf0Number) amf0Marker() amf0Marker {
return amf0MarkerNumber
}
func (v *amf0Number) Size() int {
return 1 + 8
}
func (v *amf0Number) UnmarshalBinary(data []byte) (err error) {
var p []byte
if p = data; len(p) < 9 {
return errors.Errorf("require 9 bytes only %v", len(p))
}
if m := amf0Marker(p[0]); m != amf0MarkerNumber {
return errors.Errorf("Amf0Number amf0Marker %v is illegal", m)
}
f := binary.BigEndian.Uint64(p[1:])
*v = amf0Number(math.Float64frombits(f))
return
}
func (v *amf0Number) MarshalBinary() (data []byte, err error) {
data = make([]byte, 9)
data[0] = byte(amf0MarkerNumber)
f := math.Float64bits(float64(*v))
binary.BigEndian.PutUint64(data[1:], f)
return
}
// The string objet, please read @doc amf0_spec_121207.pdf, @page 5, @section 2.4 String Type
type amf0String string
func NewAmf0String(s string) *amf0String {
v := amf0String(s)
return &v
}
func (v *amf0String) amf0Marker() amf0Marker {
return amf0MarkerString
}
func (v *amf0String) Size() int {
u := amf0UTF8(*v)
return 1 + u.Size()
}
func (v *amf0String) UnmarshalBinary(data []byte) (err error) {
var p []byte
if p = data; len(p) < 1 {
return errors.Errorf("require 1 bytes only %v", len(p))
}
if m := amf0Marker(p[0]); m != amf0MarkerString {
return errors.Errorf("Amf0String amf0Marker %v is illegal", m)
}
var sv amf0UTF8
if err = sv.UnmarshalBinary(p[1:]); err != nil {
return errors.WithMessage(err, "utf8")
}
*v = amf0String(string(sv))
return
}
func (v *amf0String) MarshalBinary() (data []byte, err error) {
u := amf0UTF8(*v)
var pb []byte
if pb, err = u.MarshalBinary(); err != nil {
return nil, errors.WithMessage(err, "utf8")
}
data = append([]byte{byte(amf0MarkerString)}, pb...)
return
}
// The AMF0 object end type, please read @doc amf0_spec_121207.pdf, @page 5, @section 2.11 Object End Type
type amf0ObjectEOF struct {
}
func (v *amf0ObjectEOF) amf0Marker() amf0Marker {
return amf0MarkerObjectEnd
}
func (v *amf0ObjectEOF) Size() int {
return 3
}
func (v *amf0ObjectEOF) UnmarshalBinary(data []byte) (err error) {
p := data
if len(p) < 3 {
return errors.Errorf("require 3 bytes only %v", len(p))
}
if p[0] != 0 || p[1] != 0 || p[2] != 9 {
return errors.Errorf("EOF amf0Marker %v is illegal", p[0:3])
}
return
}
func (v *amf0ObjectEOF) MarshalBinary() (data []byte, err error) {
return []byte{0, 0, 9}, nil
}
// Use array for object and ecma array, to keep the original order.
type amf0Property struct {
key amf0UTF8
value amf0Any
}
// The object-like AMF0 structure, like object and ecma array and strict array.
type amf0ObjectBase struct {
properties []*amf0Property
lock sync.Mutex
}
func (v *amf0ObjectBase) Size() int {
v.lock.Lock()
defer v.lock.Unlock()
var size int
for _, p := range v.properties {
key, value := p.key, p.value
size += key.Size() + value.Size()
}
return size
}
func (v *amf0ObjectBase) Get(key string) amf0Any {
v.lock.Lock()
defer v.lock.Unlock()
for _, p := range v.properties {
if string(p.key) == key {
return p.value
}
}
return nil
}
func (v *amf0ObjectBase) Set(key string, value amf0Any) *amf0ObjectBase {
v.lock.Lock()
defer v.lock.Unlock()
prop := &amf0Property{key: amf0UTF8(key), value: value}
var ok bool
for i, p := range v.properties {
if string(p.key) == key {
v.properties[i] = prop
ok = true
}
}
if !ok {
v.properties = append(v.properties, prop)
}
return v
}
func (v *amf0ObjectBase) unmarshal(p []byte, eof bool, maxElems int) (err error) {
// if no eof, elems specified by maxElems.
if !eof && maxElems < 0 {
return errors.Errorf("maxElems=%v without eof", maxElems)
}
// if eof, maxElems must be -1.
if eof && maxElems != -1 {
return errors.Errorf("maxElems=%v with eof", maxElems)
}
readOne := func() (amf0UTF8, amf0Any, error) {
var u amf0UTF8
if err = u.UnmarshalBinary(p); err != nil {
return "", nil, errors.WithMessage(err, "prop name")
}
p = p[u.Size():]
var a amf0Any
if a, err = Amf0Discovery(p); err != nil {
return "", nil, errors.WithMessage(err, fmt.Sprintf("discover prop %v", string(u)))
}
return u, a, nil
}
pushOne := func(u amf0UTF8, a amf0Any) error {
// For object property, consume the whole bytes.
if err = a.UnmarshalBinary(p); err != nil {
return errors.WithMessage(err, fmt.Sprintf("unmarshal prop %v", string(u)))
}
v.Set(string(u), a)
p = p[a.Size():]
return nil
}
for eof {
u, a, err := readOne()
if err != nil {
return errors.WithMessage(err, "read")
}
// For object EOF, we should only consume total 3bytes.
if u.Size() == 2 && a.amf0Marker() == amf0MarkerObjectEnd {
// 2 bytes is consumed by u(name), the a(eof) should only consume 1 byte.
p = p[1:]
return nil
}
if err := pushOne(u, a); err != nil {
return errors.WithMessage(err, "push")
}
}
for len(v.properties) < maxElems {
u, a, err := readOne()
if err != nil {
return errors.WithMessage(err, "read")
}
if err := pushOne(u, a); err != nil {
return errors.WithMessage(err, "push")
}
}
return
}
func (v *amf0ObjectBase) marshal(b amf0Buffer) (err error) {
v.lock.Lock()
defer v.lock.Unlock()
var pb []byte
for _, p := range v.properties {
key, value := p.key, p.value
if pb, err = key.MarshalBinary(); err != nil {
return errors.WithMessage(err, fmt.Sprintf("marshal %v", string(key)))
}
if _, err = b.Write(pb); err != nil {
return errors.Wrapf(err, "write %v", string(key))
}
if pb, err = value.MarshalBinary(); err != nil {
return errors.WithMessage(err, fmt.Sprintf("marshal value for %v", string(key)))
}
if _, err = b.Write(pb); err != nil {
return errors.Wrapf(err, "marshal value for %v", string(key))
}
}
return
}
// The AMF0 object, please read @doc amf0_spec_121207.pdf, @page 5, @section 2.5 Object Type
type amf0Object struct {
amf0ObjectBase
eof amf0ObjectEOF
}
func NewAmf0Object() *amf0Object {
v := &amf0Object{}
v.properties = []*amf0Property{}
return v
}
func (v *amf0Object) amf0Marker() amf0Marker {
return amf0MarkerObject
}
func (v *amf0Object) Size() int {
return int(1) + v.eof.Size() + v.amf0ObjectBase.Size()
}
func (v *amf0Object) UnmarshalBinary(data []byte) (err error) {
var p []byte
if p = data; len(p) < 1 {
return errors.Errorf("require 1 byte only %v", len(p))
}
if m := amf0Marker(p[0]); m != amf0MarkerObject {
return errors.Errorf("Amf0Object amf0Marker %v is illegal", m)
}
p = p[1:]
if err = v.unmarshal(p, true, -1); err != nil {
return errors.WithMessage(err, "unmarshal")
}
return
}
func (v *amf0Object) MarshalBinary() (data []byte, err error) {
b := createBuffer()
if err = b.WriteByte(byte(amf0MarkerObject)); err != nil {
return nil, errors.Wrap(err, "marshal")
}
if err = v.marshal(b); err != nil {
return nil, errors.WithMessage(err, "marshal")
}
var pb []byte
if pb, err = v.eof.MarshalBinary(); err != nil {
return nil, errors.WithMessage(err, "marshal")
}
if _, err = b.Write(pb); err != nil {
return nil, errors.Wrap(err, "marshal")
}
return b.Bytes(), nil
}
// The AMF0 ecma array, please read @doc amf0_spec_121207.pdf, @page 6, @section 2.10 ECMA Array Type
type amf0EcmaArray struct {
amf0ObjectBase
count uint32
eof amf0ObjectEOF
}
func NewAmf0EcmaArray() *amf0EcmaArray {
v := &amf0EcmaArray{}
v.properties = []*amf0Property{}
return v
}
func (v *amf0EcmaArray) amf0Marker() amf0Marker {
return amf0MarkerEcmaArray
}
func (v *amf0EcmaArray) Size() int {
return int(1) + 4 + v.eof.Size() + v.amf0ObjectBase.Size()
}
func (v *amf0EcmaArray) UnmarshalBinary(data []byte) (err error) {
var p []byte
if p = data; len(p) < 5 {
return errors.Errorf("require 5 bytes only %v", len(p))
}
if m := amf0Marker(p[0]); m != amf0MarkerEcmaArray {
return errors.Errorf("EcmaArray amf0Marker %v is illegal", m)
}
v.count = binary.BigEndian.Uint32(p[1:])
p = p[5:]
if err = v.unmarshal(p, true, -1); err != nil {
return errors.WithMessage(err, "unmarshal")
}
return
}
func (v *amf0EcmaArray) MarshalBinary() (data []byte, err error) {
b := createBuffer()
if err = b.WriteByte(byte(amf0MarkerEcmaArray)); err != nil {
return nil, errors.Wrap(err, "marshal")
}
if err = binary.Write(b, binary.BigEndian, v.count); err != nil {
return nil, errors.Wrap(err, "marshal")
}
if err = v.marshal(b); err != nil {
return nil, errors.WithMessage(err, "marshal")
}
var pb []byte
if pb, err = v.eof.MarshalBinary(); err != nil {
return nil, errors.WithMessage(err, "marshal")
}
if _, err = b.Write(pb); err != nil {
return nil, errors.Wrap(err, "marshal")
}
return b.Bytes(), nil
}
// The AMF0 strict array, please read @doc amf0_spec_121207.pdf, @page 7, @section 2.12 Strict Array Type
type amf0StrictArray struct {
amf0ObjectBase
count uint32
}
func NewAmf0StrictArray() *amf0StrictArray {
v := &amf0StrictArray{}
v.properties = []*amf0Property{}
return v
}
func (v *amf0StrictArray) amf0Marker() amf0Marker {
return amf0MarkerStrictArray
}
func (v *amf0StrictArray) Size() int {
return int(1) + 4 + v.amf0ObjectBase.Size()
}
func (v *amf0StrictArray) UnmarshalBinary(data []byte) (err error) {
var p []byte
if p = data; len(p) < 5 {
return errors.Errorf("require 5 bytes only %v", len(p))
}
if m := amf0Marker(p[0]); m != amf0MarkerStrictArray {
return errors.Errorf("StrictArray amf0Marker %v is illegal", m)
}
v.count = binary.BigEndian.Uint32(p[1:])
p = p[5:]
if int(v.count) <= 0 {
return
}
if err = v.unmarshal(p, false, int(v.count)); err != nil {
return errors.WithMessage(err, "unmarshal")
}
return
}
func (v *amf0StrictArray) MarshalBinary() (data []byte, err error) {
b := createBuffer()
if err = b.WriteByte(byte(amf0MarkerStrictArray)); err != nil {
return nil, errors.Wrap(err, "marshal")
}
if err = binary.Write(b, binary.BigEndian, v.count); err != nil {
return nil, errors.Wrap(err, "marshal")
}
if err = v.marshal(b); err != nil {
return nil, errors.WithMessage(err, "marshal")
}
return b.Bytes(), nil
}
// The single amf0Marker object, for all AMF0 which only has the amf0Marker, like null and undefined.
type amf0SingleMarkerObject struct {
target amf0Marker
}
func newAmf0SingleMarkerObject(m amf0Marker) amf0SingleMarkerObject {
return amf0SingleMarkerObject{target: m}
}
func (v *amf0SingleMarkerObject) amf0Marker() amf0Marker {
return v.target
}
func (v *amf0SingleMarkerObject) Size() int {
return int(1)
}
func (v *amf0SingleMarkerObject) UnmarshalBinary(data []byte) (err error) {
var p []byte
if p = data; len(p) < 1 {
return errors.Errorf("require 1 byte only %v", len(p))
}
if m := amf0Marker(p[0]); m != v.target {
return errors.Errorf("%v amf0Marker %v is illegal", v.target, m)
}
return
}
func (v *amf0SingleMarkerObject) MarshalBinary() (data []byte, err error) {
return []byte{byte(v.target)}, nil
}
// The AMF0 null, please read @doc amf0_spec_121207.pdf, @page 6, @section 2.7 null Type
type amf0Null struct {
amf0SingleMarkerObject
}
func NewAmf0Null() *amf0Null {
v := amf0Null{}
v.amf0SingleMarkerObject = newAmf0SingleMarkerObject(amf0MarkerNull)
return &v
}
// The AMF0 undefined, please read @doc amf0_spec_121207.pdf, @page 6, @section 2.8 undefined Type
type amf0Undefined struct {
amf0SingleMarkerObject
}
func NewAmf0Undefined() amf0Any {
v := amf0Undefined{}
v.amf0SingleMarkerObject = newAmf0SingleMarkerObject(amf0MarkerUndefined)
return &v
}
// The AMF0 boolean, please read @doc amf0_spec_121207.pdf, @page 5, @section 2.3 Boolean Type
type amf0Boolean bool
func NewAmf0Boolean(b bool) amf0Any {
v := amf0Boolean(b)
return &v
}
func (v *amf0Boolean) amf0Marker() amf0Marker {
return amf0MarkerBoolean
}
func (v *amf0Boolean) Size() int {
return int(2)
}
func (v *amf0Boolean) UnmarshalBinary(data []byte) (err error) {
var p []byte
if p = data; len(p) < 2 {
return errors.Errorf("require 2 bytes only %v", len(p))
}
if m := amf0Marker(p[0]); m != amf0MarkerBoolean {
return errors.Errorf("BOOL amf0Marker %v is illegal", m)
}
if p[1] == 0 {
*v = false
} else {
*v = true
}
return
}
func (v *amf0Boolean) MarshalBinary() (data []byte, err error) {
var b byte
if *v {
b = 1
}
return []byte{byte(amf0MarkerBoolean), b}, nil
}

File diff suppressed because it is too large Load Diff

@ -0,0 +1,44 @@
// Copyright (c) 2024 Winlin
//
// SPDX-License-Identifier: MIT
package main
import (
"context"
"os"
"os/signal"
"syscall"
"time"
"srs-proxy/errors"
"srs-proxy/logger"
)
func installSignals(ctx context.Context, cancel context.CancelFunc) {
sc := make(chan os.Signal, 1)
signal.Notify(sc, syscall.SIGINT, syscall.SIGTERM, os.Interrupt)
go func() {
for s := range sc {
logger.Df(ctx, "Got signal %v", s)
cancel()
}
}()
}
func installForceQuit(ctx context.Context) error {
var forceTimeout time.Duration
if t, err := time.ParseDuration(envForceQuitTimeout()); err != nil {
return errors.Wrapf(err, "parse force timeout %v", envForceQuitTimeout())
} else {
forceTimeout = t
}
go func() {
<-ctx.Done()
time.Sleep(forceTimeout)
logger.Wf(ctx, "Force to exit by timeout")
os.Exit(1)
}()
return nil
}

@ -0,0 +1,553 @@
// Copyright (c) 2024 Winlin
//
// SPDX-License-Identifier: MIT
package main
import (
"context"
"encoding/json"
"fmt"
"math/rand"
"os"
"strconv"
"strings"
"time"
// Use v8 because we use Go 1.16+, while v9 requires Go 1.18+
"github.com/go-redis/redis/v8"
"srs-proxy/errors"
"srs-proxy/logger"
"srs-proxy/sync"
)
// If server heartbeat in this duration, it's alive.
const srsServerAliveDuration = 300 * time.Second
// If HLS streaming update in this duration, it's alive.
const srsHLSAliveDuration = 120 * time.Second
// If WebRTC streaming update in this duration, it's alive.
const srsRTCAliveDuration = 120 * time.Second
type SRSServer struct {
// The server IP.
IP string `json:"ip,omitempty"`
// The server device ID, configured by user.
DeviceID string `json:"device_id,omitempty"`
// The server id of SRS, store in file, may not change, mandatory.
ServerID string `json:"server_id,omitempty"`
// The service id of SRS, always change when restarted, mandatory.
ServiceID string `json:"service_id,omitempty"`
// The process id of SRS, always change when restarted, mandatory.
PID string `json:"pid,omitempty"`
// The RTMP listen endpoints.
RTMP []string `json:"rtmp,omitempty"`
// The HTTP Stream listen endpoints.
HTTP []string `json:"http,omitempty"`
// The HTTP API listen endpoints.
API []string `json:"api,omitempty"`
// The SRT server listen endpoints.
SRT []string `json:"srt,omitempty"`
// The RTC server listen endpoints.
RTC []string `json:"rtc,omitempty"`
// Last update time.
UpdatedAt time.Time `json:"update_at,omitempty"`
}
func (v *SRSServer) ID() string {
return fmt.Sprintf("%v-%v-%v", v.ServerID, v.ServiceID, v.PID)
}
func (v *SRSServer) String() string {
return fmt.Sprintf("%v", v)
}
func (v *SRSServer) Format(f fmt.State, c rune) {
switch c {
case 'v', 's':
if f.Flag('+') {
var sb strings.Builder
sb.WriteString(fmt.Sprintf("pid=%v, server=%v, service=%v", v.PID, v.ServerID, v.ServiceID))
if v.DeviceID != "" {
sb.WriteString(fmt.Sprintf(", device=%v", v.DeviceID))
}
if len(v.RTMP) > 0 {
sb.WriteString(fmt.Sprintf(", rtmp=[%v]", strings.Join(v.RTMP, ",")))
}
if len(v.HTTP) > 0 {
sb.WriteString(fmt.Sprintf(", http=[%v]", strings.Join(v.HTTP, ",")))
}
if len(v.API) > 0 {
sb.WriteString(fmt.Sprintf(", api=[%v]", strings.Join(v.API, ",")))
}
if len(v.SRT) > 0 {
sb.WriteString(fmt.Sprintf(", srt=[%v]", strings.Join(v.SRT, ",")))
}
if len(v.RTC) > 0 {
sb.WriteString(fmt.Sprintf(", rtc=[%v]", strings.Join(v.RTC, ",")))
}
sb.WriteString(fmt.Sprintf(", update=%v", v.UpdatedAt.Format("2006-01-02 15:04:05.999")))
fmt.Fprintf(f, "SRS ip=%v, id=%v, %v", v.IP, v.ID(), sb.String())
} else {
fmt.Fprintf(f, "SRS ip=%v, id=%v", v.IP, v.ID())
}
default:
fmt.Fprintf(f, "%v, fmt=%%%c", v, c)
}
}
func NewSRSServer(opts ...func(*SRSServer)) *SRSServer {
v := &SRSServer{}
for _, opt := range opts {
opt(v)
}
return v
}
// NewDefaultSRSForDebugging initialize the default SRS media server, for debugging only.
func NewDefaultSRSForDebugging() (*SRSServer, error) {
if envDefaultBackendEnabled() != "on" {
return nil, nil
}
if envDefaultBackendIP() == "" {
return nil, fmt.Errorf("empty default backend ip")
}
if envDefaultBackendRTMP() == "" {
return nil, fmt.Errorf("empty default backend rtmp")
}
server := NewSRSServer(func(srs *SRSServer) {
srs.IP = envDefaultBackendIP()
srs.RTMP = []string{envDefaultBackendRTMP()}
srs.ServerID = fmt.Sprintf("default-%v", logger.GenerateContextID())
srs.ServiceID = logger.GenerateContextID()
srs.PID = fmt.Sprintf("%v", os.Getpid())
srs.UpdatedAt = time.Now()
})
if envDefaultBackendHttp() != "" {
server.HTTP = []string{envDefaultBackendHttp()}
}
if envDefaultBackendAPI() != "" {
server.API = []string{envDefaultBackendAPI()}
}
if envDefaultBackendRTC() != "" {
server.RTC = []string{envDefaultBackendRTC()}
}
if envDefaultBackendSRT() != "" {
server.SRT = []string{envDefaultBackendSRT()}
}
return server, nil
}
// SRSLoadBalancer is the interface to load balance the SRS servers.
type SRSLoadBalancer interface {
// Initialize the load balancer.
Initialize(ctx context.Context) error
// Update the backer server.
Update(ctx context.Context, server *SRSServer) error
// Pick a backend server for the specified stream URL.
Pick(ctx context.Context, streamURL string) (*SRSServer, error)
// Load or store the HLS streaming for the specified stream URL.
LoadOrStoreHLS(ctx context.Context, streamURL string, value *HLSPlayStream) (*HLSPlayStream, error)
// Load the HLS streaming by SPBHID, the SRS Proxy Backend HLS ID.
LoadHLSBySPBHID(ctx context.Context, spbhid string) (*HLSPlayStream, error)
// Store the WebRTC streaming for the specified stream URL.
StoreWebRTC(ctx context.Context, streamURL string, value *RTCConnection) error
// Load the WebRTC streaming by ufrag, the ICE username.
LoadWebRTCByUfrag(ctx context.Context, ufrag string) (*RTCConnection, error)
}
// srsLoadBalancer is the global SRS load balancer.
var srsLoadBalancer SRSLoadBalancer
// srsMemoryLoadBalancer stores state in memory.
type srsMemoryLoadBalancer struct {
// All available SRS servers, key is server ID.
servers sync.Map[string, *SRSServer]
// The picked server to servce client by specified stream URL, key is stream url.
picked sync.Map[string, *SRSServer]
// The HLS streaming, key is stream URL.
hlsStreamURL sync.Map[string, *HLSPlayStream]
// The HLS streaming, key is SPBHID.
hlsSPBHID sync.Map[string, *HLSPlayStream]
// The WebRTC streaming, key is stream URL.
rtcStreamURL sync.Map[string, *RTCConnection]
// The WebRTC streaming, key is ufrag.
rtcUfrag sync.Map[string, *RTCConnection]
}
func NewMemoryLoadBalancer() SRSLoadBalancer {
return &srsMemoryLoadBalancer{}
}
func (v *srsMemoryLoadBalancer) Initialize(ctx context.Context) error {
if server, err := NewDefaultSRSForDebugging(); err != nil {
return errors.Wrapf(err, "initialize default SRS")
} else if server != nil {
if err := v.Update(ctx, server); err != nil {
return errors.Wrapf(err, "update default SRS %+v", server)
}
// Keep alive.
go func() {
for {
select {
case <-ctx.Done():
return
case <-time.After(30 * time.Second):
if err := v.Update(ctx, server); err != nil {
logger.Wf(ctx, "update default SRS %+v failed, %+v", server, err)
}
}
}
}()
logger.Df(ctx, "MemoryLB: Initialize default SRS media server, %+v", server)
}
return nil
}
func (v *srsMemoryLoadBalancer) Update(ctx context.Context, server *SRSServer) error {
v.servers.Store(server.ID(), server)
return nil
}
func (v *srsMemoryLoadBalancer) Pick(ctx context.Context, streamURL string) (*SRSServer, error) {
// Always proxy to the same server for the same stream URL.
if server, ok := v.picked.Load(streamURL); ok {
return server, nil
}
// Gather all servers that were alive within the last few seconds.
var servers []*SRSServer
v.servers.Range(func(key string, server *SRSServer) bool {
if time.Since(server.UpdatedAt) < srsServerAliveDuration {
servers = append(servers, server)
}
return true
})
// If no servers available, use all possible servers.
if len(servers) == 0 {
v.servers.Range(func(key string, server *SRSServer) bool {
servers = append(servers, server)
return true
})
}
// No server found, failed.
if len(servers) == 0 {
return nil, fmt.Errorf("no server available for %v", streamURL)
}
// Pick a server randomly from servers.
server := servers[rand.Intn(len(servers))]
v.picked.Store(streamURL, server)
return server, nil
}
func (v *srsMemoryLoadBalancer) LoadHLSBySPBHID(ctx context.Context, spbhid string) (*HLSPlayStream, error) {
// Load the HLS streaming for the SPBHID, for TS files.
if actual, ok := v.hlsSPBHID.Load(spbhid); !ok {
return nil, errors.Errorf("no HLS streaming for SPBHID %v", spbhid)
} else {
return actual, nil
}
}
func (v *srsMemoryLoadBalancer) LoadOrStoreHLS(ctx context.Context, streamURL string, value *HLSPlayStream) (*HLSPlayStream, error) {
// Update the HLS streaming for the stream URL, for M3u8.
actual, _ := v.hlsStreamURL.LoadOrStore(streamURL, value)
if actual == nil {
return nil, errors.Errorf("load or store HLS streaming for %v failed", streamURL)
}
// Update the HLS streaming for the SPBHID, for TS files.
v.hlsSPBHID.Store(value.SRSProxyBackendHLSID, actual)
return actual, nil
}
func (v *srsMemoryLoadBalancer) StoreWebRTC(ctx context.Context, streamURL string, value *RTCConnection) error {
// Update the WebRTC streaming for the stream URL.
v.rtcStreamURL.Store(streamURL, value)
// Update the WebRTC streaming for the ufrag.
v.rtcUfrag.Store(value.Ufrag, value)
return nil
}
func (v *srsMemoryLoadBalancer) LoadWebRTCByUfrag(ctx context.Context, ufrag string) (*RTCConnection, error) {
if actual, ok := v.rtcUfrag.Load(ufrag); !ok {
return nil, errors.Errorf("no WebRTC streaming for ufrag %v", ufrag)
} else {
return actual, nil
}
}
type srsRedisLoadBalancer struct {
// The redis client sdk.
rdb *redis.Client
}
func NewRedisLoadBalancer() SRSLoadBalancer {
return &srsRedisLoadBalancer{}
}
func (v *srsRedisLoadBalancer) Initialize(ctx context.Context) error {
redisDatabase, err := strconv.Atoi(envRedisDB())
if err != nil {
return errors.Wrapf(err, "invalid PROXY_REDIS_DB %v", envRedisDB())
}
rdb := redis.NewClient(&redis.Options{
Addr: fmt.Sprintf("%v:%v", envRedisHost(), envRedisPort()),
Password: envRedisPassword(),
DB: redisDatabase,
})
v.rdb = rdb
if err := rdb.Ping(ctx).Err(); err != nil {
return errors.Wrapf(err, "unable to connect to redis %v", rdb.String())
}
logger.Df(ctx, "RedisLB: connected to redis %v ok", rdb.String())
if server, err := NewDefaultSRSForDebugging(); err != nil {
return errors.Wrapf(err, "initialize default SRS")
} else if server != nil {
if err := v.Update(ctx, server); err != nil {
return errors.Wrapf(err, "update default SRS %+v", server)
}
// Keep alive.
go func() {
for {
select {
case <-ctx.Done():
return
case <-time.After(30 * time.Second):
if err := v.Update(ctx, server); err != nil {
logger.Wf(ctx, "update default SRS %+v failed, %+v", server, err)
}
}
}
}()
logger.Df(ctx, "RedisLB: Initialize default SRS media server, %+v", server)
}
return nil
}
func (v *srsRedisLoadBalancer) Update(ctx context.Context, server *SRSServer) error {
b, err := json.Marshal(server)
if err != nil {
return errors.Wrapf(err, "marshal server %+v", server)
}
key := v.redisKeyServer(server.ID())
if err = v.rdb.Set(ctx, key, b, srsServerAliveDuration).Err(); err != nil {
return errors.Wrapf(err, "set key=%v server %+v", key, server)
}
// Query all servers from redis, in json string.
var serverKeys []string
if b, err := v.rdb.Get(ctx, v.redisKeyServers()).Bytes(); err == nil {
if err := json.Unmarshal(b, &serverKeys); err != nil {
return errors.Wrapf(err, "unmarshal key=%v servers %v", v.redisKeyServers(), string(b))
}
}
// Check each server expiration, if not exists in redis, remove from servers.
for i := len(serverKeys) - 1; i >= 0; i-- {
if _, err := v.rdb.Get(ctx, serverKeys[i]).Bytes(); err != nil {
serverKeys = append(serverKeys[:i], serverKeys[i+1:]...)
}
}
// Add server to servers if not exists.
var found bool
for _, serverKey := range serverKeys {
if serverKey == key {
found = true
break
}
}
if !found {
serverKeys = append(serverKeys, key)
}
// Update all servers to redis.
b, err = json.Marshal(serverKeys)
if err != nil {
return errors.Wrapf(err, "marshal servers %+v", serverKeys)
}
if err = v.rdb.Set(ctx, v.redisKeyServers(), b, 0).Err(); err != nil {
return errors.Wrapf(err, "set key=%v servers %+v", v.redisKeyServers(), serverKeys)
}
return nil
}
func (v *srsRedisLoadBalancer) Pick(ctx context.Context, streamURL string) (*SRSServer, error) {
key := fmt.Sprintf("srs-proxy-url:%v", streamURL)
// Always proxy to the same server for the same stream URL.
if serverKey, err := v.rdb.Get(ctx, key).Result(); err == nil {
// If server not exists, ignore and pick another server for the stream URL.
if b, err := v.rdb.Get(ctx, serverKey).Bytes(); err == nil && len(b) > 0 {
var server SRSServer
if err := json.Unmarshal(b, &server); err != nil {
return nil, errors.Wrapf(err, "unmarshal key=%v server %v", key, string(b))
}
// TODO: If server fail, we should migrate the streams to another server.
return &server, nil
}
}
// Query all servers from redis, in json string.
var serverKeys []string
if b, err := v.rdb.Get(ctx, v.redisKeyServers()).Bytes(); err == nil {
if err := json.Unmarshal(b, &serverKeys); err != nil {
return nil, errors.Wrapf(err, "unmarshal key=%v servers %v", v.redisKeyServers(), string(b))
}
}
// No server found, failed.
if len(serverKeys) == 0 {
return nil, fmt.Errorf("no server available for %v", streamURL)
}
// All server should be alive, if not, should have been removed by redis. So we only
// random pick one that is always available.
var serverKey string
var server SRSServer
for i := 0; i < 3; i++ {
tryServerKey := serverKeys[rand.Intn(len(serverKeys))]
b, err := v.rdb.Get(ctx, tryServerKey).Bytes()
if err == nil && len(b) > 0 {
if err := json.Unmarshal(b, &server); err != nil {
return nil, errors.Wrapf(err, "unmarshal key=%v server %v", serverKey, string(b))
}
serverKey = tryServerKey
break
}
}
if serverKey == "" {
return nil, errors.Errorf("no server available in %v for %v", serverKeys, streamURL)
}
// Update the picked server for the stream URL.
if err := v.rdb.Set(ctx, key, []byte(serverKey), 0).Err(); err != nil {
return nil, errors.Wrapf(err, "set key=%v server %v", key, serverKey)
}
return &server, nil
}
func (v *srsRedisLoadBalancer) LoadHLSBySPBHID(ctx context.Context, spbhid string) (*HLSPlayStream, error) {
key := v.redisKeySPBHID(spbhid)
b, err := v.rdb.Get(ctx, key).Bytes()
if err != nil {
return nil, errors.Wrapf(err, "get key=%v HLS", key)
}
var actual HLSPlayStream
if err := json.Unmarshal(b, &actual); err != nil {
return nil, errors.Wrapf(err, "unmarshal key=%v HLS %v", key, string(b))
}
return &actual, nil
}
func (v *srsRedisLoadBalancer) LoadOrStoreHLS(ctx context.Context, streamURL string, value *HLSPlayStream) (*HLSPlayStream, error) {
b, err := json.Marshal(value)
if err != nil {
return nil, errors.Wrapf(err, "marshal HLS %v", value)
}
key := v.redisKeyHLS(streamURL)
if err = v.rdb.Set(ctx, key, b, srsHLSAliveDuration).Err(); err != nil {
return nil, errors.Wrapf(err, "set key=%v HLS %v", key, value)
}
key2 := v.redisKeySPBHID(value.SRSProxyBackendHLSID)
if err := v.rdb.Set(ctx, key2, b, srsHLSAliveDuration).Err(); err != nil {
return nil, errors.Wrapf(err, "set key=%v HLS %v", key2, value)
}
// Query the HLS streaming from redis.
b2, err := v.rdb.Get(ctx, key).Bytes()
if err != nil {
return nil, errors.Wrapf(err, "get key=%v HLS", key)
}
var actual HLSPlayStream
if err := json.Unmarshal(b2, &actual); err != nil {
return nil, errors.Wrapf(err, "unmarshal key=%v HLS %v", key, string(b2))
}
return &actual, nil
}
func (v *srsRedisLoadBalancer) StoreWebRTC(ctx context.Context, streamURL string, value *RTCConnection) error {
b, err := json.Marshal(value)
if err != nil {
return errors.Wrapf(err, "marshal WebRTC %v", value)
}
key := v.redisKeyRTC(streamURL)
if err = v.rdb.Set(ctx, key, b, srsRTCAliveDuration).Err(); err != nil {
return errors.Wrapf(err, "set key=%v WebRTC %v", key, value)
}
key2 := v.redisKeyUfrag(value.Ufrag)
if err := v.rdb.Set(ctx, key2, b, srsRTCAliveDuration).Err(); err != nil {
return errors.Wrapf(err, "set key=%v WebRTC %v", key2, value)
}
return nil
}
func (v *srsRedisLoadBalancer) LoadWebRTCByUfrag(ctx context.Context, ufrag string) (*RTCConnection, error) {
key := v.redisKeyUfrag(ufrag)
b, err := v.rdb.Get(ctx, key).Bytes()
if err != nil {
return nil, errors.Wrapf(err, "get key=%v WebRTC", key)
}
var actual RTCConnection
if err := json.Unmarshal(b, &actual); err != nil {
return nil, errors.Wrapf(err, "unmarshal key=%v WebRTC %v", key, string(b))
}
return &actual, nil
}
func (v *srsRedisLoadBalancer) redisKeyUfrag(ufrag string) string {
return fmt.Sprintf("srs-proxy-ufrag:%v", ufrag)
}
func (v *srsRedisLoadBalancer) redisKeyRTC(streamURL string) string {
return fmt.Sprintf("srs-proxy-rtc:%v", streamURL)
}
func (v *srsRedisLoadBalancer) redisKeySPBHID(spbhid string) string {
return fmt.Sprintf("srs-proxy-spbhid:%v", spbhid)
}
func (v *srsRedisLoadBalancer) redisKeyHLS(streamURL string) string {
return fmt.Sprintf("srs-proxy-hls:%v", streamURL)
}
func (v *srsRedisLoadBalancer) redisKeyServer(serverID string) string {
return fmt.Sprintf("srs-proxy-server:%v", serverID)
}
func (v *srsRedisLoadBalancer) redisKeyServers() string {
return fmt.Sprintf("srs-proxy-all-servers")
}

@ -0,0 +1,574 @@
// Copyright (c) 2024 Winlin
//
// SPDX-License-Identifier: MIT
package main
import (
"bytes"
"context"
"encoding/binary"
"fmt"
"net"
"strings"
stdSync "sync"
"time"
"srs-proxy/errors"
"srs-proxy/logger"
"srs-proxy/sync"
)
// srsSRTServer is the proxy for SRS server via SRT. It will figure out which backend server to
// proxy to. It only parses the SRT handshake messages, parses the stream id, and proxy to the
// backend server.
type srsSRTServer struct {
// The UDP listener for SRT server.
listener *net.UDPConn
// The SRT connections, identify by the socket ID.
sockets sync.Map[uint32, *SRTConnection]
// The system start time.
start time.Time
// The wait group for server.
wg stdSync.WaitGroup
}
func NewSRSSRTServer(opts ...func(*srsSRTServer)) *srsSRTServer {
v := &srsSRTServer{
start: time.Now(),
}
for _, opt := range opts {
opt(v)
}
return v
}
func (v *srsSRTServer) Close() error {
if v.listener != nil {
v.listener.Close()
}
v.wg.Wait()
return nil
}
func (v *srsSRTServer) Run(ctx context.Context) error {
// Parse address to listen.
endpoint := envSRTServer()
if !strings.Contains(endpoint, ":") {
endpoint = ":" + endpoint
}
saddr, err := net.ResolveUDPAddr("udp", endpoint)
if err != nil {
return errors.Wrapf(err, "resolve udp addr %v", endpoint)
}
listener, err := net.ListenUDP("udp", saddr)
if err != nil {
return errors.Wrapf(err, "listen udp %v", saddr)
}
v.listener = listener
logger.Df(ctx, "SRT server listen at %v", saddr)
// Consume all messages from UDP media transport.
v.wg.Add(1)
go func() {
defer v.wg.Done()
for ctx.Err() == nil {
buf := make([]byte, 4096)
n, caddr, err := v.listener.ReadFromUDP(buf)
if err != nil {
// TODO: If SRT server closed unexpectedly, we should notice the main loop to quit.
logger.Wf(ctx, "read from udp failed, err=%+v", err)
continue
}
if err := v.handleClientUDP(ctx, caddr, buf[:n]); err != nil {
logger.Wf(ctx, "handle udp %vB failed, addr=%v, err=%+v", n, caddr, err)
}
}
}()
return nil
}
func (v *srsSRTServer) handleClientUDP(ctx context.Context, addr *net.UDPAddr, data []byte) error {
socketID := srtParseSocketID(data)
var pkt *SRTHandshakePacket
if srtIsHandshake(data) {
pkt = &SRTHandshakePacket{}
if err := pkt.UnmarshalBinary(data); err != nil {
return err
}
if socketID == 0 {
socketID = pkt.SRTSocketID
}
}
conn, ok := v.sockets.LoadOrStore(socketID, NewSRTConnection(func(c *SRTConnection) {
c.ctx = logger.WithContext(ctx)
c.listenerUDP, c.socketID = v.listener, socketID
c.start = v.start
}))
ctx = conn.ctx
if !ok {
logger.Df(ctx, "Create new SRT connection skt=%v", socketID)
}
if newSocketID, err := conn.HandlePacket(pkt, addr, data); err != nil {
return errors.Wrapf(err, "handle packet")
} else if newSocketID != 0 && newSocketID != socketID {
// The connection may use a new socket ID.
// TODO: FIXME: Should cleanup the dead SRT connection.
v.sockets.Store(newSocketID, conn)
}
return nil
}
// SRTConnection is an SRT connection proxy, for both caller and listener. It represents an SRT
// connection, identify by the socket ID.
//
// It's similar to RTMP or HTTP FLV/TS proxy connection, which are stateless and all state is in
// the client request. The SRTConnection is stateless, and no need to sync between proxy servers.
//
// Unlike the WebRTC connection, SRTConnection does not support address changes. This means the
// client should never switch to another network or port. If this occurs, the client may be served
// by a different proxy server and fail because the other proxy server cannot identify the client.
type SRTConnection struct {
// The stream context for SRT connection.
ctx context.Context
// The current socket ID.
socketID uint32
// The UDP connection proxy to backend.
backendUDP *net.UDPConn
// The listener UDP connection, used to send messages to client.
listenerUDP *net.UDPConn
// Listener start time.
start time.Time
// Handshake packets with client.
handshake0 *SRTHandshakePacket
handshake1 *SRTHandshakePacket
handshake2 *SRTHandshakePacket
handshake3 *SRTHandshakePacket
}
func NewSRTConnection(opts ...func(*SRTConnection)) *SRTConnection {
v := &SRTConnection{}
for _, opt := range opts {
opt(v)
}
return v
}
func (v *SRTConnection) HandlePacket(pkt *SRTHandshakePacket, addr *net.UDPAddr, data []byte) (uint32, error) {
ctx := v.ctx
// If not handshake, try to proxy to backend directly.
if pkt == nil {
// Proxy client message to backend.
if v.backendUDP != nil {
if _, err := v.backendUDP.Write(data); err != nil {
return v.socketID, errors.Wrapf(err, "write to backend")
}
}
return v.socketID, nil
}
// Handle handshake messages.
if err := v.handleHandshake(ctx, pkt, addr, data); err != nil {
return v.socketID, errors.Wrapf(err, "handle handshake %v", pkt)
}
return v.socketID, nil
}
func (v *SRTConnection) handleHandshake(ctx context.Context, pkt *SRTHandshakePacket, addr *net.UDPAddr, data []byte) error {
// Handle handshake 0 and 1 messages.
if pkt.SynCookie == 0 {
// Save handshake 0 packet.
v.handshake0 = pkt
logger.Df(ctx, "SRT Handshake 0: %v", v.handshake0)
// Response handshake 1.
v.handshake1 = &SRTHandshakePacket{
ControlFlag: pkt.ControlFlag,
ControlType: 0,
SubType: 0,
AdditionalInfo: 0,
Timestamp: uint32(time.Since(v.start).Microseconds()),
SocketID: pkt.SRTSocketID,
Version: 5,
EncryptionField: 0,
ExtensionField: 0x4A17,
InitSequence: pkt.InitSequence,
MTU: pkt.MTU,
FlowWindow: pkt.FlowWindow,
HandshakeType: 1,
SRTSocketID: pkt.SRTSocketID,
SynCookie: 0x418d5e4e,
PeerIP: net.ParseIP("127.0.0.1"),
}
logger.Df(ctx, "SRT Handshake 1: %v", v.handshake1)
if b, err := v.handshake1.MarshalBinary(); err != nil {
return errors.Wrapf(err, "marshal handshake 1")
} else if _, err = v.listenerUDP.WriteToUDP(b, addr); err != nil {
return errors.Wrapf(err, "write handshake 1")
}
return nil
}
// Handle handshake 2 and 3 messages.
// Parse stream id from packet.
streamID, err := pkt.StreamID()
if err != nil {
return errors.Wrapf(err, "parse stream id")
}
// Save handshake packet.
v.handshake2 = pkt
logger.Df(ctx, "SRT Handshake 2: %v, sid=%v", v.handshake2, streamID)
// Start the UDP proxy to backend.
if err := v.connectBackend(ctx, streamID); err != nil {
return errors.Wrapf(err, "connect backend for %v", streamID)
}
// Proxy client message to backend.
if v.backendUDP == nil {
return errors.Errorf("no backend for %v", streamID)
}
// Proxy handshake 0 to backend server.
if b, err := v.handshake0.MarshalBinary(); err != nil {
return errors.Wrapf(err, "marshal handshake 0")
} else if _, err = v.backendUDP.Write(b); err != nil {
return errors.Wrapf(err, "write handshake 0")
}
logger.Df(ctx, "Proxy send handshake 0: %v", v.handshake0)
// Read handshake 1 from backend server.
b := make([]byte, 4096)
handshake1p := &SRTHandshakePacket{}
if nn, err := v.backendUDP.Read(b); err != nil {
return errors.Wrapf(err, "read handshake 1")
} else if err := handshake1p.UnmarshalBinary(b[:nn]); err != nil {
return errors.Wrapf(err, "unmarshal handshake 1")
}
logger.Df(ctx, "Proxy got handshake 1: %v", handshake1p)
// Proxy handshake 2 to backend server.
handshake2p := *v.handshake2
handshake2p.SynCookie = handshake1p.SynCookie
if b, err := handshake2p.MarshalBinary(); err != nil {
return errors.Wrapf(err, "marshal handshake 2")
} else if _, err = v.backendUDP.Write(b); err != nil {
return errors.Wrapf(err, "write handshake 2")
}
logger.Df(ctx, "Proxy send handshake 2: %v", handshake2p)
// Read handshake 3 from backend server.
handshake3p := &SRTHandshakePacket{}
if nn, err := v.backendUDP.Read(b); err != nil {
return errors.Wrapf(err, "read handshake 3")
} else if err := handshake3p.UnmarshalBinary(b[:nn]); err != nil {
return errors.Wrapf(err, "unmarshal handshake 3")
}
logger.Df(ctx, "Proxy got handshake 3: %v", handshake3p)
// Response handshake 3 to client.
v.handshake3 = &*handshake3p
v.handshake3.SynCookie = v.handshake1.SynCookie
v.socketID = handshake3p.SRTSocketID
logger.Df(ctx, "Handshake 3: %v", v.handshake3)
if b, err := v.handshake3.MarshalBinary(); err != nil {
return errors.Wrapf(err, "marshal handshake 3")
} else if _, err = v.listenerUDP.WriteToUDP(b, addr); err != nil {
return errors.Wrapf(err, "write handshake 3")
}
// Start a goroutine to proxy message from backend to client.
// TODO: FIXME: Support close the connection when timeout or client disconnected.
go func() {
for ctx.Err() == nil {
nn, err := v.backendUDP.Read(b)
if err != nil {
// TODO: If backend server closed unexpectedly, we should notice the stream to quit.
logger.Wf(ctx, "read from backend failed, err=%v", err)
return
}
if _, err = v.listenerUDP.WriteToUDP(b[:nn], addr); err != nil {
// TODO: If backend server closed unexpectedly, we should notice the stream to quit.
logger.Wf(ctx, "write to client failed, err=%v", err)
return
}
}
}()
return nil
}
func (v *SRTConnection) connectBackend(ctx context.Context, streamID string) error {
if v.backendUDP != nil {
return nil
}
// Parse stream id to host and resource.
host, resource, err := parseSRTStreamID(streamID)
if err != nil {
return errors.Wrapf(err, "parse stream id %v", streamID)
}
if host == "" {
host = "localhost"
}
streamURL, err := buildStreamURL(fmt.Sprintf("srt://%v/%v", host, resource))
if err != nil {
return errors.Wrapf(err, "build stream url %v", streamID)
}
// Pick a backend SRS server to proxy the SRT stream.
backend, err := srsLoadBalancer.Pick(ctx, streamURL)
if err != nil {
return errors.Wrapf(err, "pick backend for %v", streamURL)
}
// Parse UDP port from backend.
if len(backend.SRT) == 0 {
return errors.Errorf("no udp server %v for %v", backend, streamURL)
}
_, _, udpPort, err := parseListenEndpoint(backend.SRT[0])
if err != nil {
return errors.Wrapf(err, "parse udp port %v of %v for %v", backend.SRT[0], backend, streamURL)
}
// Connect to backend SRS server via UDP client.
// TODO: FIXME: Support close the connection when timeout or client disconnected.
backendAddr := net.UDPAddr{IP: net.ParseIP(backend.IP), Port: int(udpPort)}
if backendUDP, err := net.DialUDP("udp", nil, &backendAddr); err != nil {
return errors.Wrapf(err, "dial udp to %v of %v for %v", backendAddr, backend, streamURL)
} else {
v.backendUDP = backendUDP
}
return nil
}
// See https://datatracker.ietf.org/doc/html/draft-sharabayko-srt-01#section-3.2
// See https://datatracker.ietf.org/doc/html/draft-sharabayko-srt-01#section-3.2.1
type SRTHandshakePacket struct {
// F: 1 bit. Packet Type Flag. The control packet has this flag set to
// "1". The data packet has this flag set to "0".
ControlFlag uint8
// Control Type: 15 bits. Control Packet Type. The use of these bits
// is determined by the control packet type definition.
// Handshake control packets (Control Type = 0x0000) are used to
// exchange peer configurations, to agree on connection parameters, and
// to establish a connection.
ControlType uint16
// Subtype: 16 bits. This field specifies an additional subtype for
// specific packets.
SubType uint16
// Type-specific Information: 32 bits. The use of this field depends on
// the particular control packet type. Handshake packets do not use
// this field.
AdditionalInfo uint32
// Timestamp: 32 bits.
Timestamp uint32
// Destination Socket ID: 32 bits.
SocketID uint32
// Version: 32 bits. A base protocol version number. Currently used
// values are 4 and 5. Values greater than 5 are reserved for future
// use.
Version uint32
// Encryption Field: 16 bits. Block cipher family and key size. The
// values of this field are described in Table 2. The default value
// is AES-128.
// 0 | No Encryption Advertised
// 2 | AES-128
// 3 | AES-192
// 4 | AES-256
EncryptionField uint16
// Extension Field: 16 bits. This field is message specific extension
// related to Handshake Type field. The value MUST be set to 0
// except for the following cases. (1) If the handshake control
// packet is the INDUCTION message, this field is sent back by the
// Listener. (2) In the case of a CONCLUSION message, this field
// value should contain a combination of Extension Type values.
// 0x00000001 | HSREQ
// 0x00000002 | KMREQ
// 0x00000004 | CONFIG
// 0x4A17 if HandshakeType is INDUCTION, see https://datatracker.ietf.org/doc/html/draft-sharabayko-srt-01#section-4.3.1.1
ExtensionField uint16
// Initial Packet Sequence Number: 32 bits. The sequence number of the
// very first data packet to be sent.
InitSequence uint32
// Maximum Transmission Unit Size: 32 bits. This value is typically set
// to 1500, which is the default Maximum Transmission Unit (MTU) size
// for Ethernet, but can be less.
MTU uint32
// Maximum Flow Window Size: 32 bits. The value of this field is the
// maximum number of data packets allowed to be "in flight" (i.e. the
// number of sent packets for which an ACK control packet has not yet
// been received).
FlowWindow uint32
// Handshake Type: 32 bits. This field indicates the handshake packet
// type.
// 0xFFFFFFFD | DONE
// 0xFFFFFFFE | AGREEMENT
// 0xFFFFFFFF | CONCLUSION
// 0x00000000 | WAVEHAND
// 0x00000001 | INDUCTION
HandshakeType uint32
// SRT Socket ID: 32 bits. This field holds the ID of the source SRT
// socket from which a handshake packet is issued.
SRTSocketID uint32
// SYN Cookie: 32 bits. Randomized value for processing a handshake.
// The value of this field is specified by the handshake message
// type.
SynCookie uint32
// Peer IP Address: 128 bits. IPv4 or IPv6 address of the packet's
// sender. The value consists of four 32-bit fields.
PeerIP net.IP
// Extensions.
// Extension Type: 16 bits. The value of this field is used to process
// an integrated handshake. Each extension can have a pair of
// request and response types.
// Extension Length: 16 bits. The length of the Extension Contents
// field in four-byte blocks.
// Extension Contents: variable length. The payload of the extension.
ExtraData []byte
}
func (v *SRTHandshakePacket) IsData() bool {
return v.ControlFlag == 0x00
}
func (v *SRTHandshakePacket) IsControl() bool {
return v.ControlFlag == 0x80
}
func (v *SRTHandshakePacket) IsHandshake() bool {
return v.IsControl() && v.ControlType == 0x00 && v.SubType == 0x00
}
func (v *SRTHandshakePacket) StreamID() (string, error) {
p := v.ExtraData
for {
if len(p) < 2 {
return "", errors.Errorf("Require 2 bytes, actual=%v, extra=%v", len(p), len(v.ExtraData))
}
extType := binary.BigEndian.Uint16(p)
extSize := binary.BigEndian.Uint16(p[2:])
p = p[4:]
if len(p) < int(extSize*4) {
return "", errors.Errorf("Require %v bytes, actual=%v, extra=%v", extSize*4, len(p), len(v.ExtraData))
}
// Ignore other packets except stream id.
if extType != 0x05 {
p = p[extSize*4:]
continue
}
// We must copy it, because we will decode the stream id.
data := append([]byte{}, p[:extSize*4]...)
// Reverse the stream id encoded in little-endian to big-endian.
for i := 0; i < len(data); i += 4 {
value := binary.LittleEndian.Uint32(data[i:])
binary.BigEndian.PutUint32(data[i:], value)
}
// Trim the trailing zero bytes.
data = bytes.TrimRight(data, "\x00")
return string(data), nil
}
}
func (v *SRTHandshakePacket) String() string {
return fmt.Sprintf("Control=%v, CType=%v, SType=%v, Timestamp=%v, SocketID=%v, Version=%v, Encrypt=%v, Extension=%v, InitSequence=%v, MTU=%v, FlowWnd=%v, HSType=%v, SRTSocketID=%v, Cookie=%v, Peer=%vB, Extra=%vB",
v.IsControl(), v.ControlType, v.SubType, v.Timestamp, v.SocketID, v.Version, v.EncryptionField, v.ExtensionField, v.InitSequence, v.MTU, v.FlowWindow, v.HandshakeType, v.SRTSocketID, v.SynCookie, len(v.PeerIP), len(v.ExtraData))
}
func (v *SRTHandshakePacket) UnmarshalBinary(b []byte) error {
if len(b) < 4 {
return errors.Errorf("Invalid packet length %v", len(b))
}
v.ControlFlag = b[0] & 0x80
v.ControlType = binary.BigEndian.Uint16(b[0:2]) & 0x7fff
v.SubType = binary.BigEndian.Uint16(b[2:4])
if len(b) < 64 {
return errors.Errorf("Invalid packet length %v", len(b))
}
v.AdditionalInfo = binary.BigEndian.Uint32(b[4:])
v.Timestamp = binary.BigEndian.Uint32(b[8:])
v.SocketID = binary.BigEndian.Uint32(b[12:])
v.Version = binary.BigEndian.Uint32(b[16:])
v.EncryptionField = binary.BigEndian.Uint16(b[20:])
v.ExtensionField = binary.BigEndian.Uint16(b[22:])
v.InitSequence = binary.BigEndian.Uint32(b[24:])
v.MTU = binary.BigEndian.Uint32(b[28:])
v.FlowWindow = binary.BigEndian.Uint32(b[32:])
v.HandshakeType = binary.BigEndian.Uint32(b[36:])
v.SRTSocketID = binary.BigEndian.Uint32(b[40:])
v.SynCookie = binary.BigEndian.Uint32(b[44:])
// Only support IPv4.
v.PeerIP = net.IPv4(b[51], b[50], b[49], b[48])
v.ExtraData = b[64:]
return nil
}
func (v *SRTHandshakePacket) MarshalBinary() ([]byte, error) {
b := make([]byte, 64+len(v.ExtraData))
binary.BigEndian.PutUint16(b, uint16(v.ControlFlag)<<8|v.ControlType)
binary.BigEndian.PutUint16(b[2:], v.SubType)
binary.BigEndian.PutUint32(b[4:], v.AdditionalInfo)
binary.BigEndian.PutUint32(b[8:], v.Timestamp)
binary.BigEndian.PutUint32(b[12:], v.SocketID)
binary.BigEndian.PutUint32(b[16:], v.Version)
binary.BigEndian.PutUint16(b[20:], v.EncryptionField)
binary.BigEndian.PutUint16(b[22:], v.ExtensionField)
binary.BigEndian.PutUint32(b[24:], v.InitSequence)
binary.BigEndian.PutUint32(b[28:], v.MTU)
binary.BigEndian.PutUint32(b[32:], v.FlowWindow)
binary.BigEndian.PutUint32(b[36:], v.HandshakeType)
binary.BigEndian.PutUint32(b[40:], v.SRTSocketID)
binary.BigEndian.PutUint32(b[44:], v.SynCookie)
// Only support IPv4.
ip := v.PeerIP.To4()
b[48] = ip[3]
b[49] = ip[2]
b[50] = ip[1]
b[51] = ip[0]
if len(v.ExtraData) > 0 {
copy(b[64:], v.ExtraData)
}
return b, nil
}

@ -0,0 +1,45 @@
// Copyright (c) 2024 Winlin
//
// SPDX-License-Identifier: MIT
package sync
import "sync"
type Map[K comparable, V any] struct {
m sync.Map
}
func (m *Map[K, V]) Delete(key K) {
m.m.Delete(key)
}
func (m *Map[K, V]) Load(key K) (value V, ok bool) {
v, ok := m.m.Load(key)
if !ok {
return value, ok
}
return v.(V), ok
}
func (m *Map[K, V]) LoadAndDelete(key K) (value V, loaded bool) {
v, loaded := m.m.LoadAndDelete(key)
if !loaded {
return value, loaded
}
return v.(V), loaded
}
func (m *Map[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) {
a, loaded := m.m.LoadOrStore(key, value)
return a.(V), loaded
}
func (m *Map[K, V]) Range(f func(key K, value V) bool) {
m.m.Range(func(key, value any) bool {
return f(key.(K), value.(V))
})
}
func (m *Map[K, V]) Store(key K, value V) {
m.m.Store(key, value)
}

@ -0,0 +1,276 @@
// Copyright (c) 2024 Winlin
//
// SPDX-License-Identifier: MIT
package main
import (
"context"
"encoding/binary"
"encoding/json"
stdErr "errors"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"net/url"
"os"
"path"
"reflect"
"regexp"
"strconv"
"strings"
"syscall"
"time"
"srs-proxy/errors"
"srs-proxy/logger"
)
func apiResponse(ctx context.Context, w http.ResponseWriter, r *http.Request, data any) {
w.Header().Set("Server", fmt.Sprintf("%v/%v", Signature(), Version()))
b, err := json.Marshal(data)
if err != nil {
apiError(ctx, w, r, errors.Wrapf(err, "marshal %v %v", reflect.TypeOf(data), data))
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write(b)
}
func apiError(ctx context.Context, w http.ResponseWriter, r *http.Request, err error) {
logger.Wf(ctx, "HTTP API error %+v", err)
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.WriteHeader(http.StatusInternalServerError)
fmt.Fprintln(w, fmt.Sprintf("%v", err))
}
func apiCORS(ctx context.Context, w http.ResponseWriter, r *http.Request) bool {
// Always support CORS. Note that browser may send origin header for m3u8, but no origin header
// for ts. So we always response CORS header.
if true {
// SRS does not need cookie or credentials, so we disable CORS credentials, and use * for CORS origin,
// headers, expose headers and methods.
w.Header().Set("Access-Control-Allow-Origin", "*")
// See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers
w.Header().Set("Access-Control-Allow-Headers", "*")
// See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods
w.Header().Set("Access-Control-Allow-Methods", "*")
}
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusOK)
return true
}
return false
}
func parseGracefullyQuitTimeout() (time.Duration, error) {
if t, err := time.ParseDuration(envGraceQuitTimeout()); err != nil {
return 0, errors.Wrapf(err, "parse duration %v", envGraceQuitTimeout())
} else {
return t, nil
}
}
// ParseBody read the body from r, and unmarshal JSON to v.
func ParseBody(r io.ReadCloser, v interface{}) error {
b, err := ioutil.ReadAll(r)
if err != nil {
return errors.Wrapf(err, "read body")
}
defer r.Close()
if len(b) == 0 {
return nil
}
if err := json.Unmarshal(b, v); err != nil {
return errors.Wrapf(err, "json unmarshal %v", string(b))
}
return nil
}
// buildStreamURL build as vhost/app/stream for stream URL r.
func buildStreamURL(r string) (string, error) {
u, err := url.Parse(r)
if err != nil {
return "", errors.Wrapf(err, "parse url %v", r)
}
// If not domain or ip in hostname, it's __defaultVhost__.
defaultVhost := !strings.Contains(u.Hostname(), ".")
// If hostname is actually an IP address, it's __defaultVhost__.
if ip := net.ParseIP(u.Hostname()); ip.To4() != nil {
defaultVhost = true
}
if defaultVhost {
return fmt.Sprintf("__defaultVhost__%v", u.Path), nil
}
// Ignore port, only use hostname as vhost.
return fmt.Sprintf("%v%v", u.Hostname(), u.Path), nil
}
// isPeerClosedError indicates whether peer object closed the connection.
func isPeerClosedError(err error) bool {
causeErr := errors.Cause(err)
if stdErr.Is(causeErr, io.EOF) {
return true
}
if stdErr.Is(causeErr, syscall.EPIPE) {
return true
}
if netErr, ok := causeErr.(*net.OpError); ok {
if sysErr, ok := netErr.Err.(*os.SyscallError); ok {
if stdErr.Is(sysErr.Err, syscall.ECONNRESET) {
return true
}
}
}
return false
}
// convertURLToStreamURL convert the URL in HTTP request to special URLs. The unifiedURL is the URL
// in unified, foramt as scheme://vhost/app/stream without extensions. While the fullURL is the unifiedURL
// with extension.
func convertURLToStreamURL(r *http.Request) (unifiedURL, fullURL string) {
scheme := "http"
if r.TLS != nil {
scheme = "https"
}
hostname := "__defaultVhost__"
if strings.Contains(r.Host, ":") {
if v, _, err := net.SplitHostPort(r.Host); err == nil {
hostname = v
}
}
var appStream, streamExt string
// Parse app/stream from query string.
q := r.URL.Query()
if app := q.Get("app"); app != "" {
appStream = "/" + app
}
if stream := q.Get("stream"); stream != "" {
appStream = fmt.Sprintf("%v/%v", appStream, stream)
}
// Parse app/stream from path.
if appStream == "" {
streamExt = path.Ext(r.URL.Path)
appStream = strings.TrimSuffix(r.URL.Path, streamExt)
}
unifiedURL = fmt.Sprintf("%v://%v%v", scheme, hostname, appStream)
fullURL = fmt.Sprintf("%v%v", unifiedURL, streamExt)
return
}
// rtcIsSTUN returns true if data of UDP payload is a STUN packet.
func rtcIsSTUN(data []byte) bool {
return len(data) > 0 && (data[0] == 0 || data[0] == 1)
}
// rtcIsRTPOrRTCP returns true if data of UDP payload is a RTP or RTCP packet.
func rtcIsRTPOrRTCP(data []byte) bool {
return len(data) >= 12 && (data[0]&0xC0) == 0x80
}
// srtIsHandshake returns true if data of UDP payload is a SRT handshake packet.
func srtIsHandshake(data []byte) bool {
return len(data) >= 4 && binary.BigEndian.Uint32(data) == 0x80000000
}
// srtParseSocketID parse the socket id from the SRT packet.
func srtParseSocketID(data []byte) uint32 {
if len(data) >= 16 {
return binary.BigEndian.Uint32(data[12:])
}
return 0
}
// parseIceUfragPwd parse the ice-ufrag and ice-pwd from the SDP.
func parseIceUfragPwd(sdp string) (ufrag, pwd string, err error) {
if true {
ufragRe := regexp.MustCompile(`a=ice-ufrag:([^\s]+)`)
ufragMatch := ufragRe.FindStringSubmatch(sdp)
if len(ufragMatch) <= 1 {
return "", "", errors.Errorf("no ice-ufrag in sdp %v", sdp)
}
ufrag = ufragMatch[1]
}
if true {
pwdRe := regexp.MustCompile(`a=ice-pwd:([^\s]+)`)
pwdMatch := pwdRe.FindStringSubmatch(sdp)
if len(pwdMatch) <= 1 {
return "", "", errors.Errorf("no ice-pwd in sdp %v", sdp)
}
pwd = pwdMatch[1]
}
return ufrag, pwd, nil
}
// parseSRTStreamID parse the SRT stream id to host(optional) and resource(required).
// See https://ossrs.io/lts/en-us/docs/v7/doc/srt#srt-url
func parseSRTStreamID(sid string) (host, resource string, err error) {
if true {
hostRe := regexp.MustCompile(`h=([^,]+)`)
hostMatch := hostRe.FindStringSubmatch(sid)
if len(hostMatch) > 1 {
host = hostMatch[1]
}
}
if true {
resourceRe := regexp.MustCompile(`r=([^,]+)`)
resourceMatch := resourceRe.FindStringSubmatch(sid)
if len(resourceMatch) <= 1 {
return "", "", errors.Errorf("no resource in sid %v", sid)
}
resource = resourceMatch[1]
}
return host, resource, nil
}
// parseListenEndpoint parse the listen endpoint as:
// port The tcp listen port, like 1935.
// protocol://ip:port The listen endpoint, like tcp://:1935 or tcp://0.0.0.0:1935
func parseListenEndpoint(ep string) (protocol string, ip net.IP, port uint16, err error) {
// If no colon in ep, it's port in string.
if !strings.Contains(ep, ":") {
if p, err := strconv.Atoi(ep); err != nil {
return "", nil, 0, errors.Wrapf(err, "parse port %v", ep)
} else {
return "tcp", nil, uint16(p), nil
}
}
// Must be protocol://ip:port schema.
parts := strings.Split(ep, ":")
if len(parts) != 3 {
return "", nil, 0, errors.Errorf("invalid endpoint %v", ep)
}
if p, err := strconv.Atoi(parts[2]); err != nil {
return "", nil, 0, errors.Wrapf(err, "parse port %v", parts[2])
} else {
return parts[0], net.ParseIP(parts[1]), uint16(p), nil
}
}

@ -0,0 +1,27 @@
// Copyright (c) 2024 Winlin
//
// SPDX-License-Identifier: MIT
package main
import "fmt"
func VersionMajor() int {
return 1
}
// VersionMinor specifies the typical version of SRS we adapt to.
func VersionMinor() int {
return 5
}
func VersionRevision() int {
return 0
}
func Version() string {
return fmt.Sprintf("%v.%v.%v", VersionMajor(), VersionMinor(), VersionRevision())
}
func Signature() string {
return "SRSProxy"
}

@ -0,0 +1,57 @@
listen 19351;
max_connections 1000;
pid objs/origin1.pid;
daemon off;
srs_log_tank console;
http_server {
enabled on;
listen 8081;
dir ./objs/nginx/html;
}
http_api {
enabled on;
listen 19851;
}
rtc_server {
enabled on;
listen 8001; # UDP port
# @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#config-candidate
candidate $CANDIDATE;
}
srt_server {
enabled on;
listen 10081;
tsbpdmode off;
tlpktdrop off;
}
heartbeat {
enabled on;
interval 9;
url http://127.0.0.1:12025/api/v1/srs/register;
device_id origin1;
ports on;
}
vhost __defaultVhost__ {
http_remux {
enabled on;
mount [vhost]/[app]/[stream].flv;
}
hls {
enabled on;
hls_path ./objs/nginx/html;
hls_fragment 10;
hls_window 60;
}
rtc {
enabled on;
# @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#rtmp-to-rtc
rtmp_to_rtc on;
# @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#rtc-to-rtmp
rtc_to_rtmp on;
}
srt {
enabled on;
srt_to_rtmp on;
}
}

@ -0,0 +1,57 @@
listen 19352;
max_connections 1000;
pid objs/origin2.pid;
daemon off;
srs_log_tank console;
http_server {
enabled on;
listen 8082;
dir ./objs/nginx/html;
}
http_api {
enabled on;
listen 19853;
}
rtc_server {
enabled on;
listen 8001; # UDP port
# @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#config-candidate
candidate $CANDIDATE;
}
srt_server {
enabled on;
listen 10082;
tsbpdmode off;
tlpktdrop off;
}
heartbeat {
enabled on;
interval 9;
url http://127.0.0.1:12025/api/v1/srs/register;
device_id origin2;
ports on;
}
vhost __defaultVhost__ {
http_remux {
enabled on;
mount [vhost]/[app]/[stream].flv;
}
hls {
enabled on;
hls_path ./objs/nginx/html;
hls_fragment 10;
hls_window 60;
}
rtc {
enabled on;
# @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#rtmp-to-rtc
rtmp_to_rtc on;
# @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#rtc-to-rtmp
rtc_to_rtmp on;
}
srt {
enabled on;
srt_to_rtmp on;
}
}

@ -0,0 +1,57 @@
listen 19353;
max_connections 1000;
pid objs/origin3.pid;
daemon off;
srs_log_tank console;
http_server {
enabled on;
listen 8083;
dir ./objs/nginx/html;
}
http_api {
enabled on;
listen 19852;
}
rtc_server {
enabled on;
listen 8001; # UDP port
# @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#config-candidate
candidate $CANDIDATE;
}
srt_server {
enabled on;
listen 10083;
tsbpdmode off;
tlpktdrop off;
}
heartbeat {
enabled on;
interval 9;
url http://127.0.0.1:12025/api/v1/srs/register;
device_id origin3;
ports on;
}
vhost __defaultVhost__ {
http_remux {
enabled on;
mount [vhost]/[app]/[stream].flv;
}
hls {
enabled on;
hls_path ./objs/nginx/html;
hls_fragment 10;
hls_window 60;
}
rtc {
enabled on;
# @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#rtmp-to-rtc
rtmp_to_rtc on;
# @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#rtc-to-rtmp
rtc_to_rtmp on;
}
srt {
enabled on;
srt_to_rtmp on;
}
}

@ -7,6 +7,7 @@ The changelog for SRS.
<a name="v7-changes"></a> <a name="v7-changes"></a>
## SRS 7.0 Changelog ## SRS 7.0 Changelog
* v7.0, 2024-09-09, Merge [#4158](https://github.com/ossrs/srs/pull/4158): Proxy: Support proxy server for SRS. v7.0.16 (#4158)
* v7.0, 2024-09-09, Merge [#4171](https://github.com/ossrs/srs/pull/4171): Heartbeat: Report ports for proxy server. v7.0.15 (#4171) * v7.0, 2024-09-09, Merge [#4171](https://github.com/ossrs/srs/pull/4171): Heartbeat: Report ports for proxy server. v7.0.15 (#4171)
* v7.0, 2024-09-01, Merge [#4165](https://github.com/ossrs/srs/pull/4165): FLV: Refine source and http handler. v7.0.14 (#4165) * v7.0, 2024-09-01, Merge [#4165](https://github.com/ossrs/srs/pull/4165): FLV: Refine source and http handler. v7.0.14 (#4165)
* v7.0, 2024-09-01, Merge [#4166](https://github.com/ossrs/srs/pull/4166): Edge: Fix flv edge crash when http unmount. v7.0.13 (#4166) * v7.0, 2024-09-01, Merge [#4166](https://github.com/ossrs/srs/pull/4166): Edge: Fix flv edge crash when http unmount. v7.0.13 (#4166)

@ -342,7 +342,12 @@ SrsWaitGroup::SrsWaitGroup()
SrsWaitGroup::~SrsWaitGroup() SrsWaitGroup::~SrsWaitGroup()
{ {
wait(); // In the destructor, we should NOT wait for all coroutines to be done, because user should decide
// to wait or not. Similar to the Go's sync.WaitGroup, it also requires user to wait explicitly. For
// some special use scenarios, such as error handling, for example, if we started three servers with
// wait group, and one of them failed, user may want to return error and quit directly, without wait
// for other running servers to be done. If we wait in the destructor, it will continue to run without
// some servers, in unknown behaviors.
srs_cond_destroy(done_); srs_cond_destroy(done_);
} }

@ -9,6 +9,6 @@
#define VERSION_MAJOR 7 #define VERSION_MAJOR 7
#define VERSION_MINOR 0 #define VERSION_MINOR 0
#define VERSION_REVISION 15 #define VERSION_REVISION 16
#endif #endif
Loading…
Cancel
Save