diff --git a/internal/daemon/worker/controller_connection_test.go b/internal/daemon/worker/controller_connection_test.go index d90e35b180..f03557f366 100644 --- a/internal/daemon/worker/controller_connection_test.go +++ b/internal/daemon/worker/controller_connection_test.go @@ -12,8 +12,6 @@ import ( "testing" "time" - opsservices "github.com/hashicorp/boundary/internal/gen/ops/services" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/net/nettest" "google.golang.org/grpc" @@ -22,48 +20,22 @@ import ( ) func TestMonitorUpstreamConnectionState(t *testing.T) { - ctx := context.Background() - stateCtx, cancelStateCtx := context.WithCancel(ctx) - - upstreamConnectionState := new(atomic.Value) - servers, err := createTestServers(t) require.NoError(t, err) scheme := strconv.FormatInt(time.Now().UnixNano(), 36) res := manual.NewBuilderWithScheme(scheme) - grpcResolver := &grpcResolverReceiver{res} - - dialOpts := createDefaultGRPCDialOptions(res, nil) - - cc, err := grpc.Dial( - fmt.Sprintf("%s:///%s", res.Scheme(), servers[0].address), - dialOpts..., - ) - - require.NoError(t, err) - - // track GRPC state changes - go monitorUpstreamConnectionState(stateCtx, cc, upstreamConnectionState) - - grpcResolver.InitialAddresses([]string{servers[0].address}) t.Cleanup(func() { - cc.Close() - cancelStateCtx() - - assert.Equal(t, connectivity.Shutdown, cc.GetState()) - for _, s := range servers { s.srv.GracefulStop() } }) tests := []struct { - name string - expectedResponse *opsservices.GetHealthResponse - addresses []string - expectedState connectivity.State + name string + addresses []string + expectedState connectivity.State }{ { name: "connection with 1 good address", @@ -89,7 +61,23 @@ func TestMonitorUpstreamConnectionState(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + stateCtx, cancelStateCtx := context.WithCancel(context.Background()) + upstreamConnectionState := new(atomic.Value) doneWait := make(chan struct{}) + + grpcResolver := &grpcResolverReceiver{res} + grpcResolver.InitialAddresses([]string{servers[0].address}) + + dialOpts := createDefaultGRPCDialOptions(res, nil) + cc, err := grpc.Dial( + fmt.Sprintf("%s:///%s", res.Scheme(), servers[0].address), + dialOpts..., + ) + require.NoError(t, err) + + // track GRPC state changes + go monitorUpstreamConnectionState(stateCtx, cc, upstreamConnectionState) + grpcResolver.SetAddresses(tt.addresses) go waitForConnectionStateCondition(upstreamConnectionState, tt.expectedState, doneWait) @@ -101,8 +89,8 @@ func TestMonitorUpstreamConnectionState(t *testing.T) { t.Error("Time out waiting for condition") } - got := upstreamConnectionState.Load() - assert.Equal(t, tt.expectedState, got) + require.NoError(t, cc.Close()) + cancelStateCtx() }) } }