mirror of
https://github.com/bluenviron/mediamtx
synced 2025-01-20 22:21:01 +00:00
add --read-ips and --publish-ips arguments; fix #12
This commit is contained in:
parent
e644587a6c
commit
fbd9f74c8b
@ -91,8 +91,10 @@ Flags:
|
||||
--write-timeout=5s timeout of write operations
|
||||
--publish-user="" optional username required to publish
|
||||
--publish-pass="" optional password required to publish
|
||||
--publish-ips="" comma-separated list of IPs or networks (x.x.x.x/24) that can publish
|
||||
--read-user="" optional username required to read
|
||||
--read-pass="" optional password required to read
|
||||
--read-ips="" comma-separated list of IPs or networks (x.x.x.x/24) that can read
|
||||
--pre-script="" optional script to run on client connect
|
||||
--post-script="" optional script to run on client disconnect
|
||||
```
|
||||
|
62
main.go
62
main.go
@ -3,6 +3,7 @@ package main
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
@ -13,6 +14,30 @@ import (
|
||||
|
||||
var Version string = "v0.0.0"
|
||||
|
||||
func parseIpCidrList(in string) ([]interface{}, error) {
|
||||
if in == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var ret []interface{}
|
||||
for _, t := range strings.Split(in, ",") {
|
||||
_, ipnet, err := net.ParseCIDR(t)
|
||||
if err == nil {
|
||||
ret = append(ret, ipnet)
|
||||
continue
|
||||
}
|
||||
|
||||
ip := net.ParseIP(t)
|
||||
if ip != nil {
|
||||
ret = append(ret, ip)
|
||||
continue
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unable to parse ip/network '%s'", t)
|
||||
}
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
type trackFlow int
|
||||
|
||||
const (
|
||||
@ -49,18 +74,22 @@ type args struct {
|
||||
writeTimeout time.Duration
|
||||
publishUser string
|
||||
publishPass string
|
||||
publishIps string
|
||||
readUser string
|
||||
readPass string
|
||||
readIps string
|
||||
preScript string
|
||||
postScript string
|
||||
}
|
||||
|
||||
type program struct {
|
||||
args args
|
||||
protocols map[streamProtocol]struct{}
|
||||
tcpl *serverTcpListener
|
||||
udplRtp *serverUdpListener
|
||||
udplRtcp *serverUdpListener
|
||||
args args
|
||||
protocols map[streamProtocol]struct{}
|
||||
publishIps []interface{}
|
||||
readIps []interface{}
|
||||
tcpl *serverTcpListener
|
||||
udplRtp *serverUdpListener
|
||||
udplRtcp *serverUdpListener
|
||||
}
|
||||
|
||||
func newProgram(sargs []string) (*program, error) {
|
||||
@ -76,8 +105,10 @@ func newProgram(sargs []string) (*program, error) {
|
||||
argWriteTimeout := kingpin.Flag("write-timeout", "timeout of write operations").Default("5s").Duration()
|
||||
argPublishUser := kingpin.Flag("publish-user", "optional username required to publish").Default("").String()
|
||||
argPublishPass := kingpin.Flag("publish-pass", "optional password required to publish").Default("").String()
|
||||
argPublishIps := kingpin.Flag("publish-ips", "comma-separated list of IPs or networks (x.x.x.x/24) that can publish").Default("").String()
|
||||
argReadUser := kingpin.Flag("read-user", "optional username required to read").Default("").String()
|
||||
argReadPass := kingpin.Flag("read-pass", "optional password required to read").Default("").String()
|
||||
argReadIps := kingpin.Flag("read-ips", "comma-separated list of IPs or networks (x.x.x.x/24) that can read").Default("").String()
|
||||
argPreScript := kingpin.Flag("pre-script", "optional script to run on client connect").Default("").String()
|
||||
argPostScript := kingpin.Flag("post-script", "optional script to run on client disconnect").Default("").String()
|
||||
|
||||
@ -93,8 +124,10 @@ func newProgram(sargs []string) (*program, error) {
|
||||
writeTimeout: *argWriteTimeout,
|
||||
publishUser: *argPublishUser,
|
||||
publishPass: *argPublishPass,
|
||||
publishIps: *argPublishIps,
|
||||
readUser: *argReadUser,
|
||||
readPass: *argReadPass,
|
||||
readIps: *argReadIps,
|
||||
preScript: *argPreScript,
|
||||
postScript: *argPostScript,
|
||||
}
|
||||
@ -120,12 +153,14 @@ func newProgram(sargs []string) (*program, error) {
|
||||
if len(protocols) == 0 {
|
||||
return nil, fmt.Errorf("no protocols provided")
|
||||
}
|
||||
|
||||
if (args.rtpPort % 2) != 0 {
|
||||
return nil, fmt.Errorf("rtp port must be even")
|
||||
}
|
||||
if args.rtcpPort != (args.rtpPort + 1) {
|
||||
return nil, fmt.Errorf("rtcp and rtp ports must be consecutive")
|
||||
}
|
||||
|
||||
if args.publishUser != "" {
|
||||
if !regexp.MustCompile("^[a-zA-Z0-9]+$").MatchString(args.publishUser) {
|
||||
return nil, fmt.Errorf("publish username must be alphanumeric")
|
||||
@ -136,6 +171,11 @@ func newProgram(sargs []string) (*program, error) {
|
||||
return nil, fmt.Errorf("publish password must be alphanumeric")
|
||||
}
|
||||
}
|
||||
publishIps, err := parseIpCidrList(args.publishIps)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if args.readUser != "" && args.readPass == "" || args.readUser == "" && args.readPass != "" {
|
||||
return nil, fmt.Errorf("read username and password must be both filled")
|
||||
}
|
||||
@ -152,16 +192,20 @@ func newProgram(sargs []string) (*program, error) {
|
||||
if args.readUser != "" && args.readPass == "" || args.readUser == "" && args.readPass != "" {
|
||||
return nil, fmt.Errorf("read username and password must be both filled")
|
||||
}
|
||||
readIps, err := parseIpCidrList(args.readIps)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Printf("rtsp-simple-server %s", Version)
|
||||
|
||||
p := &program{
|
||||
args: args,
|
||||
protocols: protocols,
|
||||
args: args,
|
||||
protocols: protocols,
|
||||
publishIps: publishIps,
|
||||
readIps: readIps,
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
p.udplRtp, err = newServerUdpListener(p, args.rtpPort, _TRACK_FLOW_RTP)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -142,6 +142,7 @@ func TestPublishAuth(t *testing.T) {
|
||||
p, err := newProgram([]string{
|
||||
"--publish-user=testuser",
|
||||
"--publish-pass=testpass",
|
||||
"--publish-ips=172.17.0.0/16",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer p.close()
|
||||
@ -185,6 +186,7 @@ func TestReadAuth(t *testing.T) {
|
||||
p, err := newProgram([]string{
|
||||
"--read-user=testuser",
|
||||
"--read-pass=testpass",
|
||||
"--read-ips=172.17.0.0/16",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer p.close()
|
||||
|
@ -202,36 +202,70 @@ func (c *serverClient) writeResError(req *gortsplib.Request, code gortsplib.Stat
|
||||
var errAuthCritical = errors.New("auth critical")
|
||||
var errAuthNotCritical = errors.New("auth not critical")
|
||||
|
||||
func (c *serverClient) validateAuth(req *gortsplib.Request, user string, pass string, auth **gortsplib.AuthServer) error {
|
||||
if user == "" {
|
||||
return nil
|
||||
}
|
||||
func (c *serverClient) validateAuth(req *gortsplib.Request, user string, pass string, auth **gortsplib.AuthServer, ips []interface{}) error {
|
||||
err := func() error {
|
||||
if ips == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
initialRequest := false
|
||||
if *auth == nil {
|
||||
initialRequest = true
|
||||
*auth = gortsplib.NewAuthServer(user, pass, nil)
|
||||
}
|
||||
connIp := c.conn.NetConn().LocalAddr().(*net.TCPAddr).IP
|
||||
|
||||
err := (*auth).ValidateHeader(req.Header["Authorization"], req.Method, req.Url)
|
||||
for _, item := range ips {
|
||||
switch titem := item.(type) {
|
||||
case net.IP:
|
||||
if titem.Equal(connIp) {
|
||||
return nil
|
||||
}
|
||||
|
||||
case *net.IPNet:
|
||||
if titem.Contains(connIp) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.log("ERR: ip '%s' not allowed", connIp)
|
||||
return errAuthCritical
|
||||
}()
|
||||
if err != nil {
|
||||
if !initialRequest {
|
||||
c.log("ERR: Unauthorized: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
err = func() error {
|
||||
if user == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
c.conn.WriteResponse(&gortsplib.Response{
|
||||
StatusCode: gortsplib.StatusUnauthorized,
|
||||
Header: gortsplib.Header{
|
||||
"CSeq": []string{req.Header["CSeq"][0]},
|
||||
"WWW-Authenticate": (*auth).GenerateHeader(),
|
||||
},
|
||||
})
|
||||
|
||||
if !initialRequest {
|
||||
return errAuthCritical
|
||||
initialRequest := false
|
||||
if *auth == nil {
|
||||
initialRequest = true
|
||||
*auth = gortsplib.NewAuthServer(user, pass, nil)
|
||||
}
|
||||
|
||||
return errAuthNotCritical
|
||||
err := (*auth).ValidateHeader(req.Header["Authorization"], req.Method, req.Url)
|
||||
if err != nil {
|
||||
if !initialRequest {
|
||||
c.log("ERR: unauthorized: %s", err)
|
||||
}
|
||||
|
||||
c.conn.WriteResponse(&gortsplib.Response{
|
||||
StatusCode: gortsplib.StatusUnauthorized,
|
||||
Header: gortsplib.Header{
|
||||
"CSeq": []string{req.Header["CSeq"][0]},
|
||||
"WWW-Authenticate": (*auth).GenerateHeader(),
|
||||
},
|
||||
})
|
||||
|
||||
if !initialRequest {
|
||||
return errAuthCritical
|
||||
}
|
||||
|
||||
return errAuthNotCritical
|
||||
}
|
||||
return nil
|
||||
}()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
@ -291,7 +325,7 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
err := c.validateAuth(req, c.p.args.readUser, c.p.args.readPass, &c.readAuth)
|
||||
err := c.validateAuth(req, c.p.args.readUser, c.p.args.readPass, &c.readAuth, c.p.readIps)
|
||||
if err != nil {
|
||||
if err == errAuthCritical {
|
||||
return false
|
||||
@ -333,7 +367,7 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
err := c.validateAuth(req, c.p.args.publishUser, c.p.args.publishPass, &c.publishAuth)
|
||||
err := c.validateAuth(req, c.p.args.publishUser, c.p.args.publishPass, &c.publishAuth, c.p.publishIps)
|
||||
if err != nil {
|
||||
if err == errAuthCritical {
|
||||
return false
|
||||
@ -405,7 +439,7 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool {
|
||||
switch c.state {
|
||||
// play
|
||||
case _CLIENT_STATE_STARTING, _CLIENT_STATE_PRE_PLAY:
|
||||
err := c.validateAuth(req, c.p.args.readUser, c.p.args.readPass, &c.readAuth)
|
||||
err := c.validateAuth(req, c.p.args.readUser, c.p.args.readPass, &c.readAuth, c.p.readIps)
|
||||
if err != nil {
|
||||
if err == errAuthCritical {
|
||||
return false
|
||||
|
Loading…
Reference in New Issue
Block a user