rtmp: simplify API (#2130)

This commit is contained in:
Alessandro Ros 2023-07-31 19:41:59 +02:00 committed by GitHub
parent 959b017d72
commit d696a782f7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 95 additions and 78 deletions

View File

@ -573,9 +573,8 @@ func TestAPIProtocolList(t *testing.T) {
}()
require.NoError(t, err)
defer nconn.Close()
conn := rtmp.NewConn(nconn)
err = conn.InitializeClient(u, true)
conn, err := rtmp.NewClientConn(nconn, u, true)
require.NoError(t, err)
_, err = rtmp.NewWriter(conn, testFormatH264, nil)
@ -828,9 +827,8 @@ func TestAPIProtocolGet(t *testing.T) {
}()
require.NoError(t, err)
defer nconn.Close()
conn := rtmp.NewConn(nconn)
err = conn.InitializeClient(u, true)
conn, err := rtmp.NewClientConn(nconn, u, true)
require.NoError(t, err)
_, err = rtmp.NewWriter(conn, testFormatH264, nil)
@ -1150,9 +1148,8 @@ func TestAPIProtocolKick(t *testing.T) {
nconn, err := net.Dial("tcp", u.Host)
require.NoError(t, err)
defer nconn.Close()
conn := rtmp.NewConn(nconn)
err = conn.InitializeClient(u, true)
conn, err := rtmp.NewClientConn(nconn, u, true)
require.NoError(t, err)
_, err = rtmp.NewWriter(conn, testFormatH264, nil)

View File

@ -85,9 +85,8 @@ webrtc_sessions_bytes_sent 0
nconn, err := net.Dial("tcp", u.Host)
require.NoError(t, err)
defer nconn.Close()
conn := rtmp.NewConn(nconn)
err = conn.InitializeClient(u, true)
conn, err := rtmp.NewClientConn(nconn, u, true)
require.NoError(t, err)
videoTrack := &formats.H264{

View File

@ -65,7 +65,6 @@ type rtmpConn struct {
runOnConnect string
runOnConnectRestart bool
wg *sync.WaitGroup
conn *rtmp.Conn
nconn net.Conn
externalCmdPool *externalcmd.Pool
pathManager rtmpConnPathManager
@ -75,7 +74,8 @@ type rtmpConn struct {
ctxCancel func()
uuid uuid.UUID
created time.Time
mutex sync.Mutex
mutex sync.RWMutex
conn *rtmp.Conn
state rtmpConnState
pathName string
}
@ -106,7 +106,6 @@ func newRTMPConn(
runOnConnect: runOnConnect,
runOnConnectRestart: runOnConnectRestart,
wg: wg,
conn: rtmp.NewConn(nconn),
nconn: nconn,
externalCmdPool: externalCmdPool,
pathManager: pathManager,
@ -194,18 +193,22 @@ func (c *rtmpConn) run() {
func (c *rtmpConn) runReader() error {
c.nconn.SetReadDeadline(time.Now().Add(time.Duration(c.readTimeout)))
c.nconn.SetWriteDeadline(time.Now().Add(time.Duration(c.writeTimeout)))
u, publish, err := c.conn.InitializeServer()
conn, u, publish, err := rtmp.NewServerConn(c.nconn)
if err != nil {
return err
}
c.mutex.Lock()
c.conn = conn
c.mutex.Unlock()
if !publish {
return c.runRead(u)
return c.runRead(conn, u)
}
return c.runPublish(u)
return c.runPublish(conn, u)
}
func (c *rtmpConn) runRead(u *url.URL) error {
func (c *rtmpConn) runRead(conn *rtmp.Conn, u *url.URL) error {
pathName, query, rawQuery := pathNameAndQuery(u)
res := c.pathManager.addReader(pathAddReaderReq{
@ -298,7 +301,7 @@ func (c *rtmpConn) runRead(u *url.URL) error {
}
var err error
w, err = rtmp.NewWriter(c.conn, videoFormat, audioFormat)
w, err = rtmp.NewWriter(conn, videoFormat, audioFormat)
if err != nil {
return err
}
@ -569,7 +572,7 @@ func (c *rtmpConn) setupAudio(
return nil, nil
}
func (c *rtmpConn) runPublish(u *url.URL) error {
func (c *rtmpConn) runPublish(conn *rtmp.Conn, u *url.URL) error {
pathName, query, rawQuery := pathNameAndQuery(u)
res := c.pathManager.addPublisher(pathAddPublisherReq{
@ -601,7 +604,7 @@ func (c *rtmpConn) runPublish(u *url.URL) error {
c.pathName = pathName
c.mutex.Unlock()
r, err := rtmp.NewReader(c.conn)
r, err := rtmp.NewReader(conn)
if err != nil {
return err
}
@ -731,8 +734,16 @@ func (c *rtmpConn) apiSourceDescribe() pathAPISourceOrReader {
}
func (c *rtmpConn) apiItem() *apiRTMPConn {
c.mutex.Lock()
defer c.mutex.Unlock()
c.mutex.RLock()
defer c.mutex.RUnlock()
bytesReceived := uint64(0)
bytesSent := uint64(0)
if c.conn != nil {
bytesReceived = c.conn.BytesReceived()
bytesSent = c.conn.BytesSent()
}
return &apiRTMPConn{
ID: c.uuid,
@ -749,7 +760,7 @@ func (c *rtmpConn) apiItem() *apiRTMPConn {
return "idle"
}(),
Path: c.pathName,
BytesReceived: c.conn.BytesReceived(),
BytesSent: c.conn.BytesSent(),
BytesReceived: bytesReceived,
BytesSent: bytesSent,
}
}

View File

@ -34,9 +34,8 @@ func TestRTMPServerRunOnConnect(t *testing.T) {
nconn, err := net.Dial("tcp", u.Host)
require.NoError(t, err)
defer nconn.Close()
conn := rtmp.NewConn(nconn)
err = conn.InitializeClient(u, true)
_, err = rtmp.NewClientConn(nconn, u, true)
require.NoError(t, err)
time.Sleep(500 * time.Millisecond)
@ -125,9 +124,8 @@ func TestRTMPServer(t *testing.T) {
}()
require.NoError(t, err)
defer nconn1.Close()
conn1 := rtmp.NewConn(nconn1)
err = conn1.InitializeClient(u1, true)
conn1, err := rtmp.NewClientConn(nconn1, u1, true)
require.NoError(t, err)
videoTrack := &formats.H264{
@ -175,9 +173,8 @@ func TestRTMPServer(t *testing.T) {
}()
require.NoError(t, err)
defer nconn2.Close()
conn2 := rtmp.NewConn(nconn2)
err = conn2.InitializeClient(u2, false)
conn2, err := rtmp.NewClientConn(nconn2, u2, false)
require.NoError(t, err)
r, err := rtmp.NewReader(conn2)
@ -237,9 +234,8 @@ func TestRTMPServerAuthFail(t *testing.T) {
nconn1, err := net.Dial("tcp", u1.Host)
require.NoError(t, err)
defer nconn1.Close()
conn1 := rtmp.NewConn(nconn1)
err = conn1.InitializeClient(u1, true)
conn1, err := rtmp.NewClientConn(nconn1, u1, true)
require.NoError(t, err)
videoTrack := &formats.H264{
@ -266,9 +262,8 @@ func TestRTMPServerAuthFail(t *testing.T) {
nconn2, err := net.Dial("tcp", u2.Host)
require.NoError(t, err)
defer nconn2.Close()
conn2 := rtmp.NewConn(nconn2)
err = conn2.InitializeClient(u2, false)
conn2, err := rtmp.NewClientConn(nconn2, u2, false)
require.NoError(t, err)
_, err = rtmp.NewReader(conn2)
@ -291,9 +286,8 @@ func TestRTMPServerAuthFail(t *testing.T) {
nconn1, err := net.Dial("tcp", u1.Host)
require.NoError(t, err)
defer nconn1.Close()
conn1 := rtmp.NewConn(nconn1)
err = conn1.InitializeClient(u1, true)
conn1, err := rtmp.NewClientConn(nconn1, u1, true)
require.NoError(t, err)
videoTrack := &formats.H264{
@ -320,9 +314,8 @@ func TestRTMPServerAuthFail(t *testing.T) {
nconn2, err := net.Dial("tcp", u2.Host)
require.NoError(t, err)
defer nconn2.Close()
conn2 := rtmp.NewConn(nconn2)
err = conn2.InitializeClient(u2, false)
conn2, err := rtmp.NewClientConn(nconn2, u2, false)
require.NoError(t, err)
_, err = rtmp.NewReader(conn2)
@ -346,9 +339,8 @@ func TestRTMPServerAuthFail(t *testing.T) {
nconn1, err := net.Dial("tcp", u1.Host)
require.NoError(t, err)
defer nconn1.Close()
conn1 := rtmp.NewConn(nconn1)
err = conn1.InitializeClient(u1, true)
conn1, err := rtmp.NewClientConn(nconn1, u1, true)
require.NoError(t, err)
videoTrack := &formats.H264{
@ -375,9 +367,8 @@ func TestRTMPServerAuthFail(t *testing.T) {
nconn2, err := net.Dial("tcp", u2.Host)
require.NoError(t, err)
defer nconn2.Close()
conn2 := rtmp.NewConn(nconn2)
err = conn2.InitializeClient(u2, false)
conn2, err := rtmp.NewClientConn(nconn2, u2, false)
require.NoError(t, err)
_, err = rtmp.NewReader(conn2)

View File

@ -99,11 +99,9 @@ func (s *rtmpSource) run(ctx context.Context, cnf *conf.PathConf, reloadConf cha
}
func (s *rtmpSource) runReader(u *url.URL, nconn net.Conn) error {
conn := rtmp.NewConn(nconn)
nconn.SetReadDeadline(time.Now().Add(time.Duration(s.readTimeout)))
nconn.SetWriteDeadline(time.Now().Add(time.Duration(s.writeTimeout)))
err := conn.InitializeClient(u, false)
conn, err := rtmp.NewClientConn(nconn, u, false)
if err != nil {
return err
}

View File

@ -52,9 +52,8 @@ func TestRTMPSource(t *testing.T) {
nconn, err := ln.Accept()
require.NoError(t, err)
defer nconn.Close()
conn := rtmp.NewConn(nconn)
_, _, err = conn.InitializeServer()
conn, _, _, err := rtmp.NewServerConn(nconn)
require.NoError(t, err)
videoTrack := &formats.H264{

View File

@ -139,29 +139,21 @@ type Conn struct {
mrw *message.ReadWriter
}
// NewConn initializes a connection.
func NewConn(rw io.ReadWriter) *Conn {
return &Conn{
// NewClientConn initializes a client-side connection.
func NewClientConn(rw io.ReadWriter, u *url.URL, publish bool) (*Conn, error) {
c := &Conn{
bc: bytecounter.NewReadWriter(rw),
}
err := c.initializeClient(u, publish)
if err != nil {
return nil, err
}
return c, nil
}
// BytesReceived returns the number of bytes received.
func (c *Conn) BytesReceived() uint64 {
return c.bc.Reader.Count()
}
// BytesSent returns the number of bytes sent.
func (c *Conn) BytesSent() uint64 {
return c.bc.Writer.Count()
}
func (c *Conn) skipInitialization() {
c.mrw = message.NewReadWriter(c.bc, false)
}
// InitializeClient performs the initialization of a client-side connection.
func (c *Conn) InitializeClient(u *url.URL, isPublishing bool) error {
func (c *Conn) initializeClient(u *url.URL, publish bool) error {
connectpath, actionpath := splitPath(u)
err := handshake.DoClient(c.bc, false)
@ -219,7 +211,7 @@ func (c *Conn) InitializeClient(u *url.URL, isPublishing bool) error {
return err
}
if !isPublishing {
if !publish {
err = c.mrw.Write(&message.CommandAMF0{
ChunkStreamID: 3,
Name: "createStream",
@ -322,8 +314,21 @@ func (c *Conn) InitializeClient(u *url.URL, isPublishing bool) error {
return readCommandResult(c.mrw, 5, "onStatus", resultIsOK1)
}
// InitializeServer performs the initialization of a server-side connection.
func (c *Conn) InitializeServer() (*url.URL, bool, error) {
// NewServerConn initializes a server-side connection.
func NewServerConn(rw io.ReadWriter) (*Conn, *url.URL, bool, error) {
c := &Conn{
bc: bytecounter.NewReadWriter(rw),
}
u, publish, err := c.initializeServer()
if err != nil {
return nil, nil, false, err
}
return c, u, publish, nil
}
func (c *Conn) initializeServer() (*url.URL, bool, error) {
err := handshake.DoServer(c.bc, false)
if err != nil {
return nil, false, err
@ -571,6 +576,26 @@ func (c *Conn) InitializeServer() (*url.URL, bool, error) {
}
}
func newNoHandshakeConn(rw io.ReadWriter) *Conn {
c := &Conn{
bc: bytecounter.NewReadWriter(rw),
}
c.mrw = message.NewReadWriter(c.bc, false)
return c
}
// BytesReceived returns the number of bytes received.
func (c *Conn) BytesReceived() uint64 {
return c.bc.Reader.Count()
}
// BytesSent returns the number of bytes sent.
func (c *Conn) BytesSent() uint64 {
return c.bc.Writer.Count()
}
// Read reads a message.
func (c *Conn) Read() (message.Message, error) {
return c.mrw.Read()

View File

@ -14,7 +14,7 @@ import (
"github.com/bluenviron/mediamtx/internal/rtmp/message"
)
func TestInitializeClient(t *testing.T) {
func TestNewClientConn(t *testing.T) {
for _, ca := range []string{"read", "publish"} {
t.Run(ca, func(t *testing.T) {
ln, err := net.Listen("tcp", "127.0.0.1:9121")
@ -236,9 +236,8 @@ func TestInitializeClient(t *testing.T) {
nconn, err := net.Dial("tcp", u.Host)
require.NoError(t, err)
defer nconn.Close()
conn := NewConn(nconn)
err = conn.InitializeClient(u, ca == "publish")
conn, err := NewClientConn(nconn, u, ca == "publish")
require.NoError(t, err)
if ca == "read" {
@ -254,7 +253,7 @@ func TestInitializeClient(t *testing.T) {
}
}
func TestInitializeServer(t *testing.T) {
func TestNewServerConn(t *testing.T) {
for _, ca := range []string{
"read",
"publish",
@ -272,9 +271,9 @@ func TestInitializeServer(t *testing.T) {
require.NoError(t, err)
defer nconn.Close()
conn := NewConn(nconn)
u, isPublishing, err := conn.InitializeServer()
_, u, isPublishing, err := NewServerConn(nconn)
require.NoError(t, err)
require.Equal(t, &url.URL{
Scheme: "rtmp",
Host: "127.0.0.1:9121",
@ -488,7 +487,7 @@ func BenchmarkRead(b *testing.B) {
})
}
conn := NewConn(&buf)
conn := newNoHandshakeConn(&buf)
for n := 0; n < b.N; n++ {
conn.Read()

View File

@ -536,8 +536,7 @@ func TestReadTracks(t *testing.T) {
require.NoError(t, err)
}
c := NewConn(&buf)
c.skipInitialization()
c := newNoHandshakeConn(&buf)
r, err := NewReader(c)
require.NoError(t, err)

View File

@ -40,8 +40,7 @@ func TestWriteTracks(t *testing.T) {
}
var buf bytes.Buffer
c := NewConn(&buf)
c.skipInitialization()
c := newNoHandshakeConn(&buf)
_, err := NewWriter(c, videoTrack, audioTrack)
require.NoError(t, err)