diff --git a/internal/cmd/base/base.go b/internal/cmd/base/base.go index 8534b12569..56b5ef23f3 100644 --- a/internal/cmd/base/base.go +++ b/internal/cmd/base/base.go @@ -76,9 +76,10 @@ var reRemoveWhitespace = regexp.MustCompile(`[\s]+`) var DevOnlyControllerFlags = func(*Command, *FlagSet) {} type Command struct { - Context context.Context - UI cli.Ui - ShutdownCh chan struct{} + Context context.Context + ContextCancel context.CancelFunc + UI cli.Ui + ShutdownCh chan struct{} flags *FlagSets flagsOnce sync.Once @@ -143,7 +144,6 @@ func NewCommand(ui cli.Ui) *Command { ShutdownCh: MakeShutdownCh(), Context: ctx, } - go func() { <-ret.ShutdownCh cancel() @@ -152,6 +152,19 @@ func NewCommand(ui cli.Ui) *Command { return ret } +// New returns a new instance of a base.Command type that does not intercept the shutdown channel +func NewServerCommand(ui cli.Ui) *Command { + ctx, cancel := context.WithCancel(context.Background()) + ret := &Command{ + UI: ui, + ShutdownCh: MakeShutdownCh(), + Context: ctx, + ContextCancel: cancel, + } + + return ret +} + // MakeShutdownCh returns a channel that can be used for shutdown // notifications for commands. This channel will send a message for every // SIGINT or SIGTERM received. @@ -161,8 +174,10 @@ func MakeShutdownCh() chan struct{} { shutdownCh := make(chan os.Signal, 4) signal.Notify(shutdownCh, os.Interrupt, syscall.SIGTERM) go func() { - <-shutdownCh - close(resultCh) + for { + <-shutdownCh + resultCh <- struct{}{} + } }() return resultCh } diff --git a/internal/cmd/base/dev.go b/internal/cmd/base/dev.go index 4df533da25..f8578e1d68 100644 --- a/internal/cmd/base/dev.go +++ b/internal/cmd/base/dev.go @@ -371,18 +371,8 @@ func (b *Server) createInitialOidcAuthMethod(ctx context.Context) (*oidc.AuthMet } } - cancelCtx, cancel := context.WithCancel(ctx) - defer cancel() - go func() { - select { - case <-b.ShutdownCh: - cancel() - case <-cancelCtx.Done(): - } - }() - b.DevOidcSetup.authMethod, err = oidcRepo.CreateAuthMethod( - cancelCtx, + ctx, authMethod, oidc.WithPublicId(b.DevOidcAuthMethodId)) if err != nil { @@ -402,7 +392,7 @@ func (b *Server) createInitialOidcAuthMethod(ctx context.Context) (*oidc.AuthMet return fmt.Errorf("error generating %s oidc account: %w", typ, err) } acct, err = oidcRepo.CreateAccount( - cancelCtx, + ctx, b.DevOidcSetup.authMethod.GetScopeId(), acct, oidc.WithPublicId(accountId), @@ -417,11 +407,11 @@ func (b *Server) createInitialOidcAuthMethod(ctx context.Context) (*oidc.AuthMet return fmt.Errorf("unable to create iam repo: %w", err) } - u, _, err := iamRepo.LookupUser(cancelCtx, userId) + u, _, err := iamRepo.LookupUser(ctx, userId) if err != nil { return fmt.Errorf("error looking up %s user: %w", typ, err) } - if _, err = iamRepo.AddUserAccounts(cancelCtx, u.GetPublicId(), u.GetVersion(), []string{acct.GetPublicId()}); err != nil { + if _, err = iamRepo.AddUserAccounts(ctx, u.GetPublicId(), u.GetVersion(), []string{acct.GetPublicId()}); err != nil { return fmt.Errorf("error associating initial %s user with account: %w", typ, err) } diff --git a/internal/cmd/base/initial_resources.go b/internal/cmd/base/initial_resources.go index 38218c2936..dddbdf23c0 100644 --- a/internal/cmd/base/initial_resources.go +++ b/internal/cmd/base/initial_resources.go @@ -32,15 +32,6 @@ func (b *Server) CreateInitialLoginRole(ctx context.Context) (*iam.Role, error) ); err != nil { return nil, fmt.Errorf("error adding config keys to kms: %w", err) } - cancelCtx, cancel := context.WithCancel(ctx) - defer cancel() - go func() { - select { - case <-b.ShutdownCh: - cancel() - case <-cancelCtx.Done(): - } - }() iamRepo, err := iam.NewRepository(rw, rw, kmsCache, iam.WithRandomReader(b.SecureRandomReader)) if err != nil { @@ -54,11 +45,11 @@ func (b *Server) CreateInitialLoginRole(ctx context.Context) (*iam.Role, error) if err != nil { return nil, fmt.Errorf("error creating in memory role for generated grants: %w", err) } - role, err := iamRepo.CreateRole(cancelCtx, pr) + role, err := iamRepo.CreateRole(ctx, pr) if err != nil { return nil, fmt.Errorf("error creating role for default generated grants: %w", err) } - if _, err := iamRepo.AddRoleGrants(cancelCtx, role.PublicId, role.Version, []string{ + if _, err := iamRepo.AddRoleGrants(ctx, role.PublicId, role.Version, []string{ "id=*;type=scope;actions=list,no-op", "id=*;type=auth-method;actions=authenticate,list", "id={{account.id}};actions=read,change-password", @@ -66,7 +57,7 @@ func (b *Server) CreateInitialLoginRole(ctx context.Context) (*iam.Role, error) }); err != nil { return nil, fmt.Errorf("error creating grant for default generated grants: %w", err) } - if _, err := iamRepo.AddPrincipalRoles(cancelCtx, role.PublicId, role.Version+1, []string{auth.AnonymousUserId}, nil); err != nil { + if _, err := iamRepo.AddPrincipalRoles(ctx, role.PublicId, role.Version+1, []string{auth.AnonymousUserId}, nil); err != nil { return nil, fmt.Errorf("error adding principal to role for default generated grants: %w", err) } @@ -106,17 +97,7 @@ func (b *Server) CreateInitialPasswordAuthMethod(ctx context.Context) (*password } } - cancelCtx, cancel := context.WithCancel(ctx) - defer cancel() - go func() { - select { - case <-b.ShutdownCh: - cancel() - case <-cancelCtx.Done(): - } - }() - - am, err := pwRepo.CreateAuthMethod(cancelCtx, authMethod, + am, err := pwRepo.CreateAuthMethod(ctx, authMethod, password.WithPublicId(b.DevPasswordAuthMethodId)) if err != nil { return nil, nil, fmt.Errorf("error saving auth method to the db: %w", err) @@ -165,7 +146,7 @@ func (b *Server) CreateInitialPasswordAuthMethod(ctx context.Context) (*password return nil, fmt.Errorf("error creating new in memory password auth account: %w", err) } acct, err = pwRepo.CreateAccount( - cancelCtx, + ctx, scope.Global.String(), acct, password.WithPassword(loginPassword), @@ -201,10 +182,10 @@ func (b *Server) CreateInitialPasswordAuthMethod(ctx context.Context) (*password if err != nil { return nil, fmt.Errorf("error creating in memory user: %w", err) } - if u, err = iamRepo.CreateUser(cancelCtx, u, opts...); err != nil { + if u, err = iamRepo.CreateUser(ctx, u, opts...); err != nil { return nil, fmt.Errorf("error creating initial %s user: %w", typeStr, err) } - if _, err = iamRepo.AddUserAccounts(cancelCtx, u.GetPublicId(), u.GetVersion(), []string{acct.GetPublicId()}); err != nil { + if _, err = iamRepo.AddUserAccounts(ctx, u.GetPublicId(), u.GetVersion(), []string{acct.GetPublicId()}); err != nil { return nil, fmt.Errorf("error associating initial %s user with account: %w", typeStr, err) } if !admin { @@ -218,14 +199,14 @@ func (b *Server) CreateInitialPasswordAuthMethod(ctx context.Context) (*password if err != nil { return nil, fmt.Errorf("error creating in memory role for generated grants: %w", err) } - defPermsRole, err := iamRepo.CreateRole(cancelCtx, pr) + defPermsRole, err := iamRepo.CreateRole(ctx, pr) if err != nil { return nil, fmt.Errorf("error creating role for default generated grants: %w", err) } - if _, err := iamRepo.AddRoleGrants(cancelCtx, defPermsRole.PublicId, defPermsRole.Version, []string{"id=*;type=*;actions=*"}); err != nil { + if _, err := iamRepo.AddRoleGrants(ctx, defPermsRole.PublicId, defPermsRole.Version, []string{"id=*;type=*;actions=*"}); err != nil { return nil, fmt.Errorf("error creating grant for default generated grants: %w", err) } - if _, err := iamRepo.AddPrincipalRoles(cancelCtx, defPermsRole.PublicId, defPermsRole.Version+1, []string{u.GetPublicId()}, nil); err != nil { + if _, err := iamRepo.AddPrincipalRoles(ctx, defPermsRole.PublicId, defPermsRole.Version+1, []string{u.GetPublicId()}, nil); err != nil { return nil, fmt.Errorf("error adding principal to role for default generated grants: %w", err) } return u, nil @@ -282,16 +263,6 @@ func (b *Server) CreateInitialScopes(ctx context.Context) (*iam.Scope, *iam.Scop return nil, nil, fmt.Errorf("error adding config keys to kms: %w", err) } - cancelCtx, cancel := context.WithCancel(ctx) - defer cancel() - go func() { - select { - case <-b.ShutdownCh: - cancel() - case <-cancelCtx.Done(): - } - }() - iamRepo, err := iam.NewRepository(rw, rw, kmsCache) if err != nil { return nil, nil, fmt.Errorf("error creating scopes repository: %w", err) @@ -314,7 +285,7 @@ func (b *Server) CreateInitialScopes(ctx context.Context) (*iam.Scope, *iam.Scop if err != nil { return nil, nil, fmt.Errorf("error creating new in memory org scope: %w", err) } - orgScope, err = iamRepo.CreateScope(cancelCtx, orgScope, b.DevUserId, opts...) + orgScope, err = iamRepo.CreateScope(ctx, orgScope, b.DevUserId, opts...) if err != nil { return nil, nil, fmt.Errorf("error saving org scope to the db: %w", err) } @@ -337,7 +308,7 @@ func (b *Server) CreateInitialScopes(ctx context.Context) (*iam.Scope, *iam.Scop if err != nil { return nil, nil, fmt.Errorf("error creating new in memory project scope: %w", err) } - projScope, err = iamRepo.CreateScope(cancelCtx, projScope, b.DevUserId, opts...) + projScope, err = iamRepo.CreateScope(ctx, projScope, b.DevUserId, opts...) if err != nil { return nil, nil, fmt.Errorf("error saving project scope to the db: %w", err) } @@ -361,16 +332,6 @@ func (b *Server) CreateInitialHostResources(ctx context.Context) (*static.HostCa return nil, nil, nil, fmt.Errorf("error adding config keys to kms: %w", err) } - cancelCtx, cancel := context.WithCancel(ctx) - defer cancel() - go func() { - select { - case <-b.ShutdownCh: - cancel() - case <-cancelCtx.Done(): - } - }() - staticRepo, err := static.NewRepository(rw, rw, kmsCache) if err != nil { return nil, nil, nil, fmt.Errorf("error creating static repository: %w", err) @@ -392,7 +353,7 @@ func (b *Server) CreateInitialHostResources(ctx context.Context) (*static.HostCa if err != nil { return nil, nil, nil, fmt.Errorf("error creating in memory host catalog: %w", err) } - if hc, err = staticRepo.CreateCatalog(cancelCtx, hc, opts...); err != nil { + if hc, err = staticRepo.CreateCatalog(ctx, hc, opts...); err != nil { return nil, nil, nil, fmt.Errorf("error saving host catalog to the db: %w", err) } b.InfoKeys = append(b.InfoKeys, "generated host catalog id") @@ -418,7 +379,7 @@ func (b *Server) CreateInitialHostResources(ctx context.Context) (*static.HostCa if err != nil { return nil, nil, nil, fmt.Errorf("error creating in memory host: %w", err) } - if h, err = staticRepo.CreateHost(cancelCtx, b.DevProjectId, h, opts...); err != nil { + if h, err = staticRepo.CreateHost(ctx, b.DevProjectId, h, opts...); err != nil { return nil, nil, nil, fmt.Errorf("error saving host to the db: %w", err) } b.InfoKeys = append(b.InfoKeys, "generated host id") @@ -440,14 +401,14 @@ func (b *Server) CreateInitialHostResources(ctx context.Context) (*static.HostCa if err != nil { return nil, nil, nil, fmt.Errorf("error creating in memory host set: %w", err) } - if hs, err = staticRepo.CreateSet(cancelCtx, b.DevProjectId, hs, opts...); err != nil { + if hs, err = staticRepo.CreateSet(ctx, b.DevProjectId, hs, opts...); err != nil { return nil, nil, nil, fmt.Errorf("error saving host set to the db: %w", err) } b.InfoKeys = append(b.InfoKeys, "generated host set id") b.Info["generated host set id"] = b.DevHostSetId // Associate members - if _, err := staticRepo.AddSetMembers(cancelCtx, b.DevProjectId, b.DevHostSetId, hs.GetVersion(), []string{h.GetPublicId()}); err != nil { + if _, err := staticRepo.AddSetMembers(ctx, b.DevProjectId, b.DevHostSetId, hs.GetVersion(), []string{h.GetPublicId()}); err != nil { return nil, nil, nil, fmt.Errorf("error associating host set to host in the db: %w", err) } @@ -468,16 +429,6 @@ func (b *Server) CreateInitialTarget(ctx context.Context) (target.Target, error) return nil, fmt.Errorf("error adding config keys to kms: %w", err) } - cancelCtx, cancel := context.WithCancel(ctx) - defer cancel() - go func() { - select { - case <-b.ShutdownCh: - cancel() - case <-cancelCtx.Done(): - } - }() - targetRepo, err := target.NewRepository(ctx, rw, rw, kmsCache) if err != nil { return nil, fmt.Errorf("error creating target repository: %w", err) @@ -505,7 +456,7 @@ func (b *Server) CreateInitialTarget(ctx context.Context) (target.Target, error) if err != nil { return nil, fmt.Errorf("error creating in memory target: %w", err) } - tt, _, _, err := targetRepo.CreateTarget(cancelCtx, t, opts...) + tt, _, _, err := targetRepo.CreateTarget(ctx, t, opts...) if err != nil { return nil, fmt.Errorf("error saving target to the db: %w", err) } @@ -559,18 +510,18 @@ func (b *Server) CreateInitialTarget(ctx context.Context) (target.Target, error) if err != nil { return nil, fmt.Errorf("error creating in memory role for generated grants: %w", err) } - sessionRole, err := iamRepo.CreateRole(cancelCtx, pr) + sessionRole, err := iamRepo.CreateRole(ctx, pr) if err != nil { return nil, fmt.Errorf("error creating role for unprivileged user generated grants: %w", err) } - if _, err := iamRepo.AddRoleGrants(cancelCtx, + if _, err := iamRepo.AddRoleGrants(ctx, sessionRole.PublicId, sessionRole.Version, []string{fmt.Sprintf("id=%s;actions=authorize-session", b.DevTargetId)}, ); err != nil { return nil, fmt.Errorf("error creating grant for unprivileged user generated grants: %w", err) } - if _, err := iamRepo.AddPrincipalRoles(cancelCtx, sessionRole.PublicId, sessionRole.Version+1, []string{b.DevUnprivilegedUserId}, nil); err != nil { + if _, err := iamRepo.AddPrincipalRoles(ctx, sessionRole.PublicId, sessionRole.Version+1, []string{b.DevUnprivilegedUserId}, nil); err != nil { return nil, fmt.Errorf("error adding principal to role for unprivileged user generated grants: %w", err) } } @@ -599,16 +550,6 @@ func (b *Server) RegisterHostPlugin(ctx context.Context, name string, plg plgpb. return nil, fmt.Errorf("error adding config keys to kms: %w", err) } - cancelCtx, cancel := context.WithCancel(ctx) - defer cancel() - go func() { - select { - case <-b.ShutdownCh: - cancel() - case <-cancelCtx.Done(): - } - }() - hpRepo, err := hostplugin.NewRepository(rw, rw, kmsCache) if err != nil { return nil, fmt.Errorf("error creating host plugin repository: %w", err) @@ -622,7 +563,7 @@ func (b *Server) RegisterHostPlugin(ctx context.Context, name string, plg plgpb. if plugin == nil { opt = append(opt, hostplugin.WithName(name)) plugin = hostplugin.NewPlugin(opt...) - plugin, err = hpRepo.CreatePlugin(cancelCtx, plugin, opt...) + plugin, err = hpRepo.CreatePlugin(ctx, plugin, opt...) if err != nil { return nil, fmt.Errorf("error creating host plugin: %w", err) } diff --git a/internal/cmd/base/servers.go b/internal/cmd/base/servers.go index 433d3ceb36..205bd524e1 100644 --- a/internal/cmd/base/servers.go +++ b/internal/cmd/base/servers.go @@ -729,17 +729,7 @@ func (b *Server) CreateGlobalKmsKeys(ctx context.Context) error { return fmt.Errorf("error adding config keys to kms: %w", err) } - cancelCtx, cancel := context.WithCancel(ctx) - defer cancel() - go func() { - select { - case <-b.ShutdownCh: - cancel() - case <-cancelCtx.Done(): - } - }() - - if err = kmsCache.CreateKeys(cancelCtx, scope.Global.String(), kms.WithRandomReader(b.SecureRandomReader)); err != nil { + if err = kmsCache.CreateKeys(ctx, scope.Global.String(), kms.WithRandomReader(b.SecureRandomReader)); err != nil { return fmt.Errorf("error creating global scope kms keys: %w", err) } diff --git a/internal/cmd/commands.go b/internal/cmd/commands.go index 3075ee6ff0..f09bba006f 100644 --- a/internal/cmd/commands.go +++ b/internal/cmd/commands.go @@ -38,14 +38,14 @@ func initCommands(ui, serverCmdUi cli.Ui, runOpts *RunOptions) { Commands = map[string]cli.CommandFactory{ "server": func() (cli.Command, error) { return &server.Command{ - Server: base.NewServer(base.NewCommand(serverCmdUi)), + Server: base.NewServer(base.NewServerCommand(serverCmdUi)), SighupCh: base.MakeSighupCh(), SigUSR2Ch: MakeSigUSR2Ch(), }, nil }, "dev": func() (cli.Command, error) { return &dev.Command{ - Server: base.NewServer(base.NewCommand(serverCmdUi)), + Server: base.NewServer(base.NewServerCommand(serverCmdUi)), SighupCh: base.MakeSighupCh(), SigUSR2Ch: MakeSigUSR2Ch(), }, nil diff --git a/internal/cmd/commands/dev/dev.go b/internal/cmd/commands/dev/dev.go index 32661c9427..44dce1bfbf 100644 --- a/internal/cmd/commands/dev/dev.go +++ b/internal/cmd/commands/dev/dev.go @@ -7,10 +7,10 @@ import ( "math/rand" "net" "os" - "os/signal" "runtime" "strings" - "syscall" + "sync" + atm "sync/atomic" "time" "github.com/hashicorp/boundary/globals" @@ -824,74 +824,92 @@ func (c *Command) Run(args []string) int { c.opsServer = opsServer c.opsServer.Start() - // Wait for shutdown - shutdownTriggered := false + var shutdownCompleted atm.Bool + shutdownTriggerCount := 0 - // Add a force-shutdown goroutine to consume more interrupts - abortForceShutdownCh := make(chan struct{}) - defer close(abortForceShutdownCh) + var workerShutdownOnce sync.Once + workerShutdownFunc := func() { + if err := c.worker.Shutdown(); err != nil { + c.UI.Error(fmt.Errorf("Error shutting down worker: %w", err).Error()) + } + if !c.flagWorkerAuthStorageSkipCleanup && c.worker.WorkerAuthStorage != nil { + c.worker.WorkerAuthStorage.Cleanup() + } + } + workerGracefulShutdownFunc := func() { + if err := c.worker.GracefulShutdown(); err != nil { + c.UI.Error(fmt.Errorf("Error shutting down worker gracefully: %w", err).Error()) + } + workerShutdownOnce.Do(workerShutdownFunc) + } + var controllerOnce sync.Once + controllerShutdownFunc := func() { + if err := c.controller.Shutdown(); err != nil { + c.UI.Error(fmt.Errorf("Error shutting down controller: %w", err).Error()) + } + err := c.opsServer.Shutdown() + if err != nil { + c.UI.Error(fmt.Errorf("Failed to shutdown ops listeners: %w", err).Error()) + } + } runShutdownLogic := func() { - go func() { - shutdownCh := make(chan os.Signal, 4) - signal.Notify(shutdownCh, os.Interrupt, syscall.SIGTERM) - for { - select { - case <-shutdownCh: - c.UI.Error("Forcing shutdown") - os.Exit(base.CommandUserError) - - case <-c.ServerSideShutdownCh: - // Drain connections in case this is hit more than once - - case <-abortForceShutdownCh: - // No-op, we just use this to shut down the goroutine - return + switch { + case shutdownTriggerCount == 1: + c.ContextCancel() + go func() { + if c.Config.Controller != nil { + c.opsServer.WaitIfHealthExists(c.Config.Controller.GracefulShutdownWaitDuration, c.UI) + } + + if !c.flagControllerOnly { + c.UI.Output("==> Boundary dev environment graceful shutdown triggered, interrupt again to enter shutdown") + workerGracefulShutdownFunc() + } else { + c.UI.Output("==> Boundary dev shutdown triggered, interrupt again to force") } - } - }() - shutdownTriggered = true + controllerOnce.Do(controllerShutdownFunc) + + shutdownCompleted.Store(true) + }() + case shutdownTriggerCount == 2 && !c.flagControllerOnly: + go func() { + if !c.flagControllerOnly { + workerShutdownOnce.Do(workerShutdownFunc) + } + + if c.Config.Controller != nil { + controllerOnce.Do(controllerShutdownFunc) + } + shutdownCompleted.Store(true) + }() + + case shutdownTriggerCount >= 2: + go func() { + c.UI.Error("Forcing shutdown") + os.Exit(base.CommandCliError) + }() + } } - for !shutdownTriggered && !errorEncountered.Load() { + for !errorEncountered.Load() && !shutdownCompleted.Load() { select { case <-c.ServerSideShutdownCh: c.UI.Output("==> Boundary dev environment self-terminating") + shutdownTriggerCount++ runShutdownLogic() case <-c.ShutdownCh: - c.UI.Output("==> Boundary dev environment shutdown triggered, interrupt again to force") + shutdownTriggerCount++ runShutdownLogic() case <-c.SigUSR2Ch: buf := make([]byte, 32*1024*1024) n := runtime.Stack(buf[:], true) event.WriteSysEvent(context.TODO(), op, "goroutine trace", "stack", string(buf[:n])) - } - - if shutdownTriggered { - if c.Config.Controller != nil { - c.opsServer.WaitIfHealthExists(c.Config.Controller.GracefulShutdownWaitDuration, c.UI) - } - if !c.flagControllerOnly { - if !c.flagWorkerAuthStorageSkipCleanup && c.worker.WorkerAuthStorage != nil { - c.worker.WorkerAuthStorage.Cleanup() - } - if err := c.worker.Shutdown(); err != nil { - c.UI.Error(fmt.Errorf("Error shutting down worker: %w", err).Error()) - } - } - - if err := c.controller.Shutdown(); err != nil { - c.UI.Error(fmt.Errorf("Error shutting down controller: %w", err).Error()) - } - - err := c.opsServer.Shutdown() - if err != nil { - c.UI.Error(fmt.Errorf("Failed to shutdown ops listeners: %w", err).Error()) - } + case <-time.After(10 * time.Millisecond): } } diff --git a/internal/cmd/commands/server/controller_db_swap_test.go b/internal/cmd/commands/server/controller_db_swap_test.go index 580d5702a8..c3d8c615aa 100644 --- a/internal/cmd/commands/server/controller_db_swap_test.go +++ b/internal/cmd/commands/server/controller_db_swap_test.go @@ -215,7 +215,7 @@ func TestReloadControllerDatabase(t *testing.T) { require.NoError(t, row.Scan(&lock)) require.Equal(t, 1, lock) - close(cmd.ShutdownCh) + cmd.ShutdownCh <- struct{}{} wg.Wait() } @@ -317,7 +317,7 @@ func TestReloadControllerDatabase_InvalidNewDatabaseState(t *testing.T) { require.NoError(t, row.Scan(&lock)) require.Equal(t, 1, lock) - close(cmd.ShutdownCh) + cmd.ShutdownCh <- struct{}{} wg.Wait() } diff --git a/internal/cmd/commands/server/listener_reload_test.go b/internal/cmd/commands/server/listener_reload_test.go index fc56c97b4e..a62cdaa83d 100644 --- a/internal/cmd/commands/server/listener_reload_test.go +++ b/internal/cmd/commands/server/listener_reload_test.go @@ -167,7 +167,6 @@ func TestServer_ReloadListener(t *testing.T) { testCertificateSerial("193080739105342897219784862820114567438786419504") - close(cmd.ShutdownCh) - + cmd.ShutdownCh <- struct{}{} wg.Wait() } diff --git a/internal/cmd/commands/server/server.go b/internal/cmd/commands/server/server.go index 93a24cc8ac..19a591890d 100644 --- a/internal/cmd/commands/server/server.go +++ b/internal/cmd/commands/server/server.go @@ -6,10 +6,10 @@ import ( "fmt" "net" "os" - "os/signal" "runtime" "strings" - "syscall" + "sync" + atm "sync/atomic" "time" "github.com/hashicorp/boundary/globals" @@ -667,69 +667,83 @@ func (c *Command) StartWorker() error { func (c *Command) WaitForInterrupt() int { const op = "server.(Command).WaitForInterrupt" - // Wait for shutdown - shutdownTriggered := false - // Add a force-shutdown goroutine to consume another interrupt - abortForceShutdownCh := make(chan struct{}) - defer close(abortForceShutdownCh) + var shutdownCompleted atm.Bool + shutdownTriggerCount := 0 - runShutdownLogic := func() { - go func() { - shutdownCh := make(chan os.Signal, 4) - signal.Notify(shutdownCh, os.Interrupt, syscall.SIGTERM) - for { - select { - case <-shutdownCh: - c.UI.Error("Forcing shutdown") - os.Exit(base.CommandUserError) - - case <-c.ServerSideShutdownCh: - // Drain connections in case this is hit more than once - - case <-abortForceShutdownCh: - // No-op, we just use this to shut down the goroutine - return - } - } - }() - - if c.Config.Controller != nil && c.opsServer != nil { - c.opsServer.WaitIfHealthExists(c.Config.Controller.GracefulShutdownWaitDuration, c.UI) + var workerShutdownOnce sync.Once + workerShutdownFunc := func() { + if err := c.worker.Shutdown(); err != nil { + c.UI.Error(fmt.Errorf("Error shutting down worker: %w", err).Error()) } - - // Do worker shutdown - if c.Config.Worker != nil { - if err := c.worker.Shutdown(); err != nil { - c.UI.Error(fmt.Errorf("Error shutting down worker: %w", err).Error()) - } + } + workerGracefulShutdownFunc := func() { + if err := c.worker.GracefulShutdown(); err != nil { + c.UI.Error(fmt.Errorf("Error shutting down worker gracefully: %w", err).Error()) } - - // Do controller shutdown - if c.Config.Controller != nil { - if err := c.controller.Shutdown(); err != nil { - c.UI.Error(fmt.Errorf("Error shutting down controller: %w", err).Error()) - } + workerShutdownOnce.Do(workerShutdownFunc) + } + var controllerOnce sync.Once + controllerShutdownFunc := func() { + if err := c.controller.Shutdown(); err != nil { + c.UI.Error(fmt.Errorf("Error shutting down controller: %w", err).Error()) } - if c.opsServer != nil { err := c.opsServer.Shutdown() if err != nil { - c.UI.Error(fmt.Errorf("Error shutting down ops listeners: %w", err).Error()) + c.UI.Error(fmt.Errorf("Failed to shutdown ops listeners: %w", err).Error()) } } + } + + runShutdownLogic := func() { + switch { + case shutdownTriggerCount == 1: + c.ContextCancel() + go func() { + if c.Config.Controller != nil && c.opsServer != nil { + c.opsServer.WaitIfHealthExists(c.Config.Controller.GracefulShutdownWaitDuration, c.UI) + } + + if c.Config.Worker != nil { + c.UI.Output("==> Boundary server graceful shutdown triggered, interrupt again to enter shutdown") + workerGracefulShutdownFunc() + } else { + c.UI.Output("==> Boundary server shutdown triggered, interrupt again to force") + } - shutdownTriggered = true + if c.Config.Controller != nil { + controllerOnce.Do(controllerShutdownFunc) + } + shutdownCompleted.Store(true) + }() + + case shutdownTriggerCount == 2 && c.Config.Worker != nil: + go func() { + if c.Config.Worker != nil { + workerShutdownOnce.Do(workerShutdownFunc) + } + if c.Config.Controller != nil { + controllerOnce.Do(controllerShutdownFunc) + } + shutdownCompleted.Store(true) + }() + + case shutdownTriggerCount >= 2: + c.UI.Error("Forcing shutdown") + os.Exit(base.CommandCliError) + } } - for !shutdownTriggered { + for !shutdownCompleted.Load() { select { case <-c.ServerSideShutdownCh: c.UI.Output("==> Boundary server self-terminating") + shutdownTriggerCount++ runShutdownLogic() case <-c.ShutdownCh: - c.UI.Output("==> Boundary server shutdown triggered, interrupt again to force") + shutdownTriggerCount++ runShutdownLogic() case <-c.SighupCh: @@ -784,6 +798,8 @@ func (c *Command) WaitForInterrupt() int { buf := make([]byte, 32*1024*1024) n := runtime.Stack(buf[:], true) event.WriteSysEvent(context.TODO(), op, "goroutine trace", "stack", string(buf[:n])) + + case <-time.After(10 * time.Millisecond): } } diff --git a/internal/cmd/commands/server/server_test.go b/internal/cmd/commands/server/server_test.go index bfe1619c8e..2454f09a2c 100644 --- a/internal/cmd/commands/server/server_test.go +++ b/internal/cmd/commands/server/server_test.go @@ -46,7 +46,7 @@ func testServerCommand(t *testing.T, opts testServerCommandOpts) *Command { require := require.New(t) t.Helper() cmd := &Command{ - Server: base.NewServer(base.NewCommand(cli.NewMockUi())), + Server: base.NewServer(base.NewServerCommand(cli.NewMockUi())), SighupCh: base.MakeSighupCh(), startedCh: make(chan struct{}), reloadedCh: make(chan struct{}, 5), diff --git a/internal/cmd/commands/server/worker_initial_upstreams_reload_test.go b/internal/cmd/commands/server/worker_initial_upstreams_reload_test.go index c83a505584..5901b11cb1 100644 --- a/internal/cmd/commands/server/worker_initial_upstreams_reload_test.go +++ b/internal/cmd/commands/server/worker_initial_upstreams_reload_test.go @@ -135,7 +135,6 @@ pollSecondController: } } - close(cmd.ShutdownCh) - + cmd.ShutdownCh <- struct{}{} wg.Wait() } diff --git a/internal/cmd/commands/server/worker_shutdown_reload_test.go b/internal/cmd/commands/server/worker_shutdown_reload_test.go index 546840464b..7982393eba 100644 --- a/internal/cmd/commands/server/worker_shutdown_reload_test.go +++ b/internal/cmd/commands/server/worker_shutdown_reload_test.go @@ -98,13 +98,13 @@ func TestServer_ShutdownWorker(t *testing.T) { t.Log("running initial send/recv test") sConn.TestSendRecvAll(t) - // Now, shut the worker down. - close(workerCmd.ShutdownCh) + // Shutdown the worker and close the connection, as the worker will otherwise wait for it to close. + sConn.Close() + workerCmd.ShutdownCh <- struct{}{} if <-workerCodeChan != 0 { output := workerCmd.UI.(*cli.MockUi).ErrorWriter.String() + workerCmd.UI.(*cli.MockUi).OutputWriter.String() require.FailNow(output, "command exited with non-zero error code") } - // Connection should fail, and the session should be closed on the DB. sConn.TestSendRecvFail(t) sess.ExpectConnectionStateOnController(ctx, t, testController.Controller().ConnectionRepoFn, session.StatusClosed) diff --git a/internal/cmd/commands/server/worker_tags_reload_test.go b/internal/cmd/commands/server/worker_tags_reload_test.go index 0fbd9a13fe..373eef5555 100644 --- a/internal/cmd/commands/server/worker_tags_reload_test.go +++ b/internal/cmd/commands/server/worker_tags_reload_test.go @@ -130,7 +130,7 @@ func TestServer_ReloadWorkerTags(t *testing.T) { time.Sleep(10 * time.Second) fetchWorkerTags("test", "foo", []string{"bar", "baz"}) - close(cmd.ShutdownCh) + cmd.ShutdownCh <- struct{}{} wg.Wait() } diff --git a/internal/daemon/controller/testing.go b/internal/daemon/controller/testing.go index ed542ca314..ae50315758 100644 --- a/internal/daemon/controller/testing.go +++ b/internal/daemon/controller/testing.go @@ -287,7 +287,7 @@ func (tc *TestController) buildClient() { func (tc *TestController) Shutdown() { tc.shutdownOnce.Do(func() { if tc.b != nil { - close(tc.b.ShutdownCh) + tc.b.ContextCancel() } tc.cancel() @@ -492,10 +492,13 @@ func TestControllerConfig(t testing.TB, ctx context.Context, tc *TestController, opts = new(TestControllerOpts) } + ctxTest, cancel := context.WithCancel(context.Background()) + // Base server tc.b = base.NewServer(&base.Command{ - Context: ctx, - ShutdownCh: make(chan struct{}), + Context: ctxTest, + ContextCancel: cancel, + ShutdownCh: make(chan struct{}), }) // Get dev config, or use a provided one diff --git a/internal/daemon/worker/worker.go b/internal/daemon/worker/worker.go index 4fb779dbc4..b88aa864c1 100644 --- a/internal/daemon/worker/worker.go +++ b/internal/daemon/worker/worker.go @@ -423,6 +423,38 @@ func (w *Worker) Start() error { return nil } +func (w *Worker) hasActiveConnection() bool { + activeConnection := false + w.sessionManager.ForEachLocalSession( + func(s session.Session) bool { + conns := s.GetLocalConnections() + for _, v := range conns { + if v.Status == pbs.CONNECTIONSTATUS_CONNECTIONSTATUS_CONNECTED { + activeConnection = true + return false + } + } + return true + }) + return activeConnection +} + +// Graceful shutdown sets the worker state to "shutdown" and will wait to return until there +// are no longer any active connections. +func (w *Worker) GracefulShutdown() error { + const op = "worker.(Worker).GracefulShutdown" + event.WriteSysEvent(w.baseContext, op, "worker entering graceful shutdown") + w.operationalState.Store(server.ShutdownOperationalState) + + // Wait for connections to drain + for w.hasActiveConnection() { + time.Sleep(time.Millisecond * 250) + } + event.WriteSysEvent(w.baseContext, op, "worker connections have drained") + + return nil +} + // Shutdown shuts down the workers. skipListeners can be used to not stop // listeners, useful for tests if we want to stop and start a worker. In order // to create new listeners we'd have to migrate listener setup logic here -- @@ -433,9 +465,11 @@ func (w *Worker) Shutdown() error { event.WriteSysEvent(w.baseContext, op, "already shut down, skipping") return nil } + event.WriteSysEvent(w.baseContext, op, "worker shutting down") // Set state to shutdown w.operationalState.Store(server.ShutdownOperationalState) + // Stop listeners first to prevent new connections to the // controller. defer w.started.Store(false) @@ -479,6 +513,7 @@ func (w *Worker) Shutdown() error { } } + event.WriteSysEvent(w.baseContext, op, "worker finished shutting down") return nil } diff --git a/internal/tests/helper/testing_helper.go b/internal/tests/helper/testing_helper.go index 24b9cb6024..0267ee1b23 100644 --- a/internal/tests/helper/testing_helper.go +++ b/internal/tests/helper/testing_helper.go @@ -301,6 +301,11 @@ type TestSessionConnection struct { conn net.Conn } +// Close a test connection +func (t *TestSessionConnection) Close() error { + return t.conn.Close() +} + // Connect returns a TestSessionConnection for a TestSession. Check // the unexported connect method for the lower-level details. func (s *TestSession) Connect( @@ -344,7 +349,8 @@ func (c *TestSessionConnection) testSendRecv(t *testing.T) bool { t.Log("received error during write", "err", err) if errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) || - errors.Is(err, websocket.CloseError{Code: websocket.StatusPolicyViolation, Reason: "timed out"}) { + errors.Is(err, websocket.CloseError{Code: websocket.StatusPolicyViolation, Reason: "timed out"}) || + errors.Is(err, websocket.CloseError{Code: websocket.StatusNormalClosure, Reason: ""}) { return false }