diff --git a/api/openapi.yaml b/api/openapi.yaml index d27627bedb5..ab61661e7b2 100644 --- a/api/openapi.yaml +++ b/api/openapi.yaml @@ -685,6 +685,56 @@ components: id: type: string + PushTarget: + type: object + properties: + id: + type: string + description: unique identifier of the push target (UUID). + url: + type: string + description: URL of the push target. + state: + type: string + enum: + - idle + - running + - error + description: current state of the push target. + error: + type: string + description: error message if state is "error". + resolvedURL: + type: string + description: resolved URL with variables expanded. + bytesSent: + type: integer + format: int64 + description: total bytes sent to the push target. + + PushTargetList: + type: object + properties: + pageCount: + type: integer + format: int64 + itemCount: + type: integer + format: int64 + items: + type: array + items: + $ref: '#/components/schemas/PushTarget' + + PushTargetAdd: + type: object + required: + - url + properties: + url: + type: string + description: URL of the push target (supports rtmp://, rtmps://). + HLSMuxer: type: object properties: @@ -1778,6 +1828,192 @@ paths: schema: $ref: '#/components/schemas/Error' + /v3/paths/pushtargets/list/{name}: + get: + operationId: pushTargetsList + tags: [Paths, Push Targets] + summary: returns all push targets for a path. + description: 'Push targets are external servers to which the stream is pushed (e.g., YouTube Live, Twitch).' + parameters: + - name: name + in: path + required: true + description: name of the path. + schema: + type: string + - name: page + in: query + description: page number. + schema: + type: integer + default: 0 + - name: itemsPerPage + in: query + description: items per page. + schema: + type: integer + default: 100 + responses: + '200': + description: the request was successful. + content: + application/json: + schema: + $ref: '#/components/schemas/PushTargetList' + '400': + description: invalid request. + content: + application/json: + schema: + $ref: '#/components/schemas/Error' + '404': + description: path not found. + content: + application/json: + schema: + $ref: '#/components/schemas/Error' + '500': + description: server error. + content: + application/json: + schema: + $ref: '#/components/schemas/Error' + + /v3/paths/pushtargets/get/{name}/{id}: + get: + operationId: pushTargetsGet + tags: [Paths, Push Targets] + summary: returns a push target. + description: '' + parameters: + - name: name + in: path + required: true + description: name of the path. + schema: + type: string + - name: id + in: path + required: true + description: UUID of the push target. + schema: + type: string + responses: + '200': + description: the request was successful. + content: + application/json: + schema: + $ref: '#/components/schemas/PushTarget' + '400': + description: invalid request. + content: + application/json: + schema: + $ref: '#/components/schemas/Error' + '404': + description: path or push target not found. + content: + application/json: + schema: + $ref: '#/components/schemas/Error' + '500': + description: server error. + content: + application/json: + schema: + $ref: '#/components/schemas/Error' + + /v3/paths/pushtargets/add/{name}: + post: + operationId: pushTargetsAdd + tags: [Paths, Push Targets] + summary: adds a push target to a path. + description: 'Push the stream to an external server (e.g., YouTube Live, Twitch).' + parameters: + - name: name + in: path + required: true + description: name of the path. + schema: + type: string + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/PushTargetAdd' + responses: + '200': + description: the request was successful. + content: + application/json: + schema: + $ref: '#/components/schemas/PushTarget' + '400': + description: invalid request. + content: + application/json: + schema: + $ref: '#/components/schemas/Error' + '404': + description: path not found. + content: + application/json: + schema: + $ref: '#/components/schemas/Error' + '500': + description: server error. + content: + application/json: + schema: + $ref: '#/components/schemas/Error' + + /v3/paths/pushtargets/remove/{name}/{id}: + delete: + operationId: pushTargetsRemove + tags: [Paths, Push Targets] + summary: removes a push target from a path. + description: '' + parameters: + - name: name + in: path + required: true + description: name of the path. + schema: + type: string + - name: id + in: path + required: true + description: UUID of the push target. + schema: + type: string + responses: + '200': + description: the request was successful. + content: + application/json: + schema: + $ref: '#/components/schemas/OK' + '400': + description: invalid request. + content: + application/json: + schema: + $ref: '#/components/schemas/Error' + '404': + description: path or push target not found. + content: + application/json: + schema: + $ref: '#/components/schemas/Error' + '500': + description: server error. + content: + application/json: + schema: + $ref: '#/components/schemas/Error' + /v3/rtspconns/list: get: operationId: rtspConnsList diff --git a/internal/api/api.go b/internal/api/api.go index e44c16a7aaf..e84723726d4 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -134,6 +134,11 @@ func (a *API) Initialize() error { group.GET("/paths/list", a.onPathsList) group.GET("/paths/get/*name", a.onPathsGet) + group.GET("/paths/pushtargets/list/*name", a.onPushTargetsList) + group.GET("/paths/pushtargets/get/*name", a.onPushTargetsGet) + group.POST("/paths/pushtargets/add/*name", a.onPushTargetsAdd) + group.DELETE("/paths/pushtargets/remove/*name", a.onPushTargetsRemove) + if !interfaceIsEmpty(a.HLSServer) { group.GET("/hlsmuxers/list", a.onHLSMuxersList) group.GET("/hlsmuxers/get/*name", a.onHLSMuxersGet) diff --git a/internal/api/api_paths_test.go b/internal/api/api_paths_test.go index 61a3683fb3b..4f45a2ea36f 100644 --- a/internal/api/api_paths_test.go +++ b/internal/api/api_paths_test.go @@ -5,6 +5,8 @@ import ( "testing" "time" + "github.com/google/uuid" + "github.com/bluenviron/mediamtx/internal/conf" "github.com/bluenviron/mediamtx/internal/defs" "github.com/bluenviron/mediamtx/internal/test" @@ -31,6 +33,22 @@ func (m *testPathManager) APIPathsGet(name string) (*defs.APIPath, error) { return path, nil } +func (m *testPathManager) APIPushTargetsList(_ string) (*defs.APIPushTargetList, error) { + return &defs.APIPushTargetList{Items: []*defs.APIPushTarget{}}, nil +} + +func (m *testPathManager) APIPushTargetsGet(_ string, _ uuid.UUID) (*defs.APIPushTarget, error) { + return nil, conf.ErrPathNotFound +} + +func (m *testPathManager) APIPushTargetsAdd(_ string, _ defs.APIPushTargetAdd) (*defs.APIPushTarget, error) { + return &defs.APIPushTarget{}, nil +} + +func (m *testPathManager) APIPushTargetsRemove(_ string, _ uuid.UUID) error { + return nil +} + func TestPathsList(t *testing.T) { now := time.Now() pathManager := &testPathManager{ diff --git a/internal/api/api_pushtargets.go b/internal/api/api_pushtargets.go new file mode 100644 index 00000000000..e2926cb8d7e --- /dev/null +++ b/internal/api/api_pushtargets.go @@ -0,0 +1,242 @@ +//nolint:dupl +package api //nolint:revive + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + "strings" + + "github.com/gin-gonic/gin" + "github.com/google/uuid" + + "github.com/bluenviron/mediamtx/internal/conf" + "github.com/bluenviron/mediamtx/internal/defs" + "github.com/bluenviron/mediamtx/internal/push" +) + +func splitPushTargetPath(value string) (string, uuid.UUID, error) { + i := strings.LastIndex(value, "/") + if i < 0 { + return "", uuid.UUID{}, fmt.Errorf("invalid path format, expected: /path/targetID") + } + + id, err := uuid.Parse(value[i+1:]) + if err != nil { + return "", uuid.UUID{}, fmt.Errorf("invalid target ID: %w", err) + } + + return value[:i], id, nil +} + +func pathToOptional(pathConf *conf.Path) (*conf.OptionalPath, error) { + byts, err := json.Marshal(pathConf) + if err != nil { + return nil, err + } + + var optional conf.OptionalPath + err = json.Unmarshal(byts, &optional) + if err != nil { + return nil, err + } + + return &optional, nil +} + +func (a *API) persistPushTargets(runtimePathName string, mutate func(*conf.Path) error) error { + if a.Conf == nil { + return nil + } + + pathData, err := a.PathManager.APIPathsGet(runtimePathName) + if err != nil { + return err + } + + a.mutex.Lock() + defer a.mutex.Unlock() + + newConf := a.Conf.Clone() + pathConf, ok := newConf.Paths[pathData.ConfName] + if !ok { + return conf.ErrPathNotFound + } + + updated := pathConf.Clone() + err = mutate(updated) + if err != nil { + return err + } + + optional, err := pathToOptional(updated) + if err != nil { + return err + } + + err = newConf.ReplacePath(pathData.ConfName, optional) + if err != nil { + return err + } + + err = newConf.Validate(nil) + if err != nil { + return err + } + + a.Conf = newConf + a.Parent.APIConfigSet(newConf) + + return nil +} + +func (a *API) onPushTargetsList(ctx *gin.Context) { + pathName, ok := paramName(ctx) + if !ok { + a.writeError(ctx, http.StatusBadRequest, fmt.Errorf("invalid path name")) + return + } + + data, err := a.PathManager.APIPushTargetsList(pathName) + if err != nil { + if errors.Is(err, conf.ErrPathNotFound) { + a.writeError(ctx, http.StatusNotFound, err) + } else { + a.writeError(ctx, http.StatusInternalServerError, err) + } + return + } + + data.ItemCount = len(data.Items) + pageCount, err := paginate(&data.Items, ctx.Query("itemsPerPage"), ctx.Query("page")) + if err != nil { + a.writeError(ctx, http.StatusBadRequest, err) + return + } + data.PageCount = pageCount + + ctx.JSON(http.StatusOK, data) +} + +func (a *API) onPushTargetsGet(ctx *gin.Context) { + pathName, ok := paramName(ctx) + if !ok { + a.writeError(ctx, http.StatusBadRequest, fmt.Errorf("invalid path name")) + return + } + + pathName, id, err := splitPushTargetPath(pathName) + if err != nil { + a.writeError(ctx, http.StatusBadRequest, err) + return + } + + data, err := a.PathManager.APIPushTargetsGet(pathName, id) + if err != nil { + if errors.Is(err, conf.ErrPathNotFound) || errors.Is(err, push.ErrTargetNotFound) { + a.writeError(ctx, http.StatusNotFound, err) + } else { + a.writeError(ctx, http.StatusInternalServerError, err) + } + return + } + + ctx.JSON(http.StatusOK, data) +} + +func (a *API) onPushTargetsAdd(ctx *gin.Context) { + pathName, ok := paramName(ctx) + if !ok { + a.writeError(ctx, http.StatusBadRequest, fmt.Errorf("invalid path name")) + return + } + + var req defs.APIPushTargetAdd + if err := ctx.ShouldBindJSON(&req); err != nil { + a.writeError(ctx, http.StatusBadRequest, fmt.Errorf("invalid request body: %w", err)) + return + } + + if req.URL == "" { + a.writeError(ctx, http.StatusBadRequest, fmt.Errorf("url is required")) + return + } + + data, err := a.PathManager.APIPushTargetsAdd(pathName, req) + if err != nil { + if errors.Is(err, conf.ErrPathNotFound) { + a.writeError(ctx, http.StatusNotFound, err) + } else { + a.writeError(ctx, http.StatusInternalServerError, err) + } + return + } + + err = a.persistPushTargets(pathName, func(pathConf *conf.Path) error { + pathConf.PushTargets = append(pathConf.PushTargets, conf.PushTarget{URL: req.URL}) + return nil + }) + if err != nil { + _ = a.PathManager.APIPushTargetsRemove(pathName, data.ID) + a.writeError(ctx, http.StatusInternalServerError, err) + return + } + + ctx.JSON(http.StatusOK, data) +} + +func (a *API) onPushTargetsRemove(ctx *gin.Context) { + pathName, ok := paramName(ctx) + if !ok { + a.writeError(ctx, http.StatusBadRequest, fmt.Errorf("invalid path name")) + return + } + + pathName, id, err := splitPushTargetPath(pathName) + if err != nil { + a.writeError(ctx, http.StatusBadRequest, err) + return + } + + target, err := a.PathManager.APIPushTargetsGet(pathName, id) + if err != nil { + if errors.Is(err, conf.ErrPathNotFound) || errors.Is(err, push.ErrTargetNotFound) { + a.writeError(ctx, http.StatusNotFound, err) + } else { + a.writeError(ctx, http.StatusInternalServerError, err) + } + return + } + + err = a.PathManager.APIPushTargetsRemove(pathName, id) + if err != nil { + if errors.Is(err, conf.ErrPathNotFound) || errors.Is(err, push.ErrTargetNotFound) { + a.writeError(ctx, http.StatusNotFound, err) + } else { + a.writeError(ctx, http.StatusInternalServerError, err) + } + return + } + + err = a.persistPushTargets(pathName, func(pathConf *conf.Path) error { + for i, persisted := range pathConf.PushTargets { + if persisted.URL == target.URL { + pathConf.PushTargets = append(pathConf.PushTargets[:i], pathConf.PushTargets[i+1:]...) + return nil + } + } + + return push.ErrTargetNotFound + }) + if err != nil { + if errors.Is(err, push.ErrTargetNotFound) || errors.Is(err, conf.ErrPathNotFound) { + a.writeError(ctx, http.StatusNotFound, err) + } else { + a.writeError(ctx, http.StatusInternalServerError, err) + } + return + } + + a.writeOK(ctx) +} diff --git a/internal/api/api_pushtargets_test.go b/internal/api/api_pushtargets_test.go new file mode 100644 index 00000000000..c5291b5cff0 --- /dev/null +++ b/internal/api/api_pushtargets_test.go @@ -0,0 +1,192 @@ +package api //nolint:revive + +import ( + "net/http" + "sort" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + "github.com/bluenviron/mediamtx/internal/conf" + "github.com/bluenviron/mediamtx/internal/defs" + "github.com/bluenviron/mediamtx/internal/push" + "github.com/bluenviron/mediamtx/internal/test" +) + +type testPushTargetPathManager struct { + targets map[string]map[uuid.UUID]*defs.APIPushTarget +} + +func (*testPushTargetPathManager) APIPathsList() (*defs.APIPathList, error) { + panic("unused") +} + +func (*testPushTargetPathManager) APIPathsGet(string) (*defs.APIPath, error) { + panic("unused") +} + +func (m *testPushTargetPathManager) APIPushTargetsList(pathName string) (*defs.APIPushTargetList, error) { + items, ok := m.targets[pathName] + if !ok { + return nil, conf.ErrPathNotFound + } + + out := make([]*defs.APIPushTarget, 0, len(items)) + for _, item := range items { + copied := *item + out = append(out, &copied) + } + + sort.Slice(out, func(i, j int) bool { + return out[i].ID.String() < out[j].ID.String() + }) + + return &defs.APIPushTargetList{Items: out}, nil +} + +func (m *testPushTargetPathManager) APIPushTargetsGet(pathName string, id uuid.UUID) (*defs.APIPushTarget, error) { + items, ok := m.targets[pathName] + if !ok { + return nil, conf.ErrPathNotFound + } + + item, ok := items[id] + if !ok { + return nil, push.ErrTargetNotFound + } + + copied := *item + return &copied, nil +} + +func (m *testPushTargetPathManager) APIPushTargetsAdd(pathName string, req defs.APIPushTargetAdd) (*defs.APIPushTarget, error) { + items, ok := m.targets[pathName] + if !ok { + return nil, conf.ErrPathNotFound + } + + item := &defs.APIPushTarget{ + ID: uuid.New(), + Created: time.Now(), + URL: req.URL, + State: defs.APIPushTargetStateIdle, + BytesSent: 0, + } + items[item.ID] = item + + copied := *item + return &copied, nil +} + +func (m *testPushTargetPathManager) APIPushTargetsRemove(pathName string, id uuid.UUID) error { + items, ok := m.targets[pathName] + if !ok { + return conf.ErrPathNotFound + } + + if _, ok := items[id]; !ok { + return push.ErrTargetNotFound + } + + delete(items, id) + return nil +} + +func TestPushTargetsLifecycle(t *testing.T) { + pathManager := &testPushTargetPathManager{ + targets: map[string]map[uuid.UUID]*defs.APIPushTarget{ + "folder/stream": {}, + }, + } + + api := API{ + Address: "localhost:9996", + ReadTimeout: conf.Duration(10 * time.Second), + WriteTimeout: conf.Duration(10 * time.Second), + AuthManager: test.NilAuthManager, + PathManager: pathManager, + Parent: &testParent{}, + } + err := api.Initialize() + require.NoError(t, err) + defer api.Close() + + tr := &http.Transport{} + defer tr.CloseIdleConnections() + hc := &http.Client{Transport: tr} + + var added defs.APIPushTarget + httpRequest(t, hc, http.MethodPost, + "http://localhost:9996/v3/paths/pushtargets/add/folder/stream", + defs.APIPushTargetAdd{URL: "rtmp://example.com/live/test"}, + &added) + + require.Equal(t, "rtmp://example.com/live/test", added.URL) + require.NotEqual(t, uuid.Nil, added.ID) + + var listed defs.APIPushTargetList + httpRequest(t, hc, http.MethodGet, + "http://localhost:9996/v3/paths/pushtargets/list/folder/stream", + nil, + &listed) + + require.Equal(t, 1, listed.ItemCount) + require.Equal(t, 1, listed.PageCount) + require.Len(t, listed.Items, 1) + require.Equal(t, added.ID, listed.Items[0].ID) + + var got defs.APIPushTarget + httpRequest(t, hc, http.MethodGet, + "http://localhost:9996/v3/paths/pushtargets/get/folder/stream/"+added.ID.String(), + nil, + &got) + + require.Equal(t, added.ID, got.ID) + require.Equal(t, added.URL, got.URL) + + httpRequest(t, hc, http.MethodDelete, + "http://localhost:9996/v3/paths/pushtargets/remove/folder/stream/"+added.ID.String(), + nil, + nil) + + httpRequest(t, hc, http.MethodGet, + "http://localhost:9996/v3/paths/pushtargets/list/folder/stream", + nil, + &listed) + + require.Equal(t, 0, listed.ItemCount) + require.Len(t, listed.Items, 0) +} + +func TestPushTargetsGetNotFound(t *testing.T) { + pathManager := &testPushTargetPathManager{ + targets: map[string]map[uuid.UUID]*defs.APIPushTarget{ + "folder/stream": {}, + }, + } + + api := API{ + Address: "localhost:9996", + ReadTimeout: conf.Duration(10 * time.Second), + WriteTimeout: conf.Duration(10 * time.Second), + AuthManager: test.NilAuthManager, + PathManager: pathManager, + Parent: &testParent{}, + } + err := api.Initialize() + require.NoError(t, err) + defer api.Close() + + tr := &http.Transport{} + defer tr.CloseIdleConnections() + hc := &http.Client{Transport: tr} + + res, err := hc.Get("http://localhost:9996/v3/paths/pushtargets/get/folder/stream/" + uuid.New().String()) + require.NoError(t, err) + defer res.Body.Close() + + require.Equal(t, http.StatusNotFound, res.StatusCode) + checkError(t, res.Body, "push target not found") +} diff --git a/internal/conf/conf_test.go b/internal/conf/conf_test.go index 1a4dd199709..474512ef138 100644 --- a/internal/conf/conf_test.go +++ b/internal/conf/conf_test.go @@ -87,6 +87,7 @@ func TestConfFromFile(t *testing.T) { RPICameraSoftwareH264Profile: "baseline", RPICameraSoftwareH264Level: "4.1", RPICameraMJPEGQuality: 60, + PushTargets: PushTargets{}, RunOnDemandStartTimeout: 5 * Duration(time.Second), RunOnDemandCloseAfter: 10 * Duration(time.Second), }, pa) @@ -747,6 +748,14 @@ func TestConfErrors(t *testing.T) { " alwaysAvailableFile: /path/to/file.mp4\n", "'alwaysAvailableFile' and 'alwaysAvailableTracks' cannot be used together", }, + { + "invalid push target URL", + "paths:\n" + + " mypath:\n" + + " pushTargets:\n" + + " - url: http://example.com/live/test\n", + "push target 0: push target URL must start with rtmp://, rtmps://, rtsp://, rtsps://, or srt://", + }, } { t.Run(ca.name, func(t *testing.T) { tmpf, err := createTempFile([]byte(ca.conf)) diff --git a/internal/conf/path.go b/internal/conf/path.go index cca311af915..5b2c92a5ee8 100644 --- a/internal/conf/path.go +++ b/internal/conf/path.go @@ -286,6 +286,9 @@ type Path struct { RPICameraSecondaryFPS float64 `json:"-"` // filled by Validate() RPICameraSecondaryMJPEGQuality uint `json:"-"` // filled by Validate() + // Push + PushTargets PushTargets `json:"pushTargets"` + // Hooks RunOnInit string `json:"runOnInit"` RunOnInitRestart bool `json:"runOnInitRestart"` @@ -846,6 +849,12 @@ func (pconf *Path) validate( }() } + // Push + + if err := pconf.PushTargets.Validate(); err != nil { + return err + } + // Hooks if pconf.RunOnInit != "" && pconf.Regexp != nil { diff --git a/internal/conf/push_target.go b/internal/conf/push_target.go new file mode 100644 index 00000000000..355cc43f858 --- /dev/null +++ b/internal/conf/push_target.go @@ -0,0 +1,48 @@ +package conf + +import ( + "fmt" + "net/url" + "strings" +) + +// PushTarget is a push target configuration. +type PushTarget struct { + URL string `json:"url"` +} + +// Validate validates a push target. +func (pt *PushTarget) Validate() error { + if pt.URL == "" { + return fmt.Errorf("push target URL is empty") + } + + // Check for valid protocols + if !strings.HasPrefix(pt.URL, "rtmp://") && + !strings.HasPrefix(pt.URL, "rtmps://") && + !strings.HasPrefix(pt.URL, "rtsp://") && + !strings.HasPrefix(pt.URL, "rtsps://") && + !strings.HasPrefix(pt.URL, "srt://") { + return fmt.Errorf("push target URL must start with rtmp://, rtmps://, rtsp://, rtsps://, or srt://") + } + + _, err := url.Parse(pt.URL) + if err != nil { + return fmt.Errorf("invalid push target URL: %w", err) + } + + return nil +} + +// PushTargets is a list of push targets. +type PushTargets []PushTarget + +// Validate validates push targets. +func (pts PushTargets) Validate() error { + for i, pt := range pts { + if err := pt.Validate(); err != nil { + return fmt.Errorf("push target %d: %w", i, err) + } + } + return nil +} diff --git a/internal/core/api_test.go b/internal/core/api_test.go index 6932cdf4847..6e95fc71c74 100644 --- a/internal/core/api_test.go +++ b/internal/core/api_test.go @@ -343,6 +343,133 @@ func TestAPIPathsGet(t *testing.T) { } } +func TestAPIPushTargetsAlwaysAvailablePushesWhileOffline(t *testing.T) { + type pushTarget struct { + ID uuid.UUID `json:"id"` + State string `json:"state"` + Error string `json:"error"` + BytesSent uint64 `json:"bytesSent"` + } + + type pushTargetList struct { + ItemCount int `json:"itemCount"` + PageCount int `json:"pageCount"` + Items []pushTarget `json:"items"` + } + + p, ok := newInstance("api: yes\n" + + "apiAddress: 127.0.0.1:19997\n" + + "rtsp: no\n" + + "rtmp: no\n" + + "hls: no\n" + + "webrtc: no\n" + + "srt: no\n" + + "paths:\n" + + " test:\n" + + " alwaysAvailable: yes\n" + + " alwaysAvailableTracks:\n" + + " - codec: H264\n") + require.Equal(t, true, ok) + defer p.Close() + + tr := &http.Transport{} + defer tr.CloseIdleConnections() + hc := &http.Client{Transport: tr} + + var added pushTarget + httpRequest(t, hc, http.MethodPost, + "http://localhost:19997/v3/paths/pushtargets/add/test", + map[string]any{"url": "rtmp://127.0.0.1:29999/test"}, + &added) + + require.NotEqual(t, uuid.Nil, added.ID) + + time.Sleep(1500 * time.Millisecond) + + var listed pushTargetList + httpRequest(t, hc, http.MethodGet, + "http://localhost:19997/v3/paths/pushtargets/list/test", + nil, + &listed) + + require.Equal(t, 1, listed.ItemCount) + require.Equal(t, 1, listed.PageCount) + require.Len(t, listed.Items, 1) + require.Equal(t, added.ID, listed.Items[0].ID) + require.Equal(t, "error", listed.Items[0].State) + require.NotEqual(t, "", listed.Items[0].Error) + require.Equal(t, uint64(0), listed.Items[0].BytesSent) +} + +func TestAPIPushTargetsPersistAcrossLiveReload(t *testing.T) { + type pushTarget struct { + ID uuid.UUID `json:"id"` + URL string `json:"url"` + } + + type pushTargetList struct { + ItemCount int `json:"itemCount"` + Items []pushTarget `json:"items"` + } + + p, ok := newInstance("api: yes\n" + + "apiAddress: 127.0.0.1:19997\n" + + "rtsp: no\n" + + "rtmp: no\n" + + "hls: no\n" + + "webrtc: no\n" + + "srt: no\n" + + "paths:\n" + + " test:\n") + require.Equal(t, true, ok) + defer p.Close() + + tr := &http.Transport{} + defer tr.CloseIdleConnections() + hc := &http.Client{Transport: tr} + + var added pushTarget + httpRequest(t, hc, http.MethodPost, + "http://localhost:19997/v3/paths/pushtargets/add/test", + map[string]any{"url": "rtmp://127.0.0.1:29999/test"}, + &added) + + httpRequest(t, hc, http.MethodPatch, + "http://localhost:19997/v3/config/paths/patch/test", + map[string]any{"record": true}, + nil) + + require.Eventually(t, func() bool { + var listed pushTargetList + httpRequest(t, hc, http.MethodGet, + "http://localhost:19997/v3/paths/pushtargets/list/test", + nil, + &listed) + + return listed.ItemCount == 1 && len(listed.Items) == 1 && listed.Items[0].URL == added.URL + }, 5*time.Second, 100*time.Millisecond) + + httpRequest(t, hc, http.MethodDelete, + "http://localhost:19997/v3/paths/pushtargets/remove/test/"+added.ID.String(), + nil, + nil) + + httpRequest(t, hc, http.MethodPatch, + "http://localhost:19997/v3/config/paths/patch/test", + map[string]any{"record": false}, + nil) + + require.Eventually(t, func() bool { + var listed pushTargetList + httpRequest(t, hc, http.MethodGet, + "http://localhost:19997/v3/paths/pushtargets/list/test", + nil, + &listed) + + return listed.ItemCount == 0 && len(listed.Items) == 0 + }, 5*time.Second, 100*time.Millisecond) +} + func TestAPIProtocolListGet(t *testing.T) { serverCertFpath, err := test.CreateTempFile(test.TLSCertPub) require.NoError(t, err) diff --git a/internal/core/path.go b/internal/core/path.go index 5de44e5cb1a..baa9602e15d 100644 --- a/internal/core/path.go +++ b/internal/core/path.go @@ -10,12 +10,14 @@ import ( "time" "github.com/bluenviron/gortsplib/v5/pkg/description" + "github.com/google/uuid" "github.com/bluenviron/mediamtx/internal/conf" "github.com/bluenviron/mediamtx/internal/defs" "github.com/bluenviron/mediamtx/internal/externalcmd" "github.com/bluenviron/mediamtx/internal/hooks" "github.com/bluenviron/mediamtx/internal/logger" + "github.com/bluenviron/mediamtx/internal/push" "github.com/bluenviron/mediamtx/internal/recorder" "github.com/bluenviron/mediamtx/internal/staticsources" "github.com/bluenviron/mediamtx/internal/stream" @@ -95,6 +97,7 @@ type path struct { recorder *recorder.Recorder availableTime time.Time onlineTime time.Time + pushManager *push.Manager onUnDemandHook func(string) onNotReadyHook func() readers map[defs.Reader]struct{} @@ -117,11 +120,53 @@ type path struct { chAddReader chan defs.PathAddReaderReq chRemoveReader chan defs.PathRemoveReaderReq chAPIPathsGet chan pathAPIPathsGetReq + chAPIPushTargetsList chan pathAPIPushTargetsListReq + chAPIPushTargetsGet chan pathAPIPushTargetsGetReq + chAPIPushTargetsAdd chan pathAPIPushTargetsAddReq + chAPIPushTargetsRemove chan pathAPIPushTargetsRemoveReq // out done chan struct{} } +type pathAPIPushTargetsListRes struct { + data *defs.APIPushTargetList + err error +} + +type pathAPIPushTargetsListReq struct { + res chan pathAPIPushTargetsListRes +} + +type pathAPIPushTargetsGetRes struct { + data *defs.APIPushTarget + err error +} + +type pathAPIPushTargetsGetReq struct { + id uuid.UUID + res chan pathAPIPushTargetsGetRes +} + +type pathAPIPushTargetsAddRes struct { + data *defs.APIPushTarget + err error +} + +type pathAPIPushTargetsAddReq struct { + req defs.APIPushTargetAdd + res chan pathAPIPushTargetsAddRes +} + +type pathAPIPushTargetsRemoveRes struct { + err error +} + +type pathAPIPushTargetsRemoveReq struct { + id uuid.UUID + res chan pathAPIPushTargetsRemoveRes +} + func (pa *path) initialize() { ctx, ctxCancel := context.WithCancel(pa.parentCtx) @@ -143,8 +188,14 @@ func (pa *path) initialize() { pa.chAddReader = make(chan defs.PathAddReaderReq) pa.chRemoveReader = make(chan defs.PathRemoveReaderReq) pa.chAPIPathsGet = make(chan pathAPIPathsGetReq) + pa.chAPIPushTargetsList = make(chan pathAPIPushTargetsListReq) + pa.chAPIPushTargetsGet = make(chan pathAPIPushTargetsGetReq) + pa.chAPIPushTargetsAdd = make(chan pathAPIPushTargetsAddReq) + pa.chAPIPushTargetsRemove = make(chan pathAPIPushTargetsRemoveReq) pa.done = make(chan struct{}) + pa.syncPushTargets(pa.conf.PushTargets) + pa.Log(logger.Debug, "created") pa.wg.Add(1) @@ -243,6 +294,11 @@ func (pa *path) run() { pa.setNotAvailable() } + if pa.pushManager != nil { + pa.pushManager.Close() + pa.pushManager = nil + } + if pa.source != nil { if source, ok := pa.source.(*staticsources.Handler); ok { if !pa.conf.SourceOnDemand || pa.onDemandStaticSourceState != pathOnDemandStateInitial { @@ -340,6 +396,18 @@ func (pa *path) runInner() error { case req := <-pa.chAPIPathsGet: pa.doAPIPathsGet(req) + case req := <-pa.chAPIPushTargetsList: + pa.doAPIPushTargetsList(req) + + case req := <-pa.chAPIPushTargetsGet: + pa.doAPIPushTargetsGet(req) + + case req := <-pa.chAPIPushTargetsAdd: + pa.doAPIPushTargetsAdd(req) + + case req := <-pa.chAPIPushTargetsRemove: + pa.doAPIPushTargetsRemove(req) + case <-pa.ctx.Done(): return fmt.Errorf("terminated") } @@ -392,6 +460,8 @@ func (pa *path) doReloadConf(newConf *conf.Path) { pa.conf = newConf pa.confMutex.Unlock() + pa.syncPushTargets(newConf.PushTargets) + if pa.conf.HasStaticSource() { pa.source.(*staticsources.Handler).ReloadConf(newConf) } @@ -437,6 +507,10 @@ func (pa *path) doSourceStaticSetReady(req defs.PathSourceStaticSetReadyReq) { if pa.conf.AlwaysAvailable { pa.onlineTime = time.Now() + + if pa.pushManager != nil { + pa.pushManager.SetStream(pa.stream) + } } if pa.conf.HasOnDemandStaticSource() { @@ -562,6 +636,10 @@ func (pa *path) doAddPublisher(req defs.PathAddPublisherReq) { if pa.conf.AlwaysAvailable { pa.onlineTime = time.Now() + + if pa.pushManager != nil { + pa.pushManager.SetStream(pa.stream) + } } if pa.conf.HasOnDemandPublisher() && pa.onDemandPublisherState != pathOnDemandStateInitial { @@ -689,6 +767,105 @@ func (pa *path) doAPIPathsGet(req pathAPIPathsGetReq) { } } +func (pa *path) doAPIPushTargetsList(req pathAPIPushTargetsListReq) { + if pa.pushManager == nil { + req.res <- pathAPIPushTargetsListRes{data: &defs.APIPushTargetList{Items: []*defs.APIPushTarget{}}} + return + } + req.res <- pathAPIPushTargetsListRes{data: pa.pushManager.APIItem()} +} + +func (pa *path) doAPIPushTargetsGet(req pathAPIPushTargetsGetReq) { + if pa.pushManager == nil { + req.res <- pathAPIPushTargetsGetRes{err: push.ErrTargetNotFound} + return + } + target, err := pa.pushManager.GetTarget(req.id) + if err != nil { + req.res <- pathAPIPushTargetsGetRes{err: err} + return + } + req.res <- pathAPIPushTargetsGetRes{data: target.APIItem()} +} + +func (pa *path) doAPIPushTargetsAdd(req pathAPIPushTargetsAddReq) { + target := pa.ensurePushManager().AddTarget(req.req.URL) + + req.res <- pathAPIPushTargetsAddRes{data: target.APIItem()} +} + +func (pa *path) doAPIPushTargetsRemove(req pathAPIPushTargetsRemoveReq) { + if pa.pushManager == nil { + req.res <- pathAPIPushTargetsRemoveRes{err: push.ErrTargetNotFound} + return + } + + err := pa.pushManager.RemoveTarget(req.id) + if err != nil { + req.res <- pathAPIPushTargetsRemoveRes{err: err} + return + } + + if len(pa.pushManager.TargetsList()) == 0 { + pa.pushManager.Close() + pa.pushManager = nil + } + + req.res <- pathAPIPushTargetsRemoveRes{} +} + +func (pa *path) ensurePushManager() *push.Manager { + if pa.pushManager == nil { + pa.pushManager = &push.Manager{ + ReadTimeout: pa.readTimeout, + WriteTimeout: pa.writeTimeout, + PathName: pa.name, + Parent: pa, + } + pa.pushManager.Initialize() + + if pa.stream != nil { + pa.pushManager.SetStream(pa.stream) + } + } + + return pa.pushManager +} + +func (pa *path) syncPushTargets(targets conf.PushTargets) { + if len(targets) == 0 { + if pa.pushManager != nil { + pa.pushManager.Close() + pa.pushManager = nil + } + return + } + + pm := pa.ensurePushManager() + pending := make(map[string]int, len(targets)) + for _, target := range targets { + pending[target.URL]++ + } + + for _, current := range pm.TargetsList() { + if pending[current.URL] > 0 { + pending[current.URL]-- + continue + } + + _ = pm.RemoveTarget(current.UUID()) + } + + for _, target := range targets { + if pending[target.URL] == 0 { + continue + } + + pm.AddTarget(target.URL) + pending[target.URL]-- + } +} + func (pa *path) SafeConf() *conf.Path { pa.confMutex.RLock() defer pa.confMutex.RUnlock() @@ -817,6 +994,10 @@ func (pa *path) setAvailable( sourceDesc = source.APISourceDescribe() } + if pa.pushManager != nil { + pa.pushManager.SetStream(pa.stream) + } + pa.onNotReadyHook = hooks.OnReady(hooks.OnReadyParams{ Logger: pa, ExternalCmdPool: pa.externalCmdPool, @@ -866,6 +1047,11 @@ func (pa *path) setNotAvailable() { pa.recorder = nil } + // Stop pushing to external targets + if pa.pushManager != nil { + pa.pushManager.ClearStream() + } + if pa.stream != nil { pa.stream.Close() pa.stream = nil @@ -1075,3 +1261,51 @@ func (pa *path) APIPathsGet(req pathAPIPathsGetReq) (*defs.APIPath, error) { return nil, fmt.Errorf("terminated") } } + +// APIPushTargetsList is called by api. +func (pa *path) APIPushTargetsList() (*defs.APIPushTargetList, error) { + req := pathAPIPushTargetsListReq{res: make(chan pathAPIPushTargetsListRes)} + select { + case pa.chAPIPushTargetsList <- req: + res := <-req.res + return res.data, res.err + case <-pa.ctx.Done(): + return nil, fmt.Errorf("terminated") + } +} + +// APIPushTargetsGet is called by api. +func (pa *path) APIPushTargetsGet(id uuid.UUID) (*defs.APIPushTarget, error) { + req := pathAPIPushTargetsGetReq{id: id, res: make(chan pathAPIPushTargetsGetRes)} + select { + case pa.chAPIPushTargetsGet <- req: + res := <-req.res + return res.data, res.err + case <-pa.ctx.Done(): + return nil, fmt.Errorf("terminated") + } +} + +// APIPushTargetsAdd is called by api. +func (pa *path) APIPushTargetsAdd(add defs.APIPushTargetAdd) (*defs.APIPushTarget, error) { + req := pathAPIPushTargetsAddReq{req: add, res: make(chan pathAPIPushTargetsAddRes)} + select { + case pa.chAPIPushTargetsAdd <- req: + res := <-req.res + return res.data, res.err + case <-pa.ctx.Done(): + return nil, fmt.Errorf("terminated") + } +} + +// APIPushTargetsRemove is called by api. +func (pa *path) APIPushTargetsRemove(id uuid.UUID) error { + req := pathAPIPushTargetsRemoveReq{id: id, res: make(chan pathAPIPushTargetsRemoveRes)} + select { + case pa.chAPIPushTargetsRemove <- req: + res := <-req.res + return res.err + case <-pa.ctx.Done(): + return fmt.Errorf("terminated") + } +} diff --git a/internal/core/path_manager.go b/internal/core/path_manager.go index d9800957470..17ac6177066 100644 --- a/internal/core/path_manager.go +++ b/internal/core/path_manager.go @@ -8,6 +8,8 @@ import ( "sync" "sync/atomic" + "github.com/google/uuid" + "github.com/bluenviron/mediamtx/internal/auth" "github.com/bluenviron/mediamtx/internal/conf" "github.com/bluenviron/mediamtx/internal/defs" @@ -23,6 +25,7 @@ func pathConfCanBeUpdated(oldPathConf *conf.Path, newPathConf *conf.Path) bool { clone.Name = newPathConf.Name clone.Regexp = newPathConf.Regexp + clone.PushTargets = newPathConf.PushTargets clone.Record = newPathConf.Record clone.RecordPath = newPathConf.RecordPath @@ -646,3 +649,87 @@ func (pm *pathManager) APIPathsGet(name string) (*defs.APIPath, error) { return nil, fmt.Errorf("terminated") } } + +// APIPushTargetsList is called by api. +func (pm *pathManager) APIPushTargetsList(name string) (*defs.APIPushTargetList, error) { + req := pathAPIPathsGetReq{ + name: name, + res: make(chan pathAPIPathsGetRes), + } + + select { + case pm.chAPIPathsGet <- req: + res := <-req.res + if res.err != nil { + return nil, res.err + } + + return res.path.APIPushTargetsList() + + case <-pm.ctx.Done(): + return nil, fmt.Errorf("terminated") + } +} + +// APIPushTargetsGet is called by api. +func (pm *pathManager) APIPushTargetsGet(name string, id uuid.UUID) (*defs.APIPushTarget, error) { + req := pathAPIPathsGetReq{ + name: name, + res: make(chan pathAPIPathsGetRes), + } + + select { + case pm.chAPIPathsGet <- req: + res := <-req.res + if res.err != nil { + return nil, res.err + } + + return res.path.APIPushTargetsGet(id) + + case <-pm.ctx.Done(): + return nil, fmt.Errorf("terminated") + } +} + +// APIPushTargetsAdd is called by api. +func (pm *pathManager) APIPushTargetsAdd(name string, add defs.APIPushTargetAdd) (*defs.APIPushTarget, error) { + req := pathAPIPathsGetReq{ + name: name, + res: make(chan pathAPIPathsGetRes), + } + + select { + case pm.chAPIPathsGet <- req: + res := <-req.res + if res.err != nil { + return nil, res.err + } + + return res.path.APIPushTargetsAdd(add) + + case <-pm.ctx.Done(): + return nil, fmt.Errorf("terminated") + } +} + +// APIPushTargetsRemove is called by api. +func (pm *pathManager) APIPushTargetsRemove(name string, id uuid.UUID) error { + req := pathAPIPathsGetReq{ + name: name, + res: make(chan pathAPIPathsGetRes), + } + + select { + case pm.chAPIPathsGet <- req: + res := <-req.res + if res.err != nil { + return res.err + } + + return res.path.APIPushTargetsRemove(id) + + case <-pm.ctx.Done(): + return fmt.Errorf("terminated") + } +} diff --git a/internal/defs/api.go b/internal/defs/api.go index 627d7ed49f9..7531eee86ff 100644 --- a/internal/defs/api.go +++ b/internal/defs/api.go @@ -12,6 +12,10 @@ import ( type APIPathManager interface { APIPathsList() (*APIPathList, error) APIPathsGet(string) (*APIPath, error) + APIPushTargetsList(pathName string) (*APIPushTargetList, error) + APIPushTargetsGet(pathName string, id uuid.UUID) (*APIPushTarget, error) + APIPushTargetsAdd(pathName string, req APIPushTargetAdd) (*APIPushTarget, error) + APIPushTargetsRemove(pathName string, id uuid.UUID) error } // APIHLSServer contains methods used by the API and Metrics server. @@ -397,6 +401,46 @@ type APIWebRTCSessionList struct { Items []APIWebRTCSession `json:"items"` } +// APIPushTargetState is the state of a push target. +type APIPushTargetState string + +// states. +const ( + APIPushTargetStateIdle APIPushTargetState = "idle" + APIPushTargetStateRunning APIPushTargetState = "running" + APIPushTargetStateError APIPushTargetState = "error" +) + +// APIPushTarget is a push target. +type APIPushTarget struct { + ID uuid.UUID `json:"id"` + Created time.Time `json:"created"` + URL string `json:"url"` + State APIPushTargetState `json:"state"` + Error string `json:"error,omitempty"` + BytesSent uint64 `json:"bytesSent"` +} + +// APIPushTargetList is a list of push targets. +type APIPushTargetList struct { + ItemCount int `json:"itemCount"` + PageCount int `json:"pageCount"` + Items []*APIPushTarget `json:"items"` +} + +// APIPushTargetAdd is the payload for adding a push target. +type APIPushTargetAdd struct { + URL string `json:"url"` +} + +// APIPushManager contains methods used by the API. +type APIPushManager interface { + APIPushTargetsList(pathName string) (*APIPushTargetList, error) + APIPushTargetsGet(pathName string, id uuid.UUID) (*APIPushTarget, error) + APIPushTargetsAdd(pathName string, req APIPushTargetAdd) (*APIPushTarget, error) + APIPushTargetsRemove(pathName string, id uuid.UUID) error +} + // APIRecordingSegment is a recording segment. type APIRecordingSegment struct { Start time.Time `json:"start"` diff --git a/internal/metrics/metrics_test.go b/internal/metrics/metrics_test.go index 5e09711cbab..b47c25be6a4 100644 --- a/internal/metrics/metrics_test.go +++ b/internal/metrics/metrics_test.go @@ -54,6 +54,22 @@ func (dummyPathManager) APIPathsGet(string) (*defs.APIPath, error) { panic("unused") } +func (dummyPathManager) APIPushTargetsList(string) (*defs.APIPushTargetList, error) { + panic("unused") +} + +func (dummyPathManager) APIPushTargetsGet(string, uuid.UUID) (*defs.APIPushTarget, error) { + panic("unused") +} + +func (dummyPathManager) APIPushTargetsAdd(string, defs.APIPushTargetAdd) (*defs.APIPushTarget, error) { + panic("unused") +} + +func (dummyPathManager) APIPushTargetsRemove(string, uuid.UUID) error { + panic("unused") +} + type dummyHLSServer struct{} func (dummyHLSServer) APIMuxersList() (*defs.APIHLSMuxerList, error) { diff --git a/internal/push/manager.go b/internal/push/manager.go new file mode 100644 index 00000000000..c3b0d02fadc --- /dev/null +++ b/internal/push/manager.go @@ -0,0 +1,158 @@ +package push + +import ( + "fmt" + "sync" + + "github.com/google/uuid" + + "github.com/bluenviron/mediamtx/internal/conf" + "github.com/bluenviron/mediamtx/internal/defs" + "github.com/bluenviron/mediamtx/internal/logger" + "github.com/bluenviron/mediamtx/internal/stream" +) + +// ErrTargetNotFound is returned when a push target is not found. +var ErrTargetNotFound = fmt.Errorf("push target not found") + +// ManagerParent is the parent interface. +type ManagerParent interface { + logger.Writer +} + +// Manager manages push targets for a path. +type Manager struct { + ReadTimeout conf.Duration + WriteTimeout conf.Duration + PathName string + Parent ManagerParent + + mutex sync.RWMutex + targets map[uuid.UUID]*Target + stream *stream.Stream +} + +// Initialize initializes the Manager. +func (m *Manager) Initialize() { + m.targets = make(map[uuid.UUID]*Target) +} + +// Close closes the Manager and all its targets. +func (m *Manager) Close() { + m.mutex.Lock() + defer m.mutex.Unlock() + + for _, t := range m.targets { + t.Close() + } +} + +// Log implements logger.Writer. +func (m *Manager) Log(level logger.Level, format string, args ...any) { + m.Parent.Log(level, "[push] "+format, args...) +} + +// SetStream sets the stream for all targets. +func (m *Manager) SetStream(strm *stream.Stream) { + m.mutex.Lock() + defer m.mutex.Unlock() + + m.stream = strm + + for _, t := range m.targets { + t.SetStream(strm) + } +} + +// ClearStream clears the stream from all targets. +func (m *Manager) ClearStream() { + m.mutex.Lock() + defer m.mutex.Unlock() + + m.stream = nil + + for _, t := range m.targets { + t.ClearStream() + } +} + +// AddTarget adds a new push target. +func (m *Manager) AddTarget(targetURL string) *Target { + m.mutex.Lock() + defer m.mutex.Unlock() + + t := &Target{ + URL: targetURL, + ReadTimeout: m.ReadTimeout, + WriteTimeout: m.WriteTimeout, + Parent: m, + PathName: m.PathName, + } + t.Initialize() + + if m.stream != nil { + t.SetStream(m.stream) + } + + m.targets[t.uuid] = t + + return t +} + +// RemoveTarget removes a push target by ID. +func (m *Manager) RemoveTarget(id uuid.UUID) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + t, ok := m.targets[id] + if !ok { + return ErrTargetNotFound + } + + t.Close() + delete(m.targets, id) + + return nil +} + +// GetTarget returns a target by ID. +func (m *Manager) GetTarget(id uuid.UUID) (*Target, error) { + m.mutex.RLock() + defer m.mutex.RUnlock() + + t, ok := m.targets[id] + if !ok { + return nil, ErrTargetNotFound + } + + return t, nil +} + +// TargetsList returns a list of all targets. +func (m *Manager) TargetsList() []*Target { + m.mutex.RLock() + defer m.mutex.RUnlock() + + list := make([]*Target, 0, len(m.targets)) + for _, t := range m.targets { + list = append(list, t) + } + + return list +} + +// APIItem returns the API list. +func (m *Manager) APIItem() *defs.APIPushTargetList { + m.mutex.RLock() + defer m.mutex.RUnlock() + + list := &defs.APIPushTargetList{ + Items: make([]*defs.APIPushTarget, 0, len(m.targets)), + } + + for _, t := range m.targets { + list.Items = append(list.Items, t.APIItem()) + } + + return list +} diff --git a/internal/push/target.go b/internal/push/target.go new file mode 100644 index 00000000000..3a66f42ed3e --- /dev/null +++ b/internal/push/target.go @@ -0,0 +1,1160 @@ +// Package push contains push target implementations. +package push + +import ( + "bufio" + "bytes" + "context" + "crypto/tls" + "fmt" + "io" + "net" + "net/url" + "reflect" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/bluenviron/gortmplib" + "github.com/bluenviron/gortmplib/pkg/amf0" + "github.com/bluenviron/gortmplib/pkg/bytecounter" + "github.com/bluenviron/gortmplib/pkg/codecs" + "github.com/bluenviron/gortmplib/pkg/handshake" + "github.com/bluenviron/gortmplib/pkg/message" + "github.com/bluenviron/gortsplib/v5" + "github.com/bluenviron/gortsplib/v5/pkg/base" + "github.com/bluenviron/gortsplib/v5/pkg/format" + "github.com/bluenviron/mediacommon/v2/pkg/codecs/h264" + "github.com/bluenviron/mediacommon/v2/pkg/codecs/h265" + mcmpegts "github.com/bluenviron/mediacommon/v2/pkg/formats/mpegts" + tscodecs "github.com/bluenviron/mediacommon/v2/pkg/formats/mpegts/codecs" + srt "github.com/datarhei/gosrt" + "github.com/google/uuid" + + "github.com/bluenviron/mediamtx/internal/conf" + "github.com/bluenviron/mediamtx/internal/defs" + "github.com/bluenviron/mediamtx/internal/logger" + mtls "github.com/bluenviron/mediamtx/internal/protocols/tls" + "github.com/bluenviron/mediamtx/internal/stream" + "github.com/bluenviron/mediamtx/internal/unit" +) + +const ( + retryPause = 5 * time.Second + encodingAMF0 = 0 + fmleFlashVer = "FMLE/3.0 (compatible; mediamtx)" +) + +var errRTMPTrackParametersChanged = fmt.Errorf("RTMP track parameters changed") + +type fmleRTMPClient struct { + nconn net.Conn + bc *bytecounter.ReadWriter + mrw *message.ReadWriter +} + +func splitPath(u *url.URL) (string, string) { + pathsegs := strings.Split(u.Path, "/") + + var app string + var streamKey string + + switch { + case len(pathsegs) == 2: + app = pathsegs[1] + + case len(pathsegs) == 3: + app = pathsegs[1] + streamKey = pathsegs[2] + + case len(pathsegs) > 3: + app = strings.Join(pathsegs[1:3], "/") + streamKey = strings.Join(pathsegs[3:], "/") + } + + return app, streamKey +} + +func getTcURL(u *url.URL) string { + app, _ := splitPath(u) + nu, _ := url.Parse(u.String()) + nu.RawQuery = "" + nu.Path = "/" + return nu.String() + app +} + +func readCommandResult(mrw *message.ReadWriter, commandID int) (*message.CommandAMF0, error) { + for { + msg, err := mrw.Read() + if err != nil { + return nil, err + } + + if cmd, ok := msg.(*message.CommandAMF0); ok { + if cmd.CommandID == commandID || (cmd.CommandID == 0 && + (cmd.Name == "_result" || cmd.Name == "_error")) { + return cmd, nil + } + } + } +} + +func resultIsOK2(res *message.CommandAMF0) bool { + if len(res.Arguments) < 2 { + return false + } + + v, ok := res.Arguments[1].(float64) + if !ok { + return false + } + + return v == 1 +} + +func objectOrArray(v interface{}) (amf0.Object, bool) { + switch vv := v.(type) { + case amf0.Object: + return vv, true + case amf0.ECMAArray: + return amf0.Object(vv), true + } + return nil, false +} + +func resultIsOK1(res *message.CommandAMF0) bool { + if len(res.Arguments) < 2 { + return false + } + + ma, ok := objectOrArray(res.Arguments[1]) + if !ok { + return false + } + + v, ok := ma.Get("level") + if !ok { + return false + } + + return v == "status" +} + +func newFMLERTMPClient(ctx context.Context, u *url.URL, tlsConfig *tls.Config) (*fmleRTMPClient, error) { + var nconn net.Conn + var err error + + if u.Scheme == "rtmp" { + dialer := &net.Dialer{} + nconn, err = dialer.DialContext(ctx, "tcp", u.Host) + } else { + dialer := &tls.Dialer{Config: tlsConfig} + nconn, err = dialer.DialContext(ctx, "tcp", u.Host) + } + if err != nil { + return nil, err + } + + closerDone := make(chan struct{}) + closerTerminate := make(chan struct{}) + + go func() { + defer close(closerDone) + select { + case <-closerTerminate: + case <-ctx.Done(): + nconn.Close() + } + }() + + c := &fmleRTMPClient{nconn: nconn} + + err = c.initialize(u) + close(closerTerminate) + <-closerDone + + if err != nil { + nconn.Close() + return nil, err + } + + return c, nil +} + +func (c *fmleRTMPClient) initialize(u *url.URL) error { + c.bc = bytecounter.NewReadWriter(c.nconn) + + _, _, err := handshake.DoClient(c.bc, false, false) + if err != nil { + return fmt.Errorf("handshake failed: %w", err) + } + + c.mrw = message.NewReadWriter(c.bc, c.bc, false) + + err = c.mrw.Write(&message.SetWindowAckSize{Value: 2500000}) + if err != nil { + return fmt.Errorf("SetWindowAckSize failed: %w", err) + } + + err = c.mrw.Write(&message.SetPeerBandwidth{Value: 2500000, Type: 2}) + if err != nil { + return fmt.Errorf("SetPeerBandwidth failed: %w", err) + } + + err = c.mrw.Write(&message.SetChunkSize{Value: 65536}) + if err != nil { + return fmt.Errorf("SetChunkSize failed: %w", err) + } + + app, streamKey := splitPath(u) + tcURL := getTcURL(u) + + connectArg := amf0.Object{ + {Key: "app", Value: app}, + {Key: "flashVer", Value: fmleFlashVer}, + {Key: "tcUrl", Value: tcURL}, + {Key: "fpad", Value: false}, + {Key: "capabilities", Value: float64(15)}, + {Key: "audioCodecs", Value: float64(4071)}, + {Key: "videoCodecs", Value: float64(252)}, + {Key: "videoFunction", Value: float64(1)}, + {Key: "objectEncoding", Value: float64(encodingAMF0)}, + {Key: "type", Value: "nonprivate"}, + } + + err = c.mrw.Write(&message.CommandAMF0{ + ChunkStreamID: 3, + Name: "connect", + CommandID: 1, + Arguments: []any{connectArg}, + }) + if err != nil { + return fmt.Errorf("connect command failed: %w", err) + } + + res, err := readCommandResult(c.mrw, 1) + if err != nil { + return fmt.Errorf("connect result read failed: %w", err) + } + + if res.Name == "_error" { + return fmt.Errorf("connect rejected: %v", res.Arguments) + } + + if res.Name != "_result" { + return fmt.Errorf("unexpected connect result: %s", res.Name) + } + + err = c.mrw.Write(&message.CommandAMF0{ + ChunkStreamID: 3, + Name: "releaseStream", + CommandID: 2, + Arguments: []any{nil, streamKey}, + }) + if err != nil { + return fmt.Errorf("releaseStream failed: %w", err) + } + + err = c.mrw.Write(&message.CommandAMF0{ + ChunkStreamID: 3, + Name: "FCPublish", + CommandID: 3, + Arguments: []any{nil, streamKey}, + }) + if err != nil { + return fmt.Errorf("FCPublish failed: %w", err) + } + + err = c.mrw.Write(&message.CommandAMF0{ + ChunkStreamID: 3, + Name: "createStream", + CommandID: 4, + Arguments: []any{nil}, + }) + if err != nil { + return fmt.Errorf("createStream failed: %w", err) + } + + res, err = readCommandResult(c.mrw, 4) + if err != nil { + return fmt.Errorf("createStream result read failed: %w", err) + } + + if res.Name != "_result" || !resultIsOK2(res) { + return fmt.Errorf("createStream rejected: %v", res) + } + + err = c.mrw.Write(&message.CommandAMF0{ + ChunkStreamID: 4, + MessageStreamID: 0x1000000, + Name: "publish", + CommandID: 5, + Arguments: []any{nil, streamKey, "live"}, + }) + if err != nil { + return fmt.Errorf("publish command failed: %w", err) + } + + for i := 0; i < 10; i++ { + msg, err := c.mrw.Read() + if err != nil { + return fmt.Errorf("publish status read failed (attempt %d): %w", i+1, err) + } + + if cmd, ok := msg.(*message.CommandAMF0); ok { + if cmd.Name == "onStatus" { + if !resultIsOK1(cmd) { + return fmt.Errorf("publish rejected: %v", cmd) + } + return nil + } + if cmd.Name == "_error" { + return fmt.Errorf("publish error: %v", cmd.Arguments) + } + } + } + + return fmt.Errorf("no publish response received after 10 attempts") +} + +func (c *fmleRTMPClient) Close() { + c.nconn.Close() +} + +func (c *fmleRTMPClient) NetConn() net.Conn { + return c.nconn +} + +func (c *fmleRTMPClient) BytesReceived() uint64 { + return c.bc.Reader.Count() +} + +func (c *fmleRTMPClient) BytesSent() uint64 { + return c.bc.Writer.Count() +} + +func (c *fmleRTMPClient) Read() (message.Message, error) { + return c.mrw.Read() +} + +func (c *fmleRTMPClient) Write(msg message.Message) error { + return c.mrw.Write(msg) +} + +func multiplyAndDivide(v, m, d time.Duration) time.Duration { + secs := v / d + dec := v % d + return secs*m + dec*m/d +} + +func timestampToDuration(t int64, clockRate int) time.Duration { + return multiplyAndDivide(time.Duration(t), time.Second, time.Duration(clockRate)) +} + +func rtmpHostCandidates(u *url.URL) []string { + if _, _, err := net.SplitHostPort(u.Host); err == nil { + return []string{u.Host} + } + + if u.Scheme == "rtmps" { + return []string{ + net.JoinHostPort(u.Host, "443"), + net.JoinHostPort(u.Host, "1936"), + } + } + + return []string{net.JoinHostPort(u.Host, "1935")} +} + +func h264TrackParametersChanged(forma *format.H264, codec *codecs.H264) bool { + sps, pps := forma.SafeParams() + return !bytes.Equal(sps, codec.SPS) || !bytes.Equal(pps, codec.PPS) +} + +func h265TrackParametersChanged(forma *format.H265, codec *codecs.H265) bool { + vps, sps, pps := forma.SafeParams() + return !bytes.Equal(vps, codec.VPS) || !bytes.Equal(sps, codec.SPS) || !bytes.Equal(pps, codec.PPS) +} + +func mpeg4AudioTrackParametersChanged(forma *format.MPEG4Audio, codec *codecs.MPEG4Audio) bool { + return !reflect.DeepEqual(forma.Config, codec.Config) +} + +// countingWriter wraps an io.Writer and counts bytes written. +type countingWriter struct { + w io.Writer + count *uint64 +} + +func (c *countingWriter) Write(p []byte) (n int, err error) { + n, err = c.w.Write(p) + atomic.AddUint64(c.count, uint64(n)) + return n, err +} + +type targetParent interface { + logger.Writer +} + +// Target is a push target. +type Target struct { + URL string + ReadTimeout conf.Duration + WriteTimeout conf.Duration + Parent targetParent + PathName string + + ctx context.Context + ctxCancel func() + uuid uuid.UUID + created time.Time + mutex sync.RWMutex + state defs.APIPushTargetState + errorMsg string + bytesSent uint64 + stream *stream.Stream + reader *stream.Reader + streamLoaded bool + + done chan struct{} +} + +// Initialize initializes Target. +func (t *Target) Initialize() { + t.ctx, t.ctxCancel = context.WithCancel(context.Background()) + t.uuid = uuid.New() + t.created = time.Now() + t.state = defs.APIPushTargetStateIdle + t.done = make(chan struct{}) + + t.Log(logger.Info, "created push target to %s", t.URL) + + go t.run() +} + +// Close closes the Target. +func (t *Target) Close() { + t.Log(logger.Info, "closing push target to %s", t.URL) + t.ctxCancel() + <-t.done +} + +// Log implements logger.Writer. +func (t *Target) Log(level logger.Level, format string, args ...any) { + t.Parent.Log(level, "[push %s] "+format, append([]any{t.uuid.String()[:8]}, args...)...) +} + +// SetStream sets the stream to push. +func (t *Target) SetStream(strm *stream.Stream) { + t.mutex.Lock() + defer t.mutex.Unlock() + t.stream = strm + t.streamLoaded = true +} + +// ClearStream clears the stream. +func (t *Target) ClearStream() { + t.mutex.Lock() + defer t.mutex.Unlock() + t.stream = nil + t.streamLoaded = false +} + +// APIItem returns the API item. +func (t *Target) APIItem() *defs.APIPushTarget { + t.mutex.RLock() + defer t.mutex.RUnlock() + + return &defs.APIPushTarget{ + ID: t.uuid, + Created: t.created, + URL: t.URL, + State: t.state, + Error: t.errorMsg, + BytesSent: atomic.LoadUint64(&t.bytesSent), + } +} + +// UUID returns the target UUID. +func (t *Target) UUID() uuid.UUID { + return t.uuid +} + +func (t *Target) run() { + defer close(t.done) + + for { + shouldRetry, waitRetry := t.runInner() + if !shouldRetry { + return + } + if !waitRetry { + continue + } + + select { + case <-time.After(retryPause): + case <-t.ctx.Done(): + return + } + } +} + +func (t *Target) runInner() (bool, bool) { + // Wait for stream to be available + for { + t.mutex.RLock() + strm := t.stream + loaded := t.streamLoaded + t.mutex.RUnlock() + + if loaded && strm != nil { + break + } + + if loaded && strm == nil { + t.mutex.Lock() + t.state = defs.APIPushTargetStateIdle + t.mutex.Unlock() + } + + select { + case <-time.After(500 * time.Millisecond): + case <-t.ctx.Done(): + return false, false + } + } + + t.mutex.Lock() + t.state = defs.APIPushTargetStateRunning + t.errorMsg = "" + t.mutex.Unlock() + + var err error + + switch { + case strings.HasPrefix(t.URL, "rtmp://") || strings.HasPrefix(t.URL, "rtmps://"): + err = t.runRTMP() + case strings.HasPrefix(t.URL, "rtsp://") || strings.HasPrefix(t.URL, "rtsps://"): + err = t.runRTSP() + case strings.HasPrefix(t.URL, "srt://"): + err = t.runSRT() + default: + err = fmt.Errorf("unsupported protocol") + } + + if err != nil { + if err == errRTMPTrackParametersChanged { + t.Log(logger.Info, "stream parameters changed, reconnecting push target") + return true, false + } + + t.Log(logger.Error, "push error: %v", err) + + t.mutex.Lock() + t.state = defs.APIPushTargetStateError + t.errorMsg = err.Error() + t.mutex.Unlock() + + return true, true + } + + return false, false +} + +func (t *Target) addBytesSent(n uint64) { + atomic.AddUint64(&t.bytesSent, n) +} + +func (t *Target) runRTMP() error { + t.Log(logger.Debug, "connecting to RTMP server") + + // Resolve the URL with path variables + targetURL := t.resolveURL() + + u, err := url.Parse(targetURL) + if err != nil { + return err + } + + t.mutex.RLock() + strm := t.stream + t.mutex.RUnlock() + + if strm == nil { + return fmt.Errorf("stream is not available") + } + + // Create reader + reader := &stream.Reader{ + Parent: t, + } + + // Setup tracks + var tracks []*gortmplib.Track + var writer *gortmplib.Writer + + for _, media := range strm.Desc.Medias { + for _, forma := range media.Formats { + switch forma := forma.(type) { + case *format.H265: + vps, sps, pps := forma.SafeParams() + codec := &codecs.H265{ + VPS: vps, + SPS: sps, + PPS: pps, + } + track := &gortmplib.Track{ + Codec: codec, + } + tracks = append(tracks, track) + + var videoDTSExtractor *h265.DTSExtractor + + reader.OnData( + media, + forma, + func(u *unit.Unit) error { + if u.NilPayload() { + return nil + } + + if h265TrackParametersChanged(forma, codec) { + return errRTMPTrackParametersChanged + } + + if videoDTSExtractor == nil { + if !h265.IsRandomAccess(u.Payload.(unit.PayloadH265)) { + return nil + } + videoDTSExtractor = &h265.DTSExtractor{} + videoDTSExtractor.Initialize() + } + + dts, err := videoDTSExtractor.Extract(u.Payload.(unit.PayloadH265), u.PTS) + if err != nil { + return err + } + + err = writer.WriteH265( + track, + timestampToDuration(u.PTS, forma.ClockRate()), + timestampToDuration(dts, forma.ClockRate()), + u.Payload.(unit.PayloadH265)) + if err != nil { + return err + } + + // Count bytes sent (approximate size of payload) + for _, nalu := range u.Payload.(unit.PayloadH265) { + t.addBytesSent(uint64(len(nalu))) + } + return nil + }) + + case *format.H264: + sps, pps := forma.SafeParams() + codec := &codecs.H264{ + SPS: sps, + PPS: pps, + } + track := &gortmplib.Track{ + Codec: codec, + } + tracks = append(tracks, track) + + var videoDTSExtractor *h264.DTSExtractor + + reader.OnData( + media, + forma, + func(u *unit.Unit) error { + if u.NilPayload() { + return nil + } + + if h264TrackParametersChanged(forma, codec) { + return errRTMPTrackParametersChanged + } + + idrPresent := false + nonIDRPresent := false + + for _, nalu := range u.Payload.(unit.PayloadH264) { + typ := h264.NALUType(nalu[0] & 0x1F) + switch typ { + case h264.NALUTypeIDR: + idrPresent = true + case h264.NALUTypeNonIDR: + nonIDRPresent = true + } + } + + if videoDTSExtractor == nil { + if !idrPresent { + return nil + } + videoDTSExtractor = &h264.DTSExtractor{} + videoDTSExtractor.Initialize() + } else if !idrPresent && !nonIDRPresent { + return nil + } + + dts, err := videoDTSExtractor.Extract(u.Payload.(unit.PayloadH264), u.PTS) + if err != nil { + return err + } + + err = writer.WriteH264( + track, + timestampToDuration(u.PTS, forma.ClockRate()), + timestampToDuration(dts, forma.ClockRate()), + u.Payload.(unit.PayloadH264)) + if err != nil { + return err + } + + // Count bytes sent (approximate size of payload) + for _, nalu := range u.Payload.(unit.PayloadH264) { + t.addBytesSent(uint64(len(nalu))) + } + return nil + }) + + case *format.MPEG4Audio: + codec := &codecs.MPEG4Audio{ + Config: forma.Config, + } + track := &gortmplib.Track{ + Codec: codec, + } + tracks = append(tracks, track) + + reader.OnData( + media, + forma, + func(u *unit.Unit) error { + if u.NilPayload() { + return nil + } + + if mpeg4AudioTrackParametersChanged(forma, codec) { + return errRTMPTrackParametersChanged + } + + for i, au := range u.Payload.(unit.PayloadMPEG4Audio) { + pts := u.PTS + int64(i)*1024 // SamplesPerAccessUnit + + err := writer.WriteMPEG4Audio( + track, + timestampToDuration(pts, forma.ClockRate()), + au, + ) + if err != nil { + return err + } + + // Count bytes sent + t.addBytesSent(uint64(len(au))) + } + + return nil + }) + } + } + } + + if len(tracks) == 0 { + return fmt.Errorf("no supported tracks found for RTMP push") + } + + hostCandidates := rtmpHostCandidates(u) + + var conn *fmleRTMPClient + for _, host := range hostCandidates { + candidate := *u + candidate.Host = host + + connectCtx, connectCtxCancel := context.WithTimeout(t.ctx, 30*time.Second) + conn, err = newFMLERTMPClient(connectCtx, &candidate, mtls.MakeConfig(candidate.Hostname(), "")) + connectCtxCancel() + if err == nil { + break + } + + t.Log(logger.Debug, "RTMP connection to %s failed: %v", host, err) + } + if err != nil { + return err + } + + defer conn.Close() + + t.Log(logger.Info, "connected to %s", targetURL) + + // Initialize writer + writer = &gortmplib.Writer{ + Conn: conn, + Tracks: tracks, + } + err = writer.Initialize() + if err != nil { + return err + } + + // Add reader to stream + strm.AddReader(reader) + defer strm.RemoveReader(reader) + + t.mutex.Lock() + t.reader = reader + t.mutex.Unlock() + + conn.NetConn().SetReadDeadline(time.Time{}) + conn.NetConn().SetWriteDeadline(time.Time{}) + + rtmpErr := make(chan error, 1) + + go func() { + for { + _, err := conn.Read() + if err != nil { + select { + case rtmpErr <- fmt.Errorf("RTMP read error: %w", err): + default: + } + return + } + } + }() + + // Wait for error or context cancellation + select { + case err := <-reader.Error(): + return err + case err := <-rtmpErr: + return err + case <-t.ctx.Done(): + return nil + } +} + +func (t *Target) runRTSP() error { + t.Log(logger.Debug, "connecting to RTSP server") + + // Resolve the URL with path variables + targetURL := t.resolveURL() + + u, err := base.ParseURL(targetURL) + if err != nil { + return err + } + + t.mutex.RLock() + strm := t.stream + t.mutex.RUnlock() + + if strm == nil { + return fmt.Errorf("stream is not available") + } + + // Determine scheme + scheme := "rtsp" + if u.Scheme == "rtsps" { + scheme = "rtsps" + } + + // Create RTSP client for publishing + client := &gortsplib.Client{ + Scheme: scheme, + Host: u.Host, + ReadTimeout: time.Duration(t.ReadTimeout), + WriteTimeout: time.Duration(t.WriteTimeout), + TLSConfig: mtls.MakeConfig(u.Host, ""), + } + + err = client.Start() + if err != nil { + return err + } + defer client.Close() + + // Announce the stream + _, err = client.Announce(u, strm.Desc) + if err != nil { + return err + } + + // Setup all medias + for _, media := range strm.Desc.Medias { + _, err = client.Setup(u, media, 0, 0) + if err != nil { + return err + } + } + + // Start recording (publishing) + _, err = client.Record() + if err != nil { + return err + } + + t.Log(logger.Info, "connected to %s", targetURL) + + // Create reader + reader := &stream.Reader{ + Parent: t, + } + + // Setup data handlers for each media + for _, media := range strm.Desc.Medias { + for _, forma := range media.Formats { + cmedia := media + cforma := forma + + reader.OnData( + cmedia, + cforma, + func(u *unit.Unit) error { + if u.NilPayload() { + return nil + } + + // Write RTP packets to the client + for _, pkt := range u.RTPPackets { + err := client.WritePacketRTP(cmedia, pkt) + if err != nil { + return err + } + + // Count bytes sent + t.addBytesSent(uint64(pkt.MarshalSize())) + } + + return nil + }) + } + } + + // Add reader to stream + strm.AddReader(reader) + defer strm.RemoveReader(reader) + + t.mutex.Lock() + t.reader = reader + t.mutex.Unlock() + + // Wait for error or context cancellation + select { + case err := <-reader.Error(): + return err + case <-t.ctx.Done(): + return nil + } +} + +func (t *Target) runSRT() error { + t.Log(logger.Debug, "connecting to SRT server") + + // Resolve the URL with path variables + targetURL := t.resolveURL() + + conf := srt.DefaultConfig() + address, err := conf.UnmarshalURL(targetURL) + if err != nil { + return err + } + + err = conf.Validate() + if err != nil { + return err + } + + t.mutex.RLock() + strm := t.stream + t.mutex.RUnlock() + + if strm == nil { + return fmt.Errorf("stream is not available") + } + + // Connect to SRT server + sconn, err := srt.Dial("srt", address, conf) + if err != nil { + return err + } + defer sconn.Close() + + t.Log(logger.Info, "connected to %s", targetURL) + + // Create a counting writer to track bytes sent + cw := &countingWriter{w: sconn, count: &t.bytesSent} + bw := bufio.NewWriterSize(cw, 1316) // SRT max payload size + + // Create MPEG-TS writer + var mpegtsWriter *mcmpegts.Writer + var tracks []*mcmpegts.Track + + // Create reader + reader := &stream.Reader{ + Parent: t, + } + + // Setup tracks based on the stream description + for _, media := range strm.Desc.Medias { + for _, forma := range media.Formats { + switch forma := forma.(type) { + case *format.H265: + track := &mcmpegts.Track{Codec: &tscodecs.H265{}} + tracks = append(tracks, track) + + var dtsExtractor *h265.DTSExtractor + + reader.OnData( + media, + forma, + func(u *unit.Unit) error { + if u.NilPayload() { + return nil + } + + if dtsExtractor == nil { + if !h265.IsRandomAccess(u.Payload.(unit.PayloadH265)) { + return nil + } + dtsExtractor = &h265.DTSExtractor{} + dtsExtractor.Initialize() + } + + dts, err := dtsExtractor.Extract(u.Payload.(unit.PayloadH265), u.PTS) + if err != nil { + return err + } + + sconn.SetWriteDeadline(time.Now().Add(time.Duration(t.WriteTimeout))) + err = mpegtsWriter.WriteH265(track, u.PTS, dts, u.Payload.(unit.PayloadH265)) + if err != nil { + return err + } + return bw.Flush() + }) + + case *format.H264: + track := &mcmpegts.Track{Codec: &tscodecs.H264{}} + tracks = append(tracks, track) + + var dtsExtractor *h264.DTSExtractor + + reader.OnData( + media, + forma, + func(u *unit.Unit) error { + if u.NilPayload() { + return nil + } + + idrPresent := h264.IsRandomAccess(u.Payload.(unit.PayloadH264)) + + if dtsExtractor == nil { + if !idrPresent { + return nil + } + dtsExtractor = &h264.DTSExtractor{} + dtsExtractor.Initialize() + } + + dts, err := dtsExtractor.Extract(u.Payload.(unit.PayloadH264), u.PTS) + if err != nil { + return err + } + + sconn.SetWriteDeadline(time.Now().Add(time.Duration(t.WriteTimeout))) + err = mpegtsWriter.WriteH264(track, u.PTS, dts, u.Payload.(unit.PayloadH264)) + if err != nil { + return err + } + return bw.Flush() + }) + + case *format.MPEG4Audio: + track := &mcmpegts.Track{Codec: &tscodecs.MPEG4Audio{ + Config: *forma.Config, + }} + tracks = append(tracks, track) + + reader.OnData( + media, + forma, + func(u *unit.Unit) error { + if u.NilPayload() { + return nil + } + + sconn.SetWriteDeadline(time.Now().Add(time.Duration(t.WriteTimeout))) + err := mpegtsWriter.WriteMPEG4Audio(track, u.PTS, u.Payload.(unit.PayloadMPEG4Audio)) + if err != nil { + return err + } + return bw.Flush() + }) + + case *format.Opus: + track := &mcmpegts.Track{Codec: &tscodecs.Opus{ + ChannelCount: forma.ChannelCount, + }} + tracks = append(tracks, track) + + reader.OnData( + media, + forma, + func(u *unit.Unit) error { + if u.NilPayload() { + return nil + } + + sconn.SetWriteDeadline(time.Now().Add(time.Duration(t.WriteTimeout))) + err := mpegtsWriter.WriteOpus(track, u.PTS, u.Payload.(unit.PayloadOpus)) + if err != nil { + return err + } + return bw.Flush() + }) + } + } + } + + if len(tracks) == 0 { + return fmt.Errorf("no supported tracks found for SRT push") + } + + // Initialize MPEG-TS writer + mpegtsWriter = &mcmpegts.Writer{W: bw, Tracks: tracks} + err = mpegtsWriter.Initialize() + if err != nil { + return err + } + + // Add reader to stream + strm.AddReader(reader) + defer strm.RemoveReader(reader) + + t.mutex.Lock() + t.reader = reader + t.mutex.Unlock() + + // Wait for error or context cancellation + select { + case err := <-reader.Error(): + return err + case <-t.ctx.Done(): + return nil + } +} + +func (t *Target) resolveURL() string { + result := t.URL + result = strings.ReplaceAll(result, "$MTX_PATH", t.PathName) + result = strings.ReplaceAll(result, "$path", t.PathName) + return result +} diff --git a/internal/push/target_test.go b/internal/push/target_test.go new file mode 100644 index 00000000000..b242e4e0d47 --- /dev/null +++ b/internal/push/target_test.go @@ -0,0 +1,62 @@ +package push + +import ( + "net/url" + "testing" + + "github.com/bluenviron/gortmplib/pkg/codecs" + "github.com/bluenviron/gortsplib/v5/pkg/format" + "github.com/bluenviron/mediacommon/v2/pkg/codecs/mpeg4audio" + "github.com/stretchr/testify/require" +) + +func TestRTMPHostCandidates(t *testing.T) { + for _, ca := range []struct { + name string + rawURL string + expected []string + }{ + { + name: "rtmp default port", + rawURL: "rtmp://example.com/app/stream", + expected: []string{"example.com:1935"}, + }, + { + name: "rtmps fallback ports", + rawURL: "rtmps://example.com/app/stream", + expected: []string{"example.com:443", "example.com:1936"}, + }, + { + name: "explicit port preserved", + rawURL: "rtmps://example.com:8443/app/stream", + expected: []string{"example.com:8443"}, + }, + } { + t.Run(ca.name, func(t *testing.T) { + u, err := url.Parse(ca.rawURL) + require.NoError(t, err) + require.Equal(t, ca.expected, rtmpHostCandidates(u)) + }) + } +} + +func TestRTMPTrackParametersChanged(t *testing.T) { + t.Run("h264", func(t *testing.T) { + forma := &format.H264{SPS: []byte{1}, PPS: []byte{2}} + require.False(t, h264TrackParametersChanged(forma, &codecs.H264{SPS: []byte{1}, PPS: []byte{2}})) + require.True(t, h264TrackParametersChanged(forma, &codecs.H264{SPS: []byte{9}, PPS: []byte{2}})) + }) + + t.Run("h265", func(t *testing.T) { + forma := &format.H265{VPS: []byte{1}, SPS: []byte{2}, PPS: []byte{3}} + require.False(t, h265TrackParametersChanged(forma, &codecs.H265{VPS: []byte{1}, SPS: []byte{2}, PPS: []byte{3}})) + require.True(t, h265TrackParametersChanged(forma, &codecs.H265{VPS: []byte{1}, SPS: []byte{7}, PPS: []byte{3}})) + }) + + t.Run("mpeg4audio", func(t *testing.T) { + config := &mpeg4audio.AudioSpecificConfig{Type: 2, SampleRate: 48000, ChannelConfig: 2} + forma := &format.MPEG4Audio{Config: config} + require.False(t, mpeg4AudioTrackParametersChanged(forma, &codecs.MPEG4Audio{Config: config})) + require.True(t, mpeg4AudioTrackParametersChanged(forma, &codecs.MPEG4Audio{Config: &mpeg4audio.AudioSpecificConfig{Type: 2, SampleRate: 44100, ChannelConfig: 2}})) + }) +} diff --git a/mediamtx.yml b/mediamtx.yml index 260f7566a6f..4813ff51900 100644 --- a/mediamtx.yml +++ b/mediamtx.yml @@ -536,6 +536,18 @@ pathDefaults: # Set to 0s to disable automatic deletion. recordDeleteAfter: 1d + ############################################### + # Default path settings -> Push Targets + + # Push the stream to external servers (e.g., YouTube Live, Twitch, etc.). + # Supported protocols: rtmp, rtmps. + # The URL can contain variables such as $MTX_PATH or $path (the path name). + # Push targets can also be added/removed dynamically via the API. + # Example: + # pushTargets: + # - url: rtmp://a.rtmp.youtube.com/live2/xxxx-xxxx-xxxx-xxxx-xxxx + # - url: rtmps://live.twitch.tv/app/live_xxxxxxxxx_xxxxxxxxxxxxxxxxxxxx + ############################################### # Default path settings -> Publisher source (when source is "publisher") @@ -685,6 +697,20 @@ pathDefaults: # M-JPEG JPEG quality (when codec is mjpeg). rpiCameraMJPEGQuality: 60 + ############################################### + # Default path settings -> Push + + # List of push targets to forward the stream to. + # The stream will be automatically pushed to these targets when available. + # If the source stream is interrupted, pushing will resume automatically + # when the stream becomes available again. + # Supported protocols: RTMP, RTSP, SRT + # Example: + # pushTargets: + # - url: rtmp://youtube.com/live/YOUR_STREAM_KEY + # - url: srt://server.com:1234/?streamid=publish:mystream + # - url: rtsp://server.com:8554/mystream + ############################################### # Default path settings -> Hooks