From 772c5b23633421f2ca4b0dca0959ada7045c06ae Mon Sep 17 00:00:00 2001 From: aler9 <46489434+aler9@users.noreply.github.com> Date: Sat, 16 Jul 2022 12:42:48 +0200 Subject: [PATCH] rtmp: improve tests --- internal/rtmp/bytecounter/reader.go | 5 ++ internal/rtmp/bytecounter/reader_test.go | 4 +- internal/rtmp/bytecounter/writer.go | 5 ++ internal/rtmp/bytecounter/writer_test.go | 5 +- internal/rtmp/rawmessage/reader.go | 1 - internal/rtmp/rawmessage/reader_test.go | 58 +++++++++++++----------- internal/rtmp/rawmessage/writer.go | 19 ++++---- internal/rtmp/rawmessage/writer_test.go | 49 ++++++++++++-------- 8 files changed, 87 insertions(+), 59 deletions(-) diff --git a/internal/rtmp/bytecounter/reader.go b/internal/rtmp/bytecounter/reader.go index c07a15a1..1babd89c 100644 --- a/internal/rtmp/bytecounter/reader.go +++ b/internal/rtmp/bytecounter/reader.go @@ -35,3 +35,8 @@ func NewReader(r io.Reader) *Reader { func (r Reader) Count() uint32 { return r.ri.count } + +// SetCount sets read bytes. +func (r *Reader) SetCount(v uint32) { + r.ri.count = v +} diff --git a/internal/rtmp/bytecounter/reader_test.go b/internal/rtmp/bytecounter/reader_test.go index 6778f375..0093bfa2 100644 --- a/internal/rtmp/bytecounter/reader_test.go +++ b/internal/rtmp/bytecounter/reader_test.go @@ -12,10 +12,12 @@ func TestReader(t *testing.T) { buf.Write(bytes.Repeat([]byte{0x01}, 1024)) r := NewReader(&buf) + r.SetCount(100) + buf2 := make([]byte, 64) n, err := r.Read(buf2) require.NoError(t, err) require.Equal(t, 64, n) - require.Equal(t, uint32(1024), r.Count()) + require.Equal(t, uint32(100+1024), r.Count()) } diff --git a/internal/rtmp/bytecounter/writer.go b/internal/rtmp/bytecounter/writer.go index 53de4332..413adc1b 100644 --- a/internal/rtmp/bytecounter/writer.go +++ b/internal/rtmp/bytecounter/writer.go @@ -28,3 +28,8 @@ func (w *Writer) Write(p []byte) (int, error) { func (w Writer) Count() uint32 { return w.count } + +// SetCount sets written bytes. +func (w *Writer) SetCount(v uint32) { + w.count = v +} diff --git a/internal/rtmp/bytecounter/writer_test.go b/internal/rtmp/bytecounter/writer_test.go index b2ecae4d..c2684523 100644 --- a/internal/rtmp/bytecounter/writer_test.go +++ b/internal/rtmp/bytecounter/writer_test.go @@ -9,7 +9,10 @@ import ( func TestWriter(t *testing.T) { var buf bytes.Buffer + w := NewWriter(&buf) + w.SetCount(100) + w.Write(bytes.Repeat([]byte{0x01}, 64)) - require.Equal(t, uint32(64), w.Count()) + require.Equal(t, uint32(100+64), w.Count()) } diff --git a/internal/rtmp/rawmessage/reader.go b/internal/rtmp/rawmessage/reader.go index ab5bdad6..ab57fcdb 100644 --- a/internal/rtmp/rawmessage/reader.go +++ b/internal/rtmp/rawmessage/reader.go @@ -30,7 +30,6 @@ func (rc *readerChunkStream) readChunk(c chunk.Chunk, chunkBodySize uint32) erro if rc.mr.ackWindowSize != 0 { count := rc.mr.r.Count() diff := count - rc.mr.lastAckCount - // TODO: handle overflow if diff > (rc.mr.ackWindowSize) { err := rc.mr.onAckNeeded(count) diff --git a/internal/rtmp/rawmessage/reader_test.go b/internal/rtmp/rawmessage/reader_test.go index eff82407..0c93c99d 100644 --- a/internal/rtmp/rawmessage/reader_test.go +++ b/internal/rtmp/rawmessage/reader_test.go @@ -158,34 +158,40 @@ func TestReader(t *testing.T) { } func TestReaderAcknowledge(t *testing.T) { - onAckCalled := make(chan struct{}) + for _, ca := range []string{"standard", "overflow"} { + t.Run(ca, func(t *testing.T) { + onAckCalled := make(chan struct{}) - var buf bytes.Buffer - bcr := bytecounter.NewReader(&buf) - r := NewReader(bcr, func(count uint32) error { - close(onAckCalled) - return nil - }) + var buf bytes.Buffer + bcr := bytecounter.NewReader(&buf) + r := NewReader(bcr, func(count uint32) error { + close(onAckCalled) + return nil + }) - r.SetWindowAckSize(100) + if ca == "overflow" { + bcr.SetCount(4294967096) + r.lastAckCount = 4294967096 + } - for i := 0; i < 2; i++ { - buf2, err := chunk.Chunk0{ - ChunkStreamID: 27, - Timestamp: 18576, - Type: chunk.MessageTypeSetPeerBandwidth, - MessageStreamID: 3123, - BodyLen: 64, - Body: bytes.Repeat([]byte{0x03}, 64), - }.Marshal() - require.NoError(t, err) - buf.Write(buf2) + r.SetChunkSize(65536) + r.SetWindowAckSize(100) + + buf2, err := chunk.Chunk0{ + ChunkStreamID: 27, + Timestamp: 18576, + Type: chunk.MessageTypeSetPeerBandwidth, + MessageStreamID: 3123, + BodyLen: 200, + Body: bytes.Repeat([]byte{0x03}, 200), + }.Marshal() + require.NoError(t, err) + buf.Write(buf2) + + _, err = r.Read() + require.NoError(t, err) + + <-onAckCalled + }) } - - for i := 0; i < 2; i++ { - _, err := r.Read() - require.NoError(t, err) - } - - <-onAckCalled } diff --git a/internal/rtmp/rawmessage/writer.go b/internal/rtmp/rawmessage/writer.go index 7087fa60..011498da 100644 --- a/internal/rtmp/rawmessage/writer.go +++ b/internal/rtmp/rawmessage/writer.go @@ -17,6 +17,15 @@ type writerChunkStream struct { } func (wc *writerChunkStream) writeChunk(c chunk.Chunk) error { + // check if we received an acknowledge + if wc.mw.ackWindowSize != 0 { + diff := wc.mw.w.Count() - (wc.mw.ackValue) + + if diff > (wc.mw.ackWindowSize * 3 / 2) { + return fmt.Errorf("no acknowledge received within window") + } + } + buf, err := c.Marshal() if err != nil { return err @@ -27,16 +36,6 @@ func (wc *writerChunkStream) writeChunk(c chunk.Chunk) error { return err } - // check if we received an acknowledge - if wc.mw.ackWindowSize != 0 { - diff := wc.mw.w.Count() - (wc.mw.ackValue) - // TODO: handle overflow - - if diff > (wc.mw.ackWindowSize * 3 / 2) { - return fmt.Errorf("no acknowledge received within window") - } - } - return nil } diff --git a/internal/rtmp/rawmessage/writer_test.go b/internal/rtmp/rawmessage/writer_test.go index 9b2f557f..ae4832cc 100644 --- a/internal/rtmp/rawmessage/writer_test.go +++ b/internal/rtmp/rawmessage/writer_test.go @@ -153,28 +153,37 @@ func TestWriter(t *testing.T) { } func TestWriterAcknowledge(t *testing.T) { - var buf bytes.Buffer - w := NewWriter(bytecounter.NewWriter(&buf)) + for _, ca := range []string{"standard", "overflow"} { + t.Run(ca, func(t *testing.T) { + var buf bytes.Buffer + bcw := bytecounter.NewWriter(&buf) + w := NewWriter(bcw) - w.SetWindowAckSize(100) + if ca == "overflow" { + bcw.SetCount(4294967096) + w.ackValue = 4294967096 + } - for i := 0; i < 2; i++ { - err := w.Write(&Message{ - ChunkStreamID: 27, - Timestamp: 18576, - Type: chunk.MessageTypeSetPeerBandwidth, - MessageStreamID: 3123, - Body: bytes.Repeat([]byte{0x03}, 64), + w.SetChunkSize(65536) + w.SetWindowAckSize(100) + + err := w.Write(&Message{ + ChunkStreamID: 27, + Timestamp: 18576, + Type: chunk.MessageTypeSetPeerBandwidth, + MessageStreamID: 3123, + Body: bytes.Repeat([]byte{0x03}, 200), + }) + require.NoError(t, err) + + err = w.Write(&Message{ + ChunkStreamID: 27, + Timestamp: 18576, + Type: chunk.MessageTypeSetPeerBandwidth, + MessageStreamID: 3123, + Body: bytes.Repeat([]byte{0x03}, 200), + }) + require.EqualError(t, err, "no acknowledge received within window") }) - require.NoError(t, err) } - - err := w.Write(&Message{ - ChunkStreamID: 27, - Timestamp: 18576, - Type: chunk.MessageTypeSetPeerBandwidth, - MessageStreamID: 3123, - Body: bytes.Repeat([]byte{0x03}, 64), - }) - require.EqualError(t, err, "no acknowledge received within window") }