From 7abb85ab2099cbd2f30982c86b925fabfa5628b3 Mon Sep 17 00:00:00 2001 From: aler9 <46489434+aler9@users.noreply.github.com> Date: Tue, 7 Jun 2022 21:09:57 +0200 Subject: [PATCH] rtmp: improve tests --- internal/rtmp/chunk/chunk2.go | 2 +- internal/rtmp/chunk/chunk2_test.go | 2 +- internal/rtmp/handshake/c1s1.go | 8 +- internal/rtmp/rawmessage/reader.go | 57 ++++++--- internal/rtmp/rawmessage/reader_test.go | 153 +++++++++++++++++++++++ internal/rtmp/rawmessage/writer_test.go | 156 ++++++++++++++++++++++++ 6 files changed, 356 insertions(+), 22 deletions(-) create mode 100644 internal/rtmp/rawmessage/reader_test.go create mode 100644 internal/rtmp/rawmessage/writer_test.go diff --git a/internal/rtmp/chunk/chunk2.go b/internal/rtmp/chunk/chunk2.go index 18040d11..5d552c4f 100644 --- a/internal/rtmp/chunk/chunk2.go +++ b/internal/rtmp/chunk/chunk2.go @@ -33,7 +33,7 @@ func (c *Chunk2) Read(r io.Reader, chunkBodyLen int) error { // Write writes the chunk. func (c Chunk2) Write(w io.Writer) error { header := make([]byte, 4) - header[0] = 1<<6 | c.ChunkStreamID + header[0] = 2<<6 | c.ChunkStreamID header[1] = byte(c.TimestampDelta >> 16) header[2] = byte(c.TimestampDelta >> 8) header[3] = byte(c.TimestampDelta) diff --git a/internal/rtmp/chunk/chunk2_test.go b/internal/rtmp/chunk/chunk2_test.go index 066f810d..99bfac68 100644 --- a/internal/rtmp/chunk/chunk2_test.go +++ b/internal/rtmp/chunk/chunk2_test.go @@ -8,7 +8,7 @@ import ( ) var chunk2enc = []byte{ - 0x59, 0xb1, 0xa1, 0x91, 0x1, 0x2, 0x3, 0x4, + 0x99, 0xb1, 0xa1, 0x91, 0x1, 0x2, 0x3, 0x4, } var chunk2dec = Chunk2{ diff --git a/internal/rtmp/handshake/c1s1.go b/internal/rtmp/handshake/c1s1.go index ff4284fe..ed2045fe 100644 --- a/internal/rtmp/handshake/c1s1.go +++ b/internal/rtmp/handshake/c1s1.go @@ -33,15 +33,15 @@ var ( hsServerPartialKey = hsServerFullKey[:36] ) -func hsCalcDigestPos(p []byte, base int) (pos int) { +func hsCalcDigestPos(p []byte, base int) int { + pos := 0 for i := 0; i < 4; i++ { pos += int(p[base+i]) } - pos = (pos % 728) + base + 4 - return + return (pos % 728) + base + 4 } -func hsMakeDigest(key []byte, src []byte, gap int) (dst []byte) { +func hsMakeDigest(key []byte, src []byte, gap int) []byte { h := hmac.New(sha256.New, key) if gap <= 0 { h.Write(src) diff --git a/internal/rtmp/rawmessage/reader.go b/internal/rtmp/rawmessage/reader.go index 09c1b0c4..7552ad71 100644 --- a/internal/rtmp/rawmessage/reader.go +++ b/internal/rtmp/rawmessage/reader.go @@ -17,6 +17,7 @@ type readerChunkStream struct { curMessageStreamID *uint32 curBodyLen *uint32 curBody *[]byte + curTimestampDelta *uint32 } func (rc *readerChunkStream) read(typ byte) (*Message, error) { @@ -40,6 +41,7 @@ func (rc *readerChunkStream) read(typ byte) (*Message, error) { rc.curTimestamp = &v3 v4 := c0.BodyLen rc.curBodyLen = &v4 + rc.curTimestampDelta = nil if c0.BodyLen != uint32(len(c0.Body)) { rc.curBody = &c0.Body @@ -74,6 +76,8 @@ func (rc *readerChunkStream) read(typ byte) (*Message, error) { rc.curTimestamp = &v3 v4 := c1.BodyLen rc.curBodyLen = &v4 + v5 := c1.TimestampDelta + rc.curTimestampDelta = &v5 if c1.BodyLen != uint32(len(c1.Body)) { rc.curBody = &c1.Body @@ -81,7 +85,7 @@ func (rc *readerChunkStream) read(typ byte) (*Message, error) { } return &Message{ - Timestamp: *rc.curTimestamp + c1.TimestampDelta, + Timestamp: *rc.curTimestamp, Type: c1.Type, MessageStreamID: *rc.curMessageStreamID, Body: c1.Body, @@ -107,8 +111,10 @@ func (rc *readerChunkStream) read(typ byte) (*Message, error) { return nil, err } - v3 := *rc.curTimestamp + c2.TimestampDelta - rc.curTimestamp = &v3 + v1 := *rc.curTimestamp + c2.TimestampDelta + rc.curTimestamp = &v1 + v2 := c2.TimestampDelta + rc.curTimestampDelta = &v2 if chunkBodyLen != len(c2.Body) { rc.curBody = &c2.Body @@ -116,19 +122,44 @@ func (rc *readerChunkStream) read(typ byte) (*Message, error) { } return &Message{ - Timestamp: *rc.curTimestamp + c2.TimestampDelta, + Timestamp: *rc.curTimestamp, Type: *rc.curType, MessageStreamID: *rc.curMessageStreamID, Body: c2.Body, }, nil default: // 3 - if rc.curTimestamp == nil { + if rc.curBody == nil && rc.curTimestampDelta == nil { return nil, fmt.Errorf("received type 3 chunk without previous chunk") } - if rc.curBody == nil { - return nil, fmt.Errorf("unsupported") + if rc.curBody != nil { + chunkBodyLen := int(*rc.curBodyLen) - len(*rc.curBody) + if chunkBodyLen > rc.mr.chunkSize { + chunkBodyLen = rc.mr.chunkSize + } + + var c3 chunk.Chunk3 + err := c3.Read(rc.mr.r, chunkBodyLen) + if err != nil { + return nil, err + } + + *rc.curBody = append(*rc.curBody, c3.Body...) + + if *rc.curBodyLen != uint32(len(*rc.curBody)) { + return nil, errMoreChunksNeeded + } + + body := *rc.curBody + rc.curBody = nil + + return &Message{ + Timestamp: *rc.curTimestamp, + Type: *rc.curType, + MessageStreamID: *rc.curMessageStreamID, + Body: body, + }, nil } chunkBodyLen := int(*rc.curBodyLen) @@ -142,20 +173,14 @@ func (rc *readerChunkStream) read(typ byte) (*Message, error) { return nil, err } - *rc.curBody = append(*rc.curBody, c3.Body...) - - if *rc.curBodyLen != uint32(len(*rc.curBody)) { - return nil, errMoreChunksNeeded - } - - body := *rc.curBody - rc.curBody = nil + v1 := *rc.curTimestamp + *rc.curTimestampDelta + rc.curTimestamp = &v1 return &Message{ Timestamp: *rc.curTimestamp, Type: *rc.curType, MessageStreamID: *rc.curMessageStreamID, - Body: body, + Body: c3.Body, }, nil } } diff --git a/internal/rtmp/rawmessage/reader_test.go b/internal/rtmp/rawmessage/reader_test.go new file mode 100644 index 00000000..4b7cddd0 --- /dev/null +++ b/internal/rtmp/rawmessage/reader_test.go @@ -0,0 +1,153 @@ +package rawmessage + +import ( + "bufio" + "bytes" + "io" + "testing" + + "github.com/aler9/rtsp-simple-server/internal/rtmp/chunk" + "github.com/stretchr/testify/require" +) + +type writableChunk interface { + Write(w io.Writer) error +} + +type sequenceEntry struct { + chunk writableChunk + msg *Message +} + +func TestReader(t *testing.T) { + testSequence := func(t *testing.T, seq []sequenceEntry) { + var buf bytes.Buffer + r := NewReader(bufio.NewReader(&buf)) + + for _, entry := range seq { + err := entry.chunk.Write(&buf) + require.NoError(t, err) + msg, err := r.Read() + require.NoError(t, err) + require.Equal(t, entry.msg, msg) + } + } + + t.Run("chunk0 + chunk1", func(t *testing.T) { + testSequence(t, []sequenceEntry{ + { + &chunk.Chunk0{ + ChunkStreamID: 27, + Timestamp: 18576, + Type: chunk.MessageTypeSetPeerBandwidth, + MessageStreamID: 3123, + BodyLen: 64, + Body: bytes.Repeat([]byte{0x02}, 64), + }, + &Message{ + ChunkStreamID: 27, + Timestamp: 18576, + Type: chunk.MessageTypeSetPeerBandwidth, + MessageStreamID: 3123, + Body: bytes.Repeat([]byte{0x02}, 64), + }, + }, + { + &chunk.Chunk1{ + ChunkStreamID: 27, + TimestampDelta: 15, + Type: chunk.MessageTypeSetPeerBandwidth, + BodyLen: 64, + Body: bytes.Repeat([]byte{0x03}, 64), + }, + &Message{ + ChunkStreamID: 27, + Timestamp: 18576 + 15, + Type: chunk.MessageTypeSetPeerBandwidth, + MessageStreamID: 3123, + Body: bytes.Repeat([]byte{0x03}, 64), + }, + }, + }) + }) + + t.Run("chunk0 + chunk2 + chunk3", func(t *testing.T) { + testSequence(t, []sequenceEntry{ + { + &chunk.Chunk0{ + ChunkStreamID: 27, + Timestamp: 18576, + Type: chunk.MessageTypeSetPeerBandwidth, + MessageStreamID: 3123, + BodyLen: 64, + Body: bytes.Repeat([]byte{0x02}, 64), + }, + &Message{ + ChunkStreamID: 27, + Timestamp: 18576, + Type: chunk.MessageTypeSetPeerBandwidth, + MessageStreamID: 3123, + Body: bytes.Repeat([]byte{0x02}, 64), + }, + }, + { + &chunk.Chunk2{ + ChunkStreamID: 27, + TimestampDelta: 15, + Body: bytes.Repeat([]byte{0x03}, 64), + }, + &Message{ + ChunkStreamID: 27, + Timestamp: 18576 + 15, + Type: chunk.MessageTypeSetPeerBandwidth, + MessageStreamID: 3123, + Body: bytes.Repeat([]byte{0x03}, 64), + }, + }, + { + &chunk.Chunk3{ + ChunkStreamID: 27, + Body: bytes.Repeat([]byte{0x04}, 64), + }, + &Message{ + ChunkStreamID: 27, + Timestamp: 18576 + 15 + 15, + Type: chunk.MessageTypeSetPeerBandwidth, + MessageStreamID: 3123, + Body: bytes.Repeat([]byte{0x04}, 64), + }, + }, + }) + }) + + t.Run("chunk0 + chunk3", func(t *testing.T) { + var buf bytes.Buffer + r := NewReader(bufio.NewReader(&buf)) + + err := chunk.Chunk0{ + ChunkStreamID: 27, + Timestamp: 18576, + Type: chunk.MessageTypeSetPeerBandwidth, + MessageStreamID: 3123, + BodyLen: 192, + Body: bytes.Repeat([]byte{0x03}, 128), + }.Write(&buf) + require.NoError(t, err) + + err = chunk.Chunk3{ + ChunkStreamID: 27, + Body: bytes.Repeat([]byte{0x03}, 64), + }.Write(&buf) + require.NoError(t, err) + + msg, err := r.Read() + require.NoError(t, err) + require.Equal(t, &Message{ + ChunkStreamID: 27, + Timestamp: 18576, + Type: chunk.MessageTypeSetPeerBandwidth, + MessageStreamID: 3123, + Body: bytes.Repeat([]byte{0x03}, 192), + }, msg) + }) +} diff --git a/internal/rtmp/rawmessage/writer_test.go b/internal/rtmp/rawmessage/writer_test.go new file mode 100644 index 00000000..18822ff1 --- /dev/null +++ b/internal/rtmp/rawmessage/writer_test.go @@ -0,0 +1,156 @@ +package rawmessage + +import ( + "bufio" + "bytes" + "testing" + + "github.com/aler9/rtsp-simple-server/internal/rtmp/chunk" + "github.com/stretchr/testify/require" +) + +func TestWriter(t *testing.T) { + t.Run("chunk0 + chunk1", func(t *testing.T) { + var buf bytes.Buffer + br := bufio.NewReader(&buf) + w := NewWriter(&buf) + + err := w.Write(&Message{ + ChunkStreamID: 27, + Timestamp: 18576, + Type: chunk.MessageTypeSetPeerBandwidth, + MessageStreamID: 3123, + Body: bytes.Repeat([]byte{0x03}, 64), + }) + require.NoError(t, err) + + var c0 chunk.Chunk0 + err = c0.Read(br, 128) + require.NoError(t, err) + require.Equal(t, chunk.Chunk0{ + ChunkStreamID: 27, + Timestamp: 18576, + Type: chunk.MessageTypeSetPeerBandwidth, + MessageStreamID: 3123, + BodyLen: 64, + Body: bytes.Repeat([]byte{0x03}, 64), + }, c0) + + err = w.Write(&Message{ + ChunkStreamID: 27, + Timestamp: 18576 + 15, + Type: chunk.MessageTypeSetWindowAckSize, + MessageStreamID: 3123, + Body: bytes.Repeat([]byte{0x04}, 64), + }) + require.NoError(t, err) + + var c1 chunk.Chunk1 + err = c1.Read(br, 128) + require.NoError(t, err) + require.Equal(t, chunk.Chunk1{ + ChunkStreamID: 27, + TimestampDelta: 15, + Type: chunk.MessageTypeSetWindowAckSize, + BodyLen: 64, + Body: bytes.Repeat([]byte{0x04}, 64), + }, c1) + }) + + t.Run("chunk0 + chunk2 + chunk3", func(t *testing.T) { + var buf bytes.Buffer + br := bufio.NewReader(&buf) + w := NewWriter(&buf) + + err := w.Write(&Message{ + ChunkStreamID: 27, + Timestamp: 18576, + Type: chunk.MessageTypeSetPeerBandwidth, + MessageStreamID: 3123, + Body: bytes.Repeat([]byte{0x03}, 64), + }) + require.NoError(t, err) + + var c0 chunk.Chunk0 + err = c0.Read(br, 128) + require.NoError(t, err) + require.Equal(t, chunk.Chunk0{ + ChunkStreamID: 27, + Timestamp: 18576, + Type: chunk.MessageTypeSetPeerBandwidth, + MessageStreamID: 3123, + BodyLen: 64, + Body: bytes.Repeat([]byte{0x03}, 64), + }, c0) + + err = w.Write(&Message{ + ChunkStreamID: 27, + Timestamp: 18576 + 15, + Type: chunk.MessageTypeSetPeerBandwidth, + MessageStreamID: 3123, + Body: bytes.Repeat([]byte{0x04}, 64), + }) + require.NoError(t, err) + + var c2 chunk.Chunk2 + err = c2.Read(br, 64) + require.NoError(t, err) + require.Equal(t, chunk.Chunk2{ + ChunkStreamID: 27, + TimestampDelta: 15, + Body: bytes.Repeat([]byte{0x04}, 64), + }, c2) + + err = w.Write(&Message{ + ChunkStreamID: 27, + Timestamp: 18576 + 15 + 15, + Type: chunk.MessageTypeSetPeerBandwidth, + MessageStreamID: 3123, + Body: bytes.Repeat([]byte{0x05}, 64), + }) + require.NoError(t, err) + + var c3 chunk.Chunk3 + err = c3.Read(br, 64) + require.NoError(t, err) + require.Equal(t, chunk.Chunk3{ + ChunkStreamID: 27, + Body: bytes.Repeat([]byte{0x05}, 64), + }, c3) + }) + + t.Run("chunk0 + chunk3", func(t *testing.T) { + var buf bytes.Buffer + br := bufio.NewReader(&buf) + w := NewWriter(&buf) + + err := w.Write(&Message{ + ChunkStreamID: 27, + Timestamp: 18576, + Type: chunk.MessageTypeSetPeerBandwidth, + MessageStreamID: 3123, + Body: bytes.Repeat([]byte{0x03}, 192), + }) + require.NoError(t, err) + + var c0 chunk.Chunk0 + err = c0.Read(br, 128) + require.NoError(t, err) + require.Equal(t, chunk.Chunk0{ + ChunkStreamID: 27, + Timestamp: 18576, + Type: chunk.MessageTypeSetPeerBandwidth, + MessageStreamID: 3123, + BodyLen: 192, + Body: bytes.Repeat([]byte{0x03}, 128), + }, c0) + + var c3 chunk.Chunk3 + err = c3.Read(br, 64) + require.NoError(t, err) + require.Equal(t, chunk.Chunk3{ + ChunkStreamID: 27, + Body: bytes.Repeat([]byte{0x03}, 64), + }, c3) + }) +}