diff --git a/internal/backend/remote-state/s3/backend_test.go b/internal/backend/remote-state/s3/backend_test.go index e36e1a7b28..840d1ae723 100644 --- a/internal/backend/remote-state/s3/backend_test.go +++ b/internal/backend/remote-state/s3/backend_test.go @@ -1809,6 +1809,30 @@ func TestBackendWrongRegion(t *testing.T) { } } +func TestBackendS3ObjectLock(t *testing.T) { + testACC(t) + + ctx := context.TODO() + + bucketName := fmt.Sprintf("terraform-remote-s3-test-%x", time.Now().Unix()) + keyName := "testState" + + b := backend.TestBackendConfig(t, New(), backend.TestWrapConfig(map[string]interface{}{ + "bucket": bucketName, + "key": keyName, + "encrypt": true, + "region": "us-west-1", + })).(*Backend) + + createS3Bucket(ctx, t, b.s3Client, bucketName, b.awsConfig.Region, + s3BucketWithVersioning, + s3BucketWithObjectLock(s3types.ObjectLockRetentionModeCompliance), + ) + defer deleteS3Bucket(ctx, t, b.s3Client, bucketName, b.awsConfig.Region) + + backend.TestBackendStates(t, b) +} + func TestKeyEnv(t *testing.T) { testACC(t) @@ -2156,17 +2180,32 @@ func checkStateList(b backend.Backend, expected []string) error { return nil } -func createS3Bucket(ctx context.Context, t *testing.T, s3Client *s3.Client, bucketName, region string) { +type createS3BucketOptions struct { + versioning bool + objectLockMode s3types.ObjectLockRetentionMode +} + +type createS3BucketOptionsFunc func(*createS3BucketOptions) + +func createS3Bucket(ctx context.Context, t *testing.T, s3Client *s3.Client, bucketName, region string, optFns ...createS3BucketOptionsFunc) { t.Helper() + var opts createS3BucketOptions + for _, f := range optFns { + f(&opts) + } + createBucketReq := &s3.CreateBucketInput{ - Bucket: &bucketName, + Bucket: aws.String(bucketName), } if region != "us-east-1" { createBucketReq.CreateBucketConfiguration = &s3types.CreateBucketConfiguration{ LocationConstraint: s3types.BucketLocationConstraint(region), } } + if opts.objectLockMode != "" { + createBucketReq.ObjectLockEnabledForBucket = true + } // Be clear about what we're doing in case the user needs to clean // this up later. @@ -2175,6 +2214,46 @@ func createS3Bucket(ctx context.Context, t *testing.T, s3Client *s3.Client, buck if err != nil { t.Fatal("failed to create test S3 bucket:", err) } + + if opts.versioning { + _, err := s3Client.PutBucketVersioning(ctx, &s3.PutBucketVersioningInput{ + Bucket: aws.String(bucketName), + VersioningConfiguration: &s3types.VersioningConfiguration{ + Status: s3types.BucketVersioningStatusEnabled, + }, + }) + if err != nil { + t.Fatalf("failed enabling versioning: %s", err) + } + } + + if opts.objectLockMode != "" { + _, err = s3Client.PutObjectLockConfiguration(ctx, &s3.PutObjectLockConfigurationInput{ + Bucket: aws.String(bucketName), + ObjectLockConfiguration: &s3types.ObjectLockConfiguration{ + ObjectLockEnabled: s3types.ObjectLockEnabledEnabled, + Rule: &s3types.ObjectLockRule{ + DefaultRetention: &s3types.DefaultRetention{ + Days: 1, + Mode: opts.objectLockMode, + }, + }, + }, + }) + if err != nil { + t.Fatalf("failed enabling object locking: %s", err) + } + } +} + +func s3BucketWithVersioning(opts *createS3BucketOptions) { + opts.versioning = true +} + +func s3BucketWithObjectLock(mode s3types.ObjectLockRetentionMode) createS3BucketOptionsFunc { + return func(opts *createS3BucketOptions) { + opts.objectLockMode = mode + } } func deleteS3Bucket(ctx context.Context, t *testing.T, s3Client *s3.Client, bucketName, region string) { diff --git a/internal/backend/remote-state/s3/client.go b/internal/backend/remote-state/s3/client.go index 8135880ad1..c4528e861e 100644 --- a/internal/backend/remote-state/s3/client.go +++ b/internal/backend/remote-state/s3/client.go @@ -190,11 +190,14 @@ func (c *RemoteClient) Put(data []byte) error { contentType := "application/json" + sum := md5.Sum(data) + input := &s3.PutObjectInput{ - ContentType: aws.String(contentType), - Body: bytes.NewReader(data), - Bucket: aws.String(c.bucketName), - Key: aws.String(c.path), + ContentType: aws.String(contentType), + Body: bytes.NewReader(data), + Bucket: aws.String(c.bucketName), + Key: aws.String(c.path), + ChecksumAlgorithm: s3types.ChecksumAlgorithmSha256, } if c.serverSideEncryption { @@ -222,7 +225,6 @@ func (c *RemoteClient) Put(data []byte) error { return fmt.Errorf("failed to upload state: %s", err) } - sum := md5.Sum(data) if err := c.putMD5(ctx, sum[:]); err != nil { // if this errors out, we unfortunately have to error out altogether, // since the next Get will inevitably fail. diff --git a/internal/backend/remote-state/s3/client_test.go b/internal/backend/remote-state/s3/client_test.go index 818c4fd900..2197e766f9 100644 --- a/internal/backend/remote-state/s3/client_test.go +++ b/internal/backend/remote-state/s3/client_test.go @@ -8,9 +8,12 @@ import ( "context" "crypto/md5" "fmt" + "io" "testing" "time" + "github.com/aws/aws-sdk-go-v2/feature/s3/manager" + s3types "github.com/aws/aws-sdk-go-v2/service/s3/types" "github.com/hashicorp/terraform/internal/backend" "github.com/hashicorp/terraform/internal/states/remote" "github.com/hashicorp/terraform/internal/states/statefile" @@ -333,3 +336,50 @@ func TestRemoteClient_stateChecksum(t *testing.T) { t.Fatal(err) } } + +func TestRemoteClientPutLargeUploadWithObjectLock(t *testing.T) { + testACC(t) + + ctx := context.TODO() + + bucketName := fmt.Sprintf("terraform-remote-s3-test-%x", time.Now().Unix()) + keyName := "testState" + + b := backend.TestBackendConfig(t, New(), backend.TestWrapConfig(map[string]interface{}{ + "bucket": bucketName, + "key": keyName, + })).(*Backend) + + createS3Bucket(ctx, t, b.s3Client, bucketName, b.awsConfig.Region, + s3BucketWithVersioning, + s3BucketWithObjectLock(s3types.ObjectLockRetentionModeCompliance), + ) + defer deleteS3Bucket(ctx, t, b.s3Client, bucketName, b.awsConfig.Region) + + s1, err := b.StateMgr(backend.DefaultStateName) + if err != nil { + t.Fatal(err) + } + client := s1.(*remote.State).Client + + var state bytes.Buffer + dataW := io.LimitReader(neverEnding('x'), manager.DefaultUploadPartSize*2) + _, err = state.ReadFrom(dataW) + if err != nil { + t.Fatalf("writing dummy data: %s", err) + } + + err = client.Put(state.Bytes()) + if err != nil { + t.Fatalf("putting data: %s", err) + } +} + +type neverEnding byte + +func (b neverEnding) Read(p []byte) (n int, err error) { + for i := range p { + p[i] = byte(b) + } + return len(p), nil +}