diff --git a/internal/rtmp/conn_test.go b/internal/rtmp/conn_test.go index 727bdb23..38391b0a 100644 --- a/internal/rtmp/conn_test.go +++ b/internal/rtmp/conn_test.go @@ -115,31 +115,34 @@ func TestReadTracks(t *testing.T) { defer conn.Close() // C->S handshake C0 - err = handshake.C0{}.Write(conn) + err = handshake.C0S0{}.Write(conn) require.NoError(t, err) // C->S handshake C1 - err = handshake.C1{}.Write(conn) + err = handshake.C1S1{}.Write(conn, true) require.NoError(t, err) // S->C handshake S0 - err = handshake.S0{}.Read(conn) + err = handshake.C0S0{}.Read(conn) require.NoError(t, err) - // S->C handshake S1+S2 - s1s2 := make([]byte, 1536*2) - _, err = conn.Read(s1s2) + // S->C handshake S1 + s1 := handshake.C1S1{} + err = s1.Read(conn, false) + require.NoError(t, err) + + // S->C handshake S2 + err = (&handshake.C2S2{}).Read(conn) require.NoError(t, err) // C->S handshake C2 - err = handshake.C2{}.Write(conn, s1s2) + err = handshake.C2S2{}.Write(conn, s1.Key) require.NoError(t, err) mw := message.NewWriter(conn) mr := message.NewReader(conn) // C->S connect - err = mw.Write(&message.MsgCommandAMF0{ ChunkStreamID: 3, Payload: []interface{}{ @@ -471,24 +474,28 @@ func TestWriteTracks(t *testing.T) { defer conn.Close() // C->S handshake C0 - err = handshake.C0{}.Write(conn) + err = handshake.C0S0{}.Write(conn) require.NoError(t, err) // C-> handshake C1 - err = handshake.C1{}.Write(conn) + err = handshake.C1S1{}.Write(conn, true) require.NoError(t, err) // S->C handshake S0 - err = handshake.S0{}.Read(conn) + err = handshake.C0S0{}.Read(conn) require.NoError(t, err) - // S->C handshake S1+S2 - s1s2 := make([]byte, 1536*2) - _, err = conn.Read(s1s2) + // S->C handshake S1 + s1 := handshake.C1S1{} + err = s1.Read(conn, false) + require.NoError(t, err) + + // S->C handshake S2 + err = (&handshake.C2S2{}).Read(conn) require.NoError(t, err) // C->S handshake C2 - err = handshake.C2{}.Write(conn, s1s2) + err = handshake.C2S2{}.Write(conn, s1.Key) require.NoError(t, err) mw := message.NewWriter(conn) diff --git a/internal/rtmp/handshake/c0.go b/internal/rtmp/handshake/c0.go deleted file mode 100644 index fe5ea892..00000000 --- a/internal/rtmp/handshake/c0.go +++ /dev/null @@ -1,18 +0,0 @@ -package handshake - -import ( - "io" -) - -const ( - rtmpVersion = 0x03 -) - -// C0 is the C0 part of an handshake. -type C0 struct{} - -// Read reads a C0. -func (C0) Write(w io.Writer) error { - _, err := w.Write([]byte{rtmpVersion}) - return err -} diff --git a/internal/rtmp/handshake/c0s0.go b/internal/rtmp/handshake/c0s0.go new file mode 100644 index 00000000..0e650c38 --- /dev/null +++ b/internal/rtmp/handshake/c0s0.go @@ -0,0 +1,34 @@ +package handshake + +import ( + "fmt" + "io" +) + +const ( + rtmpVersion = 0x03 +) + +// C0S0 is a C0 or S0 packet. +type C0S0 struct{} + +// Read reads a C0S0. +func (C0S0) Read(r io.Reader) error { + buf := make([]byte, 1) + _, err := io.ReadFull(r, buf) + if err != nil { + return err + } + + if buf[0] != rtmpVersion { + return fmt.Errorf("invalid rtmp version (%d)", buf[0]) + } + + return nil +} + +// Write writes a C0S0. +func (C0S0) Write(w io.Writer) error { + _, err := w.Write([]byte{rtmpVersion}) + return err +} diff --git a/internal/rtmp/handshake/c0s0_test.go b/internal/rtmp/handshake/c0s0_test.go new file mode 100644 index 00000000..3ea439eb --- /dev/null +++ b/internal/rtmp/handshake/c0s0_test.go @@ -0,0 +1,26 @@ +package handshake + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/require" +) + +var c0s0enc = []byte{0x03} + +var c0s0dec = C0S0{} + +func TestC0S0Read(t *testing.T) { + var c0s0 C0S0 + err := c0s0.Read(bytes.NewReader(c0s0enc)) + require.NoError(t, err) + require.Equal(t, c0s0dec, c0s0) +} + +func TestC0S0Write(t *testing.T) { + var buf bytes.Buffer + err := c0s0dec.Write(&buf) + require.NoError(t, err) + require.Equal(t, c0s0enc, buf.Bytes()) +} diff --git a/internal/rtmp/handshake/c1.go b/internal/rtmp/handshake/c1.go deleted file mode 100644 index fdce9221..00000000 --- a/internal/rtmp/handshake/c1.go +++ /dev/null @@ -1,67 +0,0 @@ -package handshake - -import ( - "crypto/hmac" - "crypto/rand" - "crypto/sha256" - "io" -) - -var ( - hsClientFullKey = []byte{ - 'G', 'e', 'n', 'u', 'i', 'n', 'e', ' ', 'A', 'd', 'o', 'b', 'e', ' ', - 'F', 'l', 'a', 's', 'h', ' ', 'P', 'l', 'a', 'y', 'e', 'r', ' ', - '0', '0', '1', - 0xF0, 0xEE, 0xC2, 0x4A, 0x80, 0x68, 0xBE, 0xE8, 0x2E, 0x00, 0xD0, 0xD1, - 0x02, 0x9E, 0x7E, 0x57, 0x6E, 0xEC, 0x5D, 0x2D, 0x29, 0x80, 0x6F, 0xAB, - 0x93, 0xB8, 0xE6, 0x36, 0xCF, 0xEB, 0x31, 0xAE, - } - hsServerFullKey = []byte{ - 'G', 'e', 'n', 'u', 'i', 'n', 'e', ' ', 'A', 'd', 'o', 'b', 'e', ' ', - 'F', 'l', 'a', 's', 'h', ' ', 'M', 'e', 'd', 'i', 'a', ' ', - 'S', 'e', 'r', 'v', 'e', 'r', ' ', - '0', '0', '1', - 0xF0, 0xEE, 0xC2, 0x4A, 0x80, 0x68, 0xBE, 0xE8, 0x2E, 0x00, 0xD0, 0xD1, - 0x02, 0x9E, 0x7E, 0x57, 0x6E, 0xEC, 0x5D, 0x2D, 0x29, 0x80, 0x6F, 0xAB, - 0x93, 0xB8, 0xE6, 0x36, 0xCF, 0xEB, 0x31, 0xAE, - } - hsClientPartialKey = hsClientFullKey[:30] - hsServerPartialKey = hsServerFullKey[:36] -) - -func hsCalcDigestPos(p []byte, base int) (pos int) { - for i := 0; i < 4; i++ { - pos += int(p[base+i]) - } - pos = (pos % 728) + base + 4 - return -} - -func hsMakeDigest(key []byte, src []byte, gap int) (dst []byte) { - h := hmac.New(sha256.New, key) - if gap <= 0 { - h.Write(src) - } else { - h.Write(src[:gap]) - h.Write(src[gap+32:]) - } - return h.Sum(nil) -} - -// C1 is the C1 part of an handshake. -type C1 struct{} - -// Read reads a C1. -func (C1) Write(w io.Writer) error { - buf := make([]byte, 1536) - copy(buf[0:4], []byte{0x00, 0x00, 0x00, 0x00}) - copy(buf[4:8], []byte{0x09, 0x00, 0x7c, 0x02}) - - rand.Read(buf[8:]) - gap := hsCalcDigestPos(buf[0:], 8) - digest := hsMakeDigest(hsClientPartialKey, buf[0:], gap) - copy(buf[gap+0:], digest) - - _, err := w.Write(buf[0:]) - return err -} diff --git a/internal/rtmp/handshake/c1s1.go b/internal/rtmp/handshake/c1s1.go new file mode 100644 index 00000000..f246fd81 --- /dev/null +++ b/internal/rtmp/handshake/c1s1.go @@ -0,0 +1,138 @@ +package handshake + +import ( + "bytes" + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "encoding/binary" + "fmt" + "io" +) + +var ( + hsClientFullKey = []byte{ + 'G', 'e', 'n', 'u', 'i', 'n', 'e', ' ', 'A', 'd', 'o', 'b', 'e', ' ', + 'F', 'l', 'a', 's', 'h', ' ', 'P', 'l', 'a', 'y', 'e', 'r', ' ', + '0', '0', '1', + 0xF0, 0xEE, 0xC2, 0x4A, 0x80, 0x68, 0xBE, 0xE8, 0x2E, 0x00, 0xD0, 0xD1, + 0x02, 0x9E, 0x7E, 0x57, 0x6E, 0xEC, 0x5D, 0x2D, 0x29, 0x80, 0x6F, 0xAB, + 0x93, 0xB8, 0xE6, 0x36, 0xCF, 0xEB, 0x31, 0xAE, + } + hsServerFullKey = []byte{ + 'G', 'e', 'n', 'u', 'i', 'n', 'e', ' ', 'A', 'd', 'o', 'b', 'e', ' ', + 'F', 'l', 'a', 's', 'h', ' ', 'M', 'e', 'd', 'i', 'a', ' ', + 'S', 'e', 'r', 'v', 'e', 'r', ' ', + '0', '0', '1', + 0xF0, 0xEE, 0xC2, 0x4A, 0x80, 0x68, 0xBE, 0xE8, 0x2E, 0x00, 0xD0, 0xD1, + 0x02, 0x9E, 0x7E, 0x57, 0x6E, 0xEC, 0x5D, 0x2D, 0x29, 0x80, 0x6F, 0xAB, + 0x93, 0xB8, 0xE6, 0x36, 0xCF, 0xEB, 0x31, 0xAE, + } + hsClientPartialKey = hsClientFullKey[:30] + hsServerPartialKey = hsServerFullKey[:36] +) + +func hsCalcDigestPos(p []byte, base int) (pos int) { + for i := 0; i < 4; i++ { + pos += int(p[base+i]) + } + pos = (pos % 728) + base + 4 + return +} + +func hsMakeDigest(key []byte, src []byte, gap int) (dst []byte) { + h := hmac.New(sha256.New, key) + if gap <= 0 { + h.Write(src) + } else { + h.Write(src[:gap]) + h.Write(src[gap+32:]) + } + return h.Sum(nil) +} + +func hsFindDigest(p []byte, key []byte, base int) int { + gap := hsCalcDigestPos(p, base) + digest := hsMakeDigest(key, p, gap) + if !bytes.Equal(p[gap:gap+32], digest) { + return -1 + } + return gap +} + +func hsParse1(p []byte, peerkey []byte, key []byte) (ok bool, digest []byte) { + var pos int + if pos = hsFindDigest(p, peerkey, 772); pos == -1 { + if pos = hsFindDigest(p, peerkey, 8); pos == -1 { + return + } + } + ok = true + digest = hsMakeDigest(key, p[pos:pos+32], -1) + return +} + +// C1S1 is a C1 or S1 packet. +type C1S1 struct { + Time uint32 + Random []byte + Key []byte +} + +// Read reads a C1S1. +func (c *C1S1) Read(r io.Reader, isC1 bool) error { + buf := make([]byte, 1536) + _, err := io.ReadFull(r, buf) + if err != nil { + return err + } + + // validate signature + var peerKey []byte + var key []byte + if isC1 { + peerKey = hsClientPartialKey + key = hsServerFullKey + } else { + peerKey = hsServerPartialKey + key = hsClientFullKey + } + ok, key := hsParse1(buf, peerKey, key) + if !ok { + return fmt.Errorf("unable to validate C1/S1 signature") + } + + c.Time = binary.BigEndian.Uint32(buf) + c.Random = buf[8:] + c.Key = key + + return nil +} + +// Write writes a C1S1. +func (c C1S1) Write(w io.Writer, isC1 bool) error { + buf := make([]byte, 1536) + + binary.BigEndian.PutUint32(buf, c.Time) + copy(buf[4:], []byte{0, 0, 0, 0}) + + if c.Random == nil { + rand.Read(buf[8:]) + } else { + copy(buf[8:], c.Random) + } + + // signature + gap := hsCalcDigestPos(buf, 8) + var key []byte + if isC1 { + key = hsClientPartialKey + } else { + key = hsServerPartialKey + } + digest := hsMakeDigest(key, buf, gap) + copy(buf[gap:], digest) + + _, err := w.Write(buf) + return err +} diff --git a/internal/rtmp/handshake/c1s1_test.go b/internal/rtmp/handshake/c1s1_test.go new file mode 100644 index 00000000..7e09dead --- /dev/null +++ b/internal/rtmp/handshake/c1s1_test.go @@ -0,0 +1,80 @@ +package handshake + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestC1S1Read(t *testing.T) { + c1s1dec := C1S1{ + Time: 435234723, + Random: append( + []byte{ + 0x01, 0x02, 0x03, 0x04, 0x01, 0x02, 0x03, 0x04, + 0x01, 0x02, 0x03, 0x04, 0x01, 0x02, 0x2d, 0x0a, + 0x37, 0x6f, 0x63, 0x2e, 0xa0, 0x21, 0xa0, 0xa4, + 0x81, 0xb1, 0x50, 0x21, 0x5a, 0x6d, 0x81, 0xad, + 0xf8, 0x44, 0x69, 0x13, 0xcc, 0x02, 0x8c, 0xd4, + 0x64, 0x43, 0xc9, 0x9f, 0xcf, 0xc6, 0x03, 0x04, + }, + bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 370)..., + ), + Key: []byte{ + 0x3f, 0xd0, 0xb1, 0xdf, 0xed, 0x6c, 0x9b, 0xc3, + 0x73, 0x68, 0xe2, 0x47, 0x92, 0x59, 0x32, 0x9a, + 0x3a, 0xc9, 0x1e, 0xeb, 0xfc, 0xad, 0x8e, 0x9d, + 0x4e, 0xf4, 0x30, 0xb1, 0x9a, 0xc9, 0x15, 0x99, + }, + } + + c1s1enc := append( + []byte{ + 0x19, 0xf1, 0x27, 0xa3, 0x00, 0x00, 0x00, 0x00, + 0x01, 0x02, 0x03, 0x04, 0x01, 0x02, 0x03, 0x04, + 0x01, 0x02, 0x03, 0x04, 0x01, 0x02, 0x2d, 0x0a, + 0x37, 0x6f, 0x63, 0x2e, 0xa0, 0x21, 0xa0, 0xa4, + 0x81, 0xb1, 0x50, 0x21, 0x5a, 0x6d, 0x81, 0xad, + 0xf8, 0x44, 0x69, 0x13, 0xcc, 0x02, 0x8c, 0xd4, + 0x64, 0x43, 0xc9, 0x9f, 0xcf, 0xc6, 0x03, 0x04, + }, + bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 370)..., + ) + + var c1s1 C1S1 + err := c1s1.Read(bytes.NewReader(c1s1enc), true) + require.NoError(t, err) + require.Equal(t, c1s1dec, c1s1) +} + +func TestC1S1Write(t *testing.T) { + c1s1dec := C1S1{ + Time: 435234723, + Random: bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 382), + Key: []byte{ + 0x3f, 0xd0, 0xb1, 0xdf, 0xed, 0x6c, 0x9b, 0xc3, + 0x73, 0x68, 0xe2, 0x47, 0x92, 0x59, 0x32, 0x9a, + 0x3a, 0xc9, 0x1e, 0xeb, 0xfc, 0xad, 0x8e, 0x9d, + 0x4e, 0xf4, 0x30, 0xb1, 0x9a, 0xc9, 0x15, 0x99, + }, + } + + c1s1enc := append( + []byte{ + 0x19, 0xf1, 0x27, 0xa3, 0x00, 0x00, 0x00, 0x00, + 0x01, 0x02, 0x03, 0x04, 0x01, 0x02, 0x03, 0x04, + 0x01, 0x02, 0x03, 0x04, 0x01, 0x02, 0x2d, 0x0a, + 0x37, 0x6f, 0x63, 0x2e, 0xa0, 0x21, 0xa0, 0xa4, + 0x81, 0xb1, 0x50, 0x21, 0x5a, 0x6d, 0x81, 0xad, + 0xf8, 0x44, 0x69, 0x13, 0xcc, 0x02, 0x8c, 0xd4, + 0x64, 0x43, 0xc9, 0x9f, 0xcf, 0xc6, 0x03, 0x04, + }, + bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 370)..., + ) + + var buf bytes.Buffer + err := c1s1dec.Write(&buf, true) + require.NoError(t, err) + require.Equal(t, c1s1enc, buf.Bytes()) +} diff --git a/internal/rtmp/handshake/c2.go b/internal/rtmp/handshake/c2.go deleted file mode 100644 index 6c156ed3..00000000 --- a/internal/rtmp/handshake/c2.go +++ /dev/null @@ -1,48 +0,0 @@ -package handshake - -import ( - "bytes" - "crypto/rand" - "fmt" - "io" -) - -func hsFindDigest(p []byte, key []byte, base int) int { - gap := hsCalcDigestPos(p, base) - digest := hsMakeDigest(key, p, gap) - if !bytes.Equal(p[gap:gap+32], digest) { - return -1 - } - return gap -} - -func hsParse1(p []byte, peerkey []byte, key []byte) (ok bool, digest []byte) { - var pos int - if pos = hsFindDigest(p, peerkey, 772); pos == -1 { - if pos = hsFindDigest(p, peerkey, 8); pos == -1 { - return - } - } - ok = true - digest = hsMakeDigest(key, p[pos:pos+32], -1) - return -} - -// C2 is the C2 part of an handshake. -type C2 struct{} - -// Read reads a C2. -func (C2) Write(w io.Writer, s1s2 []byte) error { - ok, key := hsParse1(s1s2[:1536], hsServerPartialKey, hsClientFullKey) - if !ok { - return fmt.Errorf("unable to parse S1+S2") - } - - buf := make([]byte, 1536) - rand.Read(buf) - gap := len(buf) - 32 - digest := hsMakeDigest(key, buf, gap) - copy(buf[gap:], digest) - _, err := w.Write(buf) - return err -} diff --git a/internal/rtmp/handshake/c2s2.go b/internal/rtmp/handshake/c2s2.go new file mode 100644 index 00000000..9d2b7fe9 --- /dev/null +++ b/internal/rtmp/handshake/c2s2.go @@ -0,0 +1,50 @@ +package handshake + +import ( + "crypto/rand" + "encoding/binary" + "io" +) + +// C2S2 is a C2 or S2 packet. +type C2S2 struct { + Time uint32 + Time2 uint32 + Random []byte +} + +// Read reads a C2S2. +func (c *C2S2) Read(r io.Reader) error { + buf := make([]byte, 1536) + _, err := io.ReadFull(r, buf) + if err != nil { + return err + } + + c.Time = binary.BigEndian.Uint32(buf) + c.Time2 = binary.BigEndian.Uint32(buf[4:]) + c.Random = buf[8:] + + return nil +} + +// Write writes a C2S2. +func (c C2S2) Write(w io.Writer, key []byte) error { + buf := make([]byte, 1536) + binary.BigEndian.PutUint32(buf, c.Time) + binary.BigEndian.PutUint32(buf[4:], c.Time2) + + if c.Random == nil { + rand.Read(buf[8:]) + } else { + copy(buf[8:], c.Random) + } + + // signature + gap := len(buf) - 32 + digest := hsMakeDigest(key, buf, gap) + copy(buf[gap:], digest) + + _, err := w.Write(buf) + return err +} diff --git a/internal/rtmp/handshake/s0.go b/internal/rtmp/handshake/s0.go deleted file mode 100644 index d7ddcc2a..00000000 --- a/internal/rtmp/handshake/s0.go +++ /dev/null @@ -1,24 +0,0 @@ -package handshake - -import ( - "fmt" - "io" -) - -// S0 is the S0 part of an handshake. -type S0 struct{} - -// Read reads a S0. -func (S0) Read(r io.Reader) error { - buf := make([]byte, 1) - _, err := io.ReadFull(r, buf) - if err != nil { - return err - } - - if buf[0] != rtmpVersion { - return fmt.Errorf("invalid rtmp version (%d)", buf[0]) - } - - return nil -}