You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
boundary/internal/cmd/base/servers.go

885 lines
28 KiB

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package base
import (
"context"
"crypto/rand"
"crypto/tls"
"errors"
"fmt"
"io"
"os"
"os/signal"
"path/filepath"
"sort"
"strconv"
"strings"
"sync"
"sync/atomic"
"syscall"
"time"
"github.com/hashicorp/boundary/globals"
"github.com/hashicorp/boundary/internal/cmd/base/internal/metric"
"github.com/hashicorp/boundary/internal/cmd/base/logging"
"github.com/hashicorp/boundary/internal/cmd/config"
"github.com/hashicorp/boundary/internal/db"
berrors "github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/boundary/internal/event"
"github.com/hashicorp/boundary/internal/kms"
"github.com/hashicorp/boundary/internal/types/scope"
"github.com/hashicorp/boundary/internal/util"
kms_plugin_assets "github.com/hashicorp/boundary/plugins/kms"
plgpb "github.com/hashicorp/boundary/sdk/pbs/plugin"
"github.com/hashicorp/boundary/version"
"github.com/hashicorp/go-hclog"
wrapping "github.com/hashicorp/go-kms-wrapping/v2"
"github.com/hashicorp/go-kms-wrapping/v2/extras/multi"
configutil "github.com/hashicorp/go-secure-stdlib/configutil/v2"
"github.com/hashicorp/go-secure-stdlib/gatedwriter"
"github.com/hashicorp/go-secure-stdlib/listenerutil"
"github.com/hashicorp/go-secure-stdlib/mlock"
"github.com/hashicorp/go-secure-stdlib/parseutil"
"github.com/hashicorp/go-secure-stdlib/pluginutil/v2"
"github.com/hashicorp/go-secure-stdlib/reloadutil"
"github.com/hashicorp/go-secure-stdlib/strutil"
"github.com/mitchellh/cli"
"github.com/prometheus/client_golang/prometheus"
"golang.org/x/text/cases"
"golang.org/x/text/language"
)
const (
// File name to use for storing workerAuth requests
WorkerAuthReqFile = "auth_request_token"
)
func init() {
metric.InitializeBuildInfo(prometheus.DefaultRegisterer)
}
type Server struct {
*Command
ServerSideShutdownCh chan struct{}
InfoKeys []string
Info map[string]string
logOutput io.Writer
GatedWriter *gatedwriter.Writer
Logger hclog.Logger
CombineLogs bool
StderrLock *sync.Mutex
Eventer *event.Eventer
// NOTE: Unlike the other wrappers below, if set, DownstreamWorkerAuthKms
// should always be a PooledWrapper, so that we can allow multiple KMSes to
// accept downstream connections. As such it's made explicit here.
DownstreamWorkerAuthKms *multi.PooledWrapper
RootKms wrapping.Wrapper
WorkerAuthKms wrapping.Wrapper
WorkerAuthStorageKms wrapping.Wrapper
RecoveryKms wrapping.Wrapper
BsrKms wrapping.Wrapper
Kms *kms.Kms
SecureRandomReader io.Reader
WorkerAuthDebuggingEnabled *atomic.Bool
PrometheusRegisterer prometheus.Registerer
ReloadFuncsLock *sync.RWMutex
ReloadFuncs map[string][]reloadutil.ReloadFunc
ShutdownFuncs []func() error
Listeners []*ServerListener
DevPasswordAuthMethodId string
DevOidcAuthMethodId string
DevLdapAuthMethodId string
DevLoginName string
DevPassword string
DevUserId string
DevPasswordAccountId string
DevOidcAccountId string
DevUnprivilegedLoginName string
DevUnprivilegedPassword string
DevUnprivilegedUserId string
DevUnprivilegedPasswordAccountId string
DevUnprivilegedOidcAccountId string
DevOrgId string
DevProjectId string
DevHostCatalogId string
DevHostSetId string
DevHostId string
DevTargetId string // Target using address.
DevSecondaryTargetId string // Target using host sources.
DevHostAddress string // Host address for target using host sources.
DevTargetAddress string // Network address for target with address.
DevTargetDefaultPort uint16
DevTargetSessionMaxSeconds int64
DevTargetSessionConnectionLimit int64
DevLoopbackPluginId string
EnabledPlugins []EnabledPlugin
HostPlugins map[string]plgpb.HostPluginServiceClient
DevOidcSetup oidcSetup
DevLdapSetup ldapSetup
SkipPlugins bool // Useful when running on platforms that we can't easily compile plugins on, like Windows
SkipAliasTargetCreation bool
WorkerDnsServer string
DatabaseUrl string
DatabaseMaxOpenConnections int
DatabaseMaxIdleConnections *int
DatabaseConnMaxIdleTimeDuration *time.Duration
DevDatabaseCleanupFunc func() error
Database *db.DB
}
// NewServer creates a new Server.
func NewServer(cmd *Command) *Server {
return &Server{
Command: cmd,
ServerSideShutdownCh: make(chan struct{}),
InfoKeys: make([]string, 0, 20),
Info: make(map[string]string),
SecureRandomReader: rand.Reader,
ReloadFuncsLock: new(sync.RWMutex),
ReloadFuncs: make(map[string][]reloadutil.ReloadFunc),
StderrLock: new(sync.Mutex),
WorkerAuthDebuggingEnabled: new(atomic.Bool),
PrometheusRegisterer: prometheus.DefaultRegisterer,
}
}
// SetupEventing will setup the server's eventer and initialize the "system
// wide" eventer with a pointer to the same eventer
func (b *Server) SetupEventing(ctx context.Context, logger hclog.Logger, serializationLock *sync.Mutex, serverName string, opt ...Option) error {
const op = "base.(Server).SetupEventing"
if logger == nil {
return berrors.New(ctx, berrors.InvalidParameter, op, "missing logger")
}
if serializationLock == nil {
return berrors.New(ctx, berrors.InvalidParameter, op, "missing serialization lock")
}
if serverName == "" {
return berrors.New(ctx, berrors.InvalidParameter, op, "missing server name")
}
opts := GetOpts(opt...)
if opts.withEventerConfig != nil {
if err := opts.withEventerConfig.Validate(); err != nil {
return berrors.Wrap(ctx, err, op, berrors.WithMsg("invalid eventer config"))
}
}
if opts.withEventerConfig == nil {
opts.withEventerConfig = event.DefaultEventerConfig()
}
if opts.withEventFlags != nil {
if err := opts.withEventFlags.Validate(); err != nil {
return berrors.Wrap(ctx, err, op, berrors.WithMsg("invalid event flags"))
}
if opts.withEventFlags.Format != "" {
for i := 0; i < len(opts.withEventerConfig.Sinks); i++ {
opts.withEventerConfig.Sinks[i].Format = opts.withEventFlags.Format
}
}
if opts.withEventFlags.AuditEnabled != nil {
opts.withEventerConfig.AuditEnabled = *opts.withEventFlags.AuditEnabled
}
if opts.withEventFlags.ObservationsEnabled != nil {
opts.withEventerConfig.ObservationsEnabled = *opts.withEventFlags.ObservationsEnabled
}
if opts.withEventFlags.SysEventsEnabled != nil {
opts.withEventerConfig.SysEventsEnabled = *opts.withEventFlags.SysEventsEnabled
}
if opts.withEventFlags.TelemetryEnabled != nil {
opts.withEventerConfig.TelemetryEnabled = *opts.withEventFlags.TelemetryEnabled
}
if len(opts.withEventFlags.AllowFilters) > 0 {
for i := 0; i < len(opts.withEventerConfig.Sinks); i++ {
opts.withEventerConfig.Sinks[i].AllowFilters = opts.withEventFlags.AllowFilters
}
}
if len(opts.withEventFlags.DenyFilters) > 0 {
for i := 0; i < len(opts.withEventerConfig.Sinks); i++ {
opts.withEventerConfig.Sinks[i].DenyFilters = opts.withEventFlags.DenyFilters
}
}
}
e, err := event.NewEventer(
logger,
serializationLock,
serverName,
*opts.withEventerConfig,
// Note: this may be nil at this point, it is updated later on in SetupKMSes.
// There is a cyclic dependency between the eventer and the wrapper, so we instantiate
// the eventer with a nil wrapper until we have a wrapper to use.
event.WithAuditWrapper(opts.withEventWrapper),
event.WithGating(opts.withEventGating))
if err != nil {
return berrors.Wrap(ctx, err, op, berrors.WithMsg("unable to create eventer"))
}
b.Eventer = e
if err := event.InitSysEventer(logger, serializationLock, serverName, event.WithEventer(e)); err != nil {
return berrors.Wrap(ctx, err, op, berrors.WithMsg("unable to initialize system eventer"))
}
return nil
}
// AddEventerToContext will add the server eventer to the context provided
func (b *Server) AddEventerToContext(ctx context.Context) (context.Context, error) {
const op = "base.(Server).AddEventerToContext"
if b.Eventer == nil {
return nil, berrors.New(ctx, berrors.InvalidParameter, op, "missing server eventer")
}
e, err := event.NewEventerContext(ctx, b.Eventer)
if err != nil {
return nil, berrors.Wrap(ctx, err, op, berrors.WithMsg("unable to add eventer to context"))
}
return e, nil
}
// SetupLogging sets up the command's logger. This is mostly historical at this
// point since we switched to eventing; however, logging is still used as a
// fallback when events are unable to be sent.
func (b *Server) SetupLogging(flagLogLevel, flagLogFormat, configLogLevel, configLogFormat string) error {
b.logOutput = os.Stderr
if b.CombineLogs {
b.logOutput = os.Stdout
}
b.GatedWriter = gatedwriter.NewWriter(b.logOutput)
// Set up logging
logLevel, logFormat, err := ProcessLogLevelAndFormat(flagLogLevel, flagLogFormat, configLogLevel, configLogFormat)
if err != nil {
return err
}
b.Logger = hclog.New(&hclog.LoggerOptions{
Output: b.GatedWriter,
Level: logLevel,
// Note that if logFormat is either unspecified or json, then
// the resulting logger's format will be json.
JSONFormat: logFormat != logging.StandardFormat,
Mutex: b.StderrLock,
})
if err := event.InitFallbackLogger(b.Logger); err != nil {
return err
}
// create GRPC logger
namedGRPCLogFaker := b.Logger.Named("grpclogfaker")
grpcLogFaker.SetLogOnOff(os.Getenv("BOUNDARY_GRPC_LOGGING") != "")
grpcLogFaker.SetLogger(namedGRPCLogFaker)
b.Info["log level"] = logLevel.String()
b.InfoKeys = append(b.InfoKeys, "log level")
// log proxy settings
// TODO: It would be good to show this but Vault has, or will soon, address
// the fact that this can log users/passwords if they are part of the proxy
// URL. When they change things to address that we should update the below
// logic and re-enable.
/*
proxyCfg := httpproxy.FromEnvironment()
event.WriteSysEvent(context.TODO(), op,
"proxy environment",
"http_proxy", proxyCfg.HTTPProxy,
"https_proxy", proxyCfg.HTTPSProxy,
"no_proxy", proxyCfg.NoProxy,
})
*/
// Setup gorm logging
return nil
}
func (b *Server) ReleaseLogGate() error {
// Release the log gate.
b.Logger.(hclog.OutputResettable).ResetOutputWithFlush(&hclog.LoggerOptions{
Output: b.logOutput,
}, b.GatedWriter)
if b.Eventer != nil {
return b.Eventer.ReleaseGate()
}
return nil
}
func (b *Server) StorePidFile(pidPath string) error {
// Quit fast if no pidfile
if pidPath == "" {
return nil
}
// Open the PID file
pidFile, err := os.OpenFile(pidPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o644)
if err != nil {
return fmt.Errorf("could not open pid file: %w", err)
}
defer pidFile.Close()
// Write out the PID
pid := os.Getpid()
_, err = pidFile.WriteString(fmt.Sprintf("%d", pid))
if err != nil {
return fmt.Errorf("could not write to pid file: %w", err)
}
b.ShutdownFuncs = append(b.ShutdownFuncs, func() error {
if err := b.RemovePidFile(pidPath); err != nil {
return fmt.Errorf("Error deleting the PID file: %w", err)
}
return nil
})
return nil
}
func (b *Server) RemovePidFile(pidPath string) error {
if pidPath == "" {
return nil
}
return os.Remove(pidPath)
}
func (b *Server) StoreWorkerAuthReq(authReq, workerAuthReqPath string) error {
// Quit fast if no workerAuthReqFile
if workerAuthReqPath == "" {
return nil
}
workerAuthReqFilePath := filepath.Join(workerAuthReqPath, WorkerAuthReqFile)
// Open the workerAuthReq file
workerAuthReqFile, err := os.OpenFile(workerAuthReqFilePath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o640)
if err != nil {
return fmt.Errorf("could not open file for worker auth request: %w", err)
}
defer workerAuthReqFile.Close()
// Write out the workerAuthRequest
_, err = workerAuthReqFile.WriteString(authReq)
if err != nil {
return fmt.Errorf("could not write to file for worker auth request: %w", err)
}
return nil
}
func (b *Server) PrintInfo(ui cli.Ui) {
verInfo := version.Get()
if verInfo.Version != "" {
b.InfoKeys = append(b.InfoKeys, "version")
b.Info["version"] = verInfo.FullVersionNumber(false)
}
if verInfo.Revision != "" {
b.Info["version sha"] = strings.Trim(verInfo.Revision, "'")
b.InfoKeys = append(b.InfoKeys, "version sha")
}
b.InfoKeys = append(b.InfoKeys, "cgo")
b.Info["cgo"] = "disabled"
if version.CgoEnabled {
b.Info["cgo"] = "enabled"
}
// Server configuration output
padding := 0
for _, k := range b.InfoKeys {
currPadding := padding - len(k)
if currPadding < 2 {
padding = len(k) + 2
}
}
sort.Strings(b.InfoKeys)
ui.Output("==> Boundary server configuration:\n")
for _, k := range b.InfoKeys {
ui.Output(fmt.Sprintf(
"%s%s: %s",
strings.Repeat(" ", padding-len(k)),
cases.Title(language.AmericanEnglish).String(k),
b.Info[k]))
}
ui.Output("")
// Output the header that the server has started
if !b.CombineLogs {
ui.Output("==> Boundary server started! Log data will stream in below:\n")
}
}
func (b *Server) SetupListeners(ui cli.Ui, config *configutil.SharedConfig, allowedPurposes []string) error {
// Initialize the listeners
b.Listeners = make([]*ServerListener, 0, len(config.Listeners))
// Make sure we close everything before we exit
// If we successfully started a controller we'll have done this anyways so
// we ignore errors
b.ShutdownFuncs = append(b.ShutdownFuncs, func() error {
for _, ln := range b.Listeners {
if ln.ProxyListener != nil {
ln.ProxyListener.Close()
}
if ln.ClusterListener != nil {
ln.ClusterListener.Close()
}
if ln.ApiListener != nil {
ln.ApiListener.Close()
}
if ln.OpsListener != nil {
ln.OpsListener.Close()
}
}
return nil
})
b.ReloadFuncsLock.Lock()
defer b.ReloadFuncsLock.Unlock()
for i, lnConfig := range config.Listeners {
if len(lnConfig.Purpose) != 1 {
return fmt.Errorf("Invalid size of listener purposes (%d)", len(lnConfig.Purpose))
}
purpose := strings.ToLower(lnConfig.Purpose[0])
if !strutil.StrListContains(allowedPurposes, purpose) {
return fmt.Errorf("Unknown listener purpose %q", purpose)
}
lnConfig.Purpose[0] = purpose
if lnConfig.TLSCipherSuites == nil {
lnConfig.TLSCipherSuites = []uint16{
// 1.3
tls.TLS_AES_128_GCM_SHA256,
tls.TLS_AES_256_GCM_SHA384,
tls.TLS_CHACHA20_POLY1305_SHA256,
// 1.2
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
}
}
ln, props, reloadFunc, err := NewListener(lnConfig, ui)
if err != nil {
return fmt.Errorf("Error initializing listener of type %s: %w", lnConfig.Type, err)
}
// CORS props
if purpose == "api" {
if lnConfig.CorsEnabled != nil && *lnConfig.CorsEnabled {
props["cors_enabled"] = "true"
props["cors_allowed_origins"] = fmt.Sprintf("%v", lnConfig.CorsAllowedOrigins)
props["cors_allowed_headers"] = fmt.Sprintf("%v", lnConfig.CorsAllowedHeaders)
} else {
props["cors_enabled"] = "false"
}
}
// X-Forwarded-For props
{
if len(lnConfig.XForwardedForAuthorizedAddrs) > 0 {
props["x_forwarded_for_authorized_addrs"] = fmt.Sprintf("%v", lnConfig.XForwardedForAuthorizedAddrs)
props["x_forwarded_for_reject_not_present"] = strconv.FormatBool(lnConfig.XForwardedForRejectNotPresent)
props["x_forwarded_for_hop_skips"] = "0"
}
if lnConfig.XForwardedForHopSkips > 0 {
props["x_forwarded_for_hop_skips"] = fmt.Sprintf("%d", lnConfig.XForwardedForHopSkips)
}
}
if reloadFunc != nil {
relSlice := b.ReloadFuncs["listeners"]
relSlice = append(relSlice, reloadFunc)
b.ReloadFuncs["listeners"] = relSlice
}
if lnConfig.MaxRequestSize == 0 {
lnConfig.MaxRequestSize = globals.DefaultMaxRequestSize
}
// TODO: We don't actually limit this yet.
// props["max_request_size"] = fmt.Sprintf("%d", lnConfig.MaxRequestSize)
if lnConfig.MaxRequestDuration == 0 {
lnConfig.MaxRequestDuration = globals.DefaultMaxRequestDuration
}
props["max_request_duration"] = lnConfig.MaxRequestDuration.String()
serverListener := &ServerListener{
Config: lnConfig,
}
switch purpose {
case "api":
serverListener.ApiListener = ln
case "cluster":
serverListener.ClusterListener = ln
case "proxy":
serverListener.ProxyListener = ln
case "ops":
serverListener.OpsListener = ln
}
b.Listeners = append(b.Listeners, serverListener)
props["purpose"] = strings.Join(lnConfig.Purpose, ",")
// Store the listener props for output later
key := fmt.Sprintf("listener %d", i+1)
propsList := make([]string, 0, len(props))
for k, v := range props {
propsList = append(propsList, fmt.Sprintf(
"%s: %q", k, v))
}
sort.Strings(propsList)
b.InfoKeys = append(b.InfoKeys, key)
b.Info[key] = fmt.Sprintf(
"%s (%s)", lnConfig.Type, strings.Join(propsList, ", "))
}
return nil
}
// SetupKMSes takes in a parsed config, does some minor checking on purposes,
// and sends each off to configutil to instantiate a wrapper.
func (b *Server) SetupKMSes(ctx context.Context, ui cli.Ui, config *config.Config, opt ...Option) error {
sharedConfig := config.SharedConfig
var pluginLogger hclog.Logger
var err error
var previousRootKms wrapping.Wrapper
purposeCount := map[string]uint{}
for _, kms := range sharedConfig.Seals {
for _, purpose := range kms.Purpose {
purpose = strings.ToLower(purpose)
purposeCount[purpose] = purposeCount[purpose] + 1
switch purpose {
case "":
return errors.New("KMS block missing 'purpose'")
case globals.KmsPurposeRoot,
globals.KmsPurposePreviousRoot,
globals.KmsPurposeConfig,
globals.KmsPurposeWorkerAuth,
globals.KmsPurposeDownstreamWorkerAuth,
globals.KmsPurposeWorkerAuthStorage,
globals.KmsPurposeBsr:
case globals.KmsPurposeRecovery:
if config.Controller != nil && config.DevRecoveryKey != "" {
kms.Config["key"] = config.DevRecoveryKey
}
default:
return fmt.Errorf("Unknown KMS purpose %q", kms.Purpose)
}
if pluginLogger == nil {
pluginLogger, err = event.NewHclogLogger(b.Context, b.Eventer)
if err != nil {
return fmt.Errorf("Error creating KMS plugin logger: %w", err)
}
}
// This can be modified by configutil so store the original value
origPurpose := kms.Purpose
kms.Purpose = []string{purpose}
wrapper, cleanupFunc, wrapperConfigError := configutil.ConfigureWrapper(
ctx,
kms,
&b.InfoKeys,
&b.Info,
configutil.WithPluginOptions(
pluginutil.WithPluginsMap(kms_plugin_assets.BuiltinKmsPlugins()),
pluginutil.WithPluginsFilesystem(kms_plugin_assets.KmsPluginPrefix, kms_plugin_assets.FileSystem()),
pluginutil.WithPluginExecutionDirectory(config.Plugins.ExecutionDir),
),
configutil.WithLogger(pluginLogger.Named(kms.Type).With("purpose", fmt.Sprintf("%s-%d", purpose, purposeCount[purpose]))),
)
if wrapperConfigError != nil {
return fmt.Errorf(
"Error parsing KMS configuration: %s", wrapperConfigError)
}
if wrapper == nil {
return fmt.Errorf(
"After configuration nil KMS returned, KMS type was %s", kms.Type)
}
if ifWrapper, ok := wrapper.(wrapping.InitFinalizer); ok {
if err := ifWrapper.Init(ctx); err != nil && !errors.Is(err, wrapping.ErrFunctionNotImplemented) {
return fmt.Errorf("Error initializing KMS: %w", err)
}
// Ensure that the seal finalizer is called, even if using verify-only
b.ShutdownFuncs = append(b.ShutdownFuncs, func() error {
if err := ifWrapper.Finalize(context.Background()); err != nil && !errors.Is(err, wrapping.ErrFunctionNotImplemented) {
return fmt.Errorf("Error finalizing kms of type %s and purpose %s: %v", kms.Type, purpose, err)
}
return nil
})
}
if cleanupFunc != nil {
b.ShutdownFuncs = append(b.ShutdownFuncs, func() error {
return cleanupFunc()
})
}
kms.Purpose = origPurpose
switch purpose {
case globals.KmsPurposePreviousRoot:
if previousRootKms != nil {
return fmt.Errorf("Duplicate KMS block for purpose '%s'. You may need to remove all but the last KMS block for this purpose.", purpose)
}
previousRootKms = wrapper
case globals.KmsPurposeRoot:
if b.RootKms != nil {
return fmt.Errorf("Duplicate KMS block for purpose '%s'. You may need to remove all but the last KMS block for this purpose.", purpose)
}
b.RootKms = wrapper
case globals.KmsPurposeWorkerAuth:
if b.WorkerAuthKms != nil {
return fmt.Errorf("Duplicate KMS block for purpose '%s'. You may need to remove all but the last KMS block for this purpose.", purpose)
}
b.WorkerAuthKms = wrapper
case globals.KmsPurposeDownstreamWorkerAuth:
if b.DownstreamWorkerAuthKms == nil {
b.DownstreamWorkerAuthKms, err = multi.NewPooledWrapper(ctx, wrapper)
if err != nil {
return fmt.Errorf("Error instantiating pooled wrapper for downstream worker auth: %w.", err)
}
} else {
added, err := b.DownstreamWorkerAuthKms.AddWrapper(ctx, wrapper)
if err != nil {
return fmt.Errorf("Error adding additional wrapper to downstream worker auth wrapper pool: %w.", err)
}
if !added {
return fmt.Errorf("Wrapper already added to downstream worker auth wrapper pool.")
}
}
case globals.KmsPurposeWorkerAuthStorage:
if b.WorkerAuthStorageKms != nil {
return fmt.Errorf("Duplicate KMS block for purpose '%s'. You may need to remove all but the last KMS block for this purpose.", purpose)
}
b.WorkerAuthStorageKms = wrapper
case globals.KmsPurposeBsr:
if b.BsrKms != nil {
return fmt.Errorf("Duplicate KMS block for purpose '%s'. You may need to remove all but the last KMS block for this purpose.", purpose)
}
b.BsrKms = wrapper
case globals.KmsPurposeRecovery:
if b.RecoveryKms != nil {
return fmt.Errorf("Duplicate KMS block for purpose '%s'. You may need to remove all but the last KMS block for this purpose.", purpose)
}
b.RecoveryKms = wrapper
case globals.KmsPurposeConfig:
// Do nothing, can be set in same file but not needed at runtime
continue
default:
return fmt.Errorf("KMS purpose of %q is unknown", purpose)
}
}
}
// Handle previous root KMS
if previousRootKms != nil {
if util.IsNil(b.RootKms) {
return fmt.Errorf("KMS block contains '%s' without '%s'", globals.KmsPurposePreviousRoot, globals.KmsPurposeRoot)
}
mw, err := multi.NewPooledWrapper(ctx, previousRootKms)
if err != nil {
return fmt.Errorf("failed to create multi wrapper: %w", err)
}
ok, err := mw.SetEncryptingWrapper(ctx, b.RootKms)
if err != nil {
return fmt.Errorf("failed to set root wrapper as active in multi wrapper: %w", err)
}
if !ok {
return fmt.Errorf("KMS blocks with purposes '%s' and '%s' must have different key IDs", globals.KmsPurposeRoot, globals.KmsPurposePreviousRoot)
}
b.RootKms = mw
}
// prepare a secure random reader
b.SecureRandomReader, err = configutil.CreateSecureRandomReaderFunc(config.SharedConfig, b.RootKms)
if err != nil {
return err
}
// This might not be the _best_ place for this but we have access to config
// here
b.Info["mlock"] = fmt.Sprintf(
"supported: %v, enabled: %v",
mlock.Supported(), !config.DisableMlock && mlock.Supported())
b.InfoKeys = append(b.InfoKeys, "mlock")
return nil
}
func (b *Server) RunShutdownFuncs() error {
var mErr error
for _, f := range b.ShutdownFuncs {
if err := f(); err != nil {
mErr = errors.Join(mErr, err)
}
}
return mErr
}
// OpenAndSetServerDatabase calls OpenDatabase and sets its result *db.DB to the Server's
// `Database` field.
func (b *Server) OpenAndSetServerDatabase(ctx context.Context, dialect string) error {
dbase, err := b.OpenDatabase(ctx, dialect, b.DatabaseUrl)
if err != nil {
return err
}
b.Database = dbase
return nil
}
// OpenDatabase creates a database connection with the given URL and returns it to the caller.
// It supports various configuration options - The values must be set on the Server object
// beforehand.
func (b *Server) OpenDatabase(ctx context.Context, dialect, url string) (*db.DB, error) {
dbType, err := db.StringToDbType(dialect)
if err != nil {
return nil, fmt.Errorf("unable to create db object with dialect %s: %w", dialect, err)
}
opts := []db.Option{
db.WithMaxOpenConnections(b.DatabaseMaxOpenConnections),
db.WithMaxIdleConnections(b.DatabaseMaxIdleConnections),
db.WithConnMaxIdleTimeDuration(b.DatabaseConnMaxIdleTimeDuration),
}
if os.Getenv("BOUNDARY_DISABLE_GORM_FORMATTER") == "" {
opts = append(opts, db.WithGormFormatter(b.Logger))
}
dbase, err := db.Open(ctx, dbType, url, opts...)
if err != nil {
return nil, fmt.Errorf("unable to create db object with dialect %s: %w", dialect, err)
}
return dbase, nil
}
func (b *Server) CreateGlobalKmsKeys(ctx context.Context) error {
rw := db.New(b.Database)
kmsCache, err := kms.New(ctx, rw, rw)
if err != nil {
return fmt.Errorf("error creating kms cache: %w", err)
}
if err := kmsCache.AddExternalWrappers(
b.Context,
kms.WithRootWrapper(b.RootKms),
); err != nil {
return fmt.Errorf("error adding config keys to kms: %w", err)
}
if err = kmsCache.CreateKeys(ctx, scope.Global.String(), kms.WithRandomReader(b.SecureRandomReader)); err != nil {
return fmt.Errorf("error creating global scope kms keys: %w", err)
}
return nil
}
func (b *Server) DestroyDevDatabase(ctx context.Context) error {
if b.Database != nil {
if err := b.Database.Close(ctx); err != nil {
return err
}
}
if b.DevDatabaseCleanupFunc != nil {
return b.DevDatabaseCleanupFunc()
}
return nil
}
func (b *Server) SetupWorkerPublicAddress(conf *config.Config, flagValue string) error {
if conf.Worker == nil {
conf.Worker = new(config.Worker)
}
if flagValue != "" {
conf.Worker.PublicAddr = flagValue
}
isUnixListener := false
if conf.Worker.PublicAddr == "" {
FindAddr:
for _, listener := range conf.Listeners {
for _, purpose := range listener.Purpose {
if purpose == "proxy" {
conf.Worker.PublicAddr = listener.Address
if strings.EqualFold(listener.Type, "unix") {
isUnixListener = true
}
break FindAddr
}
}
}
} else {
var err error
conf.Worker.PublicAddr, err = parseutil.ParsePath(conf.Worker.PublicAddr)
if err != nil && !errors.Is(err, parseutil.ErrNotAUrl) {
return fmt.Errorf("Error parsing public addr: %w", err)
}
conf.Worker.PublicAddr, err = listenerutil.ParseSingleIPTemplate(conf.Worker.PublicAddr)
if err != nil {
return fmt.Errorf("Error parsing IP template on worker public addr: %w", err)
}
}
host, port, err := util.SplitHostPort(conf.Worker.PublicAddr)
if err != nil && !errors.Is(err, util.ErrMissingPort) {
return fmt.Errorf("Error splitting public adddress host/port: %w", err)
}
if host != "" {
if host, err = parseutil.NormalizeAddr(host); err != nil {
return fmt.Errorf("Error normalizing worker address")
}
}
if port == "" {
port = "9202"
}
conf.Worker.PublicAddr = util.JoinHostPort(host, port)
if host != "" && !isUnixListener {
// NormalizeAddr requires that a host be present, but that is not
// guaranteed in this code path. Additionally, if no host is present,
// there's no need to normalize.
conf.Worker.PublicAddr, err = parseutil.NormalizeAddr(conf.Worker.PublicAddr)
if err != nil {
return fmt.Errorf("Failed to normalize worker public adddress: %w", err)
}
}
return nil
}
// MakeSighupCh returns a channel that can be used for SIGHUP
// reloading. This channel will send a message for every
// SIGHUP received.
func MakeSighupCh() chan struct{} {
resultCh := make(chan struct{})
signalCh := make(chan os.Signal, 4)
signal.Notify(signalCh, syscall.SIGHUP)
go func() {
for {
<-signalCh
resultCh <- struct{}{}
}
}()
return resultCh
}