diff --git a/client.go b/client/client.go similarity index 67% rename from client.go rename to client/client.go index 73f9fe3e..1836311f 100644 --- a/client.go +++ b/client/client.go @@ -1,4 +1,4 @@ -package main +package client import ( "errors" @@ -7,6 +7,7 @@ import ( "net" "strconv" "strings" + "sync" "sync/atomic" "time" @@ -18,190 +19,186 @@ import ( "github.com/aler9/rtsp-simple-server/conf" "github.com/aler9/rtsp-simple-server/externalcmd" + "github.com/aler9/rtsp-simple-server/serverudp" + "github.com/aler9/rtsp-simple-server/stats" ) const ( - clientCheckStreamInterval = 5 * time.Second - clientReceiverReportInterval = 10 * time.Second + checkStreamInterval = 5 * time.Second + receiverReportInterval = 10 * time.Second ) -type clientDescribeReq struct { - client *client - pathName string - pathConf *conf.PathConf -} - -type clientAnnounceReq struct { - res chan error - client *client - pathName string - pathConf *conf.PathConf - trackCount int - sdp []byte -} - -type clientSetupPlayReq struct { - res chan error - client *client - pathName string - trackId int -} - type readRequestPair struct { req *base.Request res chan error } -type clientTrack struct { +type streamTrack struct { rtpPort int rtcpPort int } -type describeRes struct { +type describeData struct { sdp []byte err error } -type clientState int +type state int const ( - clientStateInitial clientState = iota - clientStateWaitDescription - clientStatePrePlay - clientStatePlay - clientStatePreRecord - clientStateRecord + stateInitial state = iota + stateWaitingDescribe + statePrePlay + statePlay + statePreRecord + stateRecord ) -func (cs clientState) String() string { +func (cs state) String() string { switch cs { - case clientStateInitial: + case stateInitial: return "Initial" - case clientStateWaitDescription: - return "WaitDescription" + case stateWaitingDescribe: + return "WaitingDescribe" - case clientStatePrePlay: + case statePrePlay: return "PrePlay" - case clientStatePlay: + case statePlay: return "Play" - case clientStatePreRecord: + case statePreRecord: return "PreRecord" - case clientStateRecord: + case stateRecord: return "Record" } return "Invalid" } -type client struct { - p *program - conn *gortsplib.ConnServer - state clientState - path *path +type Path interface { + Name() string + SourceTrackCount() int + Conf() *conf.PathConf + OnClientRemove(*Client) + OnClientPlay(*Client) + OnClientRecord(*Client) + OnFrame(int, gortsplib.StreamType, []byte) +} + +type Parent interface { + Log(string, ...interface{}) + OnClientClose(*Client) + OnClientDescribe(*Client, string, *base.Request) (Path, error) + OnClientAnnounce(*Client, string, gortsplib.Tracks, *base.Request) (Path, error) + OnClientSetupPlay(*Client, string, int, *base.Request) (Path, error) +} + +type Client struct { + wg *sync.WaitGroup + stats *stats.Stats + serverUdpRtp *serverudp.Server + serverUdpRtcp *serverudp.Server + readTimeout time.Duration + runOnConnect string + protocols map[gortsplib.StreamProtocol]struct{} + conn *gortsplib.ConnServer + parent Parent + + state state + path Path authUser string authPass string authHelper *auth.Server authFailures int streamProtocol gortsplib.StreamProtocol - streamTracks map[int]*clientTrack + streamTracks map[int]*streamTrack rtcpReceivers []*rtcpreceiver.RtcpReceiver udpLastFrameTimes []*int64 describeCSeq base.HeaderValue describeUrl string - describe chan describeRes - tcpFrame chan *base.InterleavedFrame - terminate chan struct{} + // in + describeData chan describeData // from path + tcpFrame chan *base.InterleavedFrame // from source + terminate chan struct{} } -func newClient(p *program, nconn net.Conn) { - c := &client{ - p: p, +func New( + wg *sync.WaitGroup, + stats *stats.Stats, + serverUdpRtp *serverudp.Server, + serverUdpRtcp *serverudp.Server, + readTimeout time.Duration, + writeTimeout time.Duration, + runOnConnect string, + protocols map[gortsplib.StreamProtocol]struct{}, + nconn net.Conn, + parent Parent) *Client { + + c := &Client{ + wg: wg, + stats: stats, + serverUdpRtp: serverUdpRtp, + serverUdpRtcp: serverUdpRtcp, + readTimeout: readTimeout, + runOnConnect: runOnConnect, + protocols: protocols, conn: gortsplib.NewConnServer(gortsplib.ConnServerConf{ Conn: nconn, - ReadTimeout: p.conf.ReadTimeout, - WriteTimeout: p.conf.WriteTimeout, + ReadTimeout: readTimeout, + WriteTimeout: writeTimeout, ReadBufferCount: 2, }), - state: clientStateInitial, - streamTracks: make(map[int]*clientTrack), - describe: make(chan describeRes), + parent: parent, + state: stateInitial, + streamTracks: make(map[int]*streamTrack), + describeData: make(chan describeData), tcpFrame: make(chan *base.InterleavedFrame), terminate: make(chan struct{}), } - p.clients[c] = struct{}{} - atomic.AddInt64(p.countClients, 1) + atomic.AddInt64(c.stats.CountClients, 1) c.log("connected") - p.clientsWg.Add(1) + c.wg.Add(1) go c.run() + return c } -func (c *client) close() { - delete(c.p.clients, c) - - atomic.AddInt64(c.p.countClients, -1) - - switch c.state { - case clientStatePlay: - atomic.AddInt64(c.p.countReaders, -1) - c.p.readersMap.remove(c) - - case clientStateRecord: - atomic.AddInt64(c.p.countPublishers, -1) - - if c.streamProtocol == gortsplib.StreamProtocolUDP { - for _, track := range c.streamTracks { - addr := makeUDPPublisherAddr(c.ip(), track.rtpPort) - c.p.udpPublishersMap.remove(addr) - - addr = makeUDPPublisherAddr(c.ip(), track.rtcpPort) - c.p.udpPublishersMap.remove(addr) - } - } - - c.path.onSourceSetNotReady() - } - - if c.path != nil && c.path.source == c { - c.path.onSourceRemove() - } - +func (c *Client) Close() { + atomic.AddInt64(c.stats.CountClients, -1) close(c.terminate) - - c.log("disconnected") } -func (c *client) log(format string, args ...interface{}) { - c.p.log("[client %s] "+format, append([]interface{}{c.conn.NetConn().RemoteAddr().String()}, args...)...) +func (c *Client) IsSource() {} + +func (c *Client) log(format string, args ...interface{}) { + c.parent.Log("[client %s] "+format, append([]interface{}{c.conn.NetConn().RemoteAddr().String()}, args...)...) } -func (c *client) isSource() {} - -func (c *client) ip() net.IP { +func (c *Client) ip() net.IP { return c.conn.NetConn().RemoteAddr().(*net.TCPAddr).IP } -func (c *client) zone() string { +func (c *Client) zone() string { return c.conn.NetConn().RemoteAddr().(*net.TCPAddr).Zone } var errRunTerminate = errors.New("terminate") -var errRunWaitDescription = errors.New("wait description") +var errRunWaitingDescribe = errors.New("wait description") var errRunPlay = errors.New("play") var errRunRecord = errors.New("record") -func (c *client) run() { - defer c.p.clientsWg.Done() +func (c *Client) run() { + defer c.wg.Done() + defer c.log("disconnected") var onConnectCmd *externalcmd.ExternalCmd - if c.p.conf.RunOnConnect != "" { + if c.runOnConnect != "" { var err error - onConnectCmd, err = externalcmd.New(c.p.conf.RunOnConnect, "") + onConnectCmd, err = externalcmd.New(c.runOnConnect, "") if err != nil { c.log("ERR: %s", err) } @@ -217,11 +214,11 @@ func (c *client) run() { onConnectCmd.Close() } - close(c.describe) + close(c.describeData) close(c.tcpFrame) } -func (c *client) writeResError(cseq base.HeaderValue, code base.StatusCode, err error) { +func (c *Client) writeResError(cseq base.HeaderValue, code base.StatusCode, err error) { c.log("ERR: %s", err) c.conn.WriteResponse(&base.Response{ @@ -232,39 +229,47 @@ func (c *client) writeResError(cseq base.HeaderValue, code base.StatusCode, err }) } -var errAuthCritical = errors.New("auth critical") -var errAuthNotCritical = errors.New("auth not critical") +type ErrAuthNotCritical struct { + *base.Response +} -func (c *client) authenticate(ips []interface{}, user string, pass string, req *base.Request) error { +func (ErrAuthNotCritical) Error() string { + return "auth not critical" +} + +type ErrAuthCritical struct { + *base.Response +} + +func (ErrAuthCritical) Error() string { + return "auth critical" +} + +func (c *Client) Authenticate(authMethods []headers.AuthMethod, ips []interface{}, user string, pass string, req *base.Request) error { // validate ip - err := func() error { - if ips == nil { - return nil - } - + if ips != nil { ip := c.ip() + if !ipEqualOrInRange(ip, ips) { c.log("ERR: ip '%s' not allowed", ip) - return errAuthCritical - } - return nil - }() - if err != nil { - return err + return ErrAuthCritical{&base.Response{ + StatusCode: base.StatusUnauthorized, + Header: base.Header{ + "CSeq": req.Header["CSeq"], + "WWW-Authenticate": c.authHelper.GenerateHeader(), + }, + }} + } } - // validate credentials - err = func() error { - if user == "" { - return nil - } - + // validate user + if user != "" { // reset authHelper every time the credentials change if c.authHelper == nil || c.authUser != user || c.authPass != pass { c.authUser = user c.authPass = pass - c.authHelper = auth.NewServer(user, pass, c.p.conf.AuthMethodsParsed) + c.authHelper = auth.NewServer(user, pass, authMethods) } err := c.authHelper.ValidateHeader(req.Header["Authorization"], req.Method, req.Url) @@ -277,43 +282,40 @@ func (c *client) authenticate(ips []interface{}, user string, pass string, req * // 3) without credentials // 4) with password and username // hence we must allow up to 3 failures - var retErr error if c.authFailures > 3 { c.log("ERR: unauthorized: %s", err) - retErr = errAuthCritical - } else if c.authFailures > 1 { - c.log("WARN: unauthorized: %s", err) - retErr = errAuthNotCritical + return ErrAuthCritical{&base.Response{ + StatusCode: base.StatusUnauthorized, + Header: base.Header{ + "CSeq": req.Header["CSeq"], + "WWW-Authenticate": c.authHelper.GenerateHeader(), + }, + }} } else { - retErr = errAuthNotCritical + if c.authFailures > 1 { + c.log("WARN: unauthorized: %s", err) + } + + return ErrAuthNotCritical{&base.Response{ + StatusCode: base.StatusUnauthorized, + Header: base.Header{ + "CSeq": req.Header["CSeq"], + "WWW-Authenticate": c.authHelper.GenerateHeader(), + }, + }} } - - c.conn.WriteResponse(&base.Response{ - StatusCode: base.StatusUnauthorized, - Header: base.Header{ - "CSeq": req.Header["CSeq"], - "WWW-Authenticate": c.authHelper.GenerateHeader(), - }, - }) - - return retErr } - - // reset authFailures after a successful login - c.authFailures = 0 - - return nil - }() - if err != nil { - return err } + // login successful, reset authFailures + c.authFailures = 0 + return nil } -func (c *client) handleRequest(req *base.Request) error { +func (c *Client) handleRequest(req *base.Request) error { c.log(string(req.Method)) cseq, ok := req.Header["CSeq"] @@ -367,58 +369,45 @@ func (c *client) handleRequest(req *base.Request) error { return nil case base.DESCRIBE: - if c.state != clientStateInitial { + if c.state != stateInitial { c.writeResError(cseq, base.StatusBadRequest, - fmt.Errorf("client is in state '%s' instead of '%s'", c.state, clientStateInitial)) + fmt.Errorf("client is in state '%s' instead of '%s'", c.state, stateInitial)) return errRunTerminate } pathName = removeQueryFromPath(pathName) - pathConf, err := c.p.conf.CheckPathNameAndFindConf(pathName) + path, err := c.parent.OnClientDescribe(c, pathName, req) if err != nil { - c.writeResError(cseq, base.StatusBadRequest, err) - return errRunTerminate - } + switch terr := err.(type) { + case ErrAuthNotCritical: + c.conn.WriteResponse(terr.Response) + return nil - err = c.authenticate(pathConf.ReadIpsParsed, pathConf.ReadUser, pathConf.ReadPass, req) - if err != nil { - if err == errAuthCritical { + case ErrAuthCritical: + c.conn.WriteResponse(terr.Response) + return errRunTerminate + + default: + c.writeResError(cseq, base.StatusBadRequest, err) return errRunTerminate } - return nil } - c.p.clientDescribe <- clientDescribeReq{c, pathName, pathConf} - + c.path = path + c.state = stateWaitingDescribe c.describeCSeq = cseq c.describeUrl = req.Url.String() - return errRunWaitDescription + return errRunWaitingDescribe case base.ANNOUNCE: - if c.state != clientStateInitial { + if c.state != stateInitial { c.writeResError(cseq, base.StatusBadRequest, - fmt.Errorf("client is in state '%s' instead of '%s'", c.state, clientStateInitial)) + fmt.Errorf("client is in state '%s' instead of '%s'", c.state, stateInitial)) return errRunTerminate } - pathName = removeQueryFromPath(pathName) - - pathConf, err := c.p.conf.CheckPathNameAndFindConf(pathName) - if err != nil { - c.writeResError(cseq, base.StatusBadRequest, err) - return errRunTerminate - } - - err = c.authenticate(pathConf.PublishIpsParsed, pathConf.PublishUser, pathConf.PublishPass, req) - if err != nil { - if err == errAuthCritical { - return errRunTerminate - } - return nil - } - ct, ok := req.Header["Content-Type"] if !ok || len(ct) != 1 { c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("Content-Type header missing")) @@ -441,15 +430,26 @@ func (c *client) handleRequest(req *base.Request) error { return errRunTerminate } - sdp := tracks.Write() + pathName = removeQueryFromPath(pathName) - res := make(chan error) - c.p.clientAnnounce <- clientAnnounceReq{res, c, pathName, pathConf, len(tracks), sdp} - err = <-res + path, err := c.parent.OnClientAnnounce(c, pathName, tracks, req) if err != nil { - c.writeResError(cseq, base.StatusBadRequest, err) - return errRunTerminate + switch terr := err.(type) { + case ErrAuthNotCritical: + c.conn.WriteResponse(terr.Response) + return nil + + case ErrAuthCritical: + c.conn.WriteResponse(terr.Response) + return errRunTerminate + + default: + c.writeResError(cseq, base.StatusBadRequest, err) + return errRunTerminate + } } + c.path = path + c.state = statePreRecord c.conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, @@ -471,7 +471,7 @@ func (c *client) handleRequest(req *base.Request) error { return errRunTerminate } - basePath, controlPath, err := splitPath(pathName) + basePath, controlPath, err := splitPathIntoBaseAndControl(pathName) if err != nil { c.writeResError(cseq, base.StatusBadRequest, err) return errRunTerminate @@ -481,28 +481,14 @@ func (c *client) handleRequest(req *base.Request) error { switch c.state { // play - case clientStateInitial, clientStatePrePlay: + case stateInitial, statePrePlay: if th.Mode != nil && *th.Mode != gortsplib.TransportModePlay { c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("transport header must contain mode=play or not contain a mode")) return errRunTerminate } - pathConf, err := c.p.conf.CheckPathNameAndFindConf(basePath) - if err != nil { - c.writeResError(cseq, base.StatusBadRequest, err) - return errRunTerminate - } - - err = c.authenticate(pathConf.ReadIpsParsed, pathConf.ReadUser, pathConf.ReadPass, req) - if err != nil { - if err == errAuthCritical { - return errRunTerminate - } - return nil - } - - if c.path != nil && basePath != c.path.name { - c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("path has changed, was '%s', now is '%s'", c.path.name, basePath)) + if c.path != nil && basePath != c.path.Name() { + c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("path has changed, was '%s', now is '%s'", c.path.Name(), basePath)) return errRunTerminate } @@ -525,7 +511,7 @@ func (c *client) handleRequest(req *base.Request) error { // play with UDP if th.Protocol == gortsplib.StreamProtocolUDP { - if _, ok := c.p.conf.ProtocolsParsed[gortsplib.StreamProtocolUDP]; !ok { + if _, ok := c.protocols[gortsplib.StreamProtocolUDP]; !ok { c.writeResError(cseq, base.StatusUnsupportedTransport, fmt.Errorf("UDP streaming is disabled")) return errRunTerminate } @@ -540,16 +526,28 @@ func (c *client) handleRequest(req *base.Request) error { return errRunTerminate } - res := make(chan error) - c.p.clientSetupPlay <- clientSetupPlayReq{res, c, basePath, trackId} - err = <-res + path, err := c.parent.OnClientSetupPlay(c, basePath, trackId, req) if err != nil { - c.writeResError(cseq, base.StatusBadRequest, err) - return errRunTerminate + switch terr := err.(type) { + case ErrAuthNotCritical: + c.conn.WriteResponse(terr.Response) + return nil + + case ErrAuthCritical: + c.conn.WriteResponse(terr.Response) + return errRunTerminate + + default: + c.writeResError(cseq, base.StatusBadRequest, err) + return errRunTerminate + } } + c.path = path + c.state = statePrePlay + c.streamProtocol = gortsplib.StreamProtocolUDP - c.streamTracks[trackId] = &clientTrack{ + c.streamTracks[trackId] = &streamTrack{ rtpPort: (*th.ClientPorts)[0], rtcpPort: (*th.ClientPorts)[1], } @@ -561,7 +559,7 @@ func (c *client) handleRequest(req *base.Request) error { return &v }(), ClientPorts: th.ClientPorts, - ServerPorts: &[2]int{c.p.conf.RtpPort, c.p.conf.RtcpPort}, + ServerPorts: &[2]int{c.serverUdpRtp.Port(), c.serverUdpRtcp.Port()}, } c.conn.WriteResponse(&base.Response{ @@ -576,7 +574,7 @@ func (c *client) handleRequest(req *base.Request) error { // play with TCP } else { - if _, ok := c.p.conf.ProtocolsParsed[gortsplib.StreamProtocolTCP]; !ok { + if _, ok := c.protocols[gortsplib.StreamProtocolTCP]; !ok { c.writeResError(cseq, base.StatusUnsupportedTransport, fmt.Errorf("TCP streaming is disabled")) return errRunTerminate } @@ -586,16 +584,28 @@ func (c *client) handleRequest(req *base.Request) error { return errRunTerminate } - res := make(chan error) - c.p.clientSetupPlay <- clientSetupPlayReq{res, c, basePath, trackId} - err = <-res + path, err := c.parent.OnClientSetupPlay(c, basePath, trackId, req) if err != nil { - c.writeResError(cseq, base.StatusBadRequest, err) - return errRunTerminate + switch terr := err.(type) { + case ErrAuthNotCritical: + c.conn.WriteResponse(terr.Response) + return nil + + case ErrAuthCritical: + c.conn.WriteResponse(terr.Response) + return errRunTerminate + + default: + c.writeResError(cseq, base.StatusBadRequest, err) + return errRunTerminate + } } + c.path = path + c.state = statePrePlay + c.streamProtocol = gortsplib.StreamProtocolTCP - c.streamTracks[trackId] = &clientTrack{ + c.streamTracks[trackId] = &streamTrack{ rtpPort: 0, rtcpPort: 0, } @@ -619,21 +629,21 @@ func (c *client) handleRequest(req *base.Request) error { } // record - case clientStatePreRecord: + case statePreRecord: if th.Mode == nil || *th.Mode != gortsplib.TransportModeRecord { c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("transport header does not contain mode=record")) return errRunTerminate } // after ANNOUNCE, c.path is already set - if basePath != c.path.name { - c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("path has changed, was '%s', now is '%s'", c.path.name, basePath)) + if basePath != c.path.Name() { + c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("path has changed, was '%s', now is '%s'", c.path.Name(), basePath)) return errRunTerminate } // record with UDP if th.Protocol == gortsplib.StreamProtocolUDP { - if _, ok := c.p.conf.ProtocolsParsed[gortsplib.StreamProtocolUDP]; !ok { + if _, ok := c.protocols[gortsplib.StreamProtocolUDP]; !ok { c.writeResError(cseq, base.StatusUnsupportedTransport, fmt.Errorf("UDP streaming is disabled")) return errRunTerminate } @@ -648,13 +658,13 @@ func (c *client) handleRequest(req *base.Request) error { return errRunTerminate } - if len(c.streamTracks) >= c.path.sourceTrackCount { + if len(c.streamTracks) >= c.path.SourceTrackCount() { c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("all the tracks have already been setup")) return errRunTerminate } c.streamProtocol = gortsplib.StreamProtocolUDP - c.streamTracks[len(c.streamTracks)] = &clientTrack{ + c.streamTracks[len(c.streamTracks)] = &streamTrack{ rtpPort: (*th.ClientPorts)[0], rtcpPort: (*th.ClientPorts)[1], } @@ -666,7 +676,7 @@ func (c *client) handleRequest(req *base.Request) error { return &v }(), ClientPorts: th.ClientPorts, - ServerPorts: &[2]int{c.p.conf.RtpPort, c.p.conf.RtcpPort}, + ServerPorts: &[2]int{c.serverUdpRtp.Port(), c.serverUdpRtcp.Port()}, } c.conn.WriteResponse(&base.Response{ @@ -681,7 +691,7 @@ func (c *client) handleRequest(req *base.Request) error { // record with TCP } else { - if _, ok := c.p.conf.ProtocolsParsed[gortsplib.StreamProtocolTCP]; !ok { + if _, ok := c.protocols[gortsplib.StreamProtocolTCP]; !ok { c.writeResError(cseq, base.StatusUnsupportedTransport, fmt.Errorf("TCP streaming is disabled")) return errRunTerminate } @@ -703,13 +713,13 @@ func (c *client) handleRequest(req *base.Request) error { return errRunTerminate } - if len(c.streamTracks) >= c.path.sourceTrackCount { + if len(c.streamTracks) >= c.path.SourceTrackCount() { c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("all the tracks have already been setup")) return errRunTerminate } c.streamProtocol = gortsplib.StreamProtocolTCP - c.streamTracks[len(c.streamTracks)] = &clientTrack{ + c.streamTracks[len(c.streamTracks)] = &streamTrack{ rtpPort: 0, rtcpPort: 0, } @@ -736,9 +746,9 @@ func (c *client) handleRequest(req *base.Request) error { } case base.PLAY: - if c.state != clientStatePrePlay { + if c.state != statePrePlay { c.writeResError(cseq, base.StatusBadRequest, - fmt.Errorf("client is in state '%s' instead of '%s'", c.state, clientStatePrePlay)) + fmt.Errorf("client is in state '%s' instead of '%s'", c.state, statePrePlay)) return errRunTerminate } @@ -747,8 +757,8 @@ func (c *client) handleRequest(req *base.Request) error { // path can end with a slash, remove it pathName = strings.TrimSuffix(pathName, "/") - if pathName != c.path.name { - c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("path has changed, was '%s', now is '%s'", c.path.name, pathName)) + if pathName != c.path.Name() { + c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("path has changed, was '%s', now is '%s'", c.path.Name(), pathName)) return errRunTerminate } @@ -771,9 +781,9 @@ func (c *client) handleRequest(req *base.Request) error { return errRunPlay case base.RECORD: - if c.state != clientStatePreRecord { + if c.state != statePreRecord { c.writeResError(cseq, base.StatusBadRequest, - fmt.Errorf("client is in state '%s' instead of '%s'", c.state, clientStatePreRecord)) + fmt.Errorf("client is in state '%s' instead of '%s'", c.state, statePreRecord)) return errRunTerminate } @@ -782,12 +792,12 @@ func (c *client) handleRequest(req *base.Request) error { // path can end with a slash, remove it pathName = strings.TrimSuffix(pathName, "/") - if pathName != c.path.name { - c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("path has changed, was '%s', now is '%s'", c.path.name, pathName)) + if pathName != c.path.Name() { + c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("path has changed, was '%s', now is '%s'", c.path.Name(), pathName)) return errRunTerminate } - if len(c.streamTracks) != c.path.sourceTrackCount { + if len(c.streamTracks) != c.path.SourceTrackCount() { c.writeResError(cseq, base.StatusBadRequest, fmt.Errorf("not all tracks have been setup")) return errRunTerminate } @@ -812,7 +822,7 @@ func (c *client) handleRequest(req *base.Request) error { } } -func (c *client) runInitial() bool { +func (c *Client) runInitial() bool { readDone := make(chan error) go func() { for { @@ -833,8 +843,8 @@ func (c *client) runInitial() bool { select { case err := <-readDone: switch err { - case errRunWaitDescription: - return c.runWaitDescription() + case errRunWaitingDescribe: + return c.runWaitingDescribe() case errRunPlay: return c.runPlay() @@ -847,7 +857,8 @@ func (c *client) runInitial() bool { if err != io.EOF && err != errRunTerminate { c.log("ERR: %s", err) } - c.p.clientClose <- c + + c.parent.OnClientClose(c) <-c.terminate return false } @@ -859,9 +870,12 @@ func (c *client) runInitial() bool { } } -func (c *client) runWaitDescription() bool { +func (c *Client) runWaitingDescribe() bool { select { - case res := <-c.describe: + case res := <-c.describeData: + c.path = nil + c.state = stateInitial + if res.err != nil { c.writeResError(c.describeCSeq, base.StatusNotFound, res.err) return true @@ -879,16 +893,23 @@ func (c *client) runWaitDescription() bool { return true case <-c.terminate: + go func() { + for range c.describeData { + } + }() + + c.path.OnClientRemove(c) + c.conn.Close() return false } } -func (c *client) runPlay() bool { - // start sending frames only after sending the response to the PLAY request - c.p.clientPlay <- c +func (c *Client) runPlay() bool { + // start sending frames only after replying to the PLAY request + c.path.OnClientPlay(c) - c.log("is reading from path '%s', %d %s with %s", c.path.name, len(c.streamTracks), func() string { + c.log("is reading from path '%s', %d %s with %s", c.path.Name(), len(c.streamTracks), func() string { if len(c.streamTracks) == 1 { return "track" } @@ -896,9 +917,9 @@ func (c *client) runPlay() bool { }(), c.streamProtocol) var onReadCmd *externalcmd.ExternalCmd - if c.path.conf.RunOnRead != "" { + if c.path.Conf().RunOnRead != "" { var err error - onReadCmd, err = externalcmd.New(c.path.conf.RunOnRead, c.path.name) + onReadCmd, err = externalcmd.New(c.path.Conf().RunOnRead, c.path.Name()) if err != nil { c.log("ERR: %s", err) } @@ -917,7 +938,7 @@ func (c *client) runPlay() bool { return false } -func (c *client) runPlayUDP() { +func (c *Client) runPlayUDP() { readDone := make(chan error) go func() { for { @@ -941,18 +962,23 @@ func (c *client) runPlayUDP() { if err != io.EOF && err != errRunTerminate { c.log("ERR: %s", err) } - c.p.clientClose <- c + + c.path.OnClientRemove(c) + + c.parent.OnClientClose(c) <-c.terminate return case <-c.terminate: + c.path.OnClientRemove(c) + c.conn.Close() <-readDone return } } -func (c *client) runPlayTCP() { +func (c *Client) runPlayTCP() { readRequest := make(chan readRequestPair) defer close(readRequest) @@ -992,11 +1018,15 @@ func (c *client) runPlayTCP() { if err != io.EOF && err != errRunTerminate { c.log("ERR: %s", err) } + go func() { for range c.tcpFrame { } }() - c.p.clientClose <- c + + c.path.OnClientRemove(c) + + c.parent.OnClientClose(c) <-c.terminate return @@ -1009,6 +1039,9 @@ func (c *client) runPlayTCP() { req.res <- fmt.Errorf("terminated") } }() + + c.path.OnClientRemove(c) + c.conn.Close() <-readDone return @@ -1016,7 +1049,7 @@ func (c *client) runPlayTCP() { } } -func (c *client) runRecord() bool { +func (c *Client) runRecord() bool { c.rtcpReceivers = make([]*rtcpreceiver.RtcpReceiver, len(c.streamTracks)) for trackId := range c.streamTracks { c.rtcpReceivers[trackId] = rtcpreceiver.New() @@ -1030,19 +1063,26 @@ func (c *client) runRecord() bool { } } - c.p.clientRecord <- c + c.path.OnClientRecord(c) - c.log("is publishing to path '%s', %d %s with %s", c.path.name, len(c.streamTracks), func() string { + c.log("is publishing to path '%s', %d %s with %s", c.path.Name(), len(c.streamTracks), func() string { if len(c.streamTracks) == 1 { return "track" } return "tracks" }(), c.streamProtocol) + if c.streamProtocol == gortsplib.StreamProtocolUDP { + for trackId, track := range c.streamTracks { + c.serverUdpRtp.AddPublisher(c.ip(), track.rtpPort, c, trackId) + c.serverUdpRtcp.AddPublisher(c.ip(), track.rtcpPort, c, trackId) + } + } + var onPublishCmd *externalcmd.ExternalCmd - if c.path.conf.RunOnPublish != "" { + if c.path.Conf().RunOnPublish != "" { var err error - onPublishCmd, err = externalcmd.New(c.path.conf.RunOnPublish, c.path.name) + onPublishCmd, err = externalcmd.New(c.path.Conf().RunOnPublish, c.path.Name()) if err != nil { c.log("ERR: %s", err) } @@ -1054,6 +1094,13 @@ func (c *client) runRecord() bool { c.runRecordTCP() } + if c.streamProtocol == gortsplib.StreamProtocolUDP { + for _, track := range c.streamTracks { + c.serverUdpRtp.RemovePublisher(c.ip(), track.rtpPort, c) + c.serverUdpRtcp.RemovePublisher(c.ip(), track.rtcpPort, c) + } + } + if onPublishCmd != nil { onPublishCmd.Close() } @@ -1061,10 +1108,10 @@ func (c *client) runRecord() bool { return false } -func (c *client) runRecordUDP() { +func (c *Client) runRecordUDP() { // open the firewall by sending packets to the counterpart for _, track := range c.streamTracks { - c.p.serverUdpRtp.write( + c.serverUdpRtp.Write( []byte{0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, &net.UDPAddr{ IP: c.ip(), @@ -1072,7 +1119,7 @@ func (c *client) runRecordUDP() { Port: track.rtpPort, }) - c.p.serverUdpRtcp.write( + c.serverUdpRtcp.Write( []byte{0x80, 0xc9, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00}, &net.UDPAddr{ IP: c.ip(), @@ -1098,10 +1145,10 @@ func (c *client) runRecordUDP() { } }() - checkStreamTicker := time.NewTicker(clientCheckStreamInterval) + checkStreamTicker := time.NewTicker(checkStreamInterval) defer checkStreamTicker.Stop() - receiverReportTicker := time.NewTicker(clientReceiverReportInterval) + receiverReportTicker := time.NewTicker(receiverReportInterval) defer receiverReportTicker.Stop() for { @@ -1111,7 +1158,10 @@ func (c *client) runRecordUDP() { if err != io.EOF && err != errRunTerminate { c.log("ERR: %s", err) } - c.p.clientClose <- c + + c.path.OnClientRemove(c) + + c.parent.OnClientClose(c) <-c.terminate return @@ -1121,11 +1171,14 @@ func (c *client) runRecordUDP() { for _, lastUnix := range c.udpLastFrameTimes { last := time.Unix(atomic.LoadInt64(lastUnix), 0) - if now.Sub(last) >= c.p.conf.ReadTimeout { + if now.Sub(last) >= c.readTimeout { c.log("ERR: no packets received recently (maybe there's a firewall/NAT in between)") c.conn.Close() <-readDone - c.p.clientClose <- c + + c.path.OnClientRemove(c) + + c.parent.OnClientClose(c) <-c.terminate return } @@ -1134,7 +1187,7 @@ func (c *client) runRecordUDP() { case <-receiverReportTicker.C: for trackId := range c.streamTracks { frame := c.rtcpReceivers[trackId].Report() - c.p.serverUdpRtcp.write(frame, &net.UDPAddr{ + c.serverUdpRtcp.Write(frame, &net.UDPAddr{ IP: c.ip(), Zone: c.zone(), Port: c.streamTracks[trackId].rtcpPort, @@ -1142,6 +1195,8 @@ func (c *client) runRecordUDP() { } case <-c.terminate: + c.path.OnClientRemove(c) + c.conn.Close() <-readDone return @@ -1149,7 +1204,7 @@ func (c *client) runRecordUDP() { } } -func (c *client) runRecordTCP() { +func (c *Client) runRecordTCP() { readRequest := make(chan readRequestPair) defer close(readRequest) @@ -1170,8 +1225,7 @@ func (c *client) runRecordTCP() { } c.rtcpReceivers[recvt.TrackId].OnFrame(recvt.StreamType, recvt.Content) - - c.p.readersMap.forwardFrame(c.path, recvt.TrackId, recvt.StreamType, recvt.Content) + c.path.OnFrame(recvt.TrackId, recvt.StreamType, recvt.Content) case *base.Request: err := c.handleRequest(recvt) @@ -1183,7 +1237,7 @@ func (c *client) runRecordTCP() { } }() - receiverReportTicker := time.NewTicker(clientReceiverReportInterval) + receiverReportTicker := time.NewTicker(receiverReportInterval) defer receiverReportTicker.Stop() for { @@ -1197,7 +1251,10 @@ func (c *client) runRecordTCP() { if err != io.EOF && err != errRunTerminate { c.log("ERR: %s", err) } - c.p.clientClose <- c + + c.path.OnClientRemove(c) + + c.parent.OnClientClose(c) <-c.terminate return @@ -1213,9 +1270,54 @@ func (c *client) runRecordTCP() { req.res <- fmt.Errorf("terminated") } }() + + c.path.OnClientRemove(c) + c.conn.Close() <-readDone return } } } + +func (c *Client) OnUdpPublisherFrame(trackId int, streamType base.StreamType, buf []byte) { + atomic.StoreInt64(c.udpLastFrameTimes[trackId], time.Now().Unix()) + + c.rtcpReceivers[trackId].OnFrame(streamType, buf) + c.path.OnFrame(trackId, streamType, buf) +} + +func (c *Client) OnReaderFrame(trackId int, streamType base.StreamType, buf []byte) { + track, ok := c.streamTracks[trackId] + if !ok { + return + } + + if c.streamProtocol == gortsplib.StreamProtocolUDP { + if streamType == gortsplib.StreamTypeRtp { + c.serverUdpRtp.Write(buf, &net.UDPAddr{ + IP: c.ip(), + Zone: c.zone(), + Port: track.rtpPort, + }) + + } else { + c.serverUdpRtcp.Write(buf, &net.UDPAddr{ + IP: c.ip(), + Zone: c.zone(), + Port: track.rtcpPort, + }) + } + + } else { + c.tcpFrame <- &base.InterleavedFrame{ + TrackId: trackId, + StreamType: streamType, + Content: buf, + } + } +} + +func (c *Client) OnPathDescribeData(sdp []byte, err error) { + c.describeData <- describeData{sdp, err} +} diff --git a/client/utils.go b/client/utils.go new file mode 100644 index 00000000..68f44fd4 --- /dev/null +++ b/client/utils.go @@ -0,0 +1,60 @@ +package client + +import ( + "fmt" + "net" + "strings" +) + +func ipEqualOrInRange(ip net.IP, ips []interface{}) bool { + for _, item := range ips { + switch titem := item.(type) { + case net.IP: + if titem.Equal(ip) { + return true + } + + case *net.IPNet: + if titem.Contains(ip) { + return true + } + } + } + return false +} + +func removeQueryFromPath(path string) string { + i := strings.Index(path, "?") + if i >= 0 { + return path[:i] + } + return path +} + +func splitPathIntoBaseAndControl(path string) (string, string, error) { + pos := func() int { + for i := len(path) - 1; i >= 0; i-- { + if path[i] == '/' { + return i + } + } + return -1 + }() + + if pos < 0 { + return "", "", fmt.Errorf("the path must contain a base path and a control path (%s)", path) + } + + basePath := path[:pos] + controlPath := path[pos+1:] + + if len(basePath) == 0 { + return "", "", fmt.Errorf("empty base path (%s)", basePath) + } + + if len(controlPath) == 0 { + return "", "", fmt.Errorf("empty control path (%s)", controlPath) + } + + return basePath, controlPath, nil +} diff --git a/clientman/clientman.go b/clientman/clientman.go new file mode 100644 index 00000000..6f398a1a --- /dev/null +++ b/clientman/clientman.go @@ -0,0 +1,154 @@ +package clientman + +import ( + "sync" + "time" + + "github.com/aler9/gortsplib" + "github.com/aler9/gortsplib/base" + "github.com/aler9/gortsplib/headers" + + "github.com/aler9/rtsp-simple-server/client" + "github.com/aler9/rtsp-simple-server/pathman" + "github.com/aler9/rtsp-simple-server/servertcp" + "github.com/aler9/rtsp-simple-server/serverudp" + "github.com/aler9/rtsp-simple-server/stats" +) + +type Parent interface { + Log(string, ...interface{}) +} + +type ClientManager struct { + stats *stats.Stats + serverUdpRtp *serverudp.Server + serverUdpRtcp *serverudp.Server + readTimeout time.Duration + writeTimeout time.Duration + runOnConnect string + protocols map[headers.StreamProtocol]struct{} + pathMan *pathman.PathManager + serverTcp *servertcp.Server + parent Parent + + clients map[*client.Client]struct{} + wg sync.WaitGroup + + // in + clientClose chan *client.Client + terminate chan struct{} + + // out + done chan struct{} +} + +func New(stats *stats.Stats, + serverUdpRtp *serverudp.Server, + serverUdpRtcp *serverudp.Server, + readTimeout time.Duration, + writeTimeout time.Duration, + runOnConnect string, + protocols map[headers.StreamProtocol]struct{}, + pathMan *pathman.PathManager, + serverTcp *servertcp.Server, + parent Parent) *ClientManager { + + cm := &ClientManager{ + stats: stats, + serverUdpRtp: serverUdpRtp, + serverUdpRtcp: serverUdpRtcp, + readTimeout: readTimeout, + writeTimeout: writeTimeout, + runOnConnect: runOnConnect, + protocols: protocols, + pathMan: pathMan, + serverTcp: serverTcp, + parent: parent, + clients: make(map[*client.Client]struct{}), + clientClose: make(chan *client.Client), + terminate: make(chan struct{}), + done: make(chan struct{}), + } + + go cm.run() + return cm +} + +func (cm *ClientManager) Close() { + close(cm.terminate) + <-cm.done +} + +func (cm *ClientManager) Log(format string, args ...interface{}) { + cm.parent.Log(format, args...) +} + +func (cm *ClientManager) run() { + defer close(cm.done) + +outer: + for { + select { + case conn := <-cm.serverTcp.Accept(): + c := client.New(&cm.wg, + cm.stats, + cm.serverUdpRtp, + cm.serverUdpRtcp, + cm.readTimeout, + cm.writeTimeout, + cm.runOnConnect, + cm.protocols, + conn, + cm) + cm.clients[c] = struct{}{} + + case c := <-cm.pathMan.ClientClose(): + if _, ok := cm.clients[c]; !ok { + continue + } + delete(cm.clients, c) + c.Close() + + case c := <-cm.clientClose: + if _, ok := cm.clients[c]; !ok { + continue + } + delete(cm.clients, c) + c.Close() + + case <-cm.terminate: + break outer + } + } + + go func() { + for { + select { + case <-cm.clientClose: + } + } + }() + + for c := range cm.clients { + c.Close() + } + cm.wg.Wait() + + close(cm.clientClose) +} + +func (cm *ClientManager) OnClientClose(c *client.Client) { + cm.clientClose <- c +} + +func (cm *ClientManager) OnClientDescribe(c *client.Client, pathName string, req *base.Request) (client.Path, error) { + return cm.pathMan.OnClientDescribe(c, pathName, req) +} + +func (cm *ClientManager) OnClientAnnounce(c *client.Client, pathName string, tracks gortsplib.Tracks, req *base.Request) (client.Path, error) { + return cm.pathMan.OnClientAnnounce(c, pathName, tracks, req) +} + +func (cm *ClientManager) OnClientSetupPlay(c *client.Client, pathName string, trackId int, req *base.Request) (client.Path, error) { + return cm.pathMan.OnClientSetupPlay(c, pathName, trackId, req) +} diff --git a/conf/conf.go b/conf/conf.go index 8111657e..46e470f4 100644 --- a/conf/conf.go +++ b/conf/conf.go @@ -2,7 +2,6 @@ package conf import ( "fmt" - "net" "net/url" "os" "regexp" @@ -17,56 +16,9 @@ import ( "github.com/aler9/rtsp-simple-server/loghandler" ) -func parseIpCidrList(in []string) ([]interface{}, error) { - if len(in) == 0 { - return nil, nil - } - - var ret []interface{} - for _, t := range in { - _, ipnet, err := net.ParseCIDR(t) - if err == nil { - ret = append(ret, ipnet) - continue - } - - ip := net.ParseIP(t) - if ip != nil { - ret = append(ret, ip) - continue - } - - return nil, fmt.Errorf("unable to parse ip/network '%s'", t) - } - return ret, nil -} - -var rePathName = regexp.MustCompile("^[0-9a-zA-Z_\\-/]+$") - -func checkPathName(name string) error { - if name == "" { - return fmt.Errorf("cannot be empty") - } - - if name[0] == '/' { - return fmt.Errorf("can't begin with a slash") - } - - if name[len(name)-1] == '/' { - return fmt.Errorf("can't end with a slash") - } - - if !rePathName.MatchString(name) { - return fmt.Errorf("can contain only alfanumeric characters, underscore, minus or slash") - } - - return nil -} - type PathConf struct { Regexp *regexp.Regexp `yaml:"-"` Source string `yaml:"source"` - SourceUrl *url.URL `yaml:"-"` SourceProtocol string `yaml:"sourceProtocol"` SourceProtocolParsed gortsplib.StreamProtocol `yaml:"-"` SourceOnDemand bool `yaml:"sourceOnDemand"` @@ -133,7 +85,7 @@ func Load(fpath string) (*Conf, error) { } // read from environment - err = confenv.Process("RTSP", conf) + err = confenv.Load("RTSP", conf) if err != nil { return nil, err } @@ -244,7 +196,7 @@ func Load(fpath string) (*Conf, error) { // normal path if name[0] != '~' { - err := checkPathName(name) + err := CheckPathName(name) if err != nil { return nil, fmt.Errorf("invalid path name: %s (%s)", err, name) } @@ -267,16 +219,13 @@ func Load(fpath string) (*Conf, error) { return nil, fmt.Errorf("a path with a regular expression (or path 'all') cannot have a RTSP source; use another path") } - pconf.SourceUrl, err = url.Parse(pconf.Source) + u, err := url.Parse(pconf.Source) if err != nil { return nil, fmt.Errorf("'%s' is not a valid url", pconf.Source) } - if pconf.SourceUrl.Port() == "" { - pconf.SourceUrl.Host += ":554" - } - if pconf.SourceUrl.User != nil { - pass, _ := pconf.SourceUrl.User.Password() - user := pconf.SourceUrl.User.Username() + if u.User != nil { + pass, _ := u.User.Password() + user := u.User.Username() if user != "" && pass == "" || user == "" && pass != "" { fmt.Errorf("username and password must be both provided") @@ -302,12 +251,17 @@ func Load(fpath string) (*Conf, error) { return nil, fmt.Errorf("a path with a regular expression (or path 'all') cannot have a RTMP source; use another path") } - pconf.SourceUrl, err = url.Parse(pconf.Source) + u, err := url.Parse(pconf.Source) if err != nil { return nil, fmt.Errorf("'%s' is not a valid url", pconf.Source) } - if pconf.SourceUrl.Port() == "" { - pconf.SourceUrl.Host += ":1935" + if u.User != nil { + pass, _ := u.User.Password() + user := u.User.Username() + if user != "" && pass == "" || + user == "" && pass != "" { + fmt.Errorf("username and password must be both provided") + } } } else if pconf.Source == "record" { @@ -371,24 +325,3 @@ func Load(fpath string) (*Conf, error) { return conf, nil } - -func (conf *Conf) CheckPathNameAndFindConf(name string) (*PathConf, error) { - err := checkPathName(name) - if err != nil { - return nil, fmt.Errorf("invalid path name: %s (%s)", err, name) - } - - // normal path - if pconf, ok := conf.Paths[name]; ok { - return pconf, nil - } - - // regular expression path - for _, pconf := range conf.Paths { - if pconf.Regexp != nil && pconf.Regexp.MatchString(name) { - return pconf, nil - } - } - - return nil, fmt.Errorf("unable to find a valid configuration for path '%s'", name) -} diff --git a/conf/utils.go b/conf/utils.go new file mode 100644 index 00000000..9e01ecd3 --- /dev/null +++ b/conf/utils.go @@ -0,0 +1,53 @@ +package conf + +import ( + "fmt" + "net" + "regexp" +) + +var rePathName = regexp.MustCompile("^[0-9a-zA-Z_\\-/]+$") + +func CheckPathName(name string) error { + if name == "" { + return fmt.Errorf("cannot be empty") + } + + if name[0] == '/' { + return fmt.Errorf("can't begin with a slash") + } + + if name[len(name)-1] == '/' { + return fmt.Errorf("can't end with a slash") + } + + if !rePathName.MatchString(name) { + return fmt.Errorf("can contain only alfanumeric characters, underscore, minus or slash") + } + + return nil +} + +func parseIpCidrList(in []string) ([]interface{}, error) { + if len(in) == 0 { + return nil, nil + } + + var ret []interface{} + for _, t := range in { + _, ipnet, err := net.ParseCIDR(t) + if err == nil { + ret = append(ret, ipnet) + continue + } + + ip := net.ParseIP(t) + if ip != nil { + ret = append(ret, ip) + continue + } + + return nil, fmt.Errorf("unable to parse ip/network '%s'", t) + } + return ret, nil +} diff --git a/confenv/confenv.go b/confenv/confenv.go index 5c4f9c19..c7c6fca5 100644 --- a/confenv/confenv.go +++ b/confenv/confenv.go @@ -9,7 +9,7 @@ import ( "time" ) -func process(env map[string]string, envKey string, rv reflect.Value) error { +func load(env map[string]string, envKey string, rv reflect.Value) error { rt := rv.Type() switch rt { @@ -93,7 +93,7 @@ func process(env map[string]string, envKey string, rv reflect.Value) error { rv.SetMapIndex(reflect.ValueOf(mapKey), nv) } - err := process(env, envKey+"_"+strings.ToUpper(mapKey), nv.Elem()) + err := load(env, envKey+"_"+strings.ToUpper(mapKey), nv.Elem()) if err != nil { return err } @@ -105,13 +105,13 @@ func process(env map[string]string, envKey string, rv reflect.Value) error { for i := 0; i < flen; i++ { f := rt.Field(i) - // process only public fields + // load only public fields if f.Tag.Get("yaml") == "-" { continue } fieldEnvKey := envKey + "_" + strings.ToUpper(f.Name) - err := process(env, fieldEnvKey, rv.Field(i)) + err := load(env, fieldEnvKey, rv.Field(i)) if err != nil { return err } @@ -122,12 +122,12 @@ func process(env map[string]string, envKey string, rv reflect.Value) error { return fmt.Errorf("unsupported type: %v", rt) } -func Process(envKey string, v interface{}) error { +func Load(envKey string, v interface{}) error { env := make(map[string]string) for _, kv := range os.Environ() { tmp := strings.Split(kv, "=") env[tmp[0]] = tmp[1] } - return process(env, envKey, reflect.ValueOf(v).Elem()) + return load(env, envKey, reflect.ValueOf(v).Elem()) } diff --git a/go.mod b/go.mod index 9effa655..b29e52a7 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.14 require ( github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751 // indirect github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d // indirect - github.com/aler9/gortsplib v0.0.0-20201016210035-b647e5ee314a + github.com/aler9/gortsplib v0.0.0-20201017143703-0b7201de6890 github.com/davecgh/go-spew v1.1.1 // indirect github.com/notedit/rtmp v0.0.2 github.com/pion/rtp v1.6.1 // indirect diff --git a/go.sum b/go.sum index d6dcd870..16a8b5dc 100644 --- a/go.sum +++ b/go.sum @@ -2,8 +2,8 @@ github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751 h1:JYp7IbQjafo github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d h1:UQZhZ2O0vMHr2cI+DC1Mbh0TJxzA3RcLoMsFw+aXw7E= github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d/go.mod h1:rBZYJk541a8SKzHPHnH3zbiI+7dagKZ0cgpgrD7Fyho= -github.com/aler9/gortsplib v0.0.0-20201016210035-b647e5ee314a h1:oSyGggUDkjjKrNw8M0GH58zbV9XJRtOCITMExgGG6aA= -github.com/aler9/gortsplib v0.0.0-20201016210035-b647e5ee314a/go.mod h1:8mpBfMEJIZn2C5fMM6vRYHgGH49WX0EH8gP1SDxv0Uw= +github.com/aler9/gortsplib v0.0.0-20201017143703-0b7201de6890 h1:p5tvUXK9CrNrkVdRhVetDuOsT64y+qEos5ZWnbxdUMo= +github.com/aler9/gortsplib v0.0.0-20201017143703-0b7201de6890/go.mod h1:8mpBfMEJIZn2C5fMM6vRYHgGH49WX0EH8gP1SDxv0Uw= github.com/aler9/sdp-dirty/v3 v3.0.0-20200919115950-f1abc664f625 h1:A3upkpYzceQTuBPvVleu1zd6R8jInhg5ifimSO7ku/o= github.com/aler9/sdp-dirty/v3 v3.0.0-20200919115950-f1abc664f625/go.mod h1:5bO/aUQr9m3OasDatNNcVqKAgs7r5hgGXmszWHaC6mI= github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= diff --git a/loghandler/loghandler.go b/loghandler/loghandler.go index 41a1e985..af85f4e8 100644 --- a/loghandler/loghandler.go +++ b/loghandler/loghandler.go @@ -16,10 +16,17 @@ const ( DestinationSyslog ) +type writeFunc func(p []byte) (int, error) + +func (f writeFunc) Write(p []byte) (int, error) { + return f(p) +} + type LogHandler struct { destinations map[Destination]struct{} - file *os.File - syslog io.WriteCloser + + file *os.File + syslog io.WriteCloser } func New(destinations map[Destination]struct{}, filePath string) (*LogHandler, error) { @@ -45,7 +52,7 @@ func New(destinations map[Destination]struct{}, filePath string) (*LogHandler, e } } - log.SetOutput(lh) + log.SetOutput(writeFunc(lh.write)) return lh, nil } @@ -60,7 +67,7 @@ func (lh *LogHandler) Close() { } } -func (lh *LogHandler) Write(p []byte) (int, error) { +func (lh *LogHandler) write(p []byte) (int, error) { if _, ok := lh.destinations[DestinationStdout]; ok { print(string(p)) } diff --git a/main.go b/main.go index 484f03a9..f3953979 100644 --- a/main.go +++ b/main.go @@ -3,61 +3,39 @@ package main import ( "fmt" "log" - "net" "os" - "sync" "sync/atomic" - "time" "github.com/aler9/gortsplib" "gopkg.in/alecthomas/kingpin.v2" + "github.com/aler9/rtsp-simple-server/clientman" "github.com/aler9/rtsp-simple-server/conf" "github.com/aler9/rtsp-simple-server/loghandler" + "github.com/aler9/rtsp-simple-server/metrics" + "github.com/aler9/rtsp-simple-server/pathman" + "github.com/aler9/rtsp-simple-server/pprof" + "github.com/aler9/rtsp-simple-server/servertcp" + "github.com/aler9/rtsp-simple-server/serverudp" + "github.com/aler9/rtsp-simple-server/stats" ) var Version = "v0.0.0" -const ( - checkPathPeriod = 5 * time.Second -) - type program struct { - conf *conf.Conf - logHandler *loghandler.LogHandler - metrics *metrics - pprof *pprof - paths map[string]*path - serverUdpRtp *serverUDP - serverUdpRtcp *serverUDP - serverTcp *serverTCP - clients map[*client]struct{} - clientsWg sync.WaitGroup - udpPublishersMap *udpPublishersMap - readersMap *readersMap - // use pointers to avoid a crash on 32bit platforms - // https://github.com/golang/go/issues/9959 - countClients *int64 - countPublishers *int64 - countReaders *int64 - countSourcesRtsp *int64 - countSourcesRtspRunning *int64 - countSourcesRtmp *int64 - countSourcesRtmpRunning *int64 + conf *conf.Conf + stats *stats.Stats + logHandler *loghandler.LogHandler + metrics *metrics.Metrics + pprof *pprof.Pprof + serverUdpRtp *serverudp.Server + serverUdpRtcp *serverudp.Server + serverTcp *servertcp.Server + pathMan *pathman.PathManager + clientMan *clientman.ClientManager - clientNew chan net.Conn - clientClose chan *client - clientDescribe chan clientDescribeReq - clientAnnounce chan clientAnnounceReq - clientSetupPlay chan clientSetupPlayReq - clientPlay chan *client - clientRecord chan *client - sourceRtspReady chan *sourceRtsp - sourceRtspNotReady chan *sourceRtsp - sourceRtmpReady chan *sourceRtmp - sourceRtmpNotReady chan *sourceRtmp - terminate chan struct{} - done chan struct{} + terminate chan struct{} + done chan struct{} } func newProgram(args []string) (*program, error) { @@ -79,109 +57,76 @@ func newProgram(args []string) (*program, error) { return nil, err } - logHandler, err := loghandler.New(conf.LogDestinationsParsed, conf.LogFile) + p := &program{ + conf: conf, + terminate: make(chan struct{}), + done: make(chan struct{}), + } + + p.stats = stats.New() + + p.logHandler, err = loghandler.New(conf.LogDestinationsParsed, conf.LogFile) if err != nil { + p.closeResources() return nil, err } - p := &program{ - conf: conf, - logHandler: logHandler, - paths: make(map[string]*path), - clients: make(map[*client]struct{}), - udpPublishersMap: newUdpPublisherMap(), - readersMap: newReadersMap(), - countClients: func() *int64 { - v := int64(0) - return &v - }(), - countPublishers: func() *int64 { - v := int64(0) - return &v - }(), - countReaders: func() *int64 { - v := int64(0) - return &v - }(), - countSourcesRtsp: func() *int64 { - v := int64(0) - return &v - }(), - countSourcesRtspRunning: func() *int64 { - v := int64(0) - return &v - }(), - countSourcesRtmp: func() *int64 { - v := int64(0) - return &v - }(), - countSourcesRtmpRunning: func() *int64 { - v := int64(0) - return &v - }(), - clientNew: make(chan net.Conn), - clientClose: make(chan *client), - clientDescribe: make(chan clientDescribeReq), - clientAnnounce: make(chan clientAnnounceReq), - clientSetupPlay: make(chan clientSetupPlayReq), - clientPlay: make(chan *client), - clientRecord: make(chan *client), - sourceRtspReady: make(chan *sourceRtsp), - sourceRtspNotReady: make(chan *sourceRtsp), - sourceRtmpReady: make(chan *sourceRtmp), - sourceRtmpNotReady: make(chan *sourceRtmp), - terminate: make(chan struct{}), - done: make(chan struct{}), - } - - p.log("rtsp-simple-server %s", Version) + p.Log("rtsp-simple-server %s", Version) if conf.Metrics { - p.metrics, err = newMetrics(p) + p.metrics, err = metrics.New(p.stats, p) if err != nil { + p.closeResources() return nil, err } } if conf.Pprof { - p.pprof, err = newPprof(p) + p.pprof, err = pprof.New(p) if err != nil { + p.closeResources() return nil, err } } - for name, pathConf := range conf.Paths { - if pathConf.Regexp == nil { - p.paths[name] = newPath(p, name, pathConf) - } - } - if _, ok := conf.ProtocolsParsed[gortsplib.StreamProtocolUDP]; ok { - p.serverUdpRtp, err = newServerUDP(p, conf.RtpPort, gortsplib.StreamTypeRtp) + p.serverUdpRtp, err = serverudp.New(p.conf.WriteTimeout, + conf.RtpPort, gortsplib.StreamTypeRtp, p) if err != nil { + p.closeResources() return nil, err } - p.serverUdpRtcp, err = newServerUDP(p, conf.RtcpPort, gortsplib.StreamTypeRtcp) + p.serverUdpRtcp, err = serverudp.New(p.conf.WriteTimeout, + conf.RtcpPort, gortsplib.StreamTypeRtcp, p) if err != nil { + p.closeResources() return nil, err } } - p.serverTcp, err = newServerTCP(p) + p.serverTcp, err = servertcp.New(conf.RtspPort, p) if err != nil { + p.closeResources() return nil, err } - go p.run() + p.pathMan = pathman.New(p.stats, p.serverUdpRtp, p.serverUdpRtcp, + p.conf.ReadTimeout, p.conf.WriteTimeout, p.conf.AuthMethodsParsed, + conf.Paths, p) + p.clientMan = clientman.New(p.stats, p.serverUdpRtp, p.serverUdpRtcp, + p.conf.ReadTimeout, p.conf.WriteTimeout, p.conf.RunOnConnect, + p.conf.ProtocolsParsed, p.pathMan, p.serverTcp, p) + + go p.run() return p, nil } -func (p *program) log(format string, args ...interface{}) { - countClients := atomic.LoadInt64(p.countClients) - countPublishers := atomic.LoadInt64(p.countPublishers) - countReaders := atomic.LoadInt64(p.countReaders) +func (p *program) Log(format string, args ...interface{}) { + countClients := atomic.LoadInt64(p.stats.CountClients) + countPublishers := atomic.LoadInt64(p.stats.CountPublishers) + countReaders := atomic.LoadInt64(p.stats.CountReaders) log.Printf(fmt.Sprintf("[%d/%d/%d] "+format, append([]interface{}{countClients, countPublishers, countReaders}, args...)...)) @@ -190,207 +135,49 @@ func (p *program) log(format string, args ...interface{}) { func (p *program) run() { defer close(p.done) - if p.metrics != nil { - go p.metrics.run() - } - - if p.pprof != nil { - go p.pprof.run() - } - - if p.serverUdpRtp != nil { - go p.serverUdpRtp.run() - } - - if p.serverUdpRtcp != nil { - go p.serverUdpRtcp.run() - } - - go p.serverTcp.run() - - for _, p := range p.paths { - p.onInit() - } - - checkPathsTicker := time.NewTicker(checkPathPeriod) - defer checkPathsTicker.Stop() - outer: for { select { - case <-checkPathsTicker.C: - for _, path := range p.paths { - path.onCheck() - } - - case conn := <-p.clientNew: - newClient(p, conn) - - case client := <-p.clientClose: - if _, ok := p.clients[client]; !ok { - continue - } - client.close() - - case req := <-p.clientDescribe: - // create path if it doesn't exist - if _, ok := p.paths[req.pathName]; !ok { - p.paths[req.pathName] = newPath(p, req.pathName, req.pathConf) - } - - p.paths[req.pathName].onDescribe(req.client) - - case req := <-p.clientAnnounce: - // create path if it doesn't exist - if path, ok := p.paths[req.pathName]; !ok { - p.paths[req.pathName] = newPath(p, req.pathName, req.pathConf) - - } else { - if path.source != nil { - req.res <- fmt.Errorf("someone is already publishing on path '%s'", req.pathName) - continue - } - } - - p.paths[req.pathName].source = req.client - p.paths[req.pathName].sourceTrackCount = req.trackCount - p.paths[req.pathName].sourceSdp = req.sdp - - req.client.path = p.paths[req.pathName] - req.client.state = clientStatePreRecord - req.res <- nil - - case req := <-p.clientSetupPlay: - path, ok := p.paths[req.pathName] - if !ok || !path.sourceReady { - req.res <- fmt.Errorf("no one is publishing on path '%s'", req.pathName) - continue - } - - if req.trackId >= path.sourceTrackCount { - req.res <- fmt.Errorf("track %d does not exist", req.trackId) - continue - } - - req.client.path = path - req.client.state = clientStatePrePlay - req.res <- nil - - case client := <-p.clientPlay: - atomic.AddInt64(p.countReaders, 1) - client.state = clientStatePlay - p.readersMap.add(client) - - case client := <-p.clientRecord: - atomic.AddInt64(p.countPublishers, 1) - client.state = clientStateRecord - - if client.streamProtocol == gortsplib.StreamProtocolUDP { - for trackId, track := range client.streamTracks { - addr := makeUDPPublisherAddr(client.ip(), track.rtpPort) - p.udpPublishersMap.add(addr, &udpPublisher{ - client: client, - trackId: trackId, - streamType: gortsplib.StreamTypeRtp, - }) - - addr = makeUDPPublisherAddr(client.ip(), track.rtcpPort) - p.udpPublishersMap.add(addr, &udpPublisher{ - client: client, - trackId: trackId, - streamType: gortsplib.StreamTypeRtcp, - }) - } - } - - client.path.onSourceSetReady() - - case s := <-p.sourceRtspReady: - s.path.onSourceSetReady() - - case s := <-p.sourceRtspNotReady: - s.path.onSourceSetNotReady() - - case s := <-p.sourceRtmpReady: - s.path.onSourceSetReady() - - case s := <-p.sourceRtmpNotReady: - s.path.onSourceSetNotReady() - case <-p.terminate: break outer } } - go func() { - for { - select { - case _, ok := <-p.clientNew: - if !ok { - return - } + p.closeResources() +} - case <-p.clientClose: - case <-p.clientDescribe: - - case req := <-p.clientAnnounce: - req.res <- fmt.Errorf("terminated") - - case req := <-p.clientSetupPlay: - req.res <- fmt.Errorf("terminated") - - case <-p.clientPlay: - case <-p.clientRecord: - case <-p.sourceRtspReady: - case <-p.sourceRtspNotReady: - case <-p.sourceRtmpReady: - case <-p.sourceRtmpNotReady: - } - } - }() - - p.udpPublishersMap.clear() - p.readersMap.clear() - - for _, p := range p.paths { - p.onClose() +func (p *program) closeResources() { + if p.clientMan != nil { + p.clientMan.Close() } - p.serverTcp.close() + if p.pathMan != nil { + p.pathMan.Close() + } + + if p.serverTcp != nil { + p.serverTcp.Close() + } if p.serverUdpRtcp != nil { - p.serverUdpRtcp.close() + p.serverUdpRtcp.Close() } if p.serverUdpRtp != nil { - p.serverUdpRtp.close() + p.serverUdpRtp.Close() } - for c := range p.clients { - c.close() - } - - p.clientsWg.Wait() - if p.metrics != nil { - p.metrics.close() + p.metrics.Close() } if p.pprof != nil { - p.pprof.close() + p.pprof.Close() } - p.logHandler.Close() - - close(p.clientNew) - close(p.clientClose) - close(p.clientDescribe) - close(p.clientAnnounce) - close(p.clientSetupPlay) - close(p.clientPlay) - close(p.clientRecord) - close(p.sourceRtspReady) - close(p.sourceRtspNotReady) + if p.logHandler != nil { + p.logHandler.Close() + } } func (p *program) close() { diff --git a/main_test.go b/main_test.go index f8508538..441c25ea 100644 --- a/main_test.go +++ b/main_test.go @@ -3,7 +3,6 @@ package main import ( "io/ioutil" "net" - "net/url" "os" "os/exec" "regexp" @@ -191,11 +190,7 @@ func TestEnvironment(t *testing.T) { pa, ok = p.conf.Paths["cam1"] require.Equal(t, true, ok) require.Equal(t, &conf.PathConf{ - Source: "rtsp://testing", - SourceUrl: func() *url.URL { - u, _ := url.Parse("rtsp://testing:554") - return u - }(), + Source: "rtsp://testing", SourceProtocol: "tcp", SourceProtocolParsed: gortsplib.StreamProtocolTCP, SourceOnDemand: true, @@ -213,11 +208,7 @@ func TestEnvironmentNoFile(t *testing.T) { pa, ok := p.conf.Paths["cam1"] require.Equal(t, true, ok) require.Equal(t, &conf.PathConf{ - Source: "rtsp://testing", - SourceUrl: func() *url.URL { - u, _ := url.Parse("rtsp://testing:554") - return u - }(), + Source: "rtsp://testing", SourceProtocol: "udp", }, pa) } diff --git a/metrics.go b/metrics/metrics.go similarity index 58% rename from metrics.go rename to metrics/metrics.go index 467301b9..5df9f5da 100644 --- a/metrics.go +++ b/metrics/metrics.go @@ -1,4 +1,4 @@ -package main +package metrics import ( "context" @@ -8,27 +8,34 @@ import ( "net/http" "sync/atomic" "time" + + "github.com/aler9/rtsp-simple-server/stats" ) const ( - metricsAddress = ":9998" + address = ":9998" ) -type metrics struct { - p *program +type Parent interface { + Log(string, ...interface{}) +} + +type Metrics struct { + stats *stats.Stats + listener net.Listener mux *http.ServeMux server *http.Server } -func newMetrics(p *program) (*metrics, error) { - listener, err := net.Listen("tcp", metricsAddress) +func New(stats *stats.Stats, parent Parent) (*Metrics, error) { + listener, err := net.Listen("tcp", address) if err != nil { return nil, err } - m := &metrics{ - p: p, + m := &Metrics{ + stats: stats, listener: listener, } @@ -39,31 +46,33 @@ func newMetrics(p *program) (*metrics, error) { Handler: m.mux, } - m.p.log("[metrics] opened on " + metricsAddress) + parent.Log("[metrics] opened on " + address) + + go m.run() return m, nil } -func (m *metrics) run() { +func (m *Metrics) Close() { + m.server.Shutdown(context.Background()) +} + +func (m *Metrics) run() { err := m.server.Serve(m.listener) if err != http.ErrServerClosed { panic(err) } } -func (m *metrics) close() { - m.server.Shutdown(context.Background()) -} - -func (m *metrics) onMetrics(w http.ResponseWriter, req *http.Request) { +func (m *Metrics) onMetrics(w http.ResponseWriter, req *http.Request) { now := time.Now().UnixNano() / 1000000 - countClients := atomic.LoadInt64(m.p.countClients) - countPublishers := atomic.LoadInt64(m.p.countPublishers) - countReaders := atomic.LoadInt64(m.p.countReaders) - countSourcesRtsp := atomic.LoadInt64(m.p.countSourcesRtsp) - countSourcesRtspRunning := atomic.LoadInt64(m.p.countSourcesRtspRunning) - countSourcesRtmp := atomic.LoadInt64(m.p.countSourcesRtmp) - countSourcesRtmpRunning := atomic.LoadInt64(m.p.countSourcesRtmpRunning) + countClients := atomic.LoadInt64(m.stats.CountClients) + countPublishers := atomic.LoadInt64(m.stats.CountPublishers) + countReaders := atomic.LoadInt64(m.stats.CountReaders) + countSourcesRtsp := atomic.LoadInt64(m.stats.CountSourcesRtsp) + countSourcesRtspRunning := atomic.LoadInt64(m.stats.CountSourcesRtspRunning) + countSourcesRtmp := atomic.LoadInt64(m.stats.CountSourcesRtmp) + countSourcesRtmpRunning := atomic.LoadInt64(m.stats.CountSourcesRtmpRunning) out := "" out += fmt.Sprintf("rtsp_clients{state=\"idle\"} %d %v\n", diff --git a/path.go b/path.go deleted file mode 100644 index 330a283f..00000000 --- a/path.go +++ /dev/null @@ -1,293 +0,0 @@ -package main - -import ( - "fmt" - "strings" - "sync/atomic" - "time" - - "github.com/aler9/rtsp-simple-server/conf" - "github.com/aler9/rtsp-simple-server/externalcmd" -) - -const ( - describeTimeout = 5 * time.Second - sourceStopAfterDescribePeriod = 10 * time.Second - onDemandCmdStopAfterDescribePeriod = 10 * time.Second -) - -// a source can be a client, a sourceRtsp or a sourceRtmp -type source interface { - isSource() -} - -type path struct { - p *program - name string - conf *conf.PathConf - source source - sourceReady bool - sourceTrackCount int - sourceSdp []byte - lastDescribeReq time.Time - lastDescribeActivation time.Time - onInitCmd *externalcmd.ExternalCmd - onDemandCmd *externalcmd.ExternalCmd -} - -func newPath(p *program, name string, conf *conf.PathConf) *path { - pa := &path{ - p: p, - name: name, - conf: conf, - } - - if strings.HasPrefix(conf.Source, "rtsp://") { - s := newSourceRtsp(p, pa) - pa.source = s - - } else if strings.HasPrefix(conf.Source, "rtmp://") { - s := newSourceRtmp(p, pa) - pa.source = s - } - - return pa -} - -func (pa *path) log(format string, args ...interface{}) { - pa.p.log("[path "+pa.name+"] "+format, args...) -} - -func (pa *path) onInit() { - if source, ok := pa.source.(*sourceRtsp); ok { - go source.run(source.state) - - } else if source, ok := pa.source.(*sourceRtmp); ok { - go source.run(source.state) - } - - if pa.conf.RunOnInit != "" { - pa.log("starting on init command") - - var err error - pa.onInitCmd, err = externalcmd.New(pa.conf.RunOnInit, pa.name) - if err != nil { - pa.log("ERR: %s", err) - } - } -} - -func (pa *path) onClose() { - if source, ok := pa.source.(*sourceRtsp); ok { - close(source.terminate) - <-source.done - - } else if source, ok := pa.source.(*sourceRtmp); ok { - close(source.terminate) - <-source.done - } - - if pa.onInitCmd != nil { - pa.log("stopping on init command (closing)") - pa.onInitCmd.Close() - } - - if pa.onDemandCmd != nil { - pa.log("stopping on demand command (closing)") - pa.onDemandCmd.Close() - } - - for c := range pa.p.clients { - if c.path == pa { - if c.state == clientStateWaitDescription { - c.path = nil - c.state = clientStateInitial - c.describe <- describeRes{nil, fmt.Errorf("publisher of path '%s' has timed out", pa.name)} - } else { - c.close() - } - } - } -} - -func (pa *path) hasClients() bool { - for c := range pa.p.clients { - if c.path == pa { - return true - } - } - return false -} - -func (pa *path) hasClientsWaitingDescribe() bool { - for c := range pa.p.clients { - if c.state == clientStateWaitDescription && c.path == pa { - return true - } - } - return false -} - -func (pa *path) hasClientReaders() bool { - for c := range pa.p.clients { - if c.path == pa && c != pa.source { - return true - } - } - return false -} - -func (pa *path) onCheck() { - // reply to DESCRIBE requests if they are in timeout - if pa.hasClientsWaitingDescribe() && - time.Since(pa.lastDescribeActivation) >= describeTimeout { - for c := range pa.p.clients { - if c.state == clientStateWaitDescription && - c.path == pa { - c.path = nil - c.state = clientStateInitial - c.describe <- describeRes{nil, fmt.Errorf("publisher of path '%s' has timed out", pa.name)} - } - } - } - - // stop on demand rtsp source if needed - if source, ok := pa.source.(*sourceRtsp); ok { - if pa.conf.SourceOnDemand && - source.state == sourceRtspStateRunning && - !pa.hasClients() && - time.Since(pa.lastDescribeReq) >= sourceStopAfterDescribePeriod { - pa.log("stopping on demand rtsp source (not requested anymore)") - atomic.AddInt64(pa.p.countSourcesRtspRunning, -1) - source.state = sourceRtspStateStopped - source.setState <- source.state - } - - // stop on demand rtmp source if needed - } else if source, ok := pa.source.(*sourceRtmp); ok { - if pa.conf.SourceOnDemand && - source.state == sourceRtmpStateRunning && - !pa.hasClients() && - time.Since(pa.lastDescribeReq) >= sourceStopAfterDescribePeriod { - pa.log("stopping on demand rtmp source (not requested anymore)") - atomic.AddInt64(pa.p.countSourcesRtmpRunning, -1) - source.state = sourceRtmpStateStopped - source.setState <- source.state - } - } - - // stop on demand command if needed - if pa.onDemandCmd != nil && - !pa.hasClientReaders() && - time.Since(pa.lastDescribeReq) >= onDemandCmdStopAfterDescribePeriod { - pa.log("stopping on demand command (not requested anymore)") - pa.onDemandCmd.Close() - pa.onDemandCmd = nil - } - - // remove regular expression paths - if pa.conf.Regexp != nil && - pa.source == nil && - !pa.hasClients() { - pa.onClose() - delete(pa.p.paths, pa.name) - } -} - -func (pa *path) onSourceRemove() { - pa.source = nil - - // close all clients that are reading or waiting for reading - for c := range pa.p.clients { - if c.path == pa && - c.state != clientStateWaitDescription && - c != pa.source { - c.close() - } - } -} - -func (pa *path) onSourceSetReady() { - pa.sourceReady = true - - // reply to all clients that are waiting for a description - for c := range pa.p.clients { - if c.state == clientStateWaitDescription && - c.path == pa { - c.path = nil - c.state = clientStateInitial - c.describe <- describeRes{pa.sourceSdp, nil} - } - } -} - -func (pa *path) onSourceSetNotReady() { - pa.sourceReady = false - - // close all clients that are reading or waiting for reading - for c := range pa.p.clients { - if c.path == pa && - c.state != clientStateWaitDescription && - c != pa.source { - c.close() - } - } -} - -func (pa *path) onDescribe(client *client) { - pa.lastDescribeReq = time.Now() - - // publisher not found - if pa.source == nil { - // on demand command is available: put the client on hold - if pa.conf.RunOnDemand != "" { - if pa.onDemandCmd == nil { // start if needed - pa.log("starting on demand command") - pa.lastDescribeActivation = time.Now() - - var err error - pa.onDemandCmd, err = externalcmd.New(pa.conf.RunOnDemand, pa.name) - if err != nil { - pa.log("ERR: %s", err) - } - } - - client.path = pa - client.state = clientStateWaitDescription - - // no on-demand: reply with 404 - } else { - client.describe <- describeRes{nil, fmt.Errorf("no one is publishing on path '%s'", pa.name)} - } - - // publisher was found but is not ready: put the client on hold - } else if !pa.sourceReady { - // start rtsp source if needed - if source, ok := pa.source.(*sourceRtsp); ok { - if source.state == sourceRtspStateStopped { - pa.log("starting on demand rtsp source") - pa.lastDescribeActivation = time.Now() - atomic.AddInt64(pa.p.countSourcesRtspRunning, +1) - source.state = sourceRtspStateRunning - source.setState <- source.state - } - - // start rtmp source if needed - } else if source, ok := pa.source.(*sourceRtmp); ok { - if source.state == sourceRtmpStateStopped { - pa.log("starting on demand rtmp source") - pa.lastDescribeActivation = time.Now() - atomic.AddInt64(pa.p.countSourcesRtmpRunning, +1) - source.state = sourceRtmpStateRunning - source.setState <- source.state - } - } - - client.path = pa - client.state = clientStateWaitDescription - - // publisher was found and is ready - } else { - client.describe <- describeRes{pa.sourceSdp, nil} - } -} diff --git a/path/path.go b/path/path.go new file mode 100644 index 00000000..2ece79f4 --- /dev/null +++ b/path/path.go @@ -0,0 +1,660 @@ +package path + +import ( + "fmt" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/aler9/gortsplib" + "github.com/aler9/gortsplib/base" + + "github.com/aler9/rtsp-simple-server/client" + "github.com/aler9/rtsp-simple-server/conf" + "github.com/aler9/rtsp-simple-server/externalcmd" + "github.com/aler9/rtsp-simple-server/serverudp" + "github.com/aler9/rtsp-simple-server/sourcertmp" + "github.com/aler9/rtsp-simple-server/sourcertsp" + "github.com/aler9/rtsp-simple-server/stats" +) + +const ( + pathCheckPeriod = 5 * time.Second + describeTimeout = 5 * time.Second + sourceStopAfterDescribePeriod = 10 * time.Second + onDemandCmdStopAfterDescribePeriod = 10 * time.Second +) + +type Parent interface { + Log(string, ...interface{}) + OnPathClose(*Path) + OnPathClientClose(*client.Client) +} + +// a source can be a client, a sourcertsp.Source or a sourcertmp.Source +type source interface { + IsSource() +} + +type ClientDescribeRes struct { + Path client.Path + Err error +} + +type ClientDescribeReq struct { + Res chan ClientDescribeRes + Client *client.Client + PathName string + Req *base.Request +} + +type ClientAnnounceRes struct { + Path client.Path + Err error +} + +type ClientAnnounceReq struct { + Res chan ClientAnnounceRes + Client *client.Client + PathName string + Tracks gortsplib.Tracks + Req *base.Request +} + +type ClientSetupPlayRes struct { + Path client.Path + Err error +} + +type ClientSetupPlayReq struct { + Res chan ClientSetupPlayRes + Client *client.Client + PathName string + TrackId int + Req *base.Request +} + +type clientRemoveReq struct { + res chan struct{} + client *client.Client +} + +type clientPlayReq struct { + res chan struct{} + client *client.Client +} + +type clientRecordReq struct { + res chan struct{} + client *client.Client +} + +type clientState int + +const ( + clientStateWaitingDescribe clientState = iota + clientStatePrePlay + clientStatePlay + clientStatePreRecord + clientStateRecord +) + +type Path struct { + wg *sync.WaitGroup + stats *stats.Stats + serverUdpRtp *serverudp.Server + serverUdpRtcp *serverudp.Server + readTimeout time.Duration + writeTimeout time.Duration + name string + conf *conf.PathConf + parent Parent + + clients map[*client.Client]clientState + source source + sourceReady bool + sourceTrackCount int + sourceSdp []byte + lastDescribeReq time.Time + lastDescribeActivation time.Time + readers *readersMap + onInitCmd *externalcmd.ExternalCmd + onDemandCmd *externalcmd.ExternalCmd + + // in + sourceSetReady chan struct{} // from source + sourceSetNotReady chan struct{} // from source + clientDescribe chan ClientDescribeReq // from program + clientAnnounce chan ClientAnnounceReq // from program + clientSetupPlay chan ClientSetupPlayReq // from program + clientPlay chan clientPlayReq // from client + clientRecord chan clientRecordReq // from client + clientRemove chan clientRemoveReq // from client + terminate chan struct{} +} + +func New( + wg *sync.WaitGroup, + stats *stats.Stats, + serverUdpRtp *serverudp.Server, + serverUdpRtcp *serverudp.Server, + readTimeout time.Duration, + writeTimeout time.Duration, + name string, + conf *conf.PathConf, + parent Parent) *Path { + + pa := &Path{ + wg: wg, + stats: stats, + serverUdpRtp: serverUdpRtp, + serverUdpRtcp: serverUdpRtcp, + readTimeout: readTimeout, + writeTimeout: writeTimeout, + name: name, + conf: conf, + parent: parent, + clients: make(map[*client.Client]clientState), + readers: newReadersMap(), + sourceSetReady: make(chan struct{}), + sourceSetNotReady: make(chan struct{}), + clientDescribe: make(chan ClientDescribeReq), + clientAnnounce: make(chan ClientAnnounceReq), + clientSetupPlay: make(chan ClientSetupPlayReq), + clientPlay: make(chan clientPlayReq), + clientRecord: make(chan clientRecordReq), + clientRemove: make(chan clientRemoveReq), + terminate: make(chan struct{}), + } + + pa.wg.Add(1) + go pa.run() + return pa +} + +func (pa *Path) Close() { + close(pa.terminate) +} + +func (pa *Path) Log(format string, args ...interface{}) { + pa.parent.Log("[path "+pa.name+"] "+format, args...) +} + +func (pa *Path) run() { + defer pa.wg.Done() + + if strings.HasPrefix(pa.conf.Source, "rtsp://") { + state := sourcertsp.StateStopped + if !pa.conf.SourceOnDemand { + state = sourcertsp.StateRunning + } + + s := sourcertsp.New( + pa.conf.Source, + pa.conf.SourceProtocolParsed, + pa.readTimeout, + pa.writeTimeout, + state, + pa) + pa.source = s + + atomic.AddInt64(pa.stats.CountSourcesRtsp, +1) + if !pa.conf.SourceOnDemand { + atomic.AddInt64(pa.stats.CountSourcesRtspRunning, +1) + } + + } else if strings.HasPrefix(pa.conf.Source, "rtmp://") { + state := sourcertmp.StateStopped + if !pa.conf.SourceOnDemand { + state = sourcertmp.StateRunning + } + + s := sourcertmp.New( + pa.conf.Source, + state, + pa) + pa.source = s + + atomic.AddInt64(pa.stats.CountSourcesRtmp, +1) + if !pa.conf.SourceOnDemand { + atomic.AddInt64(pa.stats.CountSourcesRtmpRunning, +1) + } + } + + if pa.conf.RunOnInit != "" { + pa.Log("starting on init command") + + var err error + pa.onInitCmd, err = externalcmd.New(pa.conf.RunOnInit, pa.name) + if err != nil { + pa.Log("ERR: %s", err) + } + } + + tickerCheck := time.NewTicker(pathCheckPeriod) + defer tickerCheck.Stop() + +outer: + for { + select { + case <-tickerCheck.C: + ok := pa.onCheck() + if !ok { + pa.parent.OnPathClose(pa) + <-pa.terminate + break outer + } + + case <-pa.sourceSetReady: + pa.onSourceSetReady() + + case <-pa.sourceSetNotReady: + pa.onSourceSetNotReady() + + case req := <-pa.clientDescribe: + // reply immediately + req.Res <- ClientDescribeRes{pa, nil} + pa.onClientDescribe(req.Client) + + case req := <-pa.clientSetupPlay: + err := pa.onClientSetupPlay(req.Client, req.TrackId) + if err != nil { + req.Res <- ClientSetupPlayRes{nil, err} + continue + } + req.Res <- ClientSetupPlayRes{pa, nil} + + case req := <-pa.clientPlay: + if _, ok := pa.clients[req.client]; ok { + pa.onClientPlay(req.client) + } + close(req.res) + + case req := <-pa.clientAnnounce: + err := pa.onClientAnnounce(req.Client, req.Tracks) + if err != nil { + req.Res <- ClientAnnounceRes{nil, err} + continue + } + req.Res <- ClientAnnounceRes{pa, nil} + + case req := <-pa.clientRecord: + if _, ok := pa.clients[req.client]; ok { + pa.onClientRecord(req.client) + } + close(req.res) + + case req := <-pa.clientRemove: + if _, ok := pa.clients[req.client]; ok { + pa.onClientRemove(req.client) + } + close(req.res) + + case <-pa.terminate: + break outer + } + } + + go func() { + for { + select { + case _, ok := <-pa.sourceSetReady: + if !ok { + return + } + + case _, ok := <-pa.sourceSetNotReady: + if !ok { + return + } + + case req, ok := <-pa.clientDescribe: + if !ok { + return + } + req.Res <- ClientDescribeRes{nil, fmt.Errorf("terminated")} + + case req, ok := <-pa.clientAnnounce: + if !ok { + return + } + req.Res <- ClientAnnounceRes{nil, fmt.Errorf("terminated")} + + case req, ok := <-pa.clientSetupPlay: + if !ok { + return + } + req.Res <- ClientSetupPlayRes{nil, fmt.Errorf("terminated")} + + case req, ok := <-pa.clientPlay: + if !ok { + return + } + close(req.res) + + case req, ok := <-pa.clientRecord: + if !ok { + return + } + close(req.res) + + case req, ok := <-pa.clientRemove: + if !ok { + return + } + close(req.res) + } + } + }() + + if pa.onInitCmd != nil { + pa.Log("stopping on init command (closing)") + pa.onInitCmd.Close() + } + + if source, ok := pa.source.(*sourcertsp.Source); ok { + source.Close() + + } else if source, ok := pa.source.(*sourcertmp.Source); ok { + source.Close() + } + + if pa.onDemandCmd != nil { + pa.Log("stopping on demand command (closing)") + pa.onDemandCmd.Close() + } + + for c, state := range pa.clients { + if state == clientStateWaitingDescribe { + delete(pa.clients, c) + c.OnPathDescribeData(nil, fmt.Errorf("publisher of path '%s' has timed out", pa.name)) + } else { + pa.onClientRemove(c) + pa.parent.OnPathClientClose(c) + } + } + + close(pa.sourceSetReady) + close(pa.sourceSetNotReady) + close(pa.clientDescribe) + close(pa.clientAnnounce) + close(pa.clientSetupPlay) + close(pa.clientPlay) + close(pa.clientRecord) + close(pa.clientRemove) +} + +func (pa *Path) hasClients() bool { + return len(pa.clients) > 0 +} + +func (pa *Path) hasClientsWaitingDescribe() bool { + for _, state := range pa.clients { + if state == clientStateWaitingDescribe { + return true + } + } + return false +} + +func (pa *Path) hasClientReadersOrWaitingDescribe() bool { + for c := range pa.clients { + if c != pa.source { + return true + } + } + return false +} + +func (pa *Path) onCheck() bool { + // reply to DESCRIBE requests if they are in timeout + if pa.hasClientsWaitingDescribe() && + time.Since(pa.lastDescribeActivation) >= describeTimeout { + for c, state := range pa.clients { + if state == clientStateWaitingDescribe { + delete(pa.clients, c) + c.OnPathDescribeData(nil, fmt.Errorf("publisher of path '%s' has timed out", pa.name)) + } + } + } + + // stop on demand rtsp source if needed + if source, ok := pa.source.(*sourcertsp.Source); ok { + if pa.conf.SourceOnDemand && + source.State() == sourcertsp.StateRunning && + !pa.hasClients() && + time.Since(pa.lastDescribeReq) >= sourceStopAfterDescribePeriod { + pa.Log("stopping on demand rtsp source (not requested anymore)") + atomic.AddInt64(pa.stats.CountSourcesRtspRunning, -1) + source.SetState(sourcertsp.StateStopped) + } + + // stop on demand rtmp source if needed + } else if source, ok := pa.source.(*sourcertmp.Source); ok { + if pa.conf.SourceOnDemand && + source.State() == sourcertmp.StateRunning && + !pa.hasClients() && + time.Since(pa.lastDescribeReq) >= sourceStopAfterDescribePeriod { + pa.Log("stopping on demand rtmp source (not requested anymore)") + atomic.AddInt64(pa.stats.CountSourcesRtmpRunning, -1) + source.SetState(sourcertmp.StateStopped) + } + } + + // stop on demand command if needed + if pa.onDemandCmd != nil && + !pa.hasClientReadersOrWaitingDescribe() && + time.Since(pa.lastDescribeReq) >= onDemandCmdStopAfterDescribePeriod { + pa.Log("stopping on demand command (not requested anymore)") + pa.onDemandCmd.Close() + pa.onDemandCmd = nil + } + + // remove path if is regexp and has no clients + if pa.conf.Regexp != nil && + pa.source == nil && + !pa.hasClients() { + return false + } + + return true +} + +func (pa *Path) onSourceSetReady() { + pa.sourceReady = true + + // reply to all clients that are waiting for a description + for c, state := range pa.clients { + if state == clientStateWaitingDescribe { + delete(pa.clients, c) + c.OnPathDescribeData(pa.sourceSdp, nil) + } + } +} + +func (pa *Path) onSourceSetNotReady() { + pa.sourceReady = false + + // close all clients that are reading or waiting to read + for c, state := range pa.clients { + if state != clientStateWaitingDescribe && c != pa.source { + pa.onClientRemove(c) + pa.parent.OnPathClientClose(c) + } + } +} + +func (pa *Path) onClientDescribe(c *client.Client) { + pa.lastDescribeReq = time.Now() + + // publisher not found + if pa.source == nil { + // on demand command is available: put the client on hold + if pa.conf.RunOnDemand != "" { + if pa.onDemandCmd == nil { // start if needed + pa.Log("starting on demand command") + pa.lastDescribeActivation = time.Now() + + var err error + pa.onDemandCmd, err = externalcmd.New(pa.conf.RunOnDemand, pa.name) + if err != nil { + pa.Log("ERR: %s", err) + } + } + + pa.clients[c] = clientStateWaitingDescribe + + // no on-demand: reply with 404 + } else { + c.OnPathDescribeData(nil, fmt.Errorf("no one is publishing on path '%s'", pa.name)) + } + + // publisher was found but is not ready: put the client on hold + } else if !pa.sourceReady { + // start rtsp source if needed + if source, ok := pa.source.(*sourcertsp.Source); ok { + if source.State() == sourcertsp.StateStopped { + pa.Log("starting on demand rtsp source") + pa.lastDescribeActivation = time.Now() + atomic.AddInt64(pa.stats.CountSourcesRtspRunning, +1) + source.SetState(sourcertsp.StateRunning) + } + + // start rtmp source if needed + } else if source, ok := pa.source.(*sourcertmp.Source); ok { + if source.State() == sourcertmp.StateStopped { + pa.Log("starting on demand rtmp source") + pa.lastDescribeActivation = time.Now() + atomic.AddInt64(pa.stats.CountSourcesRtmpRunning, +1) + source.SetState(sourcertmp.StateRunning) + } + } + + pa.clients[c] = clientStateWaitingDescribe + + // publisher was found and is ready + } else { + c.OnPathDescribeData(pa.sourceSdp, nil) + } +} + +func (pa *Path) onClientSetupPlay(c *client.Client, trackId int) error { + if !pa.sourceReady { + return fmt.Errorf("no one is publishing on path '%s'", pa.name) + } + + if trackId >= pa.sourceTrackCount { + return fmt.Errorf("track %d does not exist", trackId) + } + + pa.clients[c] = clientStatePrePlay + return nil +} + +func (pa *Path) onClientPlay(c *client.Client) { + atomic.AddInt64(pa.stats.CountReaders, 1) + pa.clients[c] = clientStatePlay + pa.readers.add(c) +} + +func (pa *Path) onClientAnnounce(c *client.Client, tracks gortsplib.Tracks) error { + if pa.source != nil { + return fmt.Errorf("someone is already publishing on path '%s'", pa.name) + } + + pa.clients[c] = clientStatePreRecord + pa.source = c + pa.sourceTrackCount = len(tracks) + pa.sourceSdp = tracks.Write() + return nil +} + +func (pa *Path) onClientRecord(c *client.Client) { + atomic.AddInt64(pa.stats.CountPublishers, 1) + pa.clients[c] = clientStateRecord + pa.onSourceSetReady() +} + +func (pa *Path) onClientRemove(c *client.Client) { + state := pa.clients[c] + delete(pa.clients, c) + + switch state { + case clientStatePlay: + atomic.AddInt64(pa.stats.CountReaders, -1) + pa.readers.remove(c) + + case clientStateRecord: + atomic.AddInt64(pa.stats.CountPublishers, -1) + pa.onSourceSetNotReady() + } + + if pa.source == c { + pa.source = nil + + // close all clients that are reading or waiting to read + for oc, state := range pa.clients { + if state != clientStateWaitingDescribe && oc != pa.source { + pa.onClientRemove(oc) + pa.parent.OnPathClientClose(oc) + } + } + } +} + +func (pa *Path) OnSourceReady(tracks gortsplib.Tracks) { + pa.sourceSdp = tracks.Write() + pa.sourceTrackCount = len(tracks) + pa.sourceSetReady <- struct{}{} +} + +func (pa *Path) OnSourceNotReady() { + pa.sourceSetNotReady <- struct{}{} +} + +func (pa *Path) Name() string { + return pa.name +} + +func (pa *Path) SourceTrackCount() int { + return pa.sourceTrackCount +} + +func (pa *Path) Conf() *conf.PathConf { + return pa.conf +} + +func (pa *Path) OnPathManDescribe(req ClientDescribeReq) { + pa.clientDescribe <- req +} + +func (pa *Path) OnPathManSetupPlay(req ClientSetupPlayReq) { + pa.clientSetupPlay <- req +} + +func (pa *Path) OnPathManAnnounce(req ClientAnnounceReq) { + pa.clientAnnounce <- req +} + +func (pa *Path) OnClientRemove(c *client.Client) { + res := make(chan struct{}) + pa.clientRemove <- clientRemoveReq{res, c} + <-res +} + +func (pa *Path) OnClientPlay(c *client.Client) { + res := make(chan struct{}) + pa.clientPlay <- clientPlayReq{res, c} + <-res +} + +func (pa *Path) OnClientRecord(c *client.Client) { + res := make(chan struct{}) + pa.clientRecord <- clientRecordReq{res, c} + <-res +} + +func (pa *Path) OnFrame(trackId int, streamType gortsplib.StreamType, buf []byte) { + pa.readers.forwardFrame(trackId, streamType, buf) +} diff --git a/path/readersmap.go b/path/readersmap.go new file mode 100644 index 00000000..acdf5767 --- /dev/null +++ b/path/readersmap.go @@ -0,0 +1,46 @@ +package path + +import ( + "sync" + + "github.com/aler9/gortsplib" + "github.com/aler9/gortsplib/base" +) + +type Reader interface { + OnReaderFrame(int, base.StreamType, []byte) +} + +type readersMap struct { + mutex sync.RWMutex + ma map[Reader]struct{} +} + +func newReadersMap() *readersMap { + return &readersMap{ + ma: make(map[Reader]struct{}), + } +} + +func (m *readersMap) add(reader Reader) { + m.mutex.Lock() + defer m.mutex.Unlock() + + m.ma[reader] = struct{}{} +} + +func (m *readersMap) remove(reader Reader) { + m.mutex.Lock() + defer m.mutex.Unlock() + + delete(m.ma, reader) +} + +func (m *readersMap) forwardFrame(trackId int, streamType gortsplib.StreamType, buf []byte) { + m.mutex.RLock() + defer m.mutex.RUnlock() + + for c := range m.ma { + c.OnReaderFrame(trackId, streamType, buf) + } +} diff --git a/pathman/pathman.go b/pathman/pathman.go new file mode 100644 index 00000000..7e0ecb75 --- /dev/null +++ b/pathman/pathman.go @@ -0,0 +1,267 @@ +package pathman + +import ( + "fmt" + "sync" + "time" + + "github.com/aler9/gortsplib" + "github.com/aler9/gortsplib/base" + "github.com/aler9/gortsplib/headers" + + "github.com/aler9/rtsp-simple-server/client" + "github.com/aler9/rtsp-simple-server/conf" + "github.com/aler9/rtsp-simple-server/path" + "github.com/aler9/rtsp-simple-server/serverudp" + "github.com/aler9/rtsp-simple-server/stats" +) + +type Parent interface { + Log(string, ...interface{}) +} + +type PathManager struct { + stats *stats.Stats + serverUdpRtp *serverudp.Server + serverUdpRtcp *serverudp.Server + readTimeout time.Duration + writeTimeout time.Duration + authMethods []headers.AuthMethod + confPaths map[string]*conf.PathConf + parent Parent + + paths map[string]*path.Path + wg sync.WaitGroup + + // in + pathClose chan *path.Path + clientDescribe chan path.ClientDescribeReq + clientAnnounce chan path.ClientAnnounceReq + clientSetupPlay chan path.ClientSetupPlayReq + terminate chan struct{} + + // out + clientClose chan *client.Client + done chan struct{} +} + +func New(stats *stats.Stats, + serverUdpRtp *serverudp.Server, + serverUdpRtcp *serverudp.Server, + readTimeout time.Duration, + writeTimeout time.Duration, + authMethods []headers.AuthMethod, + confPaths map[string]*conf.PathConf, + parent Parent) *PathManager { + + pm := &PathManager{ + stats: stats, + serverUdpRtp: serverUdpRtp, + serverUdpRtcp: serverUdpRtcp, + readTimeout: readTimeout, + writeTimeout: writeTimeout, + authMethods: authMethods, + confPaths: confPaths, + parent: parent, + paths: make(map[string]*path.Path), + pathClose: make(chan *path.Path), + clientDescribe: make(chan path.ClientDescribeReq), + clientAnnounce: make(chan path.ClientAnnounceReq), + clientSetupPlay: make(chan path.ClientSetupPlayReq), + terminate: make(chan struct{}), + clientClose: make(chan *client.Client), + done: make(chan struct{}), + } + + for name, pathConf := range confPaths { + if pathConf.Regexp == nil { + pa := path.New(&pm.wg, pm.stats, pm.serverUdpRtp, pm.serverUdpRtcp, + pm.readTimeout, pm.writeTimeout, name, pathConf, pm) + pm.paths[name] = pa + } + } + + go pm.run() + return pm +} + +func (pm *PathManager) Close() { + go func() { + for range pm.clientClose { + } + }() + close(pm.terminate) + <-pm.done +} + +func (pm *PathManager) Log(format string, args ...interface{}) { + pm.parent.Log(format, args...) +} + +func (pm *PathManager) run() { + defer close(pm.done) + +outer: + for { + select { + case pa := <-pm.pathClose: + delete(pm.paths, pa.Name()) + pa.Close() + + case req := <-pm.clientDescribe: + pathConf, err := pm.findPathConf(req.PathName) + if err != nil { + req.Res <- path.ClientDescribeRes{nil, err} + continue + } + + err = req.Client.Authenticate(pm.authMethods, pathConf.ReadIpsParsed, + pathConf.ReadUser, pathConf.ReadPass, req.Req) + if err != nil { + req.Res <- path.ClientDescribeRes{nil, err} + continue + } + + // create path if it doesn't exist + if _, ok := pm.paths[req.PathName]; !ok { + pa := path.New(&pm.wg, pm.stats, pm.serverUdpRtp, pm.serverUdpRtcp, + pm.readTimeout, pm.writeTimeout, req.PathName, pathConf, pm) + pm.paths[req.PathName] = pa + } + + pm.paths[req.PathName].OnPathManDescribe(req) + + case req := <-pm.clientAnnounce: + pathConf, err := pm.findPathConf(req.PathName) + if err != nil { + req.Res <- path.ClientAnnounceRes{nil, err} + continue + } + + err = req.Client.Authenticate(pm.authMethods, + pathConf.PublishIpsParsed, pathConf.PublishUser, pathConf.PublishPass, req.Req) + if err != nil { + req.Res <- path.ClientAnnounceRes{nil, err} + continue + } + + // create path if it doesn't exist + if _, ok := pm.paths[req.PathName]; !ok { + pa := path.New(&pm.wg, pm.stats, pm.serverUdpRtp, pm.serverUdpRtcp, + pm.readTimeout, pm.writeTimeout, req.PathName, pathConf, pm) + pm.paths[req.PathName] = pa + } + + pm.paths[req.PathName].OnPathManAnnounce(req) + + case req := <-pm.clientSetupPlay: + if _, ok := pm.paths[req.PathName]; !ok { + req.Res <- path.ClientSetupPlayRes{nil, fmt.Errorf("no one is publishing on path '%s'", req.PathName)} + continue + } + + pathConf, err := pm.findPathConf(req.PathName) + if err != nil { + req.Res <- path.ClientSetupPlayRes{nil, err} + continue + } + + err = req.Client.Authenticate(pm.authMethods, + pathConf.ReadIpsParsed, pathConf.ReadUser, pathConf.ReadPass, req.Req) + if err != nil { + req.Res <- path.ClientSetupPlayRes{nil, err} + continue + } + + pm.paths[req.PathName].OnPathManSetupPlay(req) + + case <-pm.terminate: + break outer + } + } + + go func() { + for { + select { + case _, ok := <-pm.pathClose: + if !ok { + return + } + + case req := <-pm.clientDescribe: + req.Res <- path.ClientDescribeRes{nil, fmt.Errorf("terminated")} + + case req := <-pm.clientAnnounce: + req.Res <- path.ClientAnnounceRes{nil, fmt.Errorf("terminated")} + + case req := <-pm.clientSetupPlay: + req.Res <- path.ClientSetupPlayRes{nil, fmt.Errorf("terminated")} + } + } + }() + + for _, pa := range pm.paths { + pa.Close() + } + pm.wg.Wait() + + close(pm.clientClose) + close(pm.pathClose) + close(pm.clientDescribe) + close(pm.clientAnnounce) + close(pm.clientSetupPlay) +} + +func (pm *PathManager) findPathConf(name string) (*conf.PathConf, error) { + err := conf.CheckPathName(name) + if err != nil { + return nil, fmt.Errorf("invalid path name: %s (%s)", err, name) + } + + // normal path + if pathConf, ok := pm.confPaths[name]; ok { + return pathConf, nil + } + + // regular expression path + for _, pathConf := range pm.confPaths { + if pathConf.Regexp != nil && pathConf.Regexp.MatchString(name) { + return pathConf, nil + } + } + + return nil, fmt.Errorf("unable to find a valid configuration for path '%s'", name) +} + +func (pm *PathManager) OnPathClose(pa *path.Path) { + pm.pathClose <- pa +} + +func (pm *PathManager) OnPathClientClose(c *client.Client) { + pm.clientClose <- c +} + +func (pm *PathManager) OnClientDescribe(c *client.Client, pathName string, req *base.Request) (client.Path, error) { + res := make(chan path.ClientDescribeRes) + pm.clientDescribe <- path.ClientDescribeReq{res, c, pathName, req} + re := <-res + return re.Path, re.Err +} + +func (pm *PathManager) OnClientAnnounce(c *client.Client, pathName string, tracks gortsplib.Tracks, req *base.Request) (client.Path, error) { + res := make(chan path.ClientAnnounceRes) + pm.clientAnnounce <- path.ClientAnnounceReq{res, c, pathName, tracks, req} + re := <-res + return re.Path, re.Err +} + +func (pm *PathManager) OnClientSetupPlay(c *client.Client, pathName string, trackId int, req *base.Request) (client.Path, error) { + res := make(chan path.ClientSetupPlayRes) + pm.clientSetupPlay <- path.ClientSetupPlayReq{res, c, pathName, trackId, req} + re := <-res + return re.Path, re.Err +} + +func (pm *PathManager) ClientClose() chan *client.Client { + return pm.clientClose +} diff --git a/pprof.go b/pprof/pprof.go similarity index 55% rename from pprof.go rename to pprof/pprof.go index 63672a80..8275e304 100644 --- a/pprof.go +++ b/pprof/pprof.go @@ -1,4 +1,4 @@ -package main +package pprof import ( "context" @@ -8,21 +8,25 @@ import ( ) const ( - pprofAddress = ":9999" + address = ":9999" ) -type pprof struct { +type Parent interface { + Log(string, ...interface{}) +} + +type Pprof struct { listener net.Listener server *http.Server } -func newPprof(p *program) (*pprof, error) { - listener, err := net.Listen("tcp", pprofAddress) +func New(parent Parent) (*Pprof, error) { + listener, err := net.Listen("tcp", address) if err != nil { return nil, err } - pp := &pprof{ + pp := &Pprof{ listener: listener, } @@ -30,17 +34,19 @@ func newPprof(p *program) (*pprof, error) { Handler: http.DefaultServeMux, } - p.log("[pprof] opened on " + pprofAddress) + parent.Log("[pprof] opened on " + address) + + go pp.run() return pp, nil } -func (pp *pprof) run() { +func (pp *Pprof) Close() { + pp.server.Shutdown(context.Background()) +} + +func (pp *Pprof) run() { err := pp.server.Serve(pp.listener) if err != http.ErrServerClosed { panic(err) } } - -func (pp *pprof) close() { - pp.server.Shutdown(context.Background()) -} diff --git a/servertcp.go b/servertcp.go deleted file mode 100644 index f0cef10e..00000000 --- a/servertcp.go +++ /dev/null @@ -1,52 +0,0 @@ -package main - -import ( - "net" -) - -type serverTCP struct { - p *program - listener *net.TCPListener - - done chan struct{} -} - -func newServerTCP(p *program) (*serverTCP, error) { - listener, err := net.ListenTCP("tcp", &net.TCPAddr{ - Port: p.conf.RtspPort, - }) - if err != nil { - return nil, err - } - - l := &serverTCP{ - p: p, - listener: listener, - done: make(chan struct{}), - } - - l.log("opened on :%d", p.conf.RtspPort) - return l, nil -} - -func (l *serverTCP) log(format string, args ...interface{}) { - l.p.log("[TCP server] "+format, args...) -} - -func (l *serverTCP) run() { - defer close(l.done) - - for { - conn, err := l.listener.AcceptTCP() - if err != nil { - break - } - - l.p.clientNew <- conn - } -} - -func (l *serverTCP) close() { - l.listener.Close() - <-l.done -} diff --git a/servertcp/server.go b/servertcp/server.go new file mode 100644 index 00000000..0e9c366f --- /dev/null +++ b/servertcp/server.go @@ -0,0 +1,69 @@ +package servertcp + +import ( + "net" +) + +type Parent interface { + Log(string, ...interface{}) +} + +type Server struct { + parent Parent + + listener *net.TCPListener + + // out + accept chan net.Conn + done chan struct{} +} + +func New(port int, parent Parent) (*Server, error) { + listener, err := net.ListenTCP("tcp", &net.TCPAddr{ + Port: port, + }) + if err != nil { + return nil, err + } + + s := &Server{ + parent: parent, + listener: listener, + accept: make(chan net.Conn), + done: make(chan struct{}), + } + + parent.Log("[TCP server] opened on :%d", port) + + go s.run() + return s, nil +} + +func (s *Server) Close() { + go func() { + for co := range s.accept { + co.Close() + } + }() + s.listener.Close() + <-s.done +} + +func (s *Server) run() { + defer close(s.done) + + for { + conn, err := s.listener.AcceptTCP() + if err != nil { + break + } + + s.accept <- conn + } + + close(s.accept) +} + +func (s *Server) Accept() <-chan net.Conn { + return s.accept +} diff --git a/serverudp.go b/serverudp.go deleted file mode 100644 index 9882ea4a..00000000 --- a/serverudp.go +++ /dev/null @@ -1,113 +0,0 @@ -package main - -import ( - "net" - "sync/atomic" - "time" - - "github.com/aler9/gortsplib" - "github.com/aler9/gortsplib/multibuffer" -) - -const ( - udpReadBufferSize = 2048 -) - -type udpBufAddrPair struct { - buf []byte - addr *net.UDPAddr -} - -type serverUDP struct { - p *program - pc *net.UDPConn - streamType gortsplib.StreamType - readBuf *multibuffer.MultiBuffer - - writec chan udpBufAddrPair - done chan struct{} -} - -func newServerUDP(p *program, port int, streamType gortsplib.StreamType) (*serverUDP, error) { - pc, err := net.ListenUDP("udp", &net.UDPAddr{ - Port: port, - }) - if err != nil { - return nil, err - } - - l := &serverUDP{ - p: p, - pc: pc, - streamType: streamType, - readBuf: multibuffer.New(2, udpReadBufferSize), - writec: make(chan udpBufAddrPair), - done: make(chan struct{}), - } - - l.log("opened on :%d", port) - return l, nil -} - -func (l *serverUDP) log(format string, args ...interface{}) { - var label string - if l.streamType == gortsplib.StreamTypeRtp { - label = "RTP" - } else { - label = "RTCP" - } - l.p.log("[UDP/"+label+" server] "+format, args...) -} - -func (l *serverUDP) run() { - defer close(l.done) - - writeDone := make(chan struct{}) - go func() { - defer close(writeDone) - for w := range l.writec { - l.pc.SetWriteDeadline(time.Now().Add(l.p.conf.WriteTimeout)) - l.pc.WriteTo(w.buf, w.addr) - } - }() - - for { - buf := l.readBuf.Next() - n, addr, err := l.pc.ReadFromUDP(buf) - if err != nil { - break - } - - pub := l.p.udpPublishersMap.get(makeUDPPublisherAddr(addr.IP, addr.Port)) - if pub == nil { - continue - } - - // client sent RTP on RTCP port or vice-versa - if pub.streamType != l.streamType { - continue - } - - atomic.StoreInt64(pub.client.udpLastFrameTimes[pub.trackId], time.Now().Unix()) - - pub.client.rtcpReceivers[pub.trackId].OnFrame(l.streamType, buf[:n]) - - l.p.readersMap.forwardFrame(pub.client.path, - pub.trackId, - l.streamType, - buf[:n]) - - } - - close(l.writec) - <-writeDone -} - -func (l *serverUDP) close() { - l.pc.Close() - <-l.done -} - -func (l *serverUDP) write(data []byte, addr *net.UDPAddr) { - l.writec <- udpBufAddrPair{data, addr} -} diff --git a/serverudp/server.go b/serverudp/server.go new file mode 100644 index 00000000..29fefb51 --- /dev/null +++ b/serverudp/server.go @@ -0,0 +1,180 @@ +package serverudp + +import ( + "net" + "sync" + "time" + + "github.com/aler9/gortsplib" + "github.com/aler9/gortsplib/base" + "github.com/aler9/gortsplib/multibuffer" +) + +const ( + readBufferSize = 2048 +) + +type Publisher interface { + OnUdpPublisherFrame(int, base.StreamType, []byte) +} + +type publisherData struct { + publisher Publisher + trackId int +} + +type bufAddrPair struct { + buf []byte + addr *net.UDPAddr +} + +type Parent interface { + Log(string, ...interface{}) +} + +type publisherAddr struct { + ip [net.IPv6len]byte // use a fixed-size array to enable the equality operator + port int +} + +func (p *publisherAddr) fill(ip net.IP, port int) { + p.port = port + + if len(ip) == net.IPv4len { + copy(p.ip[0:], []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff}) // v4InV6Prefix + copy(p.ip[12:], ip) + } else { + copy(p.ip[:], ip) + } +} + +type Server struct { + writeTimeout time.Duration + streamType gortsplib.StreamType + + pc *net.UDPConn + readBuf *multibuffer.MultiBuffer + publishersMutex sync.RWMutex + publishers map[publisherAddr]*publisherData + + // in + write chan bufAddrPair + + // out + done chan struct{} +} + +func New(writeTimeout time.Duration, + port int, + streamType gortsplib.StreamType, + parent Parent) (*Server, error) { + + pc, err := net.ListenUDP("udp", &net.UDPAddr{ + Port: port, + }) + if err != nil { + return nil, err + } + + s := &Server{ + writeTimeout: writeTimeout, + streamType: streamType, + pc: pc, + readBuf: multibuffer.New(2, readBufferSize), + publishers: make(map[publisherAddr]*publisherData), + write: make(chan bufAddrPair), + done: make(chan struct{}), + } + + var label string + if s.streamType == gortsplib.StreamTypeRtp { + label = "RTP" + } else { + label = "RTCP" + } + parent.Log("[UDP/"+label+" server] opened on :%d", port) + + go s.run() + return s, nil +} + +func (s *Server) Close() { + s.pc.Close() + <-s.done +} + +func (s *Server) run() { + defer close(s.done) + + writeDone := make(chan struct{}) + go func() { + defer close(writeDone) + for w := range s.write { + s.pc.SetWriteDeadline(time.Now().Add(s.writeTimeout)) + s.pc.WriteTo(w.buf, w.addr) + } + }() + + for { + buf := s.readBuf.Next() + n, addr, err := s.pc.ReadFromUDP(buf) + if err != nil { + break + } + + pub := s.getPublisher(addr.IP, addr.Port) + if pub == nil { + continue + } + + pub.publisher.OnUdpPublisherFrame(pub.trackId, s.streamType, buf[:n]) + } + + close(s.write) + <-writeDone +} + +func (s *Server) Port() int { + return s.pc.LocalAddr().(*net.UDPAddr).Port +} + +func (s *Server) Write(data []byte, addr *net.UDPAddr) { + s.write <- bufAddrPair{data, addr} +} + +func (s *Server) AddPublisher(ip net.IP, port int, publisher Publisher, trackId int) { + s.publishersMutex.Lock() + defer s.publishersMutex.Unlock() + + var addr publisherAddr + addr.fill(ip, port) + + s.publishers[addr] = &publisherData{ + publisher: publisher, + trackId: trackId, + } +} + +func (s *Server) RemovePublisher(ip net.IP, port int, publisher Publisher) { + s.publishersMutex.Lock() + defer s.publishersMutex.Unlock() + + var addr publisherAddr + addr.fill(ip, port) + + delete(s.publishers, addr) +} + +func (s *Server) getPublisher(ip net.IP, port int) *publisherData { + s.publishersMutex.RLock() + defer s.publishersMutex.RUnlock() + + var addr publisherAddr + addr.fill(ip, port) + + el, ok := s.publishers[addr] + if !ok { + return nil + } + return el +} diff --git a/sourcertmp.go b/sourcertmp/source.go similarity index 68% rename from sourcertmp.go rename to sourcertmp/source.go index 08c6866c..18996e56 100644 --- a/sourcertmp.go +++ b/sourcertmp/source.go @@ -1,4 +1,4 @@ -package main +package sourcertmp import ( "fmt" @@ -15,53 +15,73 @@ import ( ) const ( - sourceRtmpRetryInterval = 5 * time.Second + retryInterval = 5 * time.Second ) -type sourceRtmpState int - -const ( - sourceRtmpStateStopped sourceRtmpState = iota - sourceRtmpStateRunning -) - -type sourceRtmp struct { - p *program - path *path - state sourceRtmpState - innerRunning bool - - innerTerminate chan struct{} - innerDone chan struct{} - setState chan sourceRtmpState - terminate chan struct{} - done chan struct{} +type Parent interface { + Log(string, ...interface{}) + OnSourceReady(gortsplib.Tracks) + OnSourceNotReady() + OnFrame(int, gortsplib.StreamType, []byte) } -func newSourceRtmp(p *program, path *path) *sourceRtmp { - s := &sourceRtmp{ - p: p, - path: path, - setState: make(chan sourceRtmpState), - terminate: make(chan struct{}), - done: make(chan struct{}), - } - - atomic.AddInt64(p.countSourcesRtmp, +1) - - if path.conf.SourceOnDemand { - s.state = sourceRtmpStateStopped - } else { - s.state = sourceRtmpStateRunning - atomic.AddInt64(p.countSourcesRtmpRunning, +1) +type State int + +const ( + StateStopped State = iota + StateRunning +) + +type Source struct { + ur string + state State + parent Parent + + innerRunning bool + + // in + innerTerminate chan struct{} + innerDone chan struct{} + stateChange chan State + terminate chan struct{} + + // out + done chan struct{} +} + +func New(ur string, + state State, + parent Parent) *Source { + s := &Source{ + ur: ur, + state: state, + parent: parent, + stateChange: make(chan State), + terminate: make(chan struct{}), + done: make(chan struct{}), } + go s.run(s.state) return s } -func (s *sourceRtmp) isSource() {} +func (s *Source) Close() { + close(s.terminate) + <-s.done +} -func (s *sourceRtmp) run(initialState sourceRtmpState) { +func (s *Source) IsSource() {} + +func (s *Source) State() State { + return s.state +} + +func (s *Source) SetState(state State) { + s.state = state + s.stateChange <- s.state +} + +func (s *Source) run(initialState State) { defer close(s.done) s.applyState(initialState) @@ -69,7 +89,7 @@ func (s *sourceRtmp) run(initialState sourceRtmpState) { outer: for { select { - case state := <-s.setState: + case state := <-s.stateChange: s.applyState(state) case <-s.terminate: @@ -82,13 +102,13 @@ outer: <-s.innerDone } - close(s.setState) + close(s.stateChange) } -func (s *sourceRtmp) applyState(state sourceRtmpState) { - if state == sourceRtmpStateRunning { +func (s *Source) applyState(state State) { + if state == StateRunning { if !s.innerRunning { - s.path.log("rtmp source started") + s.parent.Log("rtmp source started") s.innerRunning = true s.innerTerminate = make(chan struct{}) s.innerDone = make(chan struct{}) @@ -99,12 +119,12 @@ func (s *sourceRtmp) applyState(state sourceRtmpState) { close(s.innerTerminate) <-s.innerDone s.innerRunning = false - s.path.log("rtmp source stopped") + s.parent.Log("rtmp source stopped") } } } -func (s *sourceRtmp) runInner() { +func (s *Source) runInner() { defer close(s.innerDone) outer: @@ -114,7 +134,7 @@ outer: break outer } - t := time.NewTimer(sourceRtmpRetryInterval) + t := time.NewTimer(retryInterval) defer t.Stop() select { @@ -125,8 +145,8 @@ outer: } } -func (s *sourceRtmp) runInnerInner() bool { - s.path.log("connecting to rtmp source") +func (s *Source) runInnerInner() bool { + s.parent.Log("connecting to rtmp source") var conn *rtmp.Conn var nconn net.Conn @@ -134,7 +154,7 @@ func (s *sourceRtmp) runInnerInner() bool { dialDone := make(chan struct{}, 1) go func() { defer close(dialDone) - conn, nconn, err = rtmp.NewClient().Dial(s.path.conf.Source, rtmp.PrepareReading) + conn, nconn, err = rtmp.NewClient().Dial(s.ur, rtmp.PrepareReading) }() select { @@ -144,7 +164,7 @@ func (s *sourceRtmp) runInnerInner() bool { } if err != nil { - s.path.log("rtmp source ERR: %s", err) + s.parent.Log("rtmp source ERR: %s", err) return true } @@ -202,7 +222,7 @@ func (s *sourceRtmp) runInnerInner() bool { } if err != nil { - s.path.log("rtmp source ERR: %s", err) + s.parent.Log("rtmp source ERR: %s", err) return true } @@ -215,13 +235,13 @@ func (s *sourceRtmp) runInnerInner() bool { if h264Sps != nil { videoTrack, err = gortsplib.NewTrackH264(len(tracks), h264Sps, h264Pps) if err != nil { - s.path.log("rtmp source ERR: %s", err) + s.parent.Log("rtmp source ERR: %s", err) return true } h264Encoder, err = rtph264.NewEncoder(uint8(len(tracks))) if err != nil { - s.path.log("rtmp source ERR: %s", err) + s.parent.Log("rtmp source ERR: %s", err) return true } @@ -231,13 +251,13 @@ func (s *sourceRtmp) runInnerInner() bool { if aacConfig != nil { audioTrack, err = gortsplib.NewTrackAac(len(tracks), aacConfig) if err != nil { - s.path.log("rtmp source ERR: %s", err) + s.parent.Log("rtmp source ERR: %s", err) return true } aacEncoder, err = rtpaac.NewEncoder(uint8(len(tracks)), aacConfig) if err != nil { - s.path.log("rtmp source ERR: %s", err) + s.parent.Log("rtmp source ERR: %s", err) return true } @@ -245,15 +265,12 @@ func (s *sourceRtmp) runInnerInner() bool { } if len(tracks) == 0 { - s.path.log("rtmp source ERR: no tracks found") + s.parent.Log("rtmp source ERR: no tracks found") return true } - s.path.sourceSdp = tracks.Write() - s.path.sourceTrackCount = len(tracks) - - s.p.sourceRtmpReady <- s - s.path.log("rtmp source ready") + s.parent.OnSourceReady(tracks) + s.parent.Log("rtmp source ready") readDone := make(chan error) go func() { @@ -286,7 +303,7 @@ func (s *sourceRtmp) runInnerInner() bool { } for _, f := range frames { - s.p.readersMap.forwardFrame(s.path, videoTrack.Id, gortsplib.StreamTypeRtp, f) + s.parent.OnFrame(videoTrack.Id, gortsplib.StreamTypeRtp, f) } case av.AAC: @@ -302,7 +319,7 @@ func (s *sourceRtmp) runInnerInner() bool { } for _, f := range frames { - s.p.readersMap.forwardFrame(s.path, audioTrack.Id, gortsplib.StreamTypeRtp, f) + s.parent.OnFrame(audioTrack.Id, gortsplib.StreamTypeRtp, f) } default: @@ -325,14 +342,14 @@ outer: case err := <-readDone: nconn.Close() - s.path.log("rtmp source ERR: %s", err) + s.parent.Log("rtmp source ERR: %s", err) ret = true break outer } } - s.p.sourceRtmpNotReady <- s - s.path.log("rtmp source not ready") + s.parent.OnSourceNotReady() + s.parent.Log("rtmp source not ready") return ret } diff --git a/sourcertsp.go b/sourcertsp.go deleted file mode 100644 index e9cd83df..00000000 --- a/sourcertsp.go +++ /dev/null @@ -1,322 +0,0 @@ -package main - -import ( - "sync" - "sync/atomic" - "time" - - "github.com/aler9/gortsplib" -) - -const ( - sourceRtspRetryInterval = 5 * time.Second -) - -type sourceRtspState int - -const ( - sourceRtspStateStopped sourceRtspState = iota - sourceRtspStateRunning -) - -type sourceRtsp struct { - p *program - path *path - state sourceRtspState - tracks []*gortsplib.Track - innerRunning bool - - innerTerminate chan struct{} - innerDone chan struct{} - setState chan sourceRtspState - terminate chan struct{} - done chan struct{} -} - -func newSourceRtsp(p *program, path *path) *sourceRtsp { - s := &sourceRtsp{ - p: p, - path: path, - setState: make(chan sourceRtspState), - terminate: make(chan struct{}), - done: make(chan struct{}), - } - - atomic.AddInt64(p.countSourcesRtsp, +1) - - if path.conf.SourceOnDemand { - s.state = sourceRtspStateStopped - } else { - s.state = sourceRtspStateRunning - atomic.AddInt64(p.countSourcesRtspRunning, +1) - } - - return s -} - -func (s *sourceRtsp) isSource() {} - -func (s *sourceRtsp) run(initialState sourceRtspState) { - defer close(s.done) - - s.applyState(initialState) - -outer: - for { - select { - case state := <-s.setState: - s.applyState(state) - - case <-s.terminate: - break outer - } - } - - if s.innerRunning { - close(s.innerTerminate) - <-s.innerDone - } - - close(s.setState) -} - -func (s *sourceRtsp) applyState(state sourceRtspState) { - if state == sourceRtspStateRunning { - if !s.innerRunning { - s.path.log("rtsp source started") - s.innerRunning = true - s.innerTerminate = make(chan struct{}) - s.innerDone = make(chan struct{}) - go s.runInner() - } - } else { - if s.innerRunning { - close(s.innerTerminate) - <-s.innerDone - s.innerRunning = false - s.path.log("rtsp source stopped") - } - } -} - -func (s *sourceRtsp) runInner() { - defer close(s.innerDone) - -outer: - for { - ok := s.runInnerInner() - if !ok { - break outer - } - - t := time.NewTimer(sourceRtspRetryInterval) - defer t.Stop() - - select { - case <-s.innerTerminate: - break outer - case <-t.C: - } - } -} - -func (s *sourceRtsp) runInnerInner() bool { - s.path.log("connecting to rtsp source") - - var conn *gortsplib.ConnClient - var err error - dialDone := make(chan struct{}, 1) - go func() { - defer close(dialDone) - conn, err = gortsplib.NewConnClient(gortsplib.ConnClientConf{ - Host: s.path.conf.SourceUrl.Host, - ReadTimeout: s.p.conf.ReadTimeout, - WriteTimeout: s.p.conf.WriteTimeout, - ReadBufferCount: 2, - }) - }() - - select { - case <-s.innerTerminate: - return false - case <-dialDone: - } - - if err != nil { - s.path.log("rtsp source ERR: %s", err) - return true - } - - _, err = conn.Options(s.path.conf.SourceUrl) - if err != nil { - conn.Close() - s.path.log("rtsp source ERR: %s", err) - return true - } - - tracks, _, err := conn.Describe(s.path.conf.SourceUrl) - if err != nil { - conn.Close() - s.path.log("rtsp source ERR: %s", err) - return true - } - - // create a filtered SDP that is used by the server (not by the client) - s.path.sourceSdp = tracks.Write() - s.path.sourceTrackCount = len(tracks) - s.tracks = tracks - - if s.path.conf.SourceProtocolParsed == gortsplib.StreamProtocolUDP { - return s.runUDP(conn) - } else { - return s.runTCP(conn) - } -} - -func (s *sourceRtsp) runUDP(conn *gortsplib.ConnClient) bool { - for _, track := range s.tracks { - _, err := conn.SetupUDP(s.path.conf.SourceUrl, gortsplib.TransportModePlay, track, 0, 0) - if err != nil { - conn.Close() - s.path.log("rtsp source ERR: %s", err) - return true - } - } - - _, err := conn.Play(s.path.conf.SourceUrl) - if err != nil { - conn.Close() - s.path.log("rtsp source ERR: %s", err) - return true - } - - s.p.sourceRtspReady <- s - s.path.log("rtsp source ready") - - var wg sync.WaitGroup - - // receive RTP packets - for trackId := range s.tracks { - wg.Add(1) - go func(trackId int) { - defer wg.Done() - - for { - buf, err := conn.ReadFrameUDP(trackId, gortsplib.StreamTypeRtp) - if err != nil { - break - } - - s.p.readersMap.forwardFrame(s.path, trackId, - gortsplib.StreamTypeRtp, buf) - } - }(trackId) - } - - // receive RTCP packets - for trackId := range s.tracks { - wg.Add(1) - go func(trackId int) { - defer wg.Done() - - for { - buf, err := conn.ReadFrameUDP(trackId, gortsplib.StreamTypeRtcp) - if err != nil { - break - } - - s.p.readersMap.forwardFrame(s.path, trackId, - gortsplib.StreamTypeRtcp, buf) - } - }(trackId) - } - - tcpConnDone := make(chan error) - go func() { - tcpConnDone <- conn.LoopUDP() - }() - - var ret bool - -outer: - for { - select { - case <-s.innerTerminate: - conn.Close() - <-tcpConnDone - ret = false - break outer - - case err := <-tcpConnDone: - conn.Close() - s.path.log("rtsp source ERR: %s", err) - ret = true - break outer - } - } - - wg.Wait() - - s.p.sourceRtspNotReady <- s - s.path.log("rtsp source not ready") - - return ret -} - -func (s *sourceRtsp) runTCP(conn *gortsplib.ConnClient) bool { - for _, track := range s.tracks { - _, err := conn.SetupTCP(s.path.conf.SourceUrl, gortsplib.TransportModePlay, track) - if err != nil { - conn.Close() - s.path.log("rtsp source ERR: %s", err) - return true - } - } - - _, err := conn.Play(s.path.conf.SourceUrl) - if err != nil { - conn.Close() - s.path.log("rtsp source ERR: %s", err) - return true - } - - s.p.sourceRtspReady <- s - s.path.log("rtsp source ready") - - tcpConnDone := make(chan error) - go func() { - for { - trackId, streamType, content, err := conn.ReadFrameTCP() - if err != nil { - tcpConnDone <- err - return - } - - s.p.readersMap.forwardFrame(s.path, trackId, streamType, content) - } - }() - - var ret bool - -outer: - for { - select { - case <-s.innerTerminate: - conn.Close() - <-tcpConnDone - ret = false - break outer - - case err := <-tcpConnDone: - conn.Close() - s.path.log("rtsp source ERR: %s", err) - ret = true - break outer - } - } - - s.p.sourceRtspNotReady <- s - s.path.log("rtsp source not ready") - - return ret -} diff --git a/sourcertsp/source.go b/sourcertsp/source.go new file mode 100644 index 00000000..f91fe8ec --- /dev/null +++ b/sourcertsp/source.go @@ -0,0 +1,345 @@ +package sourcertsp + +import ( + "net/url" + "sync" + "time" + + "github.com/aler9/gortsplib" +) + +const ( + retryInterval = 5 * time.Second +) + +type Parent interface { + Log(string, ...interface{}) + OnSourceReady(gortsplib.Tracks) + OnSourceNotReady() + OnFrame(int, gortsplib.StreamType, []byte) +} + +type State int + +const ( + StateStopped State = iota + StateRunning +) + +type Source struct { + ur string + proto gortsplib.StreamProtocol + readTimeout time.Duration + writeTimeout time.Duration + state State + parent Parent + + innerRunning bool + + // in + innerTerminate chan struct{} + innerDone chan struct{} + stateChange chan State + terminate chan struct{} + + // out + done chan struct{} +} + +func New(ur string, + proto gortsplib.StreamProtocol, + readTimeout time.Duration, + writeTimeout time.Duration, + state State, + parent Parent) *Source { + s := &Source{ + ur: ur, + proto: proto, + readTimeout: readTimeout, + writeTimeout: writeTimeout, + state: state, + parent: parent, + stateChange: make(chan State), + terminate: make(chan struct{}), + done: make(chan struct{}), + } + + go s.run(s.state) + return s +} + +func (s *Source) Close() { + close(s.terminate) + <-s.done +} + +func (s *Source) IsSource() {} + +func (s *Source) State() State { + return s.state +} + +func (s *Source) SetState(state State) { + s.state = state + s.stateChange <- s.state +} + +func (s *Source) run(initialState State) { + defer close(s.done) + + s.applyState(initialState) + +outer: + for { + select { + case state := <-s.stateChange: + s.applyState(state) + + case <-s.terminate: + break outer + } + } + + if s.innerRunning { + close(s.innerTerminate) + <-s.innerDone + } + + close(s.stateChange) +} + +func (s *Source) applyState(state State) { + if state == StateRunning { + if !s.innerRunning { + s.parent.Log("rtsp source started") + s.innerRunning = true + s.innerTerminate = make(chan struct{}) + s.innerDone = make(chan struct{}) + go s.runInner() + } + } else { + if s.innerRunning { + close(s.innerTerminate) + <-s.innerDone + s.innerRunning = false + s.parent.Log("rtsp source stopped") + } + } +} + +func (s *Source) runInner() { + defer close(s.innerDone) + +outer: + for { + ok := s.runInnerInner() + if !ok { + break outer + } + + t := time.NewTimer(retryInterval) + defer t.Stop() + + select { + case <-s.innerTerminate: + break outer + case <-t.C: + } + } +} + +func (s *Source) runInnerInner() bool { + s.parent.Log("connecting to rtsp source") + + u, _ := url.Parse(s.ur) + + var conn *gortsplib.ConnClient + var err error + dialDone := make(chan struct{}, 1) + go func() { + defer close(dialDone) + conn, err = gortsplib.NewConnClient(gortsplib.ConnClientConf{ + Host: u.Host, + ReadTimeout: s.readTimeout, + WriteTimeout: s.writeTimeout, + ReadBufferCount: 2, + }) + }() + + select { + case <-s.innerTerminate: + return false + case <-dialDone: + } + + if err != nil { + s.parent.Log("rtsp source ERR: %s", err) + return true + } + + _, err = conn.Options(u) + if err != nil { + conn.Close() + s.parent.Log("rtsp source ERR: %s", err) + return true + } + + tracks, _, err := conn.Describe(u) + if err != nil { + conn.Close() + s.parent.Log("rtsp source ERR: %s", err) + return true + } + + if s.proto == gortsplib.StreamProtocolUDP { + return s.runUDP(u, conn, tracks) + } else { + return s.runTCP(u, conn, tracks) + } +} + +func (s *Source) runUDP(u *url.URL, conn *gortsplib.ConnClient, tracks gortsplib.Tracks) bool { + for _, track := range tracks { + _, err := conn.SetupUDP(u, gortsplib.TransportModePlay, track, 0, 0) + if err != nil { + conn.Close() + s.parent.Log("rtsp source ERR: %s", err) + return true + } + } + + _, err := conn.Play(u) + if err != nil { + conn.Close() + s.parent.Log("rtsp source ERR: %s", err) + return true + } + + s.parent.OnSourceReady(tracks) + s.parent.Log("rtsp source ready") + + var wg sync.WaitGroup + + // receive RTP packets + for trackId := range tracks { + wg.Add(1) + go func(trackId int) { + defer wg.Done() + + for { + buf, err := conn.ReadFrameUDP(trackId, gortsplib.StreamTypeRtp) + if err != nil { + break + } + + s.parent.OnFrame(trackId, gortsplib.StreamTypeRtp, buf) + } + }(trackId) + } + + // receive RTCP packets + for trackId := range tracks { + wg.Add(1) + go func(trackId int) { + defer wg.Done() + + for { + buf, err := conn.ReadFrameUDP(trackId, gortsplib.StreamTypeRtcp) + if err != nil { + break + } + + s.parent.OnFrame(trackId, gortsplib.StreamTypeRtcp, buf) + } + }(trackId) + } + + tcpConnDone := make(chan error) + go func() { + tcpConnDone <- conn.LoopUDP() + }() + + var ret bool + +outer: + for { + select { + case <-s.innerTerminate: + conn.Close() + <-tcpConnDone + ret = false + break outer + + case err := <-tcpConnDone: + conn.Close() + s.parent.Log("rtsp source ERR: %s", err) + ret = true + break outer + } + } + + wg.Wait() + + s.parent.OnSourceNotReady() + s.parent.Log("rtsp source not ready") + + return ret +} + +func (s *Source) runTCP(u *url.URL, conn *gortsplib.ConnClient, tracks gortsplib.Tracks) bool { + for _, track := range tracks { + _, err := conn.SetupTCP(u, gortsplib.TransportModePlay, track) + if err != nil { + conn.Close() + s.parent.Log("rtsp source ERR: %s", err) + return true + } + } + + _, err := conn.Play(u) + if err != nil { + conn.Close() + s.parent.Log("rtsp source ERR: %s", err) + return true + } + + s.parent.OnSourceReady(tracks) + s.parent.Log("rtsp source ready") + + tcpConnDone := make(chan error) + go func() { + for { + trackId, streamType, content, err := conn.ReadFrameTCP() + if err != nil { + tcpConnDone <- err + return + } + + s.parent.OnFrame(trackId, streamType, content) + } + }() + + var ret bool + +outer: + for { + select { + case <-s.innerTerminate: + conn.Close() + <-tcpConnDone + ret = false + break outer + + case err := <-tcpConnDone: + conn.Close() + s.parent.Log("rtsp source ERR: %s", err) + ret = true + break outer + } + } + + s.parent.OnSourceNotReady() + s.parent.Log("rtsp source not ready") + + return ret +} diff --git a/stats/stats.go b/stats/stats.go new file mode 100644 index 00000000..24422e5c --- /dev/null +++ b/stats/stats.go @@ -0,0 +1,30 @@ +package stats + +func ptrInt64() *int64 { + v := int64(0) + return &v +} + +type Stats struct { + // use pointers to avoid a crash on 32bit platforms + // https://github.com/golang/go/issues/9959 + CountClients *int64 + CountPublishers *int64 + CountReaders *int64 + CountSourcesRtsp *int64 + CountSourcesRtspRunning *int64 + CountSourcesRtmp *int64 + CountSourcesRtmpRunning *int64 +} + +func New() *Stats { + return &Stats{ + CountClients: ptrInt64(), + CountPublishers: ptrInt64(), + CountReaders: ptrInt64(), + CountSourcesRtsp: ptrInt64(), + CountSourcesRtspRunning: ptrInt64(), + CountSourcesRtmp: ptrInt64(), + CountSourcesRtmpRunning: ptrInt64(), + } +} diff --git a/utils.go b/utils.go deleted file mode 100644 index 94a76cb2..00000000 --- a/utils.go +++ /dev/null @@ -1,205 +0,0 @@ -package main - -import ( - "fmt" - "net" - "strings" - "sync" - - "github.com/aler9/gortsplib" - "github.com/aler9/gortsplib/base" -) - -func ipEqualOrInRange(ip net.IP, ips []interface{}) bool { - for _, item := range ips { - switch titem := item.(type) { - case net.IP: - if titem.Equal(ip) { - return true - } - - case *net.IPNet: - if titem.Contains(ip) { - return true - } - } - } - return false -} - -func splitPath(path string) (string, string, error) { - pos := func() int { - for i := len(path) - 1; i >= 0; i-- { - if path[i] == '/' { - return i - } - } - return -1 - }() - - if pos < 0 { - return "", "", fmt.Errorf("the path must contain a base path and a control path (%s)", path) - } - - basePath := path[:pos] - controlPath := path[pos+1:] - - if len(basePath) == 0 { - return "", "", fmt.Errorf("empty base path (%s)", basePath) - } - - if len(controlPath) == 0 { - return "", "", fmt.Errorf("empty control path (%s)", controlPath) - } - - return basePath, controlPath, nil -} - -func removeQueryFromPath(path string) string { - i := strings.Index(path, "?") - if i >= 0 { - return path[:i] - } - return path -} - -type udpPublisherAddr struct { - ip [net.IPv6len]byte // use a fixed-size array to enable the equality operator - port int -} - -func makeUDPPublisherAddr(ip net.IP, port int) udpPublisherAddr { - ret := udpPublisherAddr{ - port: port, - } - - if len(ip) == net.IPv4len { - copy(ret.ip[0:], []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff}) // v4InV6Prefix - copy(ret.ip[12:], ip) - } else { - copy(ret.ip[:], ip) - } - - return ret -} - -type udpPublisher struct { - client *client - trackId int - streamType gortsplib.StreamType -} - -type udpPublishersMap struct { - mutex sync.RWMutex - ma map[udpPublisherAddr]*udpPublisher -} - -func newUdpPublisherMap() *udpPublishersMap { - return &udpPublishersMap{ - ma: make(map[udpPublisherAddr]*udpPublisher), - } -} - -func (m *udpPublishersMap) clear() { - m.mutex.Lock() - defer m.mutex.Unlock() - - m.ma = make(map[udpPublisherAddr]*udpPublisher) -} - -func (m *udpPublishersMap) add(addr udpPublisherAddr, pub *udpPublisher) { - m.mutex.Lock() - defer m.mutex.Unlock() - - m.ma[addr] = pub -} - -func (m *udpPublishersMap) remove(addr udpPublisherAddr) { - m.mutex.Lock() - defer m.mutex.Unlock() - - delete(m.ma, addr) -} - -func (m *udpPublishersMap) get(addr udpPublisherAddr) *udpPublisher { - m.mutex.RLock() - defer m.mutex.RUnlock() - - el, ok := m.ma[addr] - if !ok { - return nil - } - return el -} - -type readersMap struct { - mutex sync.RWMutex - ma map[*client]struct{} -} - -func newReadersMap() *readersMap { - return &readersMap{ - ma: make(map[*client]struct{}), - } -} - -func (m *readersMap) clear() { - m.mutex.Lock() - defer m.mutex.Unlock() - - m.ma = make(map[*client]struct{}) -} - -func (m *readersMap) add(reader *client) { - m.mutex.Lock() - defer m.mutex.Unlock() - - m.ma[reader] = struct{}{} -} - -func (m *readersMap) remove(reader *client) { - m.mutex.Lock() - defer m.mutex.Unlock() - - delete(m.ma, reader) -} - -func (m *readersMap) forwardFrame(path *path, trackId int, streamType gortsplib.StreamType, frame []byte) { - m.mutex.RLock() - defer m.mutex.RUnlock() - - for c := range m.ma { - if c.path != path { - continue - } - - track, ok := c.streamTracks[trackId] - if !ok { - continue - } - - if c.streamProtocol == gortsplib.StreamProtocolUDP { - if streamType == gortsplib.StreamTypeRtp { - c.p.serverUdpRtp.write(frame, &net.UDPAddr{ - IP: c.ip(), - Zone: c.zone(), - Port: track.rtpPort, - }) - - } else { - c.p.serverUdpRtcp.write(frame, &net.UDPAddr{ - IP: c.ip(), - Zone: c.zone(), - Port: track.rtcpPort, - }) - } - - } else { - c.tcpFrame <- &base.InterleavedFrame{ - TrackId: trackId, - StreamType: streamType, - Content: frame, - } - } - } -}