diff --git a/internal/cloud/retry.go b/internal/cloud/retry.go new file mode 100644 index 0000000000..ab784f5964 --- /dev/null +++ b/internal/cloud/retry.go @@ -0,0 +1,101 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package cloud + +import ( + "context" + "log" + "sync/atomic" + "time" +) + +// Fatal implements a RetryBackoff func return value that, if encountered, +// signals that the func should not be retried. In that case, the error +// returned by the interface method will be returned by RetryBackoff +type Fatal interface { + FatalError() error +} + +// NonRetryableError is a simple implementation of Fatal that wraps an error +type NonRetryableError struct { + InnerError error +} + +// FatalError returns the inner error, but also implements Fatal, which +// signals to RetryBackoff that a non-retryable error occurred. +func (e NonRetryableError) FatalError() error { + return e.InnerError +} + +// Error returns the inner error string +func (e NonRetryableError) Error() string { + return e.InnerError.Error() +} + +var ( + initialBackoffDelay = time.Second + maxBackoffDelay = 3 * time.Second +) + +// RetryBackoff retries function f until nil or a FatalError is returned. +// RetryBackoff only returns an error if the context is in error or if a +// FatalError was encountered. +func RetryBackoff(ctx context.Context, f func() error) error { + // doneCh signals that the routine is done and sends the last error + var doneCh = make(chan struct{}) + var errVal atomic.Value + type errWrap struct { + E error + } + + go func() { + // the retry delay between each attempt + var delay time.Duration = 0 + defer close(doneCh) + + for { + select { + case <-ctx.Done(): + return + case <-time.After(delay): + } + + err := f() + switch e := err.(type) { + case nil: + return + case Fatal: + errVal.Store(errWrap{e.FatalError()}) + return + } + + delay *= 2 + if delay == 0 { + delay = initialBackoffDelay + } + + delay = min(delay, maxBackoffDelay) + + log.Printf("[WARN] retryable error: %q, delaying for %s", err, delay) + } + }() + + // Wait until done or deadline + select { + case <-doneCh: + case <-ctx.Done(): + } + + err, hadErr := errVal.Load().(errWrap) + var lastErr error + if hadErr { + lastErr = err.E + } + + if ctx.Err() != nil { + return ctx.Err() + } + + return lastErr +} diff --git a/internal/cloud/retry_test.go b/internal/cloud/retry_test.go new file mode 100644 index 0000000000..3c8c8f989a --- /dev/null +++ b/internal/cloud/retry_test.go @@ -0,0 +1,100 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package cloud + +import ( + "context" + "errors" + "testing" + "time" +) + +type fatalError struct{} + +var fe = errors.New("this was a fatal error") + +func (f fatalError) FatalError() error { + return fe +} + +func (f fatalError) Error() string { + return f.FatalError().Error() +} + +func Test_RetryBackoff_canceled(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithCancel(context.Background()) + + cancel() + + err := RetryBackoff(ctx, func() error { + return nil + }) + + if !errors.Is(err, context.Canceled) { + t.Errorf("expected canceled error, got %q", err) + } +} + +func Test_RetryBackoff_deadline(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Millisecond)) + + defer cancel() + + err := RetryBackoff(ctx, func() error { + time.Sleep(10 * time.Millisecond) + return nil + }) + + if !errors.Is(err, context.DeadlineExceeded) { + t.Errorf("expected timeout error, got %q", err) + } +} + +func Test_RetryBackoff_happy(t *testing.T) { + t.Parallel() + + err := RetryBackoff(context.Background(), func() error { + return nil + }) + + if err != nil { + t.Errorf("expected nil err, got %q", err) + } +} + +func Test_RetryBackoff_fatal(t *testing.T) { + t.Parallel() + + err := RetryBackoff(context.Background(), func() error { + return fatalError{} + }) + + if !errors.Is(fe, err) { + t.Errorf("expected fatal error, got %q", err) + } +} + +func Test_RetryBackoff_non_fatal(t *testing.T) { + t.Parallel() + + var retriedCount = 0 + + err := RetryBackoff(context.Background(), func() error { + retriedCount += 1 + if retriedCount == 2 { + return nil + } + return errors.New("retryable error") + }) + + if err != nil { + t.Errorf("expected no error, got %q", err) + } + + if retriedCount != 2 { + t.Errorf("expected 2 retries, got %d", retriedCount) + } +} diff --git a/internal/cloud/state.go b/internal/cloud/state.go index c886bf445a..36fa4dcf2f 100644 --- a/internal/cloud/state.go +++ b/internal/cloud/state.go @@ -516,10 +516,36 @@ func (s *State) Delete(force bool) error { // GetRootOutputValues fetches output values from HCP Terraform func (s *State) GetRootOutputValues(ctx context.Context) (map[string]*states.OutputValue, error) { + // The cloud backend initializes this value to true, but we want to implement + // some custom retry logic. This code presumes that the tfeClient doesn't need + // to be shared with other goroutines by the caller. + s.tfeClient.RetryServerErrors(false) + defer s.tfeClient.RetryServerErrors(true) - so, err := s.tfeClient.StateVersionOutputs.ReadCurrent(ctx, s.workspace.ID) + ctx, cancel := context.WithTimeout(ctx, time.Minute) + defer cancel() + + var so *tfe.StateVersionOutputsList + err := RetryBackoff(ctx, func() error { + var err error + so, err = s.tfeClient.StateVersionOutputs.ReadCurrent(ctx, s.workspace.ID) + + if err != nil { + if strings.Contains(err.Error(), "service unavailable") { + return err + } + return NonRetryableError{err} + } + return nil + }) if err != nil { + switch err { + case context.DeadlineExceeded: + return nil, fmt.Errorf("current outputs were not ready to be read within the deadline. Please try again") + case context.Canceled: + return nil, fmt.Errorf("canceled reading current outputs") + } return nil, fmt.Errorf("could not read state version outputs: %w", err) }