diff --git a/internal/backend/remote-state/s3/backend.go b/internal/backend/remote-state/s3/backend.go index aa4aa4bbc5..cf651c50ab 100644 --- a/internal/backend/remote-state/s3/backend.go +++ b/internal/backend/remote-state/s3/backend.go @@ -4,9 +4,7 @@ package s3 import ( - "context" "encoding/base64" - "errors" "fmt" "os" "strings" @@ -21,6 +19,7 @@ import ( "github.com/hashicorp/terraform/internal/tfdiags" "github.com/hashicorp/terraform/version" "github.com/zclconf/go-cty/cty" + "github.com/zclconf/go-cty/cty/gocty" ) // New creates a new backend for S3 remote state. @@ -247,9 +246,24 @@ func New() backend.Backend { }, } - result := &Backend{Backend: s} - result.Backend.ConfigureFunc = result.configure - return result + return &Backend{Backend: s} +} + +type Backend struct { + *schema.Backend + + // The fields below are set from configure + s3Client *s3.S3 + dynClient *dynamodb.DynamoDB + + bucketName string + keyName string + serverSideEncryption bool + customerEncryptionKey []byte + acl string + kmsKeyID string + ddbTable string + workspaceKeyPrefix string } func (b *Backend) PrepareConfig(obj cty.Value) (cty.Value, tfdiags.Diagnostics) { @@ -337,78 +351,62 @@ func (b *Backend) PrepareConfig(obj cty.Value) (cty.Value, tfdiags.Diagnostics) return obj, diags } -type Backend struct { - *schema.Backend - - // The fields below are set from configure - s3Client *s3.S3 - dynClient *dynamodb.DynamoDB - - bucketName string - keyName string - serverSideEncryption bool - customerEncryptionKey []byte - acl string - kmsKeyID string - ddbTable string - workspaceKeyPrefix string -} - -func (b *Backend) configure(ctx context.Context) error { - if b.s3Client != nil { - return nil +func (b *Backend) Configure(obj cty.Value) tfdiags.Diagnostics { + var diags tfdiags.Diagnostics + if obj.IsNull() { + return diags } - // Grab the resource data - data := schema.FromContextBackendConfig(ctx) - - if !data.Get("skip_region_validation").(bool) { - if err := awsbase.ValidateRegion(data.Get("region").(string)); err != nil { - return err - } + var region string + if v, ok := stringAttrOk(obj, "region"); ok { + region = v } - b.bucketName = data.Get("bucket").(string) - b.keyName = data.Get("key").(string) - b.acl = data.Get("acl").(string) - b.workspaceKeyPrefix = data.Get("workspace_key_prefix").(string) - b.serverSideEncryption = data.Get("encrypt").(bool) - b.kmsKeyID = data.Get("kms_key_id").(string) - b.ddbTable = data.Get("dynamodb_table").(string) - - customerKeyString := data.Get("sse_customer_key").(string) - if customerKeyString != "" { - if b.kmsKeyID != "" { - return errors.New(encryptionKeyConflictError) + if boolAttr(obj, "skip_region_validation") { + if err := awsbase.ValidateRegion(region); err != nil { + diags = diags.Append(tfdiags.AttributeValue( + tfdiags.Error, + "Invalid region value", + err.Error(), + cty.Path{cty.GetAttrStep{Name: "region"}}, + )) + return diags } + } - var err error - b.customerEncryptionKey, err = base64.StdEncoding.DecodeString(customerKeyString) - if err != nil { - return fmt.Errorf("Failed to decode sse_customer_key: %s", err.Error()) - } + b.bucketName = stringAttr(obj, "bucket") + b.keyName = stringAttr(obj, "key") + b.acl = stringAttr(obj, "acl") + b.workspaceKeyPrefix = stringAttrDefault(obj, "workspace_key_prefix", "env:") + b.serverSideEncryption = boolAttr(obj, "encrypt") + b.kmsKeyID = stringAttr(obj, "kms_key_id") + b.ddbTable = stringAttr(obj, "dynamodb_table") + + if customerKeyString, ok := stringAttrOk(obj, "sse_customer_key"); ok { + // Validation is handled in PrepareConfig, so ignore it here + b.customerEncryptionKey, _ = base64.StdEncoding.DecodeString(customerKeyString) } cfg := &awsbase.Config{ - AccessKey: data.Get("access_key").(string), - AssumeRoleARN: data.Get("role_arn").(string), - AssumeRoleDurationSeconds: data.Get("assume_role_duration_seconds").(int), - AssumeRoleExternalID: data.Get("external_id").(string), - AssumeRolePolicy: data.Get("assume_role_policy").(string), - AssumeRoleSessionName: data.Get("session_name").(string), + AccessKey: stringAttr(obj, "access_key"), + AssumeRoleARN: stringAttr(obj, "role_arn"), + AssumeRoleDurationSeconds: intAttr(obj, "assume_role_duration_seconds"), + AssumeRoleExternalID: stringAttr(obj, "external_id"), + AssumeRolePolicy: stringAttr(obj, "assume_role_policy"), + AssumeRoleSessionName: stringAttr(obj, "session_name"), CallerDocumentationURL: "https://www.terraform.io/docs/language/settings/backends/s3.html", CallerName: "S3 Backend", - CredsFilename: data.Get("shared_credentials_file").(string), + CredsFilename: stringAttr(obj, "shared_credentials_file"), DebugLogging: logging.IsDebugOrHigher(), - IamEndpoint: data.Get("iam_endpoint").(string), - MaxRetries: data.Get("max_retries").(int), - Profile: data.Get("profile").(string), - Region: data.Get("region").(string), - SecretKey: data.Get("secret_key").(string), - SkipCredsValidation: data.Get("skip_credentials_validation").(bool), - SkipMetadataApiCheck: data.Get("skip_metadata_api_check").(bool), - StsEndpoint: data.Get("sts_endpoint").(string), - Token: data.Get("token").(string), + IamEndpoint: stringAttr(obj, "iam_endpoint"), + MaxRetries: intAttrDefault(obj, "max_retries", 5), + Profile: stringAttr(obj, "profile"), + Region: stringAttr(obj, "region"), + SecretKey: stringAttr(obj, "secret_key"), + SkipCredsValidation: boolAttr(obj, "skip_credentials_validation"), + SkipMetadataApiCheck: boolAttr(obj, "skip_metadata_api_check"), + StsEndpoint: stringAttr(obj, "sts_endpoint"), + Token: stringAttr(obj, "token"), UserAgentProducts: []*awsbase.UserAgentProduct{ {Name: "APN", Version: "1.0"}, {Name: "HashiCorp", Version: "1.0"}, @@ -416,58 +414,124 @@ func (b *Backend) configure(ctx context.Context) error { }, } - if policyARNSet := data.Get("assume_role_policy_arns").(*schema.Set); policyARNSet.Len() > 0 { - for _, policyARNRaw := range policyARNSet.List() { - policyARN, ok := policyARNRaw.(string) - - if !ok { - continue + if policyARNSet := obj.GetAttr("assume_role_policy_arns"); !policyARNSet.IsNull() { + policyARNSet.ForEachElement(func(key, val cty.Value) (stop bool) { + v, ok := stringValueOk(val) + if ok { + cfg.AssumeRolePolicyARNs = append(cfg.AssumeRolePolicyARNs, v) } - - cfg.AssumeRolePolicyARNs = append(cfg.AssumeRolePolicyARNs, policyARN) - } + return + }) } - if tagMap := data.Get("assume_role_tags").(map[string]interface{}); len(tagMap) > 0 { - cfg.AssumeRoleTags = make(map[string]string) - - for k, vRaw := range tagMap { - v, ok := vRaw.(string) - - if !ok { - continue + if tagMap := obj.GetAttr("assume_role_tags"); !tagMap.IsNull() { + cfg.AssumeRoleTags = make(map[string]string, tagMap.LengthInt()) + tagMap.ForEachElement(func(key, val cty.Value) (stop bool) { + k := stringValue(key) + v, ok := stringValueOk(val) + if ok { + cfg.AssumeRoleTags[k] = v } - - cfg.AssumeRoleTags[k] = v - } + return + }) } - if transitiveTagKeySet := data.Get("assume_role_transitive_tag_keys").(*schema.Set); transitiveTagKeySet.Len() > 0 { - for _, transitiveTagKeyRaw := range transitiveTagKeySet.List() { - transitiveTagKey, ok := transitiveTagKeyRaw.(string) - - if !ok { - continue + if transitiveTagKeySet := obj.GetAttr("assume_role_transitive_tag_keys"); !transitiveTagKeySet.IsNull() { + transitiveTagKeySet.ForEachElement(func(key, val cty.Value) (stop bool) { + v, ok := stringValueOk(val) + if ok { + cfg.AssumeRoleTransitiveTagKeys = append(cfg.AssumeRoleTransitiveTagKeys, v) } - - cfg.AssumeRoleTransitiveTagKeys = append(cfg.AssumeRoleTransitiveTagKeys, transitiveTagKey) - } + return + }) } sess, err := awsbase.GetSession(cfg) if err != nil { - return fmt.Errorf("error configuring S3 Backend: %w", err) + diags = diags.Append(tfdiags.Sourceless( + tfdiags.Error, + "Failed to configure AWS client", + fmt.Sprintf(`The "S3" backend encountered an unexpected error while creating the AWS client: %s`, err), + )) + return diags } b.dynClient = dynamodb.New(sess.Copy(&aws.Config{ - Endpoint: aws.String(data.Get("dynamodb_endpoint").(string)), + Endpoint: aws.String(stringAttr(obj, "dynamodb_endpoint")), })) b.s3Client = s3.New(sess.Copy(&aws.Config{ - Endpoint: aws.String(data.Get("endpoint").(string)), - S3ForcePathStyle: aws.Bool(data.Get("force_path_style").(bool)), + Endpoint: aws.String(stringAttr(obj, "endpoint")), + S3ForcePathStyle: aws.Bool(boolAttr(obj, "force_path_style")), })) - return nil + return diags +} + +func stringValue(val cty.Value) string { + v, _ := stringValueOk(val) + return v +} + +func stringValueOk(val cty.Value) (string, bool) { + if val.IsNull() { + return "", false + } else { + return val.AsString(), true + } +} + +func stringAttr(obj cty.Value, name string) string { + return stringValue(obj.GetAttr(name)) +} + +func stringAttrOk(obj cty.Value, name string) (string, bool) { + return stringValueOk(obj.GetAttr(name)) +} + +func stringAttrDefault(obj cty.Value, name, def string) string { + if v, ok := stringAttrOk(obj, name); !ok { + return def + } else { + return v + } +} + +func boolAttr(obj cty.Value, name string) bool { + v, _ := boolAttrOk(obj, name) + return v +} + +func boolAttrOk(obj cty.Value, name string) (bool, bool) { + if val := obj.GetAttr(name); val.IsNull() { + return false, false + } else { + return val.True(), true + } +} + +func intAttr(obj cty.Value, name string) int { + v, _ := intAttrOk(obj, name) + return v +} + +func intAttrOk(obj cty.Value, name string) (int, bool) { + if val := obj.GetAttr(name); val.IsNull() { + return 0, false + } else { + var v int + if err := gocty.FromCtyValue(val, &v); err != nil { + return 0, false + } + return v, true + } +} + +func intAttrDefault(obj cty.Value, name string, def int) int { + if v, ok := intAttrOk(obj, name); !ok { + return def + } else { + return v + } } const encryptionKeyConflictError = `Only one of "kms_key_id" and "sse_customer_key" can be set. diff --git a/internal/backend/remote-state/s3/backend_test.go b/internal/backend/remote-state/s3/backend_test.go index feae781a03..85a62add9a 100644 --- a/internal/backend/remote-state/s3/backend_test.go +++ b/internal/backend/remote-state/s3/backend_test.go @@ -62,6 +62,9 @@ func TestBackendConfig(t *testing.T) { if *b.s3Client.Config.Region != "us-west-1" { t.Fatalf("Incorrect region was populated") } + if *b.s3Client.Config.MaxRetries != 5 { + t.Fatalf("Default max_retries was not set") + } if b.bucketName != "tf-test" { t.Fatalf("Incorrect bucketName was populated") } @@ -307,7 +310,8 @@ func TestBackendConfig_AssumeRole(t *testing.T) { testCase.Config["sts_endpoint"] = aws.StringValue(mockStsSession.Config.Endpoint) } - diags := New().Configure(hcl2shim.HCL2ValueFromConfigValue(testCase.Config)) + b := New() + diags := b.Configure(populateSchema(t, b.ConfigSchema(), hcl2shim.HCL2ValueFromConfigValue(testCase.Config))) if diags.HasErrors() { for _, diag := range diags { @@ -917,11 +921,15 @@ func populateSchema(t *testing.T, schema *configschema.Block, value cty.Value) c func unmarshal(value cty.Value, ty cty.Type, path cty.Path) (cty.Value, error) { switch { case ty.IsPrimitiveType(): - val, err := unmarshalPrimitive(value, ty, path) - if err != nil { - return cty.NilVal, err - } - return val, nil + return value, nil + // case ty.IsListType(): + // return unmarshalList(value, ty.ElementType(), path) + case ty.IsSetType(): + return unmarshalSet(value, ty.ElementType(), path) + case ty.IsMapType(): + return unmarshalMap(value, ty.ElementType(), path) + // case ty.IsTupleType(): + // return unmarshalTuple(value, ty.TupleElementTypes(), path) case ty.IsObjectType(): return unmarshalObject(value, ty.AttributeTypes(), path) default: @@ -929,8 +937,45 @@ func unmarshal(value cty.Value, ty cty.Type, path cty.Path) (cty.Value, error) { } } -func unmarshalPrimitive(value cty.Value, ty cty.Type, path cty.Path) (cty.Value, error) { - return value, nil +func unmarshalSet(dec cty.Value, ety cty.Type, path cty.Path) (cty.Value, error) { + if dec.IsNull() { + return dec, nil + } + + length := dec.LengthInt() + + if length == 0 { + return cty.SetValEmpty(ety), nil + } + + vals := make([]cty.Value, 0, length) + dec.ForEachElement(func(key, val cty.Value) (stop bool) { + vals = append(vals, val) + return + }) + + return cty.SetVal(vals), nil +} + +func unmarshalMap(dec cty.Value, ety cty.Type, path cty.Path) (cty.Value, error) { + if dec.IsNull() { + return dec, nil + } + + length := dec.LengthInt() + + if length == 0 { + return cty.MapValEmpty(ety), nil + } + + vals := make(map[string]cty.Value, length) + dec.ForEachElement(func(key, val cty.Value) (stop bool) { + k := stringValue(key) + vals[k] = val + return + }) + + return cty.MapVal(vals), nil } func unmarshalObject(dec cty.Value, atys map[string]cty.Type, path cty.Path) (cty.Value, error) {