From 1645c3106ee7aefe5ab9e6ebbccfbb32e3c98f1c Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Thu, 19 Mar 2020 15:17:40 -0400 Subject: [PATCH] Migrate a bunch of stuff around to prep for worker command (#2) Right now it's essentially a copy; it needs some updating, integration into dev, and so on. --- .../listener_tcp.go => base/listener.go} | 71 +++- internal/cmd/base/profiling_off.go | 8 + .../dev_profile.go => base/profiling_on.go} | 14 +- internal/cmd/base/servers.go | 3 +- internal/cmd/commands.go | 12 + .../cmd/commands/controller/controller.go | 22 +- .../controller/controller_noprofile.go | 6 - .../commands/controller/controller_profile.go | 40 -- .../commands/controller/listener/listener.go | 81 ---- internal/cmd/commands/dev/dev.go | 8 +- internal/cmd/commands/dev/dev_noprofile.go | 6 - internal/cmd/commands/worker/config/config.go | 122 ++++++ internal/cmd/commands/worker/worker.go | 385 ++++++++++++++++++ internal/servers/worker/client_tls.go | 182 +++++++++ internal/servers/worker/config.go | 14 + internal/servers/worker/handler.go | 159 ++++++++ internal/servers/worker/listeners.go | 148 +++++++ internal/servers/worker/worker.go | 76 ++++ 18 files changed, 1200 insertions(+), 157 deletions(-) rename internal/cmd/{commands/controller/listener/listener_tcp.go => base/listener.go} (50%) create mode 100644 internal/cmd/base/profiling_off.go rename internal/cmd/{commands/dev/dev_profile.go => base/profiling_on.go} (62%) delete mode 100644 internal/cmd/commands/controller/controller_noprofile.go delete mode 100644 internal/cmd/commands/controller/controller_profile.go delete mode 100644 internal/cmd/commands/controller/listener/listener.go delete mode 100644 internal/cmd/commands/dev/dev_noprofile.go create mode 100644 internal/cmd/commands/worker/config/config.go create mode 100644 internal/cmd/commands/worker/worker.go create mode 100644 internal/servers/worker/client_tls.go create mode 100644 internal/servers/worker/config.go create mode 100644 internal/servers/worker/handler.go create mode 100644 internal/servers/worker/listeners.go create mode 100644 internal/servers/worker/worker.go diff --git a/internal/cmd/commands/controller/listener/listener_tcp.go b/internal/cmd/base/listener.go similarity index 50% rename from internal/cmd/commands/controller/listener/listener_tcp.go rename to internal/cmd/base/listener.go index 095649c74d..8ba363f749 100644 --- a/internal/cmd/commands/controller/listener/listener_tcp.go +++ b/internal/cmd/base/listener.go @@ -1,18 +1,44 @@ -package listener +package base import ( + "errors" + "fmt" "io" "net" "strings" "time" + // We must import sha512 so that it registers with the runtime so that + // certificates that use it can be parsed. + _ "crypto/sha512" + "github.com/hashicorp/go-alpnmux" "github.com/hashicorp/vault/internalshared/configutil" "github.com/hashicorp/vault/internalshared/listenerutil" "github.com/hashicorp/vault/internalshared/reloadutil" "github.com/mitchellh/cli" + "github.com/pires/go-proxyproto" ) +// Factory is the factory function to create a listener. +type ListenerFactory func(*configutil.Listener, io.Writer, cli.Ui) (*alpnmux.ALPNMux, map[string]string, reloadutil.ReloadFunc, error) + +// BuiltinListeners is the list of built-in listener types. +var BuiltinListeners = map[string]ListenerFactory{ + "tcp": tcpListenerFactory, +} + +// New creates a new listener of the given type with the given +// configuration. The type is looked up in the BuiltinListeners map. +func NewListener(l *configutil.Listener, w io.Writer, ui cli.Ui) (*alpnmux.ALPNMux, map[string]string, reloadutil.ReloadFunc, error) { + f, ok := BuiltinListeners[l.Type] + if !ok { + return nil, nil, nil, fmt.Errorf("unknown listener type: %q", l.Type) + } + + return f(l, w, ui) +} + func tcpListenerFactory(l *configutil.Listener, _ io.Writer, ui cli.Ui) (*alpnmux.ALPNMux, map[string]string, reloadutil.ReloadFunc, error) { if l.Address == "" { l.Address = "127.0.0.1:9200" @@ -69,6 +95,49 @@ func tcpListenerFactory(l *configutil.Listener, _ io.Writer, ui cli.Ui) (*alpnmu return alpnMux, props, reloadFunc, nil } +func listenerWrapProxy(ln net.Listener, l *configutil.Listener) (net.Listener, error) { + behavior := l.ProxyProtocolBehavior + if behavior == "" { + return ln, nil + } + + authorizedAddrs := make([]string, 0, len(l.ProxyProtocolAuthorizedAddrs)) + for _, v := range l.ProxyProtocolAuthorizedAddrs { + authorizedAddrs = append(authorizedAddrs, v.String()) + } + + var policyFunc proxyproto.PolicyFunc + + switch behavior { + case "use_always": + policyFunc = func(upstream net.Addr) (proxyproto.Policy, error) { + return proxyproto.USE, nil + } + + case "allow_authorized": + if len(authorizedAddrs) == 0 { + return nil, errors.New("proxy_protocol_behavior set but no proxy_protocol_authorized_addrs value") + } + policyFunc = proxyproto.MustLaxWhiteListPolicy(authorizedAddrs) + + case "deny_unauthorized": + if len(authorizedAddrs) == 0 { + return nil, errors.New("proxy_protocol_behavior set but no proxy_protocol_authorized_addrs value") + } + policyFunc = proxyproto.MustStrictWhiteListPolicy(authorizedAddrs) + + default: + return nil, fmt.Errorf("unknown %q value: %q", "proxy_protocol_behavior", behavior) + } + + proxyListener := &proxyproto.Listener{ + Listener: ln, + Policy: policyFunc, + } + + return proxyListener, nil +} + // TCPKeepAliveListener sets TCP keep-alive timeouts on accepted // connections. It's used by ListenAndServe and ListenAndServeTLS so // dead TCP connections (e.g. closing laptop mid-download) eventually diff --git a/internal/cmd/base/profiling_off.go b/internal/cmd/base/profiling_off.go new file mode 100644 index 0000000000..c1578a6301 --- /dev/null +++ b/internal/cmd/base/profiling_off.go @@ -0,0 +1,8 @@ +// +build !memprofiler + +package base + +import "github.com/hashicorp/go-hclog" + +func StartMemProfiler(_ hclog.Logger) { +} diff --git a/internal/cmd/commands/dev/dev_profile.go b/internal/cmd/base/profiling_on.go similarity index 62% rename from internal/cmd/commands/dev/dev_profile.go rename to internal/cmd/base/profiling_on.go index 95001e7630..b10936345d 100644 --- a/internal/cmd/commands/dev/dev_profile.go +++ b/internal/cmd/base/profiling_on.go @@ -1,6 +1,6 @@ // +build memprofiler -package dev +package base import ( "os" @@ -8,16 +8,18 @@ import ( "runtime" "runtime/pprof" "time" + + "github.com/hashicorp/go-hclog" ) func init() { memProfilerEnabled = true } -func (d *Command) startMemProfiler() { +func StartMemProfiler(logger hclog.Logger) { profileDir := filepath.Join(os.TempDir(), "watchtowerprof") if err := os.MkdirAll(profileDir, 0700); err != nil { - d.logger.Debug("could not create profile directory", "error", err) + logger.Debug("could not create profile directory", "error", err) return } @@ -26,14 +28,14 @@ func (d *Command) startMemProfiler() { filename := filepath.Join(profileDir, time.Now().UTC().Format("20060102_150405")) + ".pprof" f, err := os.Create(filename) if err != nil { - d.logger.Debug("could not create memory profile", "error", err) + logger.Debug("could not create memory profile", "error", err) } runtime.GC() if err := pprof.WriteHeapProfile(f); err != nil { - d.logger.Debug("could not write memory profile", "error", err) + logger.Debug("could not write memory profile", "error", err) } f.Close() - d.logger.Debug("wrote memory profile", "filename", filename) + logger.Debug("wrote memory profile", "filename", filename) time.Sleep(5 * time.Minute) } }() diff --git a/internal/cmd/base/servers.go b/internal/cmd/base/servers.go index 5ae6367eb4..3256e32c23 100644 --- a/internal/cmd/base/servers.go +++ b/internal/cmd/base/servers.go @@ -28,7 +28,6 @@ import ( "github.com/hashicorp/vault/sdk/helper/mlock" "github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/watchtower/globals" - "github.com/hashicorp/watchtower/internal/cmd/commands/controller/listener" "github.com/hashicorp/watchtower/version" "github.com/mitchellh/cli" "github.com/ory/dockertest/v3" @@ -250,7 +249,7 @@ func (b *Server) SetupListeners(ui cli.Ui, config *configutil.SharedConfig) erro tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, } - lnMux, props, reloadFunc, err := listener.New(lnConfig, b.GatedWriter, ui) + lnMux, props, reloadFunc, err := NewListener(lnConfig, b.GatedWriter, ui) if err != nil { return fmt.Errorf("Error initializing listener of type %s: %w", lnConfig.Type, err) } diff --git a/internal/cmd/commands.go b/internal/cmd/commands.go index 20d5642e17..2dbbf88241 100644 --- a/internal/cmd/commands.go +++ b/internal/cmd/commands.go @@ -8,6 +8,7 @@ import ( "github.com/hashicorp/watchtower/internal/cmd/base" "github.com/hashicorp/watchtower/internal/cmd/commands/controller" "github.com/hashicorp/watchtower/internal/cmd/commands/dev" + "github.com/hashicorp/watchtower/internal/cmd/commands/worker" "github.com/mitchellh/cli" ) @@ -36,6 +37,17 @@ func initCommands(ui, serverCmdUi cli.Ui, runOpts *RunOptions) { SigUSR2Ch: MakeSigUSR2Ch(), }, nil }, + "worker": func() (cli.Command, error) { + return &worker.Command{ + Command: &base.Command{ + UI: serverCmdUi, + Address: runOpts.Address, + }, + ShutdownCh: MakeShutdownCh(), + SighupCh: MakeSighupCh(), + SigUSR2Ch: MakeSigUSR2Ch(), + }, nil + }, "dev": func() (cli.Command, error) { return &dev.Command{ Command: &base.Command{ diff --git a/internal/cmd/commands/controller/controller.go b/internal/cmd/commands/controller/controller.go index f4cd628f37..dcd742d70b 100644 --- a/internal/cmd/commands/controller/controller.go +++ b/internal/cmd/commands/controller/controller.go @@ -54,7 +54,7 @@ Usage: watchtower controller [options] Start a controller with a configuration file: - $ watchtower controller -config=/etc/controller/config.hcl + $ watchtower controller -config=/etc/watchtower/controller.hcl For a full list of examples, please see the documentation. @@ -147,15 +147,15 @@ func (c *Command) Run(args []string) int { return result } - if memProfilerEnabled { - c.startMemProfiler() - } - if err := c.SetupLogging(c.flagLogLevel, c.flagLogFormat, c.Config.LogLevel, c.Config.LogFormat); err != nil { c.UI.Error(err.Error()) return 1 } + if memProfilerEnabled { + base.StartMemProfiler(c.Logger) + } + if err := c.SetupMetrics(c.UI, c.Config.Telemetry); err != nil { c.UI.Error(err.Error()) return 1 @@ -259,21 +259,21 @@ func (c *Command) ParseFlagsAndConfig(args []string) int { func (c *Command) Start() int { // Instantiate the wait group - controllerConfig := &controller.Config{ + conf := &controller.Config{ RawConfig: c.Config, Server: c.Server, } // Initialize the core - controller, err := controller.New(controllerConfig) + ctlr, err := controller.New(conf) if err != nil { c.UI.Error(fmt.Sprintf("Error initializing controller: %w", err)) return 1 } - if err := controller.Start(); err != nil { + if err := ctlr.Start(); err != nil { c.UI.Error(fmt.Sprint("Error starting controller: %w", err)) - if err := controller.Shutdown(); err != nil { + if err := ctlr.Shutdown(); err != nil { c.UI.Error(fmt.Sprintf("Error with controller shutdown: %w", err)) } return 1 @@ -287,7 +287,7 @@ func (c *Command) Start() int { case <-c.ShutdownCh: c.UI.Output("==> Watchtower controller shutdown triggered") - if err := controller.Shutdown(); err != nil { + if err := ctlr.Shutdown(); err != nil { c.UI.Error(fmt.Sprintf("Error with controller shutdown: %w", err)) } @@ -337,7 +337,7 @@ func (c *Command) Start() int { c.Logger.Error("unknown log level found on reload", "level", newConf.LogLevel) goto RUNRELOADFUNCS } - controller.SetLogLevel(level) + ctlr.SetLogLevel(level) } RUNRELOADFUNCS: diff --git a/internal/cmd/commands/controller/controller_noprofile.go b/internal/cmd/commands/controller/controller_noprofile.go deleted file mode 100644 index eb90994217..0000000000 --- a/internal/cmd/commands/controller/controller_noprofile.go +++ /dev/null @@ -1,6 +0,0 @@ -// +build !memprofiler - -package controller - -func (c *Command) startMemProfiler() { -} diff --git a/internal/cmd/commands/controller/controller_profile.go b/internal/cmd/commands/controller/controller_profile.go deleted file mode 100644 index 53e28483d9..0000000000 --- a/internal/cmd/commands/controller/controller_profile.go +++ /dev/null @@ -1,40 +0,0 @@ -// +build memprofiler - -package controller - -import ( - "os" - "path/filepath" - "runtime" - "runtime/pprof" - "time" -) - -func init() { - memProfilerEnabled = true -} - -func (c *Command) startMemProfiler() { - profileDir := filepath.Join(os.TempDir(), "watchtowerprof") - if err := os.MkdirAll(profileDir, 0700); err != nil { - c.logger.Debug("could not create profile directory", "error", err) - return - } - - go func() { - for { - filename := filepath.Join(profileDir, time.Now().UTC().Format("20060102_150405")) + ".pprof" - f, err := os.Create(filename) - if err != nil { - c.logger.Debug("could not create memory profile", "error", err) - } - runtime.GC() - if err := pprof.WriteHeapProfile(f); err != nil { - c.logger.Debug("could not write memory profile", "error", err) - } - f.Close() - c.logger.Debug("wrote memory profile", "filename", filename) - time.Sleep(5 * time.Minute) - } - }() -} diff --git a/internal/cmd/commands/controller/listener/listener.go b/internal/cmd/commands/controller/listener/listener.go deleted file mode 100644 index 5f14606f5a..0000000000 --- a/internal/cmd/commands/controller/listener/listener.go +++ /dev/null @@ -1,81 +0,0 @@ -package listener - -import ( - "errors" - "io" - - "github.com/pires/go-proxyproto" - - // We must import sha512 so that it registers with the runtime so that - // certificates that use it can be parsed. - _ "crypto/sha512" - "fmt" - "net" - - "github.com/hashicorp/go-alpnmux" - "github.com/hashicorp/vault/internalshared/configutil" - "github.com/hashicorp/vault/internalshared/reloadutil" - "github.com/mitchellh/cli" -) - -// Factory is the factory function to create a listener. -type Factory func(*configutil.Listener, io.Writer, cli.Ui) (*alpnmux.ALPNMux, map[string]string, reloadutil.ReloadFunc, error) - -// BuiltinListeners is the list of built-in listener types. -var BuiltinListeners = map[string]Factory{ - "tcp": tcpListenerFactory, -} - -// New creates a new listener of the given type with the given -// configuration. The type is looked up in the BuiltinListeners map. -func New(l *configutil.Listener, w io.Writer, ui cli.Ui) (*alpnmux.ALPNMux, map[string]string, reloadutil.ReloadFunc, error) { - f, ok := BuiltinListeners[l.Type] - if !ok { - return nil, nil, nil, fmt.Errorf("unknown listener type: %q", l.Type) - } - - return f(l, w, ui) -} - -func listenerWrapProxy(ln net.Listener, l *configutil.Listener) (net.Listener, error) { - behavior := l.ProxyProtocolBehavior - if behavior == "" { - return ln, nil - } - - authorizedAddrs := make([]string, 0, len(l.ProxyProtocolAuthorizedAddrs)) - for _, v := range l.ProxyProtocolAuthorizedAddrs { - authorizedAddrs = append(authorizedAddrs, v.String()) - } - - var policyFunc proxyproto.PolicyFunc - - switch behavior { - case "use_always": - policyFunc = func(upstream net.Addr) (proxyproto.Policy, error) { - return proxyproto.USE, nil - } - - case "allow_authorized": - if len(authorizedAddrs) == 0 { - return nil, errors.New("proxy_protocol_behavior set but no proxy_protocol_authorized_addrs value") - } - policyFunc = proxyproto.MustLaxWhiteListPolicy(authorizedAddrs) - - case "deny_unauthorized": - if len(authorizedAddrs) == 0 { - return nil, errors.New("proxy_protocol_behavior set but no proxy_protocol_authorized_addrs value") - } - policyFunc = proxyproto.MustStrictWhiteListPolicy(authorizedAddrs) - - default: - return nil, fmt.Errorf("unknown %q value: %q", "proxy_protocol_behavior", behavior) - } - - proxyListener := &proxyproto.Listener{ - Listener: ln, - Policy: policyFunc, - } - - return proxyListener, nil -} diff --git a/internal/cmd/commands/dev/dev.go b/internal/cmd/commands/dev/dev.go index 6c7f0d2832..335d5b1356 100644 --- a/internal/cmd/commands/dev/dev.go +++ b/internal/cmd/commands/dev/dev.go @@ -138,15 +138,15 @@ func (c *Command) Run(args []string) int { return 1 } - if memProfilerEnabled { - c.startMemProfiler() - } - if err := c.SetupLogging(c.flagLogLevel, c.flagLogFormat, "", ""); err != nil { c.UI.Error(err.Error()) return 1 } + if memProfilerEnabled { + base.StartMemProfiler(c.Logger) + } + if err := c.SetupMetrics(c.UI, devControllerConfig.Telemetry); err != nil { c.UI.Error(err.Error()) return 1 diff --git a/internal/cmd/commands/dev/dev_noprofile.go b/internal/cmd/commands/dev/dev_noprofile.go deleted file mode 100644 index 51aa876425..0000000000 --- a/internal/cmd/commands/dev/dev_noprofile.go +++ /dev/null @@ -1,6 +0,0 @@ -// +build !memprofiler - -package dev - -func (d *Command) startMemProfiler() { -} diff --git a/internal/cmd/commands/worker/config/config.go b/internal/cmd/commands/worker/config/config.go new file mode 100644 index 0000000000..93e237d8bf --- /dev/null +++ b/internal/cmd/commands/worker/config/config.go @@ -0,0 +1,122 @@ +package config + +import ( + "bytes" + "crypto/rand" + "encoding/base64" + "fmt" + "io" + "io/ioutil" + + "github.com/hashicorp/hcl" + "github.com/hashicorp/vault/internalshared/configutil" +) + +// Config is the configuration for the watchtower controller +type Config struct { + *configutil.SharedConfig `hcl:"-"` +} + +// Dev is a Config that is used for dev mode of Watchtower +func Dev() (*Config, error) { + randBuf := new(bytes.Buffer) + n, err := randBuf.ReadFrom(&io.LimitedReader{ + R: rand.Reader, + N: 64, + }) + if err != nil { + return nil, err + } + if n != 64 { + return nil, fmt.Errorf("expected to read 64 bytes, read %d", n) + } + controllerKey := base64.StdEncoding.EncodeToString(randBuf.Bytes()[0:32]) + workerAuthKey := base64.StdEncoding.EncodeToString(randBuf.Bytes()[32:64]) + + hclStr := ` +disable_mlock = true + +listener "tcp" { + tls_disable = true + proxy_protocol_behavior = "allow_authorized" + proxy_protocol_authorized_addrs = "127.0.0.1" +} + +telemetry { + prometheus_retention_time = "24h" + disable_hostname = true +} +` + + hclStr = fmt.Sprintf(hclStr, controllerKey, workerAuthKey) + parsed, err := Parse(hclStr) + if err != nil { + return nil, fmt.Errorf("error parsing dev config: %w", err) + } + return parsed, nil +} + +func New() *Config { + return &Config{ + SharedConfig: new(configutil.SharedConfig), + } +} + +// LoadFile loads the configuration from the given file. +func LoadFile(path string) (*Config, error) { + // Read the file + d, err := ioutil.ReadFile(path) + if err != nil { + return nil, err + } + + conf, err := Parse(string(d)) + if err != nil { + return nil, err + } + + return conf, nil +} + +func Parse(d string) (*Config, error) { + obj, err := hcl.Parse(d) + if err != nil { + return nil, err + } + + // Nothing to do here right now + result := New() + if err := hcl.DecodeObject(result, obj); err != nil { + return nil, err + } + + sharedConfig, err := configutil.ParseConfig(d) + if err != nil { + return nil, err + } + result.SharedConfig = sharedConfig + + return result, nil +} + +// Sanitized returns a copy of the config with all values that are considered +// sensitive stripped. It also strips all `*Raw` values that are mainly +// used for parsing. +// +// Specifically, the fields that this method strips are: +// - KMS.Config +// - Telemetry.CirconusAPIToken +func (c *Config) Sanitized() map[string]interface{} { + // Create shared config if it doesn't exist (e.g. in tests) so that map + // keys are actually populated + if c.SharedConfig == nil { + c.SharedConfig = new(configutil.SharedConfig) + } + sharedResult := c.SharedConfig.Sanitized() + result := map[string]interface{}{} + for k, v := range sharedResult { + result[k] = v + } + + return result +} diff --git a/internal/cmd/commands/worker/worker.go b/internal/cmd/commands/worker/worker.go new file mode 100644 index 0000000000..824836ac78 --- /dev/null +++ b/internal/cmd/commands/worker/worker.go @@ -0,0 +1,385 @@ +package worker + +import ( + "fmt" + "runtime" + "strings" + "sync" + + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/go-multierror" + "github.com/hashicorp/vault/sdk/helper/mlock" + "github.com/hashicorp/watchtower/globals" + "github.com/hashicorp/watchtower/internal/cmd/base" + "github.com/hashicorp/watchtower/internal/cmd/commands/worker/config" + "github.com/hashicorp/watchtower/internal/servers/worker" + "github.com/mitchellh/cli" + "github.com/posener/complete" +) + +var _ cli.Command = (*Command)(nil) +var _ cli.CommandAutocomplete = (*Command)(nil) + +var memProfilerEnabled = false + +type Command struct { + *base.Command + *base.Server + + ShutdownCh chan struct{} + SighupCh chan struct{} + ReloadedCh chan struct{} + SigUSR2Ch chan struct{} + + cleanupGuard sync.Once + + Config *config.Config + + flagConfig string + flagLogLevel string + flagLogFormat string + flagDev bool + flagDevAdminToken string + flagDevWorkerListenAddr string + flagCombineLogs bool +} + +func (c *Command) Synopsis() string { + return "Start a Watchtower worker" +} + +func (c *Command) Help() string { + helpText := ` +Usage: watchtower worker [options] + + Start a worker with a configuration file: + + $ watchtower worker -config=/etc/watchtower/worker.hcl + + For a full list of examples, please see the documentation. + +` + c.Flags().Help() + return strings.TrimSpace(helpText) +} + +func (c *Command) Flags() *base.FlagSets { + set := c.FlagSet(base.FlagSetHTTP) + + f := set.NewFlagSet("Command Options") + + f.StringVar(&base.StringVar{ + Name: "config", + Target: &c.flagConfig, + Completion: complete.PredictOr( + complete.PredictFiles("*.hcl"), + complete.PredictFiles("*.json"), + ), + Usage: "Path to a configuration file.", + }) + + f.StringVar(&base.StringVar{ + Name: "log-level", + Target: &c.flagLogLevel, + Default: base.NotSetValue, + EnvVar: "WATCHTOWER_LOG_LEVEL", + Completion: complete.PredictSet("trace", "debug", "info", "warn", "err"), + Usage: "Log verbosity level. Supported values (in order of more detail to less) are " + + "\"trace\", \"debug\", \"info\", \"warn\", and \"err\".", + }) + + f.StringVar(&base.StringVar{ + Name: "log-format", + Target: &c.flagLogFormat, + Default: base.NotSetValue, + Completion: complete.PredictSet("standard", "json"), + Usage: `Log format. Supported values are "standard" and "json".`, + }) + + f = set.NewFlagSet("Dev Options") + + f.BoolVar(&base.BoolVar{ + Name: "dev", + Target: &c.flagDev, + Usage: "Enable development mode. As the name implies, do not run \"dev\" mode in " + + "production.", + }) + + f.StringVar(&base.StringVar{ + Name: "dev-admin-token", + Target: &c.flagDevAdminToken, + Default: "", + EnvVar: "WATCHTWER_DEV_ADMIN_TOKEN", + Usage: "Initial admin token. This only applies when running in \"dev\" " + + "mode.", + }) + + f.StringVar(&base.StringVar{ + Name: "dev-listen-address", + Target: &c.flagDevWorkerListenAddr, + Default: "127.0.0.1:9200", + EnvVar: "WATCHTOWER_DEV_WORKER_LISTEN_ADDRESS", + Usage: "Address to bind the worker to in \"dev\" mode.", + }) + + f.BoolVar(&base.BoolVar{ + Name: "combine-logs", + Target: &c.flagCombineLogs, + Default: false, + Usage: "If set, both startup information and logs will be sent to stdout. If not set (the default), startup information will go to stdout and logs will be sent to stderr.", + }) + + return set +} + +func (c *Command) AutocompleteArgs() complete.Predictor { + return complete.PredictNothing +} + +func (c *Command) AutocompleteFlags() complete.Flags { + return c.Flags().Completions() +} + +func (c *Command) Run(args []string) int { + c.Server = base.NewServer() + c.CombineLogs = c.flagCombineLogs + + if result := c.ParseFlagsAndConfig(args); result > 0 { + return result + } + + if err := c.SetupLogging(c.flagLogLevel, c.flagLogFormat, c.Config.LogLevel, c.Config.LogFormat); err != nil { + c.UI.Error(err.Error()) + return 1 + } + + if memProfilerEnabled { + base.StartMemProfiler(c.Logger) + } + + if err := c.SetupMetrics(c.UI, c.Config.Telemetry); err != nil { + c.UI.Error(err.Error()) + return 1 + } + + if err := c.SetupKMSes(c.UI, c.Config.SharedConfig, 2); err != nil { + c.UI.Error(err.Error()) + return 1 + } + + if c.Config.DefaultMaxRequestDuration != 0 { + globals.DefaultMaxRequestDuration = c.Config.DefaultMaxRequestDuration + } + + // If mlockall(2) isn't supported, show a warning. We disable this in dev + // because it is quite scary to see when first using Vault. We also disable + // this if the user has explicitly disabled mlock in configuration. + if !c.flagDev && !c.Config.DisableMlock && !mlock.Supported() { + c.UI.Warn(base.WrapAtLength( + "WARNING! mlock is not supported on this system! An mlockall(2)-like " + + "syscall to prevent memory from being swapped to disk is not " + + "supported on this system. For better security, only run Vault on " + + "systems where this call is supported. If you are running Vault " + + "in a Docker container, provide the IPC_LOCK cap to the container.")) + } + + if err := c.SetupListeners(c.UI, c.Config.SharedConfig); err != nil { + c.UI.Error(err.Error()) + return 1 + } + + // Write out the PID to the file now that server has successfully started + if err := c.StorePidFile(c.Config.PidFile); err != nil { + c.UI.Error(fmt.Sprintf("Error storing PID: %w", err)) + return 1 + } + + if c.flagDev { + if err := c.CreateDevDatabase(); err != nil { + c.UI.Error(fmt.Sprintf("Error creating dev database container: %s", err.Error())) + return 1 + } + c.ShutdownFuncs = append(c.ShutdownFuncs, c.DestroyDevDatabase) + } + + defer c.RunShutdownFuncs(c.UI) + + c.PrintInfo(c.UI, "worker") + c.ReleaseLogGate() + + return c.Start() +} + +func (c *Command) ParseFlagsAndConfig(args []string) int { + var err error + + f := c.Flags() + + if err = f.Parse(args); err != nil { + c.UI.Error(err.Error()) + return 1 + } + + // Validation + if !c.flagDev { + switch { + case len(c.flagConfig) == 0: + c.UI.Error("Must specify a config file using -config") + return 1 + case c.flagDevAdminToken != "": + c.UI.Warn(base.WrapAtLength( + "You cannot specify a custom admin token ID outside of \"dev\" mode. " + + "Your request has been ignored.")) + c.flagDevAdminToken = "" + } + + if len(c.flagConfig) == 0 { + c.UI.Error("Must supply a config file with -config") + return 1 + } + c.Config, err = config.LoadFile(c.flagConfig) + if err != nil { + c.UI.Error("Error parsing config: " + err.Error()) + return 1 + } + + } else { + c.Config, err = config.Dev() + if err != nil { + c.UI.Error(fmt.Sprintf("Error creating dev config: %s", err)) + return 1 + } + + if c.flagDevWorkerListenAddr != "" { + c.Config.Listeners[0].Address = c.flagDevWorkerListenAddr + } + } + + return 0 +} + +func (c *Command) Start() int { + // Instantiate the wait group + conf := &worker.Config{ + RawConfig: c.Config, + Server: c.Server, + } + + // Initialize the core + wrkr, err := worker.New(conf) + if err != nil { + c.UI.Error(fmt.Sprintf("Error initializing worker: %w", err)) + return 1 + } + + if err := wrkr.Start(); err != nil { + c.UI.Error(fmt.Sprint("Error starting worker: %w", err)) + if err := wrkr.Shutdown(); err != nil { + c.UI.Error(fmt.Sprintf("Error with worker shutdown: %w", err)) + } + return 1 + } + + // Wait for shutdown + shutdownTriggered := false + + for !shutdownTriggered { + select { + case <-c.ShutdownCh: + c.UI.Output("==> Watchtower worker shutdown triggered") + + if err := wrkr.Shutdown(); err != nil { + c.UI.Error(fmt.Sprintf("Error with worker shutdown: %w", err)) + } + + shutdownTriggered = true + + case <-c.SighupCh: + c.UI.Output("==> Watchtower worker reload triggered") + + // Check for new log level + var level hclog.Level + var err error + var newConf *config.Config + + if c.flagConfig == "" { + goto RUNRELOADFUNCS + } + + newConf, err = config.LoadFile(c.flagConfig) + if err != nil { + c.Logger.Error("could not reload config", "path", c.flagConfig, "error", err) + goto RUNRELOADFUNCS + } + + // Ensure at least one config was found. + if newConf == nil { + c.Logger.Error("no config found at reload time") + goto RUNRELOADFUNCS + } + + // Commented out until we need this + //wrkr.SetConfig(config) + + if newConf.LogLevel != "" { + configLogLevel := strings.ToLower(strings.TrimSpace(newConf.LogLevel)) + switch configLogLevel { + case "trace": + level = hclog.Trace + case "debug": + level = hclog.Debug + case "notice", "info", "": + level = hclog.Info + case "warn", "warning": + level = hclog.Warn + case "err", "error": + level = hclog.Error + default: + c.Logger.Error("unknown log level found on reload", "level", newConf.LogLevel) + goto RUNRELOADFUNCS + } + wrkr.SetLogLevel(level) + } + + RUNRELOADFUNCS: + if err := c.Reload(); err != nil { + c.UI.Error(fmt.Sprintf("Error(s) were encountered during worker reload: %w", err)) + } + + case <-c.SigUSR2Ch: + buf := make([]byte, 32*1024*1024) + n := runtime.Stack(buf[:], true) + c.Logger.Info("goroutine trace", "stack", string(buf[:n])) + } + } + + return 0 +} + +func (c *Command) Reload() error { + c.ReloadFuncsLock.RLock() + defer c.ReloadFuncsLock.RUnlock() + + var reloadErrors *multierror.Error + + for k, relFuncs := range c.ReloadFuncs { + switch { + case strings.HasPrefix(k, "listener|"): + for _, relFunc := range relFuncs { + if relFunc != nil { + if err := relFunc(); err != nil { + reloadErrors = multierror.Append(reloadErrors, fmt.Errorf("error encountered reloading listener: %w", err)) + } + } + } + } + } + + // Send a message that we reloaded. This prevents "guessing" sleep times + // in tests. + select { + case c.ReloadedCh <- struct{}{}: + default: + } + + return reloadErrors.ErrorOrNil() +} diff --git a/internal/servers/worker/client_tls.go b/internal/servers/worker/client_tls.go new file mode 100644 index 0000000000..ee855fed1a --- /dev/null +++ b/internal/servers/worker/client_tls.go @@ -0,0 +1,182 @@ +package worker + +import ( + "crypto/ed25519" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "io/ioutil" + "math/big" + mathrand "math/rand" + "net" + "os" + "path/filepath" + "time" + + "github.com/hashicorp/vault/sdk/helper/base62" +) + +type workerTLSOpts struct { + Address string + Protos []string + DumpDir string +} + +type certInfo struct { + CACert []byte `json:"ca_cert"` + CAKey []byte `json:"ca_key"` +} + +func (c Worker) workerTLS(opts workerTLSOpts) (*tls.Config, *certInfo, error) { + info := new(certInfo) + + certIPs := []net.IP{ + net.IPv6loopback, + net.ParseIP("127.0.0.1"), + } + + if opts.Address != "" { + baseAddr, err := net.ResolveTCPAddr("tcp", opts.Address) + if err != nil { + return nil, nil, err + } + certIPs = append(certIPs, baseAddr.IP) + } + + _, caKey, err := ed25519.GenerateKey(c.conf.SecureRandomReader) + if err != nil { + return nil, nil, err + } + info.CAKey = caKey + caHost, err := base62.Random(20) + if err != nil { + return nil, nil, err + } + + caCertTemplate := &x509.Certificate{ + Subject: pkix.Name{ + CommonName: caHost, + }, + DNSNames: []string{caHost}, + IPAddresses: certIPs, + KeyUsage: x509.KeyUsage(x509.KeyUsageCertSign | x509.KeyUsageCRLSign), + SerialNumber: big.NewInt(mathrand.Int63()), + NotBefore: time.Now().Add(-30 * time.Second), + NotAfter: time.Now().Add(262980 * time.Hour), + BasicConstraintsValid: true, + IsCA: true, + } + caBytes, err := x509.CreateCertificate(c.conf.SecureRandomReader, caCertTemplate, caCertTemplate, caKey.Public(), caKey) + if err != nil { + return nil, nil, err + } + info.CACert = caBytes + caCert, err := x509.ParseCertificate(caBytes) + if err != nil { + return nil, nil, err + } + + rootCAs := x509.NewCertPool() + rootCAs.AddCert(caCert) + caCertPEMBlock := &pem.Block{ + Type: "CERTIFICATE", + Bytes: caBytes, + } + caCertPEM := pem.EncodeToMemory(caCertPEMBlock) + caCertPEMFile := filepath.Join(opts.DumpDir, "ca_cert.pem") + + marshaledCAKey, err := x509.MarshalPKCS8PrivateKey(caKey) + if err != nil { + return nil, nil, err + } + caKeyPEMBlock := &pem.Block{ + Type: "PRIVATE KEY", + Bytes: marshaledCAKey, + } + caKeyPEM := pem.EncodeToMemory(caKeyPEMBlock) + + // + // Certs generation + // + _, key, err := ed25519.GenerateKey(c.conf.SecureRandomReader) + if err != nil { + return nil, nil, err + } + host, err := base62.Random(20) + if err != nil { + return nil, nil, err + } + certTemplate := &x509.Certificate{ + Subject: pkix.Name{ + CommonName: host, + }, + DNSNames: []string{host}, + IPAddresses: certIPs, + ExtKeyUsage: []x509.ExtKeyUsage{ + x509.ExtKeyUsageServerAuth, + x509.ExtKeyUsageClientAuth, + }, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement, + SerialNumber: big.NewInt(mathrand.Int63()), + NotBefore: time.Now().Add(-30 * time.Second), + NotAfter: time.Now().Add(262980 * time.Hour), + } + certBytes, err := x509.CreateCertificate(c.conf.SecureRandomReader, certTemplate, caCert, key.Public(), caKey) + if err != nil { + return nil, nil, err + } + certPEMBlock := &pem.Block{ + Type: "CERTIFICATE", + Bytes: certBytes, + } + certPEM := pem.EncodeToMemory(certPEMBlock) + marshaledKey, err := x509.MarshalPKCS8PrivateKey(key) + if err != nil { + return nil, nil, err + } + keyPEMBlock := &pem.Block{ + Type: "PRIVATE KEY", + Bytes: marshaledKey, + } + keyPEM := pem.EncodeToMemory(keyPEMBlock) + + certFile := filepath.Join(opts.DumpDir, "cert.pem") + keyFile := filepath.Join(opts.DumpDir, "key.pem") + + tlsCert, err := tls.X509KeyPair(certPEM, keyPEM) + if err != nil { + return nil, nil, err + } + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{tlsCert}, + RootCAs: rootCAs, + ClientCAs: rootCAs, + ClientAuth: tls.RequestClientCert, + NextProtos: opts.Protos, + MinVersion: tls.VersionTLS13, + } + tlsConfig.BuildNameToCertificate() + + if opts.DumpDir != "" { + if _, err := os.Stat(opts.DumpDir); os.IsNotExist(err) { + if err := os.MkdirAll(opts.DumpDir, 0700); err != nil { + return nil, nil, err + } + } + if err := ioutil.WriteFile(filepath.Join(opts.DumpDir, "ca_key.pem"), caKeyPEM, 0755); err != nil { + return nil, nil, err + } + if err := ioutil.WriteFile(caCertPEMFile, caCertPEM, 0755); err != nil { + return nil, nil, err + } + if err := ioutil.WriteFile(certFile, certPEM, 0755); err != nil { + return nil, nil, err + } + if err := ioutil.WriteFile(keyFile, keyPEM, 0755); err != nil { + return nil, nil, err + } + } + + return tlsConfig, info, nil +} diff --git a/internal/servers/worker/config.go b/internal/servers/worker/config.go new file mode 100644 index 0000000000..a747b8744f --- /dev/null +++ b/internal/servers/worker/config.go @@ -0,0 +1,14 @@ +package worker + +import ( + "context" + + "github.com/hashicorp/watchtower/internal/cmd/base" + "github.com/hashicorp/watchtower/internal/cmd/commands/worker/config" +) + +type Config struct { + *base.Server + RawConfig *config.Config + BaseContext context.Context +} diff --git a/internal/servers/worker/handler.go b/internal/servers/worker/handler.go new file mode 100644 index 0000000000..e444a30662 --- /dev/null +++ b/internal/servers/worker/handler.go @@ -0,0 +1,159 @@ +package worker + +import ( + "context" + "net/http" + "time" + + "github.com/hashicorp/vault/internalshared/configutil" + "github.com/hashicorp/watchtower/globals" +) + +type HandlerProperties struct { + ListenerConfig *configutil.Listener +} + +// Handler returns an http.Handler for the API. This can be used on +// its own to mount the Vault API within another web server. +func (c *Worker) Handler(props HandlerProperties) http.Handler { + // Create the muxer to handle the actual endpoints + mux := http.NewServeMux() + + mux.Handle("/v1/", handleDummy()) + + genericWrappedHandler := c.wrapGenericHandler(mux, props) + + return genericWrappedHandler +} + +func handleDummy() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"foo": "bar"}`)) + }) +} + +func (c *Worker) wrapGenericHandler(h http.Handler, props HandlerProperties) http.Handler { + var maxRequestDuration time.Duration + var maxRequestSize int64 + if props.ListenerConfig != nil { + maxRequestDuration = props.ListenerConfig.MaxRequestDuration + maxRequestSize = props.ListenerConfig.MaxRequestSize + } + if maxRequestDuration == 0 { + maxRequestDuration = globals.DefaultMaxRequestDuration + } + if maxRequestSize == 0 { + maxRequestSize = globals.DefaultMaxRequestSize + } + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Set the Cache-Control header for all responses returned + w.Header().Set("Cache-Control", "no-store") + + // Start with the request context + ctx := r.Context() + var cancelFunc context.CancelFunc + // Add our timeout + ctx, cancelFunc = context.WithTimeout(ctx, maxRequestDuration) + // Add a size limiter if desired + if maxRequestSize > 0 { + ctx = context.WithValue(ctx, "max_request_size", maxRequestSize) + } + ctx = context.WithValue(ctx, "original_request_path", r.URL.Path) + r = r.WithContext(ctx) + + h.ServeHTTP(w, r) + cancelFunc() + return + }) +} + +/* +func WrapForwardedForHandler(h http.Handler, authorizedAddrs []*sockaddr.SockAddrMarshaler, rejectNotPresent, rejectNonAuthz bool, hopSkips int) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + headers, headersOK := r.Header[textproto.CanonicalMIMEHeaderKey("X-Forwarded-For")] + if !headersOK || len(headers) == 0 { + if !rejectNotPresent { + h.ServeHTTP(w, r) + return + } + respondError(w, http.StatusBadRequest, fmt.Errorf("missing x-forwarded-for header and configured to reject when not present")) + return + } + + host, port, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + // If not rejecting treat it like we just don't have a valid + // header because we can't do a comparison against an address we + // can't understand + if !rejectNotPresent { + h.ServeHTTP(w, r) + return + } + respondError(w, http.StatusBadRequest, errwrap.Wrapf("error parsing client hostport: {{err}}", err)) + return + } + + addr, err := sockaddr.NewIPAddr(host) + if err != nil { + // We treat this the same as the case above + if !rejectNotPresent { + h.ServeHTTP(w, r) + return + } + respondError(w, http.StatusBadRequest, errwrap.Wrapf("error parsing client address: {{err}}", err)) + return + } + + var found bool + for _, authz := range authorizedAddrs { + if authz.Contains(addr) { + found = true + break + } + } + if !found { + // If we didn't find it and aren't configured to reject, simply + // don't trust it + if !rejectNonAuthz { + h.ServeHTTP(w, r) + return + } + respondError(w, http.StatusBadRequest, fmt.Errorf("client address not authorized for x-forwarded-for and configured to reject connection")) + return + } + + // At this point we have at least one value and it's authorized + + // Split comma separated ones, which are common. This brings it in line + // to the multiple-header case. + var acc []string + for _, header := range headers { + vals := strings.Split(header, ",") + for _, v := range vals { + acc = append(acc, strings.TrimSpace(v)) + } + } + + indexToUse := len(acc) - 1 - hopSkips + if indexToUse < 0 { + // This is likely an error in either configuration or other + // infrastructure. We could either deny the request, or we + // could simply not trust the value. Denying the request is + // "safer" since if this logic is configured at all there may + // be an assumption it can always be trusted. Given that we can + // deny accepting the request at all if it's not from an + // authorized address, if we're at this point the address is + // authorized (or we've turned off explicit rejection) and we + // should assume that what comes in should be properly + // formatted. + respondError(w, http.StatusBadRequest, fmt.Errorf("malformed x-forwarded-for configuration or request, hops to skip (%d) would skip before earliest chain link (chain length %d)", hopSkips, len(headers))) + return + } + + r.RemoteAddr = net.JoinHostPort(acc[indexToUse], port) + h.ServeHTTP(w, r) + return + }) +} +*/ diff --git a/internal/servers/worker/listeners.go b/internal/servers/worker/listeners.go new file mode 100644 index 0000000000..dfcd16de85 --- /dev/null +++ b/internal/servers/worker/listeners.go @@ -0,0 +1,148 @@ +package worker + +import ( + "context" + "errors" + "fmt" + "net" + "net/http" + "sync" + "time" + + "github.com/hashicorp/go-alpnmux" + "github.com/hashicorp/go-multierror" +) + +func (c *Worker) startListeners() error { + var retErr *multierror.Error + servers := make([]func(), 0, len(c.conf.Listeners)) + for _, ln := range c.conf.Listeners { + handler := c.Handler(HandlerProperties{ + ListenerConfig: ln.Config, + }) + + /* + // TODO: As I write this Vault's having this code audited, make sure to + // port over any recommendations + // + // We perform validation on the config earlier, we can just cast here + if _, ok := ln.config["x_forwarded_for_authorized_addrs"]; ok { + hopSkips := ln.config["x_forwarded_for_hop_skips"].(int) + authzdAddrs := ln.config["x_forwarded_for_authorized_addrs"].([]*sockaddr.SockAddrMarshaler) + rejectNotPresent := ln.config["x_forwarded_for_reject_not_present"].(bool) + rejectNonAuthz := ln.config["x_forwarded_for_reject_not_authorized"].(bool) + if len(authzdAddrs) > 0 { + handler = vaulthttp.WrapForwardedForHandler(handler, authzdAddrs, rejectNotPresent, rejectNonAuthz, hopSkips) + } + } + */ + + server := &http.Server{ + Handler: handler, + ReadHeaderTimeout: 10 * time.Second, + ReadTimeout: 30 * time.Second, + IdleTimeout: 5 * time.Minute, + ErrorLog: c.conf.Logger.StandardLogger(nil), + BaseContext: func(net.Listener) context.Context { + return c.baseContext + }, + } + ln.HTTPServer = server + + if ln.Config.HTTPReadHeaderTimeout > 0 { + server.ReadHeaderTimeout = ln.Config.HTTPReadHeaderTimeout + } + if ln.Config.HTTPReadTimeout > 0 { + server.ReadTimeout = ln.Config.HTTPReadTimeout + } + if ln.Config.HTTPWriteTimeout > 0 { + server.WriteTimeout = ln.Config.HTTPWriteTimeout + } + if ln.Config.HTTPIdleTimeout > 0 { + server.IdleTimeout = ln.Config.HTTPIdleTimeout + } + + switch ln.Config.TLSDisable { + case true: + l := ln.Mux.GetListener(alpnmux.NoProto) + if l == nil { + retErr = multierror.Append(retErr, errors.New("could not get non-tls listener")) + continue + } + servers = append(servers, func() { + go server.Serve(l) + }) + + default: + protos := []string{"", "http/1.1", "h2"} + for _, v := range protos { + l := ln.Mux.GetListener(v) + if l == nil { + retErr = multierror.Append(retErr, fmt.Errorf("could not get tls proto %q listener", v)) + continue + } + servers = append(servers, func() { + go server.Serve(l) + }) + } + } + + workerTLSConfig, peeringInfo, err := c.workerTLS(workerTLSOpts{ + Address: ln.Config.Address, + Protos: []string{"watchtower-worker-v1"}, + }) + if err != nil { + retErr = multierror.Append(retErr, fmt.Errorf("error getting TLS configuration: %w", err)) + continue + } + l, err := ln.Mux.RegisterProto("watchtower-worker-v1", workerTLSConfig) + if err != nil { + retErr = multierror.Append(retErr, fmt.Errorf("error getting sub-listener for worker proto: %w", err)) + continue + } + + // TODO: Start listner for real; for now send it to the http server just for testing + servers = append(servers, func() { + go server.Serve(l) + }) + + // TODO: Add peering info into database + _ = peeringInfo + } + + err := retErr.ErrorOrNil() + if err != nil { + return err + } + + for _, s := range servers { + s() + } + + return nil +} + +func (c *Worker) stopListeners() error { + serverWg := new(sync.WaitGroup) + for _, ln := range c.conf.Listeners { + if ln.HTTPServer == nil { + continue + } + serverWg.Add(1) + go func() { + shutdownKill, shutdownKillCancel := context.WithTimeout(c.baseContext, ln.Config.MaxRequestDuration) + defer shutdownKillCancel() + defer serverWg.Done() + ln.HTTPServer.Shutdown(shutdownKill) + }() + } + serverWg.Wait() + + var retErr *multierror.Error + for _, ln := range c.conf.Listeners { + if err := ln.Mux.Close(); err != nil { + retErr = multierror.Append(retErr, err) + } + } + return retErr.ErrorOrNil() +} diff --git a/internal/servers/worker/worker.go b/internal/servers/worker/worker.go new file mode 100644 index 0000000000..7f20eb8d0a --- /dev/null +++ b/internal/servers/worker/worker.go @@ -0,0 +1,76 @@ +package worker + +import ( + "context" + "crypto/rand" + "fmt" + + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault/sdk/helper/mlock" +) + +type Worker struct { + conf *Config + + baseContext context.Context + baseCancel context.CancelFunc +} + +func New(conf *Config) (*Worker, error) { + if conf.Logger == nil { + conf.Logger = hclog.New(&hclog.LoggerOptions{ + Level: hclog.Trace, + }) + conf.AllLoggers = append(conf.AllLoggers, conf.Logger) + } + + if conf.SecureRandomReader == nil { + conf.SecureRandomReader = rand.Reader + } + + if !conf.RawConfig.DisableMlock { + // Ensure our memory usage is locked into physical RAM + if err := mlock.LockMemory(); err != nil { + return nil, fmt.Errorf( + "Failed to lock memory: %v\n\n"+ + "This usually means that the mlock syscall is not available.\n"+ + "Watchtower uses mlock to prevent memory from being swapped to\n"+ + "disk. This requires root privileges as well as a machine\n"+ + "that supports mlock. Please enable mlock on your system or\n"+ + "disable Watchtower from using it. To disable Watchtower from using it,\n"+ + "set the `disable_mlock` configuration option in your configuration\n"+ + "file.", + err) + } + } + + conf.Logger = conf.Logger.Named("worker") + + c := &Worker{ + conf: conf, + } + + c.baseContext, c.baseCancel = context.WithCancel(context.Background()) + + return c, nil +} + +func (c *Worker) Start() error { + if err := c.startListeners(); err != nil { + return err + } + return nil +} + +func (c *Worker) Shutdown() error { + if err := c.stopListeners(); err != nil { + return err + } + return nil +} + +func (c *Worker) SetLogLevel(level hclog.Level) { + for _, logger := range c.conf.AllLoggers { + logger.SetLevel(level) + } +}