diff --git a/internal/backend/local/backend_local_test.go b/internal/backend/local/backend_local_test.go index 3b9a315cd4..f929d79eaa 100644 --- a/internal/backend/local/backend_local_test.go +++ b/internal/backend/local/backend_local_test.go @@ -4,6 +4,7 @@ package local import ( + "context" "fmt" "os" "path/filepath" @@ -261,7 +262,7 @@ func (s *stateStorageThatFailsRefresh) State() *states.State { return nil } -func (s *stateStorageThatFailsRefresh) GetRootOutputValues() (map[string]*states.OutputValue, error) { +func (s *stateStorageThatFailsRefresh) GetRootOutputValues(ctx context.Context) (map[string]*states.OutputValue, error) { return nil, fmt.Errorf("unimplemented") } diff --git a/internal/cloud/retry.go b/internal/cloud/retry.go new file mode 100644 index 0000000000..ab784f5964 --- /dev/null +++ b/internal/cloud/retry.go @@ -0,0 +1,101 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package cloud + +import ( + "context" + "log" + "sync/atomic" + "time" +) + +// Fatal implements a RetryBackoff func return value that, if encountered, +// signals that the func should not be retried. In that case, the error +// returned by the interface method will be returned by RetryBackoff +type Fatal interface { + FatalError() error +} + +// NonRetryableError is a simple implementation of Fatal that wraps an error +type NonRetryableError struct { + InnerError error +} + +// FatalError returns the inner error, but also implements Fatal, which +// signals to RetryBackoff that a non-retryable error occurred. +func (e NonRetryableError) FatalError() error { + return e.InnerError +} + +// Error returns the inner error string +func (e NonRetryableError) Error() string { + return e.InnerError.Error() +} + +var ( + initialBackoffDelay = time.Second + maxBackoffDelay = 3 * time.Second +) + +// RetryBackoff retries function f until nil or a FatalError is returned. +// RetryBackoff only returns an error if the context is in error or if a +// FatalError was encountered. +func RetryBackoff(ctx context.Context, f func() error) error { + // doneCh signals that the routine is done and sends the last error + var doneCh = make(chan struct{}) + var errVal atomic.Value + type errWrap struct { + E error + } + + go func() { + // the retry delay between each attempt + var delay time.Duration = 0 + defer close(doneCh) + + for { + select { + case <-ctx.Done(): + return + case <-time.After(delay): + } + + err := f() + switch e := err.(type) { + case nil: + return + case Fatal: + errVal.Store(errWrap{e.FatalError()}) + return + } + + delay *= 2 + if delay == 0 { + delay = initialBackoffDelay + } + + delay = min(delay, maxBackoffDelay) + + log.Printf("[WARN] retryable error: %q, delaying for %s", err, delay) + } + }() + + // Wait until done or deadline + select { + case <-doneCh: + case <-ctx.Done(): + } + + err, hadErr := errVal.Load().(errWrap) + var lastErr error + if hadErr { + lastErr = err.E + } + + if ctx.Err() != nil { + return ctx.Err() + } + + return lastErr +} diff --git a/internal/cloud/retry_test.go b/internal/cloud/retry_test.go new file mode 100644 index 0000000000..3c8c8f989a --- /dev/null +++ b/internal/cloud/retry_test.go @@ -0,0 +1,100 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package cloud + +import ( + "context" + "errors" + "testing" + "time" +) + +type fatalError struct{} + +var fe = errors.New("this was a fatal error") + +func (f fatalError) FatalError() error { + return fe +} + +func (f fatalError) Error() string { + return f.FatalError().Error() +} + +func Test_RetryBackoff_canceled(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithCancel(context.Background()) + + cancel() + + err := RetryBackoff(ctx, func() error { + return nil + }) + + if !errors.Is(err, context.Canceled) { + t.Errorf("expected canceled error, got %q", err) + } +} + +func Test_RetryBackoff_deadline(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Millisecond)) + + defer cancel() + + err := RetryBackoff(ctx, func() error { + time.Sleep(10 * time.Millisecond) + return nil + }) + + if !errors.Is(err, context.DeadlineExceeded) { + t.Errorf("expected timeout error, got %q", err) + } +} + +func Test_RetryBackoff_happy(t *testing.T) { + t.Parallel() + + err := RetryBackoff(context.Background(), func() error { + return nil + }) + + if err != nil { + t.Errorf("expected nil err, got %q", err) + } +} + +func Test_RetryBackoff_fatal(t *testing.T) { + t.Parallel() + + err := RetryBackoff(context.Background(), func() error { + return fatalError{} + }) + + if !errors.Is(fe, err) { + t.Errorf("expected fatal error, got %q", err) + } +} + +func Test_RetryBackoff_non_fatal(t *testing.T) { + t.Parallel() + + var retriedCount = 0 + + err := RetryBackoff(context.Background(), func() error { + retriedCount += 1 + if retriedCount == 2 { + return nil + } + return errors.New("retryable error") + }) + + if err != nil { + t.Errorf("expected no error, got %q", err) + } + + if retriedCount != 2 { + t.Errorf("expected 2 retries, got %d", retriedCount) + } +} diff --git a/internal/cloud/state.go b/internal/cloud/state.go index cb41aac86f..36fa4dcf2f 100644 --- a/internal/cloud/state.go +++ b/internal/cloud/state.go @@ -515,12 +515,37 @@ func (s *State) Delete(force bool) error { } // GetRootOutputValues fetches output values from HCP Terraform -func (s *State) GetRootOutputValues() (map[string]*states.OutputValue, error) { - ctx := context.Background() +func (s *State) GetRootOutputValues(ctx context.Context) (map[string]*states.OutputValue, error) { + // The cloud backend initializes this value to true, but we want to implement + // some custom retry logic. This code presumes that the tfeClient doesn't need + // to be shared with other goroutines by the caller. + s.tfeClient.RetryServerErrors(false) + defer s.tfeClient.RetryServerErrors(true) + + ctx, cancel := context.WithTimeout(ctx, time.Minute) + defer cancel() - so, err := s.tfeClient.StateVersionOutputs.ReadCurrent(ctx, s.workspace.ID) + var so *tfe.StateVersionOutputsList + err := RetryBackoff(ctx, func() error { + var err error + so, err = s.tfeClient.StateVersionOutputs.ReadCurrent(ctx, s.workspace.ID) + + if err != nil { + if strings.Contains(err.Error(), "service unavailable") { + return err + } + return NonRetryableError{err} + } + return nil + }) if err != nil { + switch err { + case context.DeadlineExceeded: + return nil, fmt.Errorf("current outputs were not ready to be read within the deadline. Please try again") + case context.Canceled: + return nil, fmt.Errorf("canceled reading current outputs") + } return nil, fmt.Errorf("could not read state version outputs: %w", err) } diff --git a/internal/cloud/state_test.go b/internal/cloud/state_test.go index 417c9a7776..61919e0b1e 100644 --- a/internal/cloud/state_test.go +++ b/internal/cloud/state_test.go @@ -40,7 +40,7 @@ func TestState_GetRootOutputValues(t *testing.T) { state := &State{tfeClient: b.client, organization: b.Organization, workspace: &tfe.Workspace{ ID: "ws-abcd", }} - outputs, err := state.GetRootOutputValues() + outputs, err := state.GetRootOutputValues(context.Background()) if err != nil { t.Fatalf("error returned from GetRootOutputValues: %s", err) diff --git a/internal/command/output.go b/internal/command/output.go index 4519d280c0..9b4bf9d7c3 100644 --- a/internal/command/output.go +++ b/internal/command/output.go @@ -69,6 +69,10 @@ func (c *OutputCommand) Outputs(statePath string) (map[string]*states.OutputValu return nil, diags } + // Command can be aborted by interruption signals + ctx, done := c.InterruptibleContext(c.CommandContext()) + defer done() + // This is a read-only command c.ignoreRemoteVersionConflict(b) @@ -85,7 +89,7 @@ func (c *OutputCommand) Outputs(statePath string) (map[string]*states.OutputValu return nil, diags } - output, err := stateStore.GetRootOutputValues() + output, err := stateStore.GetRootOutputValues(ctx) if err != nil { return nil, diags.Append(err) } diff --git a/internal/states/remote/state.go b/internal/states/remote/state.go index 2c6edc1b0d..34fbbc638c 100644 --- a/internal/states/remote/state.go +++ b/internal/states/remote/state.go @@ -5,6 +5,7 @@ package remote import ( "bytes" + "context" "fmt" "log" "sync" @@ -59,7 +60,7 @@ func (s *State) State() *states.State { return s.state.DeepCopy() } -func (s *State) GetRootOutputValues() (map[string]*states.OutputValue, error) { +func (s *State) GetRootOutputValues(ctx context.Context) (map[string]*states.OutputValue, error) { if err := s.RefreshState(); err != nil { return nil, fmt.Errorf("Failed to load state: %s", err) } diff --git a/internal/states/remote/state_test.go b/internal/states/remote/state_test.go index ce34ff7bad..cbb6a6219f 100644 --- a/internal/states/remote/state_test.go +++ b/internal/states/remote/state_test.go @@ -4,6 +4,7 @@ package remote import ( + "context" "log" "sync" "testing" @@ -408,7 +409,7 @@ func TestState_GetRootOutputValues(t *testing.T) { }, } - outputs, err := mgr.GetRootOutputValues() + outputs, err := mgr.GetRootOutputValues(context.Background()) if err != nil { t.Errorf("Expected GetRootOutputValues to not return an error, but it returned %v", err) } diff --git a/internal/states/statemgr/filesystem.go b/internal/states/statemgr/filesystem.go index 908a552b0a..7a7a63e3ca 100644 --- a/internal/states/statemgr/filesystem.go +++ b/internal/states/statemgr/filesystem.go @@ -5,6 +5,7 @@ package statemgr import ( "bytes" + "context" "encoding/json" "errors" "fmt" @@ -236,7 +237,7 @@ func (s *Filesystem) RefreshState() error { return s.refreshState() } -func (s *Filesystem) GetRootOutputValues() (map[string]*states.OutputValue, error) { +func (s *Filesystem) GetRootOutputValues(ctx context.Context) (map[string]*states.OutputValue, error) { err := s.RefreshState() if err != nil { return nil, err diff --git a/internal/states/statemgr/filesystem_test.go b/internal/states/statemgr/filesystem_test.go index 8add434e46..ea180275c5 100644 --- a/internal/states/statemgr/filesystem_test.go +++ b/internal/states/statemgr/filesystem_test.go @@ -4,6 +4,7 @@ package statemgr import ( + "context" "io/ioutil" "os" "os/exec" @@ -417,7 +418,7 @@ func TestFilesystem_refreshWhileLocked(t *testing.T) { func TestFilesystem_GetRootOutputValues(t *testing.T) { fs := testFilesystem(t) - outputs, err := fs.GetRootOutputValues() + outputs, err := fs.GetRootOutputValues(context.Background()) if err != nil { t.Errorf("Expected GetRootOutputValues to not return an error, but it returned %v", err) } diff --git a/internal/states/statemgr/lock.go b/internal/states/statemgr/lock.go index 07b3121ede..9d34c20415 100644 --- a/internal/states/statemgr/lock.go +++ b/internal/states/statemgr/lock.go @@ -4,6 +4,8 @@ package statemgr import ( + "context" + "github.com/hashicorp/terraform/internal/schemarepo" "github.com/hashicorp/terraform/internal/states" ) @@ -21,8 +23,8 @@ func (s *LockDisabled) State() *states.State { return s.Inner.State() } -func (s *LockDisabled) GetRootOutputValues() (map[string]*states.OutputValue, error) { - return s.Inner.GetRootOutputValues() +func (s *LockDisabled) GetRootOutputValues(ctx context.Context) (map[string]*states.OutputValue, error) { + return s.Inner.GetRootOutputValues(ctx) } func (s *LockDisabled) WriteState(v *states.State) error { diff --git a/internal/states/statemgr/persistent.go b/internal/states/statemgr/persistent.go index e2fa88b295..1e2c82a735 100644 --- a/internal/states/statemgr/persistent.go +++ b/internal/states/statemgr/persistent.go @@ -4,6 +4,7 @@ package statemgr import ( + "context" "time" version "github.com/hashicorp/go-version" @@ -33,7 +34,7 @@ type Persistent interface { // to differentiate reading the state and reading the outputs within the state. type OutputReader interface { // GetRootOutputValues fetches the root module output values from state or another source - GetRootOutputValues() (map[string]*states.OutputValue, error) + GetRootOutputValues(ctx context.Context) (map[string]*states.OutputValue, error) } // Refresher is the interface for managers that can read snapshots from diff --git a/internal/states/statemgr/statemgr_fake.go b/internal/states/statemgr/statemgr_fake.go index 29e1bf7bc1..25800fdbbd 100644 --- a/internal/states/statemgr/statemgr_fake.go +++ b/internal/states/statemgr/statemgr_fake.go @@ -4,6 +4,7 @@ package statemgr import ( + "context" "errors" "sync" @@ -69,7 +70,7 @@ func (m *fakeFull) PersistState(schemas *schemarepo.Schemas) error { return m.fakeP.WriteState(m.t.State()) } -func (m *fakeFull) GetRootOutputValues() (map[string]*states.OutputValue, error) { +func (m *fakeFull) GetRootOutputValues(ctx context.Context) (map[string]*states.OutputValue, error) { return m.State().RootOutputValues, nil } @@ -119,7 +120,7 @@ func (m *fakeErrorFull) State() *states.State { return nil } -func (m *fakeErrorFull) GetRootOutputValues() (map[string]*states.OutputValue, error) { +func (m *fakeErrorFull) GetRootOutputValues(ctx context.Context) (map[string]*states.OutputValue, error) { return nil, errors.New("fake state manager error") }