webrtc muxer: keep the WebSocket connection

The WebSocket connection is kept open in order to use it to notify
shutdowns.
This commit is contained in:
aler9 2023-01-08 15:36:33 +01:00
parent f3f55452e5
commit b20abbed6c
7 changed files with 202 additions and 57 deletions

View File

@ -5,7 +5,6 @@ import (
"crypto/hmac" "crypto/hmac"
"crypto/sha1" "crypto/sha1"
"encoding/base64" "encoding/base64"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"math/rand" "math/rand"
@ -22,7 +21,6 @@ import (
"github.com/aler9/gortsplib/v2/pkg/media" "github.com/aler9/gortsplib/v2/pkg/media"
"github.com/aler9/gortsplib/v2/pkg/ringbuffer" "github.com/aler9/gortsplib/v2/pkg/ringbuffer"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/gorilla/websocket"
"github.com/pion/ice/v2" "github.com/pion/ice/v2"
"github.com/pion/interceptor" "github.com/pion/interceptor"
"github.com/pion/webrtc/v3" "github.com/pion/webrtc/v3"
@ -30,11 +28,13 @@ import (
"github.com/aler9/rtsp-simple-server/internal/conf" "github.com/aler9/rtsp-simple-server/internal/conf"
"github.com/aler9/rtsp-simple-server/internal/formatprocessor" "github.com/aler9/rtsp-simple-server/internal/formatprocessor"
"github.com/aler9/rtsp-simple-server/internal/logger" "github.com/aler9/rtsp-simple-server/internal/logger"
"github.com/aler9/rtsp-simple-server/internal/websocket"
) )
const ( const (
webrtcHandshakeDeadline = 10 * time.Second webrtcHandshakeDeadline = 10 * time.Second
webrtcPayloadMaxSize = 1200 webrtcWsWriteDeadline = 2 * time.Second
webrtcPayloadMaxSize = 1188 // 1200 - 12 (RTP header)
) )
// newPeerConnection creates a PeerConnection with the default codecs and // newPeerConnection creates a PeerConnection with the default codecs and
@ -91,7 +91,7 @@ type webRTCConnParent interface {
type webRTCConn struct { type webRTCConn struct {
readBufferCount int readBufferCount int
pathName string pathName string
wsconn *websocket.Conn wsconn *websocket.ServerConn
iceServers []string iceServers []string
wg *sync.WaitGroup wg *sync.WaitGroup
pathManager webRTCConnPathManager pathManager webRTCConnPathManager
@ -114,7 +114,7 @@ func newWebRTCConn(
parentCtx context.Context, parentCtx context.Context,
readBufferCount int, readBufferCount int,
pathName string, pathName string,
wsconn *websocket.Conn, wsconn *websocket.ServerConn,
iceServers []string, iceServers []string,
wg *sync.WaitGroup, wg *sync.WaitGroup,
pathManager webRTCConnPathManager, pathManager webRTCConnPathManager,
@ -311,10 +311,6 @@ func (c *webRTCConn) runInner(ctx context.Context) error {
return err return err
} }
// maximum deadline to complete the handshake
c.wsconn.SetReadDeadline(time.Now().Add(webrtcHandshakeDeadline))
c.wsconn.SetWriteDeadline(time.Now().Add(webrtcHandshakeDeadline))
err = c.writeICEServers(c.genICEServers()) err = c.writeICEServers(c.genICEServers())
if err != nil { if err != nil {
return err return err
@ -425,7 +421,7 @@ func (c *webRTCConn) runInner(ctx context.Context) error {
return err return err
} }
readError := make(chan error) wsReadError := make(chan error)
remoteCandidate := make(chan *webrtc.ICECandidateInit) remoteCandidate := make(chan *webrtc.ICECandidateInit)
go func() { go func() {
@ -433,8 +429,7 @@ func (c *webRTCConn) runInner(ctx context.Context) error {
candidate, err := c.readCandidate() candidate, err := c.readCandidate()
if err != nil { if err != nil {
select { select {
case readError <- err: case wsReadError <- err:
case <-pcConnected:
case <-ctx.Done(): case <-ctx.Done():
} }
return return
@ -448,6 +443,9 @@ func (c *webRTCConn) runInner(ctx context.Context) error {
} }
}() }()
t := time.NewTimer(webrtcHandshakeDeadline)
defer t.Stop()
outer: outer:
for { for {
select { select {
@ -462,9 +460,12 @@ outer:
return err return err
} }
case err := <-readError: case err := <-wsReadError:
return err return err
case <-t.C:
return fmt.Errorf("deadline exceeded")
case <-pcConnected: case <-pcConnected:
break outer break outer
@ -473,9 +474,11 @@ outer:
} }
} }
// do NOT close the WebSocket connection // Keep WebSocket connection open and use it to notify shutdowns.
// in order to allow the other side of the connection // This is because pion/webrtc doesn't write yet a WebRTC shutdown
// to switch to the "connected" state before WebSocket is closed. // message to clients (like a DTLS close alert or a RTCP BYE),
// therefore browsers do not properly detect shutdowns and do not
// attempt to restart the connection immediately.
c.mutex.Lock() c.mutex.Lock()
c.curPC = pc c.curPC = pc
@ -516,6 +519,9 @@ outer:
case <-pcDisconnected: case <-pcDisconnected:
return fmt.Errorf("peer connection closed") return fmt.Errorf("peer connection closed")
case err := <-wsReadError:
return fmt.Errorf("websocket error: %v", err)
case err := <-writeError: case err := <-writeError:
return err return err
@ -833,18 +839,12 @@ func (c *webRTCConn) genICEServers() []webrtc.ICEServer {
} }
func (c *webRTCConn) writeICEServers(iceServers []webrtc.ICEServer) error { func (c *webRTCConn) writeICEServers(iceServers []webrtc.ICEServer) error {
enc, _ := json.Marshal(iceServers) return c.wsconn.WriteJSON(iceServers)
return c.wsconn.WriteMessage(websocket.TextMessage, enc)
} }
func (c *webRTCConn) readOffer() (*webrtc.SessionDescription, error) { func (c *webRTCConn) readOffer() (*webrtc.SessionDescription, error) {
_, enc, err := c.wsconn.ReadMessage()
if err != nil {
return nil, err
}
var offer webrtc.SessionDescription var offer webrtc.SessionDescription
err = json.Unmarshal(enc, &offer) err := c.wsconn.ReadJSON(&offer)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -857,23 +857,16 @@ func (c *webRTCConn) readOffer() (*webrtc.SessionDescription, error) {
} }
func (c *webRTCConn) writeAnswer(answer *webrtc.SessionDescription) error { func (c *webRTCConn) writeAnswer(answer *webrtc.SessionDescription) error {
enc, _ := json.Marshal(answer) return c.wsconn.WriteJSON(answer)
return c.wsconn.WriteMessage(websocket.TextMessage, enc)
} }
func (c *webRTCConn) writeCandidate(candidate *webrtc.ICECandidate) error { func (c *webRTCConn) writeCandidate(candidate *webrtc.ICECandidate) error {
enc, _ := json.Marshal(candidate.ToJSON()) return c.wsconn.WriteJSON(candidate)
return c.wsconn.WriteMessage(websocket.TextMessage, enc)
} }
func (c *webRTCConn) readCandidate() (*webrtc.ICECandidateInit, error) { func (c *webRTCConn) readCandidate() (*webrtc.ICECandidateInit, error) {
_, enc, err := c.wsconn.ReadMessage()
if err != nil {
return nil, err
}
var candidate webrtc.ICECandidateInit var candidate webrtc.ICECandidateInit
err = json.Unmarshal(enc, &candidate) err := c.wsconn.ReadJSON(&candidate)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -78,17 +78,6 @@ class Receiver {
console.log("peer connection state:", this.pc.iceConnectionState); console.log("peer connection state:", this.pc.iceConnectionState);
switch (this.pc.iceConnectionState) { switch (this.pc.iceConnectionState) {
case "connected":
this.pc.onicecandidate = undefined;
// do not unbind ws.onmessage due to a strange Firefox bug
// if all callbacks are removed from WS, video freezes after some seconds.
this.ws.onerror = undefined
this.ws.onclose = undefined;
// do not close the WebSocket connection
// in order to allow the other side of the connection
// to switch to the "connected" state before WebSocket is closed.
break;
case "disconnected": case "disconnected":
this.scheduleRestart(); this.scheduleRestart();
} }

View File

@ -14,23 +14,17 @@ import (
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/pion/ice/v2" "github.com/pion/ice/v2"
"github.com/pion/webrtc/v3" "github.com/pion/webrtc/v3"
"github.com/aler9/rtsp-simple-server/internal/conf" "github.com/aler9/rtsp-simple-server/internal/conf"
"github.com/aler9/rtsp-simple-server/internal/logger" "github.com/aler9/rtsp-simple-server/internal/logger"
"github.com/aler9/rtsp-simple-server/internal/websocket"
) )
//go:embed webrtc_index.html //go:embed webrtc_index.html
var webrtcIndex []byte var webrtcIndex []byte
var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true
},
}
type webRTCServerAPIConnsListItem struct { type webRTCServerAPIConnsListItem struct {
Created time.Time `json:"created"` Created time.Time `json:"created"`
RemoteAddr string `json:"remoteAddr"` RemoteAddr string `json:"remoteAddr"`
@ -65,7 +59,7 @@ type webRTCServerAPIConnsKickReq struct {
type webRTCConnNewReq struct { type webRTCConnNewReq struct {
pathName string pathName string
wsconn *websocket.Conn wsconn *websocket.ServerConn
res chan *webRTCConn res chan *webRTCConn
} }
@ -396,7 +390,7 @@ func (s *webRTCServer) onRequest(ctx *gin.Context) {
return return
case "ws": case "ws":
wsconn, err := upgrader.Upgrade(ctx.Writer, ctx.Request, nil) wsconn, err := websocket.NewServerConn(ctx.Writer, ctx.Request)
if err != nil { if err != nil {
return return
} }
@ -411,7 +405,7 @@ func (s *webRTCServer) onRequest(ctx *gin.Context) {
} }
} }
func (s *webRTCServer) newConn(dir string, wsconn *websocket.Conn) *webRTCConn { func (s *webRTCServer) newConn(dir string, wsconn *websocket.ServerConn) *webRTCConn {
req := webRTCConnNewReq{ req := webRTCConnNewReq{
pathName: dir, pathName: dir,
wsconn: wsconn, wsconn: wsconn,

View File

@ -1,4 +1,4 @@
// Package hls contains a HLS muxer and client implementation. // Package hls contains a HLS muxer and client.
package hls package hls
import ( import (

View File

@ -1,4 +1,4 @@
// Package rtmp implements a RTMP connection. // Package rtmp provides RTMP connectivity.
package rtmp package rtmp
import ( import (

View File

@ -0,0 +1,114 @@
// Package websocket provides WebSocket connectivity.
package websocket
import (
"encoding/json"
"fmt"
"net"
"net/http"
"time"
"github.com/gorilla/websocket"
)
var (
pingInterval = 30 * time.Second
pingTimeout = 5 * time.Second
writeTimeout = 2 * time.Second
)
var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true
},
}
// ServerConn is a server-side WebSocket connection with automatic, periodic ping / pong.
type ServerConn struct {
wc *websocket.Conn
// in
terminate chan struct{}
write chan []byte
// out
writeErr chan error
}
// NewServerConn allocates a ServerConn.
func NewServerConn(w http.ResponseWriter, req *http.Request) (*ServerConn, error) {
wc, err := upgrader.Upgrade(w, req, nil)
if err != nil {
return nil, err
}
c := &ServerConn{
wc: wc,
terminate: make(chan struct{}),
write: make(chan []byte),
writeErr: make(chan error),
}
go c.run()
return c, nil
}
// Close closes a ServerConn.
func (c *ServerConn) Close() {
c.wc.Close()
close(c.terminate)
}
// RemoteAddr returns the remote address.
func (c *ServerConn) RemoteAddr() net.Addr {
return c.wc.RemoteAddr()
}
func (c *ServerConn) run() {
c.wc.SetReadDeadline(time.Now().Add(pingInterval + pingTimeout))
c.wc.SetPongHandler(func(string) error {
c.wc.SetReadDeadline(time.Now().Add(pingInterval + pingTimeout))
return nil
})
pingTicker := time.NewTicker(pingInterval)
defer pingTicker.Stop()
for {
select {
case byts := <-c.write:
c.wc.SetWriteDeadline(time.Now().Add(writeTimeout))
err := c.wc.WriteMessage(websocket.TextMessage, byts)
c.writeErr <- err
case <-pingTicker.C:
c.wc.SetWriteDeadline(time.Now().Add(writeTimeout))
c.wc.WriteMessage(websocket.PingMessage, nil)
case <-c.terminate:
return
}
}
}
// ReadJSON reads a JSON object.
func (c *ServerConn) ReadJSON(in interface{}) error {
return c.wc.ReadJSON(in)
}
// WriteJSON writes a JSON object.
func (c *ServerConn) WriteJSON(in interface{}) error {
byts, err := json.Marshal(in)
if err != nil {
return err
}
select {
case c.write <- byts:
return <-c.writeErr
case <-c.terminate:
return fmt.Errorf("terminated")
}
}

View File

@ -0,0 +1,55 @@
package websocket
import (
"context"
"net"
"net/http"
"testing"
"time"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/require"
)
func TestServerConn(t *testing.T) {
pingReceived := make(chan struct{})
pingInterval = 100 * time.Millisecond
handler := func(w http.ResponseWriter, r *http.Request) {
c, err := NewServerConn(w, r)
require.NoError(t, err)
defer c.Close()
err = c.WriteJSON("testing")
require.NoError(t, err)
<-pingReceived
}
ln, err := net.Listen("tcp", "localhost:6344")
require.NoError(t, err)
defer ln.Close()
s := &http.Server{Handler: http.HandlerFunc(handler)}
go s.Serve(ln)
defer s.Shutdown(context.Background())
c, res, err := websocket.DefaultDialer.Dial("ws://localhost:6344/", nil)
require.NoError(t, err)
defer res.Body.Close()
defer c.Close()
c.SetPingHandler(func(msg string) error {
close(pingReceived)
return nil
})
var msg string
err = c.ReadJSON(&msg)
require.NoError(t, err)
require.Equal(t, "testing", msg)
c.ReadMessage()
<-pingReceived
}