mediamtx/internal/core/webrtc_conn.go
2023-03-10 12:44:59 +01:00

880 lines
19 KiB
Go

package core
import (
"context"
"crypto/hmac"
"crypto/sha1"
"encoding/base64"
"errors"
"fmt"
"math/rand"
"net"
"strconv"
"strings"
"sync"
"time"
"github.com/aler9/gortsplib/v2/pkg/format"
"github.com/aler9/gortsplib/v2/pkg/formatdecenc/rtph264"
"github.com/aler9/gortsplib/v2/pkg/formatdecenc/rtpvp8"
"github.com/aler9/gortsplib/v2/pkg/formatdecenc/rtpvp9"
"github.com/aler9/gortsplib/v2/pkg/media"
"github.com/aler9/gortsplib/v2/pkg/ringbuffer"
"github.com/google/uuid"
"github.com/pion/ice/v2"
"github.com/pion/interceptor"
"github.com/pion/webrtc/v3"
"github.com/aler9/rtsp-simple-server/internal/conf"
"github.com/aler9/rtsp-simple-server/internal/formatprocessor"
"github.com/aler9/rtsp-simple-server/internal/logger"
"github.com/aler9/rtsp-simple-server/internal/websocket"
)
const (
webrtcHandshakeDeadline = 10 * time.Second
webrtcWsWriteDeadline = 2 * time.Second
webrtcPayloadMaxSize = 1188 // 1200 - 12 (RTP header)
)
// newPeerConnection creates a PeerConnection with the default codecs and
// interceptors. See RegisterDefaultCodecs and RegisterDefaultInterceptors.
//
// This function is a copy of webrtc/peerconnection.go
// unlike the original one, allows you to add additional custom options
func newPeerConnection(configuration webrtc.Configuration,
options ...func(*webrtc.API),
) (*webrtc.PeerConnection, error) {
m := &webrtc.MediaEngine{}
if err := m.RegisterDefaultCodecs(); err != nil {
return nil, err
}
i := &interceptor.Registry{}
if err := webrtc.RegisterDefaultInterceptors(m, i); err != nil {
return nil, err
}
options = append(options, webrtc.WithMediaEngine(m))
options = append(options, webrtc.WithInterceptorRegistry(i))
api := webrtc.NewAPI(options...)
return api.NewPeerConnection(configuration)
}
type webRTCTrack struct {
media *media.Media
format format.Format
webRTCTrack *webrtc.TrackLocalStaticRTP
cb func(formatprocessor.Unit, context.Context, chan error)
}
func gatherMedias(tracks []*webRTCTrack) media.Medias {
var ret media.Medias
for _, track := range tracks {
ret = append(ret, track.media)
}
return ret
}
type webRTCConnPathManager interface {
readerAdd(req pathReaderAddReq) pathReaderSetupPlayRes
}
type webRTCConnParent interface {
log(logger.Level, string, ...interface{})
connClose(*webRTCConn)
}
type webRTCConn struct {
readBufferCount int
pathName string
wsconn *websocket.ServerConn
iceServers []string
wg *sync.WaitGroup
pathManager webRTCConnPathManager
parent webRTCConnParent
iceUDPMux ice.UDPMux
iceTCPMux ice.TCPMux
iceHostNAT1To1IPs []string
ctx context.Context
ctxCancel func()
uuid uuid.UUID
created time.Time
curPC *webrtc.PeerConnection
mutex sync.RWMutex
closed chan struct{}
}
func newWebRTCConn(
parentCtx context.Context,
readBufferCount int,
pathName string,
wsconn *websocket.ServerConn,
iceServers []string,
wg *sync.WaitGroup,
pathManager webRTCConnPathManager,
parent webRTCConnParent,
iceHostNAT1To1IPs []string,
iceUDPMux ice.UDPMux,
iceTCPMux ice.TCPMux,
) *webRTCConn {
ctx, ctxCancel := context.WithCancel(parentCtx)
c := &webRTCConn{
readBufferCount: readBufferCount,
pathName: pathName,
wsconn: wsconn,
iceServers: iceServers,
wg: wg,
pathManager: pathManager,
parent: parent,
ctx: ctx,
ctxCancel: ctxCancel,
uuid: uuid.New(),
created: time.Now(),
iceUDPMux: iceUDPMux,
iceTCPMux: iceTCPMux,
iceHostNAT1To1IPs: iceHostNAT1To1IPs,
closed: make(chan struct{}),
}
c.log(logger.Info, "opened")
wg.Add(1)
go c.run()
return c
}
func (c *webRTCConn) close() {
c.ctxCancel()
}
func (c *webRTCConn) wait() {
<-c.closed
}
func (c *webRTCConn) remoteAddr() net.Addr {
return c.wsconn.RemoteAddr()
}
func (c *webRTCConn) peerConnectionEstablished() bool {
c.mutex.RLock()
defer c.mutex.RUnlock()
return c.curPC != nil
}
func (c *webRTCConn) localCandidate() string {
c.mutex.RLock()
defer c.mutex.RUnlock()
if c.curPC != nil {
var cid string
for _, stats := range c.curPC.GetStats() {
if tstats, ok := stats.(webrtc.ICECandidatePairStats); ok && tstats.Nominated {
cid = tstats.LocalCandidateID
break
}
}
if cid != "" {
for _, stats := range c.curPC.GetStats() {
if tstats, ok := stats.(webrtc.ICECandidateStats); ok && tstats.ID == cid {
return tstats.CandidateType.String() + "/" + tstats.Protocol + "/" +
tstats.IP + "/" + strconv.FormatInt(int64(tstats.Port), 10)
}
}
}
}
return ""
}
func (c *webRTCConn) remoteCandidate() string {
c.mutex.RLock()
defer c.mutex.RUnlock()
if c.curPC != nil {
var cid string
for _, stats := range c.curPC.GetStats() {
if tstats, ok := stats.(webrtc.ICECandidatePairStats); ok && tstats.Nominated {
cid = tstats.RemoteCandidateID
break
}
}
if cid != "" {
for _, stats := range c.curPC.GetStats() {
if tstats, ok := stats.(webrtc.ICECandidateStats); ok && tstats.ID == cid {
return tstats.CandidateType.String() + "/" + tstats.Protocol + "/" +
tstats.IP + "/" + strconv.FormatInt(int64(tstats.Port), 10)
}
}
}
}
return ""
}
func (c *webRTCConn) bytesReceived() uint64 {
c.mutex.RLock()
defer c.mutex.RUnlock()
if c.curPC != nil {
for _, stats := range c.curPC.GetStats() {
if tstats, ok := stats.(webrtc.TransportStats); ok {
if tstats.ID == "iceTransport" {
return tstats.BytesReceived
}
}
}
}
return 0
}
func (c *webRTCConn) bytesSent() uint64 {
c.mutex.RLock()
defer c.mutex.RUnlock()
if c.curPC != nil {
for _, stats := range c.curPC.GetStats() {
if tstats, ok := stats.(webrtc.TransportStats); ok {
if tstats.ID == "iceTransport" {
return tstats.BytesSent
}
}
}
}
return 0
}
func (c *webRTCConn) log(level logger.Level, format string, args ...interface{}) {
c.parent.log(level, "[conn %v] "+format, append([]interface{}{c.wsconn.RemoteAddr()}, args...)...)
}
func (c *webRTCConn) run() {
defer close(c.closed)
defer c.wg.Done()
innerCtx, innerCtxCancel := context.WithCancel(c.ctx)
runErr := make(chan error)
go func() {
runErr <- c.runInner(innerCtx)
}()
var err error
select {
case err = <-runErr:
innerCtxCancel()
case <-c.ctx.Done():
innerCtxCancel()
<-runErr
err = errors.New("terminated")
}
c.ctxCancel()
c.parent.connClose(c)
c.log(logger.Info, "closed (%v)", err)
}
func (c *webRTCConn) runInner(ctx context.Context) error {
res := c.pathManager.readerAdd(pathReaderAddReq{
author: c,
pathName: c.pathName,
authenticate: func(
pathIPs []fmt.Stringer,
pathUser conf.Credential,
pathPass conf.Credential,
) error {
return nil
},
})
if res.err != nil {
return res.err
}
path := res.path
defer func() {
path.readerRemove(pathReaderRemoveReq{author: c})
}()
tracks, err := c.allocateTracks(res.stream.medias())
if err != nil {
return err
}
err = c.wsconn.WriteJSON(c.genICEServers())
if err != nil {
return err
}
offer, err := c.readOffer()
if err != nil {
return err
}
configuration := webrtc.Configuration{ICEServers: c.genICEServers()}
settingsEngine := webrtc.SettingEngine{}
if len(c.iceHostNAT1To1IPs) != 0 {
settingsEngine.SetNAT1To1IPs(c.iceHostNAT1To1IPs, webrtc.ICECandidateTypeHost)
}
if c.iceUDPMux != nil {
settingsEngine.SetICEUDPMux(c.iceUDPMux)
}
if c.iceTCPMux != nil {
settingsEngine.SetICETCPMux(c.iceTCPMux)
settingsEngine.SetNetworkTypes([]webrtc.NetworkType{webrtc.NetworkTypeTCP4})
}
pc, err := newPeerConnection(configuration, webrtc.WithSettingEngine(settingsEngine))
if err != nil {
return err
}
pcConnected := make(chan struct{})
pcDisconnected := make(chan struct{})
pcClosed := make(chan struct{})
var stateChangeMutex sync.Mutex
pc.OnConnectionStateChange(func(state webrtc.PeerConnectionState) {
stateChangeMutex.Lock()
defer stateChangeMutex.Unlock()
select {
case <-pcClosed:
return
default:
}
c.log(logger.Debug, "peer connection state: "+state.String())
switch state {
case webrtc.PeerConnectionStateConnected:
close(pcConnected)
case webrtc.PeerConnectionStateDisconnected:
close(pcDisconnected)
case webrtc.PeerConnectionStateClosed:
close(pcClosed)
}
})
defer func() {
pc.Close()
<-pcClosed
}()
for _, track := range tracks {
rtpSender, err := pc.AddTrack(track.webRTCTrack)
if err != nil {
return err
}
// read incoming RTCP packets in order to make interceptors work
go func() {
buf := make([]byte, 1500)
for {
_, _, err := rtpSender.Read(buf)
if err != nil {
return
}
}
}()
}
localCandidate := make(chan *webrtc.ICECandidateInit)
pc.OnICECandidate(func(i *webrtc.ICECandidate) {
if i != nil {
v := i.ToJSON()
select {
case localCandidate <- &v:
case <-pcConnected:
case <-ctx.Done():
}
}
})
err = pc.SetRemoteDescription(*offer)
if err != nil {
return err
}
answer, err := pc.CreateAnswer(nil)
if err != nil {
return err
}
err = pc.SetLocalDescription(answer)
if err != nil {
return err
}
err = c.wsconn.WriteJSON(&answer)
if err != nil {
return err
}
wsReadError := make(chan error)
remoteCandidate := make(chan *webrtc.ICECandidateInit)
go func() {
for {
candidate, err := c.readCandidate()
if err != nil {
select {
case wsReadError <- err:
case <-ctx.Done():
}
return
}
select {
case remoteCandidate <- candidate:
case <-pcConnected:
case <-ctx.Done():
}
}
}()
t := time.NewTimer(webrtcHandshakeDeadline)
defer t.Stop()
outer:
for {
select {
case candidate := <-localCandidate:
c.log(logger.Debug, "local candidate: %+v", candidate.Candidate)
err := c.wsconn.WriteJSON(candidate)
if err != nil {
return err
}
case candidate := <-remoteCandidate:
c.log(logger.Debug, "remote candidate: %+v", candidate.Candidate)
err := pc.AddICECandidate(*candidate)
if err != nil {
return err
}
case err := <-wsReadError:
return err
case <-t.C:
return fmt.Errorf("deadline exceeded")
case <-pcConnected:
break outer
case <-ctx.Done():
return fmt.Errorf("terminated")
}
}
// Keep WebSocket connection open and use it to notify shutdowns.
// This is because pion/webrtc doesn't write yet a WebRTC shutdown
// message to clients (like a DTLS close alert or a RTCP BYE),
// therefore browsers do not properly detect shutdowns and do not
// attempt to restart the connection immediately.
c.mutex.Lock()
c.curPC = pc
c.mutex.Unlock()
c.log(logger.Info, "peer connection established, local candidate: %v, remote candidate: %v",
c.localCandidate(), c.remoteCandidate())
ringBuffer, _ := ringbuffer.New(uint64(c.readBufferCount))
defer ringBuffer.Close()
writeError := make(chan error)
for _, track := range tracks {
ctrack := track
res.stream.readerAdd(c, track.media, track.format, func(unit formatprocessor.Unit) {
ringBuffer.Push(func() {
ctrack.cb(unit, ctx, writeError)
})
})
}
defer res.stream.readerRemove(c)
c.log(logger.Info, "is reading from path '%s', %s",
path.name, sourceMediaInfo(gatherMedias(tracks)))
go func() {
for {
item, ok := ringBuffer.Pull()
if !ok {
return
}
item.(func())()
}
}()
select {
case <-pcDisconnected:
return fmt.Errorf("peer connection closed")
case err := <-wsReadError:
return fmt.Errorf("websocket error: %v", err)
case err := <-writeError:
return err
case <-ctx.Done():
return fmt.Errorf("terminated")
}
}
func (c *webRTCConn) allocateTracks(medias media.Medias) ([]*webRTCTrack, error) {
var ret []*webRTCTrack
var vp9Format *format.VP9
vp9Media := medias.FindFormat(&vp9Format)
if vp9Format != nil {
webRTCTrak, err := webrtc.NewTrackLocalStaticRTP(
webrtc.RTPCodecCapability{
MimeType: webrtc.MimeTypeVP9,
ClockRate: uint32(vp9Format.ClockRate()),
},
"vp9",
"rtspss",
)
if err != nil {
return nil, err
}
encoder := &rtpvp9.Encoder{
PayloadType: 96,
PayloadMaxSize: webrtcPayloadMaxSize,
}
encoder.Init()
ret = append(ret, &webRTCTrack{
media: vp9Media,
format: vp9Format,
webRTCTrack: webRTCTrak,
cb: func(unit formatprocessor.Unit, ctx context.Context, writeError chan error) {
tunit := unit.(*formatprocessor.UnitVP9)
if tunit.Frame == nil {
return
}
packets, err := encoder.Encode(tunit.Frame, tunit.PTS)
if err != nil {
return
}
for _, pkt := range packets {
webRTCTrak.WriteRTP(pkt)
}
},
})
}
var vp8Format *format.VP8
if vp9Format == nil {
vp8Media := medias.FindFormat(&vp8Format)
if vp8Format != nil {
webRTCTrak, err := webrtc.NewTrackLocalStaticRTP(
webrtc.RTPCodecCapability{
MimeType: webrtc.MimeTypeVP8,
ClockRate: uint32(vp8Format.ClockRate()),
},
"vp8",
"rtspss",
)
if err != nil {
return nil, err
}
encoder := &rtpvp8.Encoder{
PayloadType: 96,
PayloadMaxSize: webrtcPayloadMaxSize,
}
encoder.Init()
ret = append(ret, &webRTCTrack{
media: vp8Media,
format: vp8Format,
webRTCTrack: webRTCTrak,
cb: func(unit formatprocessor.Unit, ctx context.Context, writeError chan error) {
tunit := unit.(*formatprocessor.UnitVP8)
if tunit.Frame == nil {
return
}
packets, err := encoder.Encode(tunit.Frame, tunit.PTS)
if err != nil {
return
}
for _, pkt := range packets {
webRTCTrak.WriteRTP(pkt)
}
},
})
}
}
if vp9Format == nil && vp8Format == nil {
var h264Format *format.H264
h264Media := medias.FindFormat(&h264Format)
if h264Format != nil {
webRTCTrak, err := webrtc.NewTrackLocalStaticRTP(
webrtc.RTPCodecCapability{
MimeType: webrtc.MimeTypeH264,
ClockRate: uint32(h264Format.ClockRate()),
},
"h264",
"rtspss",
)
if err != nil {
return nil, err
}
encoder := &rtph264.Encoder{
PayloadType: 96,
PayloadMaxSize: webrtcPayloadMaxSize,
}
encoder.Init()
var lastPTS time.Duration
firstNALUReceived := false
ret = append(ret, &webRTCTrack{
media: h264Media,
format: h264Format,
webRTCTrack: webRTCTrak,
cb: func(unit formatprocessor.Unit, ctx context.Context, writeError chan error) {
tunit := unit.(*formatprocessor.UnitH264)
if tunit.AU == nil {
return
}
if !firstNALUReceived {
firstNALUReceived = true
lastPTS = tunit.PTS
} else {
if tunit.PTS < lastPTS {
select {
case writeError <- fmt.Errorf("WebRTC doesn't support H264 streams with B-frames"):
case <-ctx.Done():
}
return
}
lastPTS = tunit.PTS
}
packets, err := encoder.Encode(tunit.AU, tunit.PTS)
if err != nil {
return
}
for _, pkt := range packets {
webRTCTrak.WriteRTP(pkt)
}
},
})
}
}
var opusFormat *format.Opus
opusMedia := medias.FindFormat(&opusFormat)
if opusFormat != nil {
webRTCTrak, err := webrtc.NewTrackLocalStaticRTP(
webrtc.RTPCodecCapability{
MimeType: webrtc.MimeTypeOpus,
ClockRate: uint32(opusFormat.ClockRate()),
},
"opus",
"rtspss",
)
if err != nil {
return nil, err
}
ret = append(ret, &webRTCTrack{
media: opusMedia,
format: opusFormat,
webRTCTrack: webRTCTrak,
cb: func(unit formatprocessor.Unit, ctx context.Context, writeError chan error) {
for _, pkt := range unit.GetRTPPackets() {
webRTCTrak.WriteRTP(pkt)
}
},
})
}
var g722Format *format.G722
if opusFormat == nil {
g722Media := medias.FindFormat(&g722Format)
if g722Format != nil {
webRTCTrak, err := webrtc.NewTrackLocalStaticRTP(
webrtc.RTPCodecCapability{
MimeType: webrtc.MimeTypeG722,
ClockRate: uint32(g722Format.ClockRate()),
},
"g722",
"rtspss",
)
if err != nil {
return nil, err
}
ret = append(ret, &webRTCTrack{
media: g722Media,
format: g722Format,
webRTCTrack: webRTCTrak,
cb: func(unit formatprocessor.Unit, ctx context.Context, writeError chan error) {
for _, pkt := range unit.GetRTPPackets() {
webRTCTrak.WriteRTP(pkt)
}
},
})
}
}
var g711Format *format.G711
if opusFormat == nil && g722Format == nil {
g711Media := medias.FindFormat(&g711Format)
if g711Format != nil {
var mtyp string
if g711Format.MULaw {
mtyp = webrtc.MimeTypePCMU
} else {
mtyp = webrtc.MimeTypePCMA
}
webRTCTrak, err := webrtc.NewTrackLocalStaticRTP(
webrtc.RTPCodecCapability{
MimeType: mtyp,
ClockRate: uint32(g711Format.ClockRate()),
},
"g711",
"rtspss",
)
if err != nil {
return nil, err
}
ret = append(ret, &webRTCTrack{
media: g711Media,
format: g711Format,
webRTCTrack: webRTCTrak,
cb: func(unit formatprocessor.Unit, ctx context.Context, writeError chan error) {
for _, pkt := range unit.GetRTPPackets() {
webRTCTrak.WriteRTP(pkt)
}
},
})
}
}
if ret == nil {
return nil, fmt.Errorf(
"the stream doesn't contain any supported codec (which are currently VP9, VP8, H264, Opus, G722, G711)")
}
return ret, nil
}
func (c *webRTCConn) genICEServers() []webrtc.ICEServer {
ret := make([]webrtc.ICEServer, len(c.iceServers))
for i, s := range c.iceServers {
parts := strings.Split(s, ":")
if len(parts) == 5 {
if parts[1] == "AUTH_SECRET" {
s := webrtc.ICEServer{
URLs: []string{parts[0] + ":" + parts[3] + ":" + parts[4]},
}
randomUser := func() string {
const charset = "abcdefghijklmnopqrstuvwxyz1234567890"
b := make([]byte, 20)
for i := range b {
b[i] = charset[rand.Intn(len(charset))]
}
return string(b)
}()
expireDate := time.Now().Add(24 * 3600 * time.Second).Unix()
s.Username = strconv.FormatInt(expireDate, 10) + ":" + randomUser
h := hmac.New(sha1.New, []byte(parts[2]))
h.Write([]byte(s.Username))
s.Credential = base64.StdEncoding.EncodeToString(h.Sum(nil))
ret[i] = s
} else {
ret[i] = webrtc.ICEServer{
URLs: []string{parts[0] + ":" + parts[3] + ":" + parts[4]},
Username: parts[1],
Credential: parts[2],
}
}
} else {
ret[i] = webrtc.ICEServer{
URLs: []string{s},
}
}
}
return ret
}
func (c *webRTCConn) readOffer() (*webrtc.SessionDescription, error) {
var offer webrtc.SessionDescription
err := c.wsconn.ReadJSON(&offer)
if err != nil {
return nil, err
}
if offer.Type != webrtc.SDPTypeOffer {
return nil, fmt.Errorf("received SDP is not an offer")
}
return &offer, nil
}
func (c *webRTCConn) readCandidate() (*webrtc.ICECandidateInit, error) {
var candidate webrtc.ICECandidateInit
err := c.wsconn.ReadJSON(&candidate)
if err != nil {
return nil, err
}
return &candidate, err
}
// apiReaderDescribe implements reader.
func (c *webRTCConn) apiReaderDescribe() interface{} {
return struct {
Type string `json:"type"`
ID string `json:"id"`
}{"webRTCConn", c.uuid.String()}
}