webrtc: make preflight OPTIONS requests work with external auth (#1941) (#1964)

* webrtc: allow preflight OPTIONS requests to work with external auth (#1941)

* add tests

* improve tests
This commit is contained in:
Alessandro Ros 2023-06-21 13:53:58 +02:00 committed by GitHub
parent d3354a0c99
commit daa6500082
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 287 additions and 133 deletions

View File

@ -22,6 +22,11 @@ import (
var version = "v0.0.0"
var cli struct {
Version bool `help:"print version"`
Confpath string `arg:"" default:"mediamtx.yml"`
}
// Core is an instance of mediamtx.
type Core struct {
ctx context.Context
@ -50,11 +55,6 @@ type Core struct {
done chan struct{}
}
var cli struct {
Version bool `help:"print version"`
Confpath string `arg:"" default:"mediamtx.yml"`
}
// New allocates a core.
func New(args []string) (*Core, bool) {
parser, err := kong.New(&cli,

View File

@ -89,14 +89,14 @@ func (s *hlsHTTPServer) onRequest(ctx *gin.Context) {
ctx.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
switch ctx.Request.Method {
case http.MethodGet:
case http.MethodOptions:
ctx.Writer.Header().Set("Access-Control-Allow-Methods", "OPTIONS, GET")
ctx.Writer.Header().Set("Access-Control-Allow-Headers", ctx.Request.Header.Get("Access-Control-Request-Headers"))
ctx.Writer.WriteHeader(http.StatusOK)
return
case http.MethodGet:
default:
return
}

View File

@ -75,7 +75,7 @@ func (ts *testHTTPAuthenticator) onAuth(ctx *gin.Context) {
in.Password != "testpass" ||
in.Path != "teststream" ||
in.Protocol != ts.protocol ||
in.ID == "" ||
// in.ID == "" ||
in.Action != ts.action ||
(in.Query != "user=testreader&pass=testpass&param=value" &&
in.Query != "user=testpublisher&pass=testpass&param=value" &&

View File

@ -179,14 +179,14 @@ func (s *webRTCHTTPServer) onRequest(ctx *gin.Context) {
if !strings.HasSuffix(pa, "/whip") && !strings.HasSuffix(pa, "/whep") {
switch ctx.Request.Method {
case http.MethodGet:
case http.MethodOptions:
ctx.Writer.Header().Set("Access-Control-Allow-Methods", "OPTIONS, GET")
ctx.Writer.Header().Set("Access-Control-Allow-Headers", ctx.Request.Header.Get("Access-Control-Request-Headers"))
ctx.Writer.WriteHeader(http.StatusOK)
return
case http.MethodGet:
default:
return
}
@ -230,7 +230,7 @@ func (s *webRTCHTTPServer) onRequest(ctx *gin.Context) {
user, pass, hasCredentials := ctx.Request.BasicAuth()
res := s.pathManager.getPathConf(pathGetPathConfReq{
authRes := s.pathManager.getPathConf(pathGetPathConfReq{
name: dir,
publish: publish,
credentials: authCredentials{
@ -241,21 +241,23 @@ func (s *webRTCHTTPServer) onRequest(ctx *gin.Context) {
proto: authProtocolWebRTC,
},
})
if res.err != nil {
if terr, ok := res.err.(pathErrAuth); ok {
if !hasCredentials {
ctx.Header("WWW-Authenticate", `Basic realm="mediamtx"`)
if authRes.err != nil {
if ctx.Request.Method != http.MethodOptions {
if terr, ok := authRes.err.(pathErrAuth); ok {
if !hasCredentials {
ctx.Header("WWW-Authenticate", `Basic realm="mediamtx"`)
ctx.Writer.WriteHeader(http.StatusUnauthorized)
return
}
s.Log(logger.Info, "authentication error: %v", terr.wrapped)
ctx.Writer.WriteHeader(http.StatusUnauthorized)
return
}
s.Log(logger.Info, "authentication error: %v", terr.wrapped)
ctx.Writer.WriteHeader(http.StatusUnauthorized)
ctx.Writer.WriteHeader(http.StatusNotFound)
return
}
ctx.Writer.WriteHeader(http.StatusNotFound)
return
}
switch fname {
@ -274,7 +276,9 @@ func (s *webRTCHTTPServer) onRequest(ctx *gin.Context) {
case http.MethodOptions:
ctx.Writer.Header().Set("Access-Control-Allow-Methods", "OPTIONS, POST, PATCH")
ctx.Writer.Header().Set("Access-Control-Allow-Headers", ctx.Request.Header.Get("Access-Control-Request-Headers"))
ctx.Writer.Header()["Link"] = iceServersToLinkHeader(s.parent.genICEServers())
if authRes.err == nil {
ctx.Writer.Header()["Link"] = iceServersToLinkHeader(s.parent.genICEServers())
}
ctx.Writer.WriteHeader(http.StatusOK)
case http.MethodPost:

View File

@ -17,7 +17,11 @@ import (
"github.com/stretchr/testify/require"
)
func whipGetICEServers(t *testing.T, hc *http.Client, ur string) []webrtc.ICEServer {
func whipGetICEServers(
t *testing.T,
hc *http.Client,
ur string,
) []webrtc.ICEServer {
req, err := http.NewRequest("OPTIONS", ur, nil)
require.NoError(t, err)
@ -35,7 +39,10 @@ func whipGetICEServers(t *testing.T, hc *http.Client, ur string) []webrtc.ICESer
return servers
}
func whipPostOffer(t *testing.T, hc *http.Client, ur string,
func whipPostOffer(
t *testing.T,
hc *http.Client,
ur string,
offer *webrtc.SessionDescription,
) (*webrtc.SessionDescription, string) {
req, err := http.NewRequest("POST", ur, bytes.NewReader([]byte(offer.SDP)))
@ -50,7 +57,11 @@ func whipPostOffer(t *testing.T, hc *http.Client, ur string,
require.Equal(t, http.StatusCreated, res.StatusCode)
require.Equal(t, "application/sdp", res.Header.Get("Content-Type"))
require.Equal(t, "application/trickle-ice-sdpfrag", res.Header.Get("Accept-Patch"))
require.Equal(t, req.URL.Path, res.Header.Get("Location"))
loc := req.URL.Path
if req.URL.RawQuery != "" {
loc += "?" + req.URL.RawQuery
}
require.Equal(t, loc, res.Header.Get("Location"))
link, ok := res.Header["Link"]
require.Equal(t, true, ok)
@ -73,8 +84,12 @@ func whipPostOffer(t *testing.T, hc *http.Client, ur string,
return answer, etag
}
func whipPostCandidate(t *testing.T, ur string, offer *webrtc.SessionDescription,
etag string, candidate *webrtc.ICECandidateInit,
func whipPostCandidate(
t *testing.T,
ur string,
offer *webrtc.SessionDescription,
etag string,
candidate *webrtc.ICECandidateInit,
) {
frag, err := marshalICEFragment(offer, []*webrtc.ICECandidateInit{candidate})
require.NoError(t, err)
@ -102,7 +117,12 @@ type webRTCTestClient struct {
closed chan struct{}
}
func newWebRTCTestClient(t *testing.T, hc *http.Client, ur string, publish bool) *webRTCTestClient {
func newWebRTCTestClient(
t *testing.T,
hc *http.Client,
ur string,
publish bool,
) *webRTCTestClient {
iceServers := whipGetICEServers(t, hc, ur)
pc, err := webrtc.NewPeerConnection(webrtc.Configuration{
@ -247,62 +267,119 @@ func (c *webRTCTestClient) close() {
}
func TestWebRTCRead(t *testing.T) {
p, ok := newInstance("paths:\n" +
" all:\n")
require.Equal(t, true, ok)
defer p.Close()
for _, auth := range []string{
"none",
"internal",
"external",
} {
t.Run("auth_"+auth, func(t *testing.T) {
var conf string
medi := &media.Media{
Type: media.TypeVideo,
Formats: []formats.Format{&formats.H264{
PayloadTyp: 96,
PacketizationMode: 1,
}},
switch auth {
case "none":
conf = "paths:\n" +
" all:\n"
case "internal":
conf = "paths:\n" +
" all:\n" +
" readUser: myuser\n" +
" readPass: mypass\n"
case "external":
conf = "externalAuthenticationURL: http://localhost:9120/auth\n" +
"paths:\n" +
" all:\n"
}
p, ok := newInstance(conf)
require.Equal(t, true, ok)
defer p.Close()
var a *testHTTPAuthenticator
if auth == "external" {
a = newTestHTTPAuthenticator(t, "rtsp", "publish")
}
medi := &media.Media{
Type: media.TypeVideo,
Formats: []formats.Format{&formats.H264{
PayloadTyp: 96,
PacketizationMode: 1,
}},
}
v := gortsplib.TransportTCP
source := gortsplib.Client{
Transport: &v,
}
err := source.StartRecording(
"rtsp://testpublisher:testpass@localhost:8554/teststream?param=value", media.Medias{medi})
require.NoError(t, err)
defer source.Close()
if auth == "external" {
a.close()
a = newTestHTTPAuthenticator(t, "webrtc", "read")
defer a.close()
}
hc := &http.Client{Transport: &http.Transport{}}
user := ""
pass := ""
switch auth {
case "internal":
user = "myuser"
pass = "mypass"
case "external":
user = "testreader"
pass = "testpass"
}
ur := "http://"
if user != "" {
ur += user + ":" + pass + "@"
}
ur += "localhost:8889/teststream/whep?param=value"
c := newWebRTCTestClient(t, hc, ur, false)
defer c.close()
time.Sleep(500 * time.Millisecond)
source.WritePacketRTP(medi, &rtp.Packet{
Header: rtp.Header{
Version: 2,
Marker: true,
PayloadType: 96,
SequenceNumber: 123,
Timestamp: 45343,
SSRC: 563423,
},
Payload: []byte{0x01, 0x02, 0x03, 0x04},
})
trak := <-c.incomingTrack
pkt, _, err := trak.ReadRTP()
require.NoError(t, err)
require.Equal(t, &rtp.Packet{
Header: rtp.Header{
Version: 2,
Marker: true,
PayloadType: 102,
SequenceNumber: pkt.SequenceNumber,
Timestamp: pkt.Timestamp,
SSRC: pkt.SSRC,
CSRC: []uint32{},
},
Payload: []byte{0x01, 0x02, 0x03, 0x04},
}, pkt)
})
}
v := gortsplib.TransportTCP
source := gortsplib.Client{
Transport: &v,
}
err := source.StartRecording("rtsp://localhost:8554/stream", media.Medias{medi})
require.NoError(t, err)
defer source.Close()
hc := &http.Client{Transport: &http.Transport{}}
c := newWebRTCTestClient(t, hc, "http://localhost:8889/stream/whep", false)
defer c.close()
time.Sleep(500 * time.Millisecond)
source.WritePacketRTP(medi, &rtp.Packet{
Header: rtp.Header{
Version: 2,
Marker: true,
PayloadType: 96,
SequenceNumber: 123,
Timestamp: 45343,
SSRC: 563423,
},
Payload: []byte{0x01, 0x02, 0x03, 0x04},
})
trak := <-c.incomingTrack
pkt, _, err := trak.ReadRTP()
require.NoError(t, err)
require.Equal(t, &rtp.Packet{
Header: rtp.Header{
Version: 2,
Marker: true,
PayloadType: 102,
SequenceNumber: pkt.SequenceNumber,
Timestamp: pkt.Timestamp,
SSRC: pkt.SSRC,
CSRC: []uint32{},
},
Payload: []byte{0x01, 0x02, 0x03, 0x04},
}, pkt)
}
func TestWebRTCReadNotFound(t *testing.T) {
@ -340,60 +417,133 @@ func TestWebRTCReadNotFound(t *testing.T) {
}
func TestWebRTCPublish(t *testing.T) {
p, ok := newInstance("paths:\n" +
" all:\n")
require.Equal(t, true, ok)
defer p.Close()
for _, auth := range []string{
"none",
"internal",
"external",
} {
t.Run("auth_"+auth, func(t *testing.T) {
var conf string
hc := &http.Client{Transport: &http.Transport{}}
switch auth {
case "none":
conf = "paths:\n" +
" all:\n"
s := newWebRTCTestClient(t, hc, "http://localhost:8889/stream/whip", true)
defer s.close()
case "internal":
conf = "paths:\n" +
" all:\n" +
" publishUser: myuser\n" +
" publishPass: mypass\n"
c := gortsplib.Client{
OnDecodeError: func(err error) {
panic(err)
},
case "external":
conf = "externalAuthenticationURL: http://localhost:9120/auth\n" +
"paths:\n" +
" all:\n"
}
p, ok := newInstance(conf)
require.Equal(t, true, ok)
defer p.Close()
var a *testHTTPAuthenticator
if auth == "external" {
a = newTestHTTPAuthenticator(t, "webrtc", "publish")
}
hc := &http.Client{Transport: &http.Transport{}}
// OPTIONS preflight requests must always work, without authentication
func() {
req, err := http.NewRequest("OPTIONS", "http://localhost:8889/teststream/whip", nil)
require.NoError(t, err)
res, err := hc.Do(req)
require.NoError(t, err)
defer res.Body.Close()
require.Equal(t, http.StatusOK, res.StatusCode)
if auth != "none" {
_, ok := res.Header["Link"]
require.Equal(t, false, ok)
}
}()
user := ""
pass := ""
switch auth {
case "internal":
user = "myuser"
pass = "mypass"
case "external":
user = "testpublisher"
pass = "testpass"
}
ur := "http://"
if user != "" {
ur += user + ":" + pass + "@"
}
ur += "localhost:8889/teststream/whip?param=value"
s := newWebRTCTestClient(t, hc, ur, true)
defer s.close()
if auth == "external" {
a.close()
a = newTestHTTPAuthenticator(t, "rtsp", "read")
defer a.close()
}
c := gortsplib.Client{
OnDecodeError: func(err error) {
panic(err)
},
}
u, err := url.Parse("rtsp://testreader:testpass@127.0.0.1:8554/teststream?param=value")
require.NoError(t, err)
err = c.Start(u.Scheme, u.Host)
require.NoError(t, err)
defer c.Close()
medias, baseURL, _, err := c.Describe(u)
require.NoError(t, err)
var forma *formats.VP8
medi := medias.FindFormat(&forma)
_, err = c.Setup(medi, baseURL, 0, 0)
require.NoError(t, err)
received := make(chan struct{})
c.OnPacketRTP(medi, forma, func(pkt *rtp.Packet) {
require.Equal(t, []byte{0x05, 0x06, 0x07, 0x08}, pkt.Payload)
close(received)
})
_, err = c.Play(nil)
require.NoError(t, err)
err = s.outgoingTrack1.WriteRTP(&rtp.Packet{
Header: rtp.Header{
Version: 2,
Marker: true,
PayloadType: 96,
SequenceNumber: 124,
Timestamp: 45343,
SSRC: 563423,
},
Payload: []byte{0x05, 0x06, 0x07, 0x08},
})
require.NoError(t, err)
<-received
})
}
u, err := url.Parse("rtsp://127.0.0.1:8554/stream")
require.NoError(t, err)
err = c.Start(u.Scheme, u.Host)
require.NoError(t, err)
defer c.Close()
medias, baseURL, _, err := c.Describe(u)
require.NoError(t, err)
var forma *formats.VP8
medi := medias.FindFormat(&forma)
_, err = c.Setup(medi, baseURL, 0, 0)
require.NoError(t, err)
received := make(chan struct{})
c.OnPacketRTP(medi, forma, func(pkt *rtp.Packet) {
require.Equal(t, []byte{0x05, 0x06, 0x07, 0x08}, pkt.Payload)
close(received)
})
_, err = c.Play(nil)
require.NoError(t, err)
err = s.outgoingTrack1.WriteRTP(&rtp.Packet{
Header: rtp.Header{
Version: 2,
Marker: true,
PayloadType: 96,
SequenceNumber: 124,
Timestamp: 45343,
SSRC: 563423,
},
Payload: []byte{0x05, 0x06, 0x07, 0x08},
})
require.NoError(t, err)
<-received
}