From 851131b0740be7291b98f295567a97f32fffc655 Mon Sep 17 00:00:00 2001 From: Tom Wilkie Date: Sun, 30 Jun 2019 11:50:23 +0100 Subject: [PATCH] Allow injection of arbitrary headers in promtool, for auth etc. (#4389) Signed-off-by: Tom Wilkie --- cmd/promtool/main.go | 17 ++++++++++++++--- cmd/promtool/main_test.go | 4 ++-- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/cmd/promtool/main.go b/cmd/promtool/main.go index 709f75dda..ffae6991d 100644 --- a/cmd/promtool/main.go +++ b/cmd/promtool/main.go @@ -19,6 +19,7 @@ import ( "encoding/json" "fmt" "math" + "net/http" "net/url" "os" "path/filepath" @@ -30,6 +31,7 @@ import ( "github.com/pkg/errors" "github.com/prometheus/client_golang/api" v1 "github.com/prometheus/client_golang/api/prometheus/v1" + "github.com/prometheus/client_golang/prometheus/promhttp" config_util "github.com/prometheus/common/config" "github.com/prometheus/common/model" "github.com/prometheus/common/version" @@ -70,6 +72,7 @@ func main() { queryRangeCmd := queryCmd.Command("range", "Run range query.") queryRangeServer := queryRangeCmd.Arg("server", "Prometheus server to query.").Required().String() queryRangeExpr := queryRangeCmd.Arg("expr", "PromQL query expression.").Required().String() + queryRangeHeaders := queryRangeCmd.Flag("header", "Extra headers to send to server.").StringMap() queryRangeBegin := queryRangeCmd.Flag("start", "Query range start time (RFC3339 or Unix timestamp).").String() queryRangeEnd := queryRangeCmd.Flag("end", "Query range end time (RFC3339 or Unix timestamp).").String() queryRangeStep := queryRangeCmd.Flag("step", "Query step size (duration).").Duration() @@ -123,7 +126,7 @@ func main() { os.Exit(QueryInstant(*queryServer, *queryExpr, p)) case queryRangeCmd.FullCommand(): - os.Exit(QueryRange(*queryRangeServer, *queryRangeExpr, *queryRangeBegin, *queryRangeEnd, *queryRangeStep, p)) + os.Exit(QueryRange(*queryRangeServer, *queryRangeHeaders, *queryRangeExpr, *queryRangeBegin, *queryRangeEnd, *queryRangeStep, p)) case querySeriesCmd.FullCommand(): os.Exit(QuerySeries(*querySeriesServer, *querySeriesMatch, *querySeriesBegin, *querySeriesEnd, p)) @@ -143,7 +146,6 @@ func main() { case testRulesCmd.FullCommand(): os.Exit(RulesUnitTest(*testRulesFiles...)) } - } // CheckConfig validates configuration files. @@ -361,11 +363,20 @@ func QueryInstant(url, query string, p printer) int { } // QueryRange performs a range query against a Prometheus server. -func QueryRange(url, query, start, end string, step time.Duration, p printer) int { +func QueryRange(url string, headers map[string]string, query, start, end string, step time.Duration, p printer) int { config := api.Config{ Address: url, } + if len(headers) > 0 { + config.RoundTripper = promhttp.RoundTripperFunc(func(req *http.Request) (*http.Response, error) { + for key, value := range headers { + req.Header.Add(key, value) + } + return http.DefaultTransport.RoundTrip(req) + }) + } + // Create new client. c, err := api.NewClient(config) if err != nil { diff --git a/cmd/promtool/main_test.go b/cmd/promtool/main_test.go index 84ff006db..80b5c1426 100644 --- a/cmd/promtool/main_test.go +++ b/cmd/promtool/main_test.go @@ -26,7 +26,7 @@ func TestQueryRange(t *testing.T) { defer s.Close() p := &promqlPrinter{} - exitCode := QueryRange(s.URL, "up", "0", "300", 0, p) + exitCode := QueryRange(s.URL, map[string]string{}, "up", "0", "300", 0, p) expectedPath := "/api/v1/query_range" gotPath := getRequest().URL.Path if gotPath != expectedPath { @@ -45,7 +45,7 @@ func TestQueryRange(t *testing.T) { t.Error() } - exitCode = QueryRange(s.URL, "up", "0", "300", 10*time.Millisecond, p) + exitCode = QueryRange(s.URL, map[string]string{}, "up", "0", "300", 10*time.Millisecond, p) gotPath = getRequest().URL.Path if gotPath != expectedPath { t.Errorf("unexpected URL path %s (wanted %s)", gotPath, expectedPath)