Worker graceful shutdown (#2455)

* add graceful shutdown to worker
pull/2470/head
Irena Rindos 4 years ago committed by GitHub
parent a2ef14283b
commit d951e1ebc1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -76,9 +76,10 @@ var reRemoveWhitespace = regexp.MustCompile(`[\s]+`)
var DevOnlyControllerFlags = func(*Command, *FlagSet) {}
type Command struct {
Context context.Context
UI cli.Ui
ShutdownCh chan struct{}
Context context.Context
ContextCancel context.CancelFunc
UI cli.Ui
ShutdownCh chan struct{}
flags *FlagSets
flagsOnce sync.Once
@ -143,7 +144,6 @@ func NewCommand(ui cli.Ui) *Command {
ShutdownCh: MakeShutdownCh(),
Context: ctx,
}
go func() {
<-ret.ShutdownCh
cancel()
@ -152,6 +152,19 @@ func NewCommand(ui cli.Ui) *Command {
return ret
}
// New returns a new instance of a base.Command type that does not intercept the shutdown channel
func NewServerCommand(ui cli.Ui) *Command {
ctx, cancel := context.WithCancel(context.Background())
ret := &Command{
UI: ui,
ShutdownCh: MakeShutdownCh(),
Context: ctx,
ContextCancel: cancel,
}
return ret
}
// MakeShutdownCh returns a channel that can be used for shutdown
// notifications for commands. This channel will send a message for every
// SIGINT or SIGTERM received.
@ -161,8 +174,10 @@ func MakeShutdownCh() chan struct{} {
shutdownCh := make(chan os.Signal, 4)
signal.Notify(shutdownCh, os.Interrupt, syscall.SIGTERM)
go func() {
<-shutdownCh
close(resultCh)
for {
<-shutdownCh
resultCh <- struct{}{}
}
}()
return resultCh
}

@ -371,18 +371,8 @@ func (b *Server) createInitialOidcAuthMethod(ctx context.Context) (*oidc.AuthMet
}
}
cancelCtx, cancel := context.WithCancel(ctx)
defer cancel()
go func() {
select {
case <-b.ShutdownCh:
cancel()
case <-cancelCtx.Done():
}
}()
b.DevOidcSetup.authMethod, err = oidcRepo.CreateAuthMethod(
cancelCtx,
ctx,
authMethod,
oidc.WithPublicId(b.DevOidcAuthMethodId))
if err != nil {
@ -402,7 +392,7 @@ func (b *Server) createInitialOidcAuthMethod(ctx context.Context) (*oidc.AuthMet
return fmt.Errorf("error generating %s oidc account: %w", typ, err)
}
acct, err = oidcRepo.CreateAccount(
cancelCtx,
ctx,
b.DevOidcSetup.authMethod.GetScopeId(),
acct,
oidc.WithPublicId(accountId),
@ -417,11 +407,11 @@ func (b *Server) createInitialOidcAuthMethod(ctx context.Context) (*oidc.AuthMet
return fmt.Errorf("unable to create iam repo: %w", err)
}
u, _, err := iamRepo.LookupUser(cancelCtx, userId)
u, _, err := iamRepo.LookupUser(ctx, userId)
if err != nil {
return fmt.Errorf("error looking up %s user: %w", typ, err)
}
if _, err = iamRepo.AddUserAccounts(cancelCtx, u.GetPublicId(), u.GetVersion(), []string{acct.GetPublicId()}); err != nil {
if _, err = iamRepo.AddUserAccounts(ctx, u.GetPublicId(), u.GetVersion(), []string{acct.GetPublicId()}); err != nil {
return fmt.Errorf("error associating initial %s user with account: %w", typ, err)
}

@ -32,15 +32,6 @@ func (b *Server) CreateInitialLoginRole(ctx context.Context) (*iam.Role, error)
); err != nil {
return nil, fmt.Errorf("error adding config keys to kms: %w", err)
}
cancelCtx, cancel := context.WithCancel(ctx)
defer cancel()
go func() {
select {
case <-b.ShutdownCh:
cancel()
case <-cancelCtx.Done():
}
}()
iamRepo, err := iam.NewRepository(rw, rw, kmsCache, iam.WithRandomReader(b.SecureRandomReader))
if err != nil {
@ -54,11 +45,11 @@ func (b *Server) CreateInitialLoginRole(ctx context.Context) (*iam.Role, error)
if err != nil {
return nil, fmt.Errorf("error creating in memory role for generated grants: %w", err)
}
role, err := iamRepo.CreateRole(cancelCtx, pr)
role, err := iamRepo.CreateRole(ctx, pr)
if err != nil {
return nil, fmt.Errorf("error creating role for default generated grants: %w", err)
}
if _, err := iamRepo.AddRoleGrants(cancelCtx, role.PublicId, role.Version, []string{
if _, err := iamRepo.AddRoleGrants(ctx, role.PublicId, role.Version, []string{
"id=*;type=scope;actions=list,no-op",
"id=*;type=auth-method;actions=authenticate,list",
"id={{account.id}};actions=read,change-password",
@ -66,7 +57,7 @@ func (b *Server) CreateInitialLoginRole(ctx context.Context) (*iam.Role, error)
}); err != nil {
return nil, fmt.Errorf("error creating grant for default generated grants: %w", err)
}
if _, err := iamRepo.AddPrincipalRoles(cancelCtx, role.PublicId, role.Version+1, []string{auth.AnonymousUserId}, nil); err != nil {
if _, err := iamRepo.AddPrincipalRoles(ctx, role.PublicId, role.Version+1, []string{auth.AnonymousUserId}, nil); err != nil {
return nil, fmt.Errorf("error adding principal to role for default generated grants: %w", err)
}
@ -106,17 +97,7 @@ func (b *Server) CreateInitialPasswordAuthMethod(ctx context.Context) (*password
}
}
cancelCtx, cancel := context.WithCancel(ctx)
defer cancel()
go func() {
select {
case <-b.ShutdownCh:
cancel()
case <-cancelCtx.Done():
}
}()
am, err := pwRepo.CreateAuthMethod(cancelCtx, authMethod,
am, err := pwRepo.CreateAuthMethod(ctx, authMethod,
password.WithPublicId(b.DevPasswordAuthMethodId))
if err != nil {
return nil, nil, fmt.Errorf("error saving auth method to the db: %w", err)
@ -165,7 +146,7 @@ func (b *Server) CreateInitialPasswordAuthMethod(ctx context.Context) (*password
return nil, fmt.Errorf("error creating new in memory password auth account: %w", err)
}
acct, err = pwRepo.CreateAccount(
cancelCtx,
ctx,
scope.Global.String(),
acct,
password.WithPassword(loginPassword),
@ -201,10 +182,10 @@ func (b *Server) CreateInitialPasswordAuthMethod(ctx context.Context) (*password
if err != nil {
return nil, fmt.Errorf("error creating in memory user: %w", err)
}
if u, err = iamRepo.CreateUser(cancelCtx, u, opts...); err != nil {
if u, err = iamRepo.CreateUser(ctx, u, opts...); err != nil {
return nil, fmt.Errorf("error creating initial %s user: %w", typeStr, err)
}
if _, err = iamRepo.AddUserAccounts(cancelCtx, u.GetPublicId(), u.GetVersion(), []string{acct.GetPublicId()}); err != nil {
if _, err = iamRepo.AddUserAccounts(ctx, u.GetPublicId(), u.GetVersion(), []string{acct.GetPublicId()}); err != nil {
return nil, fmt.Errorf("error associating initial %s user with account: %w", typeStr, err)
}
if !admin {
@ -218,14 +199,14 @@ func (b *Server) CreateInitialPasswordAuthMethod(ctx context.Context) (*password
if err != nil {
return nil, fmt.Errorf("error creating in memory role for generated grants: %w", err)
}
defPermsRole, err := iamRepo.CreateRole(cancelCtx, pr)
defPermsRole, err := iamRepo.CreateRole(ctx, pr)
if err != nil {
return nil, fmt.Errorf("error creating role for default generated grants: %w", err)
}
if _, err := iamRepo.AddRoleGrants(cancelCtx, defPermsRole.PublicId, defPermsRole.Version, []string{"id=*;type=*;actions=*"}); err != nil {
if _, err := iamRepo.AddRoleGrants(ctx, defPermsRole.PublicId, defPermsRole.Version, []string{"id=*;type=*;actions=*"}); err != nil {
return nil, fmt.Errorf("error creating grant for default generated grants: %w", err)
}
if _, err := iamRepo.AddPrincipalRoles(cancelCtx, defPermsRole.PublicId, defPermsRole.Version+1, []string{u.GetPublicId()}, nil); err != nil {
if _, err := iamRepo.AddPrincipalRoles(ctx, defPermsRole.PublicId, defPermsRole.Version+1, []string{u.GetPublicId()}, nil); err != nil {
return nil, fmt.Errorf("error adding principal to role for default generated grants: %w", err)
}
return u, nil
@ -282,16 +263,6 @@ func (b *Server) CreateInitialScopes(ctx context.Context) (*iam.Scope, *iam.Scop
return nil, nil, fmt.Errorf("error adding config keys to kms: %w", err)
}
cancelCtx, cancel := context.WithCancel(ctx)
defer cancel()
go func() {
select {
case <-b.ShutdownCh:
cancel()
case <-cancelCtx.Done():
}
}()
iamRepo, err := iam.NewRepository(rw, rw, kmsCache)
if err != nil {
return nil, nil, fmt.Errorf("error creating scopes repository: %w", err)
@ -314,7 +285,7 @@ func (b *Server) CreateInitialScopes(ctx context.Context) (*iam.Scope, *iam.Scop
if err != nil {
return nil, nil, fmt.Errorf("error creating new in memory org scope: %w", err)
}
orgScope, err = iamRepo.CreateScope(cancelCtx, orgScope, b.DevUserId, opts...)
orgScope, err = iamRepo.CreateScope(ctx, orgScope, b.DevUserId, opts...)
if err != nil {
return nil, nil, fmt.Errorf("error saving org scope to the db: %w", err)
}
@ -337,7 +308,7 @@ func (b *Server) CreateInitialScopes(ctx context.Context) (*iam.Scope, *iam.Scop
if err != nil {
return nil, nil, fmt.Errorf("error creating new in memory project scope: %w", err)
}
projScope, err = iamRepo.CreateScope(cancelCtx, projScope, b.DevUserId, opts...)
projScope, err = iamRepo.CreateScope(ctx, projScope, b.DevUserId, opts...)
if err != nil {
return nil, nil, fmt.Errorf("error saving project scope to the db: %w", err)
}
@ -361,16 +332,6 @@ func (b *Server) CreateInitialHostResources(ctx context.Context) (*static.HostCa
return nil, nil, nil, fmt.Errorf("error adding config keys to kms: %w", err)
}
cancelCtx, cancel := context.WithCancel(ctx)
defer cancel()
go func() {
select {
case <-b.ShutdownCh:
cancel()
case <-cancelCtx.Done():
}
}()
staticRepo, err := static.NewRepository(rw, rw, kmsCache)
if err != nil {
return nil, nil, nil, fmt.Errorf("error creating static repository: %w", err)
@ -392,7 +353,7 @@ func (b *Server) CreateInitialHostResources(ctx context.Context) (*static.HostCa
if err != nil {
return nil, nil, nil, fmt.Errorf("error creating in memory host catalog: %w", err)
}
if hc, err = staticRepo.CreateCatalog(cancelCtx, hc, opts...); err != nil {
if hc, err = staticRepo.CreateCatalog(ctx, hc, opts...); err != nil {
return nil, nil, nil, fmt.Errorf("error saving host catalog to the db: %w", err)
}
b.InfoKeys = append(b.InfoKeys, "generated host catalog id")
@ -418,7 +379,7 @@ func (b *Server) CreateInitialHostResources(ctx context.Context) (*static.HostCa
if err != nil {
return nil, nil, nil, fmt.Errorf("error creating in memory host: %w", err)
}
if h, err = staticRepo.CreateHost(cancelCtx, b.DevProjectId, h, opts...); err != nil {
if h, err = staticRepo.CreateHost(ctx, b.DevProjectId, h, opts...); err != nil {
return nil, nil, nil, fmt.Errorf("error saving host to the db: %w", err)
}
b.InfoKeys = append(b.InfoKeys, "generated host id")
@ -440,14 +401,14 @@ func (b *Server) CreateInitialHostResources(ctx context.Context) (*static.HostCa
if err != nil {
return nil, nil, nil, fmt.Errorf("error creating in memory host set: %w", err)
}
if hs, err = staticRepo.CreateSet(cancelCtx, b.DevProjectId, hs, opts...); err != nil {
if hs, err = staticRepo.CreateSet(ctx, b.DevProjectId, hs, opts...); err != nil {
return nil, nil, nil, fmt.Errorf("error saving host set to the db: %w", err)
}
b.InfoKeys = append(b.InfoKeys, "generated host set id")
b.Info["generated host set id"] = b.DevHostSetId
// Associate members
if _, err := staticRepo.AddSetMembers(cancelCtx, b.DevProjectId, b.DevHostSetId, hs.GetVersion(), []string{h.GetPublicId()}); err != nil {
if _, err := staticRepo.AddSetMembers(ctx, b.DevProjectId, b.DevHostSetId, hs.GetVersion(), []string{h.GetPublicId()}); err != nil {
return nil, nil, nil, fmt.Errorf("error associating host set to host in the db: %w", err)
}
@ -468,16 +429,6 @@ func (b *Server) CreateInitialTarget(ctx context.Context) (target.Target, error)
return nil, fmt.Errorf("error adding config keys to kms: %w", err)
}
cancelCtx, cancel := context.WithCancel(ctx)
defer cancel()
go func() {
select {
case <-b.ShutdownCh:
cancel()
case <-cancelCtx.Done():
}
}()
targetRepo, err := target.NewRepository(ctx, rw, rw, kmsCache)
if err != nil {
return nil, fmt.Errorf("error creating target repository: %w", err)
@ -505,7 +456,7 @@ func (b *Server) CreateInitialTarget(ctx context.Context) (target.Target, error)
if err != nil {
return nil, fmt.Errorf("error creating in memory target: %w", err)
}
tt, _, _, err := targetRepo.CreateTarget(cancelCtx, t, opts...)
tt, _, _, err := targetRepo.CreateTarget(ctx, t, opts...)
if err != nil {
return nil, fmt.Errorf("error saving target to the db: %w", err)
}
@ -559,18 +510,18 @@ func (b *Server) CreateInitialTarget(ctx context.Context) (target.Target, error)
if err != nil {
return nil, fmt.Errorf("error creating in memory role for generated grants: %w", err)
}
sessionRole, err := iamRepo.CreateRole(cancelCtx, pr)
sessionRole, err := iamRepo.CreateRole(ctx, pr)
if err != nil {
return nil, fmt.Errorf("error creating role for unprivileged user generated grants: %w", err)
}
if _, err := iamRepo.AddRoleGrants(cancelCtx,
if _, err := iamRepo.AddRoleGrants(ctx,
sessionRole.PublicId,
sessionRole.Version,
[]string{fmt.Sprintf("id=%s;actions=authorize-session", b.DevTargetId)},
); err != nil {
return nil, fmt.Errorf("error creating grant for unprivileged user generated grants: %w", err)
}
if _, err := iamRepo.AddPrincipalRoles(cancelCtx, sessionRole.PublicId, sessionRole.Version+1, []string{b.DevUnprivilegedUserId}, nil); err != nil {
if _, err := iamRepo.AddPrincipalRoles(ctx, sessionRole.PublicId, sessionRole.Version+1, []string{b.DevUnprivilegedUserId}, nil); err != nil {
return nil, fmt.Errorf("error adding principal to role for unprivileged user generated grants: %w", err)
}
}
@ -599,16 +550,6 @@ func (b *Server) RegisterHostPlugin(ctx context.Context, name string, plg plgpb.
return nil, fmt.Errorf("error adding config keys to kms: %w", err)
}
cancelCtx, cancel := context.WithCancel(ctx)
defer cancel()
go func() {
select {
case <-b.ShutdownCh:
cancel()
case <-cancelCtx.Done():
}
}()
hpRepo, err := hostplugin.NewRepository(rw, rw, kmsCache)
if err != nil {
return nil, fmt.Errorf("error creating host plugin repository: %w", err)
@ -622,7 +563,7 @@ func (b *Server) RegisterHostPlugin(ctx context.Context, name string, plg plgpb.
if plugin == nil {
opt = append(opt, hostplugin.WithName(name))
plugin = hostplugin.NewPlugin(opt...)
plugin, err = hpRepo.CreatePlugin(cancelCtx, plugin, opt...)
plugin, err = hpRepo.CreatePlugin(ctx, plugin, opt...)
if err != nil {
return nil, fmt.Errorf("error creating host plugin: %w", err)
}

@ -729,17 +729,7 @@ func (b *Server) CreateGlobalKmsKeys(ctx context.Context) error {
return fmt.Errorf("error adding config keys to kms: %w", err)
}
cancelCtx, cancel := context.WithCancel(ctx)
defer cancel()
go func() {
select {
case <-b.ShutdownCh:
cancel()
case <-cancelCtx.Done():
}
}()
if err = kmsCache.CreateKeys(cancelCtx, scope.Global.String(), kms.WithRandomReader(b.SecureRandomReader)); err != nil {
if err = kmsCache.CreateKeys(ctx, scope.Global.String(), kms.WithRandomReader(b.SecureRandomReader)); err != nil {
return fmt.Errorf("error creating global scope kms keys: %w", err)
}

@ -38,14 +38,14 @@ func initCommands(ui, serverCmdUi cli.Ui, runOpts *RunOptions) {
Commands = map[string]cli.CommandFactory{
"server": func() (cli.Command, error) {
return &server.Command{
Server: base.NewServer(base.NewCommand(serverCmdUi)),
Server: base.NewServer(base.NewServerCommand(serverCmdUi)),
SighupCh: base.MakeSighupCh(),
SigUSR2Ch: MakeSigUSR2Ch(),
}, nil
},
"dev": func() (cli.Command, error) {
return &dev.Command{
Server: base.NewServer(base.NewCommand(serverCmdUi)),
Server: base.NewServer(base.NewServerCommand(serverCmdUi)),
SighupCh: base.MakeSighupCh(),
SigUSR2Ch: MakeSigUSR2Ch(),
}, nil

@ -7,10 +7,10 @@ import (
"math/rand"
"net"
"os"
"os/signal"
"runtime"
"strings"
"syscall"
"sync"
atm "sync/atomic"
"time"
"github.com/hashicorp/boundary/globals"
@ -824,74 +824,92 @@ func (c *Command) Run(args []string) int {
c.opsServer = opsServer
c.opsServer.Start()
// Wait for shutdown
shutdownTriggered := false
var shutdownCompleted atm.Bool
shutdownTriggerCount := 0
// Add a force-shutdown goroutine to consume more interrupts
abortForceShutdownCh := make(chan struct{})
defer close(abortForceShutdownCh)
var workerShutdownOnce sync.Once
workerShutdownFunc := func() {
if err := c.worker.Shutdown(); err != nil {
c.UI.Error(fmt.Errorf("Error shutting down worker: %w", err).Error())
}
if !c.flagWorkerAuthStorageSkipCleanup && c.worker.WorkerAuthStorage != nil {
c.worker.WorkerAuthStorage.Cleanup()
}
}
workerGracefulShutdownFunc := func() {
if err := c.worker.GracefulShutdown(); err != nil {
c.UI.Error(fmt.Errorf("Error shutting down worker gracefully: %w", err).Error())
}
workerShutdownOnce.Do(workerShutdownFunc)
}
var controllerOnce sync.Once
controllerShutdownFunc := func() {
if err := c.controller.Shutdown(); err != nil {
c.UI.Error(fmt.Errorf("Error shutting down controller: %w", err).Error())
}
err := c.opsServer.Shutdown()
if err != nil {
c.UI.Error(fmt.Errorf("Failed to shutdown ops listeners: %w", err).Error())
}
}
runShutdownLogic := func() {
go func() {
shutdownCh := make(chan os.Signal, 4)
signal.Notify(shutdownCh, os.Interrupt, syscall.SIGTERM)
for {
select {
case <-shutdownCh:
c.UI.Error("Forcing shutdown")
os.Exit(base.CommandUserError)
case <-c.ServerSideShutdownCh:
// Drain connections in case this is hit more than once
case <-abortForceShutdownCh:
// No-op, we just use this to shut down the goroutine
return
switch {
case shutdownTriggerCount == 1:
c.ContextCancel()
go func() {
if c.Config.Controller != nil {
c.opsServer.WaitIfHealthExists(c.Config.Controller.GracefulShutdownWaitDuration, c.UI)
}
if !c.flagControllerOnly {
c.UI.Output("==> Boundary dev environment graceful shutdown triggered, interrupt again to enter shutdown")
workerGracefulShutdownFunc()
} else {
c.UI.Output("==> Boundary dev shutdown triggered, interrupt again to force")
}
}
}()
shutdownTriggered = true
controllerOnce.Do(controllerShutdownFunc)
shutdownCompleted.Store(true)
}()
case shutdownTriggerCount == 2 && !c.flagControllerOnly:
go func() {
if !c.flagControllerOnly {
workerShutdownOnce.Do(workerShutdownFunc)
}
if c.Config.Controller != nil {
controllerOnce.Do(controllerShutdownFunc)
}
shutdownCompleted.Store(true)
}()
case shutdownTriggerCount >= 2:
go func() {
c.UI.Error("Forcing shutdown")
os.Exit(base.CommandCliError)
}()
}
}
for !shutdownTriggered && !errorEncountered.Load() {
for !errorEncountered.Load() && !shutdownCompleted.Load() {
select {
case <-c.ServerSideShutdownCh:
c.UI.Output("==> Boundary dev environment self-terminating")
shutdownTriggerCount++
runShutdownLogic()
case <-c.ShutdownCh:
c.UI.Output("==> Boundary dev environment shutdown triggered, interrupt again to force")
shutdownTriggerCount++
runShutdownLogic()
case <-c.SigUSR2Ch:
buf := make([]byte, 32*1024*1024)
n := runtime.Stack(buf[:], true)
event.WriteSysEvent(context.TODO(), op, "goroutine trace", "stack", string(buf[:n]))
}
if shutdownTriggered {
if c.Config.Controller != nil {
c.opsServer.WaitIfHealthExists(c.Config.Controller.GracefulShutdownWaitDuration, c.UI)
}
if !c.flagControllerOnly {
if !c.flagWorkerAuthStorageSkipCleanup && c.worker.WorkerAuthStorage != nil {
c.worker.WorkerAuthStorage.Cleanup()
}
if err := c.worker.Shutdown(); err != nil {
c.UI.Error(fmt.Errorf("Error shutting down worker: %w", err).Error())
}
}
if err := c.controller.Shutdown(); err != nil {
c.UI.Error(fmt.Errorf("Error shutting down controller: %w", err).Error())
}
err := c.opsServer.Shutdown()
if err != nil {
c.UI.Error(fmt.Errorf("Failed to shutdown ops listeners: %w", err).Error())
}
case <-time.After(10 * time.Millisecond):
}
}

@ -215,7 +215,7 @@ func TestReloadControllerDatabase(t *testing.T) {
require.NoError(t, row.Scan(&lock))
require.Equal(t, 1, lock)
close(cmd.ShutdownCh)
cmd.ShutdownCh <- struct{}{}
wg.Wait()
}
@ -317,7 +317,7 @@ func TestReloadControllerDatabase_InvalidNewDatabaseState(t *testing.T) {
require.NoError(t, row.Scan(&lock))
require.Equal(t, 1, lock)
close(cmd.ShutdownCh)
cmd.ShutdownCh <- struct{}{}
wg.Wait()
}

@ -167,7 +167,6 @@ func TestServer_ReloadListener(t *testing.T) {
testCertificateSerial("193080739105342897219784862820114567438786419504")
close(cmd.ShutdownCh)
cmd.ShutdownCh <- struct{}{}
wg.Wait()
}

@ -6,10 +6,10 @@ import (
"fmt"
"net"
"os"
"os/signal"
"runtime"
"strings"
"syscall"
"sync"
atm "sync/atomic"
"time"
"github.com/hashicorp/boundary/globals"
@ -667,69 +667,83 @@ func (c *Command) StartWorker() error {
func (c *Command) WaitForInterrupt() int {
const op = "server.(Command).WaitForInterrupt"
// Wait for shutdown
shutdownTriggered := false
// Add a force-shutdown goroutine to consume another interrupt
abortForceShutdownCh := make(chan struct{})
defer close(abortForceShutdownCh)
var shutdownCompleted atm.Bool
shutdownTriggerCount := 0
runShutdownLogic := func() {
go func() {
shutdownCh := make(chan os.Signal, 4)
signal.Notify(shutdownCh, os.Interrupt, syscall.SIGTERM)
for {
select {
case <-shutdownCh:
c.UI.Error("Forcing shutdown")
os.Exit(base.CommandUserError)
case <-c.ServerSideShutdownCh:
// Drain connections in case this is hit more than once
case <-abortForceShutdownCh:
// No-op, we just use this to shut down the goroutine
return
}
}
}()
if c.Config.Controller != nil && c.opsServer != nil {
c.opsServer.WaitIfHealthExists(c.Config.Controller.GracefulShutdownWaitDuration, c.UI)
var workerShutdownOnce sync.Once
workerShutdownFunc := func() {
if err := c.worker.Shutdown(); err != nil {
c.UI.Error(fmt.Errorf("Error shutting down worker: %w", err).Error())
}
// Do worker shutdown
if c.Config.Worker != nil {
if err := c.worker.Shutdown(); err != nil {
c.UI.Error(fmt.Errorf("Error shutting down worker: %w", err).Error())
}
}
workerGracefulShutdownFunc := func() {
if err := c.worker.GracefulShutdown(); err != nil {
c.UI.Error(fmt.Errorf("Error shutting down worker gracefully: %w", err).Error())
}
// Do controller shutdown
if c.Config.Controller != nil {
if err := c.controller.Shutdown(); err != nil {
c.UI.Error(fmt.Errorf("Error shutting down controller: %w", err).Error())
}
workerShutdownOnce.Do(workerShutdownFunc)
}
var controllerOnce sync.Once
controllerShutdownFunc := func() {
if err := c.controller.Shutdown(); err != nil {
c.UI.Error(fmt.Errorf("Error shutting down controller: %w", err).Error())
}
if c.opsServer != nil {
err := c.opsServer.Shutdown()
if err != nil {
c.UI.Error(fmt.Errorf("Error shutting down ops listeners: %w", err).Error())
c.UI.Error(fmt.Errorf("Failed to shutdown ops listeners: %w", err).Error())
}
}
}
runShutdownLogic := func() {
switch {
case shutdownTriggerCount == 1:
c.ContextCancel()
go func() {
if c.Config.Controller != nil && c.opsServer != nil {
c.opsServer.WaitIfHealthExists(c.Config.Controller.GracefulShutdownWaitDuration, c.UI)
}
if c.Config.Worker != nil {
c.UI.Output("==> Boundary server graceful shutdown triggered, interrupt again to enter shutdown")
workerGracefulShutdownFunc()
} else {
c.UI.Output("==> Boundary server shutdown triggered, interrupt again to force")
}
shutdownTriggered = true
if c.Config.Controller != nil {
controllerOnce.Do(controllerShutdownFunc)
}
shutdownCompleted.Store(true)
}()
case shutdownTriggerCount == 2 && c.Config.Worker != nil:
go func() {
if c.Config.Worker != nil {
workerShutdownOnce.Do(workerShutdownFunc)
}
if c.Config.Controller != nil {
controllerOnce.Do(controllerShutdownFunc)
}
shutdownCompleted.Store(true)
}()
case shutdownTriggerCount >= 2:
c.UI.Error("Forcing shutdown")
os.Exit(base.CommandCliError)
}
}
for !shutdownTriggered {
for !shutdownCompleted.Load() {
select {
case <-c.ServerSideShutdownCh:
c.UI.Output("==> Boundary server self-terminating")
shutdownTriggerCount++
runShutdownLogic()
case <-c.ShutdownCh:
c.UI.Output("==> Boundary server shutdown triggered, interrupt again to force")
shutdownTriggerCount++
runShutdownLogic()
case <-c.SighupCh:
@ -784,6 +798,8 @@ func (c *Command) WaitForInterrupt() int {
buf := make([]byte, 32*1024*1024)
n := runtime.Stack(buf[:], true)
event.WriteSysEvent(context.TODO(), op, "goroutine trace", "stack", string(buf[:n]))
case <-time.After(10 * time.Millisecond):
}
}

@ -46,7 +46,7 @@ func testServerCommand(t *testing.T, opts testServerCommandOpts) *Command {
require := require.New(t)
t.Helper()
cmd := &Command{
Server: base.NewServer(base.NewCommand(cli.NewMockUi())),
Server: base.NewServer(base.NewServerCommand(cli.NewMockUi())),
SighupCh: base.MakeSighupCh(),
startedCh: make(chan struct{}),
reloadedCh: make(chan struct{}, 5),

@ -135,7 +135,6 @@ pollSecondController:
}
}
close(cmd.ShutdownCh)
cmd.ShutdownCh <- struct{}{}
wg.Wait()
}

@ -98,13 +98,13 @@ func TestServer_ShutdownWorker(t *testing.T) {
t.Log("running initial send/recv test")
sConn.TestSendRecvAll(t)
// Now, shut the worker down.
close(workerCmd.ShutdownCh)
// Shutdown the worker and close the connection, as the worker will otherwise wait for it to close.
sConn.Close()
workerCmd.ShutdownCh <- struct{}{}
if <-workerCodeChan != 0 {
output := workerCmd.UI.(*cli.MockUi).ErrorWriter.String() + workerCmd.UI.(*cli.MockUi).OutputWriter.String()
require.FailNow(output, "command exited with non-zero error code")
}
// Connection should fail, and the session should be closed on the DB.
sConn.TestSendRecvFail(t)
sess.ExpectConnectionStateOnController(ctx, t, testController.Controller().ConnectionRepoFn, session.StatusClosed)

@ -130,7 +130,7 @@ func TestServer_ReloadWorkerTags(t *testing.T) {
time.Sleep(10 * time.Second)
fetchWorkerTags("test", "foo", []string{"bar", "baz"})
close(cmd.ShutdownCh)
cmd.ShutdownCh <- struct{}{}
wg.Wait()
}

@ -287,7 +287,7 @@ func (tc *TestController) buildClient() {
func (tc *TestController) Shutdown() {
tc.shutdownOnce.Do(func() {
if tc.b != nil {
close(tc.b.ShutdownCh)
tc.b.ContextCancel()
}
tc.cancel()
@ -492,10 +492,13 @@ func TestControllerConfig(t testing.TB, ctx context.Context, tc *TestController,
opts = new(TestControllerOpts)
}
ctxTest, cancel := context.WithCancel(context.Background())
// Base server
tc.b = base.NewServer(&base.Command{
Context: ctx,
ShutdownCh: make(chan struct{}),
Context: ctxTest,
ContextCancel: cancel,
ShutdownCh: make(chan struct{}),
})
// Get dev config, or use a provided one

@ -423,6 +423,38 @@ func (w *Worker) Start() error {
return nil
}
func (w *Worker) hasActiveConnection() bool {
activeConnection := false
w.sessionManager.ForEachLocalSession(
func(s session.Session) bool {
conns := s.GetLocalConnections()
for _, v := range conns {
if v.Status == pbs.CONNECTIONSTATUS_CONNECTIONSTATUS_CONNECTED {
activeConnection = true
return false
}
}
return true
})
return activeConnection
}
// Graceful shutdown sets the worker state to "shutdown" and will wait to return until there
// are no longer any active connections.
func (w *Worker) GracefulShutdown() error {
const op = "worker.(Worker).GracefulShutdown"
event.WriteSysEvent(w.baseContext, op, "worker entering graceful shutdown")
w.operationalState.Store(server.ShutdownOperationalState)
// Wait for connections to drain
for w.hasActiveConnection() {
time.Sleep(time.Millisecond * 250)
}
event.WriteSysEvent(w.baseContext, op, "worker connections have drained")
return nil
}
// Shutdown shuts down the workers. skipListeners can be used to not stop
// listeners, useful for tests if we want to stop and start a worker. In order
// to create new listeners we'd have to migrate listener setup logic here --
@ -433,9 +465,11 @@ func (w *Worker) Shutdown() error {
event.WriteSysEvent(w.baseContext, op, "already shut down, skipping")
return nil
}
event.WriteSysEvent(w.baseContext, op, "worker shutting down")
// Set state to shutdown
w.operationalState.Store(server.ShutdownOperationalState)
// Stop listeners first to prevent new connections to the
// controller.
defer w.started.Store(false)
@ -479,6 +513,7 @@ func (w *Worker) Shutdown() error {
}
}
event.WriteSysEvent(w.baseContext, op, "worker finished shutting down")
return nil
}

@ -301,6 +301,11 @@ type TestSessionConnection struct {
conn net.Conn
}
// Close a test connection
func (t *TestSessionConnection) Close() error {
return t.conn.Close()
}
// Connect returns a TestSessionConnection for a TestSession. Check
// the unexported connect method for the lower-level details.
func (s *TestSession) Connect(
@ -344,7 +349,8 @@ func (c *TestSessionConnection) testSendRecv(t *testing.T) bool {
t.Log("received error during write", "err", err)
if errors.Is(err, net.ErrClosed) ||
errors.Is(err, io.EOF) ||
errors.Is(err, websocket.CloseError{Code: websocket.StatusPolicyViolation, Reason: "timed out"}) {
errors.Is(err, websocket.CloseError{Code: websocket.StatusPolicyViolation, Reason: "timed out"}) ||
errors.Is(err, websocket.CloseError{Code: websocket.StatusNormalClosure, Reason: ""}) {
return false
}

Loading…
Cancel
Save