Allow the use of bearer_token or bearer_token_file for MarathonSD

This commit is contained in:
Michael Kraus 2017-03-02 09:44:20 +01:00
parent 0a7fb56b16
commit 47bdcf0f67
3 changed files with 39 additions and 9 deletions

View File

@ -241,6 +241,7 @@ func resolveFilepaths(baseDir string, cfg *Config) {
kcfg.TLSConfig.KeyFile = join(kcfg.TLSConfig.KeyFile) kcfg.TLSConfig.KeyFile = join(kcfg.TLSConfig.KeyFile)
} }
for _, mcfg := range cfg.MarathonSDConfigs { for _, mcfg := range cfg.MarathonSDConfigs {
mcfg.BearerTokenFile = join(mcfg.BearerTokenFile)
mcfg.TLSConfig.CAFile = join(mcfg.TLSConfig.CAFile) mcfg.TLSConfig.CAFile = join(mcfg.TLSConfig.CAFile)
mcfg.TLSConfig.CertFile = join(mcfg.TLSConfig.CertFile) mcfg.TLSConfig.CertFile = join(mcfg.TLSConfig.CertFile)
mcfg.TLSConfig.KeyFile = join(mcfg.TLSConfig.KeyFile) mcfg.TLSConfig.KeyFile = join(mcfg.TLSConfig.KeyFile)
@ -920,6 +921,8 @@ type MarathonSDConfig struct {
Timeout model.Duration `yaml:"timeout,omitempty"` Timeout model.Duration `yaml:"timeout,omitempty"`
RefreshInterval model.Duration `yaml:"refresh_interval,omitempty"` RefreshInterval model.Duration `yaml:"refresh_interval,omitempty"`
TLSConfig TLSConfig `yaml:"tls_config,omitempty"` TLSConfig TLSConfig `yaml:"tls_config,omitempty"`
BearerToken string `yaml:"bearer_token,omitempty"`
BearerTokenFile string `yaml:"bearer_token_file,omitempty"`
// Catches all undefined fields and must be empty after parsing. // Catches all undefined fields and must be empty after parsing.
XXX map[string]interface{} `yaml:",inline"` XXX map[string]interface{} `yaml:",inline"`
@ -939,6 +942,12 @@ func (c *MarathonSDConfig) UnmarshalYAML(unmarshal func(interface{}) error) erro
if len(c.Servers) == 0 { if len(c.Servers) == 0 {
return fmt.Errorf("Marathon SD config must contain at least one Marathon server") return fmt.Errorf("Marathon SD config must contain at least one Marathon server")
} }
if len(c.BearerToken) > 0 && len(c.BearerTokenFile) > 0 {
return fmt.Errorf("at most one of bearer_token & bearer_token_file must be configured")
}
if len(c.BearerToken) == 0 && len(c.BearerTokenFile) == 0 {
return fmt.Errorf("at most one of bearer_token & bearer_token_file must be configured")
}
return nil return nil
} }

View File

@ -20,6 +20,7 @@ import (
"math/rand" "math/rand"
"net" "net"
"net/http" "net/http"
"strings"
"time" "time"
"golang.org/x/net/context" "golang.org/x/net/context"
@ -77,6 +78,7 @@ type Discovery struct {
refreshInterval time.Duration refreshInterval time.Duration
lastRefresh map[string]*config.TargetGroup lastRefresh map[string]*config.TargetGroup
appsClient AppListClient appsClient AppListClient
token string
} }
// Initialize sets up the discovery for usage. // Initialize sets up the discovery for usage.
@ -86,6 +88,15 @@ func NewDiscovery(conf *config.MarathonSDConfig) (*Discovery, error) {
return nil, err return nil, err
} }
token := conf.BearerToken
if conf.BearerTokenFile != "" {
bf, err := ioutil.ReadFile(conf.BearerTokenFile)
if err != nil {
return nil, err
}
token = strings.TrimSpace(string(bf))
}
client := &http.Client{ client := &http.Client{
Timeout: time.Duration(conf.Timeout), Timeout: time.Duration(conf.Timeout),
Transport: &http.Transport{ Transport: &http.Transport{
@ -98,6 +109,7 @@ func NewDiscovery(conf *config.MarathonSDConfig) (*Discovery, error) {
servers: conf.Servers, servers: conf.Servers,
refreshInterval: time.Duration(conf.RefreshInterval), refreshInterval: time.Duration(conf.RefreshInterval),
appsClient: fetchApps, appsClient: fetchApps,
token: token,
}, nil }, nil
} }
@ -160,7 +172,7 @@ func (md *Discovery) updateServices(ctx context.Context, ch chan<- []*config.Tar
func (md *Discovery) fetchTargetGroups() (map[string]*config.TargetGroup, error) { func (md *Discovery) fetchTargetGroups() (map[string]*config.TargetGroup, error) {
url := RandomAppsURL(md.servers) url := RandomAppsURL(md.servers)
apps, err := md.appsClient(md.client, url) apps, err := md.appsClient(md.client, url, md.token)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -201,11 +213,20 @@ type AppList struct {
} }
// AppListClient defines a function that can be used to get an application list from marathon. // AppListClient defines a function that can be used to get an application list from marathon.
type AppListClient func(client *http.Client, url string) (*AppList, error) type AppListClient func(client *http.Client, url, token string) (*AppList, error)
// fetchApps requests a list of applications from a marathon server. // fetchApps requests a list of applications from a marathon server.
func fetchApps(client *http.Client, url string) (*AppList, error) { func fetchApps(client *http.Client, url, token string) (*AppList, error) {
resp, err := client.Get(url) request, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, err
}
if token != "" {
request.Header.Set("Authorization", "token="+token)
}
resp, err := client.Do(request)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -44,7 +44,7 @@ func TestMarathonSDHandleError(t *testing.T) {
var ( var (
errTesting = errors.New("testing failure") errTesting = errors.New("testing failure")
ch = make(chan []*config.TargetGroup, 1) ch = make(chan []*config.TargetGroup, 1)
client = func(client *http.Client, url string) (*AppList, error) { return nil, errTesting } client = func(client *http.Client, url, token string) (*AppList, error) { return nil, errTesting }
) )
if err := testUpdateServices(client, ch); err != errTesting { if err := testUpdateServices(client, ch); err != errTesting {
t.Fatalf("Expected error: %s", err) t.Fatalf("Expected error: %s", err)
@ -59,7 +59,7 @@ func TestMarathonSDHandleError(t *testing.T) {
func TestMarathonSDEmptyList(t *testing.T) { func TestMarathonSDEmptyList(t *testing.T) {
var ( var (
ch = make(chan []*config.TargetGroup, 1) ch = make(chan []*config.TargetGroup, 1)
client = func(client *http.Client, url string) (*AppList, error) { return &AppList{}, nil } client = func(client *http.Client, url, token string) (*AppList, error) { return &AppList{}, nil }
) )
if err := testUpdateServices(client, ch); err != nil { if err := testUpdateServices(client, ch); err != nil {
t.Fatalf("Got error: %s", err) t.Fatalf("Got error: %s", err)
@ -130,7 +130,7 @@ func TestMarathonSDRemoveApp(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("%s", err) t.Fatalf("%s", err)
} }
md.appsClient = func(client *http.Client, url string) (*AppList, error) { md.appsClient = func(client *http.Client, url, token string) (*AppList, error) {
return marathonTestAppList(marathonValidLabel, 1), nil return marathonTestAppList(marathonValidLabel, 1), nil
} }
go func() { go func() {
@ -165,7 +165,7 @@ func TestMarathonSDRunAndStop(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("%s", err) t.Fatalf("%s", err)
} }
md.appsClient = func(client *http.Client, url string) (*AppList, error) { md.appsClient = func(client *http.Client, url, token string) (*AppList, error) {
return marathonTestAppList(marathonValidLabel, 1), nil return marathonTestAppList(marathonValidLabel, 1), nil
} }
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
@ -213,7 +213,7 @@ func marathonTestZeroTaskPortAppList(labels map[string]string, runningTasks int)
func TestMarathonZeroTaskPorts(t *testing.T) { func TestMarathonZeroTaskPorts(t *testing.T) {
var ( var (
ch = make(chan []*config.TargetGroup, 1) ch = make(chan []*config.TargetGroup, 1)
client = func(client *http.Client, url string) (*AppList, error) { client = func(client *http.Client, url, token string) (*AppList, error) {
return marathonTestZeroTaskPortAppList(marathonValidLabel, 1), nil return marathonTestZeroTaskPortAppList(marathonValidLabel, 1), nil
} }
) )