add --read-ips and --publish-ips arguments; fix #12

This commit is contained in:
aler9 2020-06-15 22:41:14 +02:00
parent e644587a6c
commit fbd9f74c8b
4 changed files with 117 additions and 35 deletions

View File

@ -91,8 +91,10 @@ Flags:
--write-timeout=5s timeout of write operations
--publish-user="" optional username required to publish
--publish-pass="" optional password required to publish
--publish-ips="" comma-separated list of IPs or networks (x.x.x.x/24) that can publish
--read-user="" optional username required to read
--read-pass="" optional password required to read
--read-ips="" comma-separated list of IPs or networks (x.x.x.x/24) that can read
--pre-script="" optional script to run on client connect
--post-script="" optional script to run on client disconnect
```

48
main.go
View File

@ -3,6 +3,7 @@ package main
import (
"fmt"
"log"
"net"
"os"
"regexp"
"strings"
@ -13,6 +14,30 @@ import (
var Version string = "v0.0.0"
func parseIpCidrList(in string) ([]interface{}, error) {
if in == "" {
return nil, nil
}
var ret []interface{}
for _, t := range strings.Split(in, ",") {
_, ipnet, err := net.ParseCIDR(t)
if err == nil {
ret = append(ret, ipnet)
continue
}
ip := net.ParseIP(t)
if ip != nil {
ret = append(ret, ip)
continue
}
return nil, fmt.Errorf("unable to parse ip/network '%s'", t)
}
return ret, nil
}
type trackFlow int
const (
@ -49,8 +74,10 @@ type args struct {
writeTimeout time.Duration
publishUser string
publishPass string
publishIps string
readUser string
readPass string
readIps string
preScript string
postScript string
}
@ -58,6 +85,8 @@ type args struct {
type program struct {
args args
protocols map[streamProtocol]struct{}
publishIps []interface{}
readIps []interface{}
tcpl *serverTcpListener
udplRtp *serverUdpListener
udplRtcp *serverUdpListener
@ -76,8 +105,10 @@ func newProgram(sargs []string) (*program, error) {
argWriteTimeout := kingpin.Flag("write-timeout", "timeout of write operations").Default("5s").Duration()
argPublishUser := kingpin.Flag("publish-user", "optional username required to publish").Default("").String()
argPublishPass := kingpin.Flag("publish-pass", "optional password required to publish").Default("").String()
argPublishIps := kingpin.Flag("publish-ips", "comma-separated list of IPs or networks (x.x.x.x/24) that can publish").Default("").String()
argReadUser := kingpin.Flag("read-user", "optional username required to read").Default("").String()
argReadPass := kingpin.Flag("read-pass", "optional password required to read").Default("").String()
argReadIps := kingpin.Flag("read-ips", "comma-separated list of IPs or networks (x.x.x.x/24) that can read").Default("").String()
argPreScript := kingpin.Flag("pre-script", "optional script to run on client connect").Default("").String()
argPostScript := kingpin.Flag("post-script", "optional script to run on client disconnect").Default("").String()
@ -93,8 +124,10 @@ func newProgram(sargs []string) (*program, error) {
writeTimeout: *argWriteTimeout,
publishUser: *argPublishUser,
publishPass: *argPublishPass,
publishIps: *argPublishIps,
readUser: *argReadUser,
readPass: *argReadPass,
readIps: *argReadIps,
preScript: *argPreScript,
postScript: *argPostScript,
}
@ -120,12 +153,14 @@ func newProgram(sargs []string) (*program, error) {
if len(protocols) == 0 {
return nil, fmt.Errorf("no protocols provided")
}
if (args.rtpPort % 2) != 0 {
return nil, fmt.Errorf("rtp port must be even")
}
if args.rtcpPort != (args.rtpPort + 1) {
return nil, fmt.Errorf("rtcp and rtp ports must be consecutive")
}
if args.publishUser != "" {
if !regexp.MustCompile("^[a-zA-Z0-9]+$").MatchString(args.publishUser) {
return nil, fmt.Errorf("publish username must be alphanumeric")
@ -136,6 +171,11 @@ func newProgram(sargs []string) (*program, error) {
return nil, fmt.Errorf("publish password must be alphanumeric")
}
}
publishIps, err := parseIpCidrList(args.publishIps)
if err != nil {
return nil, err
}
if args.readUser != "" && args.readPass == "" || args.readUser == "" && args.readPass != "" {
return nil, fmt.Errorf("read username and password must be both filled")
}
@ -152,16 +192,20 @@ func newProgram(sargs []string) (*program, error) {
if args.readUser != "" && args.readPass == "" || args.readUser == "" && args.readPass != "" {
return nil, fmt.Errorf("read username and password must be both filled")
}
readIps, err := parseIpCidrList(args.readIps)
if err != nil {
return nil, err
}
log.Printf("rtsp-simple-server %s", Version)
p := &program{
args: args,
protocols: protocols,
publishIps: publishIps,
readIps: readIps,
}
var err error
p.udplRtp, err = newServerUdpListener(p, args.rtpPort, _TRACK_FLOW_RTP)
if err != nil {
return nil, err

View File

@ -142,6 +142,7 @@ func TestPublishAuth(t *testing.T) {
p, err := newProgram([]string{
"--publish-user=testuser",
"--publish-pass=testpass",
"--publish-ips=172.17.0.0/16",
})
require.NoError(t, err)
defer p.close()
@ -185,6 +186,7 @@ func TestReadAuth(t *testing.T) {
p, err := newProgram([]string{
"--read-user=testuser",
"--read-pass=testpass",
"--read-ips=172.17.0.0/16",
})
require.NoError(t, err)
defer p.close()

View File

@ -202,7 +202,36 @@ func (c *serverClient) writeResError(req *gortsplib.Request, code gortsplib.Stat
var errAuthCritical = errors.New("auth critical")
var errAuthNotCritical = errors.New("auth not critical")
func (c *serverClient) validateAuth(req *gortsplib.Request, user string, pass string, auth **gortsplib.AuthServer) error {
func (c *serverClient) validateAuth(req *gortsplib.Request, user string, pass string, auth **gortsplib.AuthServer, ips []interface{}) error {
err := func() error {
if ips == nil {
return nil
}
connIp := c.conn.NetConn().LocalAddr().(*net.TCPAddr).IP
for _, item := range ips {
switch titem := item.(type) {
case net.IP:
if titem.Equal(connIp) {
return nil
}
case *net.IPNet:
if titem.Contains(connIp) {
return nil
}
}
}
c.log("ERR: ip '%s' not allowed", connIp)
return errAuthCritical
}()
if err != nil {
return err
}
err = func() error {
if user == "" {
return nil
}
@ -216,7 +245,7 @@ func (c *serverClient) validateAuth(req *gortsplib.Request, user string, pass st
err := (*auth).ValidateHeader(req.Header["Authorization"], req.Method, req.Url)
if err != nil {
if !initialRequest {
c.log("ERR: Unauthorized: %s", err)
c.log("ERR: unauthorized: %s", err)
}
c.conn.WriteResponse(&gortsplib.Response{
@ -233,6 +262,11 @@ func (c *serverClient) validateAuth(req *gortsplib.Request, user string, pass st
return errAuthNotCritical
}
return nil
}()
if err != nil {
return err
}
return nil
}
@ -291,7 +325,7 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool {
return false
}
err := c.validateAuth(req, c.p.args.readUser, c.p.args.readPass, &c.readAuth)
err := c.validateAuth(req, c.p.args.readUser, c.p.args.readPass, &c.readAuth, c.p.readIps)
if err != nil {
if err == errAuthCritical {
return false
@ -333,7 +367,7 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool {
return false
}
err := c.validateAuth(req, c.p.args.publishUser, c.p.args.publishPass, &c.publishAuth)
err := c.validateAuth(req, c.p.args.publishUser, c.p.args.publishPass, &c.publishAuth, c.p.publishIps)
if err != nil {
if err == errAuthCritical {
return false
@ -405,7 +439,7 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool {
switch c.state {
// play
case _CLIENT_STATE_STARTING, _CLIENT_STATE_PRE_PLAY:
err := c.validateAuth(req, c.p.args.readUser, c.p.args.readPass, &c.readAuth)
err := c.validateAuth(req, c.p.args.readUser, c.p.args.readPass, &c.readAuth, c.p.readIps)
if err != nil {
if err == errAuthCritical {
return false