diff --git a/main.go b/main.go index 8bdebd1a..7f35b4c1 100644 --- a/main.go +++ b/main.go @@ -63,24 +63,40 @@ type program struct { udplRtcp *serverUdpListener } -func newProgram(args args) (*program, error) { - if args.protocolsStr == "" { - args.protocolsStr = "udp,tcp" - } - if args.rtspPort == 0 { - args.rtspPort = 8554 - } - if args.rtpPort == 0 { - args.rtpPort = 8000 - } - if args.rtcpPort == 0 { - args.rtcpPort = 8001 - } - if args.readTimeout == time.Duration(0) { - args.readTimeout = 5 * time.Second - } - if args.writeTimeout == time.Duration(0) { - args.writeTimeout = 5 * time.Second +func newProgram(sargs []string) (*program, error) { + kingpin.CommandLine.Help = "rtsp-simple-server " + Version + "\n\n" + + "RTSP server." + + argVersion := kingpin.Flag("version", "print rtsp-simple-server version").Bool() + argProtocolsStr := kingpin.Flag("protocols", "supported protocols").Default("udp,tcp").String() + argRtspPort := kingpin.Flag("rtsp-port", "port of the RTSP TCP listener").Default("8554").Int() + argRtpPort := kingpin.Flag("rtp-port", "port of the RTP UDP listener").Default("8000").Int() + argRtcpPort := kingpin.Flag("rtcp-port", "port of the RTCP UDP listener").Default("8001").Int() + argReadTimeout := kingpin.Flag("read-timeout", "timeout for read operations").Default("5s").Duration() + argWriteTimeout := kingpin.Flag("write-timeout", "timeout for write operations").Default("5s").Duration() + argPublishUser := kingpin.Flag("publish-user", "optional username required to publish").Default("").String() + argPublishPass := kingpin.Flag("publish-pass", "optional password required to publish").Default("").String() + argReadUser := kingpin.Flag("read-user", "optional username required to read").Default("").String() + argReadPass := kingpin.Flag("read-pass", "optional password required to read").Default("").String() + argPreScript := kingpin.Flag("pre-script", "optional script to run on client connect").Default("").String() + argPostScript := kingpin.Flag("post-script", "optional script to run on client disconnect").Default("").String() + + kingpin.MustParse(kingpin.CommandLine.Parse(sargs)) + + args := args{ + version: *argVersion, + protocolsStr: *argProtocolsStr, + rtspPort: *argRtspPort, + rtpPort: *argRtpPort, + rtcpPort: *argRtcpPort, + readTimeout: *argReadTimeout, + writeTimeout: *argWriteTimeout, + publishUser: *argPublishUser, + publishPass: *argPublishPass, + readUser: *argReadUser, + readPass: *argReadPass, + preScript: *argPreScript, + postScript: *argPostScript, } if args.version == true { @@ -88,14 +104,6 @@ func newProgram(args args) (*program, error) { os.Exit(0) } - if (args.rtpPort % 2) != 0 { - return nil, fmt.Errorf("rtp port must be even") - } - - if args.rtcpPort != (args.rtpPort + 1) { - return nil, fmt.Errorf("rtcp and rtp ports must be consecutive") - } - protocols := make(map[streamProtocol]struct{}) for _, proto := range strings.Split(args.protocolsStr, ",") { switch proto { @@ -112,21 +120,37 @@ func newProgram(args args) (*program, error) { if len(protocols) == 0 { return nil, fmt.Errorf("no protocols provided") } - + if (args.rtpPort % 2) != 0 { + return nil, fmt.Errorf("rtp port must be even") + } + if args.rtcpPort != (args.rtpPort + 1) { + return nil, fmt.Errorf("rtcp and rtp ports must be consecutive") + } if args.publishUser != "" { if !regexp.MustCompile("^[a-zA-Z0-9]+$").MatchString(args.publishUser) { return nil, fmt.Errorf("publish username must be alphanumeric") } } - if args.publishPass != "" { if !regexp.MustCompile("^[a-zA-Z0-9]+$").MatchString(args.publishPass) { return nil, fmt.Errorf("publish password must be alphanumeric") } } - - if args.publishUser != "" && args.publishPass == "" || args.publishUser == "" && args.publishPass != "" { - return nil, fmt.Errorf("publish username and password must be both filled") + if args.readUser != "" && args.readPass == "" || args.readUser == "" && args.readPass != "" { + return nil, fmt.Errorf("read username and password must be both filled") + } + if args.readUser != "" { + if !regexp.MustCompile("^[a-zA-Z0-9]+$").MatchString(args.readUser) { + return nil, fmt.Errorf("read username must be alphanumeric") + } + } + if args.readPass != "" { + if !regexp.MustCompile("^[a-zA-Z0-9]+$").MatchString(args.readPass) { + return nil, fmt.Errorf("read password must be alphanumeric") + } + } + if args.readUser != "" && args.readPass == "" || args.readUser == "" && args.readPass != "" { + return nil, fmt.Errorf("read username and password must be both filled") } log.Printf("rtsp-simple-server %s", Version) @@ -167,40 +191,7 @@ func (p *program) close() { } func main() { - kingpin.CommandLine.Help = "rtsp-simple-server " + Version + "\n\n" + - "RTSP server." - - argVersion := kingpin.Flag("version", "print rtsp-simple-server version").Bool() - argProtocolsStr := kingpin.Flag("protocols", "supported protocols").Default("udp,tcp").String() - argRtspPort := kingpin.Flag("rtsp-port", "port of the RTSP TCP listener").Default("8554").Int() - argRtpPort := kingpin.Flag("rtp-port", "port of the RTP UDP listener").Default("8000").Int() - argRtcpPort := kingpin.Flag("rtcp-port", "port of the RTCP UDP listener").Default("8001").Int() - argReadTimeout := kingpin.Flag("read-timeout", "timeout for read operations").Default("5s").Duration() - argWriteTimeout := kingpin.Flag("write-timeout", "timeout for write operations").Default("5s").Duration() - argPublishUser := kingpin.Flag("publish-user", "optional username required to publish").Default("").String() - argPublishPass := kingpin.Flag("publish-pass", "optional password required to publish").Default("").String() - argReadUser := kingpin.Flag("read-user", "optional username required to read").Default("").String() - argReadPass := kingpin.Flag("read-pass", "optional password required to read").Default("").String() - argPreScript := kingpin.Flag("pre-script", "optional script to run on client connect").Default("").String() - argPostScript := kingpin.Flag("post-script", "optional script to run on client disconnect").Default("").String() - - kingpin.Parse() - - _, err := newProgram(args{ - version: *argVersion, - protocolsStr: *argProtocolsStr, - rtspPort: *argRtspPort, - rtpPort: *argRtpPort, - rtcpPort: *argRtcpPort, - readTimeout: *argReadTimeout, - writeTimeout: *argWriteTimeout, - publishUser: *argPublishUser, - publishPass: *argPublishPass, - readUser: *argReadUser, - readPass: *argReadPass, - preScript: *argPreScript, - postScript: *argPostScript, - }) + _, err := newProgram(os.Args[1:]) if err != nil { log.Fatal("ERR: ", err) } diff --git a/main_test.go b/main_test.go index f2af2bf9..dc6d031a 100644 --- a/main_test.go +++ b/main_test.go @@ -61,7 +61,7 @@ func TestProtocols(t *testing.T) { {"tcp", "tcp"}, } { t.Run(pair[0]+"_"+pair[1], func(t *testing.T) { - p, err := newProgram(args{}) + p, err := newProgram([]string{}) require.NoError(t, err) defer p.close() @@ -103,9 +103,9 @@ func TestProtocols(t *testing.T) { } func TestPublishAuth(t *testing.T) { - p, err := newProgram(args{ - publishUser: "testuser", - publishPass: "testpass", + p, err := newProgram([]string{ + "--publish-user=testuser", + "--publish-pass=testpass", }) require.NoError(t, err) defer p.close() @@ -146,9 +146,9 @@ func TestPublishAuth(t *testing.T) { } func TestReadAuth(t *testing.T) { - p, err := newProgram(args{ - readUser: "testuser", - readPass: "testpass", + p, err := newProgram([]string{ + "--read-user=testuser", + "--read-pass=testpass", }) require.NoError(t, err) defer p.close()