diff --git a/retrieval/discovery/marathon/marathon.go b/retrieval/discovery/marathon/marathon.go index 905e8ac5f..942362454 100644 --- a/retrieval/discovery/marathon/marathon.go +++ b/retrieval/discovery/marathon/marathon.go @@ -48,7 +48,6 @@ const appListPath string = "/v2/apps/?embed=apps.tasks" type Discovery struct { Servers []string RefreshInterval time.Duration - Done chan struct{} lastRefresh map[string]*config.TargetGroup Client AppListClient } @@ -62,7 +61,7 @@ func (md *Discovery) Run(ctx context.Context, ch chan<- []*config.TargetGroup) { case <-ctx.Done(): return case <-time.After(md.RefreshInterval): - err := md.updateServices(ch) + err := md.updateServices(ctx, ch) if err != nil { log.Errorf("Error while updating services: %s", err) } @@ -70,7 +69,7 @@ func (md *Discovery) Run(ctx context.Context, ch chan<- []*config.TargetGroup) { } } -func (md *Discovery) updateServices(ch chan<- []*config.TargetGroup) error { +func (md *Discovery) updateServices(ctx context.Context, ch chan<- []*config.TargetGroup) error { targetMap, err := md.fetchTargetGroups() if err != nil { return err @@ -80,14 +79,23 @@ func (md *Discovery) updateServices(ch chan<- []*config.TargetGroup) error { for _, tg := range targetMap { all = append(all, tg) } - ch <- all - // Remove services which did disappear + select { + case <-ctx.Done(): + return ctx.Err() + case ch <- all: + } + + // Remove services which did disappear. for source := range md.lastRefresh { _, ok := targetMap[source] if !ok { - log.Debugf("Removing group for %s", source) - ch <- []*config.TargetGroup{{Source: source}} + select { + case <-ctx.Done(): + return ctx.Err() + case ch <- []*config.TargetGroup{{Source: source}}: + log.Debugf("Removing group for %s", source) + } } } diff --git a/retrieval/discovery/marathon/marathon_test.go b/retrieval/discovery/marathon/marathon_test.go index 98a6c5e89..c01c876c2 100644 --- a/retrieval/discovery/marathon/marathon_test.go +++ b/retrieval/discovery/marathon/marathon_test.go @@ -47,7 +47,7 @@ func TestMarathonSDHandleError(t *testing.T) { default: } }() - err := md.updateServices(ch) + err := md.updateServices(context.Background(), ch) if err != errTesting { t.Fatalf("Expected error: %s", err) } @@ -66,7 +66,7 @@ func TestMarathonSDEmptyList(t *testing.T) { default: } }() - err := md.updateServices(ch) + err := md.updateServices(context.Background(), ch) if err != nil { t.Fatalf("Got error: %s", err) } @@ -115,7 +115,7 @@ func TestMarathonSDSendGroup(t *testing.T) { t.Fatal("Did not get a target group.") } }() - err := md.updateServices(ch) + err := md.updateServices(context.Background(), ch) if err != nil { t.Fatalf("Got error: %s", err) } @@ -136,7 +136,7 @@ func TestMarathonSDRemoveApp(t *testing.T) { } } }() - err := md.updateServices(ch) + err := md.updateServices(context.Background(), ch) if err != nil { t.Fatalf("Got error on first update: %s", err) } @@ -144,7 +144,7 @@ func TestMarathonSDRemoveApp(t *testing.T) { md.Client = func(url string) (*AppList, error) { return marathonTestAppList(marathonValidLabel, 0), nil } - err = md.updateServices(ch) + err = md.updateServices(context.Background(), ch) if err != nil { t.Fatalf("Got error on second update: %s", err) }