Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 88 additions & 13 deletions internal/servers/srt/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"context"
"errors"
"fmt"
"math/rand"
"reflect"
"sort"
"sync"
Expand All @@ -19,6 +20,17 @@ import (
"github.com/bluenviron/mediamtx/internal/logger"
)

// srtListen is a package-level indirection over srt.Listen so that tests
// can substitute a fake listener implementation.
var srtListen = srt.Listen

// Listener restart backoff parameters. Variables (rather than constants) so
// that tests can override them without changing production behavior.
var (
listenerRestartBaseDelay = 500 * time.Millisecond
listenerRestartMaxDelay = 30 * time.Second
)

// ErrConnNotFound is returned when a connection is not found.
var ErrConnNotFound = errors.New("connection not found")

Expand Down Expand Up @@ -87,11 +99,12 @@ type Server struct {
PathManager serverPathManager
Parent serverParent

ctx context.Context
ctxCancel func()
wg sync.WaitGroup
ln srt.Listener
conns map[*conn]struct{}
ctx context.Context
ctxCancel func()
wg sync.WaitGroup
ln srt.Listener
listenerConf srt.Config
conns map[*conn]struct{}

// in
chNewConnRequest chan srt.ConnRequest
Expand All @@ -104,13 +117,13 @@ type Server struct {

// Initialize initializes the server.
func (s *Server) Initialize() error {
conf := srt.DefaultConfig()
conf.ConnectionTimeout = time.Duration(s.ReadTimeout)
conf.PeerIdleTimeout = time.Duration(s.ReadTimeout)
conf.PayloadSize = uint32(srtMaxPayloadSize(s.UDPMaxPayloadSize))
s.listenerConf = srt.DefaultConfig()
s.listenerConf.ConnectionTimeout = time.Duration(s.ReadTimeout)
s.listenerConf.PeerIdleTimeout = time.Duration(s.ReadTimeout)
s.listenerConf.PayloadSize = uint32(srtMaxPayloadSize(s.UDPMaxPayloadSize))

var err error
s.ln, err = srt.Listen("srt", s.Address, conf)
s.ln, err = srtListen("srt", s.Address, s.listenerConf)
if err != nil {
return err
}
Expand Down Expand Up @@ -168,8 +181,16 @@ outer:
for {
select {
case err := <-s.chAcceptErr:
s.Log(logger.Error, "%s", err)
break outer
// ErrListenerClosed is the normal signal emitted when we Close()
// the listener ourselves during shutdown.
if errors.Is(err, srt.ErrListenerClosed) {
break outer
}
s.Log(logger.Warn, "listener failed: %s; attempting restart", err)
if rerr := s.restartListener(); rerr != nil {
s.Log(logger.Error, "listener restart aborted: %s", rerr)
break outer
}

case req := <-s.chNewConnRequest:
c := &conn{
Expand Down Expand Up @@ -235,7 +256,61 @@ outer:

s.ctxCancel()

s.ln.Close()
if s.ln != nil {
s.ln.Close()
}
}

// restartListener disposes of the dead gosrt listener and re-creates a new
// one with the same configuration, using bounded exponential backoff with
// jitter. It returns nil on success, or a non-nil error when the server
// context has been cancelled (i.e. the server is shutting down).
//
// Existing live *conn instances and the server goroutine are untouched;
// only the accept side is recreated.
func (s *Server) restartListener() error {
if s.ln != nil {
s.ln.Close()
s.ln = nil
}

delay := listenerRestartBaseDelay
attempt := 0

for {
attempt++

// jitter: +/- 25% of the current delay
jitter := time.Duration(rand.Int63n(int64(delay/2))) - delay/4
wait := delay + jitter

select {
case <-s.ctx.Done():
return fmt.Errorf("server is closing")
case <-time.After(wait):
}

ln, err := srtListen("srt", s.Address, s.listenerConf)
if err != nil {
s.Log(logger.Warn, "listener restart attempt %d failed: %s", attempt, err)
delay *= 2
if delay > listenerRestartMaxDelay {
delay = listenerRestartMaxDelay
}
continue
}

s.ln = ln
s.Log(logger.Info, "listener restarted on %s after %d attempt(s)", s.Address, attempt)

l := &listener{
ln: s.ln,
wg: &s.wg,
parent: s,
}
l.initialize()
return nil
}
}

func (s *Server) findConnByUUID(uuid uuid.UUID) *conn {
Expand Down
215 changes: 215 additions & 0 deletions internal/servers/srt/server_restart_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
package srt

import (
"errors"
"net"
"sync"
"sync/atomic"
"testing"
"time"

"github.com/bluenviron/mediamtx/internal/conf"
"github.com/bluenviron/mediamtx/internal/test"
srt "github.com/datarhei/gosrt"
"github.com/stretchr/testify/require"
)

// fakeListener is a minimal in-memory implementation of srt.Listener used to
// drive the SRT server's accept-error/restart code paths without touching the
// network.
type fakeListener struct {
// acceptCh, when non-nil, supplies values returned by Accept2(). When
// closed (or drained), Accept2() returns srt.ErrListenerClosed.
acceptCh chan acceptResult

closeOnce sync.Once
closed chan struct{}
}

type acceptResult struct {
req srt.ConnRequest
err error
}

func newFakeListener() *fakeListener {
return &fakeListener{
acceptCh: make(chan acceptResult, 4),
closed: make(chan struct{}),
}
}

func (f *fakeListener) Accept2() (srt.ConnRequest, error) {
select {
case <-f.closed:
return nil, srt.ErrListenerClosed
case res, ok := <-f.acceptCh:
if !ok {
return nil, srt.ErrListenerClosed
}
return res.req, res.err
}
}

func (f *fakeListener) Accept(_ srt.AcceptFunc) (srt.Conn, srt.ConnType, error) {
return nil, srt.REJECT, errors.New("Accept is not used in tests")
}

func (f *fakeListener) Close() {
f.closeOnce.Do(func() {
close(f.closed)
})
}

func (f *fakeListener) Addr() net.Addr {
a, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0")
return a
}

// withListenerFactory swaps the package-level srtListen indirection for the
// duration of a test and restores it on cleanup.
func withListenerFactory(t *testing.T, fn func(network, address string, config srt.Config) (srt.Listener, error)) {
t.Helper()
orig := srtListen
srtListen = fn
t.Cleanup(func() { srtListen = orig })
}

// withFastBackoff shortens the listener restart backoff so tests run quickly,
// and restores the production values on cleanup.
func withFastBackoff(t *testing.T) {
t.Helper()
origBase, origMax := listenerRestartBaseDelay, listenerRestartMaxDelay
listenerRestartBaseDelay = 5 * time.Millisecond
listenerRestartMaxDelay = 50 * time.Millisecond
t.Cleanup(func() {
listenerRestartBaseDelay = origBase
listenerRestartMaxDelay = origMax
})
}

// newRestartTestServer constructs a minimal *Server suitable for restart
// tests. It does not need a path manager because no connection requests are
// produced by the fake listeners used here.
func newRestartTestServer() *Server {
return &Server{
Address: "127.0.0.1:0",
ReadTimeout: conf.Duration(10 * time.Second),
WriteTimeout: conf.Duration(10 * time.Second),
UDPMaxPayloadSize: 1472,
Parent: test.NilLogger,
}
}

func TestServerRestartsOnTransientListenerError(t *testing.T) {
withFastBackoff(t)

first := newFakeListener()
first.acceptCh <- acceptResult{err: errors.New("read udp: network is unreachable")}

second := newFakeListener()
secondCreated := make(chan struct{})

var calls int32

Check failure on line 112 in internal/servers/srt/server_restart_test.go

View workflow job for this annotation

GitHub Actions / go

atomic: var calls int32 may be simplified using atomic.Int32 (modernize)
withListenerFactory(t, func(_, _ string, _ srt.Config) (srt.Listener, error) {
switch atomic.AddInt32(&calls, 1) {
case 1:
return first, nil
case 2:
close(secondCreated)
return second, nil
default:
return nil, errors.New("unexpected extra Listen call")
}
})

s := newRestartTestServer()
require.NoError(t, s.Initialize())
defer s.Close()

select {
case <-secondCreated:
case <-time.After(2 * time.Second):
t.Fatalf("server did not restart its listener after a transient error")
}

require.EqualValues(t, 2, atomic.LoadInt32(&calls))
}

func TestServerDoesNotRestartOnErrListenerClosed(t *testing.T) {
withFastBackoff(t)

first := newFakeListener()
first.acceptCh <- acceptResult{err: srt.ErrListenerClosed}

var calls int32

Check failure on line 144 in internal/servers/srt/server_restart_test.go

View workflow job for this annotation

GitHub Actions / go

atomic: var calls int32 may be simplified using atomic.Int32 (modernize)
withListenerFactory(t, func(_, _ string, _ srt.Config) (srt.Listener, error) {
atomic.AddInt32(&calls, 1)
return first, nil
})

s := newRestartTestServer()
require.NoError(t, s.Initialize())

// The server should shut its run loop down on its own once
// ErrListenerClosed propagates. Close() must then return promptly.
closed := make(chan struct{})
go func() {
s.Close()
close(closed)
}()

select {
case <-closed:
case <-time.After(2 * time.Second):
t.Fatalf("server.Close() did not return after ErrListenerClosed")
}

require.EqualValues(t, 1, atomic.LoadInt32(&calls), "no listener restart should be attempted")
}

func TestServerRestartGivesUpOnContextCancel(t *testing.T) {
withFastBackoff(t)
// Make every retry slow enough that we are guaranteed to be inside the
// backoff sleep when Close() cancels the context.
listenerRestartBaseDelay = 200 * time.Millisecond
listenerRestartMaxDelay = 200 * time.Millisecond

first := newFakeListener()
first.acceptCh <- acceptResult{err: errors.New("transient failure")}

restartAttempted := make(chan struct{}, 1)

var calls int32

Check failure on line 182 in internal/servers/srt/server_restart_test.go

View workflow job for this annotation

GitHub Actions / go

atomic: var calls int32 may be simplified using atomic.Int32 (modernize)
withListenerFactory(t, func(_, _ string, _ srt.Config) (srt.Listener, error) {
n := atomic.AddInt32(&calls, 1)
if n == 1 {
return first, nil
}
select {
case restartAttempted <- struct{}{}:
default:
}
return nil, errors.New("listen permanently broken")
})

s := newRestartTestServer()
require.NoError(t, s.Initialize())

select {
case <-restartAttempted:
case <-time.After(2 * time.Second):
t.Fatalf("server did not attempt at least one listener restart")
}

closed := make(chan struct{})
go func() {
s.Close()
close(closed)
}()

select {
case <-closed:
case <-time.After(2 * time.Second):
t.Fatalf("server.Close() did not return while restart loop was active")
}
}
Loading