From baf2100ad65e383ecce2d7b7f20336692efc5b19 Mon Sep 17 00:00:00 2001 From: aler9 <46489434+aler9@users.noreply.github.com> Date: Sun, 9 May 2021 16:32:54 +0200 Subject: [PATCH] connrtmp: use contexts --- internal/connrtmp/conn.go | 55 ++++++++++++++++++++++++--------------- 1 file changed, 34 insertions(+), 21 deletions(-) diff --git a/internal/connrtmp/conn.go b/internal/connrtmp/conn.go index 502eb3b8..3e66dd46 100644 --- a/internal/connrtmp/conn.go +++ b/internal/connrtmp/conn.go @@ -1,6 +1,7 @@ package connrtmp import ( + "context" "fmt" "io" "net" @@ -159,36 +160,22 @@ func (c *Conn) run() { defer onConnectCmd.Close() } - c.ringBuffer = ringbuffer.New(uint64(c.readBufferCount)) - + ctx, cancel := context.WithCancel(context.Background()) connErr := make(chan error) go func() { - connErr <- func() error { - c.conn.NetConn().SetReadDeadline(time.Now().Add(c.readTimeout)) - c.conn.NetConn().SetWriteDeadline(time.Now().Add(c.writeTimeout)) - err := c.conn.ServerHandshake() - if err != nil { - return err - } - - if c.conn.IsPublishing() { - return c.runPublish() - } - return c.runRead() - }() + connErr <- c.runInner(ctx) }() select { case err := <-connErr: + cancel() + if err != io.EOF { c.log(logger.Info, "ERR: %s", err) } - c.conn.NetConn().Close() - case <-c.terminate: - c.ringBuffer.Close() - c.conn.NetConn().Close() + cancel() <-connErr } @@ -202,7 +189,26 @@ func (c *Conn) run() { <-c.parentTerminate } -func (c *Conn) runRead() error { +func (c *Conn) runInner(ctx context.Context) error { + go func() { + <-ctx.Done() + c.conn.NetConn().Close() + }() + + c.conn.NetConn().SetReadDeadline(time.Now().Add(c.readTimeout)) + c.conn.NetConn().SetWriteDeadline(time.Now().Add(c.writeTimeout)) + err := c.conn.ServerHandshake() + if err != nil { + return err + } + + if c.conn.IsPublishing() { + return c.runPublish(ctx) + } + return c.runRead(ctx) +} + +func (c *Conn) runRead(ctx context.Context) error { pathName, query := pathNameAndQuery(c.conn.URL()) sres := make(chan readpublisher.SetupPlayRes) @@ -259,6 +265,13 @@ func (c *Conn) runRead() error { c.conn.NetConn().SetWriteDeadline(time.Now().Add(c.writeTimeout)) c.conn.WriteMetadata(videoTrack, audioTrack) + c.ringBuffer = ringbuffer.New(uint64(c.readBufferCount)) + + go func() { + <-ctx.Done() + c.ringBuffer.Close() + }() + pres := make(chan readpublisher.PlayRes) c.path.OnReadPublisherPlay(readpublisher.PlayReq{c, pres}) //nolint:govet <-pres @@ -348,7 +361,7 @@ func (c *Conn) runRead() error { } } -func (c *Conn) runPublish() error { +func (c *Conn) runPublish(ctx context.Context) error { c.conn.NetConn().SetReadDeadline(time.Now().Add(c.readTimeout)) videoTrack, audioTrack, err := c.conn.ReadMetadata() if err != nil {