mediamtx/client.go

725 lines
16 KiB
Go
Raw Normal View History

2019-12-31 12:48:17 +00:00
package main
import (
"errors"
"fmt"
"io"
"log"
"net"
"net/url"
"strconv"
"strings"
"rtsp-server/rtsp"
"gortc.io/sdp"
)
var (
errTeardown = errors.New("teardown")
errPlay = errors.New("play")
errRecord = errors.New("record")
)
func interleavedChannelToTrack(channel int) (trackFlow, int) {
if (channel % 2) == 0 {
return _TRACK_FLOW_RTP, (channel / 2)
}
return _TRACK_FLOW_RTCP, ((channel - 1) / 2)
}
func trackToInterleavedChannel(flow trackFlow, id int) int {
if flow == _TRACK_FLOW_RTP {
return id * 2
}
return (id * 2) + 1
}
type transportHeader map[string]struct{}
func newTransportHeader(in string) transportHeader {
th := make(map[string]struct{})
for _, t := range strings.Split(in, ";") {
th[t] = struct{}{}
}
return th
}
func (th transportHeader) getKeyValue(key string) string {
prefix := key + "="
for t := range th {
if strings.HasPrefix(t, prefix) {
return t[len(prefix):]
}
}
return ""
}
func (th transportHeader) getClientPorts() (int, int) {
val := th.getKeyValue("client_port")
if val == "" {
return 0, 0
}
ports := strings.Split(val, "-")
if len(ports) != 2 {
return 0, 0
}
port1, err := strconv.ParseInt(ports[0], 10, 64)
if err != nil {
return 0, 0
}
port2, err := strconv.ParseInt(ports[1], 10, 64)
if err != nil {
return 0, 0
}
return int(port1), int(port2)
}
type client struct {
p *program
rconn *rtsp.Conn
state string
ip net.IP
streamSdpText []byte // filled only if publisher
streamSdpParsed *sdp.Message // filled only if publisher
streamProtocol streamProtocol
streamTracks []*track
}
func newRtspClient(p *program, nconn net.Conn) *client {
c := &client{
p: p,
rconn: rtsp.NewConn(nconn),
state: "STARTING",
}
c.p.mutex.Lock()
c.p.clients[c] = struct{}{}
c.p.mutex.Unlock()
return c
}
func (c *client) close() error {
// already deleted
if _, ok := c.p.clients[c]; !ok {
return nil
}
delete(c.p.clients, c)
c.rconn.Close()
if c.p.publisher == c {
c.p.publisher = nil
// if the publisher has disconnected
// close all other connections
for oc := range c.p.clients {
oc.close()
}
}
return nil
}
func (c *client) log(format string, args ...interface{}) {
format = "[RTSP client " + c.rconn.RemoteAddr().String() + "] " + format
log.Printf(format, args...)
}
func (c *client) run() {
defer c.log("disconnected")
defer func() {
c.p.mutex.Lock()
defer c.p.mutex.Unlock()
c.close()
}()
ipstr, _, _ := net.SplitHostPort(c.rconn.RemoteAddr().String())
c.ip = net.ParseIP(ipstr)
c.log("connected")
for {
req, err := c.rconn.ReadRequest()
if err != nil {
if err != io.EOF {
c.log("ERR: %s", err)
}
return
}
c.log(req.Method)
res, err := c.handleRequest(req)
switch err {
// normal response
case nil:
err = c.rconn.WriteResponse(res)
if err != nil {
c.log("ERR: %s", err)
return
}
// TEARDOWN, close connection silently
case errTeardown:
return
// PLAY: first write response, then set state
// otherwise, in case of TCP connections, RTP packets could be written
// before the response
// then switch to RTP if TCP
case errPlay:
err = c.rconn.WriteResponse(res)
if err != nil {
c.log("ERR: %s", err)
return
}
c.log("is receiving %d %s via %s", len(c.streamTracks), func() string {
if len(c.streamTracks) == 1 {
return "track"
}
return "tracks"
}(), c.streamProtocol)
c.p.mutex.Lock()
c.state = "PLAY"
c.p.mutex.Unlock()
// when protocol is TCP, the RTSP connection becomes a RTP connection
// receive RTP feedback, do not parse it, wait until connection closes
if c.streamProtocol == _STREAM_PROTOCOL_TCP {
buf := make([]byte, 2048)
for {
_, err := c.rconn.Read(buf)
if err != nil {
if err != io.EOF {
c.log("ERR: %s", err)
}
return
}
}
}
// RECORD: switch to RTP if TCP
case errRecord:
err = c.rconn.WriteResponse(res)
if err != nil {
c.log("ERR: %s", err)
return
}
c.p.mutex.Lock()
c.state = "RECORD"
c.p.mutex.Unlock()
c.log("is publishing %d %s via %s", len(c.streamTracks), func() string {
if len(c.streamTracks) == 1 {
return "track"
}
return "tracks"
}(), c.streamProtocol)
// when protocol is TCP, the RTSP connection becomes a RTP connection
// receive RTP data and parse it
if c.streamProtocol == _STREAM_PROTOCOL_TCP {
buf := make([]byte, 2048)
for {
channel, n, err := c.rconn.ReadInterleavedFrame(buf)
if err != nil {
if _, ok := err.(*net.OpError); ok {
} else if err == io.EOF {
} else {
c.log("ERR: %s", err)
}
return
}
trackFlow, trackId := interleavedChannelToTrack(channel)
if trackId >= len(c.streamTracks) {
c.log("ERR: invalid track id '%d'", trackId)
return
}
c.p.mutex.RLock()
c.p.forwardTrack(trackFlow, trackId, buf[:n])
c.p.mutex.RUnlock()
}
}
// error: write and exit
default:
c.log("ERR: %s", err)
if cseq, ok := req.Headers["CSeq"]; ok {
c.rconn.WriteResponse(&rtsp.Response{
StatusCode: 400,
Status: "Bad Request",
Headers: map[string]string{
"CSeq": cseq,
},
})
} else {
c.rconn.WriteResponse(&rtsp.Response{
StatusCode: 400,
Status: "Bad Request",
})
}
return
}
}
}
func (c *client) handleRequest(req *rtsp.Request) (*rtsp.Response, error) {
cseq, ok := req.Headers["CSeq"]
if !ok {
return nil, fmt.Errorf("cseq missing")
}
ur, err := url.Parse(req.Path)
if err != nil {
return nil, fmt.Errorf("unable to parse path '%s'", req.Path)
}
switch req.Method {
case "OPTIONS":
// do not check state, since OPTIONS can be requested
// in any state
return &rtsp.Response{
StatusCode: 200,
Status: "OK",
Headers: map[string]string{
"CSeq": cseq,
"Public": strings.Join([]string{
"DESCRIBE",
"ANNOUNCE",
"SETUP",
"PLAY",
"PAUSE",
"RECORD",
"TEARDOWN",
}, ", "),
},
}, nil
case "DESCRIBE":
if c.state != "STARTING" {
return nil, fmt.Errorf("client is in state '%s'", c.state)
}
sdp, err := func() ([]byte, error) {
c.p.mutex.RLock()
defer c.p.mutex.RUnlock()
if c.p.publisher == nil {
return nil, fmt.Errorf("no one is streaming")
}
return c.p.publisher.streamSdpText, nil
}()
if err != nil {
return nil, err
}
return &rtsp.Response{
StatusCode: 200,
Status: "OK",
Headers: map[string]string{
"CSeq": cseq,
"Content-Base": ur.String(),
"Content-Type": "application/sdp",
},
Content: sdp,
}, nil
case "ANNOUNCE":
if c.state != "STARTING" {
return nil, fmt.Errorf("client is in state '%s'", c.state)
}
ct, ok := req.Headers["Content-Type"]
if !ok {
return nil, fmt.Errorf("Content-Type header missing")
}
if ct != "application/sdp" {
return nil, fmt.Errorf("unsupported Content-Type '%s'", ct)
}
sdpParsed, err := func() (*sdp.Message, error) {
s, err := sdp.DecodeSession(req.Content, nil)
if err != nil {
return nil, err
}
m := &sdp.Message{}
d := sdp.NewDecoder(s)
err = d.Decode(m)
if err != nil {
return nil, err
}
return m, nil
}()
if err != nil {
return nil, fmt.Errorf("invalid SDP: %s", err)
}
err = func() error {
c.p.mutex.Lock()
defer c.p.mutex.Unlock()
if c.p.publisher != nil {
return fmt.Errorf("another client is already streaming")
}
c.p.publisher = c
c.streamSdpText = req.Content
c.streamSdpParsed = sdpParsed
c.state = "ANNOUNCE"
return nil
}()
if err != nil {
return nil, err
}
return &rtsp.Response{
StatusCode: 200,
Status: "OK",
Headers: map[string]string{
"CSeq": cseq,
},
}, nil
case "SETUP":
transportstr, ok := req.Headers["Transport"]
if !ok {
return nil, fmt.Errorf("transport header missing")
}
th := newTransportHeader(transportstr)
if _, ok := th["unicast"]; !ok {
return nil, fmt.Errorf("transport header does not contain unicast")
}
switch c.state {
// play
case "STARTING", "PRE_PLAY":
err := func() error {
c.p.mutex.RLock()
defer c.p.mutex.RUnlock()
if c.p.publisher == nil {
return fmt.Errorf("no one is streaming")
}
return nil
}()
if err != nil {
return nil, err
}
// play via UDP
if _, ok := th["RTP/AVP"]; ok {
rtpPort, rtcpPort := th.getClientPorts()
if rtpPort == 0 || rtcpPort == 0 {
return nil, fmt.Errorf("transport header does not have valid client ports (%s)", transportstr)
}
err = func() error {
c.p.mutex.Lock()
defer c.p.mutex.Unlock()
if len(c.streamTracks) > 0 && c.streamProtocol != _STREAM_PROTOCOL_UDP {
return fmt.Errorf("client want to send tracks with different protocols")
}
if len(c.streamTracks) >= len(c.p.publisher.streamSdpParsed.Medias) {
return fmt.Errorf("all the tracks have already been setup")
}
c.streamProtocol = _STREAM_PROTOCOL_UDP
c.streamTracks = append(c.streamTracks, &track{
rtpPort: rtpPort,
rtcpPort: rtcpPort,
})
c.state = "PRE_PLAY"
return nil
}()
if err != nil {
return nil, err
}
return &rtsp.Response{
StatusCode: 200,
Status: "OK",
Headers: map[string]string{
"CSeq": cseq,
"Transport": strings.Join([]string{
"RTP/AVP",
"unicast",
fmt.Sprintf("client_port=%d-%d", rtpPort, rtcpPort),
// use two fake server ports, since we do not want to receive feedback
// from the client
fmt.Sprintf("server_port=%d-%d", c.p.rtpPort+2, c.p.rtcpPort+2),
"ssrc=1234ABCD",
}, ";"),
"Session": "12345678",
},
}, nil
// play via TCP
} else if _, ok := th["RTP/AVP/TCP"]; ok {
err = func() error {
c.p.mutex.Lock()
defer c.p.mutex.Unlock()
if len(c.streamTracks) > 0 && c.streamProtocol != _STREAM_PROTOCOL_TCP {
return fmt.Errorf("client want to send tracks with different protocols")
}
if len(c.streamTracks) >= len(c.p.publisher.streamSdpParsed.Medias) {
return fmt.Errorf("all the tracks have already been setup")
}
c.streamProtocol = _STREAM_PROTOCOL_TCP
c.streamTracks = append(c.streamTracks, &track{
rtpPort: 0,
rtcpPort: 0,
})
c.state = "PRE_PLAY"
return nil
}()
if err != nil {
return nil, err
}
interleaved := fmt.Sprintf("%d-%d", ((len(c.streamTracks) - 1) * 2), ((len(c.streamTracks)-1)*2)+1)
return &rtsp.Response{
StatusCode: 200,
Status: "OK",
Headers: map[string]string{
"CSeq": cseq,
"Transport": strings.Join([]string{
"RTP/AVP/TCP",
"unicast",
fmt.Sprintf("interleaved=%s", interleaved),
}, ";"),
"Session": "12345678",
},
}, nil
} else {
return nil, fmt.Errorf("transport header does not contain a valid protocol (RTP/AVP or RTP/AVP/TCP) (%s)", transportstr)
}
// record
case "ANNOUNCE", "PRE_RECORD":
if _, ok := th["mode=record"]; !ok {
return nil, fmt.Errorf("transport header does not contain mode=record")
}
// record via UDP
if _, ok := th["RTP/AVP/UDP"]; ok {
rtpPort, rtcpPort := th.getClientPorts()
if rtpPort == 0 || rtcpPort == 0 {
return nil, fmt.Errorf("transport header does not have valid client ports (%s)", transportstr)
}
err = func() error {
c.p.mutex.Lock()
defer c.p.mutex.Unlock()
if len(c.streamTracks) > 0 && c.streamProtocol != _STREAM_PROTOCOL_UDP {
return fmt.Errorf("client want to send tracks with different protocols")
}
if len(c.streamTracks) >= len(c.streamSdpParsed.Medias) {
return fmt.Errorf("all the tracks have already been setup")
}
c.streamProtocol = _STREAM_PROTOCOL_UDP
c.streamTracks = append(c.streamTracks, &track{
rtpPort: rtpPort,
rtcpPort: rtcpPort,
})
c.state = "PRE_RECORD"
return nil
}()
if err != nil {
return nil, err
}
return &rtsp.Response{
StatusCode: 200,
Status: "OK",
Headers: map[string]string{
"CSeq": cseq,
"Transport": strings.Join([]string{
"RTP/AVP",
"unicast",
fmt.Sprintf("client_port=%d-%d", rtpPort, rtcpPort),
fmt.Sprintf("server_port=%d-%d", c.p.rtpPort, c.p.rtcpPort),
"ssrc=1234ABCD",
}, ";"),
"Session": "12345678",
},
}, nil
// record via TCP
} else if _, ok := th["RTP/AVP/TCP"]; ok {
var interleaved string
err = func() error {
c.p.mutex.Lock()
defer c.p.mutex.Unlock()
if len(c.streamTracks) > 0 && c.streamProtocol != _STREAM_PROTOCOL_TCP {
return fmt.Errorf("client want to send tracks with different protocols")
}
if len(c.streamTracks) >= len(c.streamSdpParsed.Medias) {
return fmt.Errorf("all the tracks have already been setup")
}
interleaved = th.getKeyValue("interleaved")
if interleaved == "" {
return fmt.Errorf("transport header does not contain interleaved field")
}
expInterleaved := fmt.Sprintf("%d-%d", 0+len(c.streamTracks)*2, 1+len(c.streamTracks)*2)
if interleaved != expInterleaved {
return fmt.Errorf("wrong interleaved value, expected '%s', got '%s'", expInterleaved, interleaved)
}
c.streamProtocol = _STREAM_PROTOCOL_TCP
c.streamTracks = append(c.streamTracks, &track{
rtpPort: 0,
rtcpPort: 0,
})
c.state = "PRE_RECORD"
return nil
}()
if err != nil {
return nil, err
}
return &rtsp.Response{
StatusCode: 200,
Status: "OK",
Headers: map[string]string{
"CSeq": cseq,
"Transport": strings.Join([]string{
"RTP/AVP/TCP",
"unicast",
fmt.Sprintf("interleaved=%s", interleaved),
}, ";"),
"Session": "12345678",
},
}, nil
} else {
return nil, fmt.Errorf("transport header does not contain a valid protocol (RTP/AVP or RTP/AVP/TCP) (%s)", transportstr)
}
default:
return nil, fmt.Errorf("client is in state '%s'", c.state)
}
case "PLAY":
if c.state != "PRE_PLAY" {
return nil, fmt.Errorf("client is in state '%s'", c.state)
}
err := func() error {
c.p.mutex.Lock()
defer c.p.mutex.Unlock()
if len(c.streamTracks) != len(c.p.publisher.streamSdpParsed.Medias) {
return fmt.Errorf("not all tracks have been setup")
}
return nil
}()
if err != nil {
return nil, err
}
return &rtsp.Response{
StatusCode: 200,
Status: "OK",
Headers: map[string]string{
"CSeq": cseq,
"Session": "12345678",
},
}, errPlay
case "PAUSE":
if c.state != "PLAY" {
return nil, fmt.Errorf("client is in state '%s'", c.state)
}
c.log("paused")
c.p.mutex.Lock()
c.state = "PRE_PLAY"
c.p.mutex.Unlock()
return &rtsp.Response{
StatusCode: 200,
Status: "OK",
Headers: map[string]string{
"CSeq": cseq,
"Session": "12345678",
},
}, nil
case "RECORD":
if c.state != "PRE_RECORD" {
return nil, fmt.Errorf("client is in state '%s'", c.state)
}
err := func() error {
c.p.mutex.Lock()
defer c.p.mutex.Unlock()
if len(c.streamTracks) != len(c.streamSdpParsed.Medias) {
return fmt.Errorf("not all tracks have been setup")
}
return nil
}()
if err != nil {
return nil, err
}
return &rtsp.Response{
StatusCode: 200,
Status: "OK",
Headers: map[string]string{
"CSeq": cseq,
"Session": "12345678",
},
}, errRecord
case "TEARDOWN":
return nil, errTeardown
default:
return nil, fmt.Errorf("unhandled method '%s'", req.Method)
}
}