cluster: fix panic when tls_client_config is empty

Closes #3403

Signed-off-by: Simon Pasquier <spasquie@redhat.com>
This commit is contained in:
Simon Pasquier 2023-08-04 14:29:05 +02:00
parent 130b8b6761
commit aea6204d58
6 changed files with 99 additions and 31 deletions

1
cluster/testdata/empty_tls_config.yml vendored Normal file
View File

@ -0,0 +1 @@
{}

View File

@ -0,0 +1,5 @@
tls_server_config:
cert_file: "certs/node2.pem"
key_file: "certs/node2-key.pem"
client_ca_file: "certs/ca.pem"
client_auth_type: "VerifyClientCertIfGiven"

View File

@ -0,0 +1,4 @@
tls_client_config:
cert_file: "certs/node1.pem"
key_file: "certs/node1-key.pem"
ca_file: "certs/ca.pem"

View File

@ -14,6 +14,7 @@
package cluster package cluster
import ( import (
"fmt"
"os" "os"
"path/filepath" "path/filepath"
@ -31,15 +32,25 @@ func GetTLSTransportConfig(configPath string) (*TLSTransportConfig, error) {
if configPath == "" { if configPath == "" {
return nil, nil return nil, nil
} }
bytes, err := os.ReadFile(configPath) bytes, err := os.ReadFile(configPath)
if err != nil { if err != nil {
return nil, err return nil, err
} }
cfg := &TLSTransportConfig{}
cfg := &TLSTransportConfig{
TLSClientConfig: &config.TLSConfig{},
}
if err := yaml.UnmarshalStrict(bytes, cfg); err != nil { if err := yaml.UnmarshalStrict(bytes, cfg); err != nil {
return nil, err return nil, err
} }
if cfg.TLSServerConfig == nil {
return nil, fmt.Errorf("missing 'tls_server_config' entry in the TLS configuration")
}
cfg.TLSServerConfig.SetDirectory(filepath.Dir(configPath)) cfg.TLSServerConfig.SetDirectory(filepath.Dir(configPath))
cfg.TLSClientConfig.SetDirectory(filepath.Dir(configPath)) cfg.TLSClientConfig.SetDirectory(filepath.Dir(configPath))
return cfg, nil return cfg, nil
} }

View File

@ -80,27 +80,33 @@ func NewTLSTransport(
if cfg == nil { if cfg == nil {
return nil, errors.New("must specify TLSTransportConfig") return nil, errors.New("must specify TLSTransportConfig")
} }
tlsServerCfg, err := web.ConfigToTLSConfig(cfg.TLSServerConfig) tlsServerCfg, err := web.ConfigToTLSConfig(cfg.TLSServerConfig)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "invalid TLS server config") return nil, errors.Wrap(err, "invalid TLS server config")
} }
tlsClientCfg, err := common.NewTLSConfig(cfg.TLSClientConfig) tlsClientCfg, err := common.NewTLSConfig(cfg.TLSClientConfig)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "invalid TLS client config") return nil, errors.Wrap(err, "invalid TLS client config")
} }
ip := net.ParseIP(bindAddr) ip := net.ParseIP(bindAddr)
if ip == nil { if ip == nil {
return nil, fmt.Errorf("invalid bind address \"%s\"", bindAddr) return nil, fmt.Errorf("invalid bind address \"%s\"", bindAddr)
} }
addr := &net.TCPAddr{IP: ip, Port: bindPort} addr := &net.TCPAddr{IP: ip, Port: bindPort}
listener, err := tls.Listen(network, addr.String(), tlsServerCfg) listener, err := tls.Listen(network, addr.String(), tlsServerCfg)
if err != nil { if err != nil {
return nil, errors.Wrap(err, fmt.Sprintf("failed to start TLS listener on %q port %d", bindAddr, bindPort)) return nil, errors.Wrap(err, fmt.Sprintf("failed to start TLS listener on %q port %d", bindAddr, bindPort))
} }
connPool, err := newConnectionPool(tlsClientCfg) connPool, err := newConnectionPool(tlsClientCfg)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "failed to initialize tls transport connection pool") return nil, errors.Wrap(err, "failed to initialize tls transport connection pool")
} }
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
t := &TLSTransport{ t := &TLSTransport{
ctx: ctx, ctx: ctx,

View File

@ -29,34 +29,72 @@ import (
var logger = log.NewNopLogger() var logger = log.NewNopLogger()
func newTLSTransport(file, address string, port int) (*TLSTransport, error) {
cfg, err := GetTLSTransportConfig(file)
if err != nil {
return nil, err
}
return NewTLSTransport(context2.Background(), log.NewNopLogger(), nil, address, port, cfg)
}
func TestNewTLSTransport(t *testing.T) { func TestNewTLSTransport(t *testing.T) {
testCases := []struct { for _, tc := range []struct {
bindAddr string bindAddr string
bindPort int bindPort int
tlsConfFile string tlsConfFile string
err string err string
}{ }{
{err: "must specify TLSTransportConfig"}, {
{err: "invalid bind address \"\"", tlsConfFile: "testdata/tls_config_node1.yml"}, err: "must specify TLSTransportConfig",
{bindAddr: "abc123", err: "invalid bind address \"abc123\"", tlsConfFile: "testdata/tls_config_node1.yml"}, },
{bindAddr: localhost, bindPort: 0, tlsConfFile: "testdata/tls_config_node1.yml"}, {
{bindAddr: localhost, bindPort: 9094, tlsConfFile: "testdata/tls_config_node2.yml"}, tlsConfFile: "testdata/empty_tls_config.yml",
} err: "missing 'tls_server_config' entry in the TLS configuration",
l := log.NewNopLogger() },
for _, tc := range testCases { {
cfg := mustTLSTransportConfig(tc.tlsConfFile) tlsConfFile: "testdata/tls_config_with_missing_server.yml",
transport, err := NewTLSTransport(context2.Background(), l, nil, tc.bindAddr, tc.bindPort, cfg) err: "missing 'tls_server_config' entry in the TLS configuration",
if len(tc.err) > 0 { },
require.Equal(t, tc.err, err.Error()) {
require.Nil(t, transport) err: "invalid bind address \"\"",
} else { tlsConfFile: "testdata/tls_config_node1.yml",
require.Nil(t, err) },
{
bindAddr: "abc123",
err: "invalid bind address \"abc123\"",
tlsConfFile: "testdata/tls_config_node1.yml",
},
{
bindAddr: localhost,
bindPort: 0,
tlsConfFile: "testdata/tls_config_node1.yml",
},
{
bindAddr: localhost,
bindPort: 9094,
tlsConfFile: "testdata/tls_config_node2.yml",
},
{
tlsConfFile: "testdata/tls_config_with_missing_client.yml",
bindAddr: localhost,
},
} {
t.Run("", func(t *testing.T) {
transport, err := newTLSTransport(tc.tlsConfFile, tc.bindAddr, tc.bindPort)
if len(tc.err) > 0 {
require.Error(t, err)
require.Equal(t, tc.err, err.Error())
return
}
defer transport.Shutdown()
require.NoError(t, err)
require.Equal(t, tc.bindAddr, transport.bindAddr) require.Equal(t, tc.bindAddr, transport.bindAddr)
require.Equal(t, tc.bindPort, transport.bindPort) require.Equal(t, tc.bindPort, transport.bindPort)
require.Equal(t, l, transport.logger)
require.NotNil(t, transport.listener) require.NotNil(t, transport.listener)
transport.Shutdown() })
}
} }
} }
@ -79,7 +117,7 @@ func TestFinalAdvertiseAddr(t *testing.T) {
{bindAddr: localhost, bindPort: 9095, inputIP: "", inputPort: 0, expectedIP: localhost, expectedPort: 9095}, {bindAddr: localhost, bindPort: 9095, inputIP: "", inputPort: 0, expectedIP: localhost, expectedPort: 9095},
} }
for _, tc := range testCases { for _, tc := range testCases {
tlsConf := mustTLSTransportConfig("testdata/tls_config_node1.yml") tlsConf := loadTLSTransportConfig(t, "testdata/tls_config_node1.yml")
transport, err := NewTLSTransport(context2.Background(), logger, nil, tc.bindAddr, tc.bindPort, tlsConf) transport, err := NewTLSTransport(context2.Background(), logger, nil, tc.bindAddr, tc.bindPort, tlsConf)
require.Nil(t, err) require.Nil(t, err)
ip, port, err := transport.FinalAdvertiseAddr(tc.inputIP, tc.inputPort) ip, port, err := transport.FinalAdvertiseAddr(tc.inputIP, tc.inputPort)
@ -104,11 +142,11 @@ func TestFinalAdvertiseAddr(t *testing.T) {
} }
func TestWriteTo(t *testing.T) { func TestWriteTo(t *testing.T) {
tlsConf1 := mustTLSTransportConfig("testdata/tls_config_node1.yml") tlsConf1 := loadTLSTransportConfig(t, "testdata/tls_config_node1.yml")
t1, _ := NewTLSTransport(context2.Background(), logger, nil, "127.0.0.1", 0, tlsConf1) t1, _ := NewTLSTransport(context2.Background(), logger, nil, "127.0.0.1", 0, tlsConf1)
defer t1.Shutdown() defer t1.Shutdown()
tlsConf2 := mustTLSTransportConfig("testdata/tls_config_node2.yml") tlsConf2 := loadTLSTransportConfig(t, "testdata/tls_config_node2.yml")
t2, _ := NewTLSTransport(context2.Background(), logger, nil, "127.0.0.1", 0, tlsConf2) t2, _ := NewTLSTransport(context2.Background(), logger, nil, "127.0.0.1", 0, tlsConf2)
defer t2.Shutdown() defer t2.Shutdown()
@ -123,11 +161,11 @@ func TestWriteTo(t *testing.T) {
} }
func BenchmarkWriteTo(b *testing.B) { func BenchmarkWriteTo(b *testing.B) {
tlsConf1 := mustTLSTransportConfig("testdata/tls_config_node1.yml") tlsConf1 := loadTLSTransportConfig(b, "testdata/tls_config_node1.yml")
t1, _ := NewTLSTransport(context2.Background(), logger, nil, "127.0.0.1", 0, tlsConf1) t1, _ := NewTLSTransport(context2.Background(), logger, nil, "127.0.0.1", 0, tlsConf1)
defer t1.Shutdown() defer t1.Shutdown()
tlsConf2 := mustTLSTransportConfig("testdata/tls_config_node2.yml") tlsConf2 := loadTLSTransportConfig(b, "testdata/tls_config_node2.yml")
t2, _ := NewTLSTransport(context2.Background(), logger, nil, "127.0.0.1", 0, tlsConf2) t2, _ := NewTLSTransport(context2.Background(), logger, nil, "127.0.0.1", 0, tlsConf2)
defer t2.Shutdown() defer t2.Shutdown()
@ -144,13 +182,13 @@ func BenchmarkWriteTo(b *testing.B) {
require.Equal(b, from, packet.From.String()) require.Equal(b, from, packet.From.String())
} }
func TestDialTimout(t *testing.T) { func TestDialTimeout(t *testing.T) {
tlsConf1 := mustTLSTransportConfig("testdata/tls_config_node1.yml") tlsConf1 := loadTLSTransportConfig(t, "testdata/tls_config_node1.yml")
t1, err := NewTLSTransport(context2.Background(), logger, nil, "127.0.0.1", 0, tlsConf1) t1, err := NewTLSTransport(context2.Background(), logger, nil, "127.0.0.1", 0, tlsConf1)
require.Nil(t, err) require.Nil(t, err)
defer t1.Shutdown() defer t1.Shutdown()
tlsConf2 := mustTLSTransportConfig("testdata/tls_config_node2.yml") tlsConf2 := loadTLSTransportConfig(t, "testdata/tls_config_node2.yml")
t2, err := NewTLSTransport(context2.Background(), logger, nil, "127.0.0.1", 0, tlsConf2) t2, err := NewTLSTransport(context2.Background(), logger, nil, "127.0.0.1", 0, tlsConf2)
require.Nil(t, err) require.Nil(t, err)
defer t2.Shutdown() defer t2.Shutdown()
@ -193,7 +231,7 @@ func (l *logWr) Write(p []byte) (n int, err error) {
} }
func TestShutdown(t *testing.T) { func TestShutdown(t *testing.T) {
tlsConf1 := mustTLSTransportConfig("testdata/tls_config_node1.yml") tlsConf1 := loadTLSTransportConfig(t, "testdata/tls_config_node1.yml")
l := &logWr{} l := &logWr{}
t1, _ := NewTLSTransport(context2.Background(), log.NewLogfmtLogger(l), nil, "127.0.0.1", 0, tlsConf1) t1, _ := NewTLSTransport(context2.Background(), log.NewLogfmtLogger(l), nil, "127.0.0.1", 0, tlsConf1)
// Sleeping to make sure listeners have started and can subsequently be shut down gracefully. // Sleeping to make sure listeners have started and can subsequently be shut down gracefully.
@ -204,10 +242,13 @@ func TestShutdown(t *testing.T) {
require.Contains(t, string(l.bytes), "shutting down tls transport") require.Contains(t, string(l.bytes), "shutting down tls transport")
} }
func mustTLSTransportConfig(filename string) *TLSTransportConfig { func loadTLSTransportConfig(tb testing.TB, filename string) *TLSTransportConfig {
tb.Helper()
config, err := GetTLSTransportConfig(filename) config, err := GetTLSTransportConfig(filename)
if err != nil { if err != nil {
panic(err) tb.Fatal(err)
} }
return config return config
} }