diff --git a/web/web.go b/web/web.go index 37d12f4c36..3602150d4d 100644 --- a/web/web.go +++ b/web/web.go @@ -918,8 +918,13 @@ func (h *Handler) version(w http.ResponseWriter, r *http.Request) { } func (h *Handler) quit(w http.ResponseWriter, r *http.Request) { - fmt.Fprintf(w, "Requesting termination... Goodbye!") - close(h.quitCh) + select { + case <-h.quitCh: + fmt.Fprintf(w, "Termination already in progress.") + default: + fmt.Fprintf(w, "Requesting termination... Goodbye!") + close(h.quitCh) + } } func (h *Handler) reload(w http.ResponseWriter, r *http.Request) { diff --git a/web/web_test.go b/web/web_test.go index 7759383573..c32485db53 100644 --- a/web/web_test.go +++ b/web/web_test.go @@ -27,6 +27,7 @@ import ( "path/filepath" "strconv" "strings" + "sync" "testing" "time" @@ -451,6 +452,60 @@ func TestShutdownWithStaleConnection(t *testing.T) { } } +func TestHandleMultipleQuitRequests(t *testing.T) { + opts := &Options{ + ListenAddress: ":9090", + MaxConnections: 512, + EnableLifecycle: true, + RoutePrefix: "/", + ExternalURL: &url.URL{ + Scheme: "http", + Host: "localhost:9090", + Path: "/", + }, + } + webHandler := New(nil, opts) + webHandler.config = &config.Config{} + webHandler.notifier = ¬ifier.Manager{} + ctx, cancel := context.WithCancel(context.Background()) + closed := make(chan struct{}) + go func() { + err := webHandler.Run(ctx) + if err != nil { + panic(fmt.Sprintf("Can't start web handler:%s", err)) + } + close(closed) + }() + + // Give some time for the web goroutine to run since we need the server + // to be up before starting tests. + time.Sleep(5 * time.Second) + + start := make(chan struct{}) + var wg sync.WaitGroup + for i := 0; i < 3; i++ { + wg.Add(1) + go func() { + defer wg.Done() + <-start + resp, err := http.Post("http://localhost:9090/-/quit", "", strings.NewReader("")) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + }() + } + close(start) + wg.Wait() + + // Stop the web handler. + cancel() + + select { + case <-closed: + case <-time.After(5 * time.Second): + t.Fatalf("Server still running after 5 seconds.") + } +} + func cleanupTestResponse(t *testing.T, resp *http.Response) { _, err := io.Copy(ioutil.Discard, resp.Body) require.NoError(t, err)