You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
srs/proxy/srs.go

554 lines
17 KiB
Go

// Copyright (c) 2025 Winlin
//
// SPDX-License-Identifier: MIT
package main
import (
"context"
"encoding/json"
"fmt"
"math/rand"
"os"
"strconv"
"strings"
"time"
// Use v8 because we use Go 1.16+, while v9 requires Go 1.18+
"github.com/go-redis/redis/v8"
"srs-proxy/errors"
"srs-proxy/logger"
"srs-proxy/sync"
)
// If server heartbeat in this duration, it's alive.
const srsServerAliveDuration = 300 * time.Second
// If HLS streaming update in this duration, it's alive.
const srsHLSAliveDuration = 120 * time.Second
// If WebRTC streaming update in this duration, it's alive.
const srsRTCAliveDuration = 120 * time.Second
type SRSServer struct {
// The server IP.
IP string `json:"ip,omitempty"`
// The server device ID, configured by user.
DeviceID string `json:"device_id,omitempty"`
// The server id of SRS, store in file, may not change, mandatory.
ServerID string `json:"server_id,omitempty"`
// The service id of SRS, always change when restarted, mandatory.
ServiceID string `json:"service_id,omitempty"`
// The process id of SRS, always change when restarted, mandatory.
PID string `json:"pid,omitempty"`
// The RTMP listen endpoints.
RTMP []string `json:"rtmp,omitempty"`
// The HTTP Stream listen endpoints.
HTTP []string `json:"http,omitempty"`
// The HTTP API listen endpoints.
API []string `json:"api,omitempty"`
// The SRT server listen endpoints.
SRT []string `json:"srt,omitempty"`
// The RTC server listen endpoints.
RTC []string `json:"rtc,omitempty"`
// Last update time.
UpdatedAt time.Time `json:"update_at,omitempty"`
}
func (v *SRSServer) ID() string {
return fmt.Sprintf("%v-%v-%v", v.ServerID, v.ServiceID, v.PID)
}
func (v *SRSServer) String() string {
return fmt.Sprintf("%v", v)
}
func (v *SRSServer) Format(f fmt.State, c rune) {
switch c {
case 'v', 's':
if f.Flag('+') {
var sb strings.Builder
sb.WriteString(fmt.Sprintf("pid=%v, server=%v, service=%v", v.PID, v.ServerID, v.ServiceID))
if v.DeviceID != "" {
sb.WriteString(fmt.Sprintf(", device=%v", v.DeviceID))
}
if len(v.RTMP) > 0 {
sb.WriteString(fmt.Sprintf(", rtmp=[%v]", strings.Join(v.RTMP, ",")))
}
if len(v.HTTP) > 0 {
sb.WriteString(fmt.Sprintf(", http=[%v]", strings.Join(v.HTTP, ",")))
}
if len(v.API) > 0 {
sb.WriteString(fmt.Sprintf(", api=[%v]", strings.Join(v.API, ",")))
}
if len(v.SRT) > 0 {
sb.WriteString(fmt.Sprintf(", srt=[%v]", strings.Join(v.SRT, ",")))
}
if len(v.RTC) > 0 {
sb.WriteString(fmt.Sprintf(", rtc=[%v]", strings.Join(v.RTC, ",")))
}
sb.WriteString(fmt.Sprintf(", update=%v", v.UpdatedAt.Format("2006-01-02 15:04:05.999")))
fmt.Fprintf(f, "SRS ip=%v, id=%v, %v", v.IP, v.ID(), sb.String())
} else {
fmt.Fprintf(f, "SRS ip=%v, id=%v", v.IP, v.ID())
}
default:
fmt.Fprintf(f, "%v, fmt=%%%c", v, c)
}
}
func NewSRSServer(opts ...func(*SRSServer)) *SRSServer {
v := &SRSServer{}
for _, opt := range opts {
opt(v)
}
return v
}
// NewDefaultSRSForDebugging initialize the default SRS media server, for debugging only.
func NewDefaultSRSForDebugging() (*SRSServer, error) {
if envDefaultBackendEnabled() != "on" {
return nil, nil
}
if envDefaultBackendIP() == "" {
return nil, fmt.Errorf("empty default backend ip")
}
if envDefaultBackendRTMP() == "" {
return nil, fmt.Errorf("empty default backend rtmp")
}
server := NewSRSServer(func(srs *SRSServer) {
srs.IP = envDefaultBackendIP()
srs.RTMP = []string{envDefaultBackendRTMP()}
srs.ServerID = fmt.Sprintf("default-%v", logger.GenerateContextID())
srs.ServiceID = logger.GenerateContextID()
srs.PID = fmt.Sprintf("%v", os.Getpid())
srs.UpdatedAt = time.Now()
})
if envDefaultBackendHttp() != "" {
server.HTTP = []string{envDefaultBackendHttp()}
}
if envDefaultBackendAPI() != "" {
server.API = []string{envDefaultBackendAPI()}
}
if envDefaultBackendRTC() != "" {
server.RTC = []string{envDefaultBackendRTC()}
}
if envDefaultBackendSRT() != "" {
server.SRT = []string{envDefaultBackendSRT()}
}
return server, nil
}
// SRSLoadBalancer is the interface to load balance the SRS servers.
type SRSLoadBalancer interface {
// Initialize the load balancer.
Initialize(ctx context.Context) error
// Update the backer server.
Update(ctx context.Context, server *SRSServer) error
// Pick a backend server for the specified stream URL.
Pick(ctx context.Context, streamURL string) (*SRSServer, error)
// Load or store the HLS streaming for the specified stream URL.
LoadOrStoreHLS(ctx context.Context, streamURL string, value *HLSPlayStream) (*HLSPlayStream, error)
// Load the HLS streaming by SPBHID, the SRS Proxy Backend HLS ID.
LoadHLSBySPBHID(ctx context.Context, spbhid string) (*HLSPlayStream, error)
// Store the WebRTC streaming for the specified stream URL.
StoreWebRTC(ctx context.Context, streamURL string, value *RTCConnection) error
// Load the WebRTC streaming by ufrag, the ICE username.
LoadWebRTCByUfrag(ctx context.Context, ufrag string) (*RTCConnection, error)
}
// srsLoadBalancer is the global SRS load balancer.
var srsLoadBalancer SRSLoadBalancer
// srsMemoryLoadBalancer stores state in memory.
type srsMemoryLoadBalancer struct {
// All available SRS servers, key is server ID.
servers sync.Map[string, *SRSServer]
// The picked server to servce client by specified stream URL, key is stream url.
picked sync.Map[string, *SRSServer]
// The HLS streaming, key is stream URL.
hlsStreamURL sync.Map[string, *HLSPlayStream]
// The HLS streaming, key is SPBHID.
hlsSPBHID sync.Map[string, *HLSPlayStream]
// The WebRTC streaming, key is stream URL.
rtcStreamURL sync.Map[string, *RTCConnection]
// The WebRTC streaming, key is ufrag.
rtcUfrag sync.Map[string, *RTCConnection]
}
func NewMemoryLoadBalancer() SRSLoadBalancer {
return &srsMemoryLoadBalancer{}
}
func (v *srsMemoryLoadBalancer) Initialize(ctx context.Context) error {
if server, err := NewDefaultSRSForDebugging(); err != nil {
return errors.Wrapf(err, "initialize default SRS")
} else if server != nil {
if err := v.Update(ctx, server); err != nil {
return errors.Wrapf(err, "update default SRS %+v", server)
}
// Keep alive.
go func() {
for {
select {
case <-ctx.Done():
return
case <-time.After(30 * time.Second):
if err := v.Update(ctx, server); err != nil {
logger.Wf(ctx, "update default SRS %+v failed, %+v", server, err)
}
}
}
}()
logger.Df(ctx, "MemoryLB: Initialize default SRS media server, %+v", server)
}
return nil
}
func (v *srsMemoryLoadBalancer) Update(ctx context.Context, server *SRSServer) error {
v.servers.Store(server.ID(), server)
return nil
}
func (v *srsMemoryLoadBalancer) Pick(ctx context.Context, streamURL string) (*SRSServer, error) {
// Always proxy to the same server for the same stream URL.
if server, ok := v.picked.Load(streamURL); ok {
return server, nil
}
// Gather all servers that were alive within the last few seconds.
var servers []*SRSServer
v.servers.Range(func(key string, server *SRSServer) bool {
if time.Since(server.UpdatedAt) < srsServerAliveDuration {
servers = append(servers, server)
}
return true
})
// If no servers available, use all possible servers.
if len(servers) == 0 {
v.servers.Range(func(key string, server *SRSServer) bool {
servers = append(servers, server)
return true
})
}
// No server found, failed.
if len(servers) == 0 {
return nil, fmt.Errorf("no server available for %v", streamURL)
}
// Pick a server randomly from servers.
server := servers[rand.Intn(len(servers))]
v.picked.Store(streamURL, server)
return server, nil
}
func (v *srsMemoryLoadBalancer) LoadHLSBySPBHID(ctx context.Context, spbhid string) (*HLSPlayStream, error) {
// Load the HLS streaming for the SPBHID, for TS files.
if actual, ok := v.hlsSPBHID.Load(spbhid); !ok {
return nil, errors.Errorf("no HLS streaming for SPBHID %v", spbhid)
} else {
return actual, nil
}
}
func (v *srsMemoryLoadBalancer) LoadOrStoreHLS(ctx context.Context, streamURL string, value *HLSPlayStream) (*HLSPlayStream, error) {
// Update the HLS streaming for the stream URL, for M3u8.
actual, _ := v.hlsStreamURL.LoadOrStore(streamURL, value)
if actual == nil {
return nil, errors.Errorf("load or store HLS streaming for %v failed", streamURL)
}
// Update the HLS streaming for the SPBHID, for TS files.
v.hlsSPBHID.Store(value.SRSProxyBackendHLSID, actual)
return actual, nil
}
func (v *srsMemoryLoadBalancer) StoreWebRTC(ctx context.Context, streamURL string, value *RTCConnection) error {
// Update the WebRTC streaming for the stream URL.
v.rtcStreamURL.Store(streamURL, value)
// Update the WebRTC streaming for the ufrag.
v.rtcUfrag.Store(value.Ufrag, value)
return nil
}
func (v *srsMemoryLoadBalancer) LoadWebRTCByUfrag(ctx context.Context, ufrag string) (*RTCConnection, error) {
if actual, ok := v.rtcUfrag.Load(ufrag); !ok {
return nil, errors.Errorf("no WebRTC streaming for ufrag %v", ufrag)
} else {
return actual, nil
}
}
type srsRedisLoadBalancer struct {
// The redis client sdk.
rdb *redis.Client
}
func NewRedisLoadBalancer() SRSLoadBalancer {
return &srsRedisLoadBalancer{}
}
func (v *srsRedisLoadBalancer) Initialize(ctx context.Context) error {
redisDatabase, err := strconv.Atoi(envRedisDB())
if err != nil {
return errors.Wrapf(err, "invalid PROXY_REDIS_DB %v", envRedisDB())
}
rdb := redis.NewClient(&redis.Options{
Addr: fmt.Sprintf("%v:%v", envRedisHost(), envRedisPort()),
Password: envRedisPassword(),
DB: redisDatabase,
})
v.rdb = rdb
if err := rdb.Ping(ctx).Err(); err != nil {
return errors.Wrapf(err, "unable to connect to redis %v", rdb.String())
}
logger.Df(ctx, "RedisLB: connected to redis %v ok", rdb.String())
if server, err := NewDefaultSRSForDebugging(); err != nil {
return errors.Wrapf(err, "initialize default SRS")
} else if server != nil {
if err := v.Update(ctx, server); err != nil {
return errors.Wrapf(err, "update default SRS %+v", server)
}
// Keep alive.
go func() {
for {
select {
case <-ctx.Done():
return
case <-time.After(30 * time.Second):
if err := v.Update(ctx, server); err != nil {
logger.Wf(ctx, "update default SRS %+v failed, %+v", server, err)
}
}
}
}()
logger.Df(ctx, "RedisLB: Initialize default SRS media server, %+v", server)
}
return nil
}
func (v *srsRedisLoadBalancer) Update(ctx context.Context, server *SRSServer) error {
b, err := json.Marshal(server)
if err != nil {
return errors.Wrapf(err, "marshal server %+v", server)
}
key := v.redisKeyServer(server.ID())
if err = v.rdb.Set(ctx, key, b, srsServerAliveDuration).Err(); err != nil {
return errors.Wrapf(err, "set key=%v server %+v", key, server)
}
// Query all servers from redis, in json string.
var serverKeys []string
if b, err := v.rdb.Get(ctx, v.redisKeyServers()).Bytes(); err == nil {
if err := json.Unmarshal(b, &serverKeys); err != nil {
return errors.Wrapf(err, "unmarshal key=%v servers %v", v.redisKeyServers(), string(b))
}
}
// Check each server expiration, if not exists in redis, remove from servers.
for i := len(serverKeys) - 1; i >= 0; i-- {
if _, err := v.rdb.Get(ctx, serverKeys[i]).Bytes(); err != nil {
serverKeys = append(serverKeys[:i], serverKeys[i+1:]...)
}
}
// Add server to servers if not exists.
var found bool
for _, serverKey := range serverKeys {
if serverKey == key {
found = true
break
}
}
if !found {
serverKeys = append(serverKeys, key)
}
// Update all servers to redis.
b, err = json.Marshal(serverKeys)
if err != nil {
return errors.Wrapf(err, "marshal servers %+v", serverKeys)
}
if err = v.rdb.Set(ctx, v.redisKeyServers(), b, 0).Err(); err != nil {
return errors.Wrapf(err, "set key=%v servers %+v", v.redisKeyServers(), serverKeys)
}
return nil
}
func (v *srsRedisLoadBalancer) Pick(ctx context.Context, streamURL string) (*SRSServer, error) {
key := fmt.Sprintf("srs-proxy-url:%v", streamURL)
// Always proxy to the same server for the same stream URL.
if serverKey, err := v.rdb.Get(ctx, key).Result(); err == nil {
// If server not exists, ignore and pick another server for the stream URL.
if b, err := v.rdb.Get(ctx, serverKey).Bytes(); err == nil && len(b) > 0 {
var server SRSServer
if err := json.Unmarshal(b, &server); err != nil {
return nil, errors.Wrapf(err, "unmarshal key=%v server %v", key, string(b))
}
// TODO: If server fail, we should migrate the streams to another server.
return &server, nil
}
}
// Query all servers from redis, in json string.
var serverKeys []string
if b, err := v.rdb.Get(ctx, v.redisKeyServers()).Bytes(); err == nil {
if err := json.Unmarshal(b, &serverKeys); err != nil {
return nil, errors.Wrapf(err, "unmarshal key=%v servers %v", v.redisKeyServers(), string(b))
}
}
// No server found, failed.
if len(serverKeys) == 0 {
return nil, fmt.Errorf("no server available for %v", streamURL)
}
// All server should be alive, if not, should have been removed by redis. So we only
// random pick one that is always available.
var serverKey string
var server SRSServer
for i := 0; i < 3; i++ {
tryServerKey := serverKeys[rand.Intn(len(serverKeys))]
b, err := v.rdb.Get(ctx, tryServerKey).Bytes()
if err == nil && len(b) > 0 {
if err := json.Unmarshal(b, &server); err != nil {
return nil, errors.Wrapf(err, "unmarshal key=%v server %v", serverKey, string(b))
}
serverKey = tryServerKey
break
}
}
if serverKey == "" {
return nil, errors.Errorf("no server available in %v for %v", serverKeys, streamURL)
}
// Update the picked server for the stream URL.
if err := v.rdb.Set(ctx, key, []byte(serverKey), 0).Err(); err != nil {
return nil, errors.Wrapf(err, "set key=%v server %v", key, serverKey)
}
return &server, nil
}
func (v *srsRedisLoadBalancer) LoadHLSBySPBHID(ctx context.Context, spbhid string) (*HLSPlayStream, error) {
key := v.redisKeySPBHID(spbhid)
b, err := v.rdb.Get(ctx, key).Bytes()
if err != nil {
return nil, errors.Wrapf(err, "get key=%v HLS", key)
}
var actual HLSPlayStream
if err := json.Unmarshal(b, &actual); err != nil {
return nil, errors.Wrapf(err, "unmarshal key=%v HLS %v", key, string(b))
}
return &actual, nil
}
func (v *srsRedisLoadBalancer) LoadOrStoreHLS(ctx context.Context, streamURL string, value *HLSPlayStream) (*HLSPlayStream, error) {
b, err := json.Marshal(value)
if err != nil {
return nil, errors.Wrapf(err, "marshal HLS %v", value)
}
key := v.redisKeyHLS(streamURL)
if err = v.rdb.Set(ctx, key, b, srsHLSAliveDuration).Err(); err != nil {
return nil, errors.Wrapf(err, "set key=%v HLS %v", key, value)
}
key2 := v.redisKeySPBHID(value.SRSProxyBackendHLSID)
if err := v.rdb.Set(ctx, key2, b, srsHLSAliveDuration).Err(); err != nil {
return nil, errors.Wrapf(err, "set key=%v HLS %v", key2, value)
}
// Query the HLS streaming from redis.
b2, err := v.rdb.Get(ctx, key).Bytes()
if err != nil {
return nil, errors.Wrapf(err, "get key=%v HLS", key)
}
var actual HLSPlayStream
if err := json.Unmarshal(b2, &actual); err != nil {
return nil, errors.Wrapf(err, "unmarshal key=%v HLS %v", key, string(b2))
}
return &actual, nil
}
func (v *srsRedisLoadBalancer) StoreWebRTC(ctx context.Context, streamURL string, value *RTCConnection) error {
b, err := json.Marshal(value)
if err != nil {
return errors.Wrapf(err, "marshal WebRTC %v", value)
}
key := v.redisKeyRTC(streamURL)
if err = v.rdb.Set(ctx, key, b, srsRTCAliveDuration).Err(); err != nil {
return errors.Wrapf(err, "set key=%v WebRTC %v", key, value)
}
key2 := v.redisKeyUfrag(value.Ufrag)
if err := v.rdb.Set(ctx, key2, b, srsRTCAliveDuration).Err(); err != nil {
return errors.Wrapf(err, "set key=%v WebRTC %v", key2, value)
}
return nil
}
func (v *srsRedisLoadBalancer) LoadWebRTCByUfrag(ctx context.Context, ufrag string) (*RTCConnection, error) {
key := v.redisKeyUfrag(ufrag)
b, err := v.rdb.Get(ctx, key).Bytes()
if err != nil {
return nil, errors.Wrapf(err, "get key=%v WebRTC", key)
}
var actual RTCConnection
if err := json.Unmarshal(b, &actual); err != nil {
return nil, errors.Wrapf(err, "unmarshal key=%v WebRTC %v", key, string(b))
}
return &actual, nil
}
func (v *srsRedisLoadBalancer) redisKeyUfrag(ufrag string) string {
return fmt.Sprintf("srs-proxy-ufrag:%v", ufrag)
}
func (v *srsRedisLoadBalancer) redisKeyRTC(streamURL string) string {
return fmt.Sprintf("srs-proxy-rtc:%v", streamURL)
}
func (v *srsRedisLoadBalancer) redisKeySPBHID(spbhid string) string {
return fmt.Sprintf("srs-proxy-spbhid:%v", spbhid)
}
func (v *srsRedisLoadBalancer) redisKeyHLS(streamURL string) string {
return fmt.Sprintf("srs-proxy-hls:%v", streamURL)
}
func (v *srsRedisLoadBalancer) redisKeyServer(serverID string) string {
return fmt.Sprintf("srs-proxy-server:%v", serverID)
}
func (v *srsRedisLoadBalancer) redisKeyServers() string {
return fmt.Sprintf("srs-proxy-all-servers")
}