fix race condition

This commit is contained in:
aler9 2020-01-26 18:08:15 +01:00
parent 924592a219
commit 6fefb89a7c
5 changed files with 100 additions and 64 deletions

View File

@ -112,13 +112,15 @@ type client struct {
streamSdpParsed *sdp.Message // filled only if publisher
streamProtocol streamProtocol
streamTracks []*track
chanWrite chan *gortsplib.InterleavedFrame
}
func newClient(p *program, nconn net.Conn) *client {
c := &client{
p: p,
conn: gortsplib.NewConnServer(nconn),
state: _CLIENT_STATE_STARTING,
p: p,
conn: gortsplib.NewConnServer(nconn),
state: _CLIENT_STATE_STARTING,
chanWrite: make(chan *gortsplib.InterleavedFrame),
}
c.p.mutex.Lock()
@ -136,6 +138,7 @@ func (c *client) close() error {
delete(c.p.clients, c)
c.conn.NetConn().Close()
close(c.chanWrite)
if c.path != "" {
if pub, ok := c.p.publishers[c.path]; ok && pub == c {
@ -251,7 +254,7 @@ func (c *client) handleRequest(req *gortsplib.Request) bool {
StatusCode: 200,
Status: "OK",
Header: gortsplib.Header{
"CSeq": cseq,
"CSeq": []string{cseq[0]},
"Public": []string{strings.Join([]string{
"DESCRIBE",
"ANNOUNCE",
@ -411,7 +414,7 @@ func (c *client) handleRequest(req *gortsplib.Request) bool {
StatusCode: 461,
Status: "Unsupported Transport",
Header: gortsplib.Header{
"CSeq": cseq,
"CSeq": []string{cseq[0]},
},
})
return false
@ -484,7 +487,7 @@ func (c *client) handleRequest(req *gortsplib.Request) bool {
StatusCode: 461,
Status: "Unsupported Transport",
Header: gortsplib.Header{
"CSeq": cseq,
"CSeq": []string{cseq[0]},
},
})
return false
@ -579,7 +582,7 @@ func (c *client) handleRequest(req *gortsplib.Request) bool {
StatusCode: 461,
Status: "Unsupported Transport",
Header: gortsplib.Header{
"CSeq": cseq,
"CSeq": []string{cseq[0]},
},
})
return false
@ -641,7 +644,7 @@ func (c *client) handleRequest(req *gortsplib.Request) bool {
StatusCode: 461,
Status: "Unsupported Transport",
Header: gortsplib.Header{
"CSeq": cseq,
"CSeq": []string{cseq[0]},
},
})
return false
@ -763,9 +766,19 @@ func (c *client) handleRequest(req *gortsplib.Request) bool {
c.state = _CLIENT_STATE_PLAY
c.p.mutex.Unlock()
//c.conn.NetConn().SetWriteDeadline(time.Now().Add(_WRITE_TIMEOUT))
// when protocol is TCP, the RTSP connection becomes a RTP connection
// receive RTP feedback, do not parse it, wait until connection closes
if c.streamProtocol == _STREAM_PROTOCOL_TCP {
// write RTP frames sequentially
go func() {
for frame := range c.chanWrite {
c.conn.NetConn().SetWriteDeadline(time.Now().Add(_WRITE_TIMEOUT))
c.conn.WriteInterleavedFrame(frame)
}
}()
// receive RTP feedback, do not parse it, wait until connection closes
buf := make([]byte, 2048)
for {
_, err := c.conn.NetConn().Read(buf)

2
go.mod
View File

@ -5,7 +5,7 @@ go 1.13
require (
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751 // indirect
github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d // indirect
github.com/aler9/gortsplib v0.0.0-20200126151926-9e5868a1b8a3
github.com/aler9/gortsplib v0.0.0-20200126152308-13da0e672306
gopkg.in/alecthomas/kingpin.v2 v2.2.6
gortc.io/sdp v0.17.0
)

4
go.sum
View File

@ -2,8 +2,8 @@ github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751 h1:JYp7IbQjafo
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d h1:UQZhZ2O0vMHr2cI+DC1Mbh0TJxzA3RcLoMsFw+aXw7E=
github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d/go.mod h1:rBZYJk541a8SKzHPHnH3zbiI+7dagKZ0cgpgrD7Fyho=
github.com/aler9/gortsplib v0.0.0-20200126151926-9e5868a1b8a3 h1:CK8JKsFz82I2mPrB5oBnu0HYZm4q4UW/1xzr58/ZzU4=
github.com/aler9/gortsplib v0.0.0-20200126151926-9e5868a1b8a3/go.mod h1:YiIgmmv0ELkWUy11Jj2h5AgfqLCpy8sIX/l9MmS8+uw=
github.com/aler9/gortsplib v0.0.0-20200126152308-13da0e672306 h1:mSGMii9I9cEyw2cgyujnlaYwml9MwUkC2Ko8R0vKS6w=
github.com/aler9/gortsplib v0.0.0-20200126152308-13da0e672306/go.mod h1:YiIgmmv0ELkWUy11Jj2h5AgfqLCpy8sIX/l9MmS8+uw=
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/pkg/errors v0.8.0 h1:WdK/asTD0HN+q6hsWO3/vpuAkAr+tw6aNJNDFFf0+qw=

29
main.go
View File

@ -152,25 +152,28 @@ func (p *program) forwardTrack(path string, id int, flow trackFlow, frame []byte
if c.path == path && c.state == _CLIENT_STATE_PLAY {
if c.streamProtocol == _STREAM_PROTOCOL_UDP {
if flow == _TRACK_FLOW_RTP {
p.rtpl.nconn.SetWriteDeadline(time.Now().Add(_WRITE_TIMEOUT))
p.rtpl.nconn.WriteTo(frame, &net.UDPAddr{
IP: c.ip,
Port: c.streamTracks[id].rtpPort,
})
p.rtpl.chanWrite <- &udpWrite{
addr: &net.UDPAddr{
IP: c.ip,
Port: c.streamTracks[id].rtpPort,
},
buf: frame,
}
} else {
p.rtcpl.nconn.SetWriteDeadline(time.Now().Add(_WRITE_TIMEOUT))
p.rtcpl.nconn.WriteTo(frame, &net.UDPAddr{
IP: c.ip,
Port: c.streamTracks[id].rtcpPort,
})
p.rtcpl.chanWrite <- &udpWrite{
addr: &net.UDPAddr{
IP: c.ip,
Port: c.streamTracks[id].rtcpPort,
},
buf: frame,
}
}
} else {
c.conn.NetConn().SetWriteDeadline(time.Now().Add(_WRITE_TIMEOUT))
c.conn.WriteInterleavedFrame(&gortsplib.InterleavedFrame{
c.chanWrite <- &gortsplib.InterleavedFrame{
Channel: trackToInterleavedChannel(id, flow),
Content: frame,
})
}
}
}
}

View File

@ -3,12 +3,19 @@ package main
import (
"log"
"net"
"time"
)
type udpWrite struct {
addr *net.UDPAddr
buf []byte
}
type serverUdpListener struct {
p *program
nconn *net.UDPConn
flow trackFlow
p *program
nconn *net.UDPConn
flow trackFlow
chanWrite chan *udpWrite
}
func newServerUdpListener(p *program, port int, flow trackFlow) (*serverUdpListener, error) {
@ -20,9 +27,10 @@ func newServerUdpListener(p *program, port int, flow trackFlow) (*serverUdpListe
}
l := &serverUdpListener{
p: p,
nconn: nconn,
flow: flow,
p: p,
nconn: nconn,
flow: flow,
chanWrite: make(chan *udpWrite),
}
l.log("opened on :%d", port)
@ -40,45 +48,57 @@ func (l *serverUdpListener) log(format string, args ...interface{}) {
}
func (l *serverUdpListener) run() {
buf := make([]byte, 2048) // UDP MTU is 1400
go func() {
for {
// create a buffer for each read.
// this is necessary since the buffer is propagated with channels
// so it must be unique.
buf := make([]byte, 2048) // UDP MTU is 1400
n, addr, err := l.nconn.ReadFromUDP(buf)
if err != nil {
l.log("ERR: %s", err)
break
}
for {
n, addr, err := l.nconn.ReadFromUDP(buf)
if err != nil {
l.log("ERR: %s", err)
break
}
func() {
l.p.mutex.RLock()
defer l.p.mutex.RUnlock()
func() {
l.p.mutex.RLock()
defer l.p.mutex.RUnlock()
// find path and track id from ip and port
path, trackId := func() (string, int) {
for _, pub := range l.p.publishers {
for i, t := range pub.streamTracks {
if !pub.ip.Equal(addr.IP) {
continue
}
if l.flow == _TRACK_FLOW_RTP {
if t.rtpPort == addr.Port {
return pub.path, i
// find path and track id from ip and port
path, trackId := func() (string, int) {
for _, pub := range l.p.publishers {
for i, t := range pub.streamTracks {
if !pub.ip.Equal(addr.IP) {
continue
}
} else {
if t.rtcpPort == addr.Port {
return pub.path, i
if l.flow == _TRACK_FLOW_RTP {
if t.rtpPort == addr.Port {
return pub.path, i
}
} else {
if t.rtcpPort == addr.Port {
return pub.path, i
}
}
}
}
return "", -1
}()
if path == "" {
return
}
return "", -1
}()
if path == "" {
return
}
l.p.forwardTrack(path, trackId, l.flow, buf[:n])
}()
}
l.p.forwardTrack(path, trackId, l.flow, buf[:n])
}()
}
}()
go func() {
for {
w := <-l.chanWrite
l.nconn.SetWriteDeadline(time.Now().Add(_WRITE_TIMEOUT))
l.nconn.WriteTo(w.buf, w.addr)
}
}()
}