@ -2,10 +2,14 @@ package s3
import (
"bytes"
"crypto/md5"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
@ -17,6 +21,9 @@ import (
"github.com/hashicorp/terraform/state/remote"
)
// Store the last saved serial in dynamo with this suffix for consistency checks.
const stateIDSuffix = "-md5"
type RemoteClient struct {
s3Client * s3 . S3
dynClient * dynamodb . DynamoDB
@ -28,7 +35,58 @@ type RemoteClient struct {
lockTable string
}
func ( c * RemoteClient ) Get ( ) ( * remote . Payload , error ) {
var (
// The amount of time we will retry a state waiting for it to match the
// expected checksum.
consistencyRetryTimeout = 10 * time . Second
// delay when polling the state
consistencyRetryPollInterval = 2 * time . Second
// checksum didn't match the remote state
errBadChecksum = errors . New ( "invalid state checksum" )
)
// test hook called when checksums don't match
var testChecksumHook func ( )
func ( c * RemoteClient ) Get ( ) ( payload * remote . Payload , err error ) {
deadline := time . Now ( ) . Add ( consistencyRetryTimeout )
// If we have a checksum, and the returned payload doesn't match, we retry
// up until deadline.
for {
payload , err = c . get ( )
if err != nil {
return nil , err
}
// verify that this state is what we expect
if expected , err := c . getMD5 ( ) ; err != nil {
log . Printf ( "[WARNING] failed to fetch state md5: %s" , err )
} else if len ( expected ) > 0 && ! bytes . Equal ( expected , payload . MD5 ) {
log . Printf ( "[WARNING] state md5 mismatch: expected '%x', got '%x'" , expected , payload . MD5 )
if testChecksumHook != nil {
testChecksumHook ( )
}
if time . Now ( ) . Before ( deadline ) {
time . Sleep ( consistencyRetryPollInterval )
log . Println ( "[INFO] retrying S3 RemoteClient.Get..." )
continue
}
return nil , errBadChecksum
}
break
}
return payload , err
}
func ( c * RemoteClient ) get ( ) ( * remote . Payload , error ) {
output , err := c . s3Client . GetObject ( & s3 . GetObjectInput {
Bucket : & c . bucketName ,
Key : & c . path ,
@ -53,8 +111,10 @@ func (c *RemoteClient) Get() (*remote.Payload, error) {
return nil , fmt . Errorf ( "Failed to read remote state: %s" , err )
}
sum := md5 . Sum ( buf . Bytes ( ) )
payload := & remote . Payload {
Data : buf . Bytes ( ) ,
MD5 : sum [ : ] ,
}
// If there was no data, then return nil
@ -92,11 +152,20 @@ func (c *RemoteClient) Put(data []byte) error {
log . Printf ( "[DEBUG] Uploading remote state to S3: %#v" , i )
if _ , err := c . s3Client . PutObject ( i ) ; err == nil {
return nil
} else {
_ , err := c . s3Client . PutObject ( i )
if err != nil {
return fmt . Errorf ( "Failed to upload state: %v" , err )
}
sum := md5 . Sum ( data )
if err := c . putMD5 ( sum [ : ] ) ; err != nil {
// if this errors out, we unfortunately have to error out altogether,
// since the next Get will inevitably fail.
return fmt . Errorf ( "failed to store state MD5: %s" , err )
}
return nil
}
func ( c * RemoteClient ) Delete ( ) error {
@ -105,7 +174,15 @@ func (c *RemoteClient) Delete() error {
Key : & c . path ,
} )
return err
if err != nil {
return err
}
if err := c . deleteMD5 ( ) ; err != nil {
log . Printf ( "error deleting state md5: %s" , err )
}
return nil
}
func ( c * RemoteClient ) Lock ( info * state . LockInfo ) ( string , error ) {
@ -146,9 +223,84 @@ func (c *RemoteClient) Lock(info *state.LockInfo) (string, error) {
}
return "" , lockErr
}
return info . ID , nil
}
func ( c * RemoteClient ) getMD5 ( ) ( [ ] byte , error ) {
if c . lockTable == "" {
return nil , nil
}
getParams := & dynamodb . GetItemInput {
Key : map [ string ] * dynamodb . AttributeValue {
"LockID" : { S : aws . String ( c . lockPath ( ) + stateIDSuffix ) } ,
} ,
ProjectionExpression : aws . String ( "LockID, Digest" ) ,
TableName : aws . String ( c . lockTable ) ,
}
resp , err := c . dynClient . GetItem ( getParams )
if err != nil {
return nil , err
}
var val string
if v , ok := resp . Item [ "Digest" ] ; ok && v . S != nil {
val = * v . S
}
sum , err := hex . DecodeString ( val )
if err != nil || len ( sum ) != md5 . Size {
return nil , errors . New ( "invalid md5" )
}
return sum , nil
}
// store the hash of the state to that clients can check for stale state files.
func ( c * RemoteClient ) putMD5 ( sum [ ] byte ) error {
if c . lockTable == "" {
return nil
}
if len ( sum ) != md5 . Size {
return errors . New ( "invalid payload md5" )
}
putParams := & dynamodb . PutItemInput {
Item : map [ string ] * dynamodb . AttributeValue {
"LockID" : { S : aws . String ( c . lockPath ( ) + stateIDSuffix ) } ,
"Digest" : { S : aws . String ( hex . EncodeToString ( sum ) ) } ,
} ,
TableName : aws . String ( c . lockTable ) ,
}
_ , err := c . dynClient . PutItem ( putParams )
if err != nil {
log . Printf ( "[WARNING] failed to record state serial in dynamodb: %s" , err )
}
return nil
}
// remove the hash value for a deleted state
func ( c * RemoteClient ) deleteMD5 ( ) error {
if c . lockTable == "" {
return nil
}
params := & dynamodb . DeleteItemInput {
Key : map [ string ] * dynamodb . AttributeValue {
"LockID" : { S : aws . String ( c . lockPath ( ) + stateIDSuffix ) } ,
} ,
TableName : aws . String ( c . lockTable ) ,
}
if _ , err := c . dynClient . DeleteItem ( params ) ; err != nil {
return err
}
return nil
}
func ( c * RemoteClient ) getLockInfo ( ) ( * state . LockInfo , error ) {
getParams := & dynamodb . GetItemInput {
Key : map [ string ] * dynamodb . AttributeValue {