From ad16a32f4c86005796b7a803d6a3c2c0388c724d Mon Sep 17 00:00:00 2001
From: kr0m <kr0m@Garrus.alfaexploit.com>
Date: Wed, 29 Nov 2023 10:24:04 +0100
Subject: [PATCH] Fixed postgresql>=10 secondary server lag always 0, SuperQ
 proposed a more clean code solution :), pg_replication_test modified to test
 pgReplicationQueryBeforeVersion10 or pgReplicationQueryAfterVersion10
 depending of the postgresql version

Signed-off-by: kr0m <kr0m@Garrus.alfaexploit.com>
---
 collector/pg_replication.go      | 25 +++++++++++----
 collector/pg_replication_test.go | 52 ++++++++++++++++++++++++++++----
 2 files changed, 65 insertions(+), 12 deletions(-)

diff --git a/collector/pg_replication.go b/collector/pg_replication.go
index 6067cc9b..4f083d79 100644
--- a/collector/pg_replication.go
+++ b/collector/pg_replication.go
@@ -15,7 +15,7 @@ package collector
 
 import (
 	"context"
-
+	"github.com/blang/semver/v4"
 	"github.com/prometheus/client_golang/prometheus"
 )
 
@@ -52,23 +52,36 @@ var (
 		[]string{}, nil,
 	)
 
-	pgReplicationQuery = `SELECT
+	pgReplicationQueryBeforeVersion10 = `SELECT
 	CASE
 		WHEN NOT pg_is_in_recovery() THEN 0
-                WHEN pg_last_wal_receive_lsn () = pg_last_wal_replay_lsn () THEN 0
+        WHEN pg_last_wal_receive_lsn () = pg_last_wal_replay_lsn () THEN 0
 		ELSE GREATEST (0, EXTRACT(EPOCH FROM (now() - pg_last_xact_replay_timestamp())))
 	END AS lag,
 	CASE
 		WHEN pg_is_in_recovery() THEN 1
 		ELSE 0
 	END as is_replica`
+
+	pgReplicationQueryAfterVersion10 = `SELECT
+    CASE
+        WHEN NOT pg_is_in_recovery() THEN 0
+        ELSE GREATEST (0, EXTRACT(EPOCH FROM (now() - pg_last_xact_replay_timestamp())))
+    END AS lag,
+    CASE
+        WHEN pg_is_in_recovery() THEN 1
+        ELSE 0
+    END as is_replica`
 )
 
 func (c *PGReplicationCollector) Update(ctx context.Context, instance *instance, ch chan<- prometheus.Metric) error {
 	db := instance.getDB()
-	row := db.QueryRowContext(ctx,
-		pgReplicationQuery,
-	)
+	query := pgReplicationQueryBeforeVersion10
+	if instance.version.GE(semver.MustParse("10.0.0")) {
+		query = pgReplicationQueryAfterVersion10
+	}
+
+	row := db.QueryRowContext(ctx, query)
 
 	var lag float64
 	var isReplica int64
diff --git a/collector/pg_replication_test.go b/collector/pg_replication_test.go
index b6df698e..22259307 100644
--- a/collector/pg_replication_test.go
+++ b/collector/pg_replication_test.go
@@ -14,15 +14,15 @@ package collector
 
 import (
 	"context"
-	"testing"
-
 	"github.com/DATA-DOG/go-sqlmock"
+	"github.com/blang/semver/v4"
 	"github.com/prometheus/client_golang/prometheus"
 	dto "github.com/prometheus/client_model/go"
 	"github.com/smartystreets/goconvey/convey"
+	"testing"
 )
 
-func TestPgReplicationCollector(t *testing.T) {
+func TestPgReplicationCollectorBeforeVersion10(t *testing.T) {
 	db, mock, err := sqlmock.New()
 	if err != nil {
 		t.Fatalf("Error opening a stub db connection: %s", err)
@@ -32,9 +32,49 @@ func TestPgReplicationCollector(t *testing.T) {
 	inst := &instance{db: db}
 
 	columns := []string{"lag", "is_replica"}
-	rows := sqlmock.NewRows(columns).
-		AddRow(1000, 1)
-	mock.ExpectQuery(sanitizeQuery(pgReplicationQuery)).WillReturnRows(rows)
+	rows := sqlmock.NewRows(columns).AddRow(1000, 1)
+	mock.ExpectQuery(sanitizeQuery(pgReplicationQueryBeforeVersion10)).WillReturnRows(rows)
+
+	ch := make(chan prometheus.Metric)
+	go func() {
+		defer close(ch)
+		c := PGReplicationCollector{}
+
+		if err := c.Update(context.Background(), inst, ch); err != nil {
+			t.Errorf("Error calling PGReplicationCollector.Update: %s", err)
+		}
+	}()
+
+	expected := []MetricResult{
+		{labels: labelMap{}, value: 1000, metricType: dto.MetricType_GAUGE},
+		{labels: labelMap{}, value: 1, metricType: dto.MetricType_GAUGE},
+	}
+
+	convey.Convey("Metrics comparison", t, func() {
+		for _, expect := range expected {
+			m := readMetric(<-ch)
+			convey.So(expect, convey.ShouldResemble, m)
+		}
+	})
+	if err := mock.ExpectationsWereMet(); err != nil {
+		t.Errorf("there were unfulfilled exceptions: %s", err)
+	}
+}
+
+func TestPgReplicationCollectorAfterVersion10(t *testing.T) {
+	db, mock, err := sqlmock.New()
+	if err != nil {
+		t.Fatalf("Error opening a stub db connection: %s", err)
+	}
+	defer db.Close()
+
+	//inst := &instance{db: db}
+	// Force test with a defined DB instance version, so ExpectQuery(pgReplicationQueryAfterVersion10) will match with PGReplicationCollector.Update query variable value
+	inst := &instance{db: db, version: semver.MustParse("10.0.0")}
+
+	columns := []string{"lag", "is_replica"}
+	rows := sqlmock.NewRows(columns).AddRow(1000, 1)
+	mock.ExpectQuery(sanitizeQuery(pgReplicationQueryAfterVersion10)).WillReturnRows(rows)
 
 	ch := make(chan prometheus.Metric)
 	go func() {