From 2a79f5fa53ab3e96612ad58a63a27dbcee463826 Mon Sep 17 00:00:00 2001 From: Samsondeen <40821565+dsa0x@users.noreply.github.com> Date: Tue, 3 Jun 2025 13:57:13 +0200 Subject: [PATCH] Use iterators for graph vertices (#36558) * Use iterators for graph vertices * use func filter * use type param instead of function filter One type parameter seem to be enough instead of 2 --- internal/dag/seq.go | 74 ++++++++++++++++ internal/dag/seq_test.go | 85 +++++++++++++++++++ .../moduletest/graph/test_graph_builder.go | 13 ++- .../moduletest/graph/transform_close_graph.go | 6 +- .../moduletest/graph/transform_context.go | 2 +- .../moduletest/graph/transform_providers.go | 9 +- .../graph/transform_state_cleanup.go | 8 +- 7 files changed, 171 insertions(+), 26 deletions(-) create mode 100644 internal/dag/seq.go create mode 100644 internal/dag/seq_test.go diff --git a/internal/dag/seq.go b/internal/dag/seq.go new file mode 100644 index 0000000000..4b2d643287 --- /dev/null +++ b/internal/dag/seq.go @@ -0,0 +1,74 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package dag + +import ( + "iter" + "slices" +) + +type VertexSeq[T Vertex] iter.Seq[T] + +func (seq VertexSeq[T]) Collect() []T { + return slices.Collect(iter.Seq[T](seq)) +} + +func (seq VertexSeq[T]) AsGeneric() VertexSeq[Vertex] { + return func(yield func(Vertex) bool) { + for v := range seq { + if !yield(v) { + return + } + } + } +} + +// Vertices returns an iterator over all the vertices in the graph. +func (g *Graph) VerticesSeq() VertexSeq[Vertex] { + return func(yield func(v Vertex) bool) { + for _, v := range g.vertices { + v, ok := v.(Vertex) + if !ok { + continue + } + if !yield(v) { + return + } + } + } +} + +// SelectSeq filters a sequence to include only elements that can be type-asserted to type U. +// It returns a new sequence containing only the matching elements. +// The yield function can return false to stop iteration early. +func SelectSeq[U Vertex](seq VertexSeq[Vertex]) VertexSeq[U] { + return func(yield func(U) bool) { + for v := range seq { + // if the item is not of the type we're looking for, skip it + u, ok := any(v).(U) + if !ok { + continue + } + if !yield(u) { + return + } + } + } +} + +// ExcludeSeq filters a sequence to exclude elements that can be type-asserted to type U. +// It returns a new sequence containing only the non-matching elements. +// The yield function can return false to stop iteration early. +func ExcludeSeq[U Vertex](seq VertexSeq[Vertex]) VertexSeq[Vertex] { + return func(yield func(Vertex) bool) { + for v := range seq { + if _, ok := any(v).(U); ok { + continue + } + if !yield(v) { + return + } + } + } +} diff --git a/internal/dag/seq_test.go b/internal/dag/seq_test.go new file mode 100644 index 0000000000..8a7780007d --- /dev/null +++ b/internal/dag/seq_test.go @@ -0,0 +1,85 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package dag + +import ( + "testing" +) + +// Mock implementation of SeqVertex for testing +type MockVertex struct { + id int +} + +func (v MockVertex) ZeroValue() any { + return MockVertex{} +} + +type MockVertex2 struct { + id int +} + +func TestSelectSeq(t *testing.T) { + v1 := MockVertex{id: 1} + v11 := MockVertex{id: 11} + v2 := MockVertex2{id: 2} + vertices := Set{v1: v1, v11: v11, v2: v2} + + graph := &Graph{vertices: vertices} + seq := SelectSeq[MockVertex](graph.VerticesSeq()) + t.Run("Select objects of given type", func(t *testing.T) { + count := len(seq.Collect()) + if count != 2 { + t.Errorf("Expected 2, got %d", count) + } + }) + + t.Run("Returns empty when looking for incompatible types", func(t *testing.T) { + seq := SelectSeq[MockVertex2](seq.AsGeneric()) + count := len(seq.Collect()) + if count != 0 { + t.Errorf("Expected empty, got %d", count) + } + }) + + t.Run("Select objects of given interface", func(t *testing.T) { + seq := SelectSeq[interface{ ZeroValue() any }](graph.VerticesSeq()) + count := len(seq.Collect()) + if count != 2 { + t.Errorf("Expected 1, got %d", count) + } + }) +} + +func TestExcludeSeq(t *testing.T) { + v1 := MockVertex{id: 1} + v11 := MockVertex{id: 11} + v2 := MockVertex2{id: 2} + vertices := Set{v1: v1, v11: v11, v2: v2} + + graph := &Graph{vertices: vertices} + seq := ExcludeSeq[MockVertex](graph.VerticesSeq()) + t.Run("Exclude objects of given type", func(t *testing.T) { + count := len(seq.Collect()) + if count != 1 { + t.Errorf("Expected 1, got %d", count) + } + }) + + t.Run("Returns empty when looking for incompatible types", func(t *testing.T) { + seq := ExcludeSeq[MockVertex2](seq) + count := len(seq.Collect()) + if count != 0 { + t.Errorf("Expected empty, got %d", count) + } + }) + + t.Run("Exclude objects of given interface", func(t *testing.T) { + seq := ExcludeSeq[interface{ ZeroValue() any }](graph.VerticesSeq()) + count := len(seq.Collect()) + if count != 1 { + t.Errorf("Expected 1, got %d", count) + } + }) +} diff --git a/internal/moduletest/graph/test_graph_builder.go b/internal/moduletest/graph/test_graph_builder.go index c642a74081..1730248510 100644 --- a/internal/moduletest/graph/test_graph_builder.go +++ b/internal/moduletest/graph/test_graph_builder.go @@ -8,6 +8,7 @@ import ( "github.com/hashicorp/terraform/internal/addrs" "github.com/hashicorp/terraform/internal/backend/backendrun" + "github.com/hashicorp/terraform/internal/dag" "github.com/hashicorp/terraform/internal/moduletest" "github.com/hashicorp/terraform/internal/terraform" "github.com/hashicorp/terraform/internal/tfdiags" @@ -62,13 +63,11 @@ func (b *TestGraphBuilder) Steps() []terraform.GraphTransformer { } func validateRunConfigs(g *terraform.Graph) error { - for _, v := range g.Vertices() { - if node, ok := v.(*NodeTestRun); ok { - diags := node.run.Config.Validate(node.run.ModuleConfig) - node.run.Diagnostics = node.run.Diagnostics.Append(diags) - if diags.HasErrors() { - node.run.Status = moduletest.Error - } + for node := range dag.SelectSeq[*NodeTestRun](g.VerticesSeq()) { + diags := node.run.Config.Validate(node.run.ModuleConfig) + node.run.Diagnostics = node.run.Diagnostics.Append(diags) + if diags.HasErrors() { + node.run.Status = moduletest.Error } } return nil diff --git a/internal/moduletest/graph/transform_close_graph.go b/internal/moduletest/graph/transform_close_graph.go index c366c9fa33..1de7562106 100644 --- a/internal/moduletest/graph/transform_close_graph.go +++ b/internal/moduletest/graph/transform_close_graph.go @@ -15,11 +15,7 @@ func (t *CloseTestGraphTransformer) Transform(g *terraform.Graph) error { closeRoot := &nodeCloseTest{} g.Add(closeRoot) - for _, v := range g.Vertices() { - if v == closeRoot { - continue - } - + for v := range dag.ExcludeSeq[*nodeCloseTest](g.VerticesSeq()) { // since this is closing the graph, make it depend on everything in // the graph that does not have a parent. Such nodes are the real roots // of the graph, and since they are now siblings of the closing root node, diff --git a/internal/moduletest/graph/transform_context.go b/internal/moduletest/graph/transform_context.go index 1b693e48bd..84079dc665 100644 --- a/internal/moduletest/graph/transform_context.go +++ b/internal/moduletest/graph/transform_context.go @@ -49,7 +49,7 @@ func (e *EvalContextTransformer) Transform(graph *terraform.Graph) error { } graph.Add(node) - for _, v := range graph.Vertices() { + for v := range graph.VerticesSeq() { if v == node { continue } diff --git a/internal/moduletest/graph/transform_providers.go b/internal/moduletest/graph/transform_providers.go index 64d083c21f..fc519359c3 100644 --- a/internal/moduletest/graph/transform_providers.go +++ b/internal/moduletest/graph/transform_providers.go @@ -21,12 +21,7 @@ func (t *TestProvidersTransformer) Transform(g *terraform.Graph) error { // a root provider node that will add the providers to the context rootProviderNode := t.createRootNode(g, runProviderMap) - for _, v := range g.Vertices() { - node, ok := v.(*NodeTestRun) - if !ok { - continue - } - + for node := range dag.SelectSeq[*NodeTestRun](g.VerticesSeq()) { // Get the providers that the test run depends on configKey := node.run.GetModuleConfigID() if _, ok := configsProviderMap[configKey]; !ok { @@ -36,7 +31,7 @@ func (t *TestProvidersTransformer) Transform(g *terraform.Graph) error { runProviderMap[node] = configsProviderMap[configKey] // Add an edge from the test run node to the root provider node - g.Connect(dag.BasicEdge(v, rootProviderNode)) + g.Connect(dag.BasicEdge(node, rootProviderNode)) } return nil diff --git a/internal/moduletest/graph/transform_state_cleanup.go b/internal/moduletest/graph/transform_state_cleanup.go index 9cf2b25e64..997dcf6199 100644 --- a/internal/moduletest/graph/transform_state_cleanup.go +++ b/internal/moduletest/graph/transform_state_cleanup.go @@ -20,11 +20,7 @@ type TestStateCleanupTransformer struct { func (t *TestStateCleanupTransformer) Transform(g *terraform.Graph) error { cleanupMap := make(map[string]*NodeStateCleanup) - for _, v := range g.Vertices() { - node, ok := v.(*NodeTestRun) - if !ok { - continue - } + for node := range dag.SelectSeq[*NodeTestRun](g.VerticesSeq()) { key := node.run.GetStateKey() if _, exists := cleanupMap[key]; !exists { cleanupMap[key] = &NodeStateCleanup{stateKey: key, opts: t.opts} @@ -40,7 +36,7 @@ func (t *TestStateCleanupTransformer) Transform(g *terraform.Graph) error { // existing CLI output. rootCleanupNode := t.addRootCleanupNode(g) - for _, v := range g.Vertices() { + for v := range g.VerticesSeq() { switch node := v.(type) { case *NodeTestRun: // All the runs that share the same state, must share the same cleanup node,