diff --git a/internal/daemon/worker/controller_connection_test.go b/internal/daemon/worker/controller_connection_test.go index da11d848e2..ba07127629 100644 --- a/internal/daemon/worker/controller_connection_test.go +++ b/internal/daemon/worker/controller_connection_test.go @@ -21,38 +21,13 @@ import ( ) func TestMonitorUpstreamConnectionState(t *testing.T) { - ctx := context.Background() - stateCtx, cancelStateCtx := context.WithCancel(ctx) - - upstreamConnectionState := new(atomic.Value) - servers, err := createTestServers(t) require.NoError(t, err) scheme := strconv.FormatInt(time.Now().UnixNano(), 36) res := manual.NewBuilderWithScheme(scheme) - grpcResolver := &grpcResolverReceiver{res} - - dialOpts := createDefaultGRPCDialOptions(res, nil) - - cc, err := grpc.Dial( - fmt.Sprintf("%s:///%s", res.Scheme(), servers[0].address), - dialOpts..., - ) - - require.NoError(t, err) - - // track GRPC state changes - go monitorUpstreamConnectionState(stateCtx, cc, upstreamConnectionState) - - grpcResolver.InitialAddresses([]string{servers[0].address}) t.Cleanup(func() { - cc.Close() - cancelStateCtx() - - assert.Equal(t, connectivity.Shutdown, cc.GetState()) - for _, s := range servers { s.srv.GracefulStop() } @@ -87,7 +62,23 @@ func TestMonitorUpstreamConnectionState(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + stateCtx, cancelStateCtx := context.WithCancel(context.Background()) + upstreamConnectionState := new(atomic.Value) doneWait := make(chan struct{}) + + grpcResolver := &grpcResolverReceiver{res} + grpcResolver.InitialAddresses([]string{servers[0].address}) + + dialOpts := createDefaultGRPCDialOptions(res, nil) + cc, err := grpc.Dial( + fmt.Sprintf("%s:///%s", res.Scheme(), servers[0].address), + dialOpts..., + ) + require.NoError(t, err) + + // track GRPC state changes + go monitorUpstreamConnectionState(stateCtx, cc, upstreamConnectionState) + grpcResolver.SetAddresses(tt.addresses) go waitForConnectionStateCondition(upstreamConnectionState, tt.expectedState, doneWait) @@ -98,6 +89,10 @@ func TestMonitorUpstreamConnectionState(t *testing.T) { case <-time.After(2 * time.Second): t.Error("Time out waiting for condition") } + + require.NoError(t, cc.Close()) + cancelStateCtx() + assert.Equal(t, connectivity.Shutdown, cc.GetState()) }) } }