hls source: support proxying any number of tracks

Tracks were previously limited to 2
This commit is contained in:
aler9 2023-01-06 15:25:05 +01:00
parent fa1c07253f
commit 3f7009f72a
9 changed files with 177 additions and 186 deletions

2
go.mod
View File

@ -5,7 +5,7 @@ go 1.18
require (
code.cloudfoundry.org/bytefmt v0.0.0
github.com/abema/go-mp4 v0.0.0
github.com/aler9/gortsplib/v2 v2.0.0-20230103153002-0ce435414414
github.com/aler9/gortsplib/v2 v2.0.0-20230106140016-a759ba9d014b
github.com/asticode/go-astits v1.10.1-0.20220319093903-4abe66a9b757
github.com/fsnotify/fsnotify v1.4.9
github.com/gin-gonic/gin v1.8.1

4
go.sum
View File

@ -4,8 +4,8 @@ github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d h1:UQZhZ2O0vMHr2c
github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d/go.mod h1:rBZYJk541a8SKzHPHnH3zbiI+7dagKZ0cgpgrD7Fyho=
github.com/aler9/go-mp4 v0.0.0-20221229200349-f3d01e787968 h1:wU8pLx4dc8bLB+JuVPWuGp+BoMkOabj98a0RmO3gqvw=
github.com/aler9/go-mp4 v0.0.0-20221229200349-f3d01e787968/go.mod h1:vPl9t5ZK7K0x68jh12/+ECWBCXoWuIDtNgPtU2f04ws=
github.com/aler9/gortsplib/v2 v2.0.0-20230103153002-0ce435414414 h1:pVyJ7Uuk5kdU/RhCepxJQJEC9hsrFgxIIw1mIHn02Zs=
github.com/aler9/gortsplib/v2 v2.0.0-20230103153002-0ce435414414/go.mod h1:lMdAxc6daduSzVwh75yQkvH9UHCYHpng/vJ8uXKFzdA=
github.com/aler9/gortsplib/v2 v2.0.0-20230106140016-a759ba9d014b h1:6Yg4zJ6XowH8dJpSYfyBnp1VR4wOFvCiNdkSdhK+OWQ=
github.com/aler9/gortsplib/v2 v2.0.0-20230106140016-a759ba9d014b/go.mod h1:lMdAxc6daduSzVwh75yQkvH9UHCYHpng/vJ8uXKFzdA=
github.com/aler9/writerseeker v0.0.0-20220601075008-6f0e685b9c82 h1:9WgSzBLo3a9ToSVV7sRTBYZ1GGOZUpq4+5H3SN0UZq4=
github.com/aler9/writerseeker v0.0.0-20220601075008-6f0e685b9c82/go.mod h1:qsMrZCbeBf/mCLOeF16KDkPu4gktn/pOWyaq1aYQE7U=
github.com/asticode/go-astikit v0.20.0 h1:+7N+J4E4lWx2QOkRdOf6DafWJMv6O4RRfgClwQokrH8=

View File

@ -43,8 +43,6 @@ func (s *hlsSource) Log(level logger.Level, format string, args ...interface{})
// run implements sourceStaticImpl.
func (s *hlsSource) run(ctx context.Context) error {
var stream *stream
var videoMedia *media.Media
var audioMedia *media.Media
defer func() {
if stream != nil {
@ -52,23 +50,51 @@ func (s *hlsSource) run(ctx context.Context) error {
}
}()
onTracks := func(videoFormat *format.H264, audioFormat *format.MPEG4Audio) error {
c, err := hls.NewClient(
s.ur,
s.fingerprint,
s,
)
if err != nil {
return err
}
c.OnTracks(func(tracks []format.Format) error {
var medias media.Medias
if videoFormat != nil {
videoMedia = &media.Media{
for _, track := range tracks {
medi := &media.Media{
Type: media.TypeVideo,
Formats: []format.Format{videoFormat},
Formats: []format.Format{track},
}
medias = append(medias, videoMedia)
}
medias = append(medias, medi)
ctrack := track
if audioFormat != nil {
audioMedia = &media.Media{
Type: media.TypeAudio,
Formats: []format.Format{audioFormat},
switch track.(type) {
case *format.H264:
c.OnData(track, func(pts time.Duration, dat interface{}) {
err := stream.writeData(medi, ctrack, &formatprocessor.DataH264{
PTS: pts,
AU: dat.([][]byte),
NTP: time.Now(),
})
if err != nil {
s.Log(logger.Warn, "%v", err)
}
})
case *format.MPEG4Audio:
c.OnData(track, func(pts time.Duration, dat interface{}) {
err := stream.writeData(medi, ctrack, &formatprocessor.DataMPEG4Audio{
PTS: pts,
AUs: [][]byte{dat.([]byte)},
NTP: time.Now(),
})
if err != nil {
s.Log(logger.Warn, "%v", err)
}
})
}
medias = append(medias, audioMedia)
}
res := s.parent.sourceStaticImplSetReady(pathSourceStaticSetReadyReq{
@ -83,41 +109,9 @@ func (s *hlsSource) run(ctx context.Context) error {
stream = res.stream
return nil
}
})
onVideoData := func(pts time.Duration, au [][]byte) {
err := stream.writeData(videoMedia, videoMedia.Formats[0], &formatprocessor.DataH264{
PTS: pts,
AU: au,
NTP: time.Now(),
})
if err != nil {
s.Log(logger.Warn, "%v", err)
}
}
onAudioData := func(pts time.Duration, au []byte) {
err := stream.writeData(audioMedia, audioMedia.Formats[0], &formatprocessor.DataMPEG4Audio{
PTS: pts,
AUs: [][]byte{au},
NTP: time.Now(),
})
if err != nil {
s.Log(logger.Warn, "%v", err)
}
}
c, err := hls.NewClient(
s.ur,
s.fingerprint,
onTracks,
onVideoData,
onAudioData,
s,
)
if err != nil {
return err
}
c.Start()
select {
case err := <-c.Wait():

View File

@ -35,13 +35,12 @@ type ClientLogger interface {
// Client is a HLS client.
type Client struct {
fingerprint string
onTracks func(*format.H264, *format.MPEG4Audio) error
onVideoData func(time.Duration, [][]byte)
onAudioData func(time.Duration, []byte)
logger ClientLogger
ctx context.Context
ctxCancel func()
onTracks func([]format.Format) error
onData map[format.Format]func(time.Duration, interface{})
playlistURL *url.URL
// out
@ -52,9 +51,6 @@ type Client struct {
func NewClient(
playlistURLStr string,
fingerprint string,
onTracks func(*format.H264, *format.MPEG4Audio) error,
onVideoData func(time.Duration, [][]byte),
onAudioData func(time.Duration, []byte),
logger ClientLogger,
) (*Client, error) {
playlistURL, err := url.Parse(playlistURLStr)
@ -66,21 +62,22 @@ func NewClient(
c := &Client{
fingerprint: fingerprint,
onTracks: onTracks,
onVideoData: onVideoData,
onAudioData: onAudioData,
logger: logger,
ctx: ctx,
ctxCancel: ctxCancel,
playlistURL: playlistURL,
onData: make(map[format.Format]func(time.Duration, interface{})),
outErr: make(chan error, 1),
}
go c.run()
return c, nil
}
// Start starts the client.
func (c *Client) Start() {
go c.run()
}
// Close closes all the Client resources.
func (c *Client) Close() {
c.ctxCancel()
@ -91,6 +88,16 @@ func (c *Client) Wait() chan error {
return c.outErr
}
// OnTracks sets a callback that is called when tracks are read.
func (c *Client) OnTracks(cb func([]format.Format) error) {
c.onTracks = cb
}
// OnData sets a callback that is called when data arrives.
func (c *Client) OnData(forma format.Format, cb func(time.Duration, interface{})) {
c.onData[forma] = cb
}
func (c *Client) run() {
c.outErr <- c.runInner()
}
@ -104,8 +111,7 @@ func (c *Client) runInner() error {
c.logger,
rp,
c.onTracks,
c.onVideoData,
c.onAudioData,
c.onData,
)
rp.add(dl)

View File

@ -96,16 +96,15 @@ type clientTimeSync interface{}
type clientDownloaderPrimary struct {
primaryPlaylistURL *url.URL
logger ClientLogger
onTracks func(*format.H264, *format.MPEG4Audio) error
onVideoData func(time.Duration, [][]byte)
onAudioData func(time.Duration, []byte)
onTracks func([]format.Format) error
onData map[format.Format]func(time.Duration, interface{})
rp *clientRoutinePool
httpClient *http.Client
leadingTimeSync clientTimeSync
// in
streamFormats chan []format.Format
streamTracks chan []format.Format
// out
startStreaming chan struct{}
@ -117,9 +116,8 @@ func newClientDownloaderPrimary(
fingerprint string,
logger ClientLogger,
rp *clientRoutinePool,
onTracks func(*format.H264, *format.MPEG4Audio) error,
onVideoData func(time.Duration, [][]byte),
onAudioData func(time.Duration, []byte),
onTracks func([]format.Format) error,
onData map[format.Format]func(time.Duration, interface{}),
) *clientDownloaderPrimary {
var tlsConfig *tls.Config
if fingerprint != "" {
@ -145,15 +143,14 @@ func newClientDownloaderPrimary(
primaryPlaylistURL: primaryPlaylistURL,
logger: logger,
onTracks: onTracks,
onVideoData: onVideoData,
onAudioData: onAudioData,
onData: onData,
rp: rp,
httpClient: &http.Client{
Transport: &http.Transport{
TLSClientConfig: tlsConfig,
},
},
streamFormats: make(chan []format.Format),
streamTracks: make(chan []format.Format),
startStreaming: make(chan struct{}),
leadingTimeSyncReady: make(chan struct{}),
}
@ -179,11 +176,11 @@ func (d *clientDownloaderPrimary) run(ctx context.Context) error {
plt,
d.logger,
d.rp,
d.onStreamFormats,
d.onStreamTracks,
d.onSetLeadingTimeSync,
d.onGetLeadingTimeSync,
d.onVideoData,
d.onAudioData)
d.onData,
)
d.rp.add(ds)
streamCount++
@ -205,11 +202,11 @@ func (d *clientDownloaderPrimary) run(ctx context.Context) error {
nil,
d.logger,
d.rp,
d.onStreamFormats,
d.onStreamTracks,
d.onSetLeadingTimeSync,
d.onGetLeadingTimeSync,
d.onVideoData,
d.onAudioData)
d.onData,
)
d.rp.add(ds)
streamCount++
@ -231,11 +228,11 @@ func (d *clientDownloaderPrimary) run(ctx context.Context) error {
nil,
d.logger,
d.rp,
d.onStreamFormats,
d.onStreamTracks,
d.onSetLeadingTimeSync,
d.onGetLeadingTimeSync,
d.onVideoData,
d.onAudioData)
d.onData,
)
d.rp.add(ds)
streamCount++
}
@ -248,31 +245,18 @@ func (d *clientDownloaderPrimary) run(ctx context.Context) error {
for i := 0; i < streamCount; i++ {
select {
case streamFormats := <-d.streamFormats:
tracks = append(tracks, streamFormats...)
case streamTracks := <-d.streamTracks:
tracks = append(tracks, streamTracks...)
case <-ctx.Done():
return fmt.Errorf("terminated")
}
}
var videoTrack *format.H264
var audioTrack *format.MPEG4Audio
for _, track := range tracks {
switch ttrack := track.(type) {
case *format.H264:
videoTrack = ttrack
case *format.MPEG4Audio:
audioTrack = ttrack
}
}
if videoTrack == nil && audioTrack == nil {
if len(tracks) == 0 {
return fmt.Errorf("no supported tracks found")
}
err = d.onTracks(videoTrack, audioTrack)
err = d.onTracks(tracks)
if err != nil {
return err
}
@ -282,9 +266,9 @@ func (d *clientDownloaderPrimary) run(ctx context.Context) error {
return nil
}
func (d *clientDownloaderPrimary) onStreamFormats(ctx context.Context, tracks []format.Format) bool {
func (d *clientDownloaderPrimary) onStreamTracks(ctx context.Context, tracks []format.Format) bool {
select {
case d.streamFormats <- tracks:
case d.streamTracks <- tracks:
case <-ctx.Done():
return false
}

View File

@ -50,11 +50,10 @@ type clientDownloaderStream struct {
initialPlaylist *m3u8.MediaPlaylist
logger ClientLogger
rp *clientRoutinePool
onStreamFormats func(context.Context, []format.Format) bool
onStreamTracks func(context.Context, []format.Format) bool
onSetLeadingTimeSync func(clientTimeSync)
onGetLeadingTimeSync func(context.Context) (clientTimeSync, bool)
onVideoData func(time.Duration, [][]byte)
onAudioData func(time.Duration, []byte)
onData map[format.Format]func(time.Duration, interface{})
curSegmentID *uint64
}
@ -66,11 +65,10 @@ func newClientDownloaderStream(
initialPlaylist *m3u8.MediaPlaylist,
logger ClientLogger,
rp *clientRoutinePool,
onStreamFormats func(context.Context, []format.Format) bool,
onStreamTracks func(context.Context, []format.Format) bool,
onSetLeadingTimeSync func(clientTimeSync),
onGetLeadingTimeSync func(context.Context) (clientTimeSync, bool),
onVideoData func(time.Duration, [][]byte),
onAudioData func(time.Duration, []byte),
onData map[format.Format]func(time.Duration, interface{}),
) *clientDownloaderStream {
return &clientDownloaderStream{
isLeading: isLeading,
@ -79,11 +77,10 @@ func newClientDownloaderStream(
initialPlaylist: initialPlaylist,
logger: logger,
rp: rp,
onStreamFormats: onStreamFormats,
onStreamTracks: onStreamTracks,
onSetLeadingTimeSync: onSetLeadingTimeSync,
onGetLeadingTimeSync: onGetLeadingTimeSync,
onVideoData: onVideoData,
onAudioData: onAudioData,
onData: onData,
}
}
@ -113,11 +110,10 @@ func (d *clientDownloaderStream) run(ctx context.Context) error {
segmentQueue,
d.logger,
d.rp,
d.onStreamFormats,
d.onStreamTracks,
d.onSetLeadingTimeSync,
d.onGetLeadingTimeSync,
d.onVideoData,
d.onAudioData,
d.onData,
)
if err != nil {
return err
@ -130,11 +126,10 @@ func (d *clientDownloaderStream) run(ctx context.Context) error {
segmentQueue,
d.logger,
d.rp,
d.onStreamFormats,
d.onStreamTracks,
d.onSetLeadingTimeSync,
d.onGetLeadingTimeSync,
d.onVideoData,
d.onAudioData,
d.onData,
)
d.rp.add(proc)
}

View File

@ -14,7 +14,8 @@ import (
func fmp4PickLeadingTrack(init *fmp4.Init) int {
// pick first video track
for _, track := range init.Tracks {
if _, ok := track.Format.(*format.H264); ok {
switch track.Format.(type) {
case *format.H264, *format.H265:
return track.ID
}
}
@ -30,8 +31,7 @@ type clientProcessorFMP4 struct {
rp *clientRoutinePool
onSetLeadingTimeSync func(clientTimeSync)
onGetLeadingTimeSync func(context.Context) (clientTimeSync, bool)
onVideoData func(time.Duration, [][]byte)
onAudioData func(time.Duration, []byte)
onData map[format.Format]func(time.Duration, interface{})
init fmp4.Init
leadingTrackID int
@ -51,8 +51,7 @@ func newClientProcessorFMP4(
onStreamFormats func(context.Context, []format.Format) bool,
onSetLeadingTimeSync func(clientTimeSync),
onGetLeadingTimeSync func(context.Context) (clientTimeSync, bool),
onVideoData func(time.Duration, [][]byte),
onAudioData func(time.Duration, []byte),
onData map[format.Format]func(time.Duration, interface{}),
) (*clientProcessorFMP4, error) {
p := &clientProcessorFMP4{
isLeading: isLeading,
@ -61,8 +60,7 @@ func newClientProcessorFMP4(
rp: rp,
onSetLeadingTimeSync: onSetLeadingTimeSync,
onGetLeadingTimeSync: onGetLeadingTimeSync,
onVideoData: onVideoData,
onAudioData: onAudioData,
onData: onData,
subpartProcessed: make(chan struct{}, clientFMP4MaxPartTracksPerSegment),
}
@ -186,21 +184,27 @@ func (p *clientProcessorFMP4) initializeTrackProcs(ts *clientTimeSyncFMP4) {
for _, track := range p.init.Tracks {
var cb func(time.Duration, []byte) error
cb2, ok := p.onData[track.Format]
if !ok {
cb2 = func(time.Duration, interface{}) {
}
}
switch track.Format.(type) {
case *format.H264:
case *format.H264, *format.H265:
cb = func(pts time.Duration, payload []byte) error {
nalus, err := h264.AVCCUnmarshal(payload)
if err != nil {
return err
}
p.onVideoData(pts, nalus)
cb2(pts, nalus)
return nil
}
case *format.MPEG4Audio:
case *format.MPEG4Audio, *format.Opus:
cb = func(pts time.Duration, payload []byte) error {
p.onAudioData(pts, payload)
cb2(pts, payload)
return nil
}
}

View File

@ -36,8 +36,7 @@ type clientProcessorMPEGTS struct {
onStreamFormats func(context.Context, []format.Format) bool
onSetLeadingTimeSync func(clientTimeSync)
onGetLeadingTimeSync func(context.Context) (clientTimeSync, bool)
onVideoData func(time.Duration, [][]byte)
onAudioData func(time.Duration, []byte)
onData map[format.Format]func(time.Duration, interface{})
mpegtsTracks []*mpegts.Track
leadingTrackPID uint16
@ -52,8 +51,7 @@ func newClientProcessorMPEGTS(
onStreamFormats func(context.Context, []format.Format) bool,
onSetLeadingTimeSync func(clientTimeSync),
onGetLeadingTimeSync func(context.Context) (clientTimeSync, bool),
onVideoData func(time.Duration, [][]byte),
onAudioData func(time.Duration, []byte),
onData map[format.Format]func(time.Duration, interface{}),
) *clientProcessorMPEGTS {
return &clientProcessorMPEGTS{
isLeading: isLeading,
@ -63,8 +61,7 @@ func newClientProcessorMPEGTS(
onStreamFormats: onStreamFormats,
onSetLeadingTimeSync: onSetLeadingTimeSync,
onGetLeadingTimeSync: onGetLeadingTimeSync,
onVideoData: onVideoData,
onAudioData: onAudioData,
onData: onData,
}
}
@ -174,10 +171,16 @@ func (p *clientProcessorMPEGTS) processSegment(ctx context.Context, byts []byte)
func (p *clientProcessorMPEGTS) initializeTrackProcs(ts *clientTimeSyncMPEGTS) {
p.trackProcs = make(map[uint16]*clientProcessorMPEGTSTrack)
for _, mt := range p.mpegtsTracks {
for _, track := range p.mpegtsTracks {
var cb func(time.Duration, []byte) error
switch mt.Format.(type) {
cb2, ok := p.onData[track.Format]
if !ok {
cb2 = func(time.Duration, interface{}) {
}
}
switch track.Format.(type) {
case *format.H264:
cb = func(pts time.Duration, payload []byte) error {
nalus, err := h264.AnnexBUnmarshal(payload)
@ -186,7 +189,7 @@ func (p *clientProcessorMPEGTS) initializeTrackProcs(ts *clientTimeSyncMPEGTS) {
return nil
}
p.onVideoData(pts, nalus)
cb2(pts, nalus)
return nil
}
@ -199,7 +202,7 @@ func (p *clientProcessorMPEGTS) initializeTrackProcs(ts *clientTimeSyncMPEGTS) {
}
for i, pkt := range adtsPkts {
p.onAudioData(
cb2(
pts+time.Duration(i)*mpeg4audio.SamplesPerAccessUnit*time.Second/time.Duration(pkt.SampleRate),
pkt.AU)
}
@ -213,6 +216,6 @@ func (p *clientProcessorMPEGTS) initializeTrackProcs(ts *clientTimeSyncMPEGTS) {
cb,
)
p.rp.add(proc)
p.trackProcs[mt.ES.ElementaryPID] = proc
p.trackProcs[track.ES.ElementaryPID] = proc
}
}

View File

@ -277,29 +277,32 @@ func TestClientMPEGTS(t *testing.T) {
c, err := NewClient(
prefix+"://localhost:5780/stream.m3u8",
"33949E05FFFB5FF3E8AA16F8213A6251B4D9363804BA53233C4DA9A46D6F2739",
func(videoTrack *format.H264, audioTrack *format.MPEG4Audio) error {
require.Equal(t, &format.H264{
PayloadTyp: 96,
PacketizationMode: 1,
}, videoTrack)
require.Equal(t, (*format.MPEG4Audio)(nil), audioTrack)
return nil
},
func(pts time.Duration, nalus [][]byte) {
require.Equal(t, 2*time.Second, pts)
require.Equal(t, [][]byte{
{7, 1, 2, 3},
{8},
{5},
}, nalus)
close(packetRecv)
},
func(pts time.Duration, au []byte) {
},
testLogger{},
)
require.NoError(t, err)
onH264 := func(pts time.Duration, dat interface{}) {
require.Equal(t, 2*time.Second, pts)
require.Equal(t, [][]byte{
{7, 1, 2, 3},
{8},
{5},
}, dat)
close(packetRecv)
}
c.OnTracks(func(tracks []format.Format) error {
require.Equal(t, 1, len(tracks))
require.Equal(t, &format.H264{
PayloadTyp: 96,
PacketizationMode: 1,
}, tracks[0])
c.OnData(tracks[0], onH264)
return nil
})
c.Start()
<-packetRecv
c.Close()
@ -341,34 +344,33 @@ func TestClientFMP4(t *testing.T) {
packetRecv := make(chan struct{})
onH264 := func(pts time.Duration, dat interface{}) {
require.Equal(t, 2*time.Second, pts)
require.Equal(t, [][]byte{
{7, 1, 2, 3},
{8},
{5},
}, dat)
close(packetRecv)
}
c, err := NewClient(
"http://localhost:5780/stream.m3u8",
"",
func(videoTrack *format.H264, audioTrack *format.MPEG4Audio) error {
require.Equal(t, &format.H264{
PayloadTyp: 96,
PacketizationMode: 1,
SPS: videoTrack.SPS,
PPS: videoTrack.PPS,
}, videoTrack)
require.Equal(t, (*format.MPEG4Audio)(nil), audioTrack)
return nil
},
func(pts time.Duration, nalus [][]byte) {
require.Equal(t, 2*time.Second, pts)
require.Equal(t, [][]byte{
{7, 1, 2, 3},
{8},
{5},
}, nalus)
close(packetRecv)
},
func(pts time.Duration, au []byte) {
},
testLogger{},
)
require.NoError(t, err)
c.OnTracks(func(tracks []format.Format) error {
require.Equal(t, 1, len(tracks))
_, ok := tracks[0].(*format.H264)
require.Equal(t, true, ok)
c.OnData(tracks[0], onH264)
return nil
})
c.Start()
<-packetRecv
c.Close()
@ -425,23 +427,26 @@ func TestClientInvalidSequenceID(t *testing.T) {
packetRecv := make(chan struct{})
onH264 := func(pts time.Duration, dat interface{}) {
close(packetRecv)
}
c, err := NewClient(
"http://localhost:5780/stream.m3u8",
"",
func(*format.H264, *format.MPEG4Audio) error {
return nil
},
func(pts time.Duration, nalus [][]byte) {
close(packetRecv)
},
nil,
testLogger{},
)
require.NoError(t, err)
c.OnTracks(func(tracks []format.Format) error {
c.OnData(tracks[0], onH264)
return nil
})
c.Start()
<-packetRecv
// c.Close()
err = <-c.Wait()
require.EqualError(t, err, "following segment not found or not ready yet")