diff --git a/pkg/scheduler/scheduler.go b/pkg/scheduler/scheduler.go index 3e742520c..ec59c011e 100644 --- a/pkg/scheduler/scheduler.go +++ b/pkg/scheduler/scheduler.go @@ -5,6 +5,19 @@ import ( "time" ) +// For mocking purposes. +// This little bit of wrapping needs to be done becuase go doesn't do +// covariance, but it does coerse *time.Timer into stoppable implicitly if we +// write it out like so. +var afterFunc = func(d time.Duration, f func()) stoppable { + return time.AfterFunc(d, f) +} + +// stoppable is the subset of time.Timer which we use, split out for mocking purposes +type stoppable interface { + Stop() bool +} + // ProcessFunc is a function to process an item in the work queue. type ProcessFunc func(interface{}) @@ -22,13 +35,13 @@ type ScheduledWorkQueue interface { type scheduledWorkQueue struct { processFunc ProcessFunc - work map[interface{}]*time.Timer + work map[interface{}]stoppable workLock sync.Mutex } // NewScheduledWorkQueue will create a new workqueue with the given processFunc func NewScheduledWorkQueue(processFunc ProcessFunc) ScheduledWorkQueue { - return &scheduledWorkQueue{processFunc, make(map[interface{}]*time.Timer), sync.Mutex{}} + return &scheduledWorkQueue{processFunc, make(map[interface{}]stoppable), sync.Mutex{}} } // Add will add an item to this queue, executing the ProcessFunc after the @@ -39,7 +52,7 @@ func (s *scheduledWorkQueue) Add(obj interface{}, duration time.Duration) { s.Forget(obj) s.workLock.Lock() defer s.workLock.Unlock() - s.work[obj] = time.AfterFunc(duration, func() { + s.work[obj] = afterFunc(duration, func() { defer s.Forget(obj) s.processFunc(obj) }) diff --git a/pkg/scheduler/scheduler_test.go b/pkg/scheduler/scheduler_test.go index 556a362bf..b71871fc5 100644 --- a/pkg/scheduler/scheduler_test.go +++ b/pkg/scheduler/scheduler_test.go @@ -7,6 +7,9 @@ import ( ) func TestAdd(t *testing.T) { + after := newMockAfter() + afterFunc = after.AfterFunc + var wg sync.WaitGroup type testT struct { obj string @@ -20,19 +23,24 @@ func TestAdd(t *testing.T) { for _, test := range tests { wg.Add(1) t.Run(test.obj, func(test testT) func(*testing.T) { + waitSubtest := make(chan struct{}) return func(t *testing.T) { - startTime := time.Now() + startTime := after.currentTime queue := NewScheduledWorkQueue(func(obj interface{}) { defer wg.Done() - durationEarly := test.duration - time.Now().Sub(startTime) + durationEarly := test.duration - after.currentTime.Sub(startTime) + if durationEarly > 0 { t.Errorf("got queue item %.2f seconds too early", float64(durationEarly)/float64(time.Second)) } if obj != test.obj { t.Errorf("expected obj '%+v' but got obj '%+v'", test.obj, obj) } + waitSubtest <- struct{}{} }) queue.Add(test.obj, test.duration) + after.warp(test.duration + time.Millisecond) + <-waitSubtest } }(test)) } @@ -41,6 +49,9 @@ func TestAdd(t *testing.T) { } func TestForget(t *testing.T) { + after := newMockAfter() + afterFunc = after.AfterFunc + var wg sync.WaitGroup type testT struct { obj string @@ -61,10 +72,60 @@ func TestForget(t *testing.T) { }) queue.Add(test.obj, test.duration) queue.Forget(test.obj) - time.Sleep(test.duration * 2) + after.warp(test.duration * 2) } }(test)) } wg.Wait() } + +type timerQueueItem struct { + f func() + t time.Time + run bool + stopped bool +} + +func (tq *timerQueueItem) Stop() bool { + stopped := tq.stopped + tq.stopped = true + return stopped +} + +type mockAfter struct { + startTime time.Time + currentTime time.Time + queue []*timerQueueItem +} + +func newMockAfter() *mockAfter { + return &mockAfter{ + queue: make([]*timerQueueItem, 0), + } +} + +func (m *mockAfter) AfterFunc(d time.Duration, f func()) stoppable { + item := &timerQueueItem{ + f: f, + t: m.currentTime.Add(d), + } + m.queue = append(m.queue, item) + return item +} + +func (m *mockAfter) warp(d time.Duration) { + m.currentTime = m.currentTime.Add(d) + for _, item := range m.queue { + if item.run || item.stopped { + continue + } + + if item.t.Before(m.currentTime) { + item.run = true + go func(f func()) { + f() + }(item.f) + } + } +}