From 4bdf1adb6cec7d1806359f9596b52c121af50412 Mon Sep 17 00:00:00 2001 From: Johannes 'fish' Ziemke Date: Wed, 26 Jun 2013 12:25:20 +0200 Subject: [PATCH] Use github.com/miekg/dns for resolving SRV records --- .build/Makefile | 6 ++- retrieval/target_provider.go | 73 ++++++++++++++++++++++++++++++++++-- 2 files changed, 74 insertions(+), 5 deletions(-) diff --git a/.build/Makefile b/.build/Makefile index c4e73bc46..943bb87b2 100644 --- a/.build/Makefile +++ b/.build/Makefile @@ -45,7 +45,7 @@ cc-implementation-Linux-stamp: [ -x "$$(which cc)" ] || $(APT_GET_INSTALL) build-essential touch $@ -dependencies-stamp: cache-stamp cc-stamp leveldb-stamp snappy-stamp +dependencies-stamp: cache-stamp cc-stamp leveldb-stamp snappy-stamp godns-stamp touch $@ goprotobuf-protoc-gen-go-stamp: protoc-stamp goprotobuf-stamp @@ -56,6 +56,10 @@ goprotobuf-stamp: protoc-stamp $(GO_GET) code.google.com/p/goprotobuf/proto $(THIRD_PARTY_BUILD_OUTPUT) touch $@ +godns-stamp: + $(GO_GET) github.com/miekg/dns $(THIRD_PARTY_BUILD_OUTPUT) + touch $@ + leveldb-stamp: cache-stamp cache/leveldb-$(LEVELDB_VERSION).tar.gz cc-stamp rsync-stamp snappy-stamp tar xzvf cache/leveldb-$(LEVELDB_VERSION).tar.gz -C dirty $(THIRD_PARTY_BUILD_OUTPUT) cd dirty/leveldb-$(LEVELDB_VERSION) && CFLAGS="$(CFLAGS) -lsnappy" CXXFLAGS="$(CXXFLAGS) -lsnappy $(LDFLAGS)" LDFLAGS="-lsnappy $(LDFLAGS)" bash -x ./build_detect_platform build_config.mk ./ diff --git a/retrieval/target_provider.go b/retrieval/target_provider.go index 6a87f6964..8d5437839 100644 --- a/retrieval/target_provider.go +++ b/retrieval/target_provider.go @@ -15,16 +15,19 @@ package retrieval import ( "fmt" - "net" + "log" "net/url" "time" clientmodel "github.com/prometheus/client_golang/model" + "github.com/miekg/dns" "github.com/prometheus/prometheus/config" "github.com/prometheus/prometheus/utility" ) +const resolvConf = "/etc/resolv.conf" + // TargetProvider encapsulates retrieving all targets for a job. type TargetProvider interface { // Retrieves the current list of targets for this provider. @@ -57,7 +60,7 @@ func (p *sdTargetProvider) Targets() ([]Target, error) { return p.targets, nil } - _, addrs, err := net.LookupSRV("", "", p.job.GetSdName()) + response, err := lookupSRV(p.job.GetSdName()) if err != nil { return nil, err } @@ -66,12 +69,17 @@ func (p *sdTargetProvider) Targets() ([]Target, error) { clientmodel.JobLabel: clientmodel.LabelValue(p.job.GetName()), } - targets := make([]Target, 0, len(addrs)) + targets := make([]Target, 0, len(response.Answer)) endpoint := &url.URL{ Scheme: "http", Path: p.job.GetMetricsPath(), } - for _, addr := range addrs { + for _, record := range response.Answer { + addr, ok := record.(*dns.SRV) + if !ok { + log.Printf("%s is not a valid SRV record", addr) + continue + } // Remove the final dot from rooted DNS names to make them look more usual. if addr.Target[len(addr.Target)-1] == '.' { addr.Target = addr.Target[:len(addr.Target)-1] @@ -84,3 +92,60 @@ func (p *sdTargetProvider) Targets() ([]Target, error) { p.targets = targets return targets, nil } + +func lookupSRV(name string) (*dns.Msg, error) { + name = dns.Fqdn(name) + conf, err := dns.ClientConfigFromFile(resolvConf) + if err != nil { + return nil, fmt.Errorf("Couldn't load resolv.conf: %s", err) + } + client := &dns.Client{} + msg := &dns.Msg{} + msg.SetQuestion(name, dns.TypeSRV) + + response := &dns.Msg{} + for _, server := range conf.Servers { + server := fmt.Sprintf("%s:%s", server, conf.Port) + response, err = lookup(msg, client, server, false) + if err == nil { + return response, nil + } + } + return response, fmt.Errorf("Couldn't resolve %s: No server responded", name) +} + +func lookup(msg *dns.Msg, client *dns.Client, server string, edns bool) (*dns.Msg, error) { + if edns { + opt := &dns.OPT{ + Hdr: dns.RR_Header{ + Name: ".", + Rrtype: dns.TypeOPT, + }, + } + opt.SetUDPSize(dns.DefaultMsgSize) + msg.Extra = append(msg.Extra, opt) + } + + response, _, err := client.Exchange(msg, server) + if err != nil { + return nil, err + } + + if msg.Id != response.Id { + return nil, fmt.Errorf("DNS ID mismatch, request: %d, response: %d", msg.Id, response.Id) + } + + if response.MsgHdr.Truncated { + if client.Net == "tcp" { + return nil, fmt.Errorf("Got truncated message on tcp") + } + + if edns { // Truncated even though EDNS is used + client.Net = "tcp" + } + + return lookup(msg, client, server, !edns) + } + + return response, nil +}