diff --git a/go.mod b/go.mod index 78db9a1628..21cb8c2d11 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/fatih/color v1.9.0 github.com/ghodss/yaml v1.0.1-0.20190212211648-25d852aebe32 github.com/hashicorp/errwrap v1.0.0 - github.com/hashicorp/go-alpnmux v0.0.0-20200323001347-b5ec1528c52d + github.com/hashicorp/go-alpnmux v0.0.0-20200323180452-dee08f00df54 github.com/hashicorp/go-hclog v0.12.1 github.com/hashicorp/go-kms-wrapping v0.5.7-0.20200322213809-e2a819ec93db github.com/hashicorp/go-multierror v1.0.0 diff --git a/go.sum b/go.sum index fec9ac4153..070742c95b 100644 --- a/go.sum +++ b/go.sum @@ -317,8 +317,8 @@ github.com/hashicorp/consul/sdk v0.1.1/go.mod h1:VKf9jXwCTEY1QZP2MOLRhb5i/I/ssyN github.com/hashicorp/consul/sdk v0.2.0/go.mod h1:VKf9jXwCTEY1QZP2MOLRhb5i/I/ssyNV1vwHyQBF0x8= github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= -github.com/hashicorp/go-alpnmux v0.0.0-20200323001347-b5ec1528c52d h1:5r4nKvUcI+vx2yDn9ysFf53mSY4kUbXqbZ4LF10skFE= -github.com/hashicorp/go-alpnmux v0.0.0-20200323001347-b5ec1528c52d/go.mod h1:KvpteZzIafT4tRAuQ9vVRBgZyqeVCS2B2177fNAyEZc= +github.com/hashicorp/go-alpnmux v0.0.0-20200323180452-dee08f00df54 h1:WhMHPQosFuXFjt2wpRzX2eqR1WpnS+35MFzmG3trg3Y= +github.com/hashicorp/go-alpnmux v0.0.0-20200323180452-dee08f00df54/go.mod h1:KvpteZzIafT4tRAuQ9vVRBgZyqeVCS2B2177fNAyEZc= github.com/hashicorp/go-cleanhttp v0.5.0/go.mod h1:JpRdi6/HCYpAwUzNwuwqhbovhLtngrth3wmdIIUrZ80= github.com/hashicorp/go-cleanhttp v0.5.1 h1:dH3aiDG9Jvb5r5+bYHsikaOUIpcM0xvgMXVoDkXMzJM= github.com/hashicorp/go-cleanhttp v0.5.1/go.mod h1:JpRdi6/HCYpAwUzNwuwqhbovhLtngrth3wmdIIUrZ80= diff --git a/internal/cmd/base/listener.go b/internal/cmd/base/listener.go index ba38f4c870..a13f99c2c5 100644 --- a/internal/cmd/base/listener.go +++ b/internal/cmd/base/listener.go @@ -6,6 +6,7 @@ import ( "io" "net" "net/http" + "os" "strings" "time" @@ -14,6 +15,7 @@ import ( _ "crypto/sha512" "github.com/hashicorp/go-alpnmux" + "github.com/hashicorp/go-hclog" "github.com/hashicorp/vault/internalshared/configutil" "github.com/hashicorp/vault/internalshared/listenerutil" "github.com/hashicorp/vault/internalshared/reloadutil" @@ -35,7 +37,7 @@ type WorkerAuthCertInfo struct { } // Factory is the factory function to create a listener. -type ListenerFactory func(*configutil.Listener, io.Writer, cli.Ui) (*alpnmux.ALPNMux, map[string]string, reloadutil.ReloadFunc, error) +type ListenerFactory func(*configutil.Listener, io.Writer, hclog.Logger, cli.Ui) (*alpnmux.ALPNMux, map[string]string, reloadutil.ReloadFunc, error) // BuiltinListeners is the list of built-in listener types. var BuiltinListeners = map[string]ListenerFactory{ @@ -44,16 +46,16 @@ var BuiltinListeners = map[string]ListenerFactory{ // New creates a new listener of the given type with the given // configuration. The type is looked up in the BuiltinListeners map. -func NewListener(l *configutil.Listener, w io.Writer, ui cli.Ui) (*alpnmux.ALPNMux, map[string]string, reloadutil.ReloadFunc, error) { +func NewListener(l *configutil.Listener, w io.Writer, logger hclog.Logger, ui cli.Ui) (*alpnmux.ALPNMux, map[string]string, reloadutil.ReloadFunc, error) { f, ok := BuiltinListeners[l.Type] if !ok { return nil, nil, nil, fmt.Errorf("unknown listener type: %q", l.Type) } - return f(l, w, ui) + return f(l, w, logger, ui) } -func tcpListenerFactory(l *configutil.Listener, _ io.Writer, ui cli.Ui) (*alpnmux.ALPNMux, map[string]string, reloadutil.ReloadFunc, error) { +func tcpListenerFactory(l *configutil.Listener, _ io.Writer, logger hclog.Logger, ui cli.Ui) (*alpnmux.ALPNMux, map[string]string, reloadutil.ReloadFunc, error) { if l.Address == "" { if len(l.Purpose) == 1 && l.Purpose[0] == "cluster" { l.Address = "127.0.0.1:9201" @@ -86,7 +88,10 @@ func tcpListenerFactory(l *configutil.Listener, _ io.Writer, ui cli.Ui) (*alpnmu "addr": l.Address, } - alpnMux := alpnmux.New(ln, nil) + if _, ok := os.LookupEnv("WATCHTOWER_LOG_CONNECTION_MUXING"); !ok { + logger = nil + } + alpnMux := alpnmux.New(ln, logger) if l.TLSDisable { if _, err = alpnMux.RegisterProto(alpnmux.NoProto, nil); err != nil { diff --git a/internal/cmd/base/servers.go b/internal/cmd/base/servers.go index d824eae51b..900db697b1 100644 --- a/internal/cmd/base/servers.go +++ b/internal/cmd/base/servers.go @@ -241,7 +241,7 @@ func (b *Server) SetupListeners(ui cli.Ui, config *configutil.SharedConfig) erro tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, } - lnMux, props, reloadFunc, err := NewListener(lnConfig, b.GatedWriter, ui) + lnMux, props, reloadFunc, err := NewListener(lnConfig, b.GatedWriter, b.Logger, ui) if err != nil { return fmt.Errorf("Error initializing listener of type %s: %w", lnConfig.Type, err) } diff --git a/internal/servers/controller/controller.go b/internal/servers/controller/controller.go index 2d58b18b6d..4af56170d2 100644 --- a/internal/servers/controller/controller.go +++ b/internal/servers/controller/controller.go @@ -10,18 +10,24 @@ import ( ) type Controller struct { - conf *Config + conf *Config + logger hclog.Logger baseContext context.Context baseCancel context.CancelFunc } func New(conf *Config) (*Controller, error) { - if conf.Logger == nil { - conf.Logger = hclog.New(&hclog.LoggerOptions{ + c := &Controller{ + conf: conf, + logger: conf.Logger, + } + + if c.logger == nil { + c.logger = hclog.New(&hclog.LoggerOptions{ Level: hclog.Trace, }) - conf.AllLoggers = append(conf.AllLoggers, conf.Logger) + conf.AllLoggers = append(conf.AllLoggers, c.logger) } if conf.SecureRandomReader == nil { @@ -44,11 +50,7 @@ func New(conf *Config) (*Controller, error) { } } - conf.Logger = conf.Logger.Named("controller") - - c := &Controller{ - conf: conf, - } + c.logger = c.logger.Named("controller") c.baseContext, c.baseCancel = context.WithCancel(context.Background()) @@ -57,14 +59,14 @@ func New(conf *Config) (*Controller, error) { func (c *Controller) Start() error { if err := c.startListeners(); err != nil { - return err + return fmt.Errorf("error starting controller listeners: %w", err) } return nil } func (c *Controller) Shutdown() error { if err := c.stopListeners(); err != nil { - return err + return fmt.Errorf("error stopping controller listeners: %w", err) } return nil } diff --git a/internal/servers/controller/listeners.go b/internal/servers/controller/listeners.go index 11e3268c9e..d050910ecc 100644 --- a/internal/servers/controller/listeners.go +++ b/internal/servers/controller/listeners.go @@ -7,6 +7,7 @@ import ( "fmt" "net" "net/http" + "strings" "sync" "time" @@ -45,7 +46,7 @@ func (c *Controller) startListeners() error { ReadHeaderTimeout: 10 * time.Second, ReadTimeout: 30 * time.Second, IdleTimeout: 5 * time.Minute, - ErrorLog: c.conf.Logger.StandardLogger(nil), + ErrorLog: c.logger.StandardLogger(nil), BaseContext: func(net.Listener) context.Context { return c.baseContext }, @@ -107,9 +108,19 @@ func (c *Controller) startListeners() error { for { conn, err := ln.ALPNListener.Accept() if err != nil { - c.conf.Logger.Info("default alpn listener errored, exiting") + if !strings.Contains(err.Error(), "use of closed network connection") { + c.logger.Info("default alpn listener errored, exiting", "error", err) + } return } + _, err = conn.Read(make([]byte, 3)) + if err != nil { + retErr = multierror.Append(retErr, fmt.Errorf("error reading test string from worker for worker auth: %w", err)) + } + _, err = conn.Write([]byte("bar")) + if err != nil { + retErr = multierror.Append(retErr, fmt.Errorf("error writing test string to worker for worker auth: %w", err)) + } conn.Close() } }() diff --git a/internal/servers/worker/cluster_tls.go b/internal/servers/worker/cluster_tls.go index d535257bc6..7ef02670cc 100644 --- a/internal/servers/worker/cluster_tls.go +++ b/internal/servers/worker/cluster_tls.go @@ -39,7 +39,7 @@ func (c Worker) workerAuthTLSConfig() (*tls.Config, error) { KeyUsage: x509.KeyUsage(x509.KeyUsageCertSign | x509.KeyUsageCRLSign), SerialNumber: big.NewInt(mathrand.Int63()), NotBefore: time.Now().Add(-30 * time.Second), - NotAfter: time.Now().Add(262980 * time.Hour), + NotAfter: time.Now().Add(3 * time.Minute), BasicConstraintsValid: true, IsCA: true, } @@ -80,7 +80,7 @@ func (c Worker) workerAuthTLSConfig() (*tls.Config, error) { KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement, SerialNumber: big.NewInt(mathrand.Int63()), NotBefore: time.Now().Add(-30 * time.Second), - NotAfter: time.Now().Add(262980 * time.Hour), + NotAfter: time.Now().Add(2 * time.Minute), } certBytes, err := x509.CreateCertificate(c.conf.SecureRandomReader, certTemplate, caCert, key.Public(), caKey) if err != nil { diff --git a/internal/servers/worker/listeners.go b/internal/servers/worker/listeners.go index a5147b5a57..9cdacdfa1c 100644 --- a/internal/servers/worker/listeners.go +++ b/internal/servers/worker/listeners.go @@ -3,6 +3,7 @@ package worker import ( "context" "crypto/tls" + "errors" "fmt" "sync" @@ -37,13 +38,41 @@ func (c *Worker) startListeners() error { continue } if ln.ALPNListener != nil { - c.conf.Logger.Info("testing the waters with proto", "protos", tlsConf.NextProtos) conn, err := tls.Dial(ln.ALPNListener.Addr().Network(), ln.ALPNListener.Addr().String(), tlsConf) if err != nil { retErr = multierror.Append(retErr, fmt.Errorf("error dialing controller for worker auth: %w", err)) continue } - c.conf.Logger.Info("negotiated a protocol", "proto", conn.ConnectionState().NegotiatedProtocol, "mutual", conn.ConnectionState().NegotiatedProtocolIsMutual) + _, err = conn.Write([]byte("foo")) + if err != nil { + retErr = multierror.Append(retErr, fmt.Errorf("error writing test string to controller for worker auth: %w", err)) + continue + } + _, err = conn.Read(make([]byte, 3)) + if err != nil { + retErr = multierror.Append(retErr, fmt.Errorf("error reading test string from controller for worker auth: %w", err)) + continue + } + c.logger.Info("done good writing/reading") + conn.Close() + newTLSConf, _ := c.workerAuthTLSConfig() + tlsConf.Certificates = newTLSConf.Certificates + conn, err = tls.Dial(ln.ALPNListener.Addr().Network(), ln.ALPNListener.Addr().String(), tlsConf) + if err != nil { + retErr = multierror.Append(retErr, fmt.Errorf("error dialing controller for worker auth: %w", err)) + continue + } + _, err = conn.Write([]byte("foo")) + if err != nil { + retErr = multierror.Append(retErr, fmt.Errorf("error writing test string to controller for worker auth: %w", err)) + continue + } + _, err = conn.Read(make([]byte, 3)) + if err == nil { + retErr = multierror.Append(retErr, errors.New("expected error reading test string from controller for worker auth")) + continue + } + c.logger.Info("done bad writing/reading") conn.Close() } } diff --git a/internal/servers/worker/worker.go b/internal/servers/worker/worker.go index 7f20eb8d0a..db66ab024c 100644 --- a/internal/servers/worker/worker.go +++ b/internal/servers/worker/worker.go @@ -10,18 +10,24 @@ import ( ) type Worker struct { - conf *Config + conf *Config + logger hclog.Logger baseContext context.Context baseCancel context.CancelFunc } func New(conf *Config) (*Worker, error) { - if conf.Logger == nil { - conf.Logger = hclog.New(&hclog.LoggerOptions{ + c := &Worker{ + conf: conf, + logger: conf.Logger, + } + + if c.logger == nil { + c.logger = hclog.New(&hclog.LoggerOptions{ Level: hclog.Trace, }) - conf.AllLoggers = append(conf.AllLoggers, conf.Logger) + conf.AllLoggers = append(conf.AllLoggers, c.logger) } if conf.SecureRandomReader == nil { @@ -44,11 +50,7 @@ func New(conf *Config) (*Worker, error) { } } - conf.Logger = conf.Logger.Named("worker") - - c := &Worker{ - conf: conf, - } + c.logger = c.logger.Named("worker") c.baseContext, c.baseCancel = context.WithCancel(context.Background()) @@ -57,14 +59,14 @@ func New(conf *Config) (*Worker, error) { func (c *Worker) Start() error { if err := c.startListeners(); err != nil { - return err + return fmt.Errorf("error starting worker listeners: %w", err) } return nil } func (c *Worker) Shutdown() error { if err := c.stopListeners(); err != nil { - return err + return fmt.Errorf("error stopping worker listeners: %w", err) } return nil }