89 lines
2.1 KiB
Go
89 lines
2.1 KiB
Go
package zk
|
|
|
|
import (
|
|
"fmt"
|
|
"net"
|
|
"sync"
|
|
)
|
|
|
|
// DNSHostProvider is the default HostProvider. It currently matches
|
|
// the Java StaticHostProvider, resolving hosts from DNS once during
|
|
// the call to Init. It could be easily extended to re-query DNS
|
|
// periodically or if there is trouble connecting.
|
|
type DNSHostProvider struct {
|
|
mu sync.Mutex // Protects everything, so we can add asynchronous updates later.
|
|
servers []string
|
|
curr int
|
|
last int
|
|
lookupHost func(string) ([]string, error) // Override of net.LookupHost, for testing.
|
|
}
|
|
|
|
// Init is called first, with the servers specified in the connection
|
|
// string. It uses DNS to look up addresses for each server, then
|
|
// shuffles them all together.
|
|
func (hp *DNSHostProvider) Init(servers []string) error {
|
|
hp.mu.Lock()
|
|
defer hp.mu.Unlock()
|
|
|
|
lookupHost := hp.lookupHost
|
|
if lookupHost == nil {
|
|
lookupHost = net.LookupHost
|
|
}
|
|
|
|
found := []string{}
|
|
for _, server := range servers {
|
|
host, port, err := net.SplitHostPort(server)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
addrs, err := lookupHost(host)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for _, addr := range addrs {
|
|
found = append(found, net.JoinHostPort(addr, port))
|
|
}
|
|
}
|
|
|
|
if len(found) == 0 {
|
|
return fmt.Errorf("No hosts found for addresses %q", servers)
|
|
}
|
|
|
|
// Randomize the order of the servers to avoid creating hotspots
|
|
stringShuffle(found)
|
|
|
|
hp.servers = found
|
|
hp.curr = -1
|
|
hp.last = -1
|
|
|
|
return nil
|
|
}
|
|
|
|
// Len returns the number of servers available
|
|
func (hp *DNSHostProvider) Len() int {
|
|
hp.mu.Lock()
|
|
defer hp.mu.Unlock()
|
|
return len(hp.servers)
|
|
}
|
|
|
|
// Next returns the next server to connect to. retryStart will be true
|
|
// if we've looped through all known servers without Connected() being
|
|
// called.
|
|
func (hp *DNSHostProvider) Next() (server string, retryStart bool) {
|
|
hp.mu.Lock()
|
|
defer hp.mu.Unlock()
|
|
hp.curr = (hp.curr + 1) % len(hp.servers)
|
|
retryStart = hp.curr == hp.last
|
|
if hp.last == -1 {
|
|
hp.last = 0
|
|
}
|
|
return hp.servers[hp.curr], retryStart
|
|
}
|
|
|
|
// Connected notifies the HostProvider of a successful connection.
|
|
func (hp *DNSHostProvider) Connected() {
|
|
hp.mu.Lock()
|
|
defer hp.mu.Unlock()
|
|
hp.last = hp.curr
|
|
}
|