diff --git a/cluster/testdata/empty_tls_config.yml b/cluster/testdata/empty_tls_config.yml new file mode 100644 index 00000000..0967ef42 --- /dev/null +++ b/cluster/testdata/empty_tls_config.yml @@ -0,0 +1 @@ +{} diff --git a/cluster/testdata/tls_config_with_missing_client.yml b/cluster/testdata/tls_config_with_missing_client.yml new file mode 100644 index 00000000..de79b181 --- /dev/null +++ b/cluster/testdata/tls_config_with_missing_client.yml @@ -0,0 +1,5 @@ +tls_server_config: + cert_file: "certs/node2.pem" + key_file: "certs/node2-key.pem" + client_ca_file: "certs/ca.pem" + client_auth_type: "VerifyClientCertIfGiven" diff --git a/cluster/testdata/tls_config_with_missing_server.yml b/cluster/testdata/tls_config_with_missing_server.yml new file mode 100644 index 00000000..5542236d --- /dev/null +++ b/cluster/testdata/tls_config_with_missing_server.yml @@ -0,0 +1,4 @@ +tls_client_config: + cert_file: "certs/node1.pem" + key_file: "certs/node1-key.pem" + ca_file: "certs/ca.pem" diff --git a/cluster/tls_config.go b/cluster/tls_config.go index 135c2fa2..f4f14972 100644 --- a/cluster/tls_config.go +++ b/cluster/tls_config.go @@ -14,6 +14,7 @@ package cluster import ( + "fmt" "os" "path/filepath" @@ -31,15 +32,25 @@ func GetTLSTransportConfig(configPath string) (*TLSTransportConfig, error) { if configPath == "" { return nil, nil } + bytes, err := os.ReadFile(configPath) if err != nil { return nil, err } - cfg := &TLSTransportConfig{} + + cfg := &TLSTransportConfig{ + TLSClientConfig: &config.TLSConfig{}, + } if err := yaml.UnmarshalStrict(bytes, cfg); err != nil { return nil, err } + + if cfg.TLSServerConfig == nil { + return nil, fmt.Errorf("missing 'tls_server_config' entry in the TLS configuration") + } + cfg.TLSServerConfig.SetDirectory(filepath.Dir(configPath)) cfg.TLSClientConfig.SetDirectory(filepath.Dir(configPath)) + return cfg, nil } diff --git a/cluster/tls_transport.go b/cluster/tls_transport.go index 16fbe251..eb521e04 100644 --- a/cluster/tls_transport.go +++ b/cluster/tls_transport.go @@ -80,27 +80,33 @@ func NewTLSTransport( if cfg == nil { return nil, errors.New("must specify TLSTransportConfig") } + tlsServerCfg, err := web.ConfigToTLSConfig(cfg.TLSServerConfig) if err != nil { return nil, errors.Wrap(err, "invalid TLS server config") } + tlsClientCfg, err := common.NewTLSConfig(cfg.TLSClientConfig) if err != nil { return nil, errors.Wrap(err, "invalid TLS client config") } + ip := net.ParseIP(bindAddr) if ip == nil { return nil, fmt.Errorf("invalid bind address \"%s\"", bindAddr) } + addr := &net.TCPAddr{IP: ip, Port: bindPort} listener, err := tls.Listen(network, addr.String(), tlsServerCfg) if err != nil { return nil, errors.Wrap(err, fmt.Sprintf("failed to start TLS listener on %q port %d", bindAddr, bindPort)) } + connPool, err := newConnectionPool(tlsClientCfg) if err != nil { return nil, errors.Wrap(err, "failed to initialize tls transport connection pool") } + ctx, cancel := context.WithCancel(ctx) t := &TLSTransport{ ctx: ctx, diff --git a/cluster/tls_transport_test.go b/cluster/tls_transport_test.go index 3c3a11cb..7b2a9224 100644 --- a/cluster/tls_transport_test.go +++ b/cluster/tls_transport_test.go @@ -29,34 +29,72 @@ import ( var logger = log.NewNopLogger() +func newTLSTransport(file, address string, port int) (*TLSTransport, error) { + cfg, err := GetTLSTransportConfig(file) + if err != nil { + return nil, err + } + + return NewTLSTransport(context2.Background(), log.NewNopLogger(), nil, address, port, cfg) +} + func TestNewTLSTransport(t *testing.T) { - testCases := []struct { + for _, tc := range []struct { bindAddr string bindPort int tlsConfFile string err string }{ - {err: "must specify TLSTransportConfig"}, - {err: "invalid bind address \"\"", tlsConfFile: "testdata/tls_config_node1.yml"}, - {bindAddr: "abc123", err: "invalid bind address \"abc123\"", tlsConfFile: "testdata/tls_config_node1.yml"}, - {bindAddr: localhost, bindPort: 0, tlsConfFile: "testdata/tls_config_node1.yml"}, - {bindAddr: localhost, bindPort: 9094, tlsConfFile: "testdata/tls_config_node2.yml"}, - } - l := log.NewNopLogger() - for _, tc := range testCases { - cfg := mustTLSTransportConfig(tc.tlsConfFile) - transport, err := NewTLSTransport(context2.Background(), l, nil, tc.bindAddr, tc.bindPort, cfg) - if len(tc.err) > 0 { - require.Equal(t, tc.err, err.Error()) - require.Nil(t, transport) - } else { - require.Nil(t, err) + { + err: "must specify TLSTransportConfig", + }, + { + tlsConfFile: "testdata/empty_tls_config.yml", + err: "missing 'tls_server_config' entry in the TLS configuration", + }, + { + tlsConfFile: "testdata/tls_config_with_missing_server.yml", + err: "missing 'tls_server_config' entry in the TLS configuration", + }, + { + err: "invalid bind address \"\"", + tlsConfFile: "testdata/tls_config_node1.yml", + }, + { + bindAddr: "abc123", + err: "invalid bind address \"abc123\"", + tlsConfFile: "testdata/tls_config_node1.yml", + }, + { + bindAddr: localhost, + bindPort: 0, + tlsConfFile: "testdata/tls_config_node1.yml", + }, + { + bindAddr: localhost, + bindPort: 9094, + tlsConfFile: "testdata/tls_config_node2.yml", + }, + { + tlsConfFile: "testdata/tls_config_with_missing_client.yml", + bindAddr: localhost, + }, + } { + t.Run("", func(t *testing.T) { + transport, err := newTLSTransport(tc.tlsConfFile, tc.bindAddr, tc.bindPort) + if len(tc.err) > 0 { + require.Error(t, err) + require.Equal(t, tc.err, err.Error()) + return + } + + defer transport.Shutdown() + + require.NoError(t, err) require.Equal(t, tc.bindAddr, transport.bindAddr) require.Equal(t, tc.bindPort, transport.bindPort) - require.Equal(t, l, transport.logger) require.NotNil(t, transport.listener) - transport.Shutdown() - } + }) } } @@ -79,7 +117,7 @@ func TestFinalAdvertiseAddr(t *testing.T) { {bindAddr: localhost, bindPort: 9095, inputIP: "", inputPort: 0, expectedIP: localhost, expectedPort: 9095}, } for _, tc := range testCases { - tlsConf := mustTLSTransportConfig("testdata/tls_config_node1.yml") + tlsConf := loadTLSTransportConfig(t, "testdata/tls_config_node1.yml") transport, err := NewTLSTransport(context2.Background(), logger, nil, tc.bindAddr, tc.bindPort, tlsConf) require.Nil(t, err) ip, port, err := transport.FinalAdvertiseAddr(tc.inputIP, tc.inputPort) @@ -104,11 +142,11 @@ func TestFinalAdvertiseAddr(t *testing.T) { } func TestWriteTo(t *testing.T) { - tlsConf1 := mustTLSTransportConfig("testdata/tls_config_node1.yml") + tlsConf1 := loadTLSTransportConfig(t, "testdata/tls_config_node1.yml") t1, _ := NewTLSTransport(context2.Background(), logger, nil, "127.0.0.1", 0, tlsConf1) defer t1.Shutdown() - tlsConf2 := mustTLSTransportConfig("testdata/tls_config_node2.yml") + tlsConf2 := loadTLSTransportConfig(t, "testdata/tls_config_node2.yml") t2, _ := NewTLSTransport(context2.Background(), logger, nil, "127.0.0.1", 0, tlsConf2) defer t2.Shutdown() @@ -123,11 +161,11 @@ func TestWriteTo(t *testing.T) { } func BenchmarkWriteTo(b *testing.B) { - tlsConf1 := mustTLSTransportConfig("testdata/tls_config_node1.yml") + tlsConf1 := loadTLSTransportConfig(b, "testdata/tls_config_node1.yml") t1, _ := NewTLSTransport(context2.Background(), logger, nil, "127.0.0.1", 0, tlsConf1) defer t1.Shutdown() - tlsConf2 := mustTLSTransportConfig("testdata/tls_config_node2.yml") + tlsConf2 := loadTLSTransportConfig(b, "testdata/tls_config_node2.yml") t2, _ := NewTLSTransport(context2.Background(), logger, nil, "127.0.0.1", 0, tlsConf2) defer t2.Shutdown() @@ -144,13 +182,13 @@ func BenchmarkWriteTo(b *testing.B) { require.Equal(b, from, packet.From.String()) } -func TestDialTimout(t *testing.T) { - tlsConf1 := mustTLSTransportConfig("testdata/tls_config_node1.yml") +func TestDialTimeout(t *testing.T) { + tlsConf1 := loadTLSTransportConfig(t, "testdata/tls_config_node1.yml") t1, err := NewTLSTransport(context2.Background(), logger, nil, "127.0.0.1", 0, tlsConf1) require.Nil(t, err) defer t1.Shutdown() - tlsConf2 := mustTLSTransportConfig("testdata/tls_config_node2.yml") + tlsConf2 := loadTLSTransportConfig(t, "testdata/tls_config_node2.yml") t2, err := NewTLSTransport(context2.Background(), logger, nil, "127.0.0.1", 0, tlsConf2) require.Nil(t, err) defer t2.Shutdown() @@ -193,7 +231,7 @@ func (l *logWr) Write(p []byte) (n int, err error) { } func TestShutdown(t *testing.T) { - tlsConf1 := mustTLSTransportConfig("testdata/tls_config_node1.yml") + tlsConf1 := loadTLSTransportConfig(t, "testdata/tls_config_node1.yml") l := &logWr{} t1, _ := NewTLSTransport(context2.Background(), log.NewLogfmtLogger(l), nil, "127.0.0.1", 0, tlsConf1) // Sleeping to make sure listeners have started and can subsequently be shut down gracefully. @@ -204,10 +242,13 @@ func TestShutdown(t *testing.T) { require.Contains(t, string(l.bytes), "shutting down tls transport") } -func mustTLSTransportConfig(filename string) *TLSTransportConfig { +func loadTLSTransportConfig(tb testing.TB, filename string) *TLSTransportConfig { + tb.Helper() + config, err := GetTLSTransportConfig(filename) if err != nil { - panic(err) + tb.Fatal(err) } + return config }