cluster: fix panic when `tls_client_config` is empty

Closes #3403

Signed-off-by: Simon Pasquier <spasquie@redhat.com>
This commit is contained in:
Simon Pasquier 2023-08-04 14:29:05 +02:00
parent 130b8b6761
commit aea6204d58
6 changed files with 99 additions and 31 deletions

1
cluster/testdata/empty_tls_config.yml vendored Normal file
View File

@ -0,0 +1 @@
{}

View File

@ -0,0 +1,5 @@
tls_server_config:
cert_file: "certs/node2.pem"
key_file: "certs/node2-key.pem"
client_ca_file: "certs/ca.pem"
client_auth_type: "VerifyClientCertIfGiven"

View File

@ -0,0 +1,4 @@
tls_client_config:
cert_file: "certs/node1.pem"
key_file: "certs/node1-key.pem"
ca_file: "certs/ca.pem"

View File

@ -14,6 +14,7 @@
package cluster
import (
"fmt"
"os"
"path/filepath"
@ -31,15 +32,25 @@ func GetTLSTransportConfig(configPath string) (*TLSTransportConfig, error) {
if configPath == "" {
return nil, nil
}
bytes, err := os.ReadFile(configPath)
if err != nil {
return nil, err
}
cfg := &TLSTransportConfig{}
cfg := &TLSTransportConfig{
TLSClientConfig: &config.TLSConfig{},
}
if err := yaml.UnmarshalStrict(bytes, cfg); err != nil {
return nil, err
}
if cfg.TLSServerConfig == nil {
return nil, fmt.Errorf("missing 'tls_server_config' entry in the TLS configuration")
}
cfg.TLSServerConfig.SetDirectory(filepath.Dir(configPath))
cfg.TLSClientConfig.SetDirectory(filepath.Dir(configPath))
return cfg, nil
}

View File

@ -80,27 +80,33 @@ func NewTLSTransport(
if cfg == nil {
return nil, errors.New("must specify TLSTransportConfig")
}
tlsServerCfg, err := web.ConfigToTLSConfig(cfg.TLSServerConfig)
if err != nil {
return nil, errors.Wrap(err, "invalid TLS server config")
}
tlsClientCfg, err := common.NewTLSConfig(cfg.TLSClientConfig)
if err != nil {
return nil, errors.Wrap(err, "invalid TLS client config")
}
ip := net.ParseIP(bindAddr)
if ip == nil {
return nil, fmt.Errorf("invalid bind address \"%s\"", bindAddr)
}
addr := &net.TCPAddr{IP: ip, Port: bindPort}
listener, err := tls.Listen(network, addr.String(), tlsServerCfg)
if err != nil {
return nil, errors.Wrap(err, fmt.Sprintf("failed to start TLS listener on %q port %d", bindAddr, bindPort))
}
connPool, err := newConnectionPool(tlsClientCfg)
if err != nil {
return nil, errors.Wrap(err, "failed to initialize tls transport connection pool")
}
ctx, cancel := context.WithCancel(ctx)
t := &TLSTransport{
ctx: ctx,

View File

@ -29,34 +29,72 @@ import (
var logger = log.NewNopLogger()
func newTLSTransport(file, address string, port int) (*TLSTransport, error) {
cfg, err := GetTLSTransportConfig(file)
if err != nil {
return nil, err
}
return NewTLSTransport(context2.Background(), log.NewNopLogger(), nil, address, port, cfg)
}
func TestNewTLSTransport(t *testing.T) {
testCases := []struct {
for _, tc := range []struct {
bindAddr string
bindPort int
tlsConfFile string
err string
}{
{err: "must specify TLSTransportConfig"},
{err: "invalid bind address \"\"", tlsConfFile: "testdata/tls_config_node1.yml"},
{bindAddr: "abc123", err: "invalid bind address \"abc123\"", tlsConfFile: "testdata/tls_config_node1.yml"},
{bindAddr: localhost, bindPort: 0, tlsConfFile: "testdata/tls_config_node1.yml"},
{bindAddr: localhost, bindPort: 9094, tlsConfFile: "testdata/tls_config_node2.yml"},
}
l := log.NewNopLogger()
for _, tc := range testCases {
cfg := mustTLSTransportConfig(tc.tlsConfFile)
transport, err := NewTLSTransport(context2.Background(), l, nil, tc.bindAddr, tc.bindPort, cfg)
if len(tc.err) > 0 {
require.Equal(t, tc.err, err.Error())
require.Nil(t, transport)
} else {
require.Nil(t, err)
{
err: "must specify TLSTransportConfig",
},
{
tlsConfFile: "testdata/empty_tls_config.yml",
err: "missing 'tls_server_config' entry in the TLS configuration",
},
{
tlsConfFile: "testdata/tls_config_with_missing_server.yml",
err: "missing 'tls_server_config' entry in the TLS configuration",
},
{
err: "invalid bind address \"\"",
tlsConfFile: "testdata/tls_config_node1.yml",
},
{
bindAddr: "abc123",
err: "invalid bind address \"abc123\"",
tlsConfFile: "testdata/tls_config_node1.yml",
},
{
bindAddr: localhost,
bindPort: 0,
tlsConfFile: "testdata/tls_config_node1.yml",
},
{
bindAddr: localhost,
bindPort: 9094,
tlsConfFile: "testdata/tls_config_node2.yml",
},
{
tlsConfFile: "testdata/tls_config_with_missing_client.yml",
bindAddr: localhost,
},
} {
t.Run("", func(t *testing.T) {
transport, err := newTLSTransport(tc.tlsConfFile, tc.bindAddr, tc.bindPort)
if len(tc.err) > 0 {
require.Error(t, err)
require.Equal(t, tc.err, err.Error())
return
}
defer transport.Shutdown()
require.NoError(t, err)
require.Equal(t, tc.bindAddr, transport.bindAddr)
require.Equal(t, tc.bindPort, transport.bindPort)
require.Equal(t, l, transport.logger)
require.NotNil(t, transport.listener)
transport.Shutdown()
}
})
}
}
@ -79,7 +117,7 @@ func TestFinalAdvertiseAddr(t *testing.T) {
{bindAddr: localhost, bindPort: 9095, inputIP: "", inputPort: 0, expectedIP: localhost, expectedPort: 9095},
}
for _, tc := range testCases {
tlsConf := mustTLSTransportConfig("testdata/tls_config_node1.yml")
tlsConf := loadTLSTransportConfig(t, "testdata/tls_config_node1.yml")
transport, err := NewTLSTransport(context2.Background(), logger, nil, tc.bindAddr, tc.bindPort, tlsConf)
require.Nil(t, err)
ip, port, err := transport.FinalAdvertiseAddr(tc.inputIP, tc.inputPort)
@ -104,11 +142,11 @@ func TestFinalAdvertiseAddr(t *testing.T) {
}
func TestWriteTo(t *testing.T) {
tlsConf1 := mustTLSTransportConfig("testdata/tls_config_node1.yml")
tlsConf1 := loadTLSTransportConfig(t, "testdata/tls_config_node1.yml")
t1, _ := NewTLSTransport(context2.Background(), logger, nil, "127.0.0.1", 0, tlsConf1)
defer t1.Shutdown()
tlsConf2 := mustTLSTransportConfig("testdata/tls_config_node2.yml")
tlsConf2 := loadTLSTransportConfig(t, "testdata/tls_config_node2.yml")
t2, _ := NewTLSTransport(context2.Background(), logger, nil, "127.0.0.1", 0, tlsConf2)
defer t2.Shutdown()
@ -123,11 +161,11 @@ func TestWriteTo(t *testing.T) {
}
func BenchmarkWriteTo(b *testing.B) {
tlsConf1 := mustTLSTransportConfig("testdata/tls_config_node1.yml")
tlsConf1 := loadTLSTransportConfig(b, "testdata/tls_config_node1.yml")
t1, _ := NewTLSTransport(context2.Background(), logger, nil, "127.0.0.1", 0, tlsConf1)
defer t1.Shutdown()
tlsConf2 := mustTLSTransportConfig("testdata/tls_config_node2.yml")
tlsConf2 := loadTLSTransportConfig(b, "testdata/tls_config_node2.yml")
t2, _ := NewTLSTransport(context2.Background(), logger, nil, "127.0.0.1", 0, tlsConf2)
defer t2.Shutdown()
@ -144,13 +182,13 @@ func BenchmarkWriteTo(b *testing.B) {
require.Equal(b, from, packet.From.String())
}
func TestDialTimout(t *testing.T) {
tlsConf1 := mustTLSTransportConfig("testdata/tls_config_node1.yml")
func TestDialTimeout(t *testing.T) {
tlsConf1 := loadTLSTransportConfig(t, "testdata/tls_config_node1.yml")
t1, err := NewTLSTransport(context2.Background(), logger, nil, "127.0.0.1", 0, tlsConf1)
require.Nil(t, err)
defer t1.Shutdown()
tlsConf2 := mustTLSTransportConfig("testdata/tls_config_node2.yml")
tlsConf2 := loadTLSTransportConfig(t, "testdata/tls_config_node2.yml")
t2, err := NewTLSTransport(context2.Background(), logger, nil, "127.0.0.1", 0, tlsConf2)
require.Nil(t, err)
defer t2.Shutdown()
@ -193,7 +231,7 @@ func (l *logWr) Write(p []byte) (n int, err error) {
}
func TestShutdown(t *testing.T) {
tlsConf1 := mustTLSTransportConfig("testdata/tls_config_node1.yml")
tlsConf1 := loadTLSTransportConfig(t, "testdata/tls_config_node1.yml")
l := &logWr{}
t1, _ := NewTLSTransport(context2.Background(), log.NewLogfmtLogger(l), nil, "127.0.0.1", 0, tlsConf1)
// Sleeping to make sure listeners have started and can subsequently be shut down gracefully.
@ -204,10 +242,13 @@ func TestShutdown(t *testing.T) {
require.Contains(t, string(l.bytes), "shutting down tls transport")
}
func mustTLSTransportConfig(filename string) *TLSTransportConfig {
func loadTLSTransportConfig(tb testing.TB, filename string) *TLSTransportConfig {
tb.Helper()
config, err := GetTLSTransportConfig(filename)
if err != nil {
panic(err)
tb.Fatal(err)
}
return config
}