diff --git a/conf.go b/conf.go new file mode 100644 index 00000000..1b14fabb --- /dev/null +++ b/conf.go @@ -0,0 +1,199 @@ +package main + +import ( + "fmt" + "io" + "os" + "regexp" + "time" + + "github.com/aler9/gortsplib" + "gopkg.in/yaml.v2" +) + +type ConfPath struct { + Source string `yaml:"source"` + SourceProtocol string `yaml:"sourceProtocol"` + PublishUser string `yaml:"publishUser"` + PublishPass string `yaml:"publishPass"` + PublishIps []string `yaml:"publishIps"` + publishIpsParsed []interface{} + ReadUser string `yaml:"readUser"` + ReadPass string `yaml:"readPass"` + ReadIps []string `yaml:"readIps"` + readIpsParsed []interface{} +} + +type conf struct { + Protocols []string `yaml:"protocols"` + protocolsParsed map[streamProtocol]struct{} + RtspPort int `yaml:"rtspPort"` + RtpPort int `yaml:"rtpPort"` + RtcpPort int `yaml:"rtcpPort"` + PreScript string `yaml:"preScript"` + PostScript string `yaml:"postScript"` + ReadTimeout time.Duration `yaml:"readTimeout"` + WriteTimeout time.Duration `yaml:"writeTimeout"` + AuthMethods []string `yaml:"authMethods"` + authMethodsParsed []gortsplib.AuthMethod + Pprof bool `yaml:"pprof"` + Paths map[string]*ConfPath `yaml:"paths"` +} + +func loadConf(fpath string, stdin io.Reader) (*conf, error) { + conf := &conf{} + + err := func() error { + if fpath == "stdin" { + err := yaml.NewDecoder(stdin).Decode(conf) + if err != nil { + return err + } + + return nil + + } else { + // conf.yml is optional + if fpath == "conf.yml" { + if _, err := os.Stat(fpath); err != nil { + return nil + } + } + + f, err := os.Open(fpath) + if err != nil { + return err + } + defer f.Close() + + err = yaml.NewDecoder(f).Decode(conf) + if err != nil { + return err + } + + return nil + } + }() + if err != nil { + return nil, err + } + + if len(conf.Protocols) == 0 { + conf.Protocols = []string{"udp", "tcp"} + } + conf.protocolsParsed = make(map[streamProtocol]struct{}) + for _, proto := range conf.Protocols { + switch proto { + case "udp": + conf.protocolsParsed[_STREAM_PROTOCOL_UDP] = struct{}{} + + case "tcp": + conf.protocolsParsed[_STREAM_PROTOCOL_TCP] = struct{}{} + + default: + return nil, fmt.Errorf("unsupported protocol: %s", proto) + } + } + if len(conf.protocolsParsed) == 0 { + return nil, fmt.Errorf("no protocols provided") + } + + if conf.RtspPort == 0 { + conf.RtspPort = 8554 + } + if conf.RtpPort == 0 { + conf.RtpPort = 8000 + } + if (conf.RtpPort % 2) != 0 { + return nil, fmt.Errorf("rtp port must be even") + } + if conf.RtcpPort == 0 { + conf.RtcpPort = 8001 + } + if conf.RtcpPort != (conf.RtpPort + 1) { + return nil, fmt.Errorf("rtcp and rtp ports must be consecutive") + } + + if conf.ReadTimeout == 0 { + conf.ReadTimeout = 5 * time.Second + } + if conf.WriteTimeout == 0 { + conf.WriteTimeout = 5 * time.Second + } + + if len(conf.AuthMethods) == 0 { + conf.AuthMethods = []string{"basic", "digest"} + } + for _, method := range conf.AuthMethods { + switch method { + case "basic": + conf.authMethodsParsed = append(conf.authMethodsParsed, gortsplib.Basic) + + case "digest": + conf.authMethodsParsed = append(conf.authMethodsParsed, gortsplib.Digest) + + default: + return nil, fmt.Errorf("unsupported authentication method: %s", method) + } + } + + if len(conf.Paths) == 0 { + conf.Paths = map[string]*ConfPath{ + "all": {}, + } + } + + for path, pconf := range conf.Paths { + if pconf.Source == "" { + pconf.Source = "record" + } + + if pconf.PublishUser != "" { + if !regexp.MustCompile("^[a-zA-Z0-9]+$").MatchString(pconf.PublishUser) { + return nil, fmt.Errorf("publish username must be alphanumeric") + } + } + if pconf.PublishPass != "" { + if !regexp.MustCompile("^[a-zA-Z0-9]+$").MatchString(pconf.PublishPass) { + return nil, fmt.Errorf("publish password must be alphanumeric") + } + } + pconf.publishIpsParsed, err = parseIpCidrList(pconf.PublishIps) + if err != nil { + return nil, err + } + + if pconf.ReadUser != "" && pconf.ReadPass == "" || pconf.ReadUser == "" && pconf.ReadPass != "" { + return nil, fmt.Errorf("read username and password must be both filled") + } + if pconf.ReadUser != "" { + if !regexp.MustCompile("^[a-zA-Z0-9]+$").MatchString(pconf.ReadUser) { + return nil, fmt.Errorf("read username must be alphanumeric") + } + } + if pconf.ReadPass != "" { + if !regexp.MustCompile("^[a-zA-Z0-9]+$").MatchString(pconf.ReadPass) { + return nil, fmt.Errorf("read password must be alphanumeric") + } + } + if pconf.ReadUser != "" && pconf.ReadPass == "" || pconf.ReadUser == "" && pconf.ReadPass != "" { + return nil, fmt.Errorf("read username and password must be both filled") + } + pconf.readIpsParsed, err = parseIpCidrList(pconf.ReadIps) + if err != nil { + return nil, err + } + + if pconf.Source != "record" { + if path == "all" { + return nil, fmt.Errorf("path 'all' cannot have a RTSP source") + } + + if pconf.SourceProtocol == "" { + pconf.SourceProtocol = "udp" + } + } + } + + return conf, nil +} diff --git a/conf.yml b/conf.yml index ce9563c3..2fbcb123 100644 --- a/conf.yml +++ b/conf.yml @@ -15,6 +15,8 @@ postScript: readTimeout: 5s # timeout of write operations writeTimeout: 5s +# supported authentication methods +authMethods: [basic, digest] # enable pprof on port 9999 to monitor performance pprof: false diff --git a/main.go b/main.go index 48596f43..de1b0937 100644 --- a/main.go +++ b/main.go @@ -8,12 +8,9 @@ import ( "net/http" _ "net/http/pprof" "os" - "regexp" - "time" "github.com/aler9/gortsplib" "gopkg.in/alecthomas/kingpin.v2" - "gopkg.in/yaml.v2" "gortc.io/sdp" ) @@ -154,66 +151,6 @@ type programEventTerminate struct{} func (programEventTerminate) isProgramEvent() {} -type ConfPath struct { - Source string `yaml:"source"` - SourceProtocol string `yaml:"sourceProtocol"` - PublishUser string `yaml:"publishUser"` - PublishPass string `yaml:"publishPass"` - PublishIps []string `yaml:"publishIps"` - publishIpsParsed []interface{} - ReadUser string `yaml:"readUser"` - ReadPass string `yaml:"readPass"` - ReadIps []string `yaml:"readIps"` - readIpsParsed []interface{} -} - -type conf struct { - Protocols []string `yaml:"protocols"` - RtspPort int `yaml:"rtspPort"` - RtpPort int `yaml:"rtpPort"` - RtcpPort int `yaml:"rtcpPort"` - ReadTimeout time.Duration `yaml:"readTimeout"` - WriteTimeout time.Duration `yaml:"writeTimeout"` - PreScript string `yaml:"preScript"` - PostScript string `yaml:"postScript"` - Pprof bool `yaml:"pprof"` - Paths map[string]*ConfPath `yaml:"paths"` -} - -func loadConf(fpath string, stdin io.Reader) (*conf, error) { - if fpath == "stdin" { - var ret conf - err := yaml.NewDecoder(stdin).Decode(&ret) - if err != nil { - return nil, err - } - - return &ret, nil - - } else { - // conf.yml is optional - if fpath == "conf.yml" { - if _, err := os.Stat(fpath); err != nil { - return &conf{}, nil - } - } - - f, err := os.Open(fpath) - if err != nil { - return nil, err - } - defer f.Close() - - var ret conf - err = yaml.NewDecoder(f).Decode(&ret) - if err != nil { - return nil, err - } - - return &ret, nil - } -} - // a publisher can be either a serverClient or a streamer type publisher interface { publisherIsReady() bool @@ -223,7 +160,6 @@ type publisher interface { type program struct { conf *conf - protocols map[streamProtocol]struct{} rtspl *serverTcpListener rtpl *serverUdpListener rtcpl *serverUdpListener @@ -256,58 +192,8 @@ func newProgram(sargs []string, stdin io.Reader) (*program, error) { return nil, err } - if conf.ReadTimeout == 0 { - conf.ReadTimeout = 5 * time.Second - } - if conf.WriteTimeout == 0 { - conf.WriteTimeout = 5 * time.Second - } - - if len(conf.Protocols) == 0 { - conf.Protocols = []string{"udp", "tcp"} - } - protocols := make(map[streamProtocol]struct{}) - for _, proto := range conf.Protocols { - switch proto { - case "udp": - protocols[_STREAM_PROTOCOL_UDP] = struct{}{} - - case "tcp": - protocols[_STREAM_PROTOCOL_TCP] = struct{}{} - - default: - return nil, fmt.Errorf("unsupported protocol: %s", proto) - } - } - if len(protocols) == 0 { - return nil, fmt.Errorf("no protocols provided") - } - - if conf.RtspPort == 0 { - conf.RtspPort = 8554 - } - if conf.RtpPort == 0 { - conf.RtpPort = 8000 - } - if (conf.RtpPort % 2) != 0 { - return nil, fmt.Errorf("rtp port must be even") - } - if conf.RtcpPort == 0 { - conf.RtcpPort = 8001 - } - if conf.RtcpPort != (conf.RtpPort + 1) { - return nil, fmt.Errorf("rtcp and rtp ports must be consecutive") - } - - if len(conf.Paths) == 0 { - conf.Paths = map[string]*ConfPath{ - "all": {}, - } - } - p := &program{ conf: conf, - protocols: protocols, clients: make(map[*serverClient]struct{}), publishers: make(map[string]publisher), events: make(chan programEvent), @@ -315,55 +201,7 @@ func newProgram(sargs []string, stdin io.Reader) (*program, error) { } for path, pconf := range conf.Paths { - if pconf.Source == "" { - pconf.Source = "record" - } - - if pconf.PublishUser != "" { - if !regexp.MustCompile("^[a-zA-Z0-9]+$").MatchString(pconf.PublishUser) { - return nil, fmt.Errorf("publish username must be alphanumeric") - } - } - if pconf.PublishPass != "" { - if !regexp.MustCompile("^[a-zA-Z0-9]+$").MatchString(pconf.PublishPass) { - return nil, fmt.Errorf("publish password must be alphanumeric") - } - } - pconf.publishIpsParsed, err = parseIpCidrList(pconf.PublishIps) - if err != nil { - return nil, err - } - - if pconf.ReadUser != "" && pconf.ReadPass == "" || pconf.ReadUser == "" && pconf.ReadPass != "" { - return nil, fmt.Errorf("read username and password must be both filled") - } - if pconf.ReadUser != "" { - if !regexp.MustCompile("^[a-zA-Z0-9]+$").MatchString(pconf.ReadUser) { - return nil, fmt.Errorf("read username must be alphanumeric") - } - } - if pconf.ReadPass != "" { - if !regexp.MustCompile("^[a-zA-Z0-9]+$").MatchString(pconf.ReadPass) { - return nil, fmt.Errorf("read password must be alphanumeric") - } - } - if pconf.ReadUser != "" && pconf.ReadPass == "" || pconf.ReadUser == "" && pconf.ReadPass != "" { - return nil, fmt.Errorf("read username and password must be both filled") - } - pconf.readIpsParsed, err = parseIpCidrList(pconf.ReadIps) - if err != nil { - return nil, err - } - if pconf.Source != "record" { - if path == "all" { - return nil, fmt.Errorf("path 'all' cannot have a RTSP source") - } - - if pconf.SourceProtocol == "" { - pconf.SourceProtocol = "udp" - } - s, err := newStreamer(p, path, pconf.Source, pconf.SourceProtocol) if err != nil { return nil, err diff --git a/server-client.go b/server-client.go index 508cfff5..6bc70aaf 100644 --- a/server-client.go +++ b/server-client.go @@ -489,7 +489,7 @@ func (c *serverClient) authenticate(ips []interface{}, user string, pass string, if c.authHelper == nil || c.authUser != user || c.authPass != pass { c.authUser = user c.authPass = pass - c.authHelper = gortsplib.NewAuthServer(user, pass, nil) + c.authHelper = gortsplib.NewAuthServer(user, pass, c.p.conf.authMethodsParsed) } err := c.authHelper.ValidateHeader(req.Header["Authorization"], req.Method, req.Url) @@ -732,7 +732,7 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) error { } return false }() { - if _, ok := c.p.protocols[_STREAM_PROTOCOL_UDP]; !ok { + if _, ok := c.p.conf.protocolsParsed[_STREAM_PROTOCOL_UDP]; !ok { c.writeResError(req, gortsplib.StatusUnsupportedTransport, fmt.Errorf("UDP streaming is disabled")) return errClientTerminate } @@ -778,7 +778,7 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) error { // play via TCP } else if _, ok := th["RTP/AVP/TCP"]; ok { - if _, ok := c.p.protocols[_STREAM_PROTOCOL_TCP]; !ok { + if _, ok := c.p.conf.protocolsParsed[_STREAM_PROTOCOL_TCP]; !ok { c.writeResError(req, gortsplib.StatusUnsupportedTransport, fmt.Errorf("TCP streaming is disabled")) return errClientTerminate } @@ -847,7 +847,7 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) error { } return false }() { - if _, ok := c.p.protocols[_STREAM_PROTOCOL_UDP]; !ok { + if _, ok := c.p.conf.protocolsParsed[_STREAM_PROTOCOL_UDP]; !ok { c.writeResError(req, gortsplib.StatusUnsupportedTransport, fmt.Errorf("UDP streaming is disabled")) return errClientTerminate } @@ -893,7 +893,7 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) error { // record via TCP } else if _, ok := th["RTP/AVP/TCP"]; ok { - if _, ok := c.p.protocols[_STREAM_PROTOCOL_TCP]; !ok { + if _, ok := c.p.conf.protocolsParsed[_STREAM_PROTOCOL_TCP]; !ok { c.writeResError(req, gortsplib.StatusUnsupportedTransport, fmt.Errorf("TCP streaming is disabled")) return errClientTerminate }