diff --git a/internal/scheduler/additional_verification_test.go b/internal/scheduler/additional_verification_test.go index 67b4d0561f..ffb1c7f5ff 100644 --- a/internal/scheduler/additional_verification_test.go +++ b/internal/scheduler/additional_verification_test.go @@ -137,6 +137,82 @@ func TestSchedulerCancelCtx(t *testing.T) { <-jobDone } +func TestSchedulerWaitOnRunningJobs(t *testing.T) { + // do not use t.Parallel() since it relies on the sys eventer + require := require.New(t) + conn, _ := db.TestSetup(t, "postgres") + wrapper := db.TestWrapper(t) + iam.TestRepo(t, conn, wrapper) + event.TestEnableEventing(t, true) + testConfig := event.DefaultEventerConfig() + testLock := &sync.Mutex{} + testLogger := hclog.New(&hclog.LoggerOptions{ + Mutex: testLock, + }) + err := event.InitSysEventer(testLogger, testLock, "TestSchedulerWorkflow", event.WithEventerConfig(testConfig)) + require.NoError(err) + + sched := TestScheduler(t, conn, wrapper, WithRunJobsLimit(10), WithRunJobsInterval(time.Second), WithMonitorInterval(time.Second)) + + jobReady := make(chan struct{}) + finishJob := make(chan struct{}) + fn := func(ctx context.Context) error { + jobReady <- struct{}{} + <-finishJob + return nil + } + + tj := testJob{name: "name", description: "desc", fn: fn, nextRunIn: time.Hour} + err = sched.RegisterJob(context.Background(), tj) + require.NoError(err) + + baseCtx, baseCnl := context.WithCancel(context.Background()) + var wg sync.WaitGroup + err = sched.Start(baseCtx, &wg) + require.NoError(err) + + waiting := make(chan struct{}) + go func() { + defer close(waiting) + wg.Wait() + }() + + // Wait for scheduler to run job + <-jobReady + + // Verify waitGroup is still waiting + select { + case <-waiting: + t.Fatal("expected waitgroup to be waiting") + default: + } + + // Cancel base context + baseCnl() + + // Sleep to propagate context cancel + time.Sleep(250 * time.Millisecond) + + // Verify waitGroup is still waiting + select { + case <-waiting: + t.Fatal("expected waitgroup to be waiting") + default: + } + + finishJob <- struct{}{} + + // Sleep to ensure job finish is registered + time.Sleep(250 * time.Millisecond) + + // Now that job has finished, verify waitGroup is no longer waiting + select { + case <-waiting: + default: + t.Fatal("expected waitgroup to no longer be waiting") + } +} + func TestSchedulerInterruptedCancelCtx(t *testing.T) { // do not use t.Parallel() since it relies on the sys eventer assert, require := assert.New(t), require.New(t) @@ -259,7 +335,13 @@ func TestSchedulerJobProgress(t *testing.T) { sched := TestScheduler(t, conn, wrapper, WithRunJobsLimit(10), WithRunJobsInterval(time.Second), WithMonitorInterval(time.Second)) jobReady := make(chan struct{}) + done := make(chan struct{}) fn := func(ctx context.Context) error { + select { + case <-done: + return nil + default: + } jobReady <- struct{}{} <-ctx.Done() return nil @@ -268,6 +350,11 @@ func TestSchedulerJobProgress(t *testing.T) { statusRequest := make(chan struct{}) jobStatus := make(chan JobStatus) status := func() JobStatus { + select { + case <-done: + return JobStatus{} + default: + } statusRequest <- struct{}{} return <-jobStatus } @@ -332,7 +419,9 @@ func TestSchedulerJobProgress(t *testing.T) { assert.Equal(uint32(10), run.CompletedCount) baseCnl() - // unblock goroutines waiting on channels + // Close done to bypass future job run / job status requests that will block on channels + close(done) + // unblock existing goroutines waiting on channels jobStatus <- JobStatus{} } diff --git a/internal/scheduler/scheduler.go b/internal/scheduler/scheduler.go index ca7302eeb6..1c59b2626f 100644 --- a/internal/scheduler/scheduler.go +++ b/internal/scheduler/scheduler.go @@ -176,13 +176,15 @@ func (s *Scheduler) Start(ctx context.Context, wg *sync.WaitGroup) error { func (s *Scheduler) start(ctx context.Context) { const op = "scheduler.(Scheduler).start" timer := time.NewTimer(s.runJobsInterval) + var wg sync.WaitGroup for { select { case <-ctx.Done(): + event.WriteSysEvent(ctx, op, "scheduling loop received shutdown, waiting for jobs to finish", "server id", s.serverId) + wg.Wait() event.WriteSysEvent(ctx, op, "scheduling loop shutting down", "server id", s.serverId) return case <-timer.C: - repo, err := s.jobRepoFn() if err != nil { event.WriteError(ctx, op, err, event.WithInfoMsg("error creating job repo")) @@ -196,7 +198,7 @@ func (s *Scheduler) start(ctx context.Context) { } for _, r := range runs { - err := s.runJob(ctx, r) + err := s.runJob(ctx, &wg, r) if err != nil { event.WriteError(ctx, op, err, event.WithInfoMsg("error starting job")) if _, inner := repo.FailRun(ctx, r.PrivateId, 0, 0); inner != nil { @@ -210,7 +212,7 @@ func (s *Scheduler) start(ctx context.Context) { } } -func (s *Scheduler) runJob(ctx context.Context, r *job.Run) error { +func (s *Scheduler) runJob(ctx context.Context, wg *sync.WaitGroup, r *job.Run) error { const op = "scheduler.(Scheduler).runJob" regJob, ok := s.registeredJobs.Load(r.JobName) if !ok { @@ -231,24 +233,27 @@ func (s *Scheduler) runJob(ctx context.Context, r *job.Run) error { var jobContext context.Context jobContext, rj.cancelCtx = context.WithCancel(ctx) + wg.Add(1) go func() { defer rj.cancelCtx() + defer wg.Done() runErr := j.Run(jobContext) - var updateErr error // Get final status report to update run progress with status := j.Status() - - switch runErr { - case nil: + var updateErr error + switch { + case ctx.Err() != nil: + // Base context is no longer valid, skip repo updates as they will fail and exit + case runErr == nil: nextRun, inner := j.NextRunIn() if inner != nil { event.WriteError(ctx, op, inner, event.WithInfoMsg("error getting next run time", "name", j.Name())) } - _, updateErr = repo.CompleteRun(jobContext, r.PrivateId, nextRun, status.Completed, status.Total) + _, updateErr = repo.CompleteRun(ctx, r.PrivateId, nextRun, status.Completed, status.Total) default: event.WriteError(ctx, op, runErr, event.WithInfoMsg("job run failed", "run id", r.PrivateId, "name", j.Name())) - _, updateErr = repo.FailRun(jobContext, r.PrivateId, status.Completed, status.Total) + _, updateErr = repo.FailRun(ctx, r.PrivateId, status.Completed, status.Total) } if updateErr != nil {