From 07dc3db974330e6b5a809aa1d0520dccdecdf6cf Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Fri, 1 Apr 2022 12:06:41 -0400 Subject: [PATCH] Remove ALPN Muxer (#1965) We never ended up using this and it makes it much harder to reason about the listeners. An explanation: for the moment I separated out all of the listeners in the struct, even though one should suffice, because when I did it that way first I was hitting some really weird and hard to find intermittent behavior where it seemed like the listeners were getting crossed. So I re-did it this way with each listener very explicitly referenced. I'm happy to make a follow-up PR that attempts to combine them all into a single listener without hitting the issues I did, if desired, but at least doing it via a separate PR means that it won't be something with removing the ALPN muxer that is at fault. --- internal/cmd/base/listener.go | 69 ++--- internal/cmd/base/servers.go | 33 ++- internal/cmd/ops/server.go | 42 +-- internal/cmd/ops/server_test.go | 53 ++-- internal/libs/alpnmux/mux.go | 273 ------------------ internal/libs/alpnmux/mux_test.go | 201 ------------- internal/libs/alpnmux/testing.go | 119 -------- internal/servers/controller/listeners.go | 49 +--- internal/servers/controller/listeners_test.go | 76 +++-- internal/servers/controller/testing.go | 10 +- internal/servers/worker/listeners.go | 24 +- internal/servers/worker/listeners_test.go | 72 ++--- internal/servers/worker/testing.go | 2 +- 13 files changed, 191 insertions(+), 832 deletions(-) delete mode 100644 internal/libs/alpnmux/mux.go delete mode 100644 internal/libs/alpnmux/mux_test.go delete mode 100644 internal/libs/alpnmux/testing.go diff --git a/internal/cmd/base/listener.go b/internal/cmd/base/listener.go index fef18788a9..ef4d82eb57 100644 --- a/internal/cmd/base/listener.go +++ b/internal/cmd/base/listener.go @@ -11,8 +11,8 @@ import ( // We must import sha512 so that it registers with the runtime so that // certificates that use it can be parsed. _ "crypto/sha512" + "crypto/tls" - "github.com/hashicorp/boundary/internal/libs/alpnmux" "github.com/hashicorp/go-secure-stdlib/listenerutil" "github.com/hashicorp/go-secure-stdlib/reloadutil" "github.com/mitchellh/cli" @@ -21,11 +21,13 @@ import ( ) type ServerListener struct { - Mux *alpnmux.ALPNMux - Config *listenerutil.ListenerConfig - HTTPServer *http.Server - GrpcServer *grpc.Server - ALPNListener net.Listener + Config *listenerutil.ListenerConfig + HTTPServer *http.Server + GrpcServer *grpc.Server + ApiListener net.Listener + ClusterListener net.Listener + ProxyListener net.Listener + OpsListener net.Listener } type WorkerAuthInfo struct { @@ -47,7 +49,7 @@ var BuiltinListeners = map[string]ListenerFactory{ // New creates a new listener of the given type with the given // configuration. The type is looked up in the BuiltinListeners map. -func NewListener(l *listenerutil.ListenerConfig, ui cli.Ui) (*alpnmux.ALPNMux, map[string]string, reloadutil.ReloadFunc, error) { +func NewListener(l *listenerutil.ListenerConfig, ui cli.Ui) (net.Listener, map[string]string, reloadutil.ReloadFunc, error) { f, ok := BuiltinListeners[l.Type] if !ok { return nil, nil, nil, fmt.Errorf("unknown listener type: %q", l.Type) @@ -58,8 +60,23 @@ func NewListener(l *listenerutil.ListenerConfig, ui cli.Ui) (*alpnmux.ALPNMux, m } purpose := l.Purpose[0] + finalAddr, ln, err := f(purpose, l, ui) + if err != nil { + return nil, nil, nil, err + } + + ln, err = listenerWrapProxy(ln, l) + if err != nil { + return nil, nil, nil, err + } + + props := map[string]string{ + "addr": finalAddr, + } + switch purpose { case "cluster": + // We handle our own cluster authentication l.TLSDisable = true case "proxy": // TODO: Eventually we'll support bringing your own cert, and we'd only @@ -78,24 +95,8 @@ func NewListener(l *listenerutil.ListenerConfig, ui cli.Ui) (*alpnmux.ALPNMux, m } } - finalAddr, ln, err := f(purpose, l, ui) - if err != nil { - return nil, nil, nil, err - } - - ln, err = listenerWrapProxy(ln, l) - if err != nil { - return nil, nil, nil, err - } - - props := map[string]string{ - "addr": finalAddr, - } - - alpnMux := alpnmux.New(ln) - if l.TLSDisable { - return alpnMux, props, nil, nil + return ln, props, nil, nil } // Don't request a client cert unless they've explicitly configured it to do @@ -107,23 +108,15 @@ func NewListener(l *listenerutil.ListenerConfig, ui cli.Ui) (*alpnmux.ALPNMux, m if err != nil { return nil, nil, nil, err } - // Register no proto, "http/1.1", and "h2", with same TLS config - if _, err = alpnMux.RegisterProto("", tlsConfig); err != nil { - return nil, nil, nil, err - } - if _, err = alpnMux.RegisterProto("http/1.1", tlsConfig); err != nil { - return nil, nil, nil, err - } - if _, err = alpnMux.RegisterProto("h2", tlsConfig); err != nil { - return nil, nil, nil, err - } - return alpnMux, props, reloadFunc, nil + return tls.NewListener(ln, tlsConfig), props, reloadFunc, nil } func tcpListenerFactory(purpose string, l *listenerutil.ListenerConfig, ui cli.Ui) (string, net.Listener, error) { if l.Address == "" { switch purpose { + case "api": + l.Address = "127.0.0.1:9200" case "cluster": l.Address = "127.0.0.1:9201" case "proxy": @@ -131,7 +124,7 @@ func tcpListenerFactory(purpose string, l *listenerutil.ListenerConfig, ui cli.U case "ops": l.Address = "127.0.0.1:9203" default: - l.Address = "127.0.0.1:9200" + return "", nil, errors.New("no purpose provided for listener and no address given") } } @@ -139,6 +132,8 @@ func tcpListenerFactory(purpose string, l *listenerutil.ListenerConfig, ui cli.U if err != nil { if strings.Contains(err.Error(), "missing port") { switch purpose { + case "api": + port = "9200" case "cluster": port = "9201" case "proxy": @@ -146,7 +141,7 @@ func tcpListenerFactory(purpose string, l *listenerutil.ListenerConfig, ui cli.U case "ops": port = "9203" default: - port = "9200" + return "", nil, errors.New("no purpose provided for listener and no port discoverable") } host = l.Address } else { diff --git a/internal/cmd/base/servers.go b/internal/cmd/base/servers.go index 4a5c297705..d1486b32ea 100644 --- a/internal/cmd/base/servers.go +++ b/internal/cmd/base/servers.go @@ -381,7 +381,18 @@ func (b *Server) SetupListeners(ui cli.Ui, config *configutil.SharedConfig, allo // we ignore errors b.ShutdownFuncs = append(b.ShutdownFuncs, func() error { for _, ln := range b.Listeners { - ln.Mux.Close() + if ln.ProxyListener != nil { + ln.ProxyListener.Close() + } + if ln.ClusterListener != nil { + ln.ClusterListener.Close() + } + if ln.ApiListener != nil { + ln.ApiListener.Close() + } + if ln.OpsListener != nil { + ln.OpsListener.Close() + } } return nil }) @@ -414,7 +425,7 @@ func (b *Server) SetupListeners(ui cli.Ui, config *configutil.SharedConfig, allo } } - lnMux, props, reloadFunc, err := NewListener(lnConfig, ui) + ln, props, reloadFunc, err := NewListener(lnConfig, ui) if err != nil { return fmt.Errorf("Error initializing listener of type %s: %w", lnConfig.Type, err) } @@ -460,10 +471,22 @@ func (b *Server) SetupListeners(ui cli.Ui, config *configutil.SharedConfig, allo } props["max_request_duration"] = lnConfig.MaxRequestDuration.String() - b.Listeners = append(b.Listeners, &ServerListener{ - Mux: lnMux, + serverListener := &ServerListener{ Config: lnConfig, - }) + } + + switch purpose { + case "api": + serverListener.ApiListener = ln + case "cluster": + serverListener.ClusterListener = ln + case "proxy": + serverListener.ProxyListener = ln + case "ops": + serverListener.OpsListener = ln + } + + b.Listeners = append(b.Listeners, serverListener) props["purpose"] = strings.Join(lnConfig.Purpose, ",") diff --git a/internal/cmd/ops/server.go b/internal/cmd/ops/server.go index 3ae537bdb1..e24986c00f 100644 --- a/internal/cmd/ops/server.go +++ b/internal/cmd/ops/server.go @@ -11,7 +11,6 @@ import ( "time" "github.com/hashicorp/boundary/internal/cmd/base" - "github.com/hashicorp/boundary/internal/libs/alpnmux" "github.com/hashicorp/boundary/internal/servers/controller" "github.com/hashicorp/go-cleanhttp" "github.com/hashicorp/go-hclog" @@ -51,6 +50,9 @@ func NewServer(l hclog.Logger, c *controller.Controller, listeners ...*base.Serv if ln.Config.Purpose[0] != "ops" { continue } + if ln.OpsListener == nil { + return nil, fmt.Errorf("%s: missing ops listener", op) + } h, err := createOpsHandler(ln.Config, c) if err != nil { @@ -60,11 +62,7 @@ func NewServer(l hclog.Logger, c *controller.Controller, listeners ...*base.Serv b := &opsBundle{ln: ln, h: h} b.ln.HTTPServer = createHttpServer(l, b.h, b.ln.Config) - funcs, err := getStartFn(b.ln) - if err != nil { - return nil, err - } - b.startFn = funcs + b.startFn = []func(){func() { go b.ln.HTTPServer.Serve(b.ln.OpsListener) }} bundles = append(bundles, b) } @@ -89,7 +87,7 @@ func (s *Server) Shutdown() error { var closeErrors *multierror.Error for _, b := range s.bundles { - if b == nil || b.ln == nil || b.ln.Config == nil || b.ln.Mux == nil || b.ln.HTTPServer == nil { + if b == nil || b.ln == nil || b.ln.Config == nil || b.ln.OpsListener == nil || b.ln.HTTPServer == nil { return fmt.Errorf("%s: missing bundle, listener or its fields", op) } @@ -101,7 +99,7 @@ func (s *Server) Shutdown() error { multierror.Append(closeErrors, fmt.Errorf("%s: failed to shutdown http server: %w", op, err)) } - err = b.ln.Mux.Close() + err = b.ln.OpsListener.Close() err = listenerCloseErrorCheck(b.ln.Config.Type, err) if err != nil { multierror.Append(closeErrors, fmt.Errorf("%s: failed to close listener mux: %w", op, err)) @@ -169,34 +167,6 @@ func createHttpServer(l hclog.Logger, h http.Handler, lncfg *listenerutil.Listen return s } -func getStartFn(ln *base.ServerListener) ([]func(), error) { - const op = "getStartFn()" - - funcs := make([]func(), 0) - switch ln.Config.TLSDisable { - case true: - l, err := ln.Mux.RegisterProto(alpnmux.NoProto, nil) - if err != nil { - return nil, fmt.Errorf("%s: error getting non-tls listener: %w", op, err) - } - if l == nil { - return nil, fmt.Errorf("%s: could not get non-tls listener", op) - } - funcs = append(funcs, func() { go ln.HTTPServer.Serve(l) }) - - default: - for _, v := range []string{"", "http/1.1", "h2"} { - l := ln.Mux.GetListener(v) - if l == nil { - return nil, fmt.Errorf("%s: could not get tls proto %q listener", op, v) - } - funcs = append(funcs, func() { go ln.HTTPServer.Serve(l) }) - } - } - - return funcs, nil -} - func listenerCloseErrorCheck(lnType string, err error) error { if errors.Is(err, net.ErrClosed) { // Ignore net.ErrClosed - The listener was already closed, diff --git a/internal/cmd/ops/server_test.go b/internal/cmd/ops/server_test.go index 4591172142..3bbf9dfe44 100644 --- a/internal/cmd/ops/server_test.go +++ b/internal/cmd/ops/server_test.go @@ -18,7 +18,6 @@ import ( "time" "github.com/hashicorp/boundary/internal/cmd/base" - "github.com/hashicorp/boundary/internal/libs/alpnmux" "github.com/hashicorp/boundary/internal/servers/controller" "github.com/hashicorp/boundary/internal/servers/controller/handlers/health" "github.com/hashicorp/go-hclog" @@ -280,7 +279,7 @@ func TestNewServerIntegration(t *testing.T) { addrs := make([]string, 0, len(s.bundles)) for _, b := range s.bundles { - addrs = append(addrs, b.ln.Mux.Addr().String()) + addrs = append(addrs, b.ln.OpsListener.Addr().String()) } if tt.assertions != nil { tt.assertions(t, addrs) @@ -334,7 +333,6 @@ func TestShutdown(t *testing.T) { bundles: []*opsBundle{ { ln: &base.ServerListener{ - Mux: &alpnmux.ALPNMux{}, HTTPServer: &http.Server{}, }, }, @@ -368,7 +366,6 @@ func TestShutdown(t *testing.T) { bundles: []*opsBundle{ { ln: &base.ServerListener{ - Mux: &alpnmux.ALPNMux{}, Config: &listenerutil.ListenerConfig{}, }, }, @@ -390,10 +387,9 @@ func TestShutdown(t *testing.T) { bundles: []*opsBundle{ { ln: &base.ServerListener{ - ALPNListener: l1, - HTTPServer: s1, - Mux: alpnmux.New(l1), - Config: &listenerutil.ListenerConfig{Type: "tcp", Purpose: []string{"ops"}}, + HTTPServer: s1, + OpsListener: l1, + Config: &listenerutil.ListenerConfig{Type: "tcp", Purpose: []string{"ops"}}, }, }, }, @@ -401,10 +397,10 @@ func TestShutdown(t *testing.T) { }, assertions: func(t *testing.T, s *Server) { // The HTTP Server must be closed. - require.ErrorIs(t, s.bundles[0].ln.HTTPServer.Serve(s.bundles[0].ln.ALPNListener), http.ErrServerClosed) + require.ErrorIs(t, s.bundles[0].ln.HTTPServer.Serve(s.bundles[0].ln.OpsListener), http.ErrServerClosed) // The underlying listener must be closed. - require.ErrorIs(t, s.bundles[0].ln.Mux.Close(), net.ErrClosed) + require.ErrorIs(t, s.bundles[0].ln.OpsListener.Close(), net.ErrClosed) }, expErr: false, }, @@ -426,10 +422,9 @@ func TestShutdown(t *testing.T) { bundles: []*opsBundle{ { ln: &base.ServerListener{ - ALPNListener: l1, - HTTPServer: s1, - Mux: alpnmux.New(l1), - Config: &listenerutil.ListenerConfig{Type: "tcp", Purpose: []string{"ops"}}, + OpsListener: l1, + HTTPServer: s1, + Config: &listenerutil.ListenerConfig{Type: "tcp", Purpose: []string{"ops"}}, }, }, }, @@ -437,10 +432,10 @@ func TestShutdown(t *testing.T) { }, assertions: func(t *testing.T, s *Server) { // The HTTP Server must be closed. - require.ErrorIs(t, s.bundles[0].ln.HTTPServer.Serve(s.bundles[0].ln.ALPNListener), http.ErrServerClosed) + require.ErrorIs(t, s.bundles[0].ln.HTTPServer.Serve(s.bundles[0].ln.OpsListener), http.ErrServerClosed) // The underlying listener must be closed. - require.ErrorIs(t, s.bundles[0].ln.Mux.Close(), net.ErrClosed) + require.ErrorIs(t, s.bundles[0].ln.OpsListener.Close(), net.ErrClosed) }, }, { @@ -466,18 +461,16 @@ func TestShutdown(t *testing.T) { bundles: []*opsBundle{ { ln: &base.ServerListener{ - ALPNListener: l1, - HTTPServer: s1, - Mux: alpnmux.New(l1), - Config: &listenerutil.ListenerConfig{Type: "tcp", Purpose: []string{"ops"}}, + OpsListener: l1, + HTTPServer: s1, + Config: &listenerutil.ListenerConfig{Type: "tcp", Purpose: []string{"ops"}}, }, }, { ln: &base.ServerListener{ - ALPNListener: l2, - HTTPServer: s2, - Mux: alpnmux.New(l2), - Config: &listenerutil.ListenerConfig{Type: "tcp", Purpose: []string{"ops"}}, + OpsListener: l2, + HTTPServer: s2, + Config: &listenerutil.ListenerConfig{Type: "tcp", Purpose: []string{"ops"}}, }, }, }, @@ -485,12 +478,12 @@ func TestShutdown(t *testing.T) { }, assertions: func(t *testing.T, s *Server) { // The HTTP Server must be closed. - require.ErrorIs(t, s.bundles[0].ln.HTTPServer.Serve(s.bundles[0].ln.ALPNListener), http.ErrServerClosed) - require.ErrorIs(t, s.bundles[1].ln.HTTPServer.Serve(s.bundles[1].ln.ALPNListener), http.ErrServerClosed) + require.ErrorIs(t, s.bundles[0].ln.HTTPServer.Serve(s.bundles[0].ln.OpsListener), http.ErrServerClosed) + require.ErrorIs(t, s.bundles[1].ln.HTTPServer.Serve(s.bundles[1].ln.OpsListener), http.ErrServerClosed) // The underlying listener must be closed. - require.ErrorIs(t, s.bundles[0].ln.Mux.Close(), net.ErrClosed) - require.ErrorIs(t, s.bundles[1].ln.Mux.Close(), net.ErrClosed) + require.ErrorIs(t, s.bundles[0].ln.OpsListener.Close(), net.ErrClosed) + require.ErrorIs(t, s.bundles[1].ln.OpsListener.Close(), net.ErrClosed) }, }, } @@ -556,7 +549,7 @@ func TestHealthEndpointLifecycle(t *testing.T) { opsServer.Start() // Assert the ops endpoint is up and returning 200 OK. - rsp, err := http.Get("http://" + tc.Config().Listeners[0].Mux.Addr().String() + "/health") + rsp, err := http.Get("http://" + tc.Config().Listeners[0].OpsListener.Addr().String() + "/health") require.NoError(t, err) require.Equal(t, http.StatusOK, rsp.StatusCode) @@ -564,7 +557,7 @@ func TestHealthEndpointLifecycle(t *testing.T) { tc.Controller().HealthService.StartServiceUnavailableReplies() // Assert we're receiving 503 Service Unavailable now instead of 200 OK. - rsp, err = http.Get("http://" + tc.Config().Listeners[0].Mux.Addr().String() + "/health") + rsp, err = http.Get("http://" + tc.Config().Listeners[0].OpsListener.Addr().String() + "/health") require.NoError(t, err) require.Equal(t, http.StatusServiceUnavailable, rsp.StatusCode) } diff --git a/internal/libs/alpnmux/mux.go b/internal/libs/alpnmux/mux.go deleted file mode 100644 index 1615085c46..0000000000 --- a/internal/libs/alpnmux/mux.go +++ /dev/null @@ -1,273 +0,0 @@ -package alpnmux - -import ( - "bufio" - "context" - "crypto/tls" - "errors" - "fmt" - "net" - "strings" - "sync" - - "github.com/hashicorp/boundary/internal/observability/event" -) - -const ( - // NoProto is used when the connection isn't actually TLS - NoProto = "(none)" - - // DefaultProto is used when there is an ALPN we don't actually know about. - // If no protos are specified on an incoming TLS connection we will first - // look for a proto of ""; if not found, will use DefaultProto. On a - // connection that has protos defined, we will look for that proto first, - // then DefaultProto. - DefaultProto = "(*)" -) - -type bufferedConn struct { - net.Conn - buffer *bufio.Reader -} - -func (b *bufferedConn) Read(p []byte) (int, error) { - return b.buffer.Read(p) -} - -type muxedListener struct { - connMutex *sync.RWMutex - ctx context.Context - addr net.Addr - proto string - tlsConf *tls.Config - connCh chan net.Conn - closed bool - closeFunc func() - closeOnce *sync.Once -} - -type ALPNMux struct { - ctx context.Context - baseLn net.Listener - cancel context.CancelFunc - muxMap *sync.Map -} - -func New(baseLn net.Listener) *ALPNMux { - ctx, cancel := context.WithCancel(context.Background()) - ret := &ALPNMux{ - ctx: ctx, - cancel: cancel, - muxMap: new(sync.Map), - baseLn: baseLn, - } - go ret.accept() - return ret -} - -func (l *ALPNMux) Addr() net.Addr { - return l.baseLn.Addr() -} - -func (l *ALPNMux) Close() error { - return l.baseLn.Close() -} - -func (l *ALPNMux) RegisterProto(proto string, tlsConf *tls.Config) (net.Listener, error) { - const op = "alpnmux.(ALPNMux).RegisterProto" - switch proto { - case NoProto: - if tlsConf != nil { - return nil, errors.New("tls config cannot be non-nil when using NoProto") - } - default: - if tlsConf == nil { - return nil, errors.New("nil tls config given") - } - } - sub := &muxedListener{ - connMutex: new(sync.RWMutex), - ctx: l.ctx, - addr: l.baseLn.Addr(), - proto: proto, - tlsConf: tlsConf, - connCh: make(chan net.Conn), - closeOnce: new(sync.Once), - } - _, loaded := l.muxMap.LoadOrStore(proto, sub) - if loaded { - close(sub.connCh) - return nil, fmt.Errorf("proto %q already registered", proto) - } - - sub.closeFunc = func() { - go l.UnregisterProto(proto) - } - - return sub, nil -} - -func (l *ALPNMux) UnregisterProto(proto string) { - const op = "alpnmux.(ALPNMux).UnregisterProto" - val, ok := l.muxMap.Load(proto) - if !ok { - return - } - ml := val.(*muxedListener) - ml.closeOnce.Do(func() { - ml.connMutex.Lock() - defer ml.connMutex.Unlock() - ml.closed = true - close(ml.connCh) - }) - l.muxMap.Delete(proto) -} - -func (l *ALPNMux) GetListener(proto string) net.Listener { - val, ok := l.muxMap.Load(proto) - if !ok || val == nil { - val, ok = l.muxMap.Load(DefaultProto) - if !ok || val == nil { - return nil - } - } - return val.(*muxedListener) -} - -func (l *ALPNMux) getConfigForClient(hello *tls.ClientHelloInfo) (*tls.Config, error) { - const op = "alpnmux.(ALPNMux).getConfigForClient" - var ret *tls.Config - - supportedProtos := hello.SupportedProtos - if len(hello.SupportedProtos) == 0 { - supportedProtos = append(supportedProtos, "") - } - for _, proto := range supportedProtos { - val, ok := l.muxMap.Load(proto) - if !ok { - continue - } - ret = val.(*muxedListener).tlsConf - } - if ret == nil { - val, ok := l.muxMap.Load(DefaultProto) - if ok && val != nil { - ret = val.(*muxedListener).tlsConf - } - } - if ret == nil { - return nil, errors.New("no tls configuration available for any client protos") - } - - // If the TLS config we found has its own lookup function, chain to it - if ret.GetConfigForClient != nil { - return ret.GetConfigForClient(hello) - } - - return ret, nil -} - -func (l *ALPNMux) accept() { - const op = "alpnmux.(ALPNMux).accept" - ctx := context.TODO() - baseTLSConf := &tls.Config{ - GetConfigForClient: l.getConfigForClient, - } - for { - conn, err := l.baseLn.Accept() - if err != nil { - if strings.Contains(err.Error(), "use of closed network connection") { - l.cancel() - return - } - } - if conn == nil { - continue - } - - // Do the rest in a goroutine so that a timeout in e.g. handshaking - // doesn't block acceptance of the next connection - go func() { - bufConn := &bufferedConn{ - Conn: conn, - buffer: bufio.NewReader(conn), - } - peeked, err := bufConn.buffer.Peek(3) - if err != nil { - bufConn.Close() - return - } - switch { - // First byte should always be a handshake, second byte a 3, and - // third can be 3 or 1 depending on the implementation - case peeked[0] != 0x16 || peeked[1] != 0x03 || (peeked[2] != 0x03 && peeked[2] != 0x01): - val, ok := l.muxMap.Load(NoProto) - if !ok { - bufConn.Close() - return - } - ml := val.(*muxedListener) - ml.connMutex.RLock() - if !ml.closed { - ml.connCh <- bufConn - } - ml.connMutex.RUnlock() - - default: - tlsConn := tls.Server(bufConn, baseTLSConf) - if err := tlsConn.Handshake(); err != nil { - closeErr := tlsConn.Close() - if closeErr != nil { - event.WriteError(ctx, op, err, event.WithInfoMsg("error handshaking connection", "addr", conn.RemoteAddr(), "close_error", closeErr)) - } - return - } - negProto := tlsConn.ConnectionState().NegotiatedProtocol - val, ok := l.muxMap.Load(negProto) - if !ok { - val, ok = l.muxMap.Load(DefaultProto) - if !ok { - tlsConn.Close() - return - } - } - ml := val.(*muxedListener) - ml.connMutex.RLock() - if !ml.closed { - ml.connCh <- tlsConn - } - ml.connMutex.RUnlock() - } - }() - } -} - -func (m *muxedListener) Accept() (net.Conn, error) { - for { - select { - case <-m.ctx.Done(): - // Wouldn't it be so much better if this error was an exported - // const from Go... - m.closeFunc() - return nil, fmt.Errorf("accept proto %s: use of closed network connection", m.proto) - case conn, ok := <-m.connCh: - if !ok { - // Channel closed - return nil, fmt.Errorf("accept proto %s: use of closed network connection", m.proto) - } - if conn == nil { - return nil, fmt.Errorf("accept proto %s: nil connection received", m.proto) - } - return conn, nil - } - } -} - -func (m *muxedListener) Close() error { - m.closeFunc() - return nil -} - -func (m *muxedListener) Addr() net.Addr { - return m.addr -} diff --git a/internal/libs/alpnmux/mux_test.go b/internal/libs/alpnmux/mux_test.go deleted file mode 100644 index 8ef7bad9f2..0000000000 --- a/internal/libs/alpnmux/mux_test.go +++ /dev/null @@ -1,201 +0,0 @@ -package alpnmux - -import ( - "crypto/tls" - "fmt" - "log" - "net" - "strings" - "sync" - "testing" - "time" - - "github.com/hashicorp/boundary/internal/observability/event" - "github.com/hashicorp/go-hclog" - "go.uber.org/atomic" -) - -func TestListenCloseErrMsg(t *testing.T) { - listener := getListener(t) - listener.Close() - _, err := listener.Accept() - if !strings.Contains(err.Error(), "use of closed network connection") { - t.Fatal(err) - } -} - -func TestRegistrationErrors(t *testing.T) { - listener := getListener(t) - defer listener.Close() - mux := New(listener) - p1config := getTestTLS(t, []string{"p1"}) - if _, err := mux.RegisterProto("p1", nil); err.Error() != "nil tls config given" { - t.Fatal(err) - } - l, err := mux.RegisterProto("p1", p1config) - if err != nil { - t.Fatal(err) - } - if _, err := mux.RegisterProto("p1", p1config); err.Error() != `proto "p1" already registered` { - t.Fatal(err) - } - l.Close() - // Unregister is not sync, so need to wait for it to actually be removed - var unregistered bool - for i := 0; i < 5; i++ { - _, ok := mux.muxMap.Load("p1") - if !ok { - unregistered = true - break - } - time.Sleep(100 * time.Millisecond) - } - if !unregistered { - t.Fatal("failed to unregister proto") - } - l, err = mux.RegisterProto("p1", p1config) - if err != nil { - t.Fatal(err) - } - l.Close() - l, err = mux.RegisterProto(NoProto, nil) - if err != nil { - t.Fatal(err) - } - l.Close() -} - -func TestListening(t *testing.T) { - event.TestEnableEventing(t, true) - testConfig := event.DefaultEventerConfig() - testLock := &sync.Mutex{} - testLogger := hclog.New(&hclog.LoggerOptions{ - Mutex: testLock, - }) - err := event.InitSysEventer(testLogger, testLock, "TestListening", event.WithEventerConfig(testConfig)) - if err != nil { - t.Fatal(err) - } - listener := getListener(t) - - mux := New(listener) - defer mux.Close() - - emptyconns := atomic.NewUint32(0) - noneconns := atomic.NewUint32(0) - l1conns := atomic.NewUint32(0) - l2conns := atomic.NewUint32(0) - l3conns := atomic.NewUint32(0) - defconns := atomic.NewUint32(0) - clientCountTracker := atomic.NewUint32(0) - - baseconfig := getTestTLS(t, nil) - noneconfig := baseconfig.Clone() - p1config := baseconfig.Clone() - p1config.NextProtos = []string{"p1"} - p2p3config := getTestTLS(t, []string{"p2", "p3"}) - p3config := p2p3config.Clone() - p3config.NextProtos = []string{"p3"} - defconfig := baseconfig.Clone() - defconfig.GetConfigForClient = func(hello *tls.ClientHelloInfo) (*tls.Config, error) { - ret := baseconfig.Clone() - ret.NextProtos = []string{fmt.Sprintf("%d", clientCountTracker.Load())} - log.Printf("returning def config with next protos = %v\n", ret.NextProtos) - clientCountTracker.Inc() - return ret, nil - } - - lempty, err := mux.RegisterProto("", noneconfig) - if err != nil { - t.Fatal(err) - } - l1, err := mux.RegisterProto("p1", p1config) - if err != nil { - t.Fatal(err) - } - l2, err := mux.RegisterProto("p2", p2p3config) - if err != nil { - t.Fatal(err) - } - l3, err := mux.RegisterProto("p3", p2p3config) - if err != nil { - t.Fatal(err) - } - lnone, err := mux.RegisterProto(NoProto, nil) - if err != nil { - t.Fatal(err) - } - ldef, err := mux.RegisterProto(DefaultProto, defconfig) - if err != nil { - t.Fatal(err) - } - - addr := listener.Addr().String() - wg := new(sync.WaitGroup) - wg.Add(6) - connWatchFunc := func(l net.Listener, connCounter *atomic.Uint32, tlsConf *tls.Config, numConns int) { - defer wg.Done() - tlsToUse := tlsConf - go func() { - for i := 0; i < numConns; i++ { - var err error - var conn net.Conn - switch tlsToUse { - case nil: - conn, err = net.Dial("tcp4", addr) - if err != nil { - t.Fatal(err) - } - // We need to send some data here because we won't have any - // from just the TLS handshake - log.Println("defconn") - n, err := conn.Write([]byte("GET ")) - if err != nil { - t.Fatal(err) - } - if n != 4 { - t.Fatal(n) - } - log.Println("defconn done") - - default: - if connCounter == defconns { - tlsToUse = baseconfig.Clone() - log.Println("FOUND CURR") - tlsToUse.NextProtos = []string{fmt.Sprintf("%d", i)} - } - log.Println(fmt.Sprintf("dialing on %d, counter = %d, protos = %v", numConns, i, tlsToUse.NextProtos)) - conn, err = tls.Dial("tcp4", addr, tlsToUse) - if err != nil { - t.Fatal(err) - } - log.Println(fmt.Sprintf("dialing done on %d, counter = %d, protos = %v", numConns, i, tlsToUse.NextProtos)) - } - conn.Close() - } - }() - for i := 0; i < numConns; i++ { - log.Println(fmt.Sprintf("accepting on %d, counter = %d", numConns, connCounter.Load())) - conn, err := l.Accept() - if err == nil && conn != nil { - conn.Close() - } else { - t.Fatal(err) - } - log.Println(fmt.Sprintf("done accepting on %d, counter = %d", numConns, connCounter.Load())) - connCounter.Inc() - } - return - } - go connWatchFunc(lempty, emptyconns, noneconfig, 4) - go connWatchFunc(l1, l1conns, p1config, 5) - go connWatchFunc(l2, l2conns, p2p3config, 6) - go connWatchFunc(l3, l3conns, p3config, 7) - go connWatchFunc(lnone, noneconns, nil, 8) - go connWatchFunc(ldef, defconns, defconfig, 9) - wg.Wait() - - if emptyconns.Load() != 4 || l1conns.Load() != 5 || l2conns.Load() != 6 || l3conns.Load() != 7 || noneconns.Load() != 8 || defconns.Load() != 9 { - t.Fatal("wrong number of conns") - } -} diff --git a/internal/libs/alpnmux/testing.go b/internal/libs/alpnmux/testing.go deleted file mode 100644 index a7fde91098..0000000000 --- a/internal/libs/alpnmux/testing.go +++ /dev/null @@ -1,119 +0,0 @@ -package alpnmux - -import ( - "crypto/ed25519" - "crypto/rand" - "crypto/tls" - "crypto/x509" - "crypto/x509/pkix" - "encoding/pem" - "math/big" - mathrand "math/rand" - "net" - "testing" - "time" -) - -func getListener(t *testing.T) net.Listener { - addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:0") - if err != nil { - t.Fatal(err) - } - listener, err := net.ListenTCP("tcp", addr) - if err != nil { - t.Fatal(err) - } - return listener -} - -func getTestTLS(t *testing.T, protos []string) *tls.Config { - certIPs := []net.IP{ - net.IPv6loopback, - net.ParseIP("127.0.0.1"), - } - - _, caKey, err := ed25519.GenerateKey(rand.Reader) - if err != nil { - t.Fatal(err) - } - caCertTemplate := &x509.Certificate{ - Subject: pkix.Name{ - CommonName: "localhost", - }, - DNSNames: []string{"localhost"}, - IPAddresses: certIPs, - KeyUsage: x509.KeyUsage(x509.KeyUsageCertSign | x509.KeyUsageCRLSign), - SerialNumber: big.NewInt(mathrand.Int63()), - NotBefore: time.Now().Add(-30 * time.Second), - NotAfter: time.Now().Add(262980 * time.Hour), - BasicConstraintsValid: true, - IsCA: true, - } - caBytes, err := x509.CreateCertificate(rand.Reader, caCertTemplate, caCertTemplate, caKey.Public(), caKey) - if err != nil { - t.Fatal(err) - } - caCert, err := x509.ParseCertificate(caBytes) - if err != nil { - t.Fatal(err) - } - rootCAs := x509.NewCertPool() - rootCAs.AddCert(caCert) - - // - // Certs generation - // - _, key, err := ed25519.GenerateKey(rand.Reader) - if err != nil { - t.Fatal(err) - } - certTemplate := &x509.Certificate{ - Subject: pkix.Name{ - CommonName: "localhost", - }, - DNSNames: []string{"localhost"}, - IPAddresses: certIPs, - ExtKeyUsage: []x509.ExtKeyUsage{ - x509.ExtKeyUsageServerAuth, - x509.ExtKeyUsageClientAuth, - }, - KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement, - SerialNumber: big.NewInt(mathrand.Int63()), - NotBefore: time.Now().Add(-30 * time.Second), - NotAfter: time.Now().Add(262980 * time.Hour), - } - certBytes, err := x509.CreateCertificate(rand.Reader, certTemplate, caCert, key.Public(), caKey) - if err != nil { - t.Fatal(err) - } - certPEMBlock := &pem.Block{ - Type: "CERTIFICATE", - Bytes: certBytes, - } - certPEM := pem.EncodeToMemory(certPEMBlock) - marshaledKey, err := x509.MarshalPKCS8PrivateKey(key) - if err != nil { - t.Fatal(err) - } - keyPEMBlock := &pem.Block{ - Type: "PRIVATE KEY", - Bytes: marshaledKey, - } - keyPEM := pem.EncodeToMemory(keyPEMBlock) - - tlsCert, err := tls.X509KeyPair(certPEM, keyPEM) - if err != nil { - t.Fatal(err) - } - tlsConfig := &tls.Config{ - Certificates: []tls.Certificate{tlsCert}, - RootCAs: rootCAs, - ClientCAs: rootCAs, - ClientAuth: tls.RequestClientCert, - NextProtos: protos, - MinVersion: tls.VersionTLS12, - MaxVersion: tls.VersionTLS13, - } - - return tlsConfig -} diff --git a/internal/servers/controller/listeners.go b/internal/servers/controller/listeners.go index d7da78a88f..16e5bca17e 100644 --- a/internal/servers/controller/listeners.go +++ b/internal/servers/controller/listeners.go @@ -14,7 +14,6 @@ import ( grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" "github.com/hashicorp/boundary/internal/cmd/base" pbs "github.com/hashicorp/boundary/internal/gen/controller/servers/services" - "github.com/hashicorp/boundary/internal/libs/alpnmux" "github.com/hashicorp/boundary/internal/servers/controller/handlers/workers" "github.com/hashicorp/go-multierror" "google.golang.org/grpc" @@ -42,7 +41,7 @@ func (c *Controller) startListeners() error { for i := range c.apiListeners { ln := c.apiListeners[i] - apiServers, err := c.configureForAPI(ln) + apiServers, err := c.configureForApi(ln) if err != nil { return fmt.Errorf("failed to configure listener for api mode: %w", err) } @@ -62,7 +61,7 @@ func (c *Controller) startListeners() error { return nil } -func (c *Controller) configureForAPI(ln *base.ServerListener) ([]func(), error) { +func (c *Controller) configureForApi(ln *base.ServerListener) ([]func(), error) { apiServers := make([]func(), 0) handler, err := c.apiHandler(HandlerProperties{ @@ -97,39 +96,15 @@ func (c *Controller) configureForAPI(ln *base.ServerListener) ([]func(), error) server.IdleTimeout = ln.Config.HTTPIdleTimeout } - switch ln.Config.TLSDisable { - case true: - l, err := ln.Mux.RegisterProto(alpnmux.NoProto, nil) - if err != nil { - return nil, fmt.Errorf("error getting non-tls listener: %w", err) - } - if l == nil { - return nil, errors.New("could not get non-tls listener") - } - apiServers = append(apiServers, func() { go server.Serve(l) }) - - default: - for _, v := range []string{"", "http/1.1", "h2"} { - l := ln.Mux.GetListener(v) - if l == nil { - return nil, fmt.Errorf("could not get tls proto %q listener", v) - } - apiServers = append(apiServers, func() { go server.Serve(l) }) - } - } + apiServers = append(apiServers, func() { go server.Serve(ln.ApiListener) }) return apiServers, nil } func (c *Controller) configureForCluster(ln *base.ServerListener) (func(), error) { - // Clear out in case this is a second start of the controller - ln.Mux.UnregisterProto(alpnmux.DefaultProto) - l, err := ln.Mux.RegisterProto(alpnmux.DefaultProto, &tls.Config{ + l := tls.NewListener(ln.ClusterListener, &tls.Config{ GetConfigForClient: c.validateWorkerTls, }) - if err != nil { - return nil, fmt.Errorf("error getting sub-listener for worker proto: %w", err) - } workerReqInterceptor, err := workerRequestInfoInterceptor(c.baseContext, c.conf.Eventer) if err != nil { @@ -153,11 +128,9 @@ func (c *Controller) configureForCluster(ln *base.ServerListener) (func(), error pbs.RegisterServerCoordinationServiceServer(workerServer, workerService) pbs.RegisterSessionServiceServer(workerServer, workerService) - interceptor := newInterceptingListener(c, l) - ln.ALPNListener = interceptor ln.GrpcServer = workerServer - return func() { go ln.GrpcServer.Serve(ln.ALPNListener) }, nil + return func() { go ln.GrpcServer.Serve(newInterceptingListener(c, l)) }, nil } func (c *Controller) stopServersAndListeners() error { @@ -183,12 +156,12 @@ func (c *Controller) stopClusterGrpcServerAndListener() error { if c.clusterListener.GrpcServer == nil { return fmt.Errorf("no cluster grpc server") } - if c.clusterListener.Mux == nil { - return fmt.Errorf("no cluster listener mux") + if c.clusterListener.ClusterListener == nil { + return fmt.Errorf("no cluster listener") } c.clusterListener.GrpcServer.GracefulStop() - err := c.clusterListener.Mux.Close() + err := c.clusterListener.ClusterListener.Close() return listenerCloseErrorCheck(c.clusterListener.Config.Type, err) } @@ -204,7 +177,7 @@ func (c *Controller) stopHttpServersAndListeners() error { ln.HTTPServer.Shutdown(ctx) cancel() - err := ln.Mux.Close() // The HTTP Shutdown call should close this, but just in case. + err := ln.ApiListener.Close() // The HTTP Shutdown call should close this, but just in case. err = listenerCloseErrorCheck(ln.Config.Type, err) if err != nil { multierror.Append(closeErrors, err) @@ -231,11 +204,11 @@ func (c *Controller) stopAnyListeners() error { var closeErrors *multierror.Error for i := range c.apiListeners { ln := c.apiListeners[i] - if ln == nil || ln.Mux == nil { + if ln == nil || ln.ApiListener == nil { continue } - err := ln.Mux.Close() + err := ln.ApiListener.Close() err = listenerCloseErrorCheck(ln.Config.Type, err) if err != nil { multierror.Append(closeErrors, err) diff --git a/internal/servers/controller/listeners_test.go b/internal/servers/controller/listeners_test.go index daf0135a01..30a3b769cd 100644 --- a/internal/servers/controller/listeners_test.go +++ b/internal/servers/controller/listeners_test.go @@ -20,7 +20,6 @@ import ( "time" "github.com/hashicorp/boundary/internal/cmd/base" - "github.com/hashicorp/boundary/internal/libs/alpnmux" "github.com/hashicorp/go-secure-stdlib/base62" "github.com/hashicorp/go-secure-stdlib/configutil/v2" "github.com/hashicorp/go-secure-stdlib/listenerutil" @@ -256,9 +255,9 @@ func TestStartListeners(t *testing.T) { apiAddrs := make([]string, 0) for _, l := range c.apiListeners { - apiAddrs = append(apiAddrs, l.Mux.Addr().String()) + apiAddrs = append(apiAddrs, l.ApiListener.Addr().String()) } - tt.assertions(t, c, apiAddrs, c.clusterListener.Mux.Addr().String()) + tt.assertions(t, c, apiAddrs, c.clusterListener.ClusterListener.Addr().String()) }) } } @@ -297,7 +296,7 @@ func TestStopClusterGrpcServerAndListener(t *testing.T) { } }, expErr: true, - expErrStr: "no cluster listener mux", + expErrStr: "no cluster listener", }, { name: "listener already closed", @@ -308,15 +307,14 @@ func TestStopClusterGrpcServerAndListener(t *testing.T) { return &Controller{ clusterListener: &base.ServerListener{ - ALPNListener: l, - GrpcServer: grpc.NewServer(), - Mux: alpnmux.New(l), - Config: &listenerutil.ListenerConfig{Type: "tcp"}, + ClusterListener: l, + GrpcServer: grpc.NewServer(), + Config: &listenerutil.ListenerConfig{Type: "tcp"}, }, } }, assertions: func(t *testing.T, c *Controller) { - require.ErrorIs(t, c.clusterListener.Mux.Close(), net.ErrClosed) + require.ErrorIs(t, c.clusterListener.ClusterListener.Close(), net.ErrClosed) }, expErr: false, }, @@ -339,15 +337,14 @@ func TestStopClusterGrpcServerAndListener(t *testing.T) { return &Controller{ clusterListener: &base.ServerListener{ - ALPNListener: l, - GrpcServer: grpcServer, - Mux: alpnmux.New(l), - Config: &listenerutil.ListenerConfig{Type: "tcp"}, + ClusterListener: l, + GrpcServer: grpcServer, + Config: &listenerutil.ListenerConfig{Type: "tcp"}, }, } }, assertions: func(t *testing.T, c *Controller) { - require.ErrorIs(t, c.clusterListener.Mux.Close(), net.ErrClosed) + require.ErrorIs(t, c.clusterListener.ClusterListener.Close(), net.ErrClosed) }, expErr: false, }, @@ -421,20 +418,19 @@ func TestStopHttpServersAndListeners(t *testing.T) { baseContext: context.Background(), apiListeners: []*base.ServerListener{ { - ALPNListener: l1, - HTTPServer: s1, - Mux: alpnmux.New(l1), - Config: &listenerutil.ListenerConfig{Type: "tcp"}, + ApiListener: l1, + HTTPServer: s1, + Config: &listenerutil.ListenerConfig{Type: "tcp"}, }, }, } }, assertions: func(t *testing.T, c *Controller) { // Asserts the HTTP Servers are closed. - require.ErrorIs(t, c.apiListeners[0].HTTPServer.Serve(c.apiListeners[0].ALPNListener), http.ErrServerClosed) + require.ErrorIs(t, c.apiListeners[0].HTTPServer.Serve(c.apiListeners[0].ApiListener), http.ErrServerClosed) // Asserts the underlying listeners are closed. - require.ErrorIs(t, c.apiListeners[0].Mux.Close(), net.ErrClosed) + require.ErrorIs(t, c.apiListeners[0].ApiListener.Close(), net.ErrClosed) }, expErr: false, }, @@ -468,28 +464,26 @@ func TestStopHttpServersAndListeners(t *testing.T) { baseContext: context.Background(), apiListeners: []*base.ServerListener{ { - ALPNListener: l1, - HTTPServer: s1, - Mux: alpnmux.New(l1), - Config: &listenerutil.ListenerConfig{Type: "tcp"}, + ApiListener: l1, + HTTPServer: s1, + Config: &listenerutil.ListenerConfig{Type: "tcp"}, }, { - ALPNListener: l2, - HTTPServer: s2, - Mux: alpnmux.New(l2), - Config: &listenerutil.ListenerConfig{Type: "tcp"}, + ApiListener: l2, + HTTPServer: s2, + Config: &listenerutil.ListenerConfig{Type: "tcp"}, }, }, } }, assertions: func(t *testing.T, c *Controller) { // Asserts the HTTP Servers are closed. - require.ErrorIs(t, c.apiListeners[0].HTTPServer.Serve(c.apiListeners[0].ALPNListener), http.ErrServerClosed) - require.ErrorIs(t, c.apiListeners[1].HTTPServer.Serve(c.apiListeners[1].ALPNListener), http.ErrServerClosed) + require.ErrorIs(t, c.apiListeners[0].HTTPServer.Serve(c.apiListeners[0].ApiListener), http.ErrServerClosed) + require.ErrorIs(t, c.apiListeners[1].HTTPServer.Serve(c.apiListeners[1].ApiListener), http.ErrServerClosed) // Asserts the underlying listeners are closed. - require.ErrorIs(t, c.apiListeners[0].Mux.Close(), net.ErrClosed) - require.ErrorIs(t, c.apiListeners[1].Mux.Close(), net.ErrClosed) + require.ErrorIs(t, c.apiListeners[0].ApiListener.Close(), net.ErrClosed) + require.ErrorIs(t, c.apiListeners[1].ApiListener.Close(), net.ErrClosed) }, expErr: false, }, @@ -601,10 +595,10 @@ func TestStopAnyListeners(t *testing.T) { expErr: false, }, { - name: "listeners with nil mux", + name: "non-empty listeners with nil listeners", controllerFn: func(t *testing.T) *Controller { return &Controller{apiListeners: []*base.ServerListener{ - {Mux: nil}, {Mux: nil}, {Mux: nil}, + {ClusterListener: nil}, {ClusterListener: nil}, {ClusterListener: nil}, }} }, expErr: false, @@ -624,23 +618,23 @@ func TestStopAnyListeners(t *testing.T) { return &Controller{apiListeners: []*base.ServerListener{ { - Config: &listenerutil.ListenerConfig{Type: "tcp"}, - Mux: alpnmux.New(l1), + Config: &listenerutil.ListenerConfig{Type: "tcp"}, + ApiListener: l1, }, { - Config: &listenerutil.ListenerConfig{Type: "tcp"}, - Mux: alpnmux.New(l2), + Config: &listenerutil.ListenerConfig{Type: "tcp"}, + ApiListener: l2, }, { - Config: &listenerutil.ListenerConfig{Type: "tcp"}, - Mux: alpnmux.New(l3), + Config: &listenerutil.ListenerConfig{Type: "tcp"}, + ApiListener: l3, }, }} }, assertions: func(t *testing.T, c *Controller) { for i := range c.apiListeners { ln := c.apiListeners[i] - require.ErrorIs(t, ln.Mux.Close(), net.ErrClosed) + require.ErrorIs(t, ln.ApiListener.Close(), net.ErrClosed) } }, expErr: false, diff --git a/internal/servers/controller/testing.go b/internal/servers/controller/testing.go index 9037c6ba0a..8797692379 100644 --- a/internal/servers/controller/testing.go +++ b/internal/servers/controller/testing.go @@ -219,7 +219,15 @@ func (tc *TestController) addrs(purpose string) []string { addrs := make([]string, 0, len(tc.b.Listeners)) for _, listener := range tc.b.Listeners { if listener.Config.Purpose[0] == purpose { - addr := listener.Mux.Addr() + var addr net.Addr + switch purpose { + case "api": + addr = listener.ApiListener.Addr() + case "cluster": + addr = listener.ClusterListener.Addr() + case "ops": + addr = listener.OpsListener.Addr() + } switch { case strings.HasPrefix(addr.String(), "/"): switch purpose { diff --git a/internal/servers/worker/listeners.go b/internal/servers/worker/listeners.go index 25521aef03..12d9692ad2 100644 --- a/internal/servers/worker/listeners.go +++ b/internal/servers/worker/listeners.go @@ -12,7 +12,6 @@ import ( "time" "github.com/hashicorp/boundary/internal/cmd/base" - "github.com/hashicorp/boundary/internal/libs/alpnmux" "github.com/hashicorp/boundary/internal/observability/event" "github.com/hashicorp/go-multierror" ) @@ -24,7 +23,7 @@ func (w *Worker) startListeners() error { if e == nil { return fmt.Errorf("%s: sys eventer not initialized", op) } - logger, err := e.StandardLogger(w.baseContext, "listeners", event.ErrorType) + logger, err := e.StandardLogger(w.baseContext, "worker.listeners: ", event.ErrorType) if err != nil { return fmt.Errorf("%s: unable to initialize std logger: %w", op, err) } @@ -46,7 +45,7 @@ func (w *Worker) startListeners() error { return nil } -func (w *Worker) configureForWorker(ln *base.ServerListener, log *log.Logger) (func(), error) { +func (w *Worker) configureForWorker(ln *base.ServerListener, logger *log.Logger) (func(), error) { handler, err := w.handler(HandlerProperties{ListenerConfig: ln.Config}) if err != nil { return nil, err @@ -57,7 +56,7 @@ func (w *Worker) configureForWorker(ln *base.ServerListener, log *log.Logger) (f Handler: handler, ReadHeaderTimeout: 10 * time.Second, ReadTimeout: 30 * time.Second, - ErrorLog: log, + ErrorLog: logger, BaseContext: func(net.Listener) context.Context { return cancelCtx }, @@ -77,18 +76,9 @@ func (w *Worker) configureForWorker(ln *base.ServerListener, log *log.Logger) (f server.IdleTimeout = ln.Config.HTTPIdleTimeout } - // Clear out in case this is a second start of the controller - ln.Mux.UnregisterProto(alpnmux.DefaultProto) - ln.Mux.UnregisterProto(alpnmux.NoProto) - l, err := ln.Mux.RegisterProto(alpnmux.DefaultProto, &tls.Config{ + l := tls.NewListener(ln.ProxyListener, &tls.Config{ GetConfigForClient: w.getSessionTls, }) - if err != nil { - return nil, fmt.Errorf("error getting tls listener: %w", err) - } - if l == nil { - return nil, errors.New("could not get tls listener") - } return func() { go server.Serve(l) }, nil } @@ -120,7 +110,7 @@ func (w *Worker) stopHttpServersAndListeners() error { ln.HTTPServer.Shutdown(ctx) cancel() - err := ln.Mux.Close() + err := ln.ProxyListener.Close() err = listenerCloseErrorCheck(ln.Config.Type, err) if err != nil { multierror.Append(closeErrors, err) @@ -136,11 +126,11 @@ func (w *Worker) stopHttpServersAndListeners() error { func (w *Worker) stopAnyListeners() error { var closeErrors *multierror.Error for _, ln := range w.listeners { - if ln == nil || ln.Mux == nil { + if ln == nil || ln.ProxyListener == nil { continue } - err := ln.Mux.Close() + err := ln.ProxyListener.Close() err = listenerCloseErrorCheck(ln.Config.Type, err) if err != nil { multierror.Append(closeErrors, err) diff --git a/internal/servers/worker/listeners_test.go b/internal/servers/worker/listeners_test.go index 46f4a8689d..f8b6e29747 100644 --- a/internal/servers/worker/listeners_test.go +++ b/internal/servers/worker/listeners_test.go @@ -12,12 +12,20 @@ import ( "time" "github.com/hashicorp/boundary/internal/cmd/base" - "github.com/hashicorp/boundary/internal/libs/alpnmux" "github.com/hashicorp/go-secure-stdlib/listenerutil" "github.com/stretchr/testify/require" ) func TestStartListeners(t *testing.T) { + testNonTlsRejected := func(t *testing.T, resp *http.Response, err error) { + require.NoError(t, err) + require.NotNil(t, resp) + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, "Client sent an HTTP request to an HTTPS server.\n", string(body)) + } + tests := []struct { name string listeners []*listenerutil.ListenerConfig @@ -45,16 +53,16 @@ func TestStartListeners(t *testing.T) { assertions: func(t *testing.T, w *Worker, addrs []string) { require.Len(t, addrs, 2) - _, err := http.Get("http://" + addrs[0] + "/v1/proxy/") - require.ErrorIs(t, err, io.EOF) // empty response because of worker tls request validation + resp, err := http.Get("http://" + addrs[0] + "/v1/proxy/") + testNonTlsRejected(t, resp, err) cl := http.Client{ Transport: &http.Transport{ Dial: func(network, addr string) (net.Conn, error) { return net.Dial("unix", addrs[1]) }, }, } - _, err = cl.Get("http://anything.domain/v1/proxy/") - require.ErrorIs(t, err, io.EOF) // empty response because of worker tls request validation + resp, err = cl.Get("http://anything.domain/v1/proxy/") + testNonTlsRejected(t, resp, err) }, }, { @@ -89,11 +97,11 @@ func TestStartListeners(t *testing.T) { assertions: func(t *testing.T, w *Worker, addrs []string) { require.Len(t, addrs, 4) - _, err := http.Get("http://" + addrs[0] + "/v1/proxy/") - require.ErrorIs(t, err, io.EOF) // empty response because of worker tls request validation + resp, err := http.Get("http://" + addrs[0] + "/v1/proxy/") + testNonTlsRejected(t, resp, err) - _, err = http.Get("http://" + addrs[1] + "/v1/proxy/") - require.ErrorIs(t, err, io.EOF) // empty response because of worker tls request validation + resp, err = http.Get("http://" + addrs[1] + "/v1/proxy/") + testNonTlsRejected(t, resp, err) for _, proxyAddr := range []string{addrs[2], addrs[3]} { cl := http.Client{ @@ -101,8 +109,8 @@ func TestStartListeners(t *testing.T) { Dial: func(network, addr string) (net.Conn, error) { return net.Dial("unix", proxyAddr) }, }, } - _, err = cl.Get("http://anything.domain/v1/proxy") - require.ErrorIs(t, err, io.EOF) // empty response because of worker tls request validation + resp, err = cl.Get("http://anything.domain/v1/proxy") + testNonTlsRejected(t, resp, err) } }, }, @@ -130,7 +138,7 @@ func TestStartListeners(t *testing.T) { addrs := make([]string, 0) for _, l := range w.listeners { - addrs = append(addrs, l.Mux.Addr().String()) + addrs = append(addrs, l.ProxyListener.Addr().String()) } if tt.assertions != nil { tt.assertions(t, w, addrs) @@ -204,28 +212,26 @@ func TestStopHttpServersAndListeners(t *testing.T) { baseContext: context.Background(), listeners: []*base.ServerListener{ { - ALPNListener: l1, - HTTPServer: s1, - Mux: alpnmux.New(l1), - Config: &listenerutil.ListenerConfig{Type: "tcp"}, + ProxyListener: l1, + HTTPServer: s1, + Config: &listenerutil.ListenerConfig{Type: "tcp"}, }, { - ALPNListener: l2, - HTTPServer: s2, - Mux: alpnmux.New(l2), - Config: &listenerutil.ListenerConfig{Type: "tcp"}, + ProxyListener: l2, + HTTPServer: s2, + Config: &listenerutil.ListenerConfig{Type: "tcp"}, }, }, } }, assertions: func(t *testing.T, w *Worker) { // Asserts the HTTP Servers are closed. - require.ErrorIs(t, w.listeners[0].HTTPServer.Serve(w.listeners[0].ALPNListener), http.ErrServerClosed) - require.ErrorIs(t, w.listeners[1].HTTPServer.Serve(w.listeners[1].ALPNListener), http.ErrServerClosed) + require.ErrorIs(t, w.listeners[0].HTTPServer.Serve(w.listeners[0].ProxyListener), http.ErrServerClosed) + require.ErrorIs(t, w.listeners[1].HTTPServer.Serve(w.listeners[1].ProxyListener), http.ErrServerClosed) // Asserts the underlying listeners are closed. - require.ErrorIs(t, w.listeners[0].Mux.Close(), net.ErrClosed) - require.ErrorIs(t, w.listeners[1].Mux.Close(), net.ErrClosed) + require.ErrorIs(t, w.listeners[0].ProxyListener.Close(), net.ErrClosed) + require.ErrorIs(t, w.listeners[1].ProxyListener.Close(), net.ErrClosed) }, expErr: false, }, @@ -278,10 +284,10 @@ func TestStopAnyListeners(t *testing.T) { expErr: false, }, { - name: "listeners with nil mux", + name: "non-empty but nil listeners", workerFn: func(t *testing.T) *Worker { return &Worker{listeners: []*base.ServerListener{ - {Mux: nil}, {Mux: nil}, {Mux: nil}, + {ProxyListener: nil}, {ProxyListener: nil}, {ProxyListener: nil}, }} }, expErr: false, @@ -301,23 +307,23 @@ func TestStopAnyListeners(t *testing.T) { return &Worker{listeners: []*base.ServerListener{ { - Config: &listenerutil.ListenerConfig{Type: "tcp"}, - Mux: alpnmux.New(l1), + Config: &listenerutil.ListenerConfig{Type: "tcp"}, + ProxyListener: l1, }, { - Config: &listenerutil.ListenerConfig{Type: "tcp"}, - Mux: alpnmux.New(l2), + Config: &listenerutil.ListenerConfig{Type: "tcp"}, + ProxyListener: l2, }, { - Config: &listenerutil.ListenerConfig{Type: "tcp"}, - Mux: alpnmux.New(l3), + Config: &listenerutil.ListenerConfig{Type: "tcp"}, + ProxyListener: l3, }, }} }, assertions: func(t *testing.T, w *Worker) { for i := range w.listeners { ln := w.listeners[i] - require.ErrorIs(t, ln.Mux.Close(), net.ErrClosed) + require.ErrorIs(t, ln.ProxyListener.Close(), net.ErrClosed) } }, expErr: false, diff --git a/internal/servers/worker/testing.go b/internal/servers/worker/testing.go index e13d3aacc6..4775ed7a33 100644 --- a/internal/servers/worker/testing.go +++ b/internal/servers/worker/testing.go @@ -68,7 +68,7 @@ func (tw *TestWorker) ProxyAddrs() []string { for _, listener := range tw.b.Listeners { if listener.Config.Purpose[0] == "proxy" { - tcpAddr, ok := listener.Mux.Addr().(*net.TCPAddr) + tcpAddr, ok := listener.ProxyListener.Addr().(*net.TCPAddr) if !ok { tw.t.Fatal("could not parse address as a TCP addr") }