rtmp: add handshake functions

This commit is contained in:
aler9 2022-07-16 15:18:04 +02:00
parent 772c5b2363
commit ef3e18a9e9
3 changed files with 119 additions and 96 deletions

View File

@ -63,30 +63,7 @@ func TestClientHandshake(t *testing.T) {
defer conn.Close()
bc := bytecounter.NewReadWriter(conn)
// C->S handshake C0
err = handshake.C0S0{}.Read(bc)
require.NoError(t, err)
// C->S handshake C1
c1 := handshake.C1S1{}
err = c1.Read(bc, true)
require.NoError(t, err)
// S->C handshake S0
err = handshake.C0S0{}.Write(bc)
require.NoError(t, err)
// S->C handshake S1
s1 := handshake.C1S1{}
err = s1.Write(bc, false)
require.NoError(t, err)
// S->C handshake S2
err = handshake.C2S2{Digest: c1.Digest}.Write(bc)
require.NoError(t, err)
// C->S handshake C2
err = (&handshake.C2S2{Digest: s1.Digest}).Read(bc)
err = handshake.DoServer(bc)
require.NoError(t, err)
mrw := message.NewReadWriter(bc)
@ -345,30 +322,7 @@ func TestServerHandshake(t *testing.T) {
defer conn.Close()
bc := bytecounter.NewReadWriter(conn)
// C->S handshake C0
err = handshake.C0S0{}.Write(bc)
require.NoError(t, err)
// C->S handshake C1
c1 := handshake.C1S1{}
err = c1.Write(bc, true)
require.NoError(t, err)
// S->C handshake S0
err = handshake.C0S0{}.Read(bc)
require.NoError(t, err)
// S->C handshake S1
s1 := handshake.C1S1{}
err = s1.Read(bc, false)
require.NoError(t, err)
// S->C handshake S2
err = (&handshake.C2S2{Digest: c1.Digest}).Read(bc)
require.NoError(t, err)
// C->S handshake C2
err = handshake.C2S2{Digest: s1.Digest}.Write(bc)
err = handshake.DoClient(bc)
require.NoError(t, err)
mrw := message.NewReadWriter(bc)
@ -656,30 +610,7 @@ func TestReadTracks(t *testing.T) {
defer conn.Close()
bc := bytecounter.NewReadWriter(conn)
// C->S handshake C0
err = handshake.C0S0{}.Write(bc)
require.NoError(t, err)
// C->S handshake C1
c1 := handshake.C1S1{}
err = c1.Write(bc, true)
require.NoError(t, err)
// S->C handshake S0
err = handshake.C0S0{}.Read(bc)
require.NoError(t, err)
// S->C handshake S1
s1 := handshake.C1S1{}
err = s1.Read(bc, false)
require.NoError(t, err)
// S->C handshake S2
err = (&handshake.C2S2{Digest: c1.Digest}).Read(bc)
require.NoError(t, err)
// C->S handshake C2
err = handshake.C2S2{Digest: s1.Digest}.Write(bc)
err = handshake.DoClient(bc)
require.NoError(t, err)
mrw := message.NewReadWriter(bc)
@ -1061,30 +992,7 @@ func TestWriteTracks(t *testing.T) {
defer conn.Close()
bc := bytecounter.NewReadWriter(conn)
// C->S handshake C0
err = handshake.C0S0{}.Write(bc)
require.NoError(t, err)
// C-> handshake C1
c1 := handshake.C1S1{}
err = c1.Write(bc, true)
require.NoError(t, err)
// S->C handshake S0
err = handshake.C0S0{}.Read(bc)
require.NoError(t, err)
// S->C handshake S1
s1 := handshake.C1S1{}
err = s1.Read(bc, false)
require.NoError(t, err)
// S->C handshake S2
err = (&handshake.C2S2{Digest: c1.Digest}).Read(bc)
require.NoError(t, err)
// C->S handshake C2
err = handshake.C2S2{Digest: s1.Digest}.Write(bc)
err = handshake.DoClient(bc)
require.NoError(t, err)
mrw := message.NewReadWriter(bc)

View File

@ -0,0 +1,79 @@
package handshake
import (
"io"
)
// DoClient performs a client-side handshake.
func DoClient(rw io.ReadWriter) error {
err := C0S0{}.Write(rw)
if err != nil {
return err
}
c1 := C1S1{}
err = c1.Write(rw, true)
if err != nil {
return err
}
err = C0S0{}.Read(rw)
if err != nil {
return err
}
s1 := C1S1{}
err = s1.Read(rw, false)
if err != nil {
return err
}
err = (&C2S2{Digest: c1.Digest}).Read(rw)
if err != nil {
return err
}
err = C2S2{Digest: s1.Digest}.Write(rw)
if err != nil {
return err
}
return nil
}
// DoServer performs a server-side handshake.
func DoServer(rw io.ReadWriter) error {
err := C0S0{}.Read(rw)
if err != nil {
return err
}
c1 := C1S1{}
err = c1.Read(rw, true)
if err != nil {
return err
}
err = C0S0{}.Write(rw)
if err != nil {
return err
}
s1 := C1S1{}
err = s1.Write(rw, false)
if err != nil {
return err
}
err = C2S2{Digest: c1.Digest}.Write(rw)
if err != nil {
return err
}
err = (&C2S2{Digest: s1.Digest}).Read(rw)
if err != nil {
return err
}
return nil
}

View File

@ -0,0 +1,36 @@
package handshake
import (
"net"
"testing"
"github.com/stretchr/testify/require"
)
func TestHandshake(t *testing.T) {
ln, err := net.Listen("tcp", "127.0.0.1:9122")
require.NoError(t, err)
defer ln.Close()
done := make(chan struct{})
go func() {
conn, err := ln.Accept()
require.NoError(t, err)
defer conn.Close()
err = DoServer(conn)
require.NoError(t, err)
close(done)
}()
conn, err := net.Dial("tcp", "127.0.0.1:9122")
require.NoError(t, err)
defer conn.Close()
err = DoClient(conn)
require.NoError(t, err)
<-done
}