mirror of https://github.com/hashicorp/boundary
Migrate a bunch of stuff around to prep for worker command (#2)
Right now it's essentially a copy; it needs some updating, integration into dev, and so on.pull/3/head
parent
c2d385d851
commit
1645c3106e
@ -0,0 +1,8 @@
|
||||
// +build !memprofiler
|
||||
|
||||
package base
|
||||
|
||||
import "github.com/hashicorp/go-hclog"
|
||||
|
||||
func StartMemProfiler(_ hclog.Logger) {
|
||||
}
|
||||
@ -1,6 +0,0 @@
|
||||
// +build !memprofiler
|
||||
|
||||
package controller
|
||||
|
||||
func (c *Command) startMemProfiler() {
|
||||
}
|
||||
@ -1,40 +0,0 @@
|
||||
// +build memprofiler
|
||||
|
||||
package controller
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"runtime/pprof"
|
||||
"time"
|
||||
)
|
||||
|
||||
func init() {
|
||||
memProfilerEnabled = true
|
||||
}
|
||||
|
||||
func (c *Command) startMemProfiler() {
|
||||
profileDir := filepath.Join(os.TempDir(), "watchtowerprof")
|
||||
if err := os.MkdirAll(profileDir, 0700); err != nil {
|
||||
c.logger.Debug("could not create profile directory", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
for {
|
||||
filename := filepath.Join(profileDir, time.Now().UTC().Format("20060102_150405")) + ".pprof"
|
||||
f, err := os.Create(filename)
|
||||
if err != nil {
|
||||
c.logger.Debug("could not create memory profile", "error", err)
|
||||
}
|
||||
runtime.GC()
|
||||
if err := pprof.WriteHeapProfile(f); err != nil {
|
||||
c.logger.Debug("could not write memory profile", "error", err)
|
||||
}
|
||||
f.Close()
|
||||
c.logger.Debug("wrote memory profile", "filename", filename)
|
||||
time.Sleep(5 * time.Minute)
|
||||
}
|
||||
}()
|
||||
}
|
||||
@ -1,81 +0,0 @@
|
||||
package listener
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
|
||||
"github.com/pires/go-proxyproto"
|
||||
|
||||
// We must import sha512 so that it registers with the runtime so that
|
||||
// certificates that use it can be parsed.
|
||||
_ "crypto/sha512"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/hashicorp/go-alpnmux"
|
||||
"github.com/hashicorp/vault/internalshared/configutil"
|
||||
"github.com/hashicorp/vault/internalshared/reloadutil"
|
||||
"github.com/mitchellh/cli"
|
||||
)
|
||||
|
||||
// Factory is the factory function to create a listener.
|
||||
type Factory func(*configutil.Listener, io.Writer, cli.Ui) (*alpnmux.ALPNMux, map[string]string, reloadutil.ReloadFunc, error)
|
||||
|
||||
// BuiltinListeners is the list of built-in listener types.
|
||||
var BuiltinListeners = map[string]Factory{
|
||||
"tcp": tcpListenerFactory,
|
||||
}
|
||||
|
||||
// New creates a new listener of the given type with the given
|
||||
// configuration. The type is looked up in the BuiltinListeners map.
|
||||
func New(l *configutil.Listener, w io.Writer, ui cli.Ui) (*alpnmux.ALPNMux, map[string]string, reloadutil.ReloadFunc, error) {
|
||||
f, ok := BuiltinListeners[l.Type]
|
||||
if !ok {
|
||||
return nil, nil, nil, fmt.Errorf("unknown listener type: %q", l.Type)
|
||||
}
|
||||
|
||||
return f(l, w, ui)
|
||||
}
|
||||
|
||||
func listenerWrapProxy(ln net.Listener, l *configutil.Listener) (net.Listener, error) {
|
||||
behavior := l.ProxyProtocolBehavior
|
||||
if behavior == "" {
|
||||
return ln, nil
|
||||
}
|
||||
|
||||
authorizedAddrs := make([]string, 0, len(l.ProxyProtocolAuthorizedAddrs))
|
||||
for _, v := range l.ProxyProtocolAuthorizedAddrs {
|
||||
authorizedAddrs = append(authorizedAddrs, v.String())
|
||||
}
|
||||
|
||||
var policyFunc proxyproto.PolicyFunc
|
||||
|
||||
switch behavior {
|
||||
case "use_always":
|
||||
policyFunc = func(upstream net.Addr) (proxyproto.Policy, error) {
|
||||
return proxyproto.USE, nil
|
||||
}
|
||||
|
||||
case "allow_authorized":
|
||||
if len(authorizedAddrs) == 0 {
|
||||
return nil, errors.New("proxy_protocol_behavior set but no proxy_protocol_authorized_addrs value")
|
||||
}
|
||||
policyFunc = proxyproto.MustLaxWhiteListPolicy(authorizedAddrs)
|
||||
|
||||
case "deny_unauthorized":
|
||||
if len(authorizedAddrs) == 0 {
|
||||
return nil, errors.New("proxy_protocol_behavior set but no proxy_protocol_authorized_addrs value")
|
||||
}
|
||||
policyFunc = proxyproto.MustStrictWhiteListPolicy(authorizedAddrs)
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown %q value: %q", "proxy_protocol_behavior", behavior)
|
||||
}
|
||||
|
||||
proxyListener := &proxyproto.Listener{
|
||||
Listener: ln,
|
||||
Policy: policyFunc,
|
||||
}
|
||||
|
||||
return proxyListener, nil
|
||||
}
|
||||
@ -1,6 +0,0 @@
|
||||
// +build !memprofiler
|
||||
|
||||
package dev
|
||||
|
||||
func (d *Command) startMemProfiler() {
|
||||
}
|
||||
@ -0,0 +1,122 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
|
||||
"github.com/hashicorp/hcl"
|
||||
"github.com/hashicorp/vault/internalshared/configutil"
|
||||
)
|
||||
|
||||
// Config is the configuration for the watchtower controller
|
||||
type Config struct {
|
||||
*configutil.SharedConfig `hcl:"-"`
|
||||
}
|
||||
|
||||
// Dev is a Config that is used for dev mode of Watchtower
|
||||
func Dev() (*Config, error) {
|
||||
randBuf := new(bytes.Buffer)
|
||||
n, err := randBuf.ReadFrom(&io.LimitedReader{
|
||||
R: rand.Reader,
|
||||
N: 64,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if n != 64 {
|
||||
return nil, fmt.Errorf("expected to read 64 bytes, read %d", n)
|
||||
}
|
||||
controllerKey := base64.StdEncoding.EncodeToString(randBuf.Bytes()[0:32])
|
||||
workerAuthKey := base64.StdEncoding.EncodeToString(randBuf.Bytes()[32:64])
|
||||
|
||||
hclStr := `
|
||||
disable_mlock = true
|
||||
|
||||
listener "tcp" {
|
||||
tls_disable = true
|
||||
proxy_protocol_behavior = "allow_authorized"
|
||||
proxy_protocol_authorized_addrs = "127.0.0.1"
|
||||
}
|
||||
|
||||
telemetry {
|
||||
prometheus_retention_time = "24h"
|
||||
disable_hostname = true
|
||||
}
|
||||
`
|
||||
|
||||
hclStr = fmt.Sprintf(hclStr, controllerKey, workerAuthKey)
|
||||
parsed, err := Parse(hclStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing dev config: %w", err)
|
||||
}
|
||||
return parsed, nil
|
||||
}
|
||||
|
||||
func New() *Config {
|
||||
return &Config{
|
||||
SharedConfig: new(configutil.SharedConfig),
|
||||
}
|
||||
}
|
||||
|
||||
// LoadFile loads the configuration from the given file.
|
||||
func LoadFile(path string) (*Config, error) {
|
||||
// Read the file
|
||||
d, err := ioutil.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conf, err := Parse(string(d))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return conf, nil
|
||||
}
|
||||
|
||||
func Parse(d string) (*Config, error) {
|
||||
obj, err := hcl.Parse(d)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Nothing to do here right now
|
||||
result := New()
|
||||
if err := hcl.DecodeObject(result, obj); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sharedConfig, err := configutil.ParseConfig(d)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result.SharedConfig = sharedConfig
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Sanitized returns a copy of the config with all values that are considered
|
||||
// sensitive stripped. It also strips all `*Raw` values that are mainly
|
||||
// used for parsing.
|
||||
//
|
||||
// Specifically, the fields that this method strips are:
|
||||
// - KMS.Config
|
||||
// - Telemetry.CirconusAPIToken
|
||||
func (c *Config) Sanitized() map[string]interface{} {
|
||||
// Create shared config if it doesn't exist (e.g. in tests) so that map
|
||||
// keys are actually populated
|
||||
if c.SharedConfig == nil {
|
||||
c.SharedConfig = new(configutil.SharedConfig)
|
||||
}
|
||||
sharedResult := c.SharedConfig.Sanitized()
|
||||
result := map[string]interface{}{}
|
||||
for k, v := range sharedResult {
|
||||
result[k] = v
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
@ -0,0 +1,385 @@
|
||||
package worker
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/hashicorp/vault/sdk/helper/mlock"
|
||||
"github.com/hashicorp/watchtower/globals"
|
||||
"github.com/hashicorp/watchtower/internal/cmd/base"
|
||||
"github.com/hashicorp/watchtower/internal/cmd/commands/worker/config"
|
||||
"github.com/hashicorp/watchtower/internal/servers/worker"
|
||||
"github.com/mitchellh/cli"
|
||||
"github.com/posener/complete"
|
||||
)
|
||||
|
||||
var _ cli.Command = (*Command)(nil)
|
||||
var _ cli.CommandAutocomplete = (*Command)(nil)
|
||||
|
||||
var memProfilerEnabled = false
|
||||
|
||||
type Command struct {
|
||||
*base.Command
|
||||
*base.Server
|
||||
|
||||
ShutdownCh chan struct{}
|
||||
SighupCh chan struct{}
|
||||
ReloadedCh chan struct{}
|
||||
SigUSR2Ch chan struct{}
|
||||
|
||||
cleanupGuard sync.Once
|
||||
|
||||
Config *config.Config
|
||||
|
||||
flagConfig string
|
||||
flagLogLevel string
|
||||
flagLogFormat string
|
||||
flagDev bool
|
||||
flagDevAdminToken string
|
||||
flagDevWorkerListenAddr string
|
||||
flagCombineLogs bool
|
||||
}
|
||||
|
||||
func (c *Command) Synopsis() string {
|
||||
return "Start a Watchtower worker"
|
||||
}
|
||||
|
||||
func (c *Command) Help() string {
|
||||
helpText := `
|
||||
Usage: watchtower worker [options]
|
||||
|
||||
Start a worker with a configuration file:
|
||||
|
||||
$ watchtower worker -config=/etc/watchtower/worker.hcl
|
||||
|
||||
For a full list of examples, please see the documentation.
|
||||
|
||||
` + c.Flags().Help()
|
||||
return strings.TrimSpace(helpText)
|
||||
}
|
||||
|
||||
func (c *Command) Flags() *base.FlagSets {
|
||||
set := c.FlagSet(base.FlagSetHTTP)
|
||||
|
||||
f := set.NewFlagSet("Command Options")
|
||||
|
||||
f.StringVar(&base.StringVar{
|
||||
Name: "config",
|
||||
Target: &c.flagConfig,
|
||||
Completion: complete.PredictOr(
|
||||
complete.PredictFiles("*.hcl"),
|
||||
complete.PredictFiles("*.json"),
|
||||
),
|
||||
Usage: "Path to a configuration file.",
|
||||
})
|
||||
|
||||
f.StringVar(&base.StringVar{
|
||||
Name: "log-level",
|
||||
Target: &c.flagLogLevel,
|
||||
Default: base.NotSetValue,
|
||||
EnvVar: "WATCHTOWER_LOG_LEVEL",
|
||||
Completion: complete.PredictSet("trace", "debug", "info", "warn", "err"),
|
||||
Usage: "Log verbosity level. Supported values (in order of more detail to less) are " +
|
||||
"\"trace\", \"debug\", \"info\", \"warn\", and \"err\".",
|
||||
})
|
||||
|
||||
f.StringVar(&base.StringVar{
|
||||
Name: "log-format",
|
||||
Target: &c.flagLogFormat,
|
||||
Default: base.NotSetValue,
|
||||
Completion: complete.PredictSet("standard", "json"),
|
||||
Usage: `Log format. Supported values are "standard" and "json".`,
|
||||
})
|
||||
|
||||
f = set.NewFlagSet("Dev Options")
|
||||
|
||||
f.BoolVar(&base.BoolVar{
|
||||
Name: "dev",
|
||||
Target: &c.flagDev,
|
||||
Usage: "Enable development mode. As the name implies, do not run \"dev\" mode in " +
|
||||
"production.",
|
||||
})
|
||||
|
||||
f.StringVar(&base.StringVar{
|
||||
Name: "dev-admin-token",
|
||||
Target: &c.flagDevAdminToken,
|
||||
Default: "",
|
||||
EnvVar: "WATCHTWER_DEV_ADMIN_TOKEN",
|
||||
Usage: "Initial admin token. This only applies when running in \"dev\" " +
|
||||
"mode.",
|
||||
})
|
||||
|
||||
f.StringVar(&base.StringVar{
|
||||
Name: "dev-listen-address",
|
||||
Target: &c.flagDevWorkerListenAddr,
|
||||
Default: "127.0.0.1:9200",
|
||||
EnvVar: "WATCHTOWER_DEV_WORKER_LISTEN_ADDRESS",
|
||||
Usage: "Address to bind the worker to in \"dev\" mode.",
|
||||
})
|
||||
|
||||
f.BoolVar(&base.BoolVar{
|
||||
Name: "combine-logs",
|
||||
Target: &c.flagCombineLogs,
|
||||
Default: false,
|
||||
Usage: "If set, both startup information and logs will be sent to stdout. If not set (the default), startup information will go to stdout and logs will be sent to stderr.",
|
||||
})
|
||||
|
||||
return set
|
||||
}
|
||||
|
||||
func (c *Command) AutocompleteArgs() complete.Predictor {
|
||||
return complete.PredictNothing
|
||||
}
|
||||
|
||||
func (c *Command) AutocompleteFlags() complete.Flags {
|
||||
return c.Flags().Completions()
|
||||
}
|
||||
|
||||
func (c *Command) Run(args []string) int {
|
||||
c.Server = base.NewServer()
|
||||
c.CombineLogs = c.flagCombineLogs
|
||||
|
||||
if result := c.ParseFlagsAndConfig(args); result > 0 {
|
||||
return result
|
||||
}
|
||||
|
||||
if err := c.SetupLogging(c.flagLogLevel, c.flagLogFormat, c.Config.LogLevel, c.Config.LogFormat); err != nil {
|
||||
c.UI.Error(err.Error())
|
||||
return 1
|
||||
}
|
||||
|
||||
if memProfilerEnabled {
|
||||
base.StartMemProfiler(c.Logger)
|
||||
}
|
||||
|
||||
if err := c.SetupMetrics(c.UI, c.Config.Telemetry); err != nil {
|
||||
c.UI.Error(err.Error())
|
||||
return 1
|
||||
}
|
||||
|
||||
if err := c.SetupKMSes(c.UI, c.Config.SharedConfig, 2); err != nil {
|
||||
c.UI.Error(err.Error())
|
||||
return 1
|
||||
}
|
||||
|
||||
if c.Config.DefaultMaxRequestDuration != 0 {
|
||||
globals.DefaultMaxRequestDuration = c.Config.DefaultMaxRequestDuration
|
||||
}
|
||||
|
||||
// If mlockall(2) isn't supported, show a warning. We disable this in dev
|
||||
// because it is quite scary to see when first using Vault. We also disable
|
||||
// this if the user has explicitly disabled mlock in configuration.
|
||||
if !c.flagDev && !c.Config.DisableMlock && !mlock.Supported() {
|
||||
c.UI.Warn(base.WrapAtLength(
|
||||
"WARNING! mlock is not supported on this system! An mlockall(2)-like " +
|
||||
"syscall to prevent memory from being swapped to disk is not " +
|
||||
"supported on this system. For better security, only run Vault on " +
|
||||
"systems where this call is supported. If you are running Vault " +
|
||||
"in a Docker container, provide the IPC_LOCK cap to the container."))
|
||||
}
|
||||
|
||||
if err := c.SetupListeners(c.UI, c.Config.SharedConfig); err != nil {
|
||||
c.UI.Error(err.Error())
|
||||
return 1
|
||||
}
|
||||
|
||||
// Write out the PID to the file now that server has successfully started
|
||||
if err := c.StorePidFile(c.Config.PidFile); err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error storing PID: %w", err))
|
||||
return 1
|
||||
}
|
||||
|
||||
if c.flagDev {
|
||||
if err := c.CreateDevDatabase(); err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error creating dev database container: %s", err.Error()))
|
||||
return 1
|
||||
}
|
||||
c.ShutdownFuncs = append(c.ShutdownFuncs, c.DestroyDevDatabase)
|
||||
}
|
||||
|
||||
defer c.RunShutdownFuncs(c.UI)
|
||||
|
||||
c.PrintInfo(c.UI, "worker")
|
||||
c.ReleaseLogGate()
|
||||
|
||||
return c.Start()
|
||||
}
|
||||
|
||||
func (c *Command) ParseFlagsAndConfig(args []string) int {
|
||||
var err error
|
||||
|
||||
f := c.Flags()
|
||||
|
||||
if err = f.Parse(args); err != nil {
|
||||
c.UI.Error(err.Error())
|
||||
return 1
|
||||
}
|
||||
|
||||
// Validation
|
||||
if !c.flagDev {
|
||||
switch {
|
||||
case len(c.flagConfig) == 0:
|
||||
c.UI.Error("Must specify a config file using -config")
|
||||
return 1
|
||||
case c.flagDevAdminToken != "":
|
||||
c.UI.Warn(base.WrapAtLength(
|
||||
"You cannot specify a custom admin token ID outside of \"dev\" mode. " +
|
||||
"Your request has been ignored."))
|
||||
c.flagDevAdminToken = ""
|
||||
}
|
||||
|
||||
if len(c.flagConfig) == 0 {
|
||||
c.UI.Error("Must supply a config file with -config")
|
||||
return 1
|
||||
}
|
||||
c.Config, err = config.LoadFile(c.flagConfig)
|
||||
if err != nil {
|
||||
c.UI.Error("Error parsing config: " + err.Error())
|
||||
return 1
|
||||
}
|
||||
|
||||
} else {
|
||||
c.Config, err = config.Dev()
|
||||
if err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error creating dev config: %s", err))
|
||||
return 1
|
||||
}
|
||||
|
||||
if c.flagDevWorkerListenAddr != "" {
|
||||
c.Config.Listeners[0].Address = c.flagDevWorkerListenAddr
|
||||
}
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
func (c *Command) Start() int {
|
||||
// Instantiate the wait group
|
||||
conf := &worker.Config{
|
||||
RawConfig: c.Config,
|
||||
Server: c.Server,
|
||||
}
|
||||
|
||||
// Initialize the core
|
||||
wrkr, err := worker.New(conf)
|
||||
if err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error initializing worker: %w", err))
|
||||
return 1
|
||||
}
|
||||
|
||||
if err := wrkr.Start(); err != nil {
|
||||
c.UI.Error(fmt.Sprint("Error starting worker: %w", err))
|
||||
if err := wrkr.Shutdown(); err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error with worker shutdown: %w", err))
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
// Wait for shutdown
|
||||
shutdownTriggered := false
|
||||
|
||||
for !shutdownTriggered {
|
||||
select {
|
||||
case <-c.ShutdownCh:
|
||||
c.UI.Output("==> Watchtower worker shutdown triggered")
|
||||
|
||||
if err := wrkr.Shutdown(); err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error with worker shutdown: %w", err))
|
||||
}
|
||||
|
||||
shutdownTriggered = true
|
||||
|
||||
case <-c.SighupCh:
|
||||
c.UI.Output("==> Watchtower worker reload triggered")
|
||||
|
||||
// Check for new log level
|
||||
var level hclog.Level
|
||||
var err error
|
||||
var newConf *config.Config
|
||||
|
||||
if c.flagConfig == "" {
|
||||
goto RUNRELOADFUNCS
|
||||
}
|
||||
|
||||
newConf, err = config.LoadFile(c.flagConfig)
|
||||
if err != nil {
|
||||
c.Logger.Error("could not reload config", "path", c.flagConfig, "error", err)
|
||||
goto RUNRELOADFUNCS
|
||||
}
|
||||
|
||||
// Ensure at least one config was found.
|
||||
if newConf == nil {
|
||||
c.Logger.Error("no config found at reload time")
|
||||
goto RUNRELOADFUNCS
|
||||
}
|
||||
|
||||
// Commented out until we need this
|
||||
//wrkr.SetConfig(config)
|
||||
|
||||
if newConf.LogLevel != "" {
|
||||
configLogLevel := strings.ToLower(strings.TrimSpace(newConf.LogLevel))
|
||||
switch configLogLevel {
|
||||
case "trace":
|
||||
level = hclog.Trace
|
||||
case "debug":
|
||||
level = hclog.Debug
|
||||
case "notice", "info", "":
|
||||
level = hclog.Info
|
||||
case "warn", "warning":
|
||||
level = hclog.Warn
|
||||
case "err", "error":
|
||||
level = hclog.Error
|
||||
default:
|
||||
c.Logger.Error("unknown log level found on reload", "level", newConf.LogLevel)
|
||||
goto RUNRELOADFUNCS
|
||||
}
|
||||
wrkr.SetLogLevel(level)
|
||||
}
|
||||
|
||||
RUNRELOADFUNCS:
|
||||
if err := c.Reload(); err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error(s) were encountered during worker reload: %w", err))
|
||||
}
|
||||
|
||||
case <-c.SigUSR2Ch:
|
||||
buf := make([]byte, 32*1024*1024)
|
||||
n := runtime.Stack(buf[:], true)
|
||||
c.Logger.Info("goroutine trace", "stack", string(buf[:n]))
|
||||
}
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
func (c *Command) Reload() error {
|
||||
c.ReloadFuncsLock.RLock()
|
||||
defer c.ReloadFuncsLock.RUnlock()
|
||||
|
||||
var reloadErrors *multierror.Error
|
||||
|
||||
for k, relFuncs := range c.ReloadFuncs {
|
||||
switch {
|
||||
case strings.HasPrefix(k, "listener|"):
|
||||
for _, relFunc := range relFuncs {
|
||||
if relFunc != nil {
|
||||
if err := relFunc(); err != nil {
|
||||
reloadErrors = multierror.Append(reloadErrors, fmt.Errorf("error encountered reloading listener: %w", err))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Send a message that we reloaded. This prevents "guessing" sleep times
|
||||
// in tests.
|
||||
select {
|
||||
case c.ReloadedCh <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
|
||||
return reloadErrors.ErrorOrNil()
|
||||
}
|
||||
@ -0,0 +1,182 @@
|
||||
package worker
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"io/ioutil"
|
||||
"math/big"
|
||||
mathrand "math/rand"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/sdk/helper/base62"
|
||||
)
|
||||
|
||||
type workerTLSOpts struct {
|
||||
Address string
|
||||
Protos []string
|
||||
DumpDir string
|
||||
}
|
||||
|
||||
type certInfo struct {
|
||||
CACert []byte `json:"ca_cert"`
|
||||
CAKey []byte `json:"ca_key"`
|
||||
}
|
||||
|
||||
func (c Worker) workerTLS(opts workerTLSOpts) (*tls.Config, *certInfo, error) {
|
||||
info := new(certInfo)
|
||||
|
||||
certIPs := []net.IP{
|
||||
net.IPv6loopback,
|
||||
net.ParseIP("127.0.0.1"),
|
||||
}
|
||||
|
||||
if opts.Address != "" {
|
||||
baseAddr, err := net.ResolveTCPAddr("tcp", opts.Address)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
certIPs = append(certIPs, baseAddr.IP)
|
||||
}
|
||||
|
||||
_, caKey, err := ed25519.GenerateKey(c.conf.SecureRandomReader)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
info.CAKey = caKey
|
||||
caHost, err := base62.Random(20)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
caCertTemplate := &x509.Certificate{
|
||||
Subject: pkix.Name{
|
||||
CommonName: caHost,
|
||||
},
|
||||
DNSNames: []string{caHost},
|
||||
IPAddresses: certIPs,
|
||||
KeyUsage: x509.KeyUsage(x509.KeyUsageCertSign | x509.KeyUsageCRLSign),
|
||||
SerialNumber: big.NewInt(mathrand.Int63()),
|
||||
NotBefore: time.Now().Add(-30 * time.Second),
|
||||
NotAfter: time.Now().Add(262980 * time.Hour),
|
||||
BasicConstraintsValid: true,
|
||||
IsCA: true,
|
||||
}
|
||||
caBytes, err := x509.CreateCertificate(c.conf.SecureRandomReader, caCertTemplate, caCertTemplate, caKey.Public(), caKey)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
info.CACert = caBytes
|
||||
caCert, err := x509.ParseCertificate(caBytes)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
rootCAs := x509.NewCertPool()
|
||||
rootCAs.AddCert(caCert)
|
||||
caCertPEMBlock := &pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: caBytes,
|
||||
}
|
||||
caCertPEM := pem.EncodeToMemory(caCertPEMBlock)
|
||||
caCertPEMFile := filepath.Join(opts.DumpDir, "ca_cert.pem")
|
||||
|
||||
marshaledCAKey, err := x509.MarshalPKCS8PrivateKey(caKey)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
caKeyPEMBlock := &pem.Block{
|
||||
Type: "PRIVATE KEY",
|
||||
Bytes: marshaledCAKey,
|
||||
}
|
||||
caKeyPEM := pem.EncodeToMemory(caKeyPEMBlock)
|
||||
|
||||
//
|
||||
// Certs generation
|
||||
//
|
||||
_, key, err := ed25519.GenerateKey(c.conf.SecureRandomReader)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
host, err := base62.Random(20)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
certTemplate := &x509.Certificate{
|
||||
Subject: pkix.Name{
|
||||
CommonName: host,
|
||||
},
|
||||
DNSNames: []string{host},
|
||||
IPAddresses: certIPs,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{
|
||||
x509.ExtKeyUsageServerAuth,
|
||||
x509.ExtKeyUsageClientAuth,
|
||||
},
|
||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement,
|
||||
SerialNumber: big.NewInt(mathrand.Int63()),
|
||||
NotBefore: time.Now().Add(-30 * time.Second),
|
||||
NotAfter: time.Now().Add(262980 * time.Hour),
|
||||
}
|
||||
certBytes, err := x509.CreateCertificate(c.conf.SecureRandomReader, certTemplate, caCert, key.Public(), caKey)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
certPEMBlock := &pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: certBytes,
|
||||
}
|
||||
certPEM := pem.EncodeToMemory(certPEMBlock)
|
||||
marshaledKey, err := x509.MarshalPKCS8PrivateKey(key)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
keyPEMBlock := &pem.Block{
|
||||
Type: "PRIVATE KEY",
|
||||
Bytes: marshaledKey,
|
||||
}
|
||||
keyPEM := pem.EncodeToMemory(keyPEMBlock)
|
||||
|
||||
certFile := filepath.Join(opts.DumpDir, "cert.pem")
|
||||
keyFile := filepath.Join(opts.DumpDir, "key.pem")
|
||||
|
||||
tlsCert, err := tls.X509KeyPair(certPEM, keyPEM)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
tlsConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{tlsCert},
|
||||
RootCAs: rootCAs,
|
||||
ClientCAs: rootCAs,
|
||||
ClientAuth: tls.RequestClientCert,
|
||||
NextProtos: opts.Protos,
|
||||
MinVersion: tls.VersionTLS13,
|
||||
}
|
||||
tlsConfig.BuildNameToCertificate()
|
||||
|
||||
if opts.DumpDir != "" {
|
||||
if _, err := os.Stat(opts.DumpDir); os.IsNotExist(err) {
|
||||
if err := os.MkdirAll(opts.DumpDir, 0700); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
if err := ioutil.WriteFile(filepath.Join(opts.DumpDir, "ca_key.pem"), caKeyPEM, 0755); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if err := ioutil.WriteFile(caCertPEMFile, caCertPEM, 0755); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if err := ioutil.WriteFile(certFile, certPEM, 0755); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if err := ioutil.WriteFile(keyFile, keyPEM, 0755); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return tlsConfig, info, nil
|
||||
}
|
||||
@ -0,0 +1,14 @@
|
||||
package worker
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/hashicorp/watchtower/internal/cmd/base"
|
||||
"github.com/hashicorp/watchtower/internal/cmd/commands/worker/config"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
*base.Server
|
||||
RawConfig *config.Config
|
||||
BaseContext context.Context
|
||||
}
|
||||
@ -0,0 +1,159 @@
|
||||
package worker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/internalshared/configutil"
|
||||
"github.com/hashicorp/watchtower/globals"
|
||||
)
|
||||
|
||||
type HandlerProperties struct {
|
||||
ListenerConfig *configutil.Listener
|
||||
}
|
||||
|
||||
// Handler returns an http.Handler for the API. This can be used on
|
||||
// its own to mount the Vault API within another web server.
|
||||
func (c *Worker) Handler(props HandlerProperties) http.Handler {
|
||||
// Create the muxer to handle the actual endpoints
|
||||
mux := http.NewServeMux()
|
||||
|
||||
mux.Handle("/v1/", handleDummy())
|
||||
|
||||
genericWrappedHandler := c.wrapGenericHandler(mux, props)
|
||||
|
||||
return genericWrappedHandler
|
||||
}
|
||||
|
||||
func handleDummy() http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(`{"foo": "bar"}`))
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Worker) wrapGenericHandler(h http.Handler, props HandlerProperties) http.Handler {
|
||||
var maxRequestDuration time.Duration
|
||||
var maxRequestSize int64
|
||||
if props.ListenerConfig != nil {
|
||||
maxRequestDuration = props.ListenerConfig.MaxRequestDuration
|
||||
maxRequestSize = props.ListenerConfig.MaxRequestSize
|
||||
}
|
||||
if maxRequestDuration == 0 {
|
||||
maxRequestDuration = globals.DefaultMaxRequestDuration
|
||||
}
|
||||
if maxRequestSize == 0 {
|
||||
maxRequestSize = globals.DefaultMaxRequestSize
|
||||
}
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Set the Cache-Control header for all responses returned
|
||||
w.Header().Set("Cache-Control", "no-store")
|
||||
|
||||
// Start with the request context
|
||||
ctx := r.Context()
|
||||
var cancelFunc context.CancelFunc
|
||||
// Add our timeout
|
||||
ctx, cancelFunc = context.WithTimeout(ctx, maxRequestDuration)
|
||||
// Add a size limiter if desired
|
||||
if maxRequestSize > 0 {
|
||||
ctx = context.WithValue(ctx, "max_request_size", maxRequestSize)
|
||||
}
|
||||
ctx = context.WithValue(ctx, "original_request_path", r.URL.Path)
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
h.ServeHTTP(w, r)
|
||||
cancelFunc()
|
||||
return
|
||||
})
|
||||
}
|
||||
|
||||
/*
|
||||
func WrapForwardedForHandler(h http.Handler, authorizedAddrs []*sockaddr.SockAddrMarshaler, rejectNotPresent, rejectNonAuthz bool, hopSkips int) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
headers, headersOK := r.Header[textproto.CanonicalMIMEHeaderKey("X-Forwarded-For")]
|
||||
if !headersOK || len(headers) == 0 {
|
||||
if !rejectNotPresent {
|
||||
h.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
respondError(w, http.StatusBadRequest, fmt.Errorf("missing x-forwarded-for header and configured to reject when not present"))
|
||||
return
|
||||
}
|
||||
|
||||
host, port, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
// If not rejecting treat it like we just don't have a valid
|
||||
// header because we can't do a comparison against an address we
|
||||
// can't understand
|
||||
if !rejectNotPresent {
|
||||
h.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
respondError(w, http.StatusBadRequest, errwrap.Wrapf("error parsing client hostport: {{err}}", err))
|
||||
return
|
||||
}
|
||||
|
||||
addr, err := sockaddr.NewIPAddr(host)
|
||||
if err != nil {
|
||||
// We treat this the same as the case above
|
||||
if !rejectNotPresent {
|
||||
h.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
respondError(w, http.StatusBadRequest, errwrap.Wrapf("error parsing client address: {{err}}", err))
|
||||
return
|
||||
}
|
||||
|
||||
var found bool
|
||||
for _, authz := range authorizedAddrs {
|
||||
if authz.Contains(addr) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
// If we didn't find it and aren't configured to reject, simply
|
||||
// don't trust it
|
||||
if !rejectNonAuthz {
|
||||
h.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
respondError(w, http.StatusBadRequest, fmt.Errorf("client address not authorized for x-forwarded-for and configured to reject connection"))
|
||||
return
|
||||
}
|
||||
|
||||
// At this point we have at least one value and it's authorized
|
||||
|
||||
// Split comma separated ones, which are common. This brings it in line
|
||||
// to the multiple-header case.
|
||||
var acc []string
|
||||
for _, header := range headers {
|
||||
vals := strings.Split(header, ",")
|
||||
for _, v := range vals {
|
||||
acc = append(acc, strings.TrimSpace(v))
|
||||
}
|
||||
}
|
||||
|
||||
indexToUse := len(acc) - 1 - hopSkips
|
||||
if indexToUse < 0 {
|
||||
// This is likely an error in either configuration or other
|
||||
// infrastructure. We could either deny the request, or we
|
||||
// could simply not trust the value. Denying the request is
|
||||
// "safer" since if this logic is configured at all there may
|
||||
// be an assumption it can always be trusted. Given that we can
|
||||
// deny accepting the request at all if it's not from an
|
||||
// authorized address, if we're at this point the address is
|
||||
// authorized (or we've turned off explicit rejection) and we
|
||||
// should assume that what comes in should be properly
|
||||
// formatted.
|
||||
respondError(w, http.StatusBadRequest, fmt.Errorf("malformed x-forwarded-for configuration or request, hops to skip (%d) would skip before earliest chain link (chain length %d)", hopSkips, len(headers)))
|
||||
return
|
||||
}
|
||||
|
||||
r.RemoteAddr = net.JoinHostPort(acc[indexToUse], port)
|
||||
h.ServeHTTP(w, r)
|
||||
return
|
||||
})
|
||||
}
|
||||
*/
|
||||
@ -0,0 +1,148 @@
|
||||
package worker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-alpnmux"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
)
|
||||
|
||||
func (c *Worker) startListeners() error {
|
||||
var retErr *multierror.Error
|
||||
servers := make([]func(), 0, len(c.conf.Listeners))
|
||||
for _, ln := range c.conf.Listeners {
|
||||
handler := c.Handler(HandlerProperties{
|
||||
ListenerConfig: ln.Config,
|
||||
})
|
||||
|
||||
/*
|
||||
// TODO: As I write this Vault's having this code audited, make sure to
|
||||
// port over any recommendations
|
||||
//
|
||||
// We perform validation on the config earlier, we can just cast here
|
||||
if _, ok := ln.config["x_forwarded_for_authorized_addrs"]; ok {
|
||||
hopSkips := ln.config["x_forwarded_for_hop_skips"].(int)
|
||||
authzdAddrs := ln.config["x_forwarded_for_authorized_addrs"].([]*sockaddr.SockAddrMarshaler)
|
||||
rejectNotPresent := ln.config["x_forwarded_for_reject_not_present"].(bool)
|
||||
rejectNonAuthz := ln.config["x_forwarded_for_reject_not_authorized"].(bool)
|
||||
if len(authzdAddrs) > 0 {
|
||||
handler = vaulthttp.WrapForwardedForHandler(handler, authzdAddrs, rejectNotPresent, rejectNonAuthz, hopSkips)
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
server := &http.Server{
|
||||
Handler: handler,
|
||||
ReadHeaderTimeout: 10 * time.Second,
|
||||
ReadTimeout: 30 * time.Second,
|
||||
IdleTimeout: 5 * time.Minute,
|
||||
ErrorLog: c.conf.Logger.StandardLogger(nil),
|
||||
BaseContext: func(net.Listener) context.Context {
|
||||
return c.baseContext
|
||||
},
|
||||
}
|
||||
ln.HTTPServer = server
|
||||
|
||||
if ln.Config.HTTPReadHeaderTimeout > 0 {
|
||||
server.ReadHeaderTimeout = ln.Config.HTTPReadHeaderTimeout
|
||||
}
|
||||
if ln.Config.HTTPReadTimeout > 0 {
|
||||
server.ReadTimeout = ln.Config.HTTPReadTimeout
|
||||
}
|
||||
if ln.Config.HTTPWriteTimeout > 0 {
|
||||
server.WriteTimeout = ln.Config.HTTPWriteTimeout
|
||||
}
|
||||
if ln.Config.HTTPIdleTimeout > 0 {
|
||||
server.IdleTimeout = ln.Config.HTTPIdleTimeout
|
||||
}
|
||||
|
||||
switch ln.Config.TLSDisable {
|
||||
case true:
|
||||
l := ln.Mux.GetListener(alpnmux.NoProto)
|
||||
if l == nil {
|
||||
retErr = multierror.Append(retErr, errors.New("could not get non-tls listener"))
|
||||
continue
|
||||
}
|
||||
servers = append(servers, func() {
|
||||
go server.Serve(l)
|
||||
})
|
||||
|
||||
default:
|
||||
protos := []string{"", "http/1.1", "h2"}
|
||||
for _, v := range protos {
|
||||
l := ln.Mux.GetListener(v)
|
||||
if l == nil {
|
||||
retErr = multierror.Append(retErr, fmt.Errorf("could not get tls proto %q listener", v))
|
||||
continue
|
||||
}
|
||||
servers = append(servers, func() {
|
||||
go server.Serve(l)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
workerTLSConfig, peeringInfo, err := c.workerTLS(workerTLSOpts{
|
||||
Address: ln.Config.Address,
|
||||
Protos: []string{"watchtower-worker-v1"},
|
||||
})
|
||||
if err != nil {
|
||||
retErr = multierror.Append(retErr, fmt.Errorf("error getting TLS configuration: %w", err))
|
||||
continue
|
||||
}
|
||||
l, err := ln.Mux.RegisterProto("watchtower-worker-v1", workerTLSConfig)
|
||||
if err != nil {
|
||||
retErr = multierror.Append(retErr, fmt.Errorf("error getting sub-listener for worker proto: %w", err))
|
||||
continue
|
||||
}
|
||||
|
||||
// TODO: Start listner for real; for now send it to the http server just for testing
|
||||
servers = append(servers, func() {
|
||||
go server.Serve(l)
|
||||
})
|
||||
|
||||
// TODO: Add peering info into database
|
||||
_ = peeringInfo
|
||||
}
|
||||
|
||||
err := retErr.ErrorOrNil()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, s := range servers {
|
||||
s()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Worker) stopListeners() error {
|
||||
serverWg := new(sync.WaitGroup)
|
||||
for _, ln := range c.conf.Listeners {
|
||||
if ln.HTTPServer == nil {
|
||||
continue
|
||||
}
|
||||
serverWg.Add(1)
|
||||
go func() {
|
||||
shutdownKill, shutdownKillCancel := context.WithTimeout(c.baseContext, ln.Config.MaxRequestDuration)
|
||||
defer shutdownKillCancel()
|
||||
defer serverWg.Done()
|
||||
ln.HTTPServer.Shutdown(shutdownKill)
|
||||
}()
|
||||
}
|
||||
serverWg.Wait()
|
||||
|
||||
var retErr *multierror.Error
|
||||
for _, ln := range c.conf.Listeners {
|
||||
if err := ln.Mux.Close(); err != nil {
|
||||
retErr = multierror.Append(retErr, err)
|
||||
}
|
||||
}
|
||||
return retErr.ErrorOrNil()
|
||||
}
|
||||
@ -0,0 +1,76 @@
|
||||
package worker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/vault/sdk/helper/mlock"
|
||||
)
|
||||
|
||||
type Worker struct {
|
||||
conf *Config
|
||||
|
||||
baseContext context.Context
|
||||
baseCancel context.CancelFunc
|
||||
}
|
||||
|
||||
func New(conf *Config) (*Worker, error) {
|
||||
if conf.Logger == nil {
|
||||
conf.Logger = hclog.New(&hclog.LoggerOptions{
|
||||
Level: hclog.Trace,
|
||||
})
|
||||
conf.AllLoggers = append(conf.AllLoggers, conf.Logger)
|
||||
}
|
||||
|
||||
if conf.SecureRandomReader == nil {
|
||||
conf.SecureRandomReader = rand.Reader
|
||||
}
|
||||
|
||||
if !conf.RawConfig.DisableMlock {
|
||||
// Ensure our memory usage is locked into physical RAM
|
||||
if err := mlock.LockMemory(); err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"Failed to lock memory: %v\n\n"+
|
||||
"This usually means that the mlock syscall is not available.\n"+
|
||||
"Watchtower uses mlock to prevent memory from being swapped to\n"+
|
||||
"disk. This requires root privileges as well as a machine\n"+
|
||||
"that supports mlock. Please enable mlock on your system or\n"+
|
||||
"disable Watchtower from using it. To disable Watchtower from using it,\n"+
|
||||
"set the `disable_mlock` configuration option in your configuration\n"+
|
||||
"file.",
|
||||
err)
|
||||
}
|
||||
}
|
||||
|
||||
conf.Logger = conf.Logger.Named("worker")
|
||||
|
||||
c := &Worker{
|
||||
conf: conf,
|
||||
}
|
||||
|
||||
c.baseContext, c.baseCancel = context.WithCancel(context.Background())
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *Worker) Start() error {
|
||||
if err := c.startListeners(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Worker) Shutdown() error {
|
||||
if err := c.stopListeners(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Worker) SetLogLevel(level hclog.Level) {
|
||||
for _, logger := range c.conf.AllLoggers {
|
||||
logger.SetLevel(level)
|
||||
}
|
||||
}
|
||||
Loading…
Reference in new issue