You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
boundary/internal/daemon/controller/controller.go

673 lines
26 KiB

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package controller
import (
"context"
"crypto/rand"
"fmt"
"strings"
"sync"
"sync/atomic"
"github.com/hashicorp/boundary/internal/alias"
talias "github.com/hashicorp/boundary/internal/alias/target"
"github.com/hashicorp/boundary/internal/auth"
"github.com/hashicorp/boundary/internal/auth/ldap"
"github.com/hashicorp/boundary/internal/auth/oidc"
"github.com/hashicorp/boundary/internal/auth/password"
"github.com/hashicorp/boundary/internal/authtoken"
"github.com/hashicorp/boundary/internal/billing"
"github.com/hashicorp/boundary/internal/census"
"github.com/hashicorp/boundary/internal/cmd/base"
"github.com/hashicorp/boundary/internal/cmd/config"
"github.com/hashicorp/boundary/internal/credential"
credstatic "github.com/hashicorp/boundary/internal/credential/static"
"github.com/hashicorp/boundary/internal/credential/vault"
"github.com/hashicorp/boundary/internal/daemon/cluster"
"github.com/hashicorp/boundary/internal/daemon/controller/common"
"github.com/hashicorp/boundary/internal/daemon/controller/handlers/health"
"github.com/hashicorp/boundary/internal/daemon/controller/internal/metric"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/boundary/internal/event"
intglobals "github.com/hashicorp/boundary/internal/globals"
"github.com/hashicorp/boundary/internal/host"
pluginhost "github.com/hashicorp/boundary/internal/host/plugin"
"github.com/hashicorp/boundary/internal/host/static"
"github.com/hashicorp/boundary/internal/iam"
"github.com/hashicorp/boundary/internal/kms"
kmsjob "github.com/hashicorp/boundary/internal/kms/job"
"github.com/hashicorp/boundary/internal/pagination/purge"
"github.com/hashicorp/boundary/internal/plugin"
"github.com/hashicorp/boundary/internal/plugin/loopback"
"github.com/hashicorp/boundary/internal/ratelimit"
"github.com/hashicorp/boundary/internal/recording"
"github.com/hashicorp/boundary/internal/scheduler"
"github.com/hashicorp/boundary/internal/scheduler/cleaner"
"github.com/hashicorp/boundary/internal/scheduler/job"
"github.com/hashicorp/boundary/internal/server"
serversjob "github.com/hashicorp/boundary/internal/server/job"
"github.com/hashicorp/boundary/internal/session"
"github.com/hashicorp/boundary/internal/snapshot"
pluginstorage "github.com/hashicorp/boundary/internal/storage/plugin"
"github.com/hashicorp/boundary/internal/target"
"github.com/hashicorp/boundary/internal/types/scope"
boundary_plugin_assets "github.com/hashicorp/boundary/plugins/boundary"
plgpb "github.com/hashicorp/boundary/sdk/pbs/plugin"
external_plugins "github.com/hashicorp/boundary/sdk/plugins"
"github.com/hashicorp/boundary/version"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-secure-stdlib/mlock"
"github.com/hashicorp/go-secure-stdlib/pluginutil/v2"
"github.com/hashicorp/nodeenrollment"
ua "go.uber.org/atomic"
"google.golang.org/grpc"
)
// downstreamReceiver defines a min interface which must be met by a
// Controller.downstreamConns field
type downstreamReceiver interface {
// StartConnectionMgmtTicking starts a ticker which manages the receiver's
// connections.
StartConnectionMgmtTicking(context.Context, func() string, int) error
// StartProcessingPendingConnections starts a function that continually processes
// incoming client connections. This only returns when the provided context
// is done.
StartProcessingPendingConnections(context.Context, func() string) error
}
// downstreamWorkersTicker defines an interface for a ticker that maintains the
// graph of the controller's downstream workers
type downstreamWorkersTicker interface {
// StartDownstreamWorkersTicking is used by a Controller to maintain their
// graph of downstream workers.
StartDownstreamWorkersTicking(context.Context, int) error
}
var (
downstreamReceiverFactory func() downstreamReceiver
downstreamersFactory func(context.Context, string, string) (common.Downstreamers, error)
downstreamWorkersTickerFactory func(context.Context, string, string, common.Downstreamers, downstreamReceiver) (downstreamWorkersTicker, error)
commandClientFactory func(context.Context, *Controller) error
extControllerFactory func(ctx context.Context, c *Controller, r db.Reader, w db.Writer, kms *kms.Kms) (intglobals.ControllerExtension, error)
)
type Controller struct {
conf *Config
logger hclog.Logger
baseContext context.Context
baseCancel context.CancelFunc
started *ua.Bool
tickerWg *sync.WaitGroup
schedulerWg *sync.WaitGroup
workerAuthCache *sync.Map
// downstream workers and routes to those workers
downstreamWorkers common.Downstreamers
downstreamConns downstreamReceiver
apiListeners []*base.ServerListener
clusterListener *base.ServerListener
// Used for testing and tracking worker health
workerStatusUpdateTimes *sync.Map
// Timing variables. These are atomics for SIGHUP support, and are int64
// because they are casted to time.Duration.
workerStatusGracePeriod *atomic.Int64
livenessTimeToStale *atomic.Int64
apiGrpcServer *grpc.Server
apiGrpcServerListener grpcServerListener
apiGrpcGatewayTicket string
rateLimiter ratelimit.Limiter
rateLimiterMu sync.RWMutex
// Repo factory methods
AuthTokenRepoFn common.AuthTokenRepoFactory
VaultCredentialRepoFn common.VaultCredentialRepoFactory
StaticCredentialRepoFn common.StaticCredentialRepoFactory
CredentialStoreRepoFn common.CredentialStoreRepoFactory
HostCatalogRepoFn common.HostCatalogRepoFactory
IamRepoFn common.IamRepoFactory
OidcRepoFn common.OidcAuthRepoFactory
LdapRepoFn common.LdapAuthRepoFactory
PasswordAuthRepoFn common.PasswordAuthRepoFactory
AuthMethodRepoFn common.AuthMethodRepoFactory
ServersRepoFn common.ServersRepoFactory
SessionRepoFn session.RepositoryFactory
ConnectionRepoFn common.ConnectionRepoFactory
StaticHostRepoFn common.StaticRepoFactory
PluginHostRepoFn common.PluginHostRepoFactory
PluginStorageBucketRepoFn common.PluginStorageBucketRepoFactory
PluginRepoFn common.PluginRepoFactory
TargetRepoFn target.RepositoryFactory
WorkerAuthRepoStorageFn common.WorkerAuthRepoStorageFactory
BillingRepoFn common.BillingRepoFactory
AliasRepoFn common.AliasRepoFactory
TargetAliasRepoFn common.TargetAliasRepoFactory
scheduler *scheduler.Scheduler
kms *kms.Kms
enabledPlugins []base.EnabledPlugin
// Used to signal the Health Service to start
// replying to queries with "503 Service Unavailable".
HealthService *health.Service
downstreamConnManager *cluster.DownstreamManager
// ControllerExtension defines a std way to extend the controller
ControllerExtension intglobals.ControllerExtension
}
func New(ctx context.Context, conf *Config) (*Controller, error) {
const op = "controller.New"
metric.InitializeApiCollectors(conf.PrometheusRegisterer)
ratelimit.InitializeMetrics(conf.PrometheusRegisterer)
c := &Controller{
conf: conf,
logger: conf.Logger.Named("controller"),
started: ua.NewBool(false),
tickerWg: new(sync.WaitGroup),
schedulerWg: new(sync.WaitGroup),
workerAuthCache: new(sync.Map),
workerStatusUpdateTimes: new(sync.Map),
enabledPlugins: conf.Server.EnabledPlugins,
apiListeners: make([]*base.ServerListener, 0),
downstreamConnManager: cluster.NewDownstreamManager(),
workerStatusGracePeriod: new(atomic.Int64),
livenessTimeToStale: new(atomic.Int64),
}
if downstreamReceiverFactory != nil {
c.downstreamConns = downstreamReceiverFactory()
}
c.started.Store(false)
if conf.SecureRandomReader == nil {
conf.SecureRandomReader = rand.Reader
}
var err error
if conf.RawConfig.Controller == nil {
conf.RawConfig.Controller = new(config.Controller)
}
if err := conf.RawConfig.Controller.InitNameIfEmpty(ctx); err != nil {
return nil, fmt.Errorf("error auto-generating controller name: %w", err)
}
if !conf.RawConfig.DisableMlock {
// Ensure our memory usage is locked into physical RAM
if err := mlock.LockMemory(); err != nil {
return nil, fmt.Errorf(
"Failed to lock memory: %v\n\n"+
"This usually means that the mlock syscall is not available.\n"+
"Boundary uses mlock to prevent memory from being swapped to\n"+
"disk. This requires root privileges as well as a machine\n"+
"that supports mlock. Please enable mlock on your system or\n"+
"disable Boundary from using it. To disable Boundary from using it,\n"+
"set the `disable_mlock` configuration option in your configuration\n"+
"file.",
err)
}
}
switch conf.RawConfig.Controller.WorkerStatusGracePeriodDuration {
case 0:
c.workerStatusGracePeriod.Store(int64(server.DefaultLiveness))
default:
c.workerStatusGracePeriod.Store(int64(conf.RawConfig.Controller.WorkerStatusGracePeriodDuration))
}
switch conf.RawConfig.Controller.LivenessTimeToStaleDuration {
case 0:
c.livenessTimeToStale.Store(int64(server.DefaultLiveness))
default:
c.livenessTimeToStale.Store(int64(conf.RawConfig.Controller.LivenessTimeToStaleDuration))
}
clusterListeners := make([]*base.ServerListener, 0)
for i := range conf.Listeners {
l := conf.Listeners[i]
if l == nil || l.Config == nil || l.Config.Purpose == nil {
continue
}
if len(l.Config.Purpose) != 1 {
return nil, fmt.Errorf("found listener with multiple purposes %q", strings.Join(l.Config.Purpose, ","))
}
switch l.Config.Purpose[0] {
case "api":
c.apiListeners = append(c.apiListeners, l)
case "cluster":
clusterListeners = append(clusterListeners, l)
}
}
if len(c.apiListeners) == 0 {
return nil, fmt.Errorf("no api listeners found")
}
if len(clusterListeners) != 1 {
// in the future, we might pick the cluster that is exposed to the outside
// instead of limiting it to one.
return nil, fmt.Errorf("exactly one cluster listener is required")
}
c.clusterListener = clusterListeners[0]
if err := c.initializeRateLimiter(conf.RawConfig); err != nil {
return nil, fmt.Errorf("error initializing rate limiter: %w", err)
}
var pluginLogger hclog.Logger
for _, enabledPlugin := range c.enabledPlugins {
if pluginLogger == nil {
pluginLogger, err = event.NewHclogLogger(ctx, c.conf.Server.Eventer)
if err != nil {
return nil, fmt.Errorf("error creating host catalog plugin logger: %w", err)
}
}
switch {
case enabledPlugin == base.EnabledPluginLoopback:
lp, err := loopback.NewLoopbackPlugin()
if err != nil {
return nil, fmt.Errorf("error creating loopback plugin: %w", err)
}
plg := loopback.NewWrappingPluginHostClient(lp)
opts := []plugin.Option{
plugin.WithDescription("Provides an initial loopback storage and host plugin in Boundary"),
plugin.WithPublicId(conf.DevLoopbackPluginId),
}
if _, err = conf.RegisterPlugin(ctx, "loopback", plg, []plugin.PluginType{plugin.PluginTypeHost, plugin.PluginTypeStorage}, opts...); err != nil {
return nil, err
}
case enabledPlugin == base.EnabledPluginHostAzure && !c.conf.SkipPlugins:
pluginType := strings.ToLower(enabledPlugin.String())
client, cleanup, err := external_plugins.CreateHostPlugin(
ctx,
pluginType,
external_plugins.WithPluginOptions(
pluginutil.WithPluginExecutionDirectory(conf.RawConfig.Plugins.ExecutionDir),
pluginutil.WithPluginsFilesystem(boundary_plugin_assets.PluginPrefix, boundary_plugin_assets.FileSystem()),
),
external_plugins.WithLogger(pluginLogger.Named(pluginType)),
)
if err != nil {
return nil, fmt.Errorf("error creating %s host plugin: %w", pluginType, err)
}
conf.ShutdownFuncs = append(conf.ShutdownFuncs, cleanup)
if _, err := conf.RegisterPlugin(ctx, pluginType, client, []plugin.PluginType{plugin.PluginTypeHost}, plugin.WithDescription(fmt.Sprintf("Built-in %s host plugin", enabledPlugin.String()))); err != nil {
return nil, fmt.Errorf("error registering %s host plugin: %w", pluginType, err)
}
case enabledPlugin == base.EnabledPluginAws && !c.conf.SkipPlugins:
pluginType := strings.ToLower(enabledPlugin.String())
client, cleanup, err := external_plugins.CreateHostPlugin(
ctx,
pluginType,
external_plugins.WithPluginOptions(
pluginutil.WithPluginExecutionDirectory(conf.RawConfig.Plugins.ExecutionDir),
pluginutil.WithPluginsFilesystem(boundary_plugin_assets.PluginPrefix, boundary_plugin_assets.FileSystem()),
),
external_plugins.WithLogger(pluginLogger.Named(pluginType)),
)
if err != nil {
return nil, fmt.Errorf("error creating %s host plugin: %w", pluginType, err)
}
conf.ShutdownFuncs = append(conf.ShutdownFuncs, cleanup)
if _, err := conf.RegisterPlugin(ctx, pluginType, client, []plugin.PluginType{plugin.PluginTypeHost, plugin.PluginTypeStorage}, plugin.WithDescription(fmt.Sprintf("Built-in %s host plugin", enabledPlugin.String()))); err != nil {
return nil, fmt.Errorf("error registering %s host plugin: %w", pluginType, err)
}
case enabledPlugin == base.EnabledPluginMinio && !c.conf.SkipPlugins:
pluginType := strings.ToLower(enabledPlugin.String())
if _, err := conf.RegisterPlugin(ctx, pluginType, nil, []plugin.PluginType{plugin.PluginTypeStorage}, plugin.WithDescription(fmt.Sprintf("Built-in %s storage plugin", enabledPlugin.String()))); err != nil {
return nil, fmt.Errorf("error registering %s storage plugin: %w", pluginType, err)
}
}
}
if conf.HostPlugins == nil {
conf.HostPlugins = make(map[string]plgpb.HostPluginServiceClient)
}
// Set up repo stuff
dbase := db.New(c.conf.Database)
c.kms, err = kms.New(ctx, dbase, dbase)
if err != nil {
return nil, fmt.Errorf("error creating kms cache: %w", err)
}
if err := c.kms.AddExternalWrappers(
ctx,
kms.WithRootWrapper(c.conf.RootKms),
kms.WithWorkerAuthWrapper(c.conf.WorkerAuthKms),
kms.WithRecoveryWrapper(c.conf.RecoveryKms),
kms.WithBsrWrapper(c.conf.BsrKms),
); err != nil {
return nil, fmt.Errorf("error adding config keys to kms: %w", err)
}
// we need to get all the scopes so we can reconcile the DEKs for each scope.
iamRepo, err := iam.NewRepository(ctx, dbase, dbase, c.kms, iam.WithRandomReader(c.conf.SecureRandomReader))
if err != nil {
return nil, fmt.Errorf("unable to initialize iam repository: %w", err)
}
allScopes, err := iamRepo.ListScopesRecursively(ctx, scope.Global.String())
if err != nil {
return nil, fmt.Errorf("error listing all scopes for reconciling keys: %w", err)
}
reconcileScopeIds := make([]string, 0, len(allScopes))
for _, s := range allScopes {
reconcileScopeIds = append(reconcileScopeIds, s.PublicId)
}
if err := c.kms.ReconcileKeys(ctx, c.conf.SecureRandomReader, kms.WithScopeIds(reconcileScopeIds...)); err != nil {
return nil, fmt.Errorf("error reconciling kms keys: %w", err)
}
// now that the kms is configured, we can get the audit wrapper and rotate
// the eventer audit wrapper, so the emitted events can include encrypt and
// hmac-sha256 data
auditWrapper, err := c.kms.GetWrapper(ctx, scope.Global.String(), kms.KeyPurposeAudit)
if err != nil {
return nil, fmt.Errorf("error getting audit wrapper from kms: %w", err)
}
if err := c.conf.Eventer.RotateAuditWrapper(ctx, auditWrapper); err != nil {
return nil, fmt.Errorf("error rotating eventer audit wrapper: %w", err)
}
jobRepoFn := func() (*job.Repository, error) {
return job.NewRepository(ctx, dbase, dbase, c.kms)
}
// TODO: Allow setting run jobs limit from config
schedulerOpts := []scheduler.Option{scheduler.WithRunJobsLimit(-1)}
if c.conf.RawConfig.Controller.Scheduler.JobRunIntervalDuration > 0 {
schedulerOpts = append(schedulerOpts, scheduler.WithRunJobsInterval(c.conf.RawConfig.Controller.Scheduler.JobRunIntervalDuration))
}
if c.conf.RawConfig.Controller.Scheduler.MonitorIntervalDuration > 0 {
schedulerOpts = append(schedulerOpts, scheduler.WithMonitorInterval(c.conf.RawConfig.Controller.Scheduler.MonitorIntervalDuration))
}
c.scheduler, err = scheduler.New(ctx, c.conf.RawConfig.Controller.Name, jobRepoFn, schedulerOpts...)
if err != nil {
return nil, fmt.Errorf("error creating new scheduler: %w", err)
}
c.IamRepoFn = func() (*iam.Repository, error) {
return iam.NewRepository(ctx, dbase, dbase, c.kms, iam.WithRandomReader(c.conf.SecureRandomReader))
}
c.StaticHostRepoFn = func() (*static.Repository, error) {
return static.NewRepository(ctx, dbase, dbase, c.kms)
}
c.PluginHostRepoFn = func() (*pluginhost.Repository, error) {
return pluginhost.NewRepository(ctx, dbase, dbase, c.kms, c.scheduler, c.conf.HostPlugins)
}
c.PluginRepoFn = func() (*plugin.Repository, error) {
return plugin.NewRepository(ctx, dbase, dbase, c.kms)
}
c.PluginStorageBucketRepoFn = func() (*pluginstorage.Repository, error) {
return pluginstorage.NewRepository(ctx, dbase, dbase, c.kms, c.scheduler)
}
c.AuthTokenRepoFn = func() (*authtoken.Repository, error) {
return authtoken.NewRepository(ctx, dbase, dbase, c.kms,
authtoken.WithTokenTimeToLiveDuration(c.conf.RawConfig.Controller.AuthTokenTimeToLiveDuration),
authtoken.WithTokenTimeToStaleDuration(c.conf.RawConfig.Controller.AuthTokenTimeToStaleDuration))
}
c.VaultCredentialRepoFn = func() (*vault.Repository, error) {
return vault.NewRepository(ctx, dbase, dbase, c.kms, c.scheduler)
}
c.StaticCredentialRepoFn = func() (*credstatic.Repository, error) {
return credstatic.NewRepository(ctx, dbase, dbase, c.kms)
}
c.CredentialStoreRepoFn = func() (*credential.StoreRepository, error) {
return credential.NewStoreRepository(ctx, dbase, dbase)
}
c.HostCatalogRepoFn = func() (*host.CatalogRepository, error) {
return host.NewCatalogRepository(ctx, dbase, dbase)
}
c.ServersRepoFn = func() (*server.Repository, error) {
return server.NewRepository(ctx, dbase, dbase, c.kms)
}
c.OidcRepoFn = func() (*oidc.Repository, error) {
return oidc.NewRepository(ctx, dbase, dbase, c.kms)
}
c.LdapRepoFn = func() (*ldap.Repository, error) {
return ldap.NewRepository(ctx, dbase, dbase, c.kms)
}
c.PasswordAuthRepoFn = func() (*password.Repository, error) {
return password.NewRepository(ctx, dbase, dbase, c.kms)
}
c.AuthMethodRepoFn = func() (*auth.AuthMethodRepository, error) {
return auth.NewAuthMethodRepository(ctx, dbase, dbase, c.kms)
}
c.TargetRepoFn = func(o ...target.Option) (*target.Repository, error) {
return target.NewRepository(ctx, dbase, dbase, c.kms, o...)
}
c.SessionRepoFn = func(opt ...session.Option) (*session.Repository, error) {
// Always add a secure random reader to the new session repository.
// Add it as the first option so that it can be overridden by users.
opt = append([]session.Option{session.WithRandomReader(c.conf.SecureRandomReader)}, opt...)
return session.NewRepository(ctx, dbase, dbase, c.kms, opt...)
}
c.ConnectionRepoFn = func() (*session.ConnectionRepository, error) {
return session.NewConnectionRepository(ctx, dbase, dbase, c.kms)
}
c.WorkerAuthRepoStorageFn = func() (*server.WorkerAuthRepositoryStorage, error) {
return server.NewRepositoryStorage(ctx, dbase, dbase, c.kms)
}
c.BillingRepoFn = func() (*billing.Repository, error) {
return billing.NewRepository(ctx, dbase)
}
c.AliasRepoFn = func() (*alias.Repository, error) {
return alias.NewRepository(ctx, dbase, dbase, c.kms)
}
c.TargetAliasRepoFn = func() (*talias.Repository, error) {
return talias.NewRepository(ctx, dbase, dbase, c.kms)
}
// Check that credentials are available at startup, to avoid some harmless
// but nasty-looking errors
serversRepo, err := server.NewRepositoryStorage(ctx, dbase, dbase, c.kms)
if err != nil {
return nil, fmt.Errorf("unable to instantiate worker auth repository: %w", err)
}
_, err = server.RotateRoots(ctx, serversRepo, nodeenrollment.WithCertificateLifetime(conf.TestOverrideWorkerAuthCaCertificateLifetime), nodeenrollment.WithReinitializeRoots(conf.TestWorkerAuthCaReinitialize))
if err != nil {
event.WriteSysEvent(ctx, op, "unable to ensure worker auth roots exist, may be due to multiple controllers starting at once, continuing")
}
if downstreamersFactory != nil {
boundVer := version.Get().VersionNumber()
c.downstreamWorkers, err = downstreamersFactory(ctx, "root", boundVer)
if err != nil {
return nil, fmt.Errorf("unable to initialize downstream workers graph: %w", err)
}
if commandClientFactory != nil {
err := commandClientFactory(ctx, c)
if err != nil {
return nil, fmt.Errorf("unable to initialize issue command factory: %w", err)
}
}
}
if extControllerFactory != nil {
if c.ControllerExtension, err = extControllerFactory(ctx, c, dbase, dbase, c.kms); err != nil {
return nil, fmt.Errorf("unable to extend controller: %w", err)
}
}
return c, nil
}
func (c *Controller) Start() error {
const op = "controller.(Controller).Start"
if c.started.Swap(true) {
event.WriteSysEvent(context.TODO(), op, "already started, skipping")
return nil
}
c.baseContext, c.baseCancel = context.WithCancel(context.Background())
if err := c.registerJobs(); err != nil {
return fmt.Errorf("error registering jobs: %w", err)
}
if err := c.startListeners(); err != nil {
return fmt.Errorf("error starting controller listeners: %w", err)
}
// Upsert controller before starting tickers and scheduler to ensure the controller exists
if err := c.upsertController(c.baseContext); err != nil {
return fmt.Errorf("error upserting controller: %w", err)
}
if err := c.scheduler.Start(c.baseContext, c.schedulerWg); err != nil {
return fmt.Errorf("error starting scheduler: %w", err)
}
c.tickerWg.Add(5)
go func() {
defer c.tickerWg.Done()
c.startStatusTicking(c.baseContext)
}()
go func() {
defer c.tickerWg.Done()
c.startNonceCleanupTicking(c.baseContext)
}()
go func() {
defer c.tickerWg.Done()
c.startTerminateCompletedSessionsTicking(c.baseContext)
}()
go func() {
defer c.tickerWg.Done()
c.startCloseExpiredPendingTokens(c.baseContext)
}()
if err := c.startWorkerConnectionMaintenanceTicking(c.baseContext, c.tickerWg, c.downstreamConnManager); err != nil {
return errors.Wrap(c.baseContext, err, op)
}
if c.downstreamConns != nil {
c.tickerWg.Add(2)
servNameFn := func() string {
switch {
case c.conf.RawConfig.Controller.Name != "":
return c.conf.RawConfig.Controller.Name
default:
return "unknown controller name"
}
}
go func() {
defer c.tickerWg.Done()
c.downstreamConns.StartProcessingPendingConnections(c.baseContext, servNameFn)
}()
go func() {
defer c.tickerWg.Done()
err := c.downstreamConns.StartConnectionMgmtTicking(
c.baseContext,
servNameFn,
-1,
)
if err != nil {
event.WriteError(c.baseContext, op, fmt.Errorf("connection management ticker exited with error: %w", err))
}
}()
}
if downstreamWorkersTickerFactory != nil {
// we'll use "root" to designate that this is the root of the graph (aka
// a controller)
boundVer := version.Get().VersionNumber()
dswTicker, err := downstreamWorkersTickerFactory(c.baseContext, "root", boundVer, c.downstreamWorkers, c.downstreamConns)
if err != nil {
return fmt.Errorf("error creating downstream workers ticker: %w", err)
}
c.tickerWg.Add(1)
go func() {
defer c.tickerWg.Done()
err := dswTicker.StartDownstreamWorkersTicking(c.baseContext, -1)
if err != nil {
event.WriteSysEvent(c.baseContext, op, "error starting/running downstream workers ticker", "err", err.Error())
}
}()
}
if c.ControllerExtension != nil {
if err := c.ControllerExtension.Start(c.baseContext); err != nil {
return fmt.Errorf("error starting controller extension: %w", err)
}
}
return nil
}
func (c *Controller) registerJobs() error {
rw := db.New(c.conf.Database)
if err := vault.RegisterJobs(c.baseContext, c.scheduler, rw, rw, c.kms); err != nil {
return err
}
if err := pluginhost.RegisterJobs(c.baseContext, c.scheduler, rw, rw, c.kms, c.conf.HostPlugins); err != nil {
return err
}
if err := session.RegisterJobs(c.baseContext, c.scheduler, rw, rw, c.kms, c.workerStatusGracePeriod); err != nil {
return err
}
var serverJobOpts []serversjob.Option
if c.conf.TestOverrideWorkerAuthCaCertificateLifetime > 0 {
serverJobOpts = append(serverJobOpts,
serversjob.WithCertificateLifetime(c.conf.TestOverrideWorkerAuthCaCertificateLifetime),
serversjob.WithRotationFrequency(c.conf.TestOverrideWorkerAuthCaCertificateLifetime/2),
)
}
if err := serversjob.RegisterJobs(c.baseContext, c.scheduler, rw, rw, c.kms, c.ControllerExtension, c.workerStatusGracePeriod, serverJobOpts...); err != nil {
return err
}
if err := kmsjob.RegisterJobs(c.baseContext, c.scheduler, c.kms); err != nil {
return err
}
if err := cleaner.RegisterJob(c.baseContext, c.scheduler, rw); err != nil {
return err
}
if err := snapshot.RegisterJob(c.baseContext, c.scheduler, rw, rw); err != nil {
return err
}
if err := census.RegisterJob(c.baseContext, c.scheduler, c.conf.RawConfig.Reporting.License.Enabled, rw, rw); err != nil {
return err
}
if err := purge.RegisterJobs(c.baseContext, c.scheduler, rw, rw); err != nil {
return err
}
if err := recording.RegisterJob(c.baseContext, c.scheduler, rw, rw, c.ControllerExtension, c.kms); err != nil {
return err
}
return nil
}
func (c *Controller) Shutdown() error {
const op = "controller.(Controller).Shutdown"
if !c.started.Load() {
event.WriteSysEvent(context.TODO(), op, "already shut down, skipping")
}
defer c.started.Store(false)
c.baseCancel()
if err := c.stopServersAndListeners(); err != nil {
return fmt.Errorf("error stopping controller servers and listeners: %w", err)
}
c.schedulerWg.Wait()
c.tickerWg.Wait()
if c.conf.Eventer != nil {
if err := c.conf.Eventer.FlushNodes(context.Background()); err != nil {
return fmt.Errorf("error flushing controller eventer nodes: %w", err)
}
}
return nil
}
// WorkerStatusUpdateTimes returns the map, which specifically is held in _this_
// controller, not the DB. It's used in tests to verify that a given controller
// is receiving updates from an expected set of workers, to test out balancing
// and auto reconnection.
func (c *Controller) WorkerStatusUpdateTimes() *sync.Map {
return c.workerStatusUpdateTimes
}