mirror of
https://github.com/prometheus/alertmanager
synced 2025-02-17 02:57:01 +00:00
cluster: fix panic when tls_client_config
is empty
Closes #3403 Signed-off-by: Simon Pasquier <spasquie@redhat.com>
This commit is contained in:
parent
130b8b6761
commit
aea6204d58
1
cluster/testdata/empty_tls_config.yml
vendored
Normal file
1
cluster/testdata/empty_tls_config.yml
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
{}
|
5
cluster/testdata/tls_config_with_missing_client.yml
vendored
Normal file
5
cluster/testdata/tls_config_with_missing_client.yml
vendored
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
tls_server_config:
|
||||||
|
cert_file: "certs/node2.pem"
|
||||||
|
key_file: "certs/node2-key.pem"
|
||||||
|
client_ca_file: "certs/ca.pem"
|
||||||
|
client_auth_type: "VerifyClientCertIfGiven"
|
4
cluster/testdata/tls_config_with_missing_server.yml
vendored
Normal file
4
cluster/testdata/tls_config_with_missing_server.yml
vendored
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
tls_client_config:
|
||||||
|
cert_file: "certs/node1.pem"
|
||||||
|
key_file: "certs/node1-key.pem"
|
||||||
|
ca_file: "certs/ca.pem"
|
@ -14,6 +14,7 @@
|
|||||||
package cluster
|
package cluster
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
|
||||||
@ -31,15 +32,25 @@ func GetTLSTransportConfig(configPath string) (*TLSTransportConfig, error) {
|
|||||||
if configPath == "" {
|
if configPath == "" {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
bytes, err := os.ReadFile(configPath)
|
bytes, err := os.ReadFile(configPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
cfg := &TLSTransportConfig{}
|
|
||||||
|
cfg := &TLSTransportConfig{
|
||||||
|
TLSClientConfig: &config.TLSConfig{},
|
||||||
|
}
|
||||||
if err := yaml.UnmarshalStrict(bytes, cfg); err != nil {
|
if err := yaml.UnmarshalStrict(bytes, cfg); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if cfg.TLSServerConfig == nil {
|
||||||
|
return nil, fmt.Errorf("missing 'tls_server_config' entry in the TLS configuration")
|
||||||
|
}
|
||||||
|
|
||||||
cfg.TLSServerConfig.SetDirectory(filepath.Dir(configPath))
|
cfg.TLSServerConfig.SetDirectory(filepath.Dir(configPath))
|
||||||
cfg.TLSClientConfig.SetDirectory(filepath.Dir(configPath))
|
cfg.TLSClientConfig.SetDirectory(filepath.Dir(configPath))
|
||||||
|
|
||||||
return cfg, nil
|
return cfg, nil
|
||||||
}
|
}
|
||||||
|
@ -80,27 +80,33 @@ func NewTLSTransport(
|
|||||||
if cfg == nil {
|
if cfg == nil {
|
||||||
return nil, errors.New("must specify TLSTransportConfig")
|
return nil, errors.New("must specify TLSTransportConfig")
|
||||||
}
|
}
|
||||||
|
|
||||||
tlsServerCfg, err := web.ConfigToTLSConfig(cfg.TLSServerConfig)
|
tlsServerCfg, err := web.ConfigToTLSConfig(cfg.TLSServerConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "invalid TLS server config")
|
return nil, errors.Wrap(err, "invalid TLS server config")
|
||||||
}
|
}
|
||||||
|
|
||||||
tlsClientCfg, err := common.NewTLSConfig(cfg.TLSClientConfig)
|
tlsClientCfg, err := common.NewTLSConfig(cfg.TLSClientConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "invalid TLS client config")
|
return nil, errors.Wrap(err, "invalid TLS client config")
|
||||||
}
|
}
|
||||||
|
|
||||||
ip := net.ParseIP(bindAddr)
|
ip := net.ParseIP(bindAddr)
|
||||||
if ip == nil {
|
if ip == nil {
|
||||||
return nil, fmt.Errorf("invalid bind address \"%s\"", bindAddr)
|
return nil, fmt.Errorf("invalid bind address \"%s\"", bindAddr)
|
||||||
}
|
}
|
||||||
|
|
||||||
addr := &net.TCPAddr{IP: ip, Port: bindPort}
|
addr := &net.TCPAddr{IP: ip, Port: bindPort}
|
||||||
listener, err := tls.Listen(network, addr.String(), tlsServerCfg)
|
listener, err := tls.Listen(network, addr.String(), tlsServerCfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, fmt.Sprintf("failed to start TLS listener on %q port %d", bindAddr, bindPort))
|
return nil, errors.Wrap(err, fmt.Sprintf("failed to start TLS listener on %q port %d", bindAddr, bindPort))
|
||||||
}
|
}
|
||||||
|
|
||||||
connPool, err := newConnectionPool(tlsClientCfg)
|
connPool, err := newConnectionPool(tlsClientCfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "failed to initialize tls transport connection pool")
|
return nil, errors.Wrap(err, "failed to initialize tls transport connection pool")
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
t := &TLSTransport{
|
t := &TLSTransport{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
@ -29,34 +29,72 @@ import (
|
|||||||
|
|
||||||
var logger = log.NewNopLogger()
|
var logger = log.NewNopLogger()
|
||||||
|
|
||||||
|
func newTLSTransport(file, address string, port int) (*TLSTransport, error) {
|
||||||
|
cfg, err := GetTLSTransportConfig(file)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return NewTLSTransport(context2.Background(), log.NewNopLogger(), nil, address, port, cfg)
|
||||||
|
}
|
||||||
|
|
||||||
func TestNewTLSTransport(t *testing.T) {
|
func TestNewTLSTransport(t *testing.T) {
|
||||||
testCases := []struct {
|
for _, tc := range []struct {
|
||||||
bindAddr string
|
bindAddr string
|
||||||
bindPort int
|
bindPort int
|
||||||
tlsConfFile string
|
tlsConfFile string
|
||||||
err string
|
err string
|
||||||
}{
|
}{
|
||||||
{err: "must specify TLSTransportConfig"},
|
{
|
||||||
{err: "invalid bind address \"\"", tlsConfFile: "testdata/tls_config_node1.yml"},
|
err: "must specify TLSTransportConfig",
|
||||||
{bindAddr: "abc123", err: "invalid bind address \"abc123\"", tlsConfFile: "testdata/tls_config_node1.yml"},
|
},
|
||||||
{bindAddr: localhost, bindPort: 0, tlsConfFile: "testdata/tls_config_node1.yml"},
|
{
|
||||||
{bindAddr: localhost, bindPort: 9094, tlsConfFile: "testdata/tls_config_node2.yml"},
|
tlsConfFile: "testdata/empty_tls_config.yml",
|
||||||
}
|
err: "missing 'tls_server_config' entry in the TLS configuration",
|
||||||
l := log.NewNopLogger()
|
},
|
||||||
for _, tc := range testCases {
|
{
|
||||||
cfg := mustTLSTransportConfig(tc.tlsConfFile)
|
tlsConfFile: "testdata/tls_config_with_missing_server.yml",
|
||||||
transport, err := NewTLSTransport(context2.Background(), l, nil, tc.bindAddr, tc.bindPort, cfg)
|
err: "missing 'tls_server_config' entry in the TLS configuration",
|
||||||
if len(tc.err) > 0 {
|
},
|
||||||
require.Equal(t, tc.err, err.Error())
|
{
|
||||||
require.Nil(t, transport)
|
err: "invalid bind address \"\"",
|
||||||
} else {
|
tlsConfFile: "testdata/tls_config_node1.yml",
|
||||||
require.Nil(t, err)
|
},
|
||||||
|
{
|
||||||
|
bindAddr: "abc123",
|
||||||
|
err: "invalid bind address \"abc123\"",
|
||||||
|
tlsConfFile: "testdata/tls_config_node1.yml",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
bindAddr: localhost,
|
||||||
|
bindPort: 0,
|
||||||
|
tlsConfFile: "testdata/tls_config_node1.yml",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
bindAddr: localhost,
|
||||||
|
bindPort: 9094,
|
||||||
|
tlsConfFile: "testdata/tls_config_node2.yml",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
tlsConfFile: "testdata/tls_config_with_missing_client.yml",
|
||||||
|
bindAddr: localhost,
|
||||||
|
},
|
||||||
|
} {
|
||||||
|
t.Run("", func(t *testing.T) {
|
||||||
|
transport, err := newTLSTransport(tc.tlsConfFile, tc.bindAddr, tc.bindPort)
|
||||||
|
if len(tc.err) > 0 {
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Equal(t, tc.err, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
defer transport.Shutdown()
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
require.Equal(t, tc.bindAddr, transport.bindAddr)
|
require.Equal(t, tc.bindAddr, transport.bindAddr)
|
||||||
require.Equal(t, tc.bindPort, transport.bindPort)
|
require.Equal(t, tc.bindPort, transport.bindPort)
|
||||||
require.Equal(t, l, transport.logger)
|
|
||||||
require.NotNil(t, transport.listener)
|
require.NotNil(t, transport.listener)
|
||||||
transport.Shutdown()
|
})
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -79,7 +117,7 @@ func TestFinalAdvertiseAddr(t *testing.T) {
|
|||||||
{bindAddr: localhost, bindPort: 9095, inputIP: "", inputPort: 0, expectedIP: localhost, expectedPort: 9095},
|
{bindAddr: localhost, bindPort: 9095, inputIP: "", inputPort: 0, expectedIP: localhost, expectedPort: 9095},
|
||||||
}
|
}
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
tlsConf := mustTLSTransportConfig("testdata/tls_config_node1.yml")
|
tlsConf := loadTLSTransportConfig(t, "testdata/tls_config_node1.yml")
|
||||||
transport, err := NewTLSTransport(context2.Background(), logger, nil, tc.bindAddr, tc.bindPort, tlsConf)
|
transport, err := NewTLSTransport(context2.Background(), logger, nil, tc.bindAddr, tc.bindPort, tlsConf)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
ip, port, err := transport.FinalAdvertiseAddr(tc.inputIP, tc.inputPort)
|
ip, port, err := transport.FinalAdvertiseAddr(tc.inputIP, tc.inputPort)
|
||||||
@ -104,11 +142,11 @@ func TestFinalAdvertiseAddr(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestWriteTo(t *testing.T) {
|
func TestWriteTo(t *testing.T) {
|
||||||
tlsConf1 := mustTLSTransportConfig("testdata/tls_config_node1.yml")
|
tlsConf1 := loadTLSTransportConfig(t, "testdata/tls_config_node1.yml")
|
||||||
t1, _ := NewTLSTransport(context2.Background(), logger, nil, "127.0.0.1", 0, tlsConf1)
|
t1, _ := NewTLSTransport(context2.Background(), logger, nil, "127.0.0.1", 0, tlsConf1)
|
||||||
defer t1.Shutdown()
|
defer t1.Shutdown()
|
||||||
|
|
||||||
tlsConf2 := mustTLSTransportConfig("testdata/tls_config_node2.yml")
|
tlsConf2 := loadTLSTransportConfig(t, "testdata/tls_config_node2.yml")
|
||||||
t2, _ := NewTLSTransport(context2.Background(), logger, nil, "127.0.0.1", 0, tlsConf2)
|
t2, _ := NewTLSTransport(context2.Background(), logger, nil, "127.0.0.1", 0, tlsConf2)
|
||||||
defer t2.Shutdown()
|
defer t2.Shutdown()
|
||||||
|
|
||||||
@ -123,11 +161,11 @@ func TestWriteTo(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkWriteTo(b *testing.B) {
|
func BenchmarkWriteTo(b *testing.B) {
|
||||||
tlsConf1 := mustTLSTransportConfig("testdata/tls_config_node1.yml")
|
tlsConf1 := loadTLSTransportConfig(b, "testdata/tls_config_node1.yml")
|
||||||
t1, _ := NewTLSTransport(context2.Background(), logger, nil, "127.0.0.1", 0, tlsConf1)
|
t1, _ := NewTLSTransport(context2.Background(), logger, nil, "127.0.0.1", 0, tlsConf1)
|
||||||
defer t1.Shutdown()
|
defer t1.Shutdown()
|
||||||
|
|
||||||
tlsConf2 := mustTLSTransportConfig("testdata/tls_config_node2.yml")
|
tlsConf2 := loadTLSTransportConfig(b, "testdata/tls_config_node2.yml")
|
||||||
t2, _ := NewTLSTransport(context2.Background(), logger, nil, "127.0.0.1", 0, tlsConf2)
|
t2, _ := NewTLSTransport(context2.Background(), logger, nil, "127.0.0.1", 0, tlsConf2)
|
||||||
defer t2.Shutdown()
|
defer t2.Shutdown()
|
||||||
|
|
||||||
@ -144,13 +182,13 @@ func BenchmarkWriteTo(b *testing.B) {
|
|||||||
require.Equal(b, from, packet.From.String())
|
require.Equal(b, from, packet.From.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDialTimout(t *testing.T) {
|
func TestDialTimeout(t *testing.T) {
|
||||||
tlsConf1 := mustTLSTransportConfig("testdata/tls_config_node1.yml")
|
tlsConf1 := loadTLSTransportConfig(t, "testdata/tls_config_node1.yml")
|
||||||
t1, err := NewTLSTransport(context2.Background(), logger, nil, "127.0.0.1", 0, tlsConf1)
|
t1, err := NewTLSTransport(context2.Background(), logger, nil, "127.0.0.1", 0, tlsConf1)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
defer t1.Shutdown()
|
defer t1.Shutdown()
|
||||||
|
|
||||||
tlsConf2 := mustTLSTransportConfig("testdata/tls_config_node2.yml")
|
tlsConf2 := loadTLSTransportConfig(t, "testdata/tls_config_node2.yml")
|
||||||
t2, err := NewTLSTransport(context2.Background(), logger, nil, "127.0.0.1", 0, tlsConf2)
|
t2, err := NewTLSTransport(context2.Background(), logger, nil, "127.0.0.1", 0, tlsConf2)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
defer t2.Shutdown()
|
defer t2.Shutdown()
|
||||||
@ -193,7 +231,7 @@ func (l *logWr) Write(p []byte) (n int, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestShutdown(t *testing.T) {
|
func TestShutdown(t *testing.T) {
|
||||||
tlsConf1 := mustTLSTransportConfig("testdata/tls_config_node1.yml")
|
tlsConf1 := loadTLSTransportConfig(t, "testdata/tls_config_node1.yml")
|
||||||
l := &logWr{}
|
l := &logWr{}
|
||||||
t1, _ := NewTLSTransport(context2.Background(), log.NewLogfmtLogger(l), nil, "127.0.0.1", 0, tlsConf1)
|
t1, _ := NewTLSTransport(context2.Background(), log.NewLogfmtLogger(l), nil, "127.0.0.1", 0, tlsConf1)
|
||||||
// Sleeping to make sure listeners have started and can subsequently be shut down gracefully.
|
// Sleeping to make sure listeners have started and can subsequently be shut down gracefully.
|
||||||
@ -204,10 +242,13 @@ func TestShutdown(t *testing.T) {
|
|||||||
require.Contains(t, string(l.bytes), "shutting down tls transport")
|
require.Contains(t, string(l.bytes), "shutting down tls transport")
|
||||||
}
|
}
|
||||||
|
|
||||||
func mustTLSTransportConfig(filename string) *TLSTransportConfig {
|
func loadTLSTransportConfig(tb testing.TB, filename string) *TLSTransportConfig {
|
||||||
|
tb.Helper()
|
||||||
|
|
||||||
config, err := GetTLSTransportConfig(filename)
|
config, err := GetTLSTransportConfig(filename)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
tb.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return config
|
return config
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user