diff --git a/internal/cmd/commands/server/worker_shutdown_reload_test.go b/internal/cmd/commands/server/worker_shutdown_reload_test.go index 3fa771c7d4..9fafaee11c 100644 --- a/internal/cmd/commands/server/worker_shutdown_reload_test.go +++ b/internal/cmd/commands/server/worker_shutdown_reload_test.go @@ -98,37 +98,31 @@ func TestServer_ShutdownWorker(t *testing.T) { UseDevAuthMethod: true, UseDevTarget: true, }) - controllerDoneCh := make(chan struct{}) - defer func() { + t.Cleanup(func() { if controllerCmd.DevDatabaseCleanupFunc != nil { require.NoError(controllerCmd.DevDatabaseCleanupFunc()) } - }() + }) controllerCmd.presetConfig = atomic.NewString(fmt.Sprintf(shutdownReloadControllerConfig, controllerCmd.DatabaseUrl, controllerKey, workerAuthKey)) + // Use code channel so that we can use test assertions on the returned integer. + // It is illegal to call `t.FailNow()` from a goroutine. + // https://pkg.go.dev/testing#T.FailNow + controllerCodeChan := make(chan int) go func() { - defer close(controllerDoneCh) - if code := controllerCmd.Run(nil); code != 0 { - output := controllerCmd.UI.(*cli.MockUi).ErrorWriter.String() + controllerCmd.UI.(*cli.MockUi).OutputWriter.String() - require.FailNow(output, "command exited with non-zero error code") - } + controllerCodeChan <- controllerCmd.Run(nil) }() waitCh(t, controllerCmd.startedCh) // Start the worker workerCmd := testServerCommand(t, testServerCommandOpts{}) - workerDoneCh := make(chan struct{}) workerCmd.presetConfig = atomic.NewString(fmt.Sprintf(shutdownReloadWorkerConfig, workerAuthKey)) + workerCodeChan := make(chan int) go func() { - defer close(workerDoneCh) - if code := workerCmd.Run(nil); code != 0 { - output := workerCmd.UI.(*cli.MockUi).ErrorWriter.String() + workerCmd.UI.(*cli.MockUi).OutputWriter.String() - require.FailNow(output, "command exited with non-zero error code") - } + workerCodeChan <- workerCmd.Run(nil) }() - waitCh(t, workerCmd.startedCh) // Set up the target @@ -159,7 +153,10 @@ func TestServer_ShutdownWorker(t *testing.T) { // Now, shut the worker down. close(workerCmd.ShutdownCh) - waitCh(t, workerDoneCh) + if <-workerCodeChan != 0 { + output := workerCmd.UI.(*cli.MockUi).ErrorWriter.String() + workerCmd.UI.(*cli.MockUi).OutputWriter.String() + require.FailNow(output, "command exited with non-zero error code") + } // Connection should fail, and the session should be closed on the DB. sConn.TestSendRecvFail(t) @@ -167,7 +164,10 @@ func TestServer_ShutdownWorker(t *testing.T) { // We're done! Shutdown the controller, and that's it. close(controllerCmd.ShutdownCh) - waitCh(t, controllerDoneCh) + if <-controllerCodeChan != 0 { + output := controllerCmd.UI.(*cli.MockUi).ErrorWriter.String() + controllerCmd.UI.(*cli.MockUi).OutputWriter.String() + require.FailNow(output, "command exited with non-zero error code") + } } // largely copied from controller/testing.go diff --git a/internal/cmd/ops/server_test.go b/internal/cmd/ops/server_test.go index 3bbf9dfe44..c481ee30ce 100644 --- a/internal/cmd/ops/server_test.go +++ b/internal/cmd/ops/server_test.go @@ -411,10 +411,17 @@ func TestShutdown(t *testing.T) { require.NoError(t, err) s1 := &http.Server{} + // Use error channel so that we can use test assertions on the returned error. + // It is illegal to call `t.FailNow()` from a goroutine. + // https://pkg.go.dev/testing#T.FailNow + errChan := make(chan error) go func() { - require.ErrorIs(t, s1.Serve(l1), http.ErrServerClosed) + errChan <- s1.Serve(l1) }() - + t.Cleanup(func() { + // Will block until we stopped serving + require.ErrorIs(t, <-errChan, http.ErrServerClosed) + }) err = s1.Shutdown(context.Background()) require.NoError(t, err) @@ -445,17 +452,28 @@ func TestShutdown(t *testing.T) { require.NoError(t, err) s1 := &http.Server{} + // Use error channel so that we can use test assertions on the returned error. + // It is illegal to call `t.FailNow()` from a goroutine. + // https://pkg.go.dev/testing#T.FailNow + s1ErrChan := make(chan error) go func() { - require.ErrorIs(t, s1.Serve(l1), http.ErrServerClosed) + s1ErrChan <- s1.Serve(l1) }() + t.Cleanup(func() { + require.ErrorIs(t, <-s1ErrChan, http.ErrServerClosed) + }) l2, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) s2 := &http.Server{} + s2ErrChan := make(chan error) go func() { - require.ErrorIs(t, s2.Serve(l2), http.ErrServerClosed) + s2ErrChan <- s2.Serve(l2) }() + t.Cleanup(func() { + require.ErrorIs(t, <-s2ErrChan, http.ErrServerClosed) + }) return &Server{ bundles: []*opsBundle{ diff --git a/internal/servers/controller/handler_test.go b/internal/servers/controller/handler_test.go index 953278b8b2..d43b5ebdd3 100644 --- a/internal/servers/controller/handler_test.go +++ b/internal/servers/controller/handler_test.go @@ -250,13 +250,18 @@ func TestCallbackInterceptor(t *testing.T) { Handler: wrapHandlerWithCallbackInterceptor(noopHandler, nil), } + // Use error channel so that we can use test assertions on the returned error. + // It is illegal to call `t.FailNow()` from a goroutine. + // https://pkg.go.dev/testing#T.FailNow + errChan := make(chan error) go func() { - if err := server.Serve(listener); err != nil { - if err != http.ErrServerClosed { - require.NoError(t, err) - } - } + errChan <- server.Serve(listener) }() + t.Cleanup(func() { + if err := <-errChan; err != http.ErrServerClosed { + require.NoError(t, err) + } + }) testCases := []struct { name string diff --git a/internal/servers/controller/testing.go b/internal/servers/controller/testing.go index 8797692379..bf9e46059c 100644 --- a/internal/servers/controller/testing.go +++ b/internal/servers/controller/testing.go @@ -790,10 +790,17 @@ func startTestGreeterService(t *testing.T, greeter interceptor.GreeterServiceSer grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(interceptors...)), ) interceptor.RegisterGreeterServiceServer(s, greeter) + // Use error channel so that we can use test assertions on the returned error. + // It is illegal to call `t.FailNow()` from a goroutine. + // https://pkg.go.dev/testing#T.FailNow + errChan := make(chan error) go func() { - err := s.Serve(listener) - require.NoError(err) + errChan <- s.Serve(listener) }() + t.Cleanup(func() { + // Will block until we stopped serving + require.NoError(<-errChan) + }) conn, _ := grpc.DialContext(dialCtx, "", grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) { return listener.Dial() diff --git a/internal/servers/worker/proxy/tcp/tcp_test.go b/internal/servers/worker/proxy/tcp/tcp_test.go index 6c8c1b24f2..6ad0179ef2 100644 --- a/internal/servers/worker/proxy/tcp/tcp_test.go +++ b/internal/servers/worker/proxy/tcp/tcp_test.go @@ -31,10 +31,10 @@ func TestHandleTcpProxyV1(t *testing.T) { defer l.Close() var endpointConn net.Conn + var endpointErr error ready := make(chan struct{}) go func() { - endpointConn, err = l.Accept() - require.NoError(err) + endpointConn, endpointErr = l.Accept() defer endpointConn.Close() ready <- struct{}{} @@ -71,13 +71,20 @@ func TestHandleTcpProxyV1(t *testing.T) { UserClientIp: net.ParseIP("127.0.0.1"), } + // Use error channel so that we can use test assertions on the returned error. + // It is illegal to call `t.FailNow()` from a goroutine. + // https://pkg.go.dev/testing#T.FailNow + errChan := make(chan error) go func() { - err = handleProxy(ctx, conf) - require.NoError(err) + errChan <- handleProxy(ctx, conf) }() + t.Cleanup(func() { + require.NoError(<-errChan) + }) // wait for HandleTcpProxyV1 to dial endpoint <-ready + require.NoError(endpointErr) netConn := websocket.NetConn(ctx, clientConn, websocket.MessageBinary) // Write from endpoint to client diff --git a/internal/servers/worker/proxy/testing_test.go b/internal/servers/worker/proxy/testing_test.go index 232627c9b0..620354b97c 100644 --- a/internal/servers/worker/proxy/testing_test.go +++ b/internal/servers/worker/proxy/testing_test.go @@ -9,6 +9,11 @@ import ( "nhooyr.io/websocket" ) +type connMsg struct { + msg []byte + err error +} + func Test_TestWsConn(t *testing.T) { t.Parallel() require, assert := require.New(t), assert.New(t) @@ -16,32 +21,35 @@ func Test_TestWsConn(t *testing.T) { ctx, cancelCtx := context.WithCancel(context.Background()) clientConn, proxyConn := TestWsConn(t, ctx) - successfulRead := make(chan struct{}) + // Use msg channel so that we can use test assertions on the returned content. + // It is illegal to call `t.FailNow()` from a goroutine. + // https://pkg.go.dev/testing#T.FailNow + readChan := make(chan connMsg) go func() { _, msg, err := proxyConn.Read(ctx) - require.NoError(err) - assert.Equal("client to proxy", string(msg)) - successfulRead <- struct{}{} + readChan <- connMsg{msg, err} }() err := clientConn.Write(ctx, websocket.MessageBinary, []byte("client to proxy")) require.NoError(err) // Wait for read to verify success - <-successfulRead + msg := <-readChan + require.NoError(msg.err) + assert.Equal("client to proxy", string(msg.msg)) go func() { _, msg, err := clientConn.Read(ctx) - require.NoError(err) - assert.Equal("proxy to client", string(msg)) - successfulRead <- struct{}{} + readChan <- connMsg{msg, err} }() err = proxyConn.Write(ctx, websocket.MessageBinary, []byte("proxy to client")) require.NoError(err) // Wait for read to verify success - <-successfulRead + msg = <-readChan + require.NoError(msg.err) + assert.Equal("proxy to client", string(msg.msg)) cancelCtx() }