diff --git a/client.go b/client.go index 52a4d369..f49d372a 100644 --- a/client.go +++ b/client.go @@ -24,6 +24,33 @@ const ( clientUdpWriteBufferSize = 128 * 1024 ) +type udpClient struct { + client *client + trackId int + streamType gortsplib.StreamType +} + +type udpClientAddr struct { + // use a fixed-size array for ip comparison + ip [net.IPv6len]byte + port int +} + +func makeUdpClientAddr(ip net.IP, port int) udpClientAddr { + ret := udpClientAddr{ + port: port, + } + + if len(ip) == net.IPv4len { + copy(ret.ip[0:], []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff}) // v4InV6Prefix + copy(ret.ip[12:], ip) + } else { + copy(ret.ip[:], ip) + } + + return ret +} + type describeRes struct { sdp []byte err error diff --git a/main.go b/main.go index 94310b81..e0a8d86c 100644 --- a/main.go +++ b/main.go @@ -163,19 +163,19 @@ type programEventTerminate struct{} func (programEventTerminate) isProgramEvent() {} type program struct { - conf *conf - logFile *os.File - metrics *metrics - serverRtsp *serverTcp - serverRtp *serverUdp - serverRtcp *serverUdp - sources []*source - clients map[*client]struct{} - udpClientPublishers map[ipKey]*client - paths map[string]*path - cmds []*exec.Cmd - publisherCount int - readerCount int + conf *conf + logFile *os.File + metrics *metrics + serverRtsp *serverTcp + serverRtp *serverUdp + serverRtcp *serverUdp + sources []*source + clients map[*client]struct{} + udpClientsByAddr map[udpClientAddr]*udpClient + paths map[string]*path + cmds []*exec.Cmd + publisherCount int + readerCount int events chan programEvent done chan struct{} @@ -201,12 +201,12 @@ func newProgram(args []string, stdin io.Reader) (*program, error) { } p := &program{ - conf: conf, - clients: make(map[*client]struct{}), - udpClientPublishers: make(map[ipKey]*client), - paths: make(map[string]*path), - events: make(chan programEvent), - done: make(chan struct{}), + conf: conf, + clients: make(map[*client]struct{}), + udpClientsByAddr: make(map[udpClientAddr]*udpClient), + paths: make(map[string]*path), + events: make(chan programEvent), + done: make(chan struct{}), } if _, ok := p.conf.logDestinationsParsed[logDestinationFile]; ok { @@ -424,9 +424,25 @@ outer: case programEventClientRecord: p.publisherCount += 1 evt.client.state = clientStateRecord + if evt.client.streamProtocol == gortsplib.StreamProtocolUdp { - p.udpClientPublishers[makeIpKey(evt.client.ip())] = evt.client + for trackId, track := range evt.client.streamTracks { + key := makeUdpClientAddr(evt.client.ip(), track.rtpPort) + p.udpClientsByAddr[key] = &udpClient{ + client: evt.client, + trackId: trackId, + streamType: gortsplib.StreamTypeRtp, + } + + key = makeUdpClientAddr(evt.client.ip(), track.rtcpPort) + p.udpClientsByAddr[key] = &udpClient{ + client: evt.client, + trackId: trackId, + streamType: gortsplib.StreamTypeRtcp, + } + } } + p.paths[evt.client.pathName].publisherSetReady() close(evt.done) @@ -434,19 +450,30 @@ outer: p.publisherCount -= 1 evt.client.state = clientStatePreRecord if evt.client.streamProtocol == gortsplib.StreamProtocolUdp { - delete(p.udpClientPublishers, makeIpKey(evt.client.ip())) + for _, track := range evt.client.streamTracks { + key := makeUdpClientAddr(evt.client.ip(), track.rtpPort) + delete(p.udpClientsByAddr, key) + + key = makeUdpClientAddr(evt.client.ip(), track.rtcpPort) + delete(p.udpClientsByAddr, key) + } } p.paths[evt.client.pathName].publisherSetNotReady() close(evt.done) case programEventClientFrameUdp: - client, trackId := p.findUdpClientPublisher(evt.addr, evt.streamType) - if client == nil { + pub, ok := p.udpClientsByAddr[makeUdpClientAddr(evt.addr.IP, evt.addr.Port)] + if !ok { continue } - client.rtcpReceivers[trackId].OnFrame(evt.streamType, evt.buf) - p.forwardFrame(client.pathName, trackId, evt.streamType, evt.buf) + // client sent RTP on RTCP port or vice-versa + if pub.streamType != evt.streamType { + continue + } + + pub.client.rtcpReceivers[pub.trackId].OnFrame(evt.streamType, evt.buf) + p.forwardFrame(pub.client.pathName, pub.trackId, evt.streamType, evt.buf) case programEventClientFrameTcp: p.forwardFrame(evt.path, evt.trackId, evt.streamType, evt.buf) @@ -555,25 +582,6 @@ func (p *program) findConfForPath(name string) *confPath { return nil } -func (p *program) findUdpClientPublisher(addr *net.UDPAddr, streamType gortsplib.StreamType) (*client, int) { - c, ok := p.udpClientPublishers[makeIpKey(addr.IP)] - if ok { - for i, t := range c.streamTracks { - if streamType == gortsplib.StreamTypeRtp { - if t.rtpPort == addr.Port { - return c, i - } - } else { - if t.rtcpPort == addr.Port { - return c, i - } - } - } - } - - return nil, -1 -} - func (p *program) forwardFrame(path string, trackId int, streamType gortsplib.StreamType, frame []byte) { for c := range p.clients { if c.pathName != path || diff --git a/utils.go b/utils.go index d136cb15..3ce1906b 100644 --- a/utils.go +++ b/utils.go @@ -149,20 +149,6 @@ func splitPath(path string) (string, string, error) { return comps[0], comps[1], nil } -// use a fixed-size array for ip comparison -type ipKey [net.IPv6len]byte - -func makeIpKey(ip net.IP) ipKey { - var ret ipKey - if len(ip) == net.IPv4len { - copy(ret[0:], []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff}) // v4InV6Prefix - copy(ret[12:], ip) - } else { - copy(ret[:], ip) - } - return ret -} - var rePathName = regexp.MustCompile("^[0-9a-zA-Z_-]+$") func checkPathName(name string) error {