From fe12924638d433c99b51f9acb1d7ebb9c1f40881 Mon Sep 17 00:00:00 2001
From: Kushal shukla <85934954+kushalShukla-web@users.noreply.github.com>
Date: Mon, 29 Jul 2024 07:28:08 -0400
Subject: [PATCH] promtool: JUnit-Format XML Test Results (#14506)

* Junit compatible output

Signed-off-by: Kushal Shukla <kushalshukla110@gmail.com>
---
 cmd/promtool/main.go           |  7 ++-
 cmd/promtool/unittest.go       | 40 +++++++++++++----
 cmd/promtool/unittest_test.go  | 50 +++++++++++++++++++++
 docs/command-line/promtool.md  |  9 ++++
 util/junitxml/junitxml.go      | 81 ++++++++++++++++++++++++++++++++++
 util/junitxml/junitxml_test.go | 66 +++++++++++++++++++++++++++
 6 files changed, 243 insertions(+), 10 deletions(-)
 create mode 100644 util/junitxml/junitxml.go
 create mode 100644 util/junitxml/junitxml_test.go

diff --git a/cmd/promtool/main.go b/cmd/promtool/main.go
index e1d275e97..1c8e1dd1c 100644
--- a/cmd/promtool/main.go
+++ b/cmd/promtool/main.go
@@ -204,6 +204,7 @@ func main() {
 	pushMetricsHeaders := pushMetricsCmd.Flag("header", "Prometheus remote write header.").StringMap()
 
 	testCmd := app.Command("test", "Unit testing.")
+	junitOutFile := testCmd.Flag("junit", "File path to store JUnit XML test results.").OpenFile(os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o644)
 	testRulesCmd := testCmd.Command("rules", "Unit tests for rules.")
 	testRulesRun := testRulesCmd.Flag("run", "If set, will only run test groups whose names match the regular expression. Can be specified multiple times.").Strings()
 	testRulesFiles := testRulesCmd.Arg(
@@ -378,7 +379,11 @@ func main() {
 		os.Exit(QueryLabels(serverURL, httpRoundTripper, *queryLabelsMatch, *queryLabelsName, *queryLabelsBegin, *queryLabelsEnd, p))
 
 	case testRulesCmd.FullCommand():
-		os.Exit(RulesUnitTest(
+		results := io.Discard
+		if *junitOutFile != nil {
+			results = *junitOutFile
+		}
+		os.Exit(RulesUnitTestResult(results,
 			promqltest.LazyLoaderOpts{
 				EnableAtModifier:     true,
 				EnableNegativeOffset: true,
diff --git a/cmd/promtool/unittest.go b/cmd/promtool/unittest.go
index 5451c5296..7030635d1 100644
--- a/cmd/promtool/unittest.go
+++ b/cmd/promtool/unittest.go
@@ -18,6 +18,7 @@ import (
 	"encoding/json"
 	"errors"
 	"fmt"
+	"io"
 	"os"
 	"path/filepath"
 	"sort"
@@ -29,9 +30,10 @@ import (
 	"github.com/google/go-cmp/cmp"
 	"github.com/grafana/regexp"
 	"github.com/nsf/jsondiff"
-	"github.com/prometheus/common/model"
 	"gopkg.in/yaml.v2"
 
+	"github.com/prometheus/common/model"
+
 	"github.com/prometheus/prometheus/model/histogram"
 	"github.com/prometheus/prometheus/model/labels"
 	"github.com/prometheus/prometheus/promql"
@@ -39,12 +41,18 @@ import (
 	"github.com/prometheus/prometheus/promql/promqltest"
 	"github.com/prometheus/prometheus/rules"
 	"github.com/prometheus/prometheus/storage"
+	"github.com/prometheus/prometheus/util/junitxml"
 )
 
 // RulesUnitTest does unit testing of rules based on the unit testing files provided.
 // More info about the file format can be found in the docs.
 func RulesUnitTest(queryOpts promqltest.LazyLoaderOpts, runStrings []string, diffFlag bool, files ...string) int {
+	return RulesUnitTestResult(io.Discard, queryOpts, runStrings, diffFlag, files...)
+}
+
+func RulesUnitTestResult(results io.Writer, queryOpts promqltest.LazyLoaderOpts, runStrings []string, diffFlag bool, files ...string) int {
 	failed := false
+	junit := &junitxml.JUnitXML{}
 
 	var run *regexp.Regexp
 	if runStrings != nil {
@@ -52,7 +60,7 @@ func RulesUnitTest(queryOpts promqltest.LazyLoaderOpts, runStrings []string, dif
 	}
 
 	for _, f := range files {
-		if errs := ruleUnitTest(f, queryOpts, run, diffFlag); errs != nil {
+		if errs := ruleUnitTest(f, queryOpts, run, diffFlag, junit.Suite(f)); errs != nil {
 			fmt.Fprintln(os.Stderr, "  FAILED:")
 			for _, e := range errs {
 				fmt.Fprintln(os.Stderr, e.Error())
@@ -64,25 +72,30 @@ func RulesUnitTest(queryOpts promqltest.LazyLoaderOpts, runStrings []string, dif
 		}
 		fmt.Println()
 	}
+	err := junit.WriteXML(results)
+	if err != nil {
+		fmt.Fprintf(os.Stderr, "failed to write JUnit XML: %s\n", err)
+	}
 	if failed {
 		return failureExitCode
 	}
 	return successExitCode
 }
 
-func ruleUnitTest(filename string, queryOpts promqltest.LazyLoaderOpts, run *regexp.Regexp, diffFlag bool) []error {
-	fmt.Println("Unit Testing: ", filename)
-
+func ruleUnitTest(filename string, queryOpts promqltest.LazyLoaderOpts, run *regexp.Regexp, diffFlag bool, ts *junitxml.TestSuite) []error {
 	b, err := os.ReadFile(filename)
 	if err != nil {
+		ts.Abort(err)
 		return []error{err}
 	}
 
 	var unitTestInp unitTestFile
 	if err := yaml.UnmarshalStrict(b, &unitTestInp); err != nil {
+		ts.Abort(err)
 		return []error{err}
 	}
 	if err := resolveAndGlobFilepaths(filepath.Dir(filename), &unitTestInp); err != nil {
+		ts.Abort(err)
 		return []error{err}
 	}
 
@@ -91,29 +104,38 @@ func ruleUnitTest(filename string, queryOpts promqltest.LazyLoaderOpts, run *reg
 	}
 
 	evalInterval := time.Duration(unitTestInp.EvaluationInterval)
-
+	ts.Settime(time.Now().Format("2006-01-02T15:04:05"))
 	// Giving number for groups mentioned in the file for ordering.
 	// Lower number group should be evaluated before higher number group.
 	groupOrderMap := make(map[string]int)
 	for i, gn := range unitTestInp.GroupEvalOrder {
 		if _, ok := groupOrderMap[gn]; ok {
-			return []error{fmt.Errorf("group name repeated in evaluation order: %s", gn)}
+			err := fmt.Errorf("group name repeated in evaluation order: %s", gn)
+			ts.Abort(err)
+			return []error{err}
 		}
 		groupOrderMap[gn] = i
 	}
 
 	// Testing.
 	var errs []error
-	for _, t := range unitTestInp.Tests {
+	for i, t := range unitTestInp.Tests {
 		if !matchesRun(t.TestGroupName, run) {
 			continue
 		}
-
+		testname := t.TestGroupName
+		if testname == "" {
+			testname = fmt.Sprintf("unnamed#%d", i)
+		}
+		tc := ts.Case(testname)
 		if t.Interval == 0 {
 			t.Interval = unitTestInp.EvaluationInterval
 		}
 		ers := t.test(evalInterval, groupOrderMap, queryOpts, diffFlag, unitTestInp.RuleFiles...)
 		if ers != nil {
+			for _, e := range ers {
+				tc.Fail(e.Error())
+			}
 			errs = append(errs, ers...)
 		}
 	}
diff --git a/cmd/promtool/unittest_test.go b/cmd/promtool/unittest_test.go
index 2dbd5a4e5..9bbac28e9 100644
--- a/cmd/promtool/unittest_test.go
+++ b/cmd/promtool/unittest_test.go
@@ -14,11 +14,15 @@
 package main
 
 import (
+	"bytes"
+	"encoding/xml"
+	"fmt"
 	"testing"
 
 	"github.com/stretchr/testify/require"
 
 	"github.com/prometheus/prometheus/promql/promqltest"
+	"github.com/prometheus/prometheus/util/junitxml"
 )
 
 func TestRulesUnitTest(t *testing.T) {
@@ -125,13 +129,59 @@ func TestRulesUnitTest(t *testing.T) {
 			want: 0,
 		},
 	}
+	reuseFiles := []string{}
+	reuseCount := [2]int{}
 	for _, tt := range tests {
+		if (tt.queryOpts == promqltest.LazyLoaderOpts{
+			EnableNegativeOffset: true,
+		} || tt.queryOpts == promqltest.LazyLoaderOpts{
+			EnableAtModifier: true,
+		}) {
+			reuseFiles = append(reuseFiles, tt.args.files...)
+			reuseCount[tt.want] += len(tt.args.files)
+		}
 		t.Run(tt.name, func(t *testing.T) {
 			if got := RulesUnitTest(tt.queryOpts, nil, false, tt.args.files...); got != tt.want {
 				t.Errorf("RulesUnitTest() = %v, want %v", got, tt.want)
 			}
 		})
 	}
+	t.Run("Junit xml output ", func(t *testing.T) {
+		var buf bytes.Buffer
+		if got := RulesUnitTestResult(&buf, promqltest.LazyLoaderOpts{}, nil, false, reuseFiles...); got != 1 {
+			t.Errorf("RulesUnitTestResults() = %v, want 1", got)
+		}
+		var test junitxml.JUnitXML
+		output := buf.Bytes()
+		err := xml.Unmarshal(output, &test)
+		if err != nil {
+			fmt.Println("error in decoding XML:", err)
+			return
+		}
+		var total int
+		var passes int
+		var failures int
+		var cases int
+		total = len(test.Suites)
+		if total != len(reuseFiles) {
+			t.Errorf("JUnit output had %d testsuite elements; expected %d\n", total, len(reuseFiles))
+		}
+
+		for _, i := range test.Suites {
+			if i.FailureCount == 0 {
+				passes++
+			} else {
+				failures++
+			}
+			cases += len(i.Cases)
+		}
+		if total != passes+failures {
+			t.Errorf("JUnit output mismatch: Total testsuites (%d) does not equal the sum of passes (%d) and failures (%d).", total, passes, failures)
+		}
+		if cases < total {
+			t.Errorf("JUnit output had %d suites without test cases\n", total-cases)
+		}
+	})
 }
 
 func TestRulesUnitTestRun(t *testing.T) {
diff --git a/docs/command-line/promtool.md b/docs/command-line/promtool.md
index 443cd3f0c..6bb80169a 100644
--- a/docs/command-line/promtool.md
+++ b/docs/command-line/promtool.md
@@ -442,6 +442,15 @@ Unit testing.
 
 
 
+#### Flags
+
+| Flag | Description |
+| --- | --- |
+| <code class="text-nowrap">--junit</code> | File path to store JUnit XML test results. |
+
+
+
+
 ##### `promtool test rules`
 
 Unit tests for rules.
diff --git a/util/junitxml/junitxml.go b/util/junitxml/junitxml.go
new file mode 100644
index 000000000..14e4b6dba
--- /dev/null
+++ b/util/junitxml/junitxml.go
@@ -0,0 +1,81 @@
+// Copyright 2024 The Prometheus Authors
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package junitxml
+
+import (
+	"encoding/xml"
+	"io"
+)
+
+type JUnitXML struct {
+	XMLName xml.Name     `xml:"testsuites"`
+	Suites  []*TestSuite `xml:"testsuite"`
+}
+
+type TestSuite struct {
+	Name         string      `xml:"name,attr"`
+	TestCount    int         `xml:"tests,attr"`
+	FailureCount int         `xml:"failures,attr"`
+	ErrorCount   int         `xml:"errors,attr"`
+	SkippedCount int         `xml:"skipped,attr"`
+	Timestamp    string      `xml:"timestamp,attr"`
+	Cases        []*TestCase `xml:"testcase"`
+}
+type TestCase struct {
+	Name     string   `xml:"name,attr"`
+	Failures []string `xml:"failure,omitempty"`
+	Error    string   `xml:"error,omitempty"`
+}
+
+func (j *JUnitXML) WriteXML(h io.Writer) error {
+	return xml.NewEncoder(h).Encode(j)
+}
+
+func (j *JUnitXML) Suite(name string) *TestSuite {
+	ts := &TestSuite{Name: name}
+	j.Suites = append(j.Suites, ts)
+	return ts
+}
+
+func (ts *TestSuite) Fail(f string) {
+	ts.FailureCount++
+	curt := ts.lastCase()
+	curt.Failures = append(curt.Failures, f)
+}
+
+func (ts *TestSuite) lastCase() *TestCase {
+	if len(ts.Cases) == 0 {
+		ts.Case("unknown")
+	}
+	return ts.Cases[len(ts.Cases)-1]
+}
+
+func (ts *TestSuite) Case(name string) *TestSuite {
+	j := &TestCase{
+		Name: name,
+	}
+	ts.Cases = append(ts.Cases, j)
+	ts.TestCount++
+	return ts
+}
+
+func (ts *TestSuite) Settime(name string) {
+	ts.Timestamp = name
+}
+
+func (ts *TestSuite) Abort(e error) {
+	ts.ErrorCount++
+	curt := ts.lastCase()
+	curt.Error = e.Error()
+}
diff --git a/util/junitxml/junitxml_test.go b/util/junitxml/junitxml_test.go
new file mode 100644
index 000000000..ad4d0293d
--- /dev/null
+++ b/util/junitxml/junitxml_test.go
@@ -0,0 +1,66 @@
+// Copyright 2024 The Prometheus Authors
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package junitxml
+
+import (
+	"bytes"
+	"encoding/xml"
+	"errors"
+	"testing"
+)
+
+func TestJunitOutput(t *testing.T) {
+	var buf bytes.Buffer
+	var test JUnitXML
+	x := FakeTestSuites()
+	if err := x.WriteXML(&buf); err != nil {
+		t.Fatalf("Failed to encode XML: %v", err)
+	}
+
+	output := buf.Bytes()
+
+	err := xml.Unmarshal(output, &test)
+	if err != nil {
+		t.Errorf("Unmarshal failed with error: %v", err)
+	}
+	var total int
+	var cases int
+	total = len(test.Suites)
+	if total != 3 {
+		t.Errorf("JUnit output had %d testsuite elements; expected 3\n", total)
+	}
+	for _, i := range test.Suites {
+		cases += len(i.Cases)
+	}
+
+	if cases != 7 {
+		t.Errorf("JUnit output had %d testcase; expected 7\n", cases)
+	}
+}
+
+func FakeTestSuites() *JUnitXML {
+	ju := &JUnitXML{}
+	good := ju.Suite("all good")
+	good.Case("alpha")
+	good.Case("beta")
+	good.Case("gamma")
+	mixed := ju.Suite("mixed")
+	mixed.Case("good")
+	bad := mixed.Case("bad")
+	bad.Fail("once")
+	bad.Fail("twice")
+	mixed.Case("ugly").Abort(errors.New("buggy"))
+	ju.Suite("fast").Fail("fail early")
+	return ju
+}