diff --git a/state/remote/remote_test.go b/state/remote/remote_test.go index 665e49a20c..77cdd5b4f1 100644 --- a/state/remote/remote_test.go +++ b/state/remote/remote_test.go @@ -2,6 +2,8 @@ package remote import ( "bytes" + "crypto/md5" + "encoding/json" "testing" "github.com/hashicorp/terraform/state" @@ -60,3 +62,57 @@ func (nilClient) Get() (*Payload, error) { return nil, nil } func (c nilClient) Put([]byte) error { return nil } func (c nilClient) Delete() error { return nil } + +// mockClient is a client that tracks persisted state snapshots only in +// memory and also logs what it has been asked to do for use in test +// assertions. +type mockClient struct { + current []byte + log []mockClientRequest +} + +type mockClientRequest struct { + Method string + Content map[string]interface{} +} + +func (c *mockClient) Get() (*Payload, error) { + c.appendLog("Get", c.current) + if c.current == nil { + return nil, nil + } + checksum := md5.Sum(c.current) + return &Payload{ + Data: c.current, + MD5: checksum[:], + }, nil +} + +func (c *mockClient) Put(data []byte) error { + c.appendLog("Put", data) + c.current = data + return nil +} + +func (c *mockClient) Delete() error { + c.appendLog("Delete", c.current) + c.current = nil + return nil +} + +func (c *mockClient) appendLog(method string, content []byte) { + // For easier test assertions, we actually log the result of decoding + // the content JSON rather than the raw bytes. Callers are in principle + // allowed to provide any arbitrary bytes here, but we know we're only + // using this to test our own State implementation here and that always + // uses the JSON state format, so this is fine. + + var contentVal map[string]interface{} + if content != nil { + err := json.Unmarshal(content, &contentVal) + if err != nil { + panic(err) // should never happen because our tests control this input + } + } + c.log = append(c.log, mockClientRequest{method, contentVal}) +} diff --git a/state/remote/state.go b/state/remote/state.go index e73fbe8f58..6d701da800 100644 --- a/state/remote/state.go +++ b/state/remote/state.go @@ -123,9 +123,11 @@ func (s *State) PersistState() error { defer s.mu.Unlock() if s.readState != nil { - if !statefile.StatesMarshalEqual(s.state, s.readState) { - s.serial++ + if statefile.StatesMarshalEqual(s.state, s.readState) { + // If the state hasn't changed at all then we have nothing to do. + return nil } + s.serial++ } else { // We might be writing a new state altogether, but before we do that // we'll check to make sure there isn't already a snapshot present diff --git a/state/remote/state_test.go b/state/remote/state_test.go index efdf04cf86..949eefe116 100644 --- a/state/remote/state_test.go +++ b/state/remote/state_test.go @@ -4,7 +4,12 @@ import ( "sync" "testing" + "github.com/google/go-cmp/cmp" + "github.com/zclconf/go-cty/cty" + + "github.com/hashicorp/terraform/states" "github.com/hashicorp/terraform/states/statemgr" + "github.com/hashicorp/terraform/version" ) func TestState_impl(t *testing.T) { @@ -20,7 +25,7 @@ func TestStateRace(t *testing.T) { Client: nilClient{}, } - current := state.TestStateInitial() + current := states.NewState() var wg sync.WaitGroup @@ -35,3 +40,116 @@ func TestStateRace(t *testing.T) { } wg.Wait() } + +func TestStatePersist(t *testing.T) { + mgr := &State{ + Client: &mockClient{ + // Initial state just to give us a fixed starting point for our + // test assertions below, or else we'd need to deal with + // random lineage. + current: []byte(` + { + "version": 4, + "lineage": "mock-lineage", + "serial": 1, + "terraform_version":"0.0.0", + "outputs": {}, + "resources": [] + } + `), + }, + } + + // In normal use (during a Terraform operation) we always refresh and read + // before any writes would happen, so we'll mimic that here for realism. + if err := mgr.RefreshState(); err != nil { + t.Fatalf("failed to RefreshState: %s", err) + } + s := mgr.State() + + s.RootModule().SetOutputValue("foo", cty.StringVal("bar"), false) + if err := mgr.WriteState(s); err != nil { + t.Fatalf("failed to WriteState: %s", err) + } + if err := mgr.PersistState(); err != nil { + t.Fatalf("failed to PersistState: %s", err) + } + + // Persisting the same state again should be a no-op: it doesn't fail, + // but it ought not to appear in the client's log either. + if err := mgr.WriteState(s); err != nil { + t.Fatalf("failed to WriteState: %s", err) + } + if err := mgr.PersistState(); err != nil { + t.Fatalf("failed to PersistState: %s", err) + } + + // ...but if we _do_ change something in the state then we should see + // it re-persist. + s.RootModule().SetOutputValue("foo", cty.StringVal("baz"), false) + if err := mgr.WriteState(s); err != nil { + t.Fatalf("failed to WriteState: %s", err) + } + if err := mgr.PersistState(); err != nil { + t.Fatalf("failed to PersistState: %s", err) + } + + got := mgr.Client.(*mockClient).log + want := []mockClientRequest{ + // The initial fetch from mgr.RefreshState above. + { + Method: "Get", + Content: map[string]interface{}{ + "version": 4.0, // encoding/json decodes this as float64 by default + "lineage": "mock-lineage", + "serial": 1.0, // encoding/json decodes this as float64 by default + "terraform_version": "0.0.0", + "outputs": map[string]interface{}{}, + "resources": []interface{}{}, + }, + }, + + // First call to PersistState, with output "foo" set to "bar". + { + Method: "Put", + Content: map[string]interface{}{ + "version": 4.0, + "lineage": "mock-lineage", + "serial": 2.0, // serial increases because the outputs changed + "terraform_version": version.Version, + "outputs": map[string]interface{}{ + "foo": map[string]interface{}{ + "type": "string", + "value": "bar", + }, + }, + "resources": []interface{}{}, + }, + }, + + // Second call to PersistState generates no client requests, because + // nothing changed in the state itself. + + // Third call to PersistState, with the "foo" output value updated + // to "baz". + { + Method: "Put", + Content: map[string]interface{}{ + "version": 4.0, + "lineage": "mock-lineage", + "serial": 3.0, // serial increases because the outputs changed + "terraform_version": version.Version, + "outputs": map[string]interface{}{ + "foo": map[string]interface{}{ + "type": "string", + "value": "baz", + }, + }, + "resources": []interface{}{}, + }, + }, + } + if diff := cmp.Diff(want, got); len(diff) > 0 { + t.Errorf("incorrect client requests\n%s", diff) + } +}