From de9d21df6aae924a64a798f8b2c5e734f41dfcca Mon Sep 17 00:00:00 2001
From: Joe Adams <github@joeadams.io>
Date: Wed, 24 Aug 2022 22:07:37 -0400
Subject: [PATCH] Add dsn type for handling datasources

dsn is designed to replace the other uses of dsn as a string in the long term. dsn is designed to be safe to log, properly redacting passwords. The goal is eventually always parse datasource information into a dsn type object which can safely be passed around and logged without worrying about wrapping calls in a redaction function (today this function is loggableDSN().

This should solve the root issue in #648, #677, and #643, although the full fix will require more changes to update all code references over to use the dsn type.

Signed-off-by: Joe Adams <github@joeadams.io>
---
 cmd/postgres_exporter/datasource.go      | 194 +++++++++++++++++++++
 cmd/postgres_exporter/datasource_test.go | 206 +++++++++++++++++++++++
 2 files changed, 400 insertions(+)
 create mode 100644 cmd/postgres_exporter/datasource_test.go

diff --git a/cmd/postgres_exporter/datasource.go b/cmd/postgres_exporter/datasource.go
index 716138f3..fdfcbd6a 100644
--- a/cmd/postgres_exporter/datasource.go
+++ b/cmd/postgres_exporter/datasource.go
@@ -20,6 +20,7 @@ import (
 	"os"
 	"regexp"
 	"strings"
+	"unicode"
 
 	"github.com/go-kit/log/level"
 	"github.com/prometheus/client_golang/prometheus"
@@ -172,3 +173,196 @@ func getDataSources() ([]string, error) {
 
 	return []string{dsn}, nil
 }
+
+// dsn represents a parsed datasource. It contains fields for the individual connection components.
+type dsn struct {
+	scheme   string
+	username string
+	password string
+	host     string
+	path     string
+	query    string
+}
+
+// String makes a dsn safe to print by excluding any passwords. This allows dsn to be used in
+// strings and log messages without needing to call a redaction function first.
+func (d dsn) String() string {
+	if d.password != "" {
+		return fmt.Sprintf("%s://%s:******@%s%s?%s", d.scheme, d.username, d.host, d.path, d.query)
+	}
+
+	if d.username != "" {
+		return fmt.Sprintf("%s://%s@%s%s?%s", d.scheme, d.username, d.host, d.path, d.query)
+	}
+
+	return fmt.Sprintf("%s://%s%s?%s", d.scheme, d.host, d.path, d.query)
+}
+
+// dsnFromString parses a connection string into a dsn. It will attempt to parse the string as
+// a URL and as a set of key=value pairs. If both attempts fail, dsnFromString will return an error.
+func dsnFromString(in string) (dsn, error) {
+	if strings.HasPrefix(in, "postgresql://") {
+		return dsnFromURL(in)
+	}
+
+	// Try to parse as key=value pairs
+	d, err := dsnFromKeyValue(in)
+	if err == nil {
+		return d, nil
+	}
+
+	return dsn{}, fmt.Errorf("could not understand DSN")
+}
+
+// dsnFromURL parses the input as a URL and returns the dsn representation.
+func dsnFromURL(in string) (dsn, error) {
+	u, err := url.Parse(in)
+	if err != nil {
+		return dsn{}, err
+	}
+	pass, _ := u.User.Password()
+	user := u.User.Username()
+
+	query := u.Query()
+
+	if queryPass := query.Get("password"); queryPass != "" {
+		if pass == "" {
+			pass = queryPass
+		}
+	}
+	query.Del("password")
+
+	if queryUser := query.Get("user"); queryUser != "" {
+		if user == "" {
+			user = queryUser
+		}
+	}
+	query.Del("user")
+
+	d := dsn{
+		scheme:   u.Scheme,
+		username: user,
+		password: pass,
+		host:     u.Host,
+		path:     u.Path,
+		query:    query.Encode(),
+	}
+
+	return d, nil
+}
+
+// dsnFromKeyValue parses the input as a set of key=value pairs and returns the dsn representation.
+func dsnFromKeyValue(in string) (dsn, error) {
+	// Attempt to confirm at least one key=value pair before starting the rune parser
+	connstringRe := regexp.MustCompile(`^ *[a-zA-Z0-9]+ *= *[^= ]+`)
+	if !connstringRe.MatchString(in) {
+		return dsn{}, fmt.Errorf("input is not a key-value DSN")
+	}
+
+	// Anything other than known fields should be part of the querystring
+	query := url.Values{}
+
+	pairs, err := parseKeyValue(in)
+	if err != nil {
+		return dsn{}, fmt.Errorf("failed to parse key-value DSN: %v", err)
+	}
+
+	// Build the dsn from the key=value pairs
+	d := dsn{
+		scheme: "postgresql",
+	}
+
+	hostname := ""
+	port := ""
+
+	for k, v := range pairs {
+		switch k {
+		case "host":
+			hostname = v
+		case "port":
+			port = v
+		case "user":
+			d.username = v
+		case "password":
+			d.password = v
+		default:
+			query.Set(k, v)
+		}
+	}
+
+	if hostname == "" {
+		hostname = "localhost"
+	}
+
+	if port == "" {
+		d.host = hostname
+	} else {
+		d.host = fmt.Sprintf("%s:%s", hostname, port)
+	}
+
+	d.query = query.Encode()
+
+	return d, nil
+}
+
+// parseKeyValue is a key=value parser. It loops over each rune to split out keys and values
+// and attempting to honor quoted values. parseKeyValue will return an error if it is unable
+// to properly parse the input.
+func parseKeyValue(in string) (map[string]string, error) {
+	out := map[string]string{}
+
+	inPart := false
+	inQuote := false
+	part := []rune{}
+	key := ""
+	for _, c := range in {
+		switch {
+		case unicode.In(c, unicode.Quotation_Mark):
+			if inQuote {
+				inQuote = false
+			} else {
+				inQuote = true
+			}
+		case unicode.In(c, unicode.White_Space):
+			if inPart {
+				if inQuote {
+					part = append(part, c)
+				} else {
+					// Are we finishing a key=value?
+					if key == "" {
+						return out, fmt.Errorf("invalid input")
+					}
+					out[key] = string(part)
+					inPart = false
+					part = []rune{}
+				}
+			} else {
+				// Are we finishing a key=value?
+				if key == "" {
+					return out, fmt.Errorf("invalid input")
+				}
+				out[key] = string(part)
+				inPart = false
+				part = []rune{}
+				// Do something with the value
+			}
+		case c == '=':
+			if inPart {
+				inPart = false
+				key = string(part)
+				part = []rune{}
+			} else {
+				return out, fmt.Errorf("invalid input")
+			}
+		default:
+			inPart = true
+			part = append(part, c)
+		}
+	}
+
+	if key != "" && len(part) > 0 {
+		out[key] = string(part)
+	}
+
+	return out, nil
+}
diff --git a/cmd/postgres_exporter/datasource_test.go b/cmd/postgres_exporter/datasource_test.go
new file mode 100644
index 00000000..02fb8dde
--- /dev/null
+++ b/cmd/postgres_exporter/datasource_test.go
@@ -0,0 +1,206 @@
+// Copyright 2022 The Prometheus Authors
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import (
+	"reflect"
+	"testing"
+)
+
+// Test_dsn_String is designed to test different dsn combinations for their string representation.
+// dsn.String() is designed to be safe to print, redacting any password information and these test
+// cases are intended to cover known cases.
+func Test_dsn_String(t *testing.T) {
+	type fields struct {
+		scheme   string
+		username string
+		password string
+		host     string
+		path     string
+		query    string
+	}
+	tests := []struct {
+		name   string
+		fields fields
+		want   string
+	}{
+		{
+			name: "Without Password",
+			fields: fields{
+				scheme:   "postgresql",
+				username: "test",
+				host:     "localhost:5432",
+				query:    "",
+			},
+			want: "postgresql://test@localhost:5432?",
+		},
+		{
+			name: "With Password",
+			fields: fields{
+				scheme:   "postgresql",
+				username: "test",
+				password: "supersecret",
+				host:     "localhost:5432",
+				query:    "",
+			},
+			want: "postgresql://test:******@localhost:5432?",
+		},
+		{
+			name: "With Password and Query String",
+			fields: fields{
+				scheme:   "postgresql",
+				username: "test",
+				password: "supersecret",
+				host:     "localhost:5432",
+				query:    "ssldisable=true",
+			},
+			want: "postgresql://test:******@localhost:5432?ssldisable=true",
+		},
+		{
+			name: "With Password, Path, and Query String",
+			fields: fields{
+				scheme:   "postgresql",
+				username: "test",
+				password: "supersecret",
+				host:     "localhost:5432",
+				path:     "/somevalue",
+				query:    "ssldisable=true",
+			},
+			want: "postgresql://test:******@localhost:5432/somevalue?ssldisable=true",
+		},
+	}
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			d := dsn{
+				scheme:   tt.fields.scheme,
+				username: tt.fields.username,
+				password: tt.fields.password,
+				host:     tt.fields.host,
+				path:     tt.fields.path,
+				query:    tt.fields.query,
+			}
+			if got := d.String(); got != tt.want {
+				t.Errorf("dsn.String() = %v, want %v", got, tt.want)
+			}
+		})
+	}
+}
+
+// Test_dsnFromString tests the dsnFromString function with known variations
+// of connection string inputs to ensure that it properly parses the input into
+// a dsn.
+func Test_dsnFromString(t *testing.T) {
+
+	tests := []struct {
+		name    string
+		input   string
+		want    dsn
+		wantErr bool
+	}{
+		{
+			name:  "Key value with password",
+			input: "host=host.example.com user=postgres port=5432 password=s3cr3t",
+			want: dsn{
+				scheme:   "postgresql",
+				host:     "host.example.com:5432",
+				username: "postgres",
+				password: "s3cr3t",
+			},
+			wantErr: false,
+		},
+		{
+			name:  "Key value with quoted password and space",
+			input: "host=host.example.com user=postgres port=5432 password=\"s3cr 3t\"",
+			want: dsn{
+				scheme:   "postgresql",
+				host:     "host.example.com:5432",
+				username: "postgres",
+				password: "s3cr 3t",
+			},
+			wantErr: false,
+		},
+		{
+			name:  "Key value with different order",
+			input: "password=abcde host=host.example.com user=postgres port=5432",
+			want: dsn{
+				scheme:   "postgresql",
+				host:     "host.example.com:5432",
+				username: "postgres",
+				password: "abcde",
+			},
+			wantErr: false,
+		},
+		{
+			name:  "Key value with different order, quoted password, duplicate password",
+			input: "password=abcde host=host.example.com user=postgres port=5432 password=\"s3cr 3t\"",
+			want: dsn{
+				scheme:   "postgresql",
+				host:     "host.example.com:5432",
+				username: "postgres",
+				password: "s3cr 3t",
+			},
+			wantErr: false,
+		},
+		{
+			name:  "URL with user in query string",
+			input: "postgresql://host.example.com:5432/tsdb?user=postgres",
+			want: dsn{
+				scheme:   "postgresql",
+				host:     "host.example.com:5432",
+				path:     "/tsdb",
+				query:    "",
+				username: "postgres",
+			},
+			wantErr: false,
+		},
+		{
+			name:  "URL with user and password",
+			input: "postgresql://user:s3cret@host.example.com:5432/tsdb?user=postgres",
+			want: dsn{
+				scheme:   "postgresql",
+				host:     "host.example.com:5432",
+				path:     "/tsdb",
+				query:    "",
+				username: "user",
+				password: "s3cret",
+			},
+			wantErr: false,
+		},
+		{
+			name:  "URL with user and password in query string",
+			input: "postgresql://host.example.com:5432/tsdb?user=postgres&password=s3cr3t",
+			want: dsn{
+				scheme:   "postgresql",
+				host:     "host.example.com:5432",
+				path:     "/tsdb",
+				query:    "",
+				username: "postgres",
+				password: "s3cr3t",
+			},
+			wantErr: false,
+		},
+	}
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			got, err := dsnFromString(tt.input)
+			if (err != nil) != tt.wantErr {
+				t.Errorf("dsnFromString() error = %v, wantErr %v", err, tt.wantErr)
+				return
+			}
+			if !reflect.DeepEqual(got, tt.want) {
+				t.Errorf("dsnFromString() = %+v, want %+v", got, tt.want)
+			}
+		})
+	}
+}