refact(controller): Decouple listeners/servers shutdown from Worker

In general, these changes were introduced in the same vein of the
changes to controller `startListeners`.

Specifically, these changes also decouple the Controller and the Worker
shutdown processes. Currently, a call to `stopListeners` on an instance
that is running as both a Controller and a Worker would stop the
listeners for both.

Lastly, we also add unit/integration tests to the shutdown code.
pull/1908/head
Hugo Vieira 4 years ago
parent d438a6d2dd
commit eebdfb88dc

@ -618,7 +618,7 @@ func (c *Command) Run(args []string) int {
if err := c.controller.Start(); err != nil {
retErr := fmt.Errorf("Error starting controller: %w", err)
if err := c.controller.Shutdown(false); err != nil {
if err := c.controller.Shutdown(); err != nil {
c.UI.Error(retErr.Error())
retErr = fmt.Errorf("Error shutting down controller: %w", err)
}
@ -647,7 +647,7 @@ func (c *Command) Run(args []string) int {
retErr = fmt.Errorf("Error shutting down worker: %w", err)
}
c.UI.Error(retErr.Error())
if err := c.controller.Shutdown(false); err != nil {
if err := c.controller.Shutdown(); err != nil {
c.UI.Error(fmt.Errorf("Error with controller shutdown: %w", err).Error())
}
return base.CommandCliError
@ -684,7 +684,7 @@ func (c *Command) Run(args []string) int {
}
}
if err := c.controller.Shutdown(false); err != nil {
if err := c.controller.Shutdown(); err != nil {
c.UI.Error(fmt.Errorf("Error shutting down controller: %w", err).Error())
}

@ -465,7 +465,7 @@ func (c *Command) Run(args []string) int {
if err := c.StartWorker(); err != nil {
c.UI.Error(err.Error())
if c.controller != nil {
if err := c.controller.Shutdown(false); err != nil {
if err := c.controller.Shutdown(); err != nil {
c.UI.Error(fmt.Errorf("Error with controller shutdown: %w", err).Error())
}
}
@ -601,7 +601,7 @@ func (c *Command) StartController(ctx context.Context) error {
if err := c.controller.Start(); err != nil {
retErr := fmt.Errorf("Error starting controller: %w", err)
if err := c.controller.Shutdown(false); err != nil {
if err := c.controller.Shutdown(); err != nil {
c.UI.Error(retErr.Error())
retErr = fmt.Errorf("Error shutting down controller: %w", err)
}
@ -669,7 +669,7 @@ func (c *Command) WaitForInterrupt() int {
// Do controller shutdown
if c.Config.Controller != nil {
if err := c.controller.Shutdown(c.Config.Worker != nil); err != nil {
if err := c.controller.Shutdown(); err != nil {
c.UI.Error(fmt.Errorf("Error shutting down controller: %w", err).Error())
}
}

@ -352,15 +352,15 @@ func (c *Controller) registerSessionConnectionCleanupJob() error {
return nil
}
func (c *Controller) Shutdown(serversOnly bool) error {
func (c *Controller) Shutdown() error {
const op = "controller.(Controller).Shutdown"
if !c.started.Load() {
event.WriteSysEvent(context.TODO(), op, "already shut down, skipping")
}
defer c.started.Store(false)
c.baseCancel()
if err := c.stopListeners(serversOnly); err != nil {
return fmt.Errorf("error stopping controller listeners: %w", err)
if err := c.stopServersAndListeners(); err != nil {
return fmt.Errorf("error stopping controller servers and listeners: %w", err)
}
c.schedulerWg.Wait()
c.tickerWg.Wait()

@ -9,11 +9,9 @@ import (
"net"
"net/http"
"os"
"sync"
"time"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
"github.com/hashicorp/boundary/globals"
"github.com/hashicorp/boundary/internal/cmd/base"
pbs "github.com/hashicorp/boundary/internal/gen/controller/servers/services"
"github.com/hashicorp/boundary/internal/libs/alpnmux"
@ -162,61 +160,108 @@ func (c *Controller) configureForCluster(ln *base.ServerListener) (func(), error
return func() { go ln.GrpcServer.Serve(ln.ALPNListener) }, nil
}
func (c *Controller) stopListeners(serversOnly bool) error {
serverWg := new(sync.WaitGroup)
for _, ln := range c.conf.Listeners {
localLn := ln
serverWg.Add(1)
go func() {
defer serverWg.Done()
shutdownKill, shutdownKillCancel := context.WithTimeout(c.baseContext, localLn.Config.MaxRequestDuration)
defer shutdownKillCancel()
if localLn.GrpcServer != nil {
// Deal with the worst case
go func() {
<-shutdownKill.Done()
localLn.GrpcServer.Stop()
}()
localLn.GrpcServer.GracefulStop()
}
if localLn.HTTPServer != nil {
localLn.HTTPServer.Shutdown(shutdownKill)
}
}()
func (c *Controller) stopServersAndListeners() error {
var mg multierror.Group
mg.Go(c.stopClusterGrpcServerAndListener)
mg.Go(c.stopHttpServersAndListeners)
mg.Go(c.stopApiGrpcServerAndListener)
stopErrors := mg.Wait()
err := c.stopAnyListeners()
if err != nil {
stopErrors = multierror.Append(stopErrors, err)
}
return stopErrors.ErrorOrNil()
}
func (c *Controller) stopClusterGrpcServerAndListener() error {
if c.clusterListener == nil {
return nil
}
if c.clusterListener.GrpcServer == nil {
return fmt.Errorf("no cluster grpc server")
}
if c.clusterListener.Mux == nil {
return fmt.Errorf("no cluster listener mux")
}
c.clusterListener.GrpcServer.GracefulStop()
err := c.clusterListener.Mux.Close()
return listenerCloseErrorCheck(c.clusterListener.Config.Type, err)
}
func (c *Controller) stopHttpServersAndListeners() error {
var closeErrors *multierror.Error
for i := range c.apiListeners {
ln := c.apiListeners[i]
if ln.HTTPServer == nil {
continue
}
if c.apiGrpcServer != nil {
serverWg.Add(1)
go func() {
defer serverWg.Done()
shutdownKill, shutdownKillCancel := context.WithTimeout(c.baseContext, globals.DefaultMaxRequestDuration)
defer shutdownKillCancel()
go func() {
<-shutdownKill.Done()
c.apiGrpcServer.Stop()
}()
c.apiGrpcServer.GracefulStop()
}()
ctx, cancel := context.WithTimeout(c.baseContext, ln.Config.MaxRequestDuration)
ln.HTTPServer.Shutdown(ctx)
cancel()
err := ln.Mux.Close() // The HTTP Shutdown call should close this, but just in case.
err = listenerCloseErrorCheck(ln.Config.Type, err)
if err != nil {
multierror.Append(closeErrors, err)
}
}
serverWg.Wait()
if serversOnly {
return closeErrors.ErrorOrNil()
}
func (c *Controller) stopApiGrpcServerAndListener() error {
if c.apiGrpcServer == nil {
return nil
}
var retErr *multierror.Error
for _, ln := range c.conf.Listeners {
if err := ln.Mux.Close(); err != nil {
if _, ok := err.(*os.PathError); ok && ln.Config.Type == "unix" {
// The rmListener probably tried to remove the file but it
// didn't exist, ignore the error; this is a conflict
// between rmListener and the default Go behavior of
// removing auto-vivified Unix domain sockets.
} else {
retErr = multierror.Append(retErr, err)
}
c.apiGrpcServer.GracefulStop()
err := c.apiGrpcServerListener.Close()
return listenerCloseErrorCheck("ch", err) // apiGrpcServerListener is just a channel, so the type here is not important.
}
// stopAnyListeners does a final once over the known
// listeners to make sure we didn't miss any;
// expected to run at the end of stopServersAndListeners.
func (c *Controller) stopAnyListeners() error {
var closeErrors *multierror.Error
for i := range c.apiListeners {
ln := c.apiListeners[i]
if ln == nil || ln.Mux == nil {
continue
}
err := ln.Mux.Close()
err = listenerCloseErrorCheck(ln.Config.Type, err)
if err != nil {
multierror.Append(closeErrors, err)
}
}
return retErr.ErrorOrNil()
return closeErrors.ErrorOrNil()
}
// listenerCloseErrorCheck does some validation on an error returned
// by a net.Listener's Close function, and ignores a few cases
// where we don't actually want an error to be returned.
func listenerCloseErrorCheck(lnType string, err error) error {
if errors.Is(err, net.ErrClosed) {
// Ignore net.ErrClosed - The listener was already closed,
// so there's nothing else to do.
return nil
}
if _, ok := err.(*os.PathError); ok && lnType == "unix" {
// The underlying rmListener probably tried to remove
// the file but it didn't exist, ignore the error;
// this is a conflict between rmListener and the
// default Go behavior of removing auto-vivified
// Unix domain sockets.
return nil
}
return err
}

@ -20,6 +20,7 @@ import (
"time"
"github.com/hashicorp/boundary/internal/cmd/base"
"github.com/hashicorp/boundary/internal/libs/alpnmux"
"github.com/hashicorp/go-secure-stdlib/base62"
"github.com/hashicorp/go-secure-stdlib/configutil/v2"
"github.com/hashicorp/go-secure-stdlib/listenerutil"
@ -262,6 +263,454 @@ func TestStartListeners(t *testing.T) {
}
}
func TestStopClusterGrpcServerAndListener(t *testing.T) {
tests := []struct {
name string
controllerFn func(t *testing.T) *Controller
assertions func(t *testing.T, c *Controller)
expErr bool
expErrStr string
}{
{
name: "no cluster listener",
controllerFn: func(t *testing.T) *Controller {
return &Controller{}
},
expErr: false,
},
{
name: "no cluster grpc server",
controllerFn: func(t *testing.T) *Controller {
return &Controller{clusterListener: &base.ServerListener{}}
},
expErr: true,
expErrStr: "no cluster grpc server",
},
{
name: "no cluster listener mux",
controllerFn: func(t *testing.T) *Controller {
return &Controller{
clusterListener: &base.ServerListener{
GrpcServer: grpc.NewServer(),
Config: &listenerutil.ListenerConfig{Type: "tcp"},
},
}
},
expErr: true,
expErrStr: "no cluster listener mux",
},
{
name: "listener already closed",
controllerFn: func(t *testing.T) *Controller {
l, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
require.NoError(t, l.Close())
return &Controller{
clusterListener: &base.ServerListener{
ALPNListener: l,
GrpcServer: grpc.NewServer(),
Mux: alpnmux.New(l),
Config: &listenerutil.ListenerConfig{Type: "tcp"},
},
}
},
assertions: func(t *testing.T, c *Controller) {
require.ErrorIs(t, c.clusterListener.Mux.Close(), net.ErrClosed)
},
expErr: false,
},
{
name: "graceful stop",
controllerFn: func(t *testing.T) *Controller {
l, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
grpcServer := grpc.NewServer()
go grpcServer.Serve(l)
// Make sure it's up
_, err = grpc.Dial(l.Addr().String(),
grpc.WithInsecure(),
grpc.WithBlock(),
grpc.WithTimeout(5*time.Second),
)
require.NoError(t, err)
return &Controller{
clusterListener: &base.ServerListener{
ALPNListener: l,
GrpcServer: grpcServer,
Mux: alpnmux.New(l),
Config: &listenerutil.ListenerConfig{Type: "tcp"},
},
}
},
assertions: func(t *testing.T, c *Controller) {
require.ErrorIs(t, c.clusterListener.Mux.Close(), net.ErrClosed)
},
expErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := tt.controllerFn(t)
err := c.stopClusterGrpcServerAndListener()
if tt.expErr {
require.EqualError(t, err, tt.expErrStr)
return
}
require.NoError(t, err)
if tt.assertions != nil {
tt.assertions(t, c)
}
})
}
}
func TestStopHttpServersAndListeners(t *testing.T) {
tests := []struct {
name string
controllerFn func(t *testing.T) *Controller
assertions func(t *testing.T, c *Controller)
expErr bool
expErrStr string
}{
{
name: "no listeners",
controllerFn: func(t *testing.T) *Controller {
return &Controller{
apiListeners: []*base.ServerListener{},
}
},
expErr: false,
},
{
name: "nil listeners",
controllerFn: func(t *testing.T) *Controller {
return &Controller{
apiListeners: nil,
}
},
expErr: false,
},
{
name: "listeners with nil http server",
controllerFn: func(t *testing.T) *Controller {
return &Controller{
apiListeners: []*base.ServerListener{
{HTTPServer: nil},
{HTTPServer: nil},
{HTTPServer: nil},
},
}
},
expErr: false,
},
{
name: "listener already closed",
controllerFn: func(t *testing.T) *Controller {
l1, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
require.NoError(t, l1.Close())
s1 := &http.Server{}
return &Controller{
baseContext: context.Background(),
apiListeners: []*base.ServerListener{
{
ALPNListener: l1,
HTTPServer: s1,
Mux: alpnmux.New(l1),
Config: &listenerutil.ListenerConfig{Type: "tcp"},
},
},
}
},
assertions: func(t *testing.T, c *Controller) {
// Asserts the HTTP Servers are closed.
require.ErrorIs(t, c.apiListeners[0].HTTPServer.Serve(c.apiListeners[0].ALPNListener), http.ErrServerClosed)
// Asserts the underlying listeners are closed.
require.ErrorIs(t, c.apiListeners[0].Mux.Close(), net.ErrClosed)
},
expErr: false,
},
{
name: "multiple listeners",
controllerFn: func(t *testing.T) *Controller {
l1, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
l2, err := net.Listen("unix", "/tmp/boundary-controller-TestStopHttpServersAndListeners-"+strconv.FormatInt(time.Now().UnixNano(), 10))
require.NoError(t, err)
s1 := &http.Server{}
s2 := &http.Server{}
go s1.Serve(l1)
go s2.Serve(l2)
// Make sure they're up
_, err = http.Get("http://" + l1.Addr().String())
require.NoError(t, err)
c := http.Client{Transport: &http.Transport{
Dial: func(network, addr string) (net.Conn, error) {
return net.Dial("unix", l2.Addr().String())
},
}}
_, err = c.Get("http://random.domain")
require.NoError(t, err)
return &Controller{
baseContext: context.Background(),
apiListeners: []*base.ServerListener{
{
ALPNListener: l1,
HTTPServer: s1,
Mux: alpnmux.New(l1),
Config: &listenerutil.ListenerConfig{Type: "tcp"},
},
{
ALPNListener: l2,
HTTPServer: s2,
Mux: alpnmux.New(l2),
Config: &listenerutil.ListenerConfig{Type: "tcp"},
},
},
}
},
assertions: func(t *testing.T, c *Controller) {
// Asserts the HTTP Servers are closed.
require.ErrorIs(t, c.apiListeners[0].HTTPServer.Serve(c.apiListeners[0].ALPNListener), http.ErrServerClosed)
require.ErrorIs(t, c.apiListeners[1].HTTPServer.Serve(c.apiListeners[1].ALPNListener), http.ErrServerClosed)
// Asserts the underlying listeners are closed.
require.ErrorIs(t, c.apiListeners[0].Mux.Close(), net.ErrClosed)
require.ErrorIs(t, c.apiListeners[1].Mux.Close(), net.ErrClosed)
},
expErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := tt.controllerFn(t)
err := c.stopHttpServersAndListeners()
if tt.expErr {
require.EqualError(t, err, tt.expErrStr)
return
}
require.NoError(t, err)
if tt.assertions != nil {
tt.assertions(t, c)
}
})
}
}
func TestStopApiGrpcServerAndListener(t *testing.T) {
tests := []struct {
name string
controllerFn func(t *testing.T) *Controller
assertions func(t *testing.T, c *Controller)
expErr bool
expErrStr string
}{
{
name: "nil api grpc server",
controllerFn: func(t *testing.T) *Controller {
return &Controller{apiGrpcServer: nil}
},
expErr: false,
},
{
name: "graceful stop",
controllerFn: func(t *testing.T) *Controller {
l := newGrpcServerListener()
grpcServer := grpc.NewServer()
go grpcServer.Serve(l)
// Make sure it's up
_, err := grpc.Dial("",
grpc.WithInsecure(),
grpc.WithBlock(),
grpc.WithTimeout(10*time.Second),
grpc.WithDialer(func(s string, d time.Duration) (net.Conn, error) {
return l.Dial()
}),
)
require.NoError(t, err)
return &Controller{apiGrpcServer: grpcServer, apiGrpcServerListener: l}
},
assertions: func(t *testing.T, c *Controller) {
require.ErrorIs(t, c.apiGrpcServer.Serve(c.apiGrpcServerListener), grpc.ErrServerStopped)
},
expErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := tt.controllerFn(t)
err := c.stopApiGrpcServerAndListener()
if tt.expErr {
require.EqualError(t, err, tt.expErrStr)
return
}
require.NoError(t, err)
if tt.assertions != nil {
tt.assertions(t, c)
}
})
}
}
func TestStopAnyListeners(t *testing.T) {
tests := []struct {
name string
controllerFn func(t *testing.T) *Controller
assertions func(t *testing.T, c *Controller)
expErr bool
expErrStr string
}{
{
name: "nil apiListeners",
controllerFn: func(t *testing.T) *Controller {
return &Controller{apiListeners: nil}
},
expErr: false,
},
{
name: "no listeners",
controllerFn: func(t *testing.T) *Controller {
return &Controller{apiListeners: []*base.ServerListener{}}
},
expErr: false,
},
{
name: "nil listeners",
controllerFn: func(t *testing.T) *Controller {
return &Controller{apiListeners: []*base.ServerListener{nil, nil, nil}}
},
expErr: false,
},
{
name: "listeners with nil mux",
controllerFn: func(t *testing.T) *Controller {
return &Controller{apiListeners: []*base.ServerListener{
{Mux: nil}, {Mux: nil}, {Mux: nil},
}}
},
expErr: false,
},
{
name: "multiple listeners, including a closed one",
controllerFn: func(t *testing.T) *Controller {
l1, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
l2, err := net.Listen("unix", "/tmp/boundary-controller-TestStopAnyListeners-"+strconv.FormatInt(time.Now().UnixNano(), 10))
require.NoError(t, err)
l3, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
require.NoError(t, l3.Close())
return &Controller{apiListeners: []*base.ServerListener{
{
Config: &listenerutil.ListenerConfig{Type: "tcp"},
Mux: alpnmux.New(l1),
},
{
Config: &listenerutil.ListenerConfig{Type: "tcp"},
Mux: alpnmux.New(l2),
},
{
Config: &listenerutil.ListenerConfig{Type: "tcp"},
Mux: alpnmux.New(l3),
},
}}
},
assertions: func(t *testing.T, c *Controller) {
for i := range c.apiListeners {
ln := c.apiListeners[i]
require.ErrorIs(t, ln.Mux.Close(), net.ErrClosed)
}
},
expErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := tt.controllerFn(t)
err := c.stopAnyListeners()
if tt.expErr {
require.EqualError(t, err, tt.expErrStr)
return
}
require.NoError(t, err)
if tt.assertions != nil {
tt.assertions(t, c)
}
})
}
}
func TestListenerCloseErrorCheck(t *testing.T) {
tests := []struct {
name string
lnType string
err error
expErr error
}{
{
name: "nil err",
lnType: "tcp",
err: nil,
expErr: nil,
},
{
name: "net.Closed",
lnType: "tcp",
err: net.ErrClosed,
expErr: nil,
},
{
name: "path err not unix type",
lnType: "tcp",
err: &os.PathError{Op: "test"},
expErr: &os.PathError{Op: "test"},
},
{
name: "path err unix type",
lnType: "unix",
err: &os.PathError{Op: "test"},
expErr: nil,
},
{
name: "literally anything else",
lnType: "tcp",
err: fmt.Errorf("oops I errored"),
expErr: fmt.Errorf("oops I errored"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := listenerCloseErrorCheck(tt.lnType, tt.err)
require.Equal(t, tt.expErr, err)
})
}
}
func clusterGrpcDialNoError(t *testing.T, c *Controller, network, addr string) {
grpcConn, err := grpc.Dial(addr,
grpc.WithInsecure(),

@ -279,7 +279,7 @@ func (tc *TestController) Shutdown() {
tc.cancel()
if tc.c != nil {
if err := tc.c.Shutdown(false); err != nil {
if err := tc.c.Shutdown(); err != nil {
tc.t.Error(err)
}
}

@ -77,16 +77,18 @@ func TestIPv6Listener(t *testing.T) {
time.Sleep(10 * time.Second)
expectWorkers(c1, w1)
require.NoError(w1.Worker().Shutdown(true))
time.Sleep(10 * time.Second)
expectWorkers(c1)
c2 := c1.AddClusterControllerMember(t, &controller.TestControllerOpts{
Logger: c1.Config().Logger.ResetNamed("c2"),
})
defer c2.Shutdown()
require.NoError(c1.Controller().Shutdown(true))
time.Sleep(10 * time.Second)
expectWorkers(c2, w1)
require.NoError(c1.Controller().Start())
require.NoError(w1.Worker().Shutdown(true))
time.Sleep(10 * time.Second)
expectWorkers(c1, w1)
expectWorkers(c1)
expectWorkers(c2)
client, err := api.NewClient(nil)
require.NoError(err)

@ -64,7 +64,7 @@ func TestMultiControllerMultiWorkerConnections(t *testing.T) {
expectWorkers(t, c1, w1, w2)
expectWorkers(t, c2, w1, w2)
require.NoError(c1.Controller().Shutdown(true))
require.NoError(c2.Controller().Shutdown())
time.Sleep(10 * time.Second)
expectWorkers(t, c2, w1, w2)

@ -94,12 +94,16 @@ func TestUnixListener(t *testing.T) {
time.Sleep(10 * time.Second)
expectWorkers(c1)
require.NoError(c1.Controller().Shutdown(true))
time.Sleep(10 * time.Second)
require.NoError(c1.Controller().Shutdown())
c1 = controller.NewTestController(t, &controller.TestControllerOpts{
Config: conf,
Logger: logger.Named("c1"),
DisableOidcAuthMethodCreation: true,
})
defer c1.Shutdown()
require.NoError(c1.Controller().Start())
time.Sleep(10 * time.Second)
expectWorkers(c1, w1)
expectWorkers(c1)
client, err := api.NewClient(nil)
require.NoError(err)

Loading…
Cancel
Save