// Copyright 2024 The Prometheus Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package netconnlimit import ( "io" "net" "sync" "testing" "time" "github.com/stretchr/testify/require" ) func TestSharedLimitListenerConcurrency(t *testing.T) { testCases := []struct { name string semCapacity int connCount int expected int // Expected number of connections processed simultaneously. }{ { name: "Single connection allowed", semCapacity: 1, connCount: 3, expected: 1, }, { name: "Two connections allowed", semCapacity: 2, connCount: 3, expected: 2, }, { name: "Three connections allowed", semCapacity: 3, connCount: 3, expected: 3, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { sem := NewSharedSemaphore(tc.semCapacity) listener, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err, "failed to create listener") defer listener.Close() limitedListener := SharedLimitListener(listener, sem) var wg sync.WaitGroup var activeConnCount int64 var mu sync.Mutex wg.Add(tc.connCount) // Accept connections. for i := 0; i < tc.connCount; i++ { go func() { defer wg.Done() conn, err := limitedListener.Accept() require.NoError(t, err, "failed to accept connection") defer conn.Close() // Simulate work and track the active connection count. mu.Lock() activeConnCount++ require.LessOrEqual(t, activeConnCount, int64(tc.expected), "too many simultaneous connections") mu.Unlock() time.Sleep(100 * time.Millisecond) mu.Lock() activeConnCount-- mu.Unlock() }() } // Create clients that attempt to connect to the listener. for i := 0; i < tc.connCount; i++ { go func() { conn, err := net.Dial("tcp", listener.Addr().String()) require.NoError(t, err, "failed to connect to listener") defer conn.Close() _, _ = io.WriteString(conn, "hello") }() } wg.Wait() // Ensure all connections are released and semaphore is empty. require.Empty(t, sem) }) } } func TestSharedLimitListenerClose(t *testing.T) { sem := NewSharedSemaphore(2) listener, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err, "failed to create listener") limitedListener := SharedLimitListener(listener, sem) // Close the listener and ensure it does not accept new connections. err = limitedListener.Close() require.NoError(t, err, "failed to close listener") conn, err := limitedListener.Accept() require.Error(t, err, "expected error on accept after listener closed") if conn != nil { conn.Close() } }