diff --git a/internal/rtmp/conn.go b/internal/rtmp/conn.go index 431ecc0d..08bb74a6 100644 --- a/internal/rtmp/conn.go +++ b/internal/rtmp/conn.go @@ -122,7 +122,7 @@ func readCommandResult( } if cmd, ok := msg.(*message.CommandAMF0); ok { - if cmd.CommandID == commandID && cmd.Name == commandName { + if (cmd.CommandID == commandID || cmd.CommandID == 0) && cmd.Name == commandName { if !isValid(cmd) { return fmt.Errorf("server refused connect request") } diff --git a/internal/rtmp/conn_test.go b/internal/rtmp/conn_test.go index d09fd405..929385b7 100644 --- a/internal/rtmp/conn_test.go +++ b/internal/rtmp/conn_test.go @@ -15,7 +15,11 @@ import ( ) func TestNewClientConn(t *testing.T) { - for _, ca := range []string{"read", "publish"} { + for _, ca := range []string{ + "read", + "read nginx rtmp", + "publish", + } { t.Run(ca, func(t *testing.T) { ln, err := net.Listen("tcp", "127.0.0.1:9121") require.NoError(t, err) @@ -92,7 +96,8 @@ func TestNewClientConn(t *testing.T) { }) require.NoError(t, err) - if ca == "read" { + switch ca { + case "read", "read nginx rtmp": msg, err = mrw.Read() require.NoError(t, err) require.Equal(t, &message.CommandAMF0{ @@ -138,7 +143,12 @@ func TestNewClientConn(t *testing.T) { ChunkStreamID: 5, MessageStreamID: 0x1000000, Name: "onStatus", - CommandID: 3, + CommandID: func() int { + if ca == "read nginx rtmp" { + return 0 + } + return 3 + }(), Arguments: []interface{}{ nil, flvio.AMFMap{ @@ -149,7 +159,8 @@ func TestNewClientConn(t *testing.T) { }, }) require.NoError(t, err) - } else { + + case "publish": msg, err = mrw.Read() require.NoError(t, err) require.Equal(t, &message.CommandAMF0{ @@ -240,10 +251,12 @@ func TestNewClientConn(t *testing.T) { conn, err := NewClientConn(nconn, u, ca == "publish") require.NoError(t, err) - if ca == "read" { + switch ca { + case "read", "read nginx rtmp": require.Equal(t, uint64(3421), conn.BytesReceived()) require.Equal(t, uint64(3409), conn.BytesSent()) - } else { + + case "publish": require.Equal(t, uint64(3427), conn.BytesReceived()) require.Equal(t, uint64(3466), conn.BytesSent()) }