mirror of
https://github.com/bluenviron/mediamtx
synced 2025-01-20 14:11:09 +00:00
rtmp: simplify API (#2130)
This commit is contained in:
parent
959b017d72
commit
d696a782f7
@ -573,9 +573,8 @@ func TestAPIProtocolList(t *testing.T) {
|
||||
}()
|
||||
require.NoError(t, err)
|
||||
defer nconn.Close()
|
||||
conn := rtmp.NewConn(nconn)
|
||||
|
||||
err = conn.InitializeClient(u, true)
|
||||
conn, err := rtmp.NewClientConn(nconn, u, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = rtmp.NewWriter(conn, testFormatH264, nil)
|
||||
@ -828,9 +827,8 @@ func TestAPIProtocolGet(t *testing.T) {
|
||||
}()
|
||||
require.NoError(t, err)
|
||||
defer nconn.Close()
|
||||
conn := rtmp.NewConn(nconn)
|
||||
|
||||
err = conn.InitializeClient(u, true)
|
||||
conn, err := rtmp.NewClientConn(nconn, u, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = rtmp.NewWriter(conn, testFormatH264, nil)
|
||||
@ -1150,9 +1148,8 @@ func TestAPIProtocolKick(t *testing.T) {
|
||||
nconn, err := net.Dial("tcp", u.Host)
|
||||
require.NoError(t, err)
|
||||
defer nconn.Close()
|
||||
conn := rtmp.NewConn(nconn)
|
||||
|
||||
err = conn.InitializeClient(u, true)
|
||||
conn, err := rtmp.NewClientConn(nconn, u, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = rtmp.NewWriter(conn, testFormatH264, nil)
|
||||
|
@ -85,9 +85,8 @@ webrtc_sessions_bytes_sent 0
|
||||
nconn, err := net.Dial("tcp", u.Host)
|
||||
require.NoError(t, err)
|
||||
defer nconn.Close()
|
||||
conn := rtmp.NewConn(nconn)
|
||||
|
||||
err = conn.InitializeClient(u, true)
|
||||
conn, err := rtmp.NewClientConn(nconn, u, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
videoTrack := &formats.H264{
|
||||
|
@ -65,7 +65,6 @@ type rtmpConn struct {
|
||||
runOnConnect string
|
||||
runOnConnectRestart bool
|
||||
wg *sync.WaitGroup
|
||||
conn *rtmp.Conn
|
||||
nconn net.Conn
|
||||
externalCmdPool *externalcmd.Pool
|
||||
pathManager rtmpConnPathManager
|
||||
@ -75,7 +74,8 @@ type rtmpConn struct {
|
||||
ctxCancel func()
|
||||
uuid uuid.UUID
|
||||
created time.Time
|
||||
mutex sync.Mutex
|
||||
mutex sync.RWMutex
|
||||
conn *rtmp.Conn
|
||||
state rtmpConnState
|
||||
pathName string
|
||||
}
|
||||
@ -106,7 +106,6 @@ func newRTMPConn(
|
||||
runOnConnect: runOnConnect,
|
||||
runOnConnectRestart: runOnConnectRestart,
|
||||
wg: wg,
|
||||
conn: rtmp.NewConn(nconn),
|
||||
nconn: nconn,
|
||||
externalCmdPool: externalCmdPool,
|
||||
pathManager: pathManager,
|
||||
@ -194,18 +193,22 @@ func (c *rtmpConn) run() {
|
||||
func (c *rtmpConn) runReader() error {
|
||||
c.nconn.SetReadDeadline(time.Now().Add(time.Duration(c.readTimeout)))
|
||||
c.nconn.SetWriteDeadline(time.Now().Add(time.Duration(c.writeTimeout)))
|
||||
u, publish, err := c.conn.InitializeServer()
|
||||
conn, u, publish, err := rtmp.NewServerConn(c.nconn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.mutex.Lock()
|
||||
c.conn = conn
|
||||
c.mutex.Unlock()
|
||||
|
||||
if !publish {
|
||||
return c.runRead(u)
|
||||
return c.runRead(conn, u)
|
||||
}
|
||||
return c.runPublish(u)
|
||||
return c.runPublish(conn, u)
|
||||
}
|
||||
|
||||
func (c *rtmpConn) runRead(u *url.URL) error {
|
||||
func (c *rtmpConn) runRead(conn *rtmp.Conn, u *url.URL) error {
|
||||
pathName, query, rawQuery := pathNameAndQuery(u)
|
||||
|
||||
res := c.pathManager.addReader(pathAddReaderReq{
|
||||
@ -298,7 +301,7 @@ func (c *rtmpConn) runRead(u *url.URL) error {
|
||||
}
|
||||
|
||||
var err error
|
||||
w, err = rtmp.NewWriter(c.conn, videoFormat, audioFormat)
|
||||
w, err = rtmp.NewWriter(conn, videoFormat, audioFormat)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -569,7 +572,7 @@ func (c *rtmpConn) setupAudio(
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (c *rtmpConn) runPublish(u *url.URL) error {
|
||||
func (c *rtmpConn) runPublish(conn *rtmp.Conn, u *url.URL) error {
|
||||
pathName, query, rawQuery := pathNameAndQuery(u)
|
||||
|
||||
res := c.pathManager.addPublisher(pathAddPublisherReq{
|
||||
@ -601,7 +604,7 @@ func (c *rtmpConn) runPublish(u *url.URL) error {
|
||||
c.pathName = pathName
|
||||
c.mutex.Unlock()
|
||||
|
||||
r, err := rtmp.NewReader(c.conn)
|
||||
r, err := rtmp.NewReader(conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -731,8 +734,16 @@ func (c *rtmpConn) apiSourceDescribe() pathAPISourceOrReader {
|
||||
}
|
||||
|
||||
func (c *rtmpConn) apiItem() *apiRTMPConn {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
c.mutex.RLock()
|
||||
defer c.mutex.RUnlock()
|
||||
|
||||
bytesReceived := uint64(0)
|
||||
bytesSent := uint64(0)
|
||||
|
||||
if c.conn != nil {
|
||||
bytesReceived = c.conn.BytesReceived()
|
||||
bytesSent = c.conn.BytesSent()
|
||||
}
|
||||
|
||||
return &apiRTMPConn{
|
||||
ID: c.uuid,
|
||||
@ -749,7 +760,7 @@ func (c *rtmpConn) apiItem() *apiRTMPConn {
|
||||
return "idle"
|
||||
}(),
|
||||
Path: c.pathName,
|
||||
BytesReceived: c.conn.BytesReceived(),
|
||||
BytesSent: c.conn.BytesSent(),
|
||||
BytesReceived: bytesReceived,
|
||||
BytesSent: bytesSent,
|
||||
}
|
||||
}
|
||||
|
@ -34,9 +34,8 @@ func TestRTMPServerRunOnConnect(t *testing.T) {
|
||||
nconn, err := net.Dial("tcp", u.Host)
|
||||
require.NoError(t, err)
|
||||
defer nconn.Close()
|
||||
conn := rtmp.NewConn(nconn)
|
||||
|
||||
err = conn.InitializeClient(u, true)
|
||||
_, err = rtmp.NewClientConn(nconn, u, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
@ -125,9 +124,8 @@ func TestRTMPServer(t *testing.T) {
|
||||
}()
|
||||
require.NoError(t, err)
|
||||
defer nconn1.Close()
|
||||
conn1 := rtmp.NewConn(nconn1)
|
||||
|
||||
err = conn1.InitializeClient(u1, true)
|
||||
conn1, err := rtmp.NewClientConn(nconn1, u1, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
videoTrack := &formats.H264{
|
||||
@ -175,9 +173,8 @@ func TestRTMPServer(t *testing.T) {
|
||||
}()
|
||||
require.NoError(t, err)
|
||||
defer nconn2.Close()
|
||||
conn2 := rtmp.NewConn(nconn2)
|
||||
|
||||
err = conn2.InitializeClient(u2, false)
|
||||
conn2, err := rtmp.NewClientConn(nconn2, u2, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
r, err := rtmp.NewReader(conn2)
|
||||
@ -237,9 +234,8 @@ func TestRTMPServerAuthFail(t *testing.T) {
|
||||
nconn1, err := net.Dial("tcp", u1.Host)
|
||||
require.NoError(t, err)
|
||||
defer nconn1.Close()
|
||||
conn1 := rtmp.NewConn(nconn1)
|
||||
|
||||
err = conn1.InitializeClient(u1, true)
|
||||
conn1, err := rtmp.NewClientConn(nconn1, u1, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
videoTrack := &formats.H264{
|
||||
@ -266,9 +262,8 @@ func TestRTMPServerAuthFail(t *testing.T) {
|
||||
nconn2, err := net.Dial("tcp", u2.Host)
|
||||
require.NoError(t, err)
|
||||
defer nconn2.Close()
|
||||
conn2 := rtmp.NewConn(nconn2)
|
||||
|
||||
err = conn2.InitializeClient(u2, false)
|
||||
conn2, err := rtmp.NewClientConn(nconn2, u2, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = rtmp.NewReader(conn2)
|
||||
@ -291,9 +286,8 @@ func TestRTMPServerAuthFail(t *testing.T) {
|
||||
nconn1, err := net.Dial("tcp", u1.Host)
|
||||
require.NoError(t, err)
|
||||
defer nconn1.Close()
|
||||
conn1 := rtmp.NewConn(nconn1)
|
||||
|
||||
err = conn1.InitializeClient(u1, true)
|
||||
conn1, err := rtmp.NewClientConn(nconn1, u1, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
videoTrack := &formats.H264{
|
||||
@ -320,9 +314,8 @@ func TestRTMPServerAuthFail(t *testing.T) {
|
||||
nconn2, err := net.Dial("tcp", u2.Host)
|
||||
require.NoError(t, err)
|
||||
defer nconn2.Close()
|
||||
conn2 := rtmp.NewConn(nconn2)
|
||||
|
||||
err = conn2.InitializeClient(u2, false)
|
||||
conn2, err := rtmp.NewClientConn(nconn2, u2, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = rtmp.NewReader(conn2)
|
||||
@ -346,9 +339,8 @@ func TestRTMPServerAuthFail(t *testing.T) {
|
||||
nconn1, err := net.Dial("tcp", u1.Host)
|
||||
require.NoError(t, err)
|
||||
defer nconn1.Close()
|
||||
conn1 := rtmp.NewConn(nconn1)
|
||||
|
||||
err = conn1.InitializeClient(u1, true)
|
||||
conn1, err := rtmp.NewClientConn(nconn1, u1, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
videoTrack := &formats.H264{
|
||||
@ -375,9 +367,8 @@ func TestRTMPServerAuthFail(t *testing.T) {
|
||||
nconn2, err := net.Dial("tcp", u2.Host)
|
||||
require.NoError(t, err)
|
||||
defer nconn2.Close()
|
||||
conn2 := rtmp.NewConn(nconn2)
|
||||
|
||||
err = conn2.InitializeClient(u2, false)
|
||||
conn2, err := rtmp.NewClientConn(nconn2, u2, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = rtmp.NewReader(conn2)
|
||||
|
@ -99,11 +99,9 @@ func (s *rtmpSource) run(ctx context.Context, cnf *conf.PathConf, reloadConf cha
|
||||
}
|
||||
|
||||
func (s *rtmpSource) runReader(u *url.URL, nconn net.Conn) error {
|
||||
conn := rtmp.NewConn(nconn)
|
||||
|
||||
nconn.SetReadDeadline(time.Now().Add(time.Duration(s.readTimeout)))
|
||||
nconn.SetWriteDeadline(time.Now().Add(time.Duration(s.writeTimeout)))
|
||||
err := conn.InitializeClient(u, false)
|
||||
conn, err := rtmp.NewClientConn(nconn, u, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -52,9 +52,8 @@ func TestRTMPSource(t *testing.T) {
|
||||
nconn, err := ln.Accept()
|
||||
require.NoError(t, err)
|
||||
defer nconn.Close()
|
||||
conn := rtmp.NewConn(nconn)
|
||||
|
||||
_, _, err = conn.InitializeServer()
|
||||
conn, _, _, err := rtmp.NewServerConn(nconn)
|
||||
require.NoError(t, err)
|
||||
|
||||
videoTrack := &formats.H264{
|
||||
|
@ -139,29 +139,21 @@ type Conn struct {
|
||||
mrw *message.ReadWriter
|
||||
}
|
||||
|
||||
// NewConn initializes a connection.
|
||||
func NewConn(rw io.ReadWriter) *Conn {
|
||||
return &Conn{
|
||||
// NewClientConn initializes a client-side connection.
|
||||
func NewClientConn(rw io.ReadWriter, u *url.URL, publish bool) (*Conn, error) {
|
||||
c := &Conn{
|
||||
bc: bytecounter.NewReadWriter(rw),
|
||||
}
|
||||
|
||||
err := c.initializeClient(u, publish)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// BytesReceived returns the number of bytes received.
|
||||
func (c *Conn) BytesReceived() uint64 {
|
||||
return c.bc.Reader.Count()
|
||||
}
|
||||
|
||||
// BytesSent returns the number of bytes sent.
|
||||
func (c *Conn) BytesSent() uint64 {
|
||||
return c.bc.Writer.Count()
|
||||
}
|
||||
|
||||
func (c *Conn) skipInitialization() {
|
||||
c.mrw = message.NewReadWriter(c.bc, false)
|
||||
}
|
||||
|
||||
// InitializeClient performs the initialization of a client-side connection.
|
||||
func (c *Conn) InitializeClient(u *url.URL, isPublishing bool) error {
|
||||
func (c *Conn) initializeClient(u *url.URL, publish bool) error {
|
||||
connectpath, actionpath := splitPath(u)
|
||||
|
||||
err := handshake.DoClient(c.bc, false)
|
||||
@ -219,7 +211,7 @@ func (c *Conn) InitializeClient(u *url.URL, isPublishing bool) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if !isPublishing {
|
||||
if !publish {
|
||||
err = c.mrw.Write(&message.CommandAMF0{
|
||||
ChunkStreamID: 3,
|
||||
Name: "createStream",
|
||||
@ -322,8 +314,21 @@ func (c *Conn) InitializeClient(u *url.URL, isPublishing bool) error {
|
||||
return readCommandResult(c.mrw, 5, "onStatus", resultIsOK1)
|
||||
}
|
||||
|
||||
// InitializeServer performs the initialization of a server-side connection.
|
||||
func (c *Conn) InitializeServer() (*url.URL, bool, error) {
|
||||
// NewServerConn initializes a server-side connection.
|
||||
func NewServerConn(rw io.ReadWriter) (*Conn, *url.URL, bool, error) {
|
||||
c := &Conn{
|
||||
bc: bytecounter.NewReadWriter(rw),
|
||||
}
|
||||
|
||||
u, publish, err := c.initializeServer()
|
||||
if err != nil {
|
||||
return nil, nil, false, err
|
||||
}
|
||||
|
||||
return c, u, publish, nil
|
||||
}
|
||||
|
||||
func (c *Conn) initializeServer() (*url.URL, bool, error) {
|
||||
err := handshake.DoServer(c.bc, false)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
@ -571,6 +576,26 @@ func (c *Conn) InitializeServer() (*url.URL, bool, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func newNoHandshakeConn(rw io.ReadWriter) *Conn {
|
||||
c := &Conn{
|
||||
bc: bytecounter.NewReadWriter(rw),
|
||||
}
|
||||
|
||||
c.mrw = message.NewReadWriter(c.bc, false)
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
// BytesReceived returns the number of bytes received.
|
||||
func (c *Conn) BytesReceived() uint64 {
|
||||
return c.bc.Reader.Count()
|
||||
}
|
||||
|
||||
// BytesSent returns the number of bytes sent.
|
||||
func (c *Conn) BytesSent() uint64 {
|
||||
return c.bc.Writer.Count()
|
||||
}
|
||||
|
||||
// Read reads a message.
|
||||
func (c *Conn) Read() (message.Message, error) {
|
||||
return c.mrw.Read()
|
||||
|
@ -14,7 +14,7 @@ import (
|
||||
"github.com/bluenviron/mediamtx/internal/rtmp/message"
|
||||
)
|
||||
|
||||
func TestInitializeClient(t *testing.T) {
|
||||
func TestNewClientConn(t *testing.T) {
|
||||
for _, ca := range []string{"read", "publish"} {
|
||||
t.Run(ca, func(t *testing.T) {
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:9121")
|
||||
@ -236,9 +236,8 @@ func TestInitializeClient(t *testing.T) {
|
||||
nconn, err := net.Dial("tcp", u.Host)
|
||||
require.NoError(t, err)
|
||||
defer nconn.Close()
|
||||
conn := NewConn(nconn)
|
||||
|
||||
err = conn.InitializeClient(u, ca == "publish")
|
||||
conn, err := NewClientConn(nconn, u, ca == "publish")
|
||||
require.NoError(t, err)
|
||||
|
||||
if ca == "read" {
|
||||
@ -254,7 +253,7 @@ func TestInitializeClient(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestInitializeServer(t *testing.T) {
|
||||
func TestNewServerConn(t *testing.T) {
|
||||
for _, ca := range []string{
|
||||
"read",
|
||||
"publish",
|
||||
@ -272,9 +271,9 @@ func TestInitializeServer(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer nconn.Close()
|
||||
|
||||
conn := NewConn(nconn)
|
||||
u, isPublishing, err := conn.InitializeServer()
|
||||
_, u, isPublishing, err := NewServerConn(nconn)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, &url.URL{
|
||||
Scheme: "rtmp",
|
||||
Host: "127.0.0.1:9121",
|
||||
@ -488,7 +487,7 @@ func BenchmarkRead(b *testing.B) {
|
||||
})
|
||||
}
|
||||
|
||||
conn := NewConn(&buf)
|
||||
conn := newNoHandshakeConn(&buf)
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
conn.Read()
|
||||
|
@ -536,8 +536,7 @@ func TestReadTracks(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
c := NewConn(&buf)
|
||||
c.skipInitialization()
|
||||
c := newNoHandshakeConn(&buf)
|
||||
|
||||
r, err := NewReader(c)
|
||||
require.NoError(t, err)
|
||||
|
@ -40,8 +40,7 @@ func TestWriteTracks(t *testing.T) {
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
c := NewConn(&buf)
|
||||
c.skipInitialization()
|
||||
c := newNoHandshakeConn(&buf)
|
||||
|
||||
_, err := NewWriter(c, videoTrack, audioTrack)
|
||||
require.NoError(t, err)
|
||||
|
Loading…
Reference in New Issue
Block a user