diff --git a/internal/scheduler/scheduler.go b/internal/scheduler/scheduler.go index 66d58dfcc4..ca7302eeb6 100644 --- a/internal/scheduler/scheduler.go +++ b/internal/scheduler/scheduler.go @@ -138,6 +138,12 @@ func (s *Scheduler) Start(ctx context.Context, wg *sync.WaitGroup) error { event.WriteSysEvent(ctx, op, "scheduler already started, skipping") return nil } + if ctx == nil { + return errors.New(errors.InvalidParameter, op, "missing context") + } + if wg == nil { + return errors.New(errors.InvalidParameter, op, "missing wait group") + } if err := ctx.Err(); err != nil { return errors.Wrap(err, op) @@ -154,9 +160,7 @@ func (s *Scheduler) Start(ctx context.Context, wg *sync.WaitGroup) error { return errors.Wrap(err, op) } - if wg != nil { - wg.Add(2) - } + wg.Add(2) go func() { defer wg.Done() s.start(ctx) diff --git a/internal/scheduler/scheduler_test.go b/internal/scheduler/scheduler_test.go index 29f6f92c7a..05fca0aab2 100644 --- a/internal/scheduler/scheduler_test.go +++ b/internal/scheduler/scheduler_test.go @@ -2,6 +2,7 @@ package scheduler import ( "context" + "sync" "testing" "time" @@ -366,3 +367,54 @@ func TestScheduler_UpdateJobNextRunInAtLeast(t *testing.T) { assert.Equal(previousNextRun.Add(-1*time.Hour).Round(time.Minute), dbJob.NextScheduledRun.AsTime().Round(time.Minute)) }) } + +func TestScheduler_Start(t *testing.T) { + t.Parallel() + conn, _ := db.TestSetup(t, "postgres") + wrapper := db.TestWrapper(t) + iam.TestRepo(t, conn, wrapper) + + sched := TestScheduler(t, conn, wrapper) + + tests := []struct { + name string + ctx context.Context + wg *sync.WaitGroup + wantErr bool + wantErrContains string + }{ + + { + name: "missing-ctx", + wg: &sync.WaitGroup{}, + wantErr: true, + wantErrContains: "missing context", + }, + { + name: "missing-waitgroup", + ctx: context.Background(), + wantErr: true, + wantErrContains: "missing wait group", + }, + + { + name: "valid", + ctx: context.Background(), + wg: &sync.WaitGroup{}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + err := sched.Start(tt.ctx, tt.wg) + if tt.wantErr { + require.Error(err) + if tt.wantErrContains != "" { + assert.Contains(err.Error(), tt.wantErrContains) + } + return + } + require.NoError(err) + }) + } +}