You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
boundary/internal/daemon/worker/worker.go

946 lines
33 KiB

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package worker
import (
"context"
"crypto/ed25519"
"crypto/rand"
"crypto/subtle"
"crypto/tls"
"crypto/x509"
"fmt"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/hashicorp/boundary/globals"
"github.com/hashicorp/boundary/internal/cmd/base"
"github.com/hashicorp/boundary/internal/cmd/config"
"github.com/hashicorp/boundary/internal/daemon/cluster"
"github.com/hashicorp/boundary/internal/daemon/cluster/handlers"
"github.com/hashicorp/boundary/internal/daemon/worker/common"
"github.com/hashicorp/boundary/internal/daemon/worker/internal/metric"
"github.com/hashicorp/boundary/internal/daemon/worker/proxy"
"github.com/hashicorp/boundary/internal/daemon/worker/session"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/boundary/internal/event"
pb "github.com/hashicorp/boundary/internal/gen/controller/servers"
pbs "github.com/hashicorp/boundary/internal/gen/controller/servers/services"
wpbs "github.com/hashicorp/boundary/internal/gen/worker/servers/services"
"github.com/hashicorp/boundary/internal/server"
"github.com/hashicorp/boundary/internal/storage"
boundary_plugin_assets "github.com/hashicorp/boundary/plugins/boundary"
plgpb "github.com/hashicorp/boundary/sdk/pbs/plugin"
external_plugins "github.com/hashicorp/boundary/sdk/plugins"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-secure-stdlib/base62"
"github.com/hashicorp/go-secure-stdlib/mlock"
"github.com/hashicorp/go-secure-stdlib/pluginutil/v2"
"github.com/hashicorp/go-secure-stdlib/strutil"
"github.com/hashicorp/nodeenrollment"
nodeenet "github.com/hashicorp/nodeenrollment/net"
nodeefile "github.com/hashicorp/nodeenrollment/storage/file"
nodeeinmem "github.com/hashicorp/nodeenrollment/storage/inmem"
"github.com/hashicorp/nodeenrollment/types"
"github.com/mr-tron/base58"
"github.com/prometheus/client_golang/prometheus"
ua "go.uber.org/atomic"
"google.golang.org/grpc"
"google.golang.org/grpc/resolver/manual"
"google.golang.org/protobuf/proto"
)
type randFn func(length int) (string, error)
// reverseConnReceiver defines a min interface which must be met by a
// Worker.downstreamReceiver field
type reverseConnReceiver interface {
// StartConnectionMgmtTicking starts a ticker which manages the receiver's
// connections.
StartConnectionMgmtTicking(context.Context, func() string, int) error
// StartProcessingPendingConnections is a function that continually
// processes incoming connections. This only returns when the provided context
// is done.
StartProcessingPendingConnections(context.Context, func() string) error
}
// downstreamersContainer is a struct that exists purely so we can perform
// atomic swap operations on the interface, to avoid/fix data races in tests
// (and any other potential location).
type downstreamersContainer struct {
downstreamers
}
// downstreamers provides at least a minimum interface that must be met by a
// Worker.downstreamWorkers field which is far better than allowing any (empty
// interface)
type downstreamers interface {
// RootId returns the root ID of the downstreamers' graph
RootId() string
}
// recorderManager updates the status updates with relevant recording
// information
type recorderManager interface {
// ReauthorizeAllExcept should be called with the result of the status update
// to reauthorize all recorders for the relevant sessions except the ones provided
ReauthorizeAllExcept(ctx context.Context, closedSessions []string) error
// SessionsManaged gets the list of session ids managed by this recorderManager
SessionsManaged(ctx context.Context) ([]string, error)
// Shutdown must be called prior to exiting the process
Shutdown(ctx context.Context)
}
// reverseConnReceiverFactory provides a simple factory which a Worker can use to
// create its reverseConnReceiver
var reverseConnReceiverFactory func(*atomic.Int64) (reverseConnReceiver, error)
var recordingStorageFactory func(
ctx context.Context,
path string,
plgClients map[string]plgpb.StoragePluginServiceClient,
enableLoopback bool,
minimumAvailableDiskSpace uint64,
) (storage.RecordingStorage, error)
var recorderManagerFactory func(*Worker) (recorderManager, error)
var eventListenerFactory func(*Worker) (event.EventListener, error)
var initializeReverseGrpcClientCollectors = noopInitializePromCollectors
func noopInitializePromCollectors(r prometheus.Registerer) {}
var hostServiceServerFactory func(
ctx context.Context,
plgClients map[string]plgpb.HostPluginServiceClient,
enableLoopback bool,
) (wpbs.HostServiceServer, error)
const (
authenticationStatusNeverAuthenticated uint32 = iota
authenticationStatusFirstAuthentication
authenticationStatusFirstStatusRpcSuccessful
)
type Worker struct {
conf *Config
logger hclog.Logger
baseContext context.Context
baseCancel context.CancelFunc
started *ua.Bool
tickerWg sync.WaitGroup
// grpc.ClientConns are thread safe. See
// https://github.com/grpc/grpc-go/blob/master/Documentation/concurrency.md#clients
// However this is an atomic because we sometimes swap this pointer out
// (mostly in tests) - which isn't thread safe. This is exported for tests.
GrpcClientConn atomic.Pointer[grpc.ClientConn]
// receives address updates and contains the grpc resolver.
addressReceivers []addressReceiver
sessionManager session.Manager
recorderManager recorderManager
everAuthenticated *ua.Uint32
lastStatusSuccess *atomic.Value
workerStartTime time.Time
operationalState *atomic.Value
// localStorageState is the current state of the local storage.
// The local storage state is updated based on the local storage events.
localStorageState *atomic.Value
storageEventListener event.EventListener
upstreamConnectionState *atomic.Value
controllerMultihopConn *atomic.Value
controllerUpstreamMsgConn atomic.Pointer[handlers.UpstreamMessageServiceClientProducer]
proxyListener *base.ServerListener
// Used to generate a random nonce for Controller connections
nonceFn randFn
// We store the current set in an atomic value so that we can add
// reload-on-sighup behavior later
tags *atomic.Value
// This stores whether or not to send updated tags on the next status
// request. It can be set via startup in New below, or (eventually) via
// SIGHUP.
updateTags *ua.Bool
// The storage for node enrollment
WorkerAuthStorage nodeenrollment.Storage
WorkerAuthCurrentKeyId *ua.String
WorkerAuthRegistrationRequest string
workerAuthSplitListener *nodeenet.SplitListener
// The storage for session recording
RecordingStorage storage.RecordingStorage
// downstream workers and routes to those workers
downstreamWorkers *atomic.Pointer[downstreamersContainer]
downstreamReceiver reverseConnReceiver
// Timing variables. These are atomics for SIGHUP support, and are int64
// because they are casted to time.Duration.
successfulStatusGracePeriod *atomic.Int64
statusCallTimeoutDuration *atomic.Int64
getDownstreamWorkersTimeoutDuration *atomic.Int64
// AuthRotationNextRotation is useful in tests to understand how long to
// sleep
AuthRotationNextRotation atomic.Pointer[time.Time]
// Test-specific options (and possibly hidden dev-mode flags)
TestOverrideX509VerifyDnsName string
TestOverrideX509VerifyCertPool *x509.CertPool
TestOverrideAuthRotationPeriod time.Duration
statusLock sync.Mutex
downstreamConnManager *cluster.DownstreamManager
HostServiceServer wpbs.HostServiceServer
}
func New(ctx context.Context, conf *Config) (*Worker, error) {
const op = "worker.New"
metric.InitializeHttpCollectors(conf.PrometheusRegisterer)
metric.InitializeWebsocketCollectors(conf.PrometheusRegisterer)
metric.InitializeClusterClientCollectors(conf.PrometheusRegisterer)
initializeReverseGrpcClientCollectors(conf.PrometheusRegisterer)
baseContext, baseCancel := context.WithCancel(context.Background())
w := &Worker{
baseContext: baseContext,
baseCancel: baseCancel,
conf: conf,
logger: conf.Logger.Named("worker"),
started: ua.NewBool(false),
everAuthenticated: ua.NewUint32(authenticationStatusNeverAuthenticated),
lastStatusSuccess: new(atomic.Value),
controllerMultihopConn: new(atomic.Value),
// controllerUpstreamMsgConn: new(atomic.Value),
tags: new(atomic.Value),
updateTags: ua.NewBool(false),
nonceFn: base62.Random,
WorkerAuthCurrentKeyId: new(ua.String),
operationalState: new(atomic.Value),
downstreamConnManager: cluster.NewDownstreamManager(),
localStorageState: new(atomic.Value),
successfulStatusGracePeriod: new(atomic.Int64),
statusCallTimeoutDuration: new(atomic.Int64),
getDownstreamWorkersTimeoutDuration: new(atomic.Int64),
upstreamConnectionState: new(atomic.Value),
downstreamWorkers: new(atomic.Pointer[downstreamersContainer]),
}
w.operationalState.Store(server.UnknownOperationalState)
w.localStorageState.Store(server.UnknownLocalStorageState)
w.lastStatusSuccess.Store((*LastStatusInformation)(nil))
scheme := strconv.FormatInt(time.Now().UnixNano(), 36)
controllerResolver := manual.NewBuilderWithScheme(scheme)
w.addressReceivers = []addressReceiver{&grpcResolverReceiver{controllerResolver}}
if conf.RawConfig.Worker == nil {
conf.RawConfig.Worker = new(config.Worker)
}
if w.conf.RawConfig.Worker.RecordingStoragePath == "" {
w.localStorageState.Store(server.NotConfiguredLocalStorageState)
}
pluginLogger, err := event.NewHclogLogger(ctx, w.conf.Server.Eventer)
if err != nil {
return nil, fmt.Errorf("error creating plugin logger: %w", err)
}
w.HostServiceServer = wpbs.UnimplementedHostServiceServer{}
if hostServiceServerFactory != nil {
enableLoopback := false
hostPlgClients := make(map[string]plgpb.HostPluginServiceClient)
for _, enabledPlugin := range w.conf.Server.EnabledPlugins {
switch {
case enabledPlugin == base.EnabledPluginHostAzure && !w.conf.SkipPlugins,
enabledPlugin == base.EnabledPluginAws && !w.conf.SkipPlugins:
pluginType := strings.ToLower(enabledPlugin.String())
client, cleanup, err := external_plugins.CreateHostPlugin(
ctx,
pluginType,
external_plugins.WithPluginOptions(
pluginutil.WithPluginExecutionDirectory(conf.RawConfig.Plugins.ExecutionDir),
pluginutil.WithPluginsFilesystem(boundary_plugin_assets.PluginPrefix, boundary_plugin_assets.FileSystem()),
),
external_plugins.WithLogger(pluginLogger.Named(pluginType)),
)
if err != nil {
return nil, fmt.Errorf("error creating %s host plugin: %w", pluginType, err)
}
conf.ShutdownFuncs = append(conf.ShutdownFuncs, cleanup)
hostPlgClients[pluginType] = client
case enabledPlugin == base.EnabledPluginLoopback:
enableLoopback = true
}
}
hs, err := hostServiceServerFactory(ctx, hostPlgClients, enableLoopback)
if err != nil {
return nil, fmt.Errorf("failed to create host service server: %w", err)
}
w.HostServiceServer = hs
}
if w.conf.RawConfig.Worker.RecordingStoragePath != "" && recordingStorageFactory != nil {
plgClients := make(map[string]plgpb.StoragePluginServiceClient)
var enableStorageLoopback bool
for _, enabledPlugin := range w.conf.Server.EnabledPlugins {
switch {
case enabledPlugin == base.EnabledPluginMinio && !w.conf.SkipPlugins:
fallthrough
case enabledPlugin == base.EnabledPluginAws && !w.conf.SkipPlugins:
pluginType := strings.ToLower(enabledPlugin.String())
client, cleanup, err := external_plugins.CreateStoragePlugin(
ctx,
pluginType,
external_plugins.WithPluginOptions(
pluginutil.WithPluginExecutionDirectory(conf.RawConfig.Plugins.ExecutionDir),
pluginutil.WithPluginsFilesystem(boundary_plugin_assets.PluginPrefix, boundary_plugin_assets.FileSystem()),
),
external_plugins.WithLogger(pluginLogger.Named(pluginType)),
)
if err != nil {
return nil, fmt.Errorf("error creating %s storage plugin: %w", pluginType, err)
}
conf.ShutdownFuncs = append(conf.ShutdownFuncs, cleanup)
plgClients[pluginType] = client
case enabledPlugin == base.EnabledPluginLoopback:
enableStorageLoopback = true
}
}
// passing in an empty context so that storage can finish syncing during an emergency shutdown or interrupt
s, err := recordingStorageFactory(
context.Background(),
w.conf.RawConfig.Worker.RecordingStoragePath,
plgClients, enableStorageLoopback,
w.conf.RawConfig.Worker.RecordingStorageMinimumAvailableDiskSpace,
)
if err != nil {
return nil, fmt.Errorf("error create recording storage: %w", err)
}
w.RecordingStorage = s
}
w.parseAndStoreTags(conf.RawConfig.Worker.Tags)
if conf.SecureRandomReader == nil {
conf.SecureRandomReader = rand.Reader
}
if !conf.RawConfig.DisableMlock {
// Ensure our memory usage is locked into physical RAM
if err := mlock.LockMemory(); err != nil {
return nil, fmt.Errorf(
"Failed to lock memory: %v\n\n"+
"This usually means that the mlock syscall is not available.\n"+
"Boundary uses mlock to prevent memory from being swapped to\n"+
"disk. This requires root privileges as well as a machine\n"+
"that supports mlock. Please enable mlock on your system or\n"+
"disable Boundary from using it. To disable Boundary from using it,\n"+
"set the `disable_mlock` configuration option in your configuration\n"+
"file.",
err)
}
}
switch conf.RawConfig.Worker.SuccessfulStatusGracePeriodDuration {
case 0:
w.successfulStatusGracePeriod.Store(int64(server.DefaultLiveness))
default:
w.successfulStatusGracePeriod.Store(int64(conf.RawConfig.Worker.SuccessfulStatusGracePeriodDuration))
}
switch conf.RawConfig.Worker.StatusCallTimeoutDuration {
case 0:
w.statusCallTimeoutDuration.Store(int64(common.DefaultStatusTimeout))
default:
w.statusCallTimeoutDuration.Store(int64(conf.RawConfig.Worker.StatusCallTimeoutDuration))
}
switch conf.RawConfig.Worker.GetDownstreamWorkersTimeoutDuration {
case 0:
w.getDownstreamWorkersTimeoutDuration.Store(int64(server.DefaultLiveness))
default:
w.getDownstreamWorkersTimeoutDuration.Store(int64(conf.RawConfig.Worker.GetDownstreamWorkersTimeoutDuration))
}
// FIXME: This is really ugly, but works.
session.CloseCallTimeout.Store(w.successfulStatusGracePeriod.Load())
if reverseConnReceiverFactory != nil {
var err error
w.downstreamReceiver, err = reverseConnReceiverFactory(w.getDownstreamWorkersTimeoutDuration)
if err != nil {
return nil, fmt.Errorf("%s: error creating reverse connection receiver: %w", op, err)
}
}
if eventListenerFactory != nil {
var err error
w.storageEventListener, err = eventListenerFactory(w)
if err != nil {
return nil, fmt.Errorf("error calling eventListenerFactory: %w", err)
}
}
if recorderManagerFactory != nil {
var err error
w.recorderManager, err = recorderManagerFactory(w)
if err != nil {
return nil, fmt.Errorf("error calling recorderManagerFactory: %w", err)
}
}
var listenerCount int
for i := range conf.Listeners {
l := conf.Listeners[i]
if l == nil || l.Config == nil || l.Config.Purpose == nil {
continue
}
if len(l.Config.Purpose) != 1 {
return nil, fmt.Errorf("found listener with multiple purposes %q", strings.Join(l.Config.Purpose, ","))
}
switch l.Config.Purpose[0] {
case "proxy":
if w.proxyListener == nil {
w.proxyListener = l
}
listenerCount++
}
}
if listenerCount != 1 {
return nil, fmt.Errorf("exactly one proxy listener is required")
}
return w, nil
}
// Reload will update a worker with a new Config. The worker will only use
// relevant parts of the new config, specifically:
// - Worker Tags
// - Initial Upstream addresses
func (w *Worker) Reload(ctx context.Context, newConf *config.Config) {
const op = "worker.(Worker).Reload"
w.parseAndStoreTags(newConf.Worker.Tags)
if !strutil.EquivalentSlices(newConf.Worker.InitialUpstreams, w.conf.RawConfig.Worker.InitialUpstreams) {
w.statusLock.Lock()
defer w.statusLock.Unlock()
upstreamsMessage := fmt.Sprintf(
"Initial Upstreams has changed; old upstreams were: %s, new upstreams are: %s",
w.conf.RawConfig.Worker.InitialUpstreams,
newConf.Worker.InitialUpstreams,
)
event.WriteSysEvent(ctx, op, upstreamsMessage)
w.conf.RawConfig.Worker.InitialUpstreams = newConf.Worker.InitialUpstreams
for _, ar := range w.addressReceivers {
ar.SetAddresses(w.conf.RawConfig.Worker.InitialUpstreams)
// set InitialAddresses in case the worker has not successfully dialed yet
ar.InitialAddresses(w.conf.RawConfig.Worker.InitialUpstreams)
}
}
switch newConf.Worker.SuccessfulStatusGracePeriodDuration {
case 0:
w.successfulStatusGracePeriod.Store(int64(server.DefaultLiveness))
default:
w.successfulStatusGracePeriod.Store(int64(newConf.Worker.SuccessfulStatusGracePeriodDuration))
}
switch newConf.Worker.StatusCallTimeoutDuration {
case 0:
w.statusCallTimeoutDuration.Store(int64(common.DefaultStatusTimeout))
default:
w.statusCallTimeoutDuration.Store(int64(newConf.Worker.StatusCallTimeoutDuration))
}
switch newConf.Worker.GetDownstreamWorkersTimeoutDuration {
case 0:
w.getDownstreamWorkersTimeoutDuration.Store(int64(server.DefaultLiveness))
default:
w.getDownstreamWorkersTimeoutDuration.Store(int64(newConf.Worker.GetDownstreamWorkersTimeoutDuration))
}
// See comment about this in worker.go
session.CloseCallTimeout.Store(w.successfulStatusGracePeriod.Load())
}
func (w *Worker) Start() error {
const op = "worker.(Worker).Start"
if w.started.Load() {
event.WriteSysEvent(w.baseContext, op, "already started, skipping")
return nil
}
// In this section, we look for existing worker credentials. The two
// variables below store whether to create new credentials and whether to
// create a fetch request so it can be displayed in the worker startup info.
// These may be different because if initial creds have been generated on
// the worker side but not yet authorized/fetched from the controller, we
// don't want to invalidate that request on restart by generating a new set
// of credentials. However it's safe to output a new fetch request so we do
// in fact do that.
//
// Note that if a controller-generated activation token has been supplied,
// we do not output a fetch request; we attempt to use that directly later.
//
// If we have a stable storage path we use that; if no path is supplied
// (e.g. when using KMS) we use inmem storage.
var err error
if w.conf.RawConfig.Worker.AuthStoragePath != "" {
w.WorkerAuthStorage, err = nodeefile.New(w.baseContext,
nodeefile.WithBaseDirectory(w.conf.RawConfig.Worker.AuthStoragePath))
if err != nil {
return errors.Wrap(w.baseContext, err, op, errors.WithMsg("error loading worker auth storage directory"))
}
} else {
w.WorkerAuthStorage, err = nodeeinmem.New(w.baseContext)
if err != nil {
return errors.Wrap(w.baseContext, err, op, errors.WithMsg("error loading in-mem worker auth storage"))
}
}
var createNodeAuthCreds bool
var createFetchRequest bool
nodeCreds, err := types.LoadNodeCredentials(
w.baseContext,
w.WorkerAuthStorage,
nodeenrollment.CurrentId,
nodeenrollment.WithStorageWrapper(w.conf.WorkerAuthStorageKms))
switch {
case err == nil:
if nodeCreds == nil {
// It's unclear why this would ever happen -- it shouldn't -- so
// this is simply safety against panics if something goes
// catastrophically wrong
event.WriteSysEvent(w.baseContext, op, "no error loading worker auth creds but nil creds, creating new creds for registration")
createNodeAuthCreds = true
createFetchRequest = true
break
}
// Check that we have valid creds, or that we have generated creds but
// simply are still waiting on authentication (in which case we don't
// want to invalidate what we've already sent)
var validCreds bool
switch len(nodeCreds.CertificateBundles) {
case 0:
// Still waiting on initial creds, so don't invalidate the request
// by creating new credentials. However, we will generate and
// display a new valid request in case the first was lost.
createFetchRequest = true
default:
now := time.Now()
for _, bundle := range nodeCreds.CertificateBundles {
if bundle.CertificateNotBefore.AsTime().Before(now) && bundle.CertificateNotAfter.AsTime().After(now) {
// If we have a certificate in its validity period,
// everything is fine
validCreds = true
break
}
}
// Certificates are both expired, so create new credentials and
// output a request based on those
createNodeAuthCreds = !validCreds
createFetchRequest = !validCreds
}
case errors.Is(err, nodeenrollment.ErrNotFound):
// Nothing was found on disk, so create
createNodeAuthCreds = true
createFetchRequest = true
default:
// Some other type of error happened, bail out
return fmt.Errorf("error loading worker auth creds: %w", err)
}
// Don't output a fetch request if an activation token has been
// provided. Technically we _could_ still output a fetch request, and it
// would be valid to do so, but if a token was provided it may well be
// confusing to a user if it seems like it was ignored because a fetch
// request was still output.
if actToken := w.conf.RawConfig.Worker.ControllerGeneratedActivationToken; actToken != "" {
createFetchRequest = false
}
// NOTE: this block _must_ be before the `if createFetchRequest` block
// or the fetch request may have no credentials to work with
if createNodeAuthCreds {
nodeCreds, err = types.NewNodeCredentials(
w.baseContext,
w.WorkerAuthStorage,
nodeenrollment.WithRandomReader(w.conf.SecureRandomReader),
nodeenrollment.WithStorageWrapper(w.conf.WorkerAuthStorageKms),
)
if err != nil {
return fmt.Errorf("error generating new worker auth creds: %w", err)
}
}
if createFetchRequest {
if nodeCreds == nil {
return fmt.Errorf("need to create fetch request but worker auth creds are nil: %w", err)
}
req, err := nodeCreds.CreateFetchNodeCredentialsRequest(w.baseContext, nodeenrollment.WithRandomReader(w.conf.SecureRandomReader))
if err != nil {
return fmt.Errorf("error creating worker auth fetch credentials request: %w", err)
}
reqBytes, err := proto.Marshal(req)
if err != nil {
return fmt.Errorf("error marshaling worker auth fetch credentials request: %w", err)
}
w.WorkerAuthRegistrationRequest = base58.FastBase58Encoding(reqBytes)
if err != nil {
return fmt.Errorf("error encoding worker auth registration request: %w", err)
}
currentKeyId, err := nodeenrollment.KeyIdFromPkix(nodeCreds.CertificatePublicKeyPkix)
if err != nil {
return fmt.Errorf("error deriving worker auth key id: %w", err)
}
w.WorkerAuthCurrentKeyId.Store(currentKeyId)
}
// Regardless, we want to load the currentKeyId
currentKeyId, err := nodeenrollment.KeyIdFromPkix(nodeCreds.CertificatePublicKeyPkix)
if err != nil {
return fmt.Errorf("error deriving worker auth key id: %w", err)
}
w.WorkerAuthCurrentKeyId.Store(currentKeyId)
if err := w.StartControllerConnections(); err != nil {
return errors.Wrap(w.baseContext, err, op, errors.WithMsg("error making controller connections"))
}
w.sessionManager, err = session.NewManager(pbs.NewSessionServiceClient(w.GrpcClientConn.Load()))
if err != nil {
return errors.Wrap(w.baseContext, err, op, errors.WithMsg("error creating session manager"))
}
if err := w.startListeners(w.sessionManager); err != nil {
return errors.Wrap(w.baseContext, err, op, errors.WithMsg("error starting worker listeners"))
}
if w.storageEventListener != nil {
if err := w.storageEventListener.Start(w.baseContext); err != nil {
return errors.Wrap(w.baseContext, err, op, errors.WithMsg("error starting worker event listener"))
}
if w.RecordingStorage != nil {
w.localStorageState.Store(w.RecordingStorage.GetLocalStorageState(w.baseContext))
}
}
w.operationalState.Store(server.ActiveOperationalState)
// Rather than deal with some of the potential error conditions for Add on
// the waitgroup vs. Done (in case a function exits immediately), we will
// always start rotation and simply exit early if we're using KMS
w.tickerWg.Add(2)
go func() {
defer w.tickerWg.Done()
w.startStatusTicking(w.baseContext, w.sessionManager, &w.addressReceivers, w.recorderManager)
}()
go func() {
defer w.tickerWg.Done()
w.startAuthRotationTicking(w.baseContext)
}()
if w.downstreamReceiver != nil {
w.tickerWg.Add(2)
servNameFn := func() string {
if s := w.LastStatusSuccess(); s != nil {
return s.WorkerId
}
return "unknown worker id"
}
go func() {
defer w.tickerWg.Done()
if err := w.downstreamReceiver.StartProcessingPendingConnections(w.baseContext, servNameFn); err != nil {
errors.Wrap(w.baseContext, err, op)
}
}()
go func() {
defer w.tickerWg.Done()
err := w.downstreamReceiver.StartConnectionMgmtTicking(
w.baseContext,
servNameFn,
-1, // indicates the ticker should run until cancelled.
)
if err != nil {
errors.Wrap(w.baseContext, err, op)
}
}()
}
w.workerStartTime = time.Now()
w.started.Store(true)
return nil
}
// GracefulShutdownm sets the worker state to "shutdown" and will wait to return until there
// are no longer any active connections.
func (w *Worker) GracefulShutdown() error {
const op = "worker.(Worker).GracefulShutdown"
event.WriteSysEvent(w.baseContext, op, "worker entering graceful shutdown")
w.operationalState.Store(server.ShutdownOperationalState)
// As long as some status has been sent in the past, wait for 2 status
// updates to be sent since we've updated our operational state.
lastStatusTime := w.lastSuccessfulStatusTime()
if lastStatusTime != w.workerStartTime {
for i := 0; i < 2; i++ {
for {
if lastStatusTime != w.lastSuccessfulStatusTime() {
lastStatusTime = w.lastSuccessfulStatusTime()
break
}
time.Sleep(time.Millisecond * 250)
}
}
}
// Wait for running proxy connections to drain
for proxy.ProxyState.CurrentProxiedConnections() > 0 {
time.Sleep(time.Millisecond * 250)
}
event.WriteSysEvent(w.baseContext, op, "worker connections have drained")
return nil
}
// Shutdown shuts down the workers. skipListeners can be used to not stop
// listeners, useful for tests if we want to stop and start a worker. In order
// to create new listeners we'd have to migrate listener setup logic here --
// doable, but work for later.
func (w *Worker) Shutdown() error {
const op = "worker.(Worker).Shutdown"
if !w.started.Load() {
event.WriteSysEvent(w.baseContext, op, "already shut down, skipping")
return nil
}
event.WriteSysEvent(w.baseContext, op, "worker shutting down")
// Set state to shutdown
w.operationalState.Store(server.ShutdownOperationalState)
// Stop listeners first to prevent new connections to the
// controller.
defer w.started.Store(false)
if err := w.stopServersAndListeners(); err != nil {
return fmt.Errorf("error stopping worker servers and listeners: %w", err)
}
var recManWg sync.WaitGroup
if w.recorderManager != nil {
recManWg.Add(1)
go func() {
// Shutdown recorder manager to close all recorders, done in a go routine
// since it will not force shutdown of channels until the passed in context
// is Done.
defer recManWg.Done()
w.recorderManager.Shutdown(w.baseContext)
}()
}
// Shut down all connections.
w.cleanupConnections(w.baseContext, true, w.sessionManager)
// Wait for next status request to succeed. Don't wait too long; time it out
// at our default liveness value, which is also our default status grace
// period timeout
waitStatusStart := time.Now()
nextStatusCtx, nextStatusCancel := context.WithTimeout(w.baseContext, server.DefaultLiveness)
defer nextStatusCancel()
for {
if err := nextStatusCtx.Err(); err != nil {
event.WriteError(w.baseContext, op, err, event.WithInfoMsg("error waiting for next status report to controller"))
break
}
if w.lastSuccessfulStatusTime().After(waitStatusStart) {
break
}
time.Sleep(time.Second)
}
// Proceed with remainder of shutdown.
w.baseCancel()
for _, ar := range w.addressReceivers {
ar.SetAddresses(nil)
}
if w.storageEventListener != nil {
err := w.storageEventListener.Shutdown(w.baseContext)
if err != nil {
return fmt.Errorf("error shutting down worker event listener: %w", err)
}
}
w.started.Store(false)
w.tickerWg.Wait()
recManWg.Wait()
if w.conf.Eventer != nil {
if err := w.conf.Eventer.FlushNodes(context.Background()); err != nil {
return fmt.Errorf("error flushing worker eventer nodes: %w", err)
}
}
event.WriteSysEvent(w.baseContext, op, "worker finished shutting down")
return nil
}
func (w *Worker) parseAndStoreTags(incoming map[string][]string) {
if len(incoming) == 0 {
w.tags.Store([]*pb.TagPair{})
return
}
tags := []*pb.TagPair{}
for k, vals := range incoming {
for _, v := range vals {
tags = append(tags, &pb.TagPair{
Key: k,
Value: v,
})
}
}
w.tags.Store(tags)
w.updateTags.Store(true)
}
func (w *Worker) getSessionTls(sessionManager session.Manager) func(hello *tls.ClientHelloInfo) (*tls.Config, error) {
const op = "worker.(Worker).getSessionTls"
return func(hello *tls.ClientHelloInfo) (*tls.Config, error) {
ctx := w.baseContext
var sessionId string
switch {
case strings.HasPrefix(hello.ServerName, fmt.Sprintf("%s_", globals.SessionPrefix)):
sessionId = hello.ServerName
default:
for _, proto := range hello.SupportedProtos {
if strings.HasPrefix(proto, fmt.Sprintf("%s_", globals.SessionPrefix)) {
sessionId = proto
break
}
}
}
if sessionId == "" {
event.WriteSysEvent(ctx, op, "session_id not found in either SNI or ALPN protos", "server_name", hello.ServerName)
return nil, fmt.Errorf("could not find session ID in SNI or ALPN protos")
}
lastSuccess := w.LastStatusSuccess()
if lastSuccess == nil {
event.WriteSysEvent(ctx, op, "no last status information found at session acceptance time")
return nil, fmt.Errorf("no last status information found at session acceptance time")
}
timeoutContext, cancel := context.WithTimeout(w.baseContext, session.ValidateSessionTimeout)
defer cancel()
sess, err := sessionManager.LoadLocalSession(timeoutContext, sessionId, lastSuccess.GetWorkerId())
if err != nil {
return nil, fmt.Errorf("error refreshing session: %w", err)
}
if sess.GetCertificate() == nil {
return nil, fmt.Errorf("requested session has no certifificate")
}
if len(sess.GetCertificate().Raw) == 0 {
return nil, fmt.Errorf("requested session has no certificate DER")
}
if len(sess.GetPrivateKey()) == 0 {
return nil, fmt.Errorf("requested session has no private key")
}
certPool := x509.NewCertPool()
certPool.AddCert(sess.GetCertificate())
tlsConf := &tls.Config{
Certificates: []tls.Certificate{
{
Certificate: [][]byte{sess.GetCertificate().Raw},
PrivateKey: ed25519.PrivateKey(sess.GetPrivateKey()),
Leaf: sess.GetCertificate(),
},
},
NextProtos: []string{"http/1.1"},
MinVersion: tls.VersionTLS13,
// These two are set this way so we can make use of VerifyConnection,
// which we set on this TLS config below. We are not skipping
// verification!
ClientAuth: tls.RequireAnyClientCert,
InsecureSkipVerify: true,
}
// We disable normal DNS SAN behavior as we don't rely on DNS or IP
// addresses for security and want to avoid issues with including localhost
// etc.
verifyOpts := x509.VerifyOptions{
DNSName: sessionId,
Roots: certPool,
KeyUsages: []x509.ExtKeyUsage{
x509.ExtKeyUsageClientAuth,
x509.ExtKeyUsageServerAuth,
},
}
if w.TestOverrideX509VerifyCertPool != nil {
verifyOpts.Roots = w.TestOverrideX509VerifyCertPool
}
if w.TestOverrideX509VerifyDnsName != "" {
verifyOpts.DNSName = w.TestOverrideX509VerifyDnsName
}
tlsConf.VerifyConnection = func(cs tls.ConnectionState) error {
// Go will not run this without at least one peer certificate, but
// doesn't hurt to check
if len(cs.PeerCertificates) == 0 {
return errors.New(ctx, errors.InvalidParameter, op, "no peer certificates provided")
}
if subtle.ConstantTimeCompare(cs.PeerCertificates[0].Raw, sess.GetCertificate().Raw) != 1 {
return errors.New(ctx, errors.InvalidParameter, op, "expected peer certificate to match session certificate")
}
_, err := cs.PeerCertificates[0].Verify(verifyOpts)
return err
}
return tlsConf, nil
}
}
// SendUpstreamMessage facilitates sending upstream messages to the controller.
func (w *Worker) SendUpstreamMessage(ctx context.Context, m proto.Message) (proto.Message, error) {
const op = "worker.(Worker).SendUpstreamMessage"
nodeCreds, err := types.LoadNodeCredentials(w.baseContext, w.WorkerAuthStorage, nodeenrollment.CurrentId, nodeenrollment.WithStorageWrapper(w.conf.WorkerAuthStorageKms))
if err != nil {
return nil, errors.Wrap(ctx, err, op)
}
initKeyId, err := nodeenrollment.KeyIdFromPkix(nodeCreds.CertificatePublicKeyPkix)
if err != nil {
return nil, errors.Wrap(ctx, err, op)
}
clientProducer := w.controllerUpstreamMsgConn.Load()
return handlers.SendUpstreamMessage(ctx, *clientProducer, initKeyId, m, handlers.WithKeyProducer(nodeCreds))
}