diff --git a/internal/terraform/context_apply_action_test.go b/internal/terraform/context_apply_action_test.go new file mode 100644 index 0000000000..01922d8967 --- /dev/null +++ b/internal/terraform/context_apply_action_test.go @@ -0,0 +1,304 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package terraform + +import ( + "testing" + + "github.com/hashicorp/terraform/internal/addrs" + "github.com/hashicorp/terraform/internal/configs/configschema" + "github.com/hashicorp/terraform/internal/plans" + "github.com/hashicorp/terraform/internal/providers" + testing_provider "github.com/hashicorp/terraform/internal/providers/testing" + "github.com/hashicorp/terraform/internal/states" + "github.com/hashicorp/terraform/internal/tfdiags" + "github.com/zclconf/go-cty/cty" +) + +func TestContext2Apply_actions(t *testing.T) { + for name, tc := range map[string]struct { + module map[string]string + mode plans.Mode + prevRunState *states.State + events []providers.InvokeActionEvent + callingInvokeReturnsDiagnostics tfdiags.Diagnostics + + expectInvokeActionCalled bool + + expectDiagnostics tfdiags.Diagnostics + }{ + "unreferenced": { + module: map[string]string{ + "main.tf": ` +action "test_unlinked" "hello" {} + `, + }, + expectInvokeActionCalled: false, + }, + + "before_create triggered": { + module: map[string]string{ + "main.tf": ` +action "test_unlinked" "hello" {} +resource "test_object" "a" { + lifecycle { + action_trigger { + events = [before_create] + actions = [action.test_unlinked.hello] + } + } +} +`, + }, + expectInvokeActionCalled: true, + }, + + "after_create triggered": { + module: map[string]string{ + "main.tf": ` +action "test_unlinked" "hello" {} +resource "test_object" "a" { + lifecycle { + action_trigger { + events = [after_create] + actions = [action.test_unlinked.hello] + } + } +} +`, + }, + expectInvokeActionCalled: true, + }, + + "before_update triggered": { + module: map[string]string{ + "main.tf": ` +action "test_unlinked" "hello" {} +resource "test_object" "a" { + name = "new name" + lifecycle { + action_trigger { + events = [before_update] + actions = [action.test_unlinked.hello] + } + } +} +`, + }, + prevRunState: states.BuildState(func(s *states.SyncState) { + s.SetResourceInstanceCurrent( + addrs.Resource{ + Mode: addrs.ManagedResourceMode, + Type: "test_object", + Name: "a", + }.Instance(addrs.NoKey).Absolute(addrs.RootModuleInstance), + &states.ResourceInstanceObjectSrc{ + Status: states.ObjectReady, + AttrsJSON: []byte(`{"name":"old name"}`), + }, + addrs.AbsProviderConfig{ + Provider: addrs.NewDefaultProvider("test"), + Module: addrs.RootModule, + }, + ) + }), + expectInvokeActionCalled: true, + }, + + "after_update triggered": { + module: map[string]string{ + "main.tf": ` +action "test_unlinked" "hello" {} +resource "test_object" "a" { + name = "new name" + lifecycle { + action_trigger { + events = [after_update] + actions = [action.test_unlinked.hello] + } + } +} +`, + }, + prevRunState: states.BuildState(func(s *states.SyncState) { + s.SetResourceInstanceCurrent( + addrs.Resource{ + Mode: addrs.ManagedResourceMode, + Type: "test_object", + Name: "a", + }.Instance(addrs.NoKey).Absolute(addrs.RootModuleInstance), + &states.ResourceInstanceObjectSrc{ + Status: states.ObjectReady, + AttrsJSON: []byte(`{"name":"old"}`), + }, + addrs.AbsProviderConfig{ + Provider: addrs.NewDefaultProvider("test"), + Module: addrs.RootModule, + }, + ) + }), + expectInvokeActionCalled: true, + }, + + "before_create failing": { + module: map[string]string{ + "main.tf": ` +action "test_unlinked" "hello" {} +resource "test_object" "a" { + lifecycle { + action_trigger { + events = [before_create] + actions = [action.test_unlinked.hello] + } + } +} +`, + }, + expectInvokeActionCalled: true, + events: []providers.InvokeActionEvent{ + providers.InvokeActionEvent_Completed{ + Diagnostics: tfdiags.Diagnostics{ + tfdiags.Sourceless( + tfdiags.Error, + "test case for failing", + "this simulates a provider failing", + ), + }, + }, + }, + + expectDiagnostics: tfdiags.Diagnostics{ + tfdiags.Sourceless( + tfdiags.Error, + "test case for failing", + "this simulates a provider failing", + ), + }, + }, + + "before_create failing to call invoke": { + module: map[string]string{ + "main.tf": ` +action "test_unlinked" "hello" {} +resource "test_object" "a" { + lifecycle { + action_trigger { + events = [before_create] + actions = [action.test_unlinked.hello] + } + } +} +`, + }, + expectInvokeActionCalled: true, + callingInvokeReturnsDiagnostics: tfdiags.Diagnostics{ + tfdiags.Sourceless( + tfdiags.Error, + "test case for failing", + "this simulates a provider failing before the action is invoked", + ), + }, + expectDiagnostics: tfdiags.Diagnostics{ + tfdiags.Sourceless( + tfdiags.Error, + "test case for failing", + "this simulates a provider failing before the action is invoked", + ), + }, + }, + } { + t.Run(name, func(t *testing.T) { + m := testModuleInline(t, tc.module) + + invokeActionCalls := []providers.InvokeActionRequest{} + + p := &testing_provider.MockProvider{ + GetProviderSchemaResponse: &providers.GetProviderSchemaResponse{ + Actions: map[string]providers.ActionSchema{ + "test_unlinked": { + ConfigSchema: &configschema.Block{ + Attributes: map[string]*configschema.Attribute{ + "attr": { + Type: cty.String, + Optional: true, + }, + }, + }, + + Unlinked: &providers.UnlinkedAction{}, + }, + }, + ResourceTypes: map[string]providers.Schema{ + "test_object": { + Body: &configschema.Block{ + Attributes: map[string]*configschema.Attribute{ + "name": { + Type: cty.String, + Optional: true, + }, + }, + }, + }, + }, + }, + InvokeActionFn: func(req providers.InvokeActionRequest) providers.InvokeActionResponse { + invokeActionCalls = append(invokeActionCalls, req) + + if len(tc.callingInvokeReturnsDiagnostics) > 0 { + return providers.InvokeActionResponse{ + Diagnostics: tc.callingInvokeReturnsDiagnostics, + } + } + + defaultEvents := []providers.InvokeActionEvent{} + defaultEvents = append(defaultEvents, providers.InvokeActionEvent_Progress{ + Message: "Hello world!", + }) + defaultEvents = append(defaultEvents, providers.InvokeActionEvent_Completed{}) + + events := defaultEvents + if len(tc.events) > 0 { + events = tc.events + } + + return providers.InvokeActionResponse{ + Events: func(yield func(providers.InvokeActionEvent) bool) { + for _, event := range events { + if !yield(event) { + return + } + } + }, + } + }, + } + + ctx := testContext2(t, &ContextOpts{ + Providers: map[addrs.Provider]providers.Factory{ + // The providers never actually going to get called here, we should + // catch the error long before anything happens. + addrs.NewDefaultProvider("test"): testProviderFuncFixed(p), + }, + }) + + // Just a sanity check that the module is valid + diags := ctx.Validate(m, &ValidateOpts{}) + tfdiags.AssertNoDiagnostics(t, diags) + + plan, diags := ctx.Plan(m, tc.prevRunState, SimplePlanOpts(plans.NormalMode, InputValues{})) + tfdiags.AssertNoDiagnostics(t, diags) + + _, diags = ctx.Apply(plan, m, nil) + if tc.expectDiagnostics.HasErrors() { + tfdiags.AssertDiagnosticsMatch(t, diags, tc.expectDiagnostics) + return + } + tfdiags.AssertNoDiagnostics(t, diags) + + if tc.expectInvokeActionCalled && len(invokeActionCalls) == 0 { + t.Fatalf("expected invoke action to be called, but it was not") + } + }) + } +} diff --git a/internal/terraform/node_action_apply.go b/internal/terraform/node_action_apply.go new file mode 100644 index 0000000000..38dc331292 --- /dev/null +++ b/internal/terraform/node_action_apply.go @@ -0,0 +1,140 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package terraform + +import ( + "fmt" + "sort" + + "github.com/hashicorp/terraform/internal/actions" + "github.com/hashicorp/terraform/internal/addrs" + "github.com/hashicorp/terraform/internal/plans" + "github.com/hashicorp/terraform/internal/providers" + "github.com/hashicorp/terraform/internal/tfdiags" +) + +type nodeActionApply struct { + TriggeringResourceaddrs addrs.AbsResourceInstance + ActionInvocations []*plans.ActionInvocationInstance +} + +var ( + _ GraphNodeExecutable = (*nodeActionApply)(nil) + _ GraphNodeReferencer = (*nodeActionApply)(nil) +) + +func (n *nodeActionApply) Execute(ctx EvalContext, _ walkOperation) (diags tfdiags.Diagnostics) { + return invokeActions(ctx, n.ActionInvocations) +} + +func invokeActions(ctx EvalContext, actionInvocations []*plans.ActionInvocationInstance) tfdiags.Diagnostics { + var diags tfdiags.Diagnostics + // First we order the action invocations by their trigger block index and events list index. + // This way we have the correct order of execution. + orderedActionInvocations := make([]*plans.ActionInvocationInstance, len(actionInvocations)) + copy(orderedActionInvocations, actionInvocations) + sort.Slice(orderedActionInvocations, func(i, j int) bool { + if orderedActionInvocations[i].ActionTriggerBlockIndex == orderedActionInvocations[j].ActionTriggerBlockIndex { + return orderedActionInvocations[i].ActionsListIndex < orderedActionInvocations[j].ActionsListIndex + } + return orderedActionInvocations[i].ActionTriggerBlockIndex < orderedActionInvocations[j].ActionTriggerBlockIndex + }) + + // Now we ensure we have an expanded action instance for each action invocations. + orderedActionData := make([]*actions.ActionData, len(orderedActionInvocations)) + for i, invocation := range orderedActionInvocations { + ai, ok := ctx.Actions().GetActionInstance(invocation.Addr) + if !ok { + diags = diags.Append(tfdiags.Sourceless( + tfdiags.Error, + "Action instance not found", + "Could not find action instance for address "+invocation.Addr.String(), + )) + return diags + } + + orderedActionData[i] = ai + } + + // Now we have everything in place to execute the actions in the correct order. + // TODO: Handle verifying the condition here, if we have any. + + // We run every action sequentially, as the order of execution is important. We also abort if + // an action fails, as we don't want to continue executing actions or nodes that depend on it. + + for i, actionData := range orderedActionData { + ai := orderedActionInvocations[i] + provider, _, err := getProvider(ctx, actionData.ProviderAddr) + if err != nil { + diags = diags.Append(tfdiags.Sourceless( + tfdiags.Error, + "Failed to get provider", + fmt.Sprintf("Failed to get provider: %s", err), + )) + return diags + } + // TODO: Change the hook identity to be sth contained in hooks + hookIdentity := addrs.AbsActionInvocationInstance{ + TriggeringResource: ai.TriggeringResourceAddr, + Action: ai.Addr, + TriggerIndex: ai.ActionTriggerBlockIndex, + } + + ctx.Hook(func(h Hook) (HookAction, error) { + return h.StartAction(hookIdentity) + }) + resp := provider.InvokeAction(providers.InvokeActionRequest{ + ActionType: orderedActionInvocations[i].Addr.Action.Action.Type, + PlannedActionData: actionData.ConfigValue, + }) + + diags = diags.Append(resp.Diagnostics) + if resp.Diagnostics.HasErrors() { + return diags + } + + for event := range resp.Events { + switch ev := event.(type) { + case providers.InvokeActionEvent_Progress: + ctx.Hook(func(h Hook) (HookAction, error) { + return h.ProgressAction(hookIdentity, ev.Message) + }) + case providers.InvokeActionEvent_Completed: + diags = diags.Append(ev.Diagnostics) + ctx.Hook(func(h Hook) (HookAction, error) { + return h.CompleteAction(hookIdentity, ev.Diagnostics.Err()) + }) + if ev.Diagnostics.HasErrors() { + // TODO: We would want to add some warning / error telling the user how to recover + // from this, or maybe attach this info to the diagnostics sent by the provider. + // For now we just return the diagnostics. + + return diags + } + default: + panic(fmt.Sprintf("unexpected action event type %T", ev)) + } + } + } + + return diags +} + +func (n *nodeActionApply) ModulePath() addrs.Module { + return n.TriggeringResourceaddrs.Module.Module() +} + +func (n *nodeActionApply) References() []*addrs.Reference { + var refs []*addrs.Reference + + // We reference each action instance that we are going to execute. + for _, invocation := range n.ActionInvocations { + // TODO: Think about how to get a source range + refs = append(refs, &addrs.Reference{ + Subject: invocation.Addr, + }) + } + + return refs +} diff --git a/internal/terraform/node_resource_apply_instance.go b/internal/terraform/node_resource_apply_instance.go index fbaecc29bc..d22da735a1 100644 --- a/internal/terraform/node_resource_apply_instance.go +++ b/internal/terraform/node_resource_apply_instance.go @@ -36,16 +36,19 @@ type NodeApplyableResourceInstance struct { // it might contain addresses that have nothing to do with the resource // that this node represents, which the node itself must therefore ignore. forceReplace []addrs.AbsResourceInstance + + beforeActionInvocations []*plans.ActionInvocationInstance } var ( - _ GraphNodeConfigResource = (*NodeApplyableResourceInstance)(nil) - _ GraphNodeResourceInstance = (*NodeApplyableResourceInstance)(nil) - _ GraphNodeCreator = (*NodeApplyableResourceInstance)(nil) - _ GraphNodeReferencer = (*NodeApplyableResourceInstance)(nil) - _ GraphNodeDeposer = (*NodeApplyableResourceInstance)(nil) - _ GraphNodeExecutable = (*NodeApplyableResourceInstance)(nil) - _ GraphNodeAttachDependencies = (*NodeApplyableResourceInstance)(nil) + _ GraphNodeConfigResource = (*NodeApplyableResourceInstance)(nil) + _ GraphNodeResourceInstance = (*NodeApplyableResourceInstance)(nil) + _ GraphNodeCreator = (*NodeApplyableResourceInstance)(nil) + _ GraphNodeReferencer = (*NodeApplyableResourceInstance)(nil) + _ GraphNodeDeposer = (*NodeApplyableResourceInstance)(nil) + _ GraphNodeExecutable = (*NodeApplyableResourceInstance)(nil) + _ GraphNodeAttachDependencies = (*NodeApplyableResourceInstance)(nil) + _ GraphNodeAttachBeforeActions = (*NodeApplyableResourceInstance)(nil) ) // GraphNodeCreator @@ -210,6 +213,11 @@ func (n *NodeApplyableResourceInstance) managedResourceExecute(ctx EvalContext) var createBeforeDestroyEnabled bool var deposedKey states.DeposedKey + diags = diags.Append(invokeActions(ctx, n.beforeActionInvocations)) + if diags.HasErrors() { + return diags + } + addr := n.ResourceInstanceAddr().Resource _, providerSchema, err := getProvider(ctx, n.ResolvedProvider) diags = diags.Append(err) @@ -467,6 +475,10 @@ func (n *NodeApplyableResourceInstance) checkPlannedChange(ctx EvalContext, plan return diags } +func (n *NodeApplyableResourceInstance) AttachBeforeActions(ais []*plans.ActionInvocationInstance) { + n.beforeActionInvocations = ais +} + // maybeTainted takes the resource addr, new value, planned change, and possible // error from an apply operation and return a new instance object marked as // tainted if it appears that a create operation has failed. diff --git a/internal/terraform/transform_diff.go b/internal/terraform/transform_diff.go index 90df9a9d86..c4824c909c 100644 --- a/internal/terraform/transform_diff.go +++ b/internal/terraform/transform_diff.go @@ -80,6 +80,43 @@ func (t *DiffTransformer) Transform(g *Graph) error { resourceNodes.Put(rAddr, append(resourceNodes.Get(rAddr), rn)) } + // We will partition the action invocations into two groups based on if they are supposed to + // run before or after the resource change. + // We want to attach before-triggered action invocations to the triggering resource instance + // to be run as part of the apply phase. + // The after-triggered action invocations will be run as part of a separate node + // that will be connected to the resource instance nodes. + runBeforeNode := addrs.MakeMap[addrs.AbsResourceInstance, []*plans.ActionInvocationInstance]() + runAfterNode := addrs.MakeMap[addrs.AbsResourceInstance, []*plans.ActionInvocationInstance]() + for _, aiSrc := range changes.ActionInvocations { + ai, err := aiSrc.Decode() + if err != nil { + diags = diags.Append(tfdiags.Sourceless( + tfdiags.Error, + "Invalid action invocation", + fmt.Sprintf("The plan contains an invalid action invocation for %s: %s", aiSrc.Addr, err), + )) + return diags.Err() + } + + var targetMap addrs.Map[addrs.AbsResourceInstance, []*plans.ActionInvocationInstance] + switch ai.TriggerEvent { + case configs.BeforeCreate, configs.BeforeUpdate, configs.BeforeDestroy: + targetMap = runBeforeNode + case configs.AfterCreate, configs.AfterUpdate, configs.AfterDestroy: + targetMap = runAfterNode + default: + panic("I don't know when to run this action invocation") + } + + basis := []*plans.ActionInvocationInstance{} + if targetMap.Has(ai.TriggeringResourceAddr) { + basis = targetMap.Get(ai.TriggeringResourceAddr) + } + + targetMap.Put(ai.TriggeringResourceAddr, append(basis, ai)) + } + for _, rc := range changes.Resources { addr := rc.Addr dk := rc.DeposedKey @@ -179,6 +216,14 @@ func (t *DiffTransformer) Transform(g *Graph) error { log.Printf("[TRACE] DiffTransformer: %s will be represented by %s", addr, dag.VertexName(node)) } + // We only need to attach actions to updating nodes for now + // (until before_destroy & after destroy are added) + if beforeActions, ok := runBeforeNode.GetOk(addr); ok { + if attachBeforeActionsNode, ok := node.(*NodeApplyableResourceInstance); ok { + attachBeforeActionsNode.beforeActionInvocations = beforeActions + } + } + g.Add(node) for _, rsrcNode := range resourceNodes.Get(addr.ConfigResource()) { g.Connect(dag.BasicEdge(node, rsrcNode)) @@ -229,6 +274,38 @@ func (t *DiffTransformer) Transform(g *Graph) error { } + // Create a node for each resource instance that invokes all the action invocations that are + // supposed to run after the resource change. + for key, value := range runAfterNode.Iter() { + if len(value) == 0 { + continue + } + + log.Printf("[TRACE] DiffTransformer: adding action invocations to run after %s", key) + actionNode := &nodeActionApply{ + TriggeringResourceaddrs: key, + ActionInvocations: value, + } + + // Find the config resource associated with this. While for each resource instance all + // actions need to run in sequence, for different resource instances they can run in + // parallel. + resourceNode, ok := resourceNodes.GetOk(key.ConfigResource()) + if !ok { + diags = diags.Append(tfdiags.Sourceless( + tfdiags.Error, + "Missing resource node for action invocations", + fmt.Sprintf("Could not find resource node for action invocations for %s", key), + )) + continue + } + + g.Add(actionNode) + for _, rNode := range resourceNode { + g.Connect(dag.BasicEdge(actionNode, rNode)) + } + } + log.Printf("[TRACE] DiffTransformer complete") return diags.Err()