// Copyright 2020 The Prometheus Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Forked from https://github.com/mxinden/memberlist-tls-transport. // Implements Transport interface so that all gossip communications occur via TLS over TCP. package cluster import ( "context" "crypto/tls" "errors" "fmt" "log/slog" "net" "strings" "time" "github.com/hashicorp/go-sockaddr" "github.com/hashicorp/memberlist" "github.com/prometheus/client_golang/prometheus" common "github.com/prometheus/common/config" "github.com/prometheus/exporter-toolkit/web" ) const ( metricNamespace = "alertmanager" metricSubsystem = "tls_transport" network = "tcp" ) // TLSTransport is a Transport implementation that uses TLS over TCP for both // packet and stream operations. type TLSTransport struct { ctx context.Context cancel context.CancelFunc logger *slog.Logger bindAddr string bindPort int done chan struct{} listener net.Listener packetCh chan *memberlist.Packet streamCh chan net.Conn connPool *connectionPool tlsServerCfg *tls.Config tlsClientCfg *tls.Config packetsSent prometheus.Counter packetsRcvd prometheus.Counter streamsSent prometheus.Counter streamsRcvd prometheus.Counter readErrs prometheus.Counter writeErrs *prometheus.CounterVec } // NewTLSTransport returns a TLS transport with the given configuration. // On successful initialization, a tls listener will be created and listening. // A valid bindAddr is required. If bindPort == 0, the system will assign // a free port automatically. func NewTLSTransport( ctx context.Context, logger *slog.Logger, reg prometheus.Registerer, bindAddr string, bindPort int, cfg *TLSTransportConfig, ) (*TLSTransport, error) { if cfg == nil { return nil, errors.New("must specify TLSTransportConfig") } tlsServerCfg, err := web.ConfigToTLSConfig(cfg.TLSServerConfig) if err != nil { return nil, fmt.Errorf("invalid TLS server config: %w", err) } tlsClientCfg, err := common.NewTLSConfig(cfg.TLSClientConfig) if err != nil { return nil, fmt.Errorf("invalid TLS client config: %w", err) } ip := net.ParseIP(bindAddr) if ip == nil { return nil, fmt.Errorf("invalid bind address \"%s\"", bindAddr) } addr := &net.TCPAddr{IP: ip, Port: bindPort} listener, err := tls.Listen(network, addr.String(), tlsServerCfg) if err != nil { return nil, fmt.Errorf("failed to start TLS listener on %q port %d: %w", bindAddr, bindPort, err) } connPool, err := newConnectionPool(tlsClientCfg) if err != nil { return nil, fmt.Errorf("failed to initialize tls transport connection pool: %w", err) } ctx, cancel := context.WithCancel(ctx) t := &TLSTransport{ ctx: ctx, cancel: cancel, logger: logger, bindAddr: bindAddr, bindPort: bindPort, done: make(chan struct{}), listener: listener, packetCh: make(chan *memberlist.Packet), streamCh: make(chan net.Conn), connPool: connPool, tlsServerCfg: tlsServerCfg, tlsClientCfg: tlsClientCfg, } t.registerMetrics(reg) go func() { t.listen() close(t.done) }() return t, nil } // FinalAdvertiseAddr is given the user's configured values (which // might be empty) and returns the desired IP and port to advertise to // the rest of the cluster. func (t *TLSTransport) FinalAdvertiseAddr(ip string, port int) (net.IP, int, error) { var advertiseAddr net.IP var advertisePort int if ip != "" { advertiseAddr = net.ParseIP(ip) if advertiseAddr == nil { return nil, 0, fmt.Errorf("failed to parse advertise address %q", ip) } if ip4 := advertiseAddr.To4(); ip4 != nil { advertiseAddr = ip4 } advertisePort = port } else { if t.bindAddr == "0.0.0.0" { // Otherwise, if we're not bound to a specific IP, let's // use a suitable private IP address. var err error ip, err = sockaddr.GetPrivateIP() if err != nil { return nil, 0, fmt.Errorf("failed to get interface addresses: %w", err) } if ip == "" { return nil, 0, fmt.Errorf("no private IP address found, and explicit IP not provided") } advertiseAddr = net.ParseIP(ip) if advertiseAddr == nil { return nil, 0, fmt.Errorf("failed to parse advertise address: %q", ip) } } else { advertiseAddr = t.listener.Addr().(*net.TCPAddr).IP } advertisePort = t.GetAutoBindPort() } return advertiseAddr, advertisePort, nil } // PacketCh returns a channel that can be read to receive incoming // packets from other peers. func (t *TLSTransport) PacketCh() <-chan *memberlist.Packet { return t.packetCh } // StreamCh returns a channel that can be read to handle incoming stream // connections from other peers. func (t *TLSTransport) StreamCh() <-chan net.Conn { return t.streamCh } // Shutdown is called when memberlist is shutting down; this gives the // TLS Transport a chance to clean up the listener and other goroutines. func (t *TLSTransport) Shutdown() error { t.logger.Debug("shutting down tls transport") t.cancel() err := t.listener.Close() t.connPool.shutdown() <-t.done return err } // WriteTo is a packet-oriented interface that borrows a connection // from the pool, and writes to it. It also returns a timestamp of when // the packet was written. func (t *TLSTransport) WriteTo(b []byte, addr string) (time.Time, error) { conn, err := t.connPool.borrowConnection(addr, DefaultTCPTimeout) if err != nil { t.writeErrs.WithLabelValues("packet").Inc() return time.Now(), fmt.Errorf("failed to dial: %w", err) } fromAddr := t.listener.Addr().String() err = conn.writePacket(fromAddr, b) if err != nil { t.writeErrs.WithLabelValues("packet").Inc() return time.Now(), fmt.Errorf("failed to write packet: %w", err) } t.packetsSent.Add(float64(len(b))) return time.Now(), nil } // DialTimeout is used to create a connection that allows memberlist // to perform two-way communications with a peer. func (t *TLSTransport) DialTimeout(addr string, timeout time.Duration) (net.Conn, error) { conn, err := dialTLSConn(addr, timeout, t.tlsClientCfg) if err != nil { t.writeErrs.WithLabelValues("stream").Inc() return nil, fmt.Errorf("failed to dial: %w", err) } err = conn.writeStream() netConn := conn.getRawConn() if err != nil { t.writeErrs.WithLabelValues("stream").Inc() return netConn, fmt.Errorf("failed to create stream connection: %w", err) } t.streamsSent.Inc() return netConn, nil } // GetAutoBindPort returns the bind port that was automatically given by the system // if a bindPort of 0 was specified during instantiation. func (t *TLSTransport) GetAutoBindPort() int { return t.listener.Addr().(*net.TCPAddr).Port } // listen starts up multiple handlers accepting concurrent connections. func (t *TLSTransport) listen() { for { select { case <-t.ctx.Done(): return default: conn, err := t.listener.Accept() if err != nil { // The error "use of closed network connection" is returned when the listener is closed. // It is not exported in a more reasonable way. See https://github.com/golang/go/issues/4373. if strings.Contains(err.Error(), "use of closed network connection") { return } t.readErrs.Inc() t.logger.Debug("error accepting connection", "err", err) } else { go t.handle(conn) } } } } func (t *TLSTransport) handle(conn net.Conn) { for { packet, err := rcvTLSConn(conn).read() if err != nil { t.logger.Debug("error reading from connection", "err", err) t.readErrs.Inc() return } select { case <-t.ctx.Done(): return default: if packet != nil { n := len(packet.Buf) t.packetCh <- packet t.packetsRcvd.Add(float64(n)) } else { t.streamCh <- conn t.streamsRcvd.Inc() return } } } } func (t *TLSTransport) registerMetrics(reg prometheus.Registerer) { t.packetsSent = prometheus.NewCounter( prometheus.CounterOpts{ Namespace: metricNamespace, Subsystem: metricSubsystem, Name: "packet_bytes_sent_total", Help: "The number of packet bytes sent to outgoing connections (excluding internal metadata).", }, ) t.packetsRcvd = prometheus.NewCounter( prometheus.CounterOpts{ Namespace: metricNamespace, Subsystem: metricSubsystem, Name: "packet_bytes_received_total", Help: "The number of packet bytes received from incoming connections (excluding internal metadata).", }, ) t.streamsSent = prometheus.NewCounter( prometheus.CounterOpts{ Namespace: metricNamespace, Subsystem: metricSubsystem, Name: "stream_connections_sent_total", Help: "The number of stream connections sent.", }, ) t.streamsRcvd = prometheus.NewCounter( prometheus.CounterOpts{ Namespace: metricNamespace, Subsystem: metricSubsystem, Name: "stream_connections_received_total", Help: "The number of stream connections received.", }, ) t.readErrs = prometheus.NewCounter( prometheus.CounterOpts{ Namespace: metricNamespace, Subsystem: metricSubsystem, Name: "read_errors_total", Help: "The number of errors encountered while reading from incoming connections.", }, ) t.writeErrs = prometheus.NewCounterVec( prometheus.CounterOpts{ Namespace: metricNamespace, Subsystem: metricSubsystem, Name: "write_errors_total", Help: "The number of errors encountered while writing to outgoing connections.", }, []string{"connection_type"}, ) if reg != nil { reg.MustRegister(t.packetsSent) reg.MustRegister(t.packetsRcvd) reg.MustRegister(t.streamsSent) reg.MustRegister(t.streamsRcvd) reg.MustRegister(t.readErrs) reg.MustRegister(t.writeErrs) } }