From b20abbed6c6d51f68cc8d2a5132c27d9298221b6 Mon Sep 17 00:00:00 2001 From: aler9 <46489434+aler9@users.noreply.github.com> Date: Sun, 8 Jan 2023 15:36:33 +0100 Subject: [PATCH] webrtc muxer: keep the WebSocket connection The WebSocket connection is kept open in order to use it to notify shutdowns. --- internal/core/webrtc_conn.go | 61 ++++++-------- internal/core/webrtc_index.html | 11 --- internal/core/webrtc_server.go | 14 +--- internal/hls/muxer.go | 2 +- internal/rtmp/conn.go | 2 +- internal/websocket/serverconn.go | 114 ++++++++++++++++++++++++++ internal/websocket/serverconn_test.go | 55 +++++++++++++ 7 files changed, 202 insertions(+), 57 deletions(-) create mode 100644 internal/websocket/serverconn.go create mode 100644 internal/websocket/serverconn_test.go diff --git a/internal/core/webrtc_conn.go b/internal/core/webrtc_conn.go index 4212dfbc..68e455c0 100644 --- a/internal/core/webrtc_conn.go +++ b/internal/core/webrtc_conn.go @@ -5,7 +5,6 @@ import ( "crypto/hmac" "crypto/sha1" "encoding/base64" - "encoding/json" "errors" "fmt" "math/rand" @@ -22,7 +21,6 @@ import ( "github.com/aler9/gortsplib/v2/pkg/media" "github.com/aler9/gortsplib/v2/pkg/ringbuffer" "github.com/google/uuid" - "github.com/gorilla/websocket" "github.com/pion/ice/v2" "github.com/pion/interceptor" "github.com/pion/webrtc/v3" @@ -30,11 +28,13 @@ import ( "github.com/aler9/rtsp-simple-server/internal/conf" "github.com/aler9/rtsp-simple-server/internal/formatprocessor" "github.com/aler9/rtsp-simple-server/internal/logger" + "github.com/aler9/rtsp-simple-server/internal/websocket" ) const ( webrtcHandshakeDeadline = 10 * time.Second - webrtcPayloadMaxSize = 1200 + webrtcWsWriteDeadline = 2 * time.Second + webrtcPayloadMaxSize = 1188 // 1200 - 12 (RTP header) ) // newPeerConnection creates a PeerConnection with the default codecs and @@ -91,7 +91,7 @@ type webRTCConnParent interface { type webRTCConn struct { readBufferCount int pathName string - wsconn *websocket.Conn + wsconn *websocket.ServerConn iceServers []string wg *sync.WaitGroup pathManager webRTCConnPathManager @@ -114,7 +114,7 @@ func newWebRTCConn( parentCtx context.Context, readBufferCount int, pathName string, - wsconn *websocket.Conn, + wsconn *websocket.ServerConn, iceServers []string, wg *sync.WaitGroup, pathManager webRTCConnPathManager, @@ -311,10 +311,6 @@ func (c *webRTCConn) runInner(ctx context.Context) error { return err } - // maximum deadline to complete the handshake - c.wsconn.SetReadDeadline(time.Now().Add(webrtcHandshakeDeadline)) - c.wsconn.SetWriteDeadline(time.Now().Add(webrtcHandshakeDeadline)) - err = c.writeICEServers(c.genICEServers()) if err != nil { return err @@ -425,7 +421,7 @@ func (c *webRTCConn) runInner(ctx context.Context) error { return err } - readError := make(chan error) + wsReadError := make(chan error) remoteCandidate := make(chan *webrtc.ICECandidateInit) go func() { @@ -433,8 +429,7 @@ func (c *webRTCConn) runInner(ctx context.Context) error { candidate, err := c.readCandidate() if err != nil { select { - case readError <- err: - case <-pcConnected: + case wsReadError <- err: case <-ctx.Done(): } return @@ -448,6 +443,9 @@ func (c *webRTCConn) runInner(ctx context.Context) error { } }() + t := time.NewTimer(webrtcHandshakeDeadline) + defer t.Stop() + outer: for { select { @@ -462,9 +460,12 @@ outer: return err } - case err := <-readError: + case err := <-wsReadError: return err + case <-t.C: + return fmt.Errorf("deadline exceeded") + case <-pcConnected: break outer @@ -473,9 +474,11 @@ outer: } } - // do NOT close the WebSocket connection - // in order to allow the other side of the connection - // to switch to the "connected" state before WebSocket is closed. + // Keep WebSocket connection open and use it to notify shutdowns. + // This is because pion/webrtc doesn't write yet a WebRTC shutdown + // message to clients (like a DTLS close alert or a RTCP BYE), + // therefore browsers do not properly detect shutdowns and do not + // attempt to restart the connection immediately. c.mutex.Lock() c.curPC = pc @@ -516,6 +519,9 @@ outer: case <-pcDisconnected: return fmt.Errorf("peer connection closed") + case err := <-wsReadError: + return fmt.Errorf("websocket error: %v", err) + case err := <-writeError: return err @@ -833,18 +839,12 @@ func (c *webRTCConn) genICEServers() []webrtc.ICEServer { } func (c *webRTCConn) writeICEServers(iceServers []webrtc.ICEServer) error { - enc, _ := json.Marshal(iceServers) - return c.wsconn.WriteMessage(websocket.TextMessage, enc) + return c.wsconn.WriteJSON(iceServers) } func (c *webRTCConn) readOffer() (*webrtc.SessionDescription, error) { - _, enc, err := c.wsconn.ReadMessage() - if err != nil { - return nil, err - } - var offer webrtc.SessionDescription - err = json.Unmarshal(enc, &offer) + err := c.wsconn.ReadJSON(&offer) if err != nil { return nil, err } @@ -857,23 +857,16 @@ func (c *webRTCConn) readOffer() (*webrtc.SessionDescription, error) { } func (c *webRTCConn) writeAnswer(answer *webrtc.SessionDescription) error { - enc, _ := json.Marshal(answer) - return c.wsconn.WriteMessage(websocket.TextMessage, enc) + return c.wsconn.WriteJSON(answer) } func (c *webRTCConn) writeCandidate(candidate *webrtc.ICECandidate) error { - enc, _ := json.Marshal(candidate.ToJSON()) - return c.wsconn.WriteMessage(websocket.TextMessage, enc) + return c.wsconn.WriteJSON(candidate) } func (c *webRTCConn) readCandidate() (*webrtc.ICECandidateInit, error) { - _, enc, err := c.wsconn.ReadMessage() - if err != nil { - return nil, err - } - var candidate webrtc.ICECandidateInit - err = json.Unmarshal(enc, &candidate) + err := c.wsconn.ReadJSON(&candidate) if err != nil { return nil, err } diff --git a/internal/core/webrtc_index.html b/internal/core/webrtc_index.html index 6b26a0b8..278879e8 100644 --- a/internal/core/webrtc_index.html +++ b/internal/core/webrtc_index.html @@ -78,17 +78,6 @@ class Receiver { console.log("peer connection state:", this.pc.iceConnectionState); switch (this.pc.iceConnectionState) { - case "connected": - this.pc.onicecandidate = undefined; - // do not unbind ws.onmessage due to a strange Firefox bug - // if all callbacks are removed from WS, video freezes after some seconds. - this.ws.onerror = undefined - this.ws.onclose = undefined; - // do not close the WebSocket connection - // in order to allow the other side of the connection - // to switch to the "connected" state before WebSocket is closed. - break; - case "disconnected": this.scheduleRestart(); } diff --git a/internal/core/webrtc_server.go b/internal/core/webrtc_server.go index 4dbdd85c..dc5e085e 100644 --- a/internal/core/webrtc_server.go +++ b/internal/core/webrtc_server.go @@ -14,23 +14,17 @@ import ( "time" "github.com/gin-gonic/gin" - "github.com/gorilla/websocket" "github.com/pion/ice/v2" "github.com/pion/webrtc/v3" "github.com/aler9/rtsp-simple-server/internal/conf" "github.com/aler9/rtsp-simple-server/internal/logger" + "github.com/aler9/rtsp-simple-server/internal/websocket" ) //go:embed webrtc_index.html var webrtcIndex []byte -var upgrader = websocket.Upgrader{ - CheckOrigin: func(r *http.Request) bool { - return true - }, -} - type webRTCServerAPIConnsListItem struct { Created time.Time `json:"created"` RemoteAddr string `json:"remoteAddr"` @@ -65,7 +59,7 @@ type webRTCServerAPIConnsKickReq struct { type webRTCConnNewReq struct { pathName string - wsconn *websocket.Conn + wsconn *websocket.ServerConn res chan *webRTCConn } @@ -396,7 +390,7 @@ func (s *webRTCServer) onRequest(ctx *gin.Context) { return case "ws": - wsconn, err := upgrader.Upgrade(ctx.Writer, ctx.Request, nil) + wsconn, err := websocket.NewServerConn(ctx.Writer, ctx.Request) if err != nil { return } @@ -411,7 +405,7 @@ func (s *webRTCServer) onRequest(ctx *gin.Context) { } } -func (s *webRTCServer) newConn(dir string, wsconn *websocket.Conn) *webRTCConn { +func (s *webRTCServer) newConn(dir string, wsconn *websocket.ServerConn) *webRTCConn { req := webRTCConnNewReq{ pathName: dir, wsconn: wsconn, diff --git a/internal/hls/muxer.go b/internal/hls/muxer.go index 43778542..771c01b2 100644 --- a/internal/hls/muxer.go +++ b/internal/hls/muxer.go @@ -1,4 +1,4 @@ -// Package hls contains a HLS muxer and client implementation. +// Package hls contains a HLS muxer and client. package hls import ( diff --git a/internal/rtmp/conn.go b/internal/rtmp/conn.go index 16c1e839..b0d4124b 100644 --- a/internal/rtmp/conn.go +++ b/internal/rtmp/conn.go @@ -1,4 +1,4 @@ -// Package rtmp implements a RTMP connection. +// Package rtmp provides RTMP connectivity. package rtmp import ( diff --git a/internal/websocket/serverconn.go b/internal/websocket/serverconn.go new file mode 100644 index 00000000..a646ddc0 --- /dev/null +++ b/internal/websocket/serverconn.go @@ -0,0 +1,114 @@ +// Package websocket provides WebSocket connectivity. +package websocket + +import ( + "encoding/json" + "fmt" + "net" + "net/http" + "time" + + "github.com/gorilla/websocket" +) + +var ( + pingInterval = 30 * time.Second + pingTimeout = 5 * time.Second + writeTimeout = 2 * time.Second +) + +var upgrader = websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { + return true + }, +} + +// ServerConn is a server-side WebSocket connection with automatic, periodic ping / pong. +type ServerConn struct { + wc *websocket.Conn + + // in + terminate chan struct{} + write chan []byte + + // out + writeErr chan error +} + +// NewServerConn allocates a ServerConn. +func NewServerConn(w http.ResponseWriter, req *http.Request) (*ServerConn, error) { + wc, err := upgrader.Upgrade(w, req, nil) + if err != nil { + return nil, err + } + + c := &ServerConn{ + wc: wc, + terminate: make(chan struct{}), + write: make(chan []byte), + writeErr: make(chan error), + } + + go c.run() + + return c, nil +} + +// Close closes a ServerConn. +func (c *ServerConn) Close() { + c.wc.Close() + close(c.terminate) +} + +// RemoteAddr returns the remote address. +func (c *ServerConn) RemoteAddr() net.Addr { + return c.wc.RemoteAddr() +} + +func (c *ServerConn) run() { + c.wc.SetReadDeadline(time.Now().Add(pingInterval + pingTimeout)) + + c.wc.SetPongHandler(func(string) error { + c.wc.SetReadDeadline(time.Now().Add(pingInterval + pingTimeout)) + return nil + }) + + pingTicker := time.NewTicker(pingInterval) + defer pingTicker.Stop() + + for { + select { + case byts := <-c.write: + c.wc.SetWriteDeadline(time.Now().Add(writeTimeout)) + err := c.wc.WriteMessage(websocket.TextMessage, byts) + c.writeErr <- err + + case <-pingTicker.C: + c.wc.SetWriteDeadline(time.Now().Add(writeTimeout)) + c.wc.WriteMessage(websocket.PingMessage, nil) + + case <-c.terminate: + return + } + } +} + +// ReadJSON reads a JSON object. +func (c *ServerConn) ReadJSON(in interface{}) error { + return c.wc.ReadJSON(in) +} + +// WriteJSON writes a JSON object. +func (c *ServerConn) WriteJSON(in interface{}) error { + byts, err := json.Marshal(in) + if err != nil { + return err + } + + select { + case c.write <- byts: + return <-c.writeErr + case <-c.terminate: + return fmt.Errorf("terminated") + } +} diff --git a/internal/websocket/serverconn_test.go b/internal/websocket/serverconn_test.go new file mode 100644 index 00000000..cc0407fa --- /dev/null +++ b/internal/websocket/serverconn_test.go @@ -0,0 +1,55 @@ +package websocket + +import ( + "context" + "net" + "net/http" + "testing" + "time" + + "github.com/gorilla/websocket" + "github.com/stretchr/testify/require" +) + +func TestServerConn(t *testing.T) { + pingReceived := make(chan struct{}) + pingInterval = 100 * time.Millisecond + + handler := func(w http.ResponseWriter, r *http.Request) { + c, err := NewServerConn(w, r) + require.NoError(t, err) + defer c.Close() + + err = c.WriteJSON("testing") + require.NoError(t, err) + + <-pingReceived + } + + ln, err := net.Listen("tcp", "localhost:6344") + require.NoError(t, err) + defer ln.Close() + + s := &http.Server{Handler: http.HandlerFunc(handler)} + go s.Serve(ln) + defer s.Shutdown(context.Background()) + + c, res, err := websocket.DefaultDialer.Dial("ws://localhost:6344/", nil) + require.NoError(t, err) + defer res.Body.Close() + defer c.Close() + + c.SetPingHandler(func(msg string) error { + close(pingReceived) + return nil + }) + + var msg string + err = c.ReadJSON(&msg) + require.NoError(t, err) + require.Equal(t, "testing", msg) + + c.ReadMessage() + + <-pingReceived +}