scheduled_task: fix memory leaks (#1649)

This commit is contained in:
Jan-Otto Kröpke 2024-09-28 15:15:15 +02:00 committed by GitHub
parent 798bf32dec
commit 01e809315c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -33,6 +33,8 @@ var ConfigDefaults = Config{
type Collector struct { type Collector struct {
config Config config Config
scheduledTasksCh chan *scheduledTaskResults
lastResult *prometheus.Desc lastResult *prometheus.Desc
missedRuns *prometheus.Desc missedRuns *prometheus.Desc
state *prometheus.Desc state *prometheus.Desc
@ -57,7 +59,9 @@ const (
SCHED_S_TASK_HAS_NOT_RUN TaskResult = 0x00041303 SCHED_S_TASK_HAS_NOT_RUN TaskResult = 0x00041303
) )
type ScheduledTask struct { var taskStates = []string{"disabled", "queued", "ready", "running", "unknown"}
type scheduledTask struct {
Name string Name string
Path string Path string
Enabled bool Enabled bool
@ -66,7 +70,10 @@ type ScheduledTask struct {
LastTaskResult TaskResult LastTaskResult TaskResult
} }
type ScheduledTasks []ScheduledTask type scheduledTaskResults struct {
scheduledTasks []scheduledTask
err error
}
func New(config *Config) *Collector { func New(config *Config) *Collector {
if config == nil { if config == nil {
@ -133,10 +140,23 @@ func (c *Collector) GetPerfCounter(_ *slog.Logger) ([]string, error) {
} }
func (c *Collector) Close(_ *slog.Logger) error { func (c *Collector) Close(_ *slog.Logger) error {
close(c.scheduledTasksCh)
c.scheduledTasksCh = nil
return nil return nil
} }
func (c *Collector) Build(_ *slog.Logger, _ *wmi.Client) error { func (c *Collector) Build(_ *slog.Logger, _ *wmi.Client) error {
initErrCh := make(chan error)
c.scheduledTasksCh = make(chan *scheduledTaskResults)
go c.initializeScheduleService(initErrCh)
if err := <-initErrCh; err != nil {
return fmt.Errorf("initialize schedule service: %w", err)
}
c.lastResult = prometheus.NewDesc( c.lastResult = prometheus.NewDesc(
prometheus.BuildFQName(types.Namespace, Name, "last_result"), prometheus.BuildFQName(types.Namespace, Name, "last_result"),
"The result that was returned the last time the registered task was run", "The result that was returned the last time the registered task was run",
@ -174,12 +194,10 @@ func (c *Collector) Collect(_ *types.ScrapeContext, logger *slog.Logger, ch chan
return nil return nil
} }
var TASK_STATES = []string{"disabled", "queued", "ready", "running", "unknown"}
func (c *Collector) collect(ch chan<- prometheus.Metric) error { func (c *Collector) collect(ch chan<- prometheus.Metric) error {
scheduledTasks, err := getScheduledTasks() scheduledTasks, err := c.getScheduledTasks()
if err != nil { if err != nil {
return err return fmt.Errorf("get scheduled tasks: %w", err)
} }
for _, task := range scheduledTasks { for _, task := range scheduledTasks {
@ -188,7 +206,7 @@ func (c *Collector) collect(ch chan<- prometheus.Metric) error {
continue continue
} }
for _, state := range TASK_STATES { for _, state := range taskStates {
var stateValue float64 var stateValue float64
if strings.ToLower(task.State.String()) == state { if strings.ToLower(task.State.String()) == state {
@ -231,14 +249,15 @@ func (c *Collector) collect(ch chan<- prometheus.Metric) error {
return nil return nil
} }
const SCHEDULED_TASK_PROGRAM_ID = "Schedule.Service.1" func (c *Collector) getScheduledTasks() ([]scheduledTask, error) {
c.scheduledTasksCh <- nil
// S_FALSE is returned by CoInitialize if it was already called on this thread. scheduledTasks := <-c.scheduledTasksCh
const S_FALSE = 0x00000001
func getScheduledTasks() (ScheduledTasks, error) { return scheduledTasks.scheduledTasks, scheduledTasks.err
var scheduledTasks ScheduledTasks }
func (c *Collector) initializeScheduleService(initErrCh chan<- error) {
// The only way to run WMI queries in parallel while being thread-safe is to // The only way to run WMI queries in parallel while being thread-safe is to
// ensure the CoInitialize[Ex]() call is bound to its current OS thread. // ensure the CoInitialize[Ex]() call is bound to its current OS thread.
// Otherwise, attempting to initialize and run parallel queries across // Otherwise, attempting to initialize and run parallel queries across
@ -248,46 +267,95 @@ func getScheduledTasks() (ScheduledTasks, error) {
if err := ole.CoInitializeEx(0, ole.COINIT_MULTITHREADED); err != nil { if err := ole.CoInitializeEx(0, ole.COINIT_MULTITHREADED); err != nil {
var oleCode *ole.OleError var oleCode *ole.OleError
if errors.As(err, &oleCode) && oleCode.Code() != ole.S_OK && oleCode.Code() != S_FALSE { if errors.As(err, &oleCode) && oleCode.Code() != ole.S_OK && oleCode.Code() != wmi.S_FALSE {
return nil, err initErrCh <- err
return
} }
} }
defer ole.CoUninitialize() defer ole.CoUninitialize()
schedClassID, err := ole.ClassIDFrom(SCHEDULED_TASK_PROGRAM_ID) scheduleClassID, err := ole.ClassIDFrom("Schedule.Service.1")
if err != nil { if err != nil {
return scheduledTasks, err initErrCh <- err
return
} }
taskSchedulerObj, err := ole.CreateInstance(schedClassID, nil) taskSchedulerObj, err := ole.CreateInstance(scheduleClassID, nil)
if err != nil || taskSchedulerObj == nil { if err != nil || taskSchedulerObj == nil {
return scheduledTasks, err initErrCh <- err
return
} }
defer taskSchedulerObj.Release() defer taskSchedulerObj.Release()
taskServiceObj := taskSchedulerObj.MustQueryInterface(ole.IID_IDispatch) taskServiceObj := taskSchedulerObj.MustQueryInterface(ole.IID_IDispatch)
_, err = oleutil.CallMethod(taskServiceObj, "Connect")
if err != nil {
return scheduledTasks, err
}
defer taskServiceObj.Release() defer taskServiceObj.Release()
res, err := oleutil.CallMethod(taskServiceObj, "GetFolder", `\`) taskService, err := oleutil.CallMethod(taskServiceObj, "Connect")
if err != nil { if err != nil {
return scheduledTasks, err initErrCh <- err
return
} }
rootFolderObj := res.ToIDispatch() defer func(taskService *ole.VARIANT) {
defer rootFolderObj.Release() _ = taskService.Clear()
}(taskService)
err = fetchTasksRecursively(rootFolderObj, &scheduledTasks) close(initErrCh)
return scheduledTasks, err scheduledTasks := make([]scheduledTask, 0, 100)
for range c.scheduledTasksCh {
func() {
// Clear the slice to avoid memory leaks
clear(scheduledTasks)
scheduledTasks = scheduledTasks[:0]
res, err := oleutil.CallMethod(taskServiceObj, "GetFolder", `\`)
if err != nil {
c.scheduledTasksCh <- &scheduledTaskResults{err: err}
return
}
rootFolderObj := res.ToIDispatch()
defer rootFolderObj.Release()
err = fetchTasksRecursively(rootFolderObj, &scheduledTasks)
c.scheduledTasksCh <- &scheduledTaskResults{scheduledTasks: scheduledTasks, err: err}
}()
}
} }
func fetchTasksInFolder(folder *ole.IDispatch, scheduledTasks *ScheduledTasks) error { func fetchTasksRecursively(folder *ole.IDispatch, scheduledTasks *[]scheduledTask) error {
if err := fetchTasksInFolder(folder, scheduledTasks); err != nil {
return err
}
res, err := oleutil.CallMethod(folder, "GetFolders", 1)
if err != nil {
return err
}
subFolders := res.ToIDispatch()
defer subFolders.Release()
err = oleutil.ForEach(subFolders, func(v *ole.VARIANT) error {
subFolder := v.ToIDispatch()
defer subFolder.Release()
return fetchTasksRecursively(subFolder, scheduledTasks)
})
return err
}
func fetchTasksInFolder(folder *ole.IDispatch, scheduledTasks *[]scheduledTask) error {
res, err := oleutil.CallMethod(folder, "GetTasks", 1) res, err := oleutil.CallMethod(folder, "GetTasks", 1)
if err != nil { if err != nil {
return err return err
@ -313,31 +381,8 @@ func fetchTasksInFolder(folder *ole.IDispatch, scheduledTasks *ScheduledTasks) e
return err return err
} }
func fetchTasksRecursively(folder *ole.IDispatch, scheduledTasks *ScheduledTasks) error { func parseTask(task *ole.IDispatch) (scheduledTask, error) {
if err := fetchTasksInFolder(folder, scheduledTasks); err != nil { var scheduledTask scheduledTask
return err
}
res, err := oleutil.CallMethod(folder, "GetFolders", 1)
if err != nil {
return err
}
subFolders := res.ToIDispatch()
defer subFolders.Release()
err = oleutil.ForEach(subFolders, func(v *ole.VARIANT) error {
subFolder := v.ToIDispatch()
defer subFolder.Release()
return fetchTasksRecursively(subFolder, scheduledTasks)
})
return err
}
func parseTask(task *ole.IDispatch) (ScheduledTask, error) {
var scheduledTask ScheduledTask
taskNameVar, err := oleutil.GetProperty(task, "Name") taskNameVar, err := oleutil.GetProperty(task, "Name")
if err != nil { if err != nil {