Migrate a bunch of stuff around to prep for worker command (#2)

Right now it's essentially a copy; it needs some updating, integration
into dev, and so on.
pull/3/head
Jeff Mitchell 6 years ago committed by GitHub
parent c2d385d851
commit 1645c3106e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,18 +1,44 @@
package listener
package base
import (
"errors"
"fmt"
"io"
"net"
"strings"
"time"
// We must import sha512 so that it registers with the runtime so that
// certificates that use it can be parsed.
_ "crypto/sha512"
"github.com/hashicorp/go-alpnmux"
"github.com/hashicorp/vault/internalshared/configutil"
"github.com/hashicorp/vault/internalshared/listenerutil"
"github.com/hashicorp/vault/internalshared/reloadutil"
"github.com/mitchellh/cli"
"github.com/pires/go-proxyproto"
)
// Factory is the factory function to create a listener.
type ListenerFactory func(*configutil.Listener, io.Writer, cli.Ui) (*alpnmux.ALPNMux, map[string]string, reloadutil.ReloadFunc, error)
// BuiltinListeners is the list of built-in listener types.
var BuiltinListeners = map[string]ListenerFactory{
"tcp": tcpListenerFactory,
}
// New creates a new listener of the given type with the given
// configuration. The type is looked up in the BuiltinListeners map.
func NewListener(l *configutil.Listener, w io.Writer, ui cli.Ui) (*alpnmux.ALPNMux, map[string]string, reloadutil.ReloadFunc, error) {
f, ok := BuiltinListeners[l.Type]
if !ok {
return nil, nil, nil, fmt.Errorf("unknown listener type: %q", l.Type)
}
return f(l, w, ui)
}
func tcpListenerFactory(l *configutil.Listener, _ io.Writer, ui cli.Ui) (*alpnmux.ALPNMux, map[string]string, reloadutil.ReloadFunc, error) {
if l.Address == "" {
l.Address = "127.0.0.1:9200"
@ -69,6 +95,49 @@ func tcpListenerFactory(l *configutil.Listener, _ io.Writer, ui cli.Ui) (*alpnmu
return alpnMux, props, reloadFunc, nil
}
func listenerWrapProxy(ln net.Listener, l *configutil.Listener) (net.Listener, error) {
behavior := l.ProxyProtocolBehavior
if behavior == "" {
return ln, nil
}
authorizedAddrs := make([]string, 0, len(l.ProxyProtocolAuthorizedAddrs))
for _, v := range l.ProxyProtocolAuthorizedAddrs {
authorizedAddrs = append(authorizedAddrs, v.String())
}
var policyFunc proxyproto.PolicyFunc
switch behavior {
case "use_always":
policyFunc = func(upstream net.Addr) (proxyproto.Policy, error) {
return proxyproto.USE, nil
}
case "allow_authorized":
if len(authorizedAddrs) == 0 {
return nil, errors.New("proxy_protocol_behavior set but no proxy_protocol_authorized_addrs value")
}
policyFunc = proxyproto.MustLaxWhiteListPolicy(authorizedAddrs)
case "deny_unauthorized":
if len(authorizedAddrs) == 0 {
return nil, errors.New("proxy_protocol_behavior set but no proxy_protocol_authorized_addrs value")
}
policyFunc = proxyproto.MustStrictWhiteListPolicy(authorizedAddrs)
default:
return nil, fmt.Errorf("unknown %q value: %q", "proxy_protocol_behavior", behavior)
}
proxyListener := &proxyproto.Listener{
Listener: ln,
Policy: policyFunc,
}
return proxyListener, nil
}
// TCPKeepAliveListener sets TCP keep-alive timeouts on accepted
// connections. It's used by ListenAndServe and ListenAndServeTLS so
// dead TCP connections (e.g. closing laptop mid-download) eventually

@ -0,0 +1,8 @@
// +build !memprofiler
package base
import "github.com/hashicorp/go-hclog"
func StartMemProfiler(_ hclog.Logger) {
}

@ -1,6 +1,6 @@
// +build memprofiler
package dev
package base
import (
"os"
@ -8,16 +8,18 @@ import (
"runtime"
"runtime/pprof"
"time"
"github.com/hashicorp/go-hclog"
)
func init() {
memProfilerEnabled = true
}
func (d *Command) startMemProfiler() {
func StartMemProfiler(logger hclog.Logger) {
profileDir := filepath.Join(os.TempDir(), "watchtowerprof")
if err := os.MkdirAll(profileDir, 0700); err != nil {
d.logger.Debug("could not create profile directory", "error", err)
logger.Debug("could not create profile directory", "error", err)
return
}
@ -26,14 +28,14 @@ func (d *Command) startMemProfiler() {
filename := filepath.Join(profileDir, time.Now().UTC().Format("20060102_150405")) + ".pprof"
f, err := os.Create(filename)
if err != nil {
d.logger.Debug("could not create memory profile", "error", err)
logger.Debug("could not create memory profile", "error", err)
}
runtime.GC()
if err := pprof.WriteHeapProfile(f); err != nil {
d.logger.Debug("could not write memory profile", "error", err)
logger.Debug("could not write memory profile", "error", err)
}
f.Close()
d.logger.Debug("wrote memory profile", "filename", filename)
logger.Debug("wrote memory profile", "filename", filename)
time.Sleep(5 * time.Minute)
}
}()

@ -28,7 +28,6 @@ import (
"github.com/hashicorp/vault/sdk/helper/mlock"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/watchtower/globals"
"github.com/hashicorp/watchtower/internal/cmd/commands/controller/listener"
"github.com/hashicorp/watchtower/version"
"github.com/mitchellh/cli"
"github.com/ory/dockertest/v3"
@ -250,7 +249,7 @@ func (b *Server) SetupListeners(ui cli.Ui, config *configutil.SharedConfig) erro
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
}
lnMux, props, reloadFunc, err := listener.New(lnConfig, b.GatedWriter, ui)
lnMux, props, reloadFunc, err := NewListener(lnConfig, b.GatedWriter, ui)
if err != nil {
return fmt.Errorf("Error initializing listener of type %s: %w", lnConfig.Type, err)
}

@ -8,6 +8,7 @@ import (
"github.com/hashicorp/watchtower/internal/cmd/base"
"github.com/hashicorp/watchtower/internal/cmd/commands/controller"
"github.com/hashicorp/watchtower/internal/cmd/commands/dev"
"github.com/hashicorp/watchtower/internal/cmd/commands/worker"
"github.com/mitchellh/cli"
)
@ -36,6 +37,17 @@ func initCommands(ui, serverCmdUi cli.Ui, runOpts *RunOptions) {
SigUSR2Ch: MakeSigUSR2Ch(),
}, nil
},
"worker": func() (cli.Command, error) {
return &worker.Command{
Command: &base.Command{
UI: serverCmdUi,
Address: runOpts.Address,
},
ShutdownCh: MakeShutdownCh(),
SighupCh: MakeSighupCh(),
SigUSR2Ch: MakeSigUSR2Ch(),
}, nil
},
"dev": func() (cli.Command, error) {
return &dev.Command{
Command: &base.Command{

@ -54,7 +54,7 @@ Usage: watchtower controller [options]
Start a controller with a configuration file:
$ watchtower controller -config=/etc/controller/config.hcl
$ watchtower controller -config=/etc/watchtower/controller.hcl
For a full list of examples, please see the documentation.
@ -147,15 +147,15 @@ func (c *Command) Run(args []string) int {
return result
}
if memProfilerEnabled {
c.startMemProfiler()
}
if err := c.SetupLogging(c.flagLogLevel, c.flagLogFormat, c.Config.LogLevel, c.Config.LogFormat); err != nil {
c.UI.Error(err.Error())
return 1
}
if memProfilerEnabled {
base.StartMemProfiler(c.Logger)
}
if err := c.SetupMetrics(c.UI, c.Config.Telemetry); err != nil {
c.UI.Error(err.Error())
return 1
@ -259,21 +259,21 @@ func (c *Command) ParseFlagsAndConfig(args []string) int {
func (c *Command) Start() int {
// Instantiate the wait group
controllerConfig := &controller.Config{
conf := &controller.Config{
RawConfig: c.Config,
Server: c.Server,
}
// Initialize the core
controller, err := controller.New(controllerConfig)
ctlr, err := controller.New(conf)
if err != nil {
c.UI.Error(fmt.Sprintf("Error initializing controller: %w", err))
return 1
}
if err := controller.Start(); err != nil {
if err := ctlr.Start(); err != nil {
c.UI.Error(fmt.Sprint("Error starting controller: %w", err))
if err := controller.Shutdown(); err != nil {
if err := ctlr.Shutdown(); err != nil {
c.UI.Error(fmt.Sprintf("Error with controller shutdown: %w", err))
}
return 1
@ -287,7 +287,7 @@ func (c *Command) Start() int {
case <-c.ShutdownCh:
c.UI.Output("==> Watchtower controller shutdown triggered")
if err := controller.Shutdown(); err != nil {
if err := ctlr.Shutdown(); err != nil {
c.UI.Error(fmt.Sprintf("Error with controller shutdown: %w", err))
}
@ -337,7 +337,7 @@ func (c *Command) Start() int {
c.Logger.Error("unknown log level found on reload", "level", newConf.LogLevel)
goto RUNRELOADFUNCS
}
controller.SetLogLevel(level)
ctlr.SetLogLevel(level)
}
RUNRELOADFUNCS:

@ -1,6 +0,0 @@
// +build !memprofiler
package controller
func (c *Command) startMemProfiler() {
}

@ -1,40 +0,0 @@
// +build memprofiler
package controller
import (
"os"
"path/filepath"
"runtime"
"runtime/pprof"
"time"
)
func init() {
memProfilerEnabled = true
}
func (c *Command) startMemProfiler() {
profileDir := filepath.Join(os.TempDir(), "watchtowerprof")
if err := os.MkdirAll(profileDir, 0700); err != nil {
c.logger.Debug("could not create profile directory", "error", err)
return
}
go func() {
for {
filename := filepath.Join(profileDir, time.Now().UTC().Format("20060102_150405")) + ".pprof"
f, err := os.Create(filename)
if err != nil {
c.logger.Debug("could not create memory profile", "error", err)
}
runtime.GC()
if err := pprof.WriteHeapProfile(f); err != nil {
c.logger.Debug("could not write memory profile", "error", err)
}
f.Close()
c.logger.Debug("wrote memory profile", "filename", filename)
time.Sleep(5 * time.Minute)
}
}()
}

@ -1,81 +0,0 @@
package listener
import (
"errors"
"io"
"github.com/pires/go-proxyproto"
// We must import sha512 so that it registers with the runtime so that
// certificates that use it can be parsed.
_ "crypto/sha512"
"fmt"
"net"
"github.com/hashicorp/go-alpnmux"
"github.com/hashicorp/vault/internalshared/configutil"
"github.com/hashicorp/vault/internalshared/reloadutil"
"github.com/mitchellh/cli"
)
// Factory is the factory function to create a listener.
type Factory func(*configutil.Listener, io.Writer, cli.Ui) (*alpnmux.ALPNMux, map[string]string, reloadutil.ReloadFunc, error)
// BuiltinListeners is the list of built-in listener types.
var BuiltinListeners = map[string]Factory{
"tcp": tcpListenerFactory,
}
// New creates a new listener of the given type with the given
// configuration. The type is looked up in the BuiltinListeners map.
func New(l *configutil.Listener, w io.Writer, ui cli.Ui) (*alpnmux.ALPNMux, map[string]string, reloadutil.ReloadFunc, error) {
f, ok := BuiltinListeners[l.Type]
if !ok {
return nil, nil, nil, fmt.Errorf("unknown listener type: %q", l.Type)
}
return f(l, w, ui)
}
func listenerWrapProxy(ln net.Listener, l *configutil.Listener) (net.Listener, error) {
behavior := l.ProxyProtocolBehavior
if behavior == "" {
return ln, nil
}
authorizedAddrs := make([]string, 0, len(l.ProxyProtocolAuthorizedAddrs))
for _, v := range l.ProxyProtocolAuthorizedAddrs {
authorizedAddrs = append(authorizedAddrs, v.String())
}
var policyFunc proxyproto.PolicyFunc
switch behavior {
case "use_always":
policyFunc = func(upstream net.Addr) (proxyproto.Policy, error) {
return proxyproto.USE, nil
}
case "allow_authorized":
if len(authorizedAddrs) == 0 {
return nil, errors.New("proxy_protocol_behavior set but no proxy_protocol_authorized_addrs value")
}
policyFunc = proxyproto.MustLaxWhiteListPolicy(authorizedAddrs)
case "deny_unauthorized":
if len(authorizedAddrs) == 0 {
return nil, errors.New("proxy_protocol_behavior set but no proxy_protocol_authorized_addrs value")
}
policyFunc = proxyproto.MustStrictWhiteListPolicy(authorizedAddrs)
default:
return nil, fmt.Errorf("unknown %q value: %q", "proxy_protocol_behavior", behavior)
}
proxyListener := &proxyproto.Listener{
Listener: ln,
Policy: policyFunc,
}
return proxyListener, nil
}

@ -138,15 +138,15 @@ func (c *Command) Run(args []string) int {
return 1
}
if memProfilerEnabled {
c.startMemProfiler()
}
if err := c.SetupLogging(c.flagLogLevel, c.flagLogFormat, "", ""); err != nil {
c.UI.Error(err.Error())
return 1
}
if memProfilerEnabled {
base.StartMemProfiler(c.Logger)
}
if err := c.SetupMetrics(c.UI, devControllerConfig.Telemetry); err != nil {
c.UI.Error(err.Error())
return 1

@ -1,6 +0,0 @@
// +build !memprofiler
package dev
func (d *Command) startMemProfiler() {
}

@ -0,0 +1,122 @@
package config
import (
"bytes"
"crypto/rand"
"encoding/base64"
"fmt"
"io"
"io/ioutil"
"github.com/hashicorp/hcl"
"github.com/hashicorp/vault/internalshared/configutil"
)
// Config is the configuration for the watchtower controller
type Config struct {
*configutil.SharedConfig `hcl:"-"`
}
// Dev is a Config that is used for dev mode of Watchtower
func Dev() (*Config, error) {
randBuf := new(bytes.Buffer)
n, err := randBuf.ReadFrom(&io.LimitedReader{
R: rand.Reader,
N: 64,
})
if err != nil {
return nil, err
}
if n != 64 {
return nil, fmt.Errorf("expected to read 64 bytes, read %d", n)
}
controllerKey := base64.StdEncoding.EncodeToString(randBuf.Bytes()[0:32])
workerAuthKey := base64.StdEncoding.EncodeToString(randBuf.Bytes()[32:64])
hclStr := `
disable_mlock = true
listener "tcp" {
tls_disable = true
proxy_protocol_behavior = "allow_authorized"
proxy_protocol_authorized_addrs = "127.0.0.1"
}
telemetry {
prometheus_retention_time = "24h"
disable_hostname = true
}
`
hclStr = fmt.Sprintf(hclStr, controllerKey, workerAuthKey)
parsed, err := Parse(hclStr)
if err != nil {
return nil, fmt.Errorf("error parsing dev config: %w", err)
}
return parsed, nil
}
func New() *Config {
return &Config{
SharedConfig: new(configutil.SharedConfig),
}
}
// LoadFile loads the configuration from the given file.
func LoadFile(path string) (*Config, error) {
// Read the file
d, err := ioutil.ReadFile(path)
if err != nil {
return nil, err
}
conf, err := Parse(string(d))
if err != nil {
return nil, err
}
return conf, nil
}
func Parse(d string) (*Config, error) {
obj, err := hcl.Parse(d)
if err != nil {
return nil, err
}
// Nothing to do here right now
result := New()
if err := hcl.DecodeObject(result, obj); err != nil {
return nil, err
}
sharedConfig, err := configutil.ParseConfig(d)
if err != nil {
return nil, err
}
result.SharedConfig = sharedConfig
return result, nil
}
// Sanitized returns a copy of the config with all values that are considered
// sensitive stripped. It also strips all `*Raw` values that are mainly
// used for parsing.
//
// Specifically, the fields that this method strips are:
// - KMS.Config
// - Telemetry.CirconusAPIToken
func (c *Config) Sanitized() map[string]interface{} {
// Create shared config if it doesn't exist (e.g. in tests) so that map
// keys are actually populated
if c.SharedConfig == nil {
c.SharedConfig = new(configutil.SharedConfig)
}
sharedResult := c.SharedConfig.Sanitized()
result := map[string]interface{}{}
for k, v := range sharedResult {
result[k] = v
}
return result
}

@ -0,0 +1,385 @@
package worker
import (
"fmt"
"runtime"
"strings"
"sync"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-multierror"
"github.com/hashicorp/vault/sdk/helper/mlock"
"github.com/hashicorp/watchtower/globals"
"github.com/hashicorp/watchtower/internal/cmd/base"
"github.com/hashicorp/watchtower/internal/cmd/commands/worker/config"
"github.com/hashicorp/watchtower/internal/servers/worker"
"github.com/mitchellh/cli"
"github.com/posener/complete"
)
var _ cli.Command = (*Command)(nil)
var _ cli.CommandAutocomplete = (*Command)(nil)
var memProfilerEnabled = false
type Command struct {
*base.Command
*base.Server
ShutdownCh chan struct{}
SighupCh chan struct{}
ReloadedCh chan struct{}
SigUSR2Ch chan struct{}
cleanupGuard sync.Once
Config *config.Config
flagConfig string
flagLogLevel string
flagLogFormat string
flagDev bool
flagDevAdminToken string
flagDevWorkerListenAddr string
flagCombineLogs bool
}
func (c *Command) Synopsis() string {
return "Start a Watchtower worker"
}
func (c *Command) Help() string {
helpText := `
Usage: watchtower worker [options]
Start a worker with a configuration file:
$ watchtower worker -config=/etc/watchtower/worker.hcl
For a full list of examples, please see the documentation.
` + c.Flags().Help()
return strings.TrimSpace(helpText)
}
func (c *Command) Flags() *base.FlagSets {
set := c.FlagSet(base.FlagSetHTTP)
f := set.NewFlagSet("Command Options")
f.StringVar(&base.StringVar{
Name: "config",
Target: &c.flagConfig,
Completion: complete.PredictOr(
complete.PredictFiles("*.hcl"),
complete.PredictFiles("*.json"),
),
Usage: "Path to a configuration file.",
})
f.StringVar(&base.StringVar{
Name: "log-level",
Target: &c.flagLogLevel,
Default: base.NotSetValue,
EnvVar: "WATCHTOWER_LOG_LEVEL",
Completion: complete.PredictSet("trace", "debug", "info", "warn", "err"),
Usage: "Log verbosity level. Supported values (in order of more detail to less) are " +
"\"trace\", \"debug\", \"info\", \"warn\", and \"err\".",
})
f.StringVar(&base.StringVar{
Name: "log-format",
Target: &c.flagLogFormat,
Default: base.NotSetValue,
Completion: complete.PredictSet("standard", "json"),
Usage: `Log format. Supported values are "standard" and "json".`,
})
f = set.NewFlagSet("Dev Options")
f.BoolVar(&base.BoolVar{
Name: "dev",
Target: &c.flagDev,
Usage: "Enable development mode. As the name implies, do not run \"dev\" mode in " +
"production.",
})
f.StringVar(&base.StringVar{
Name: "dev-admin-token",
Target: &c.flagDevAdminToken,
Default: "",
EnvVar: "WATCHTWER_DEV_ADMIN_TOKEN",
Usage: "Initial admin token. This only applies when running in \"dev\" " +
"mode.",
})
f.StringVar(&base.StringVar{
Name: "dev-listen-address",
Target: &c.flagDevWorkerListenAddr,
Default: "127.0.0.1:9200",
EnvVar: "WATCHTOWER_DEV_WORKER_LISTEN_ADDRESS",
Usage: "Address to bind the worker to in \"dev\" mode.",
})
f.BoolVar(&base.BoolVar{
Name: "combine-logs",
Target: &c.flagCombineLogs,
Default: false,
Usage: "If set, both startup information and logs will be sent to stdout. If not set (the default), startup information will go to stdout and logs will be sent to stderr.",
})
return set
}
func (c *Command) AutocompleteArgs() complete.Predictor {
return complete.PredictNothing
}
func (c *Command) AutocompleteFlags() complete.Flags {
return c.Flags().Completions()
}
func (c *Command) Run(args []string) int {
c.Server = base.NewServer()
c.CombineLogs = c.flagCombineLogs
if result := c.ParseFlagsAndConfig(args); result > 0 {
return result
}
if err := c.SetupLogging(c.flagLogLevel, c.flagLogFormat, c.Config.LogLevel, c.Config.LogFormat); err != nil {
c.UI.Error(err.Error())
return 1
}
if memProfilerEnabled {
base.StartMemProfiler(c.Logger)
}
if err := c.SetupMetrics(c.UI, c.Config.Telemetry); err != nil {
c.UI.Error(err.Error())
return 1
}
if err := c.SetupKMSes(c.UI, c.Config.SharedConfig, 2); err != nil {
c.UI.Error(err.Error())
return 1
}
if c.Config.DefaultMaxRequestDuration != 0 {
globals.DefaultMaxRequestDuration = c.Config.DefaultMaxRequestDuration
}
// If mlockall(2) isn't supported, show a warning. We disable this in dev
// because it is quite scary to see when first using Vault. We also disable
// this if the user has explicitly disabled mlock in configuration.
if !c.flagDev && !c.Config.DisableMlock && !mlock.Supported() {
c.UI.Warn(base.WrapAtLength(
"WARNING! mlock is not supported on this system! An mlockall(2)-like " +
"syscall to prevent memory from being swapped to disk is not " +
"supported on this system. For better security, only run Vault on " +
"systems where this call is supported. If you are running Vault " +
"in a Docker container, provide the IPC_LOCK cap to the container."))
}
if err := c.SetupListeners(c.UI, c.Config.SharedConfig); err != nil {
c.UI.Error(err.Error())
return 1
}
// Write out the PID to the file now that server has successfully started
if err := c.StorePidFile(c.Config.PidFile); err != nil {
c.UI.Error(fmt.Sprintf("Error storing PID: %w", err))
return 1
}
if c.flagDev {
if err := c.CreateDevDatabase(); err != nil {
c.UI.Error(fmt.Sprintf("Error creating dev database container: %s", err.Error()))
return 1
}
c.ShutdownFuncs = append(c.ShutdownFuncs, c.DestroyDevDatabase)
}
defer c.RunShutdownFuncs(c.UI)
c.PrintInfo(c.UI, "worker")
c.ReleaseLogGate()
return c.Start()
}
func (c *Command) ParseFlagsAndConfig(args []string) int {
var err error
f := c.Flags()
if err = f.Parse(args); err != nil {
c.UI.Error(err.Error())
return 1
}
// Validation
if !c.flagDev {
switch {
case len(c.flagConfig) == 0:
c.UI.Error("Must specify a config file using -config")
return 1
case c.flagDevAdminToken != "":
c.UI.Warn(base.WrapAtLength(
"You cannot specify a custom admin token ID outside of \"dev\" mode. " +
"Your request has been ignored."))
c.flagDevAdminToken = ""
}
if len(c.flagConfig) == 0 {
c.UI.Error("Must supply a config file with -config")
return 1
}
c.Config, err = config.LoadFile(c.flagConfig)
if err != nil {
c.UI.Error("Error parsing config: " + err.Error())
return 1
}
} else {
c.Config, err = config.Dev()
if err != nil {
c.UI.Error(fmt.Sprintf("Error creating dev config: %s", err))
return 1
}
if c.flagDevWorkerListenAddr != "" {
c.Config.Listeners[0].Address = c.flagDevWorkerListenAddr
}
}
return 0
}
func (c *Command) Start() int {
// Instantiate the wait group
conf := &worker.Config{
RawConfig: c.Config,
Server: c.Server,
}
// Initialize the core
wrkr, err := worker.New(conf)
if err != nil {
c.UI.Error(fmt.Sprintf("Error initializing worker: %w", err))
return 1
}
if err := wrkr.Start(); err != nil {
c.UI.Error(fmt.Sprint("Error starting worker: %w", err))
if err := wrkr.Shutdown(); err != nil {
c.UI.Error(fmt.Sprintf("Error with worker shutdown: %w", err))
}
return 1
}
// Wait for shutdown
shutdownTriggered := false
for !shutdownTriggered {
select {
case <-c.ShutdownCh:
c.UI.Output("==> Watchtower worker shutdown triggered")
if err := wrkr.Shutdown(); err != nil {
c.UI.Error(fmt.Sprintf("Error with worker shutdown: %w", err))
}
shutdownTriggered = true
case <-c.SighupCh:
c.UI.Output("==> Watchtower worker reload triggered")
// Check for new log level
var level hclog.Level
var err error
var newConf *config.Config
if c.flagConfig == "" {
goto RUNRELOADFUNCS
}
newConf, err = config.LoadFile(c.flagConfig)
if err != nil {
c.Logger.Error("could not reload config", "path", c.flagConfig, "error", err)
goto RUNRELOADFUNCS
}
// Ensure at least one config was found.
if newConf == nil {
c.Logger.Error("no config found at reload time")
goto RUNRELOADFUNCS
}
// Commented out until we need this
//wrkr.SetConfig(config)
if newConf.LogLevel != "" {
configLogLevel := strings.ToLower(strings.TrimSpace(newConf.LogLevel))
switch configLogLevel {
case "trace":
level = hclog.Trace
case "debug":
level = hclog.Debug
case "notice", "info", "":
level = hclog.Info
case "warn", "warning":
level = hclog.Warn
case "err", "error":
level = hclog.Error
default:
c.Logger.Error("unknown log level found on reload", "level", newConf.LogLevel)
goto RUNRELOADFUNCS
}
wrkr.SetLogLevel(level)
}
RUNRELOADFUNCS:
if err := c.Reload(); err != nil {
c.UI.Error(fmt.Sprintf("Error(s) were encountered during worker reload: %w", err))
}
case <-c.SigUSR2Ch:
buf := make([]byte, 32*1024*1024)
n := runtime.Stack(buf[:], true)
c.Logger.Info("goroutine trace", "stack", string(buf[:n]))
}
}
return 0
}
func (c *Command) Reload() error {
c.ReloadFuncsLock.RLock()
defer c.ReloadFuncsLock.RUnlock()
var reloadErrors *multierror.Error
for k, relFuncs := range c.ReloadFuncs {
switch {
case strings.HasPrefix(k, "listener|"):
for _, relFunc := range relFuncs {
if relFunc != nil {
if err := relFunc(); err != nil {
reloadErrors = multierror.Append(reloadErrors, fmt.Errorf("error encountered reloading listener: %w", err))
}
}
}
}
}
// Send a message that we reloaded. This prevents "guessing" sleep times
// in tests.
select {
case c.ReloadedCh <- struct{}{}:
default:
}
return reloadErrors.ErrorOrNil()
}

@ -0,0 +1,182 @@
package worker
import (
"crypto/ed25519"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"io/ioutil"
"math/big"
mathrand "math/rand"
"net"
"os"
"path/filepath"
"time"
"github.com/hashicorp/vault/sdk/helper/base62"
)
type workerTLSOpts struct {
Address string
Protos []string
DumpDir string
}
type certInfo struct {
CACert []byte `json:"ca_cert"`
CAKey []byte `json:"ca_key"`
}
func (c Worker) workerTLS(opts workerTLSOpts) (*tls.Config, *certInfo, error) {
info := new(certInfo)
certIPs := []net.IP{
net.IPv6loopback,
net.ParseIP("127.0.0.1"),
}
if opts.Address != "" {
baseAddr, err := net.ResolveTCPAddr("tcp", opts.Address)
if err != nil {
return nil, nil, err
}
certIPs = append(certIPs, baseAddr.IP)
}
_, caKey, err := ed25519.GenerateKey(c.conf.SecureRandomReader)
if err != nil {
return nil, nil, err
}
info.CAKey = caKey
caHost, err := base62.Random(20)
if err != nil {
return nil, nil, err
}
caCertTemplate := &x509.Certificate{
Subject: pkix.Name{
CommonName: caHost,
},
DNSNames: []string{caHost},
IPAddresses: certIPs,
KeyUsage: x509.KeyUsage(x509.KeyUsageCertSign | x509.KeyUsageCRLSign),
SerialNumber: big.NewInt(mathrand.Int63()),
NotBefore: time.Now().Add(-30 * time.Second),
NotAfter: time.Now().Add(262980 * time.Hour),
BasicConstraintsValid: true,
IsCA: true,
}
caBytes, err := x509.CreateCertificate(c.conf.SecureRandomReader, caCertTemplate, caCertTemplate, caKey.Public(), caKey)
if err != nil {
return nil, nil, err
}
info.CACert = caBytes
caCert, err := x509.ParseCertificate(caBytes)
if err != nil {
return nil, nil, err
}
rootCAs := x509.NewCertPool()
rootCAs.AddCert(caCert)
caCertPEMBlock := &pem.Block{
Type: "CERTIFICATE",
Bytes: caBytes,
}
caCertPEM := pem.EncodeToMemory(caCertPEMBlock)
caCertPEMFile := filepath.Join(opts.DumpDir, "ca_cert.pem")
marshaledCAKey, err := x509.MarshalPKCS8PrivateKey(caKey)
if err != nil {
return nil, nil, err
}
caKeyPEMBlock := &pem.Block{
Type: "PRIVATE KEY",
Bytes: marshaledCAKey,
}
caKeyPEM := pem.EncodeToMemory(caKeyPEMBlock)
//
// Certs generation
//
_, key, err := ed25519.GenerateKey(c.conf.SecureRandomReader)
if err != nil {
return nil, nil, err
}
host, err := base62.Random(20)
if err != nil {
return nil, nil, err
}
certTemplate := &x509.Certificate{
Subject: pkix.Name{
CommonName: host,
},
DNSNames: []string{host},
IPAddresses: certIPs,
ExtKeyUsage: []x509.ExtKeyUsage{
x509.ExtKeyUsageServerAuth,
x509.ExtKeyUsageClientAuth,
},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement,
SerialNumber: big.NewInt(mathrand.Int63()),
NotBefore: time.Now().Add(-30 * time.Second),
NotAfter: time.Now().Add(262980 * time.Hour),
}
certBytes, err := x509.CreateCertificate(c.conf.SecureRandomReader, certTemplate, caCert, key.Public(), caKey)
if err != nil {
return nil, nil, err
}
certPEMBlock := &pem.Block{
Type: "CERTIFICATE",
Bytes: certBytes,
}
certPEM := pem.EncodeToMemory(certPEMBlock)
marshaledKey, err := x509.MarshalPKCS8PrivateKey(key)
if err != nil {
return nil, nil, err
}
keyPEMBlock := &pem.Block{
Type: "PRIVATE KEY",
Bytes: marshaledKey,
}
keyPEM := pem.EncodeToMemory(keyPEMBlock)
certFile := filepath.Join(opts.DumpDir, "cert.pem")
keyFile := filepath.Join(opts.DumpDir, "key.pem")
tlsCert, err := tls.X509KeyPair(certPEM, keyPEM)
if err != nil {
return nil, nil, err
}
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{tlsCert},
RootCAs: rootCAs,
ClientCAs: rootCAs,
ClientAuth: tls.RequestClientCert,
NextProtos: opts.Protos,
MinVersion: tls.VersionTLS13,
}
tlsConfig.BuildNameToCertificate()
if opts.DumpDir != "" {
if _, err := os.Stat(opts.DumpDir); os.IsNotExist(err) {
if err := os.MkdirAll(opts.DumpDir, 0700); err != nil {
return nil, nil, err
}
}
if err := ioutil.WriteFile(filepath.Join(opts.DumpDir, "ca_key.pem"), caKeyPEM, 0755); err != nil {
return nil, nil, err
}
if err := ioutil.WriteFile(caCertPEMFile, caCertPEM, 0755); err != nil {
return nil, nil, err
}
if err := ioutil.WriteFile(certFile, certPEM, 0755); err != nil {
return nil, nil, err
}
if err := ioutil.WriteFile(keyFile, keyPEM, 0755); err != nil {
return nil, nil, err
}
}
return tlsConfig, info, nil
}

@ -0,0 +1,14 @@
package worker
import (
"context"
"github.com/hashicorp/watchtower/internal/cmd/base"
"github.com/hashicorp/watchtower/internal/cmd/commands/worker/config"
)
type Config struct {
*base.Server
RawConfig *config.Config
BaseContext context.Context
}

@ -0,0 +1,159 @@
package worker
import (
"context"
"net/http"
"time"
"github.com/hashicorp/vault/internalshared/configutil"
"github.com/hashicorp/watchtower/globals"
)
type HandlerProperties struct {
ListenerConfig *configutil.Listener
}
// Handler returns an http.Handler for the API. This can be used on
// its own to mount the Vault API within another web server.
func (c *Worker) Handler(props HandlerProperties) http.Handler {
// Create the muxer to handle the actual endpoints
mux := http.NewServeMux()
mux.Handle("/v1/", handleDummy())
genericWrappedHandler := c.wrapGenericHandler(mux, props)
return genericWrappedHandler
}
func handleDummy() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"foo": "bar"}`))
})
}
func (c *Worker) wrapGenericHandler(h http.Handler, props HandlerProperties) http.Handler {
var maxRequestDuration time.Duration
var maxRequestSize int64
if props.ListenerConfig != nil {
maxRequestDuration = props.ListenerConfig.MaxRequestDuration
maxRequestSize = props.ListenerConfig.MaxRequestSize
}
if maxRequestDuration == 0 {
maxRequestDuration = globals.DefaultMaxRequestDuration
}
if maxRequestSize == 0 {
maxRequestSize = globals.DefaultMaxRequestSize
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Set the Cache-Control header for all responses returned
w.Header().Set("Cache-Control", "no-store")
// Start with the request context
ctx := r.Context()
var cancelFunc context.CancelFunc
// Add our timeout
ctx, cancelFunc = context.WithTimeout(ctx, maxRequestDuration)
// Add a size limiter if desired
if maxRequestSize > 0 {
ctx = context.WithValue(ctx, "max_request_size", maxRequestSize)
}
ctx = context.WithValue(ctx, "original_request_path", r.URL.Path)
r = r.WithContext(ctx)
h.ServeHTTP(w, r)
cancelFunc()
return
})
}
/*
func WrapForwardedForHandler(h http.Handler, authorizedAddrs []*sockaddr.SockAddrMarshaler, rejectNotPresent, rejectNonAuthz bool, hopSkips int) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
headers, headersOK := r.Header[textproto.CanonicalMIMEHeaderKey("X-Forwarded-For")]
if !headersOK || len(headers) == 0 {
if !rejectNotPresent {
h.ServeHTTP(w, r)
return
}
respondError(w, http.StatusBadRequest, fmt.Errorf("missing x-forwarded-for header and configured to reject when not present"))
return
}
host, port, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
// If not rejecting treat it like we just don't have a valid
// header because we can't do a comparison against an address we
// can't understand
if !rejectNotPresent {
h.ServeHTTP(w, r)
return
}
respondError(w, http.StatusBadRequest, errwrap.Wrapf("error parsing client hostport: {{err}}", err))
return
}
addr, err := sockaddr.NewIPAddr(host)
if err != nil {
// We treat this the same as the case above
if !rejectNotPresent {
h.ServeHTTP(w, r)
return
}
respondError(w, http.StatusBadRequest, errwrap.Wrapf("error parsing client address: {{err}}", err))
return
}
var found bool
for _, authz := range authorizedAddrs {
if authz.Contains(addr) {
found = true
break
}
}
if !found {
// If we didn't find it and aren't configured to reject, simply
// don't trust it
if !rejectNonAuthz {
h.ServeHTTP(w, r)
return
}
respondError(w, http.StatusBadRequest, fmt.Errorf("client address not authorized for x-forwarded-for and configured to reject connection"))
return
}
// At this point we have at least one value and it's authorized
// Split comma separated ones, which are common. This brings it in line
// to the multiple-header case.
var acc []string
for _, header := range headers {
vals := strings.Split(header, ",")
for _, v := range vals {
acc = append(acc, strings.TrimSpace(v))
}
}
indexToUse := len(acc) - 1 - hopSkips
if indexToUse < 0 {
// This is likely an error in either configuration or other
// infrastructure. We could either deny the request, or we
// could simply not trust the value. Denying the request is
// "safer" since if this logic is configured at all there may
// be an assumption it can always be trusted. Given that we can
// deny accepting the request at all if it's not from an
// authorized address, if we're at this point the address is
// authorized (or we've turned off explicit rejection) and we
// should assume that what comes in should be properly
// formatted.
respondError(w, http.StatusBadRequest, fmt.Errorf("malformed x-forwarded-for configuration or request, hops to skip (%d) would skip before earliest chain link (chain length %d)", hopSkips, len(headers)))
return
}
r.RemoteAddr = net.JoinHostPort(acc[indexToUse], port)
h.ServeHTTP(w, r)
return
})
}
*/

@ -0,0 +1,148 @@
package worker
import (
"context"
"errors"
"fmt"
"net"
"net/http"
"sync"
"time"
"github.com/hashicorp/go-alpnmux"
"github.com/hashicorp/go-multierror"
)
func (c *Worker) startListeners() error {
var retErr *multierror.Error
servers := make([]func(), 0, len(c.conf.Listeners))
for _, ln := range c.conf.Listeners {
handler := c.Handler(HandlerProperties{
ListenerConfig: ln.Config,
})
/*
// TODO: As I write this Vault's having this code audited, make sure to
// port over any recommendations
//
// We perform validation on the config earlier, we can just cast here
if _, ok := ln.config["x_forwarded_for_authorized_addrs"]; ok {
hopSkips := ln.config["x_forwarded_for_hop_skips"].(int)
authzdAddrs := ln.config["x_forwarded_for_authorized_addrs"].([]*sockaddr.SockAddrMarshaler)
rejectNotPresent := ln.config["x_forwarded_for_reject_not_present"].(bool)
rejectNonAuthz := ln.config["x_forwarded_for_reject_not_authorized"].(bool)
if len(authzdAddrs) > 0 {
handler = vaulthttp.WrapForwardedForHandler(handler, authzdAddrs, rejectNotPresent, rejectNonAuthz, hopSkips)
}
}
*/
server := &http.Server{
Handler: handler,
ReadHeaderTimeout: 10 * time.Second,
ReadTimeout: 30 * time.Second,
IdleTimeout: 5 * time.Minute,
ErrorLog: c.conf.Logger.StandardLogger(nil),
BaseContext: func(net.Listener) context.Context {
return c.baseContext
},
}
ln.HTTPServer = server
if ln.Config.HTTPReadHeaderTimeout > 0 {
server.ReadHeaderTimeout = ln.Config.HTTPReadHeaderTimeout
}
if ln.Config.HTTPReadTimeout > 0 {
server.ReadTimeout = ln.Config.HTTPReadTimeout
}
if ln.Config.HTTPWriteTimeout > 0 {
server.WriteTimeout = ln.Config.HTTPWriteTimeout
}
if ln.Config.HTTPIdleTimeout > 0 {
server.IdleTimeout = ln.Config.HTTPIdleTimeout
}
switch ln.Config.TLSDisable {
case true:
l := ln.Mux.GetListener(alpnmux.NoProto)
if l == nil {
retErr = multierror.Append(retErr, errors.New("could not get non-tls listener"))
continue
}
servers = append(servers, func() {
go server.Serve(l)
})
default:
protos := []string{"", "http/1.1", "h2"}
for _, v := range protos {
l := ln.Mux.GetListener(v)
if l == nil {
retErr = multierror.Append(retErr, fmt.Errorf("could not get tls proto %q listener", v))
continue
}
servers = append(servers, func() {
go server.Serve(l)
})
}
}
workerTLSConfig, peeringInfo, err := c.workerTLS(workerTLSOpts{
Address: ln.Config.Address,
Protos: []string{"watchtower-worker-v1"},
})
if err != nil {
retErr = multierror.Append(retErr, fmt.Errorf("error getting TLS configuration: %w", err))
continue
}
l, err := ln.Mux.RegisterProto("watchtower-worker-v1", workerTLSConfig)
if err != nil {
retErr = multierror.Append(retErr, fmt.Errorf("error getting sub-listener for worker proto: %w", err))
continue
}
// TODO: Start listner for real; for now send it to the http server just for testing
servers = append(servers, func() {
go server.Serve(l)
})
// TODO: Add peering info into database
_ = peeringInfo
}
err := retErr.ErrorOrNil()
if err != nil {
return err
}
for _, s := range servers {
s()
}
return nil
}
func (c *Worker) stopListeners() error {
serverWg := new(sync.WaitGroup)
for _, ln := range c.conf.Listeners {
if ln.HTTPServer == nil {
continue
}
serverWg.Add(1)
go func() {
shutdownKill, shutdownKillCancel := context.WithTimeout(c.baseContext, ln.Config.MaxRequestDuration)
defer shutdownKillCancel()
defer serverWg.Done()
ln.HTTPServer.Shutdown(shutdownKill)
}()
}
serverWg.Wait()
var retErr *multierror.Error
for _, ln := range c.conf.Listeners {
if err := ln.Mux.Close(); err != nil {
retErr = multierror.Append(retErr, err)
}
}
return retErr.ErrorOrNil()
}

@ -0,0 +1,76 @@
package worker
import (
"context"
"crypto/rand"
"fmt"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/sdk/helper/mlock"
)
type Worker struct {
conf *Config
baseContext context.Context
baseCancel context.CancelFunc
}
func New(conf *Config) (*Worker, error) {
if conf.Logger == nil {
conf.Logger = hclog.New(&hclog.LoggerOptions{
Level: hclog.Trace,
})
conf.AllLoggers = append(conf.AllLoggers, conf.Logger)
}
if conf.SecureRandomReader == nil {
conf.SecureRandomReader = rand.Reader
}
if !conf.RawConfig.DisableMlock {
// Ensure our memory usage is locked into physical RAM
if err := mlock.LockMemory(); err != nil {
return nil, fmt.Errorf(
"Failed to lock memory: %v\n\n"+
"This usually means that the mlock syscall is not available.\n"+
"Watchtower uses mlock to prevent memory from being swapped to\n"+
"disk. This requires root privileges as well as a machine\n"+
"that supports mlock. Please enable mlock on your system or\n"+
"disable Watchtower from using it. To disable Watchtower from using it,\n"+
"set the `disable_mlock` configuration option in your configuration\n"+
"file.",
err)
}
}
conf.Logger = conf.Logger.Named("worker")
c := &Worker{
conf: conf,
}
c.baseContext, c.baseCancel = context.WithCancel(context.Background())
return c, nil
}
func (c *Worker) Start() error {
if err := c.startListeners(); err != nil {
return err
}
return nil
}
func (c *Worker) Shutdown() error {
if err := c.stopListeners(); err != nil {
return err
}
return nil
}
func (c *Worker) SetLogLevel(level hclog.Level) {
for _, logger := range c.conf.AllLoggers {
logger.SetLevel(level)
}
}
Loading…
Cancel
Save