From cd19332350ba39ef36daaaad6b9ac44804425e49 Mon Sep 17 00:00:00 2001 From: aler9 <46489434+aler9@users.noreply.github.com> Date: Tue, 7 Jun 2022 20:00:24 +0200 Subject: [PATCH] rtmp: use bufio reader during handshake --- internal/rtmp/conn_test.go | 19 +++++++++++-------- internal/rtmp/handshake/c0s0.go | 3 ++- internal/rtmp/handshake/c0s0_test.go | 3 ++- internal/rtmp/handshake/c1s1.go | 3 ++- internal/rtmp/handshake/c1s1_test.go | 3 ++- internal/rtmp/handshake/c2s2.go | 3 ++- internal/rtmp/handshake/c2s2_test.go | 3 ++- internal/rtmp/message/reader.go | 4 ++-- internal/rtmp/rawmessage/reader.go | 5 ++--- 9 files changed, 27 insertions(+), 19 deletions(-) diff --git a/internal/rtmp/conn_test.go b/internal/rtmp/conn_test.go index 9c5fc7ab..dbaae924 100644 --- a/internal/rtmp/conn_test.go +++ b/internal/rtmp/conn_test.go @@ -1,6 +1,7 @@ package rtmp import ( + "bufio" "net" "net/url" "strings" @@ -113,6 +114,7 @@ func TestReadTracks(t *testing.T) { conn, err := net.Dial("tcp", "127.0.0.1:9121") require.NoError(t, err) defer conn.Close() + br := bufio.NewReader(conn) // C->S handshake C0 err = handshake.C0S0{}.Write(conn) @@ -124,16 +126,16 @@ func TestReadTracks(t *testing.T) { require.NoError(t, err) // S->C handshake S0 - err = handshake.C0S0{}.Read(conn) + err = handshake.C0S0{}.Read(br) require.NoError(t, err) // S->C handshake S1 s1 := handshake.C1S1{} - err = s1.Read(conn, false) + err = s1.Read(br, false) require.NoError(t, err) // S->C handshake S2 - err = (&handshake.C2S2{Digest: c1.Digest}).Read(conn) + err = (&handshake.C2S2{Digest: c1.Digest}).Read(br) require.NoError(t, err) // C->S handshake C2 @@ -141,7 +143,7 @@ func TestReadTracks(t *testing.T) { require.NoError(t, err) mw := message.NewWriter(conn) - mr := message.NewReader(conn) + mr := message.NewReader(br) // C->S connect err = mw.Write(&message.MsgCommandAMF0{ @@ -473,6 +475,7 @@ func TestWriteTracks(t *testing.T) { conn, err := net.Dial("tcp", "127.0.0.1:9121") require.NoError(t, err) defer conn.Close() + br := bufio.NewReader(conn) // C->S handshake C0 err = handshake.C0S0{}.Write(conn) @@ -484,16 +487,16 @@ func TestWriteTracks(t *testing.T) { require.NoError(t, err) // S->C handshake S0 - err = handshake.C0S0{}.Read(conn) + err = handshake.C0S0{}.Read(br) require.NoError(t, err) // S->C handshake S1 s1 := handshake.C1S1{} - err = s1.Read(conn, false) + err = s1.Read(br, false) require.NoError(t, err) // S->C handshake S2 - err = (&handshake.C2S2{Digest: c1.Digest}).Read(conn) + err = (&handshake.C2S2{Digest: c1.Digest}).Read(br) require.NoError(t, err) // C->S handshake C2 @@ -501,7 +504,7 @@ func TestWriteTracks(t *testing.T) { require.NoError(t, err) mw := message.NewWriter(conn) - mr := message.NewReader(conn) + mr := message.NewReader(br) // C->S connect err = mw.Write(&message.MsgCommandAMF0{ diff --git a/internal/rtmp/handshake/c0s0.go b/internal/rtmp/handshake/c0s0.go index 0e650c38..228875ee 100644 --- a/internal/rtmp/handshake/c0s0.go +++ b/internal/rtmp/handshake/c0s0.go @@ -1,6 +1,7 @@ package handshake import ( + "bufio" "fmt" "io" ) @@ -13,7 +14,7 @@ const ( type C0S0 struct{} // Read reads a C0S0. -func (C0S0) Read(r io.Reader) error { +func (C0S0) Read(r *bufio.Reader) error { buf := make([]byte, 1) _, err := io.ReadFull(r, buf) if err != nil { diff --git a/internal/rtmp/handshake/c0s0_test.go b/internal/rtmp/handshake/c0s0_test.go index 3ea439eb..ef006b18 100644 --- a/internal/rtmp/handshake/c0s0_test.go +++ b/internal/rtmp/handshake/c0s0_test.go @@ -1,6 +1,7 @@ package handshake import ( + "bufio" "bytes" "testing" @@ -13,7 +14,7 @@ var c0s0dec = C0S0{} func TestC0S0Read(t *testing.T) { var c0s0 C0S0 - err := c0s0.Read(bytes.NewReader(c0s0enc)) + err := c0s0.Read(bufio.NewReader(bytes.NewReader(c0s0enc))) require.NoError(t, err) require.Equal(t, c0s0dec, c0s0) } diff --git a/internal/rtmp/handshake/c1s1.go b/internal/rtmp/handshake/c1s1.go index 3d1a1af3..ff4284fe 100644 --- a/internal/rtmp/handshake/c1s1.go +++ b/internal/rtmp/handshake/c1s1.go @@ -1,6 +1,7 @@ package handshake import ( + "bufio" "bytes" "crypto/hmac" "crypto/rand" @@ -78,7 +79,7 @@ type C1S1 struct { } // Read reads a C1S1. -func (c *C1S1) Read(r io.Reader, isC1 bool) error { +func (c *C1S1) Read(r *bufio.Reader, isC1 bool) error { buf := make([]byte, 1536) _, err := io.ReadFull(r, buf) if err != nil { diff --git a/internal/rtmp/handshake/c1s1_test.go b/internal/rtmp/handshake/c1s1_test.go index 01937392..f93fb2a8 100644 --- a/internal/rtmp/handshake/c1s1_test.go +++ b/internal/rtmp/handshake/c1s1_test.go @@ -1,6 +1,7 @@ package handshake import ( + "bufio" "bytes" "testing" @@ -43,7 +44,7 @@ func TestC1S1Read(t *testing.T) { ) var c1s1 C1S1 - err := c1s1.Read(bytes.NewReader(c1s1enc), true) + err := c1s1.Read(bufio.NewReader(bytes.NewReader(c1s1enc)), true) require.NoError(t, err) require.Equal(t, c1s1dec, c1s1) } diff --git a/internal/rtmp/handshake/c2s2.go b/internal/rtmp/handshake/c2s2.go index 58d6de54..45b71c7f 100644 --- a/internal/rtmp/handshake/c2s2.go +++ b/internal/rtmp/handshake/c2s2.go @@ -1,6 +1,7 @@ package handshake import ( + "bufio" "bytes" "crypto/rand" "encoding/binary" @@ -17,7 +18,7 @@ type C2S2 struct { } // Read reads a C2S2. -func (c *C2S2) Read(r io.Reader) error { +func (c *C2S2) Read(r *bufio.Reader) error { buf := make([]byte, 1536) _, err := io.ReadFull(r, buf) if err != nil { diff --git a/internal/rtmp/handshake/c2s2_test.go b/internal/rtmp/handshake/c2s2_test.go index fd092a3d..2f08446a 100644 --- a/internal/rtmp/handshake/c2s2_test.go +++ b/internal/rtmp/handshake/c2s2_test.go @@ -1,6 +1,7 @@ package handshake import ( + "bufio" "bytes" "testing" @@ -42,7 +43,7 @@ func TestC2S2Read(t *testing.T) { var c2s2 C2S2 c2s2.Digest = c2s2dec.Digest - err := c2s2.Read(bytes.NewReader(c2s2enc)) + err := c2s2.Read(bufio.NewReader(bytes.NewReader(c2s2enc))) require.NoError(t, err) require.Equal(t, c2s2dec, c2s2) } diff --git a/internal/rtmp/message/reader.go b/internal/rtmp/message/reader.go index 3932c4e5..991d53e3 100644 --- a/internal/rtmp/message/reader.go +++ b/internal/rtmp/message/reader.go @@ -1,9 +1,9 @@ package message import ( + "bufio" "encoding/binary" "fmt" - "io" "github.com/aler9/rtsp-simple-server/internal/rtmp/chunk" "github.com/aler9/rtsp-simple-server/internal/rtmp/rawmessage" @@ -75,7 +75,7 @@ type Reader struct { } // NewReader allocates a Reader. -func NewReader(r io.Reader) *Reader { +func NewReader(r *bufio.Reader) *Reader { return &Reader{ r: rawmessage.NewReader(r), } diff --git a/internal/rtmp/rawmessage/reader.go b/internal/rtmp/rawmessage/reader.go index e87ef364..09c1b0c4 100644 --- a/internal/rtmp/rawmessage/reader.go +++ b/internal/rtmp/rawmessage/reader.go @@ -4,7 +4,6 @@ import ( "bufio" "errors" "fmt" - "io" "github.com/aler9/rtsp-simple-server/internal/rtmp/chunk" ) @@ -169,9 +168,9 @@ type Reader struct { } // NewReader allocates a Reader. -func NewReader(r io.Reader) *Reader { +func NewReader(r *bufio.Reader) *Reader { return &Reader{ - r: bufio.NewReader(r), + r: r, chunkSize: 128, chunkStreams: make(map[byte]*readerChunkStream), }