mirror of
https://github.com/bluenviron/mediamtx
synced 2024-12-24 15:42:28 +00:00
webrtc muxer: keep the WebSocket connection
The WebSocket connection is kept open in order to use it to notify shutdowns.
This commit is contained in:
parent
f3f55452e5
commit
b20abbed6c
@ -5,7 +5,6 @@ import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha1"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
@ -22,7 +21,6 @@ import (
|
||||
"github.com/aler9/gortsplib/v2/pkg/media"
|
||||
"github.com/aler9/gortsplib/v2/pkg/ringbuffer"
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/pion/ice/v2"
|
||||
"github.com/pion/interceptor"
|
||||
"github.com/pion/webrtc/v3"
|
||||
@ -30,11 +28,13 @@ import (
|
||||
"github.com/aler9/rtsp-simple-server/internal/conf"
|
||||
"github.com/aler9/rtsp-simple-server/internal/formatprocessor"
|
||||
"github.com/aler9/rtsp-simple-server/internal/logger"
|
||||
"github.com/aler9/rtsp-simple-server/internal/websocket"
|
||||
)
|
||||
|
||||
const (
|
||||
webrtcHandshakeDeadline = 10 * time.Second
|
||||
webrtcPayloadMaxSize = 1200
|
||||
webrtcWsWriteDeadline = 2 * time.Second
|
||||
webrtcPayloadMaxSize = 1188 // 1200 - 12 (RTP header)
|
||||
)
|
||||
|
||||
// newPeerConnection creates a PeerConnection with the default codecs and
|
||||
@ -91,7 +91,7 @@ type webRTCConnParent interface {
|
||||
type webRTCConn struct {
|
||||
readBufferCount int
|
||||
pathName string
|
||||
wsconn *websocket.Conn
|
||||
wsconn *websocket.ServerConn
|
||||
iceServers []string
|
||||
wg *sync.WaitGroup
|
||||
pathManager webRTCConnPathManager
|
||||
@ -114,7 +114,7 @@ func newWebRTCConn(
|
||||
parentCtx context.Context,
|
||||
readBufferCount int,
|
||||
pathName string,
|
||||
wsconn *websocket.Conn,
|
||||
wsconn *websocket.ServerConn,
|
||||
iceServers []string,
|
||||
wg *sync.WaitGroup,
|
||||
pathManager webRTCConnPathManager,
|
||||
@ -311,10 +311,6 @@ func (c *webRTCConn) runInner(ctx context.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// maximum deadline to complete the handshake
|
||||
c.wsconn.SetReadDeadline(time.Now().Add(webrtcHandshakeDeadline))
|
||||
c.wsconn.SetWriteDeadline(time.Now().Add(webrtcHandshakeDeadline))
|
||||
|
||||
err = c.writeICEServers(c.genICEServers())
|
||||
if err != nil {
|
||||
return err
|
||||
@ -425,7 +421,7 @@ func (c *webRTCConn) runInner(ctx context.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
readError := make(chan error)
|
||||
wsReadError := make(chan error)
|
||||
remoteCandidate := make(chan *webrtc.ICECandidateInit)
|
||||
|
||||
go func() {
|
||||
@ -433,8 +429,7 @@ func (c *webRTCConn) runInner(ctx context.Context) error {
|
||||
candidate, err := c.readCandidate()
|
||||
if err != nil {
|
||||
select {
|
||||
case readError <- err:
|
||||
case <-pcConnected:
|
||||
case wsReadError <- err:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
return
|
||||
@ -448,6 +443,9 @@ func (c *webRTCConn) runInner(ctx context.Context) error {
|
||||
}
|
||||
}()
|
||||
|
||||
t := time.NewTimer(webrtcHandshakeDeadline)
|
||||
defer t.Stop()
|
||||
|
||||
outer:
|
||||
for {
|
||||
select {
|
||||
@ -462,9 +460,12 @@ outer:
|
||||
return err
|
||||
}
|
||||
|
||||
case err := <-readError:
|
||||
case err := <-wsReadError:
|
||||
return err
|
||||
|
||||
case <-t.C:
|
||||
return fmt.Errorf("deadline exceeded")
|
||||
|
||||
case <-pcConnected:
|
||||
break outer
|
||||
|
||||
@ -473,9 +474,11 @@ outer:
|
||||
}
|
||||
}
|
||||
|
||||
// do NOT close the WebSocket connection
|
||||
// in order to allow the other side of the connection
|
||||
// to switch to the "connected" state before WebSocket is closed.
|
||||
// Keep WebSocket connection open and use it to notify shutdowns.
|
||||
// This is because pion/webrtc doesn't write yet a WebRTC shutdown
|
||||
// message to clients (like a DTLS close alert or a RTCP BYE),
|
||||
// therefore browsers do not properly detect shutdowns and do not
|
||||
// attempt to restart the connection immediately.
|
||||
|
||||
c.mutex.Lock()
|
||||
c.curPC = pc
|
||||
@ -516,6 +519,9 @@ outer:
|
||||
case <-pcDisconnected:
|
||||
return fmt.Errorf("peer connection closed")
|
||||
|
||||
case err := <-wsReadError:
|
||||
return fmt.Errorf("websocket error: %v", err)
|
||||
|
||||
case err := <-writeError:
|
||||
return err
|
||||
|
||||
@ -833,18 +839,12 @@ func (c *webRTCConn) genICEServers() []webrtc.ICEServer {
|
||||
}
|
||||
|
||||
func (c *webRTCConn) writeICEServers(iceServers []webrtc.ICEServer) error {
|
||||
enc, _ := json.Marshal(iceServers)
|
||||
return c.wsconn.WriteMessage(websocket.TextMessage, enc)
|
||||
return c.wsconn.WriteJSON(iceServers)
|
||||
}
|
||||
|
||||
func (c *webRTCConn) readOffer() (*webrtc.SessionDescription, error) {
|
||||
_, enc, err := c.wsconn.ReadMessage()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var offer webrtc.SessionDescription
|
||||
err = json.Unmarshal(enc, &offer)
|
||||
err := c.wsconn.ReadJSON(&offer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -857,23 +857,16 @@ func (c *webRTCConn) readOffer() (*webrtc.SessionDescription, error) {
|
||||
}
|
||||
|
||||
func (c *webRTCConn) writeAnswer(answer *webrtc.SessionDescription) error {
|
||||
enc, _ := json.Marshal(answer)
|
||||
return c.wsconn.WriteMessage(websocket.TextMessage, enc)
|
||||
return c.wsconn.WriteJSON(answer)
|
||||
}
|
||||
|
||||
func (c *webRTCConn) writeCandidate(candidate *webrtc.ICECandidate) error {
|
||||
enc, _ := json.Marshal(candidate.ToJSON())
|
||||
return c.wsconn.WriteMessage(websocket.TextMessage, enc)
|
||||
return c.wsconn.WriteJSON(candidate)
|
||||
}
|
||||
|
||||
func (c *webRTCConn) readCandidate() (*webrtc.ICECandidateInit, error) {
|
||||
_, enc, err := c.wsconn.ReadMessage()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var candidate webrtc.ICECandidateInit
|
||||
err = json.Unmarshal(enc, &candidate)
|
||||
err := c.wsconn.ReadJSON(&candidate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -78,17 +78,6 @@ class Receiver {
|
||||
console.log("peer connection state:", this.pc.iceConnectionState);
|
||||
|
||||
switch (this.pc.iceConnectionState) {
|
||||
case "connected":
|
||||
this.pc.onicecandidate = undefined;
|
||||
// do not unbind ws.onmessage due to a strange Firefox bug
|
||||
// if all callbacks are removed from WS, video freezes after some seconds.
|
||||
this.ws.onerror = undefined
|
||||
this.ws.onclose = undefined;
|
||||
// do not close the WebSocket connection
|
||||
// in order to allow the other side of the connection
|
||||
// to switch to the "connected" state before WebSocket is closed.
|
||||
break;
|
||||
|
||||
case "disconnected":
|
||||
this.scheduleRestart();
|
||||
}
|
||||
|
@ -14,23 +14,17 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/pion/ice/v2"
|
||||
"github.com/pion/webrtc/v3"
|
||||
|
||||
"github.com/aler9/rtsp-simple-server/internal/conf"
|
||||
"github.com/aler9/rtsp-simple-server/internal/logger"
|
||||
"github.com/aler9/rtsp-simple-server/internal/websocket"
|
||||
)
|
||||
|
||||
//go:embed webrtc_index.html
|
||||
var webrtcIndex []byte
|
||||
|
||||
var upgrader = websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true
|
||||
},
|
||||
}
|
||||
|
||||
type webRTCServerAPIConnsListItem struct {
|
||||
Created time.Time `json:"created"`
|
||||
RemoteAddr string `json:"remoteAddr"`
|
||||
@ -65,7 +59,7 @@ type webRTCServerAPIConnsKickReq struct {
|
||||
|
||||
type webRTCConnNewReq struct {
|
||||
pathName string
|
||||
wsconn *websocket.Conn
|
||||
wsconn *websocket.ServerConn
|
||||
res chan *webRTCConn
|
||||
}
|
||||
|
||||
@ -396,7 +390,7 @@ func (s *webRTCServer) onRequest(ctx *gin.Context) {
|
||||
return
|
||||
|
||||
case "ws":
|
||||
wsconn, err := upgrader.Upgrade(ctx.Writer, ctx.Request, nil)
|
||||
wsconn, err := websocket.NewServerConn(ctx.Writer, ctx.Request)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@ -411,7 +405,7 @@ func (s *webRTCServer) onRequest(ctx *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *webRTCServer) newConn(dir string, wsconn *websocket.Conn) *webRTCConn {
|
||||
func (s *webRTCServer) newConn(dir string, wsconn *websocket.ServerConn) *webRTCConn {
|
||||
req := webRTCConnNewReq{
|
||||
pathName: dir,
|
||||
wsconn: wsconn,
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Package hls contains a HLS muxer and client implementation.
|
||||
// Package hls contains a HLS muxer and client.
|
||||
package hls
|
||||
|
||||
import (
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Package rtmp implements a RTMP connection.
|
||||
// Package rtmp provides RTMP connectivity.
|
||||
package rtmp
|
||||
|
||||
import (
|
||||
|
114
internal/websocket/serverconn.go
Normal file
114
internal/websocket/serverconn.go
Normal file
@ -0,0 +1,114 @@
|
||||
// Package websocket provides WebSocket connectivity.
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
var (
|
||||
pingInterval = 30 * time.Second
|
||||
pingTimeout = 5 * time.Second
|
||||
writeTimeout = 2 * time.Second
|
||||
)
|
||||
|
||||
var upgrader = websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true
|
||||
},
|
||||
}
|
||||
|
||||
// ServerConn is a server-side WebSocket connection with automatic, periodic ping / pong.
|
||||
type ServerConn struct {
|
||||
wc *websocket.Conn
|
||||
|
||||
// in
|
||||
terminate chan struct{}
|
||||
write chan []byte
|
||||
|
||||
// out
|
||||
writeErr chan error
|
||||
}
|
||||
|
||||
// NewServerConn allocates a ServerConn.
|
||||
func NewServerConn(w http.ResponseWriter, req *http.Request) (*ServerConn, error) {
|
||||
wc, err := upgrader.Upgrade(w, req, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c := &ServerConn{
|
||||
wc: wc,
|
||||
terminate: make(chan struct{}),
|
||||
write: make(chan []byte),
|
||||
writeErr: make(chan error),
|
||||
}
|
||||
|
||||
go c.run()
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// Close closes a ServerConn.
|
||||
func (c *ServerConn) Close() {
|
||||
c.wc.Close()
|
||||
close(c.terminate)
|
||||
}
|
||||
|
||||
// RemoteAddr returns the remote address.
|
||||
func (c *ServerConn) RemoteAddr() net.Addr {
|
||||
return c.wc.RemoteAddr()
|
||||
}
|
||||
|
||||
func (c *ServerConn) run() {
|
||||
c.wc.SetReadDeadline(time.Now().Add(pingInterval + pingTimeout))
|
||||
|
||||
c.wc.SetPongHandler(func(string) error {
|
||||
c.wc.SetReadDeadline(time.Now().Add(pingInterval + pingTimeout))
|
||||
return nil
|
||||
})
|
||||
|
||||
pingTicker := time.NewTicker(pingInterval)
|
||||
defer pingTicker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case byts := <-c.write:
|
||||
c.wc.SetWriteDeadline(time.Now().Add(writeTimeout))
|
||||
err := c.wc.WriteMessage(websocket.TextMessage, byts)
|
||||
c.writeErr <- err
|
||||
|
||||
case <-pingTicker.C:
|
||||
c.wc.SetWriteDeadline(time.Now().Add(writeTimeout))
|
||||
c.wc.WriteMessage(websocket.PingMessage, nil)
|
||||
|
||||
case <-c.terminate:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ReadJSON reads a JSON object.
|
||||
func (c *ServerConn) ReadJSON(in interface{}) error {
|
||||
return c.wc.ReadJSON(in)
|
||||
}
|
||||
|
||||
// WriteJSON writes a JSON object.
|
||||
func (c *ServerConn) WriteJSON(in interface{}) error {
|
||||
byts, err := json.Marshal(in)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
select {
|
||||
case c.write <- byts:
|
||||
return <-c.writeErr
|
||||
case <-c.terminate:
|
||||
return fmt.Errorf("terminated")
|
||||
}
|
||||
}
|
55
internal/websocket/serverconn_test.go
Normal file
55
internal/websocket/serverconn_test.go
Normal file
@ -0,0 +1,55 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestServerConn(t *testing.T) {
|
||||
pingReceived := make(chan struct{})
|
||||
pingInterval = 100 * time.Millisecond
|
||||
|
||||
handler := func(w http.ResponseWriter, r *http.Request) {
|
||||
c, err := NewServerConn(w, r)
|
||||
require.NoError(t, err)
|
||||
defer c.Close()
|
||||
|
||||
err = c.WriteJSON("testing")
|
||||
require.NoError(t, err)
|
||||
|
||||
<-pingReceived
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp", "localhost:6344")
|
||||
require.NoError(t, err)
|
||||
defer ln.Close()
|
||||
|
||||
s := &http.Server{Handler: http.HandlerFunc(handler)}
|
||||
go s.Serve(ln)
|
||||
defer s.Shutdown(context.Background())
|
||||
|
||||
c, res, err := websocket.DefaultDialer.Dial("ws://localhost:6344/", nil)
|
||||
require.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
defer c.Close()
|
||||
|
||||
c.SetPingHandler(func(msg string) error {
|
||||
close(pingReceived)
|
||||
return nil
|
||||
})
|
||||
|
||||
var msg string
|
||||
err = c.ReadJSON(&msg)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "testing", msg)
|
||||
|
||||
c.ReadMessage()
|
||||
|
||||
<-pingReceived
|
||||
}
|
Loading…
Reference in New Issue
Block a user