From 1989997fe19fd699e4cb08c1db8963fe8e43ef25 Mon Sep 17 00:00:00 2001 From: Jonathan Martin Date: Sun, 10 Mar 2024 11:33:00 +0100 Subject: [PATCH] webrtc: fix memory leak when WHEP track gathering fails and decrease count of idle HTTP connections (#3124) * fix: whep gathering failure leaks peer connections * fix: failure to create whep client leaks when read fails, client is not closed * fix: close idle connection with whip client * fix: no link check on early whip client close * move http.Client.CloseIdleConnections() outside WHIPClient * automatically call WHEPClient.Close() in case of errors during WHEPClient.Read() or WHEPClient.Publish() --------- Co-authored-by: Jonathan Martin Co-authored-by: aler9 <46489434+aler9@users.noreply.github.com> --- internal/protocols/webrtc/whip_client.go | 17 ++++++++++++++++- internal/staticsources/webrtc/source.go | 19 +++++++++++-------- 2 files changed, 27 insertions(+), 9 deletions(-) diff --git a/internal/protocols/webrtc/whip_client.go b/internal/protocols/webrtc/whip_client.go index 2c4d5a01..61c076a4 100644 --- a/internal/protocols/webrtc/whip_client.go +++ b/internal/protocols/webrtc/whip_client.go @@ -78,6 +78,7 @@ func (c *WHIPClient) Publish( err = c.pc.SetAnswer(res.Answer) if err != nil { + WHIPDeleteSession(context.Background(), c.HTTPClient, c.URL.String()) //nolint:errcheck c.pc.Close() return nil, err } @@ -91,6 +92,7 @@ outer: case ca := <-c.pc.NewLocalCandidate(): err := WHIPPatchCandidate(ctx, c.HTTPClient, c.URL.String(), offer, res.ETag, ca) if err != nil { + WHIPDeleteSession(context.Background(), c.HTTPClient, c.URL.String()) //nolint:errcheck c.pc.Close() return nil, err } @@ -101,6 +103,7 @@ outer: break outer case <-t.C: + WHIPDeleteSession(context.Background(), c.HTTPClient, c.URL.String()) //nolint:errcheck c.pc.Close() return nil, fmt.Errorf("deadline exceeded while waiting connection") } @@ -156,6 +159,7 @@ func (c *WHIPClient) Read(ctx context.Context) ([]*IncomingTrack, error) { var sdp sdp.SessionDescription err = sdp.Unmarshal([]byte(res.Answer.SDP)) if err != nil { + WHIPDeleteSession(context.Background(), c.HTTPClient, c.URL.String()) //nolint:errcheck c.pc.Close() return nil, err } @@ -163,12 +167,14 @@ func (c *WHIPClient) Read(ctx context.Context) ([]*IncomingTrack, error) { // check that there are at most two tracks _, err = TrackCount(sdp.MediaDescriptions) if err != nil { + WHIPDeleteSession(context.Background(), c.HTTPClient, c.URL.String()) //nolint:errcheck c.pc.Close() return nil, err } err = c.pc.SetAnswer(res.Answer) if err != nil { + WHIPDeleteSession(context.Background(), c.HTTPClient, c.URL.String()) //nolint:errcheck c.pc.Close() return nil, err } @@ -182,6 +188,7 @@ outer: case ca := <-c.pc.NewLocalCandidate(): err := WHIPPatchCandidate(ctx, c.HTTPClient, c.URL.String(), offer, res.ETag, ca) if err != nil { + WHIPDeleteSession(context.Background(), c.HTTPClient, c.URL.String()) //nolint:errcheck c.pc.Close() return nil, err } @@ -192,12 +199,20 @@ outer: break outer case <-t.C: + WHIPDeleteSession(context.Background(), c.HTTPClient, c.URL.String()) //nolint:errcheck c.pc.Close() return nil, fmt.Errorf("deadline exceeded while waiting connection") } } - return c.pc.GatherIncomingTracks(ctx, 0) + tracks, err := c.pc.GatherIncomingTracks(ctx, 0) + if err != nil { + WHIPDeleteSession(context.Background(), c.HTTPClient, c.URL.String()) //nolint:errcheck + c.pc.Close() + return nil, err + } + + return tracks, nil } // Close closes the client. diff --git a/internal/staticsources/webrtc/source.go b/internal/staticsources/webrtc/source.go index 5a6aab67..c3f588df 100644 --- a/internal/staticsources/webrtc/source.go +++ b/internal/staticsources/webrtc/source.go @@ -40,15 +40,18 @@ func (s *Source) Run(params defs.StaticSourceRunParams) error { u.Scheme = strings.ReplaceAll(u.Scheme, "whep", "http") - client := webrtc.WHIPClient{ - HTTPClient: &http.Client{ - Timeout: time.Duration(s.ReadTimeout), - Transport: &http.Transport{ - TLSClientConfig: tls.ConfigForFingerprint(params.Conf.SourceFingerprint), - }, + hc := &http.Client{ + Timeout: time.Duration(s.ReadTimeout), + Transport: &http.Transport{ + TLSClientConfig: tls.ConfigForFingerprint(params.Conf.SourceFingerprint), }, - URL: u, - Log: s, + } + defer hc.CloseIdleConnections() + + client := webrtc.WHIPClient{ + HTTPClient: hc, + URL: u, + Log: s, } tracks, err := client.Read(params.Context)