From 600f7bf48ccc2dd0d20a1eb60340b6924b5b1053 Mon Sep 17 00:00:00 2001 From: aler9 <46489434+aler9@users.noreply.github.com> Date: Sat, 14 Aug 2021 15:39:29 +0200 Subject: [PATCH] hls: move NALU filtering into hls muxer --- internal/core/hls_remuxer.go | 26 +------------------ internal/hls/muxer.go | 22 +++++++++++++--- internal/hls/muxer_test.go | 13 +++++----- internal/hls/tsfile.go | 50 ++++++++++++++++++++++++++---------- 4 files changed, 62 insertions(+), 49 deletions(-) diff --git a/internal/core/hls_remuxer.go b/internal/core/hls_remuxer.go index 61b2fbb6..0d655fb2 100644 --- a/internal/core/hls_remuxer.go +++ b/internal/core/hls_remuxer.go @@ -18,7 +18,6 @@ import ( "github.com/aler9/gortsplib/pkg/rtph264" "github.com/pion/rtp" - "github.com/aler9/rtsp-simple-server/internal/h264" "github.com/aler9/rtsp-simple-server/internal/hls" "github.com/aler9/rtsp-simple-server/internal/logger" ) @@ -244,8 +243,6 @@ func (r *hlsRemuxer) runRemuxer(remuxerCtx context.Context, remuxerReady chan st var videoTrack *gortsplib.Track videoTrackID := -1 - var h264SPS []byte - var h264PPS []byte var h264Decoder *rtph264.Decoder var audioTrack *gortsplib.Track audioTrackID := -1 @@ -261,12 +258,6 @@ func (r *hlsRemuxer) runRemuxer(remuxerCtx context.Context, remuxerReady chan st videoTrack = t videoTrackID = i - var err error - h264SPS, h264PPS, err = t.ExtractDataH264() - if err != nil { - return err - } - h264Decoder = rtph264.NewDecoder() } else if t.IsAAC() { @@ -341,22 +332,7 @@ func (r *hlsRemuxer) runRemuxer(remuxerCtx context.Context, remuxerReady chan st continue } - for _, nalu := range nalus { - // remove SPS, PPS, AUD - typ := h264.NALUType(nalu[0] & 0x1F) - switch typ { - case h264.NALUTypeSPS, h264.NALUTypePPS, h264.NALUTypeAccessUnitDelimiter: - continue - } - - // add SPS and PPS before IDR - if typ == h264.NALUTypeIDR { - videoBuf = append(videoBuf, h264SPS) - videoBuf = append(videoBuf, h264PPS) - } - - videoBuf = append(videoBuf, nalu) - } + videoBuf = append(videoBuf, nalus...) // RTP marker means that all the NALUs with the same PTS have been received. // send them together. diff --git a/internal/hls/muxer.go b/internal/hls/muxer.go index 9d83fdac..53a6ebe6 100644 --- a/internal/hls/muxer.go +++ b/internal/hls/muxer.go @@ -31,6 +31,8 @@ type Muxer struct { videoTrack *gortsplib.Track audioTrack *gortsplib.Track + h264SPS []byte + h264PPS []byte aacConfig rtpaac.MPEG4AudioConfig startPCR time.Time videoDTSEst *h264.DTSEstimator @@ -48,6 +50,16 @@ func NewMuxer( hlsSegmentDuration time.Duration, videoTrack *gortsplib.Track, audioTrack *gortsplib.Track) (*Muxer, error) { + var h264SPS []byte + var h264PPS []byte + if videoTrack != nil { + var err error + h264SPS, h264PPS, err = videoTrack.ExtractDataH264() + if err != nil { + return nil, err + } + } + var aacConfig rtpaac.MPEG4AudioConfig if audioTrack != nil { byts, err := audioTrack.ExtractDataAAC() @@ -66,10 +78,12 @@ func NewMuxer( hlsSegmentDuration: hlsSegmentDuration, videoTrack: videoTrack, audioTrack: audioTrack, + h264SPS: h264SPS, + h264PPS: h264PPS, aacConfig: aacConfig, startPCR: time.Now(), videoDTSEst: h264.NewDTSEstimator(), - tsCurrent: newTSFile(videoTrack != nil, audioTrack != nil), + tsCurrent: newTSFile(videoTrack, audioTrack), tsByName: make(map[string]*tsFile), } @@ -111,7 +125,7 @@ func (m *Muxer) WriteH264(pts time.Duration, nalus [][]byte) error { m.tsCurrent.close() } - m.tsCurrent = newTSFile(m.videoTrack != nil, m.audioTrack != nil) + m.tsCurrent = newTSFile(m.videoTrack, m.audioTrack) m.tsByName[m.tsCurrent.name] = m.tsCurrent m.tsQueue = append(m.tsQueue, m.tsCurrent) @@ -124,6 +138,8 @@ func (m *Muxer) WriteH264(pts time.Duration, nalus [][]byte) error { m.tsCurrent.setPCR(time.Since(m.startPCR)) err := m.tsCurrent.writeH264( + m.h264SPS, + m.h264PPS, m.videoDTSEst.Feed(pts+ptsOffset), pts+ptsOffset, idrPresent, @@ -150,7 +166,7 @@ func (m *Muxer) WriteAAC(pts time.Duration, aus [][]byte) error { } m.audioAUCount = 0 - m.tsCurrent = newTSFile(m.videoTrack != nil, m.audioTrack != nil) + m.tsCurrent = newTSFile(m.videoTrack, m.audioTrack) m.tsByName[m.tsCurrent.name] = m.tsCurrent m.tsQueue = append(m.tsQueue, m.tsCurrent) if len(m.tsQueue) > m.hlsSegmentCount { diff --git a/internal/hls/muxer_test.go b/internal/hls/muxer_test.go index 6b89c277..edcc98a7 100644 --- a/internal/hls/muxer_test.go +++ b/internal/hls/muxer_test.go @@ -17,7 +17,7 @@ func checkTSPacket(t *testing.T, byts []byte, pid int, afc int) { } func TestMuxer(t *testing.T) { - videoTrack, err := gortsplib.NewTrackH264(96, []byte{0x01, 0x02, 0x03, 0x04}, []byte{0x01, 0x02, 0x03, 0x04}) + videoTrack, err := gortsplib.NewTrackH264(96, []byte{0x07, 0x01, 0x02, 0x03}, []byte{0x08}) require.NoError(t, err) audioTrack, err := gortsplib.NewTrackAAC(97, []byte{17, 144}) @@ -92,12 +92,11 @@ func TestMuxer(t *testing.T) { byts = byts[4+alen+20:] require.Equal(t, []byte{ - 0, 0, 0, 1, 9, 240, - 0, 0, 0, 1, 5, - 0, 0, 0, 1, 9, - 0, 0, 0, 1, 8, - 0, 0, 0, 1, 7, + 0, 0, 0, 1, 9, 240, // AUD + 0, 0, 0, 1, 7, 1, 2, 3, // SPS + 0, 0, 0, 1, 8, // PPS + 0, 0, 0, 1, 5, // IDR }, - byts[:26], + byts[:24], ) } diff --git a/internal/hls/tsfile.go b/internal/hls/tsfile.go index e58fcc77..49649ced 100644 --- a/internal/hls/tsfile.go +++ b/internal/hls/tsfile.go @@ -6,6 +6,7 @@ import ( "strconv" "time" + "github.com/aler9/gortsplib" "github.com/asticode/go-astits" "github.com/aler9/rtsp-simple-server/internal/aac" @@ -13,7 +14,7 @@ import ( ) type tsFile struct { - hasVideoTrack bool + videoTrack *gortsplib.Track name string buf *multiAccessBuffer mux *astits.Muxer @@ -23,30 +24,30 @@ type tsFile struct { maxPTS time.Duration } -func newTSFile(hasVideoTrack bool, hasAudioTrack bool) *tsFile { +func newTSFile(videoTrack *gortsplib.Track, audioTrack *gortsplib.Track) *tsFile { t := &tsFile{ - hasVideoTrack: hasVideoTrack, - name: strconv.FormatInt(time.Now().Unix(), 10), - buf: newMultiAccessBuffer(), + videoTrack: videoTrack, + name: strconv.FormatInt(time.Now().Unix(), 10), + buf: newMultiAccessBuffer(), } t.mux = astits.NewMuxer(context.Background(), t.buf) - if hasVideoTrack { + if videoTrack != nil { t.mux.AddElementaryStream(astits.PMTElementaryStream{ ElementaryPID: 256, StreamType: astits.StreamTypeH264Video, }) } - if hasAudioTrack { + if audioTrack != nil { t.mux.AddElementaryStream(astits.PMTElementaryStream{ ElementaryPID: 257, StreamType: astits.StreamTypeAACAudio, }) } - if hasVideoTrack { + if videoTrack != nil { t.mux.SetPCRPID(256) } else { t.mux.SetPCRPID(257) @@ -76,7 +77,9 @@ func (t *tsFile) newReader() io.Reader { return t.buf.NewReader() } -func (t *tsFile) writeH264(dts time.Duration, pts time.Duration, isIDR bool, nalus [][]byte) error { +func (t *tsFile) writeH264( + h264SPS []byte, h264PPS []byte, + dts time.Duration, pts time.Duration, isIDR bool, nalus [][]byte) error { if !t.firstPacketWritten { t.firstPacketWritten = true t.minPTS = pts @@ -90,10 +93,29 @@ func (t *tsFile) writeH264(dts time.Duration, pts time.Duration, isIDR bool, nal } } - // prepend an AUD. This is required by video.js and iOS - nalus = append([][]byte{{byte(h264.NALUTypeAccessUnitDelimiter), 240}}, nalus...) + filteredNALUs := [][]byte{ + // prepend an AUD. This is required by video.js and iOS + {byte(h264.NALUTypeAccessUnitDelimiter), 240}, + } - enc, err := h264.EncodeAnnexB(nalus) + for _, nalu := range nalus { + // remove existing SPS, PPS, AUD + typ := h264.NALUType(nalu[0] & 0x1F) + switch typ { + case h264.NALUTypeSPS, h264.NALUTypePPS, h264.NALUTypeAccessUnitDelimiter: + continue + } + + // add SPS and PPS before IDR + if typ == h264.NALUTypeIDR { + filteredNALUs = append(filteredNALUs, h264SPS) + filteredNALUs = append(filteredNALUs, h264PPS) + } + + filteredNALUs = append(filteredNALUs, nalu) + } + + enc, err := h264.EncodeAnnexB(filteredNALUs) if err != nil { return err } @@ -124,7 +146,7 @@ func (t *tsFile) writeH264(dts time.Duration, pts time.Duration, isIDR bool, nal } func (t *tsFile) writeAAC(sampleRate int, channelCount int, pts time.Duration, au []byte) error { - if !t.hasVideoTrack { + if t.videoTrack == nil { if !t.firstPacketWritten { t.firstPacketWritten = true t.minPTS = pts @@ -154,7 +176,7 @@ func (t *tsFile) writeAAC(sampleRate int, channelCount int, pts time.Duration, a RandomAccessIndicator: true, } - if !t.hasVideoTrack { + if t.videoTrack == nil { af.HasPCR = true af.PCR = &astits.ClockReference{Base: int64(t.pcr.Seconds() * 90000)} }