Remove ALPN Muxer (#1965)

We never ended up using this and it makes it much harder to reason about the listeners.

An explanation: for the moment I separated out all of the listeners in the struct, even though one should suffice, because when I did it that way first I was hitting some really weird and hard to find intermittent behavior where it seemed like the listeners were getting crossed. So I re-did it this way with each listener very explicitly referenced. I'm happy to make a follow-up PR that attempts to combine them all into a single listener without hitting the issues I did, if desired, but at least doing it via a separate PR means that it won't be something with removing the ALPN muxer that is at fault.
pull/1961/head
Jeff Mitchell 4 years ago committed by GitHub
parent ce7292d9c8
commit 07dc3db974
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -11,8 +11,8 @@ import (
// We must import sha512 so that it registers with the runtime so that
// certificates that use it can be parsed.
_ "crypto/sha512"
"crypto/tls"
"github.com/hashicorp/boundary/internal/libs/alpnmux"
"github.com/hashicorp/go-secure-stdlib/listenerutil"
"github.com/hashicorp/go-secure-stdlib/reloadutil"
"github.com/mitchellh/cli"
@ -21,11 +21,13 @@ import (
)
type ServerListener struct {
Mux *alpnmux.ALPNMux
Config *listenerutil.ListenerConfig
HTTPServer *http.Server
GrpcServer *grpc.Server
ALPNListener net.Listener
Config *listenerutil.ListenerConfig
HTTPServer *http.Server
GrpcServer *grpc.Server
ApiListener net.Listener
ClusterListener net.Listener
ProxyListener net.Listener
OpsListener net.Listener
}
type WorkerAuthInfo struct {
@ -47,7 +49,7 @@ var BuiltinListeners = map[string]ListenerFactory{
// New creates a new listener of the given type with the given
// configuration. The type is looked up in the BuiltinListeners map.
func NewListener(l *listenerutil.ListenerConfig, ui cli.Ui) (*alpnmux.ALPNMux, map[string]string, reloadutil.ReloadFunc, error) {
func NewListener(l *listenerutil.ListenerConfig, ui cli.Ui) (net.Listener, map[string]string, reloadutil.ReloadFunc, error) {
f, ok := BuiltinListeners[l.Type]
if !ok {
return nil, nil, nil, fmt.Errorf("unknown listener type: %q", l.Type)
@ -58,8 +60,23 @@ func NewListener(l *listenerutil.ListenerConfig, ui cli.Ui) (*alpnmux.ALPNMux, m
}
purpose := l.Purpose[0]
finalAddr, ln, err := f(purpose, l, ui)
if err != nil {
return nil, nil, nil, err
}
ln, err = listenerWrapProxy(ln, l)
if err != nil {
return nil, nil, nil, err
}
props := map[string]string{
"addr": finalAddr,
}
switch purpose {
case "cluster":
// We handle our own cluster authentication
l.TLSDisable = true
case "proxy":
// TODO: Eventually we'll support bringing your own cert, and we'd only
@ -78,24 +95,8 @@ func NewListener(l *listenerutil.ListenerConfig, ui cli.Ui) (*alpnmux.ALPNMux, m
}
}
finalAddr, ln, err := f(purpose, l, ui)
if err != nil {
return nil, nil, nil, err
}
ln, err = listenerWrapProxy(ln, l)
if err != nil {
return nil, nil, nil, err
}
props := map[string]string{
"addr": finalAddr,
}
alpnMux := alpnmux.New(ln)
if l.TLSDisable {
return alpnMux, props, nil, nil
return ln, props, nil, nil
}
// Don't request a client cert unless they've explicitly configured it to do
@ -107,23 +108,15 @@ func NewListener(l *listenerutil.ListenerConfig, ui cli.Ui) (*alpnmux.ALPNMux, m
if err != nil {
return nil, nil, nil, err
}
// Register no proto, "http/1.1", and "h2", with same TLS config
if _, err = alpnMux.RegisterProto("", tlsConfig); err != nil {
return nil, nil, nil, err
}
if _, err = alpnMux.RegisterProto("http/1.1", tlsConfig); err != nil {
return nil, nil, nil, err
}
if _, err = alpnMux.RegisterProto("h2", tlsConfig); err != nil {
return nil, nil, nil, err
}
return alpnMux, props, reloadFunc, nil
return tls.NewListener(ln, tlsConfig), props, reloadFunc, nil
}
func tcpListenerFactory(purpose string, l *listenerutil.ListenerConfig, ui cli.Ui) (string, net.Listener, error) {
if l.Address == "" {
switch purpose {
case "api":
l.Address = "127.0.0.1:9200"
case "cluster":
l.Address = "127.0.0.1:9201"
case "proxy":
@ -131,7 +124,7 @@ func tcpListenerFactory(purpose string, l *listenerutil.ListenerConfig, ui cli.U
case "ops":
l.Address = "127.0.0.1:9203"
default:
l.Address = "127.0.0.1:9200"
return "", nil, errors.New("no purpose provided for listener and no address given")
}
}
@ -139,6 +132,8 @@ func tcpListenerFactory(purpose string, l *listenerutil.ListenerConfig, ui cli.U
if err != nil {
if strings.Contains(err.Error(), "missing port") {
switch purpose {
case "api":
port = "9200"
case "cluster":
port = "9201"
case "proxy":
@ -146,7 +141,7 @@ func tcpListenerFactory(purpose string, l *listenerutil.ListenerConfig, ui cli.U
case "ops":
port = "9203"
default:
port = "9200"
return "", nil, errors.New("no purpose provided for listener and no port discoverable")
}
host = l.Address
} else {

@ -381,7 +381,18 @@ func (b *Server) SetupListeners(ui cli.Ui, config *configutil.SharedConfig, allo
// we ignore errors
b.ShutdownFuncs = append(b.ShutdownFuncs, func() error {
for _, ln := range b.Listeners {
ln.Mux.Close()
if ln.ProxyListener != nil {
ln.ProxyListener.Close()
}
if ln.ClusterListener != nil {
ln.ClusterListener.Close()
}
if ln.ApiListener != nil {
ln.ApiListener.Close()
}
if ln.OpsListener != nil {
ln.OpsListener.Close()
}
}
return nil
})
@ -414,7 +425,7 @@ func (b *Server) SetupListeners(ui cli.Ui, config *configutil.SharedConfig, allo
}
}
lnMux, props, reloadFunc, err := NewListener(lnConfig, ui)
ln, props, reloadFunc, err := NewListener(lnConfig, ui)
if err != nil {
return fmt.Errorf("Error initializing listener of type %s: %w", lnConfig.Type, err)
}
@ -460,10 +471,22 @@ func (b *Server) SetupListeners(ui cli.Ui, config *configutil.SharedConfig, allo
}
props["max_request_duration"] = lnConfig.MaxRequestDuration.String()
b.Listeners = append(b.Listeners, &ServerListener{
Mux: lnMux,
serverListener := &ServerListener{
Config: lnConfig,
})
}
switch purpose {
case "api":
serverListener.ApiListener = ln
case "cluster":
serverListener.ClusterListener = ln
case "proxy":
serverListener.ProxyListener = ln
case "ops":
serverListener.OpsListener = ln
}
b.Listeners = append(b.Listeners, serverListener)
props["purpose"] = strings.Join(lnConfig.Purpose, ",")

@ -11,7 +11,6 @@ import (
"time"
"github.com/hashicorp/boundary/internal/cmd/base"
"github.com/hashicorp/boundary/internal/libs/alpnmux"
"github.com/hashicorp/boundary/internal/servers/controller"
"github.com/hashicorp/go-cleanhttp"
"github.com/hashicorp/go-hclog"
@ -51,6 +50,9 @@ func NewServer(l hclog.Logger, c *controller.Controller, listeners ...*base.Serv
if ln.Config.Purpose[0] != "ops" {
continue
}
if ln.OpsListener == nil {
return nil, fmt.Errorf("%s: missing ops listener", op)
}
h, err := createOpsHandler(ln.Config, c)
if err != nil {
@ -60,11 +62,7 @@ func NewServer(l hclog.Logger, c *controller.Controller, listeners ...*base.Serv
b := &opsBundle{ln: ln, h: h}
b.ln.HTTPServer = createHttpServer(l, b.h, b.ln.Config)
funcs, err := getStartFn(b.ln)
if err != nil {
return nil, err
}
b.startFn = funcs
b.startFn = []func(){func() { go b.ln.HTTPServer.Serve(b.ln.OpsListener) }}
bundles = append(bundles, b)
}
@ -89,7 +87,7 @@ func (s *Server) Shutdown() error {
var closeErrors *multierror.Error
for _, b := range s.bundles {
if b == nil || b.ln == nil || b.ln.Config == nil || b.ln.Mux == nil || b.ln.HTTPServer == nil {
if b == nil || b.ln == nil || b.ln.Config == nil || b.ln.OpsListener == nil || b.ln.HTTPServer == nil {
return fmt.Errorf("%s: missing bundle, listener or its fields", op)
}
@ -101,7 +99,7 @@ func (s *Server) Shutdown() error {
multierror.Append(closeErrors, fmt.Errorf("%s: failed to shutdown http server: %w", op, err))
}
err = b.ln.Mux.Close()
err = b.ln.OpsListener.Close()
err = listenerCloseErrorCheck(b.ln.Config.Type, err)
if err != nil {
multierror.Append(closeErrors, fmt.Errorf("%s: failed to close listener mux: %w", op, err))
@ -169,34 +167,6 @@ func createHttpServer(l hclog.Logger, h http.Handler, lncfg *listenerutil.Listen
return s
}
func getStartFn(ln *base.ServerListener) ([]func(), error) {
const op = "getStartFn()"
funcs := make([]func(), 0)
switch ln.Config.TLSDisable {
case true:
l, err := ln.Mux.RegisterProto(alpnmux.NoProto, nil)
if err != nil {
return nil, fmt.Errorf("%s: error getting non-tls listener: %w", op, err)
}
if l == nil {
return nil, fmt.Errorf("%s: could not get non-tls listener", op)
}
funcs = append(funcs, func() { go ln.HTTPServer.Serve(l) })
default:
for _, v := range []string{"", "http/1.1", "h2"} {
l := ln.Mux.GetListener(v)
if l == nil {
return nil, fmt.Errorf("%s: could not get tls proto %q listener", op, v)
}
funcs = append(funcs, func() { go ln.HTTPServer.Serve(l) })
}
}
return funcs, nil
}
func listenerCloseErrorCheck(lnType string, err error) error {
if errors.Is(err, net.ErrClosed) {
// Ignore net.ErrClosed - The listener was already closed,

@ -18,7 +18,6 @@ import (
"time"
"github.com/hashicorp/boundary/internal/cmd/base"
"github.com/hashicorp/boundary/internal/libs/alpnmux"
"github.com/hashicorp/boundary/internal/servers/controller"
"github.com/hashicorp/boundary/internal/servers/controller/handlers/health"
"github.com/hashicorp/go-hclog"
@ -280,7 +279,7 @@ func TestNewServerIntegration(t *testing.T) {
addrs := make([]string, 0, len(s.bundles))
for _, b := range s.bundles {
addrs = append(addrs, b.ln.Mux.Addr().String())
addrs = append(addrs, b.ln.OpsListener.Addr().String())
}
if tt.assertions != nil {
tt.assertions(t, addrs)
@ -334,7 +333,6 @@ func TestShutdown(t *testing.T) {
bundles: []*opsBundle{
{
ln: &base.ServerListener{
Mux: &alpnmux.ALPNMux{},
HTTPServer: &http.Server{},
},
},
@ -368,7 +366,6 @@ func TestShutdown(t *testing.T) {
bundles: []*opsBundle{
{
ln: &base.ServerListener{
Mux: &alpnmux.ALPNMux{},
Config: &listenerutil.ListenerConfig{},
},
},
@ -390,10 +387,9 @@ func TestShutdown(t *testing.T) {
bundles: []*opsBundle{
{
ln: &base.ServerListener{
ALPNListener: l1,
HTTPServer: s1,
Mux: alpnmux.New(l1),
Config: &listenerutil.ListenerConfig{Type: "tcp", Purpose: []string{"ops"}},
HTTPServer: s1,
OpsListener: l1,
Config: &listenerutil.ListenerConfig{Type: "tcp", Purpose: []string{"ops"}},
},
},
},
@ -401,10 +397,10 @@ func TestShutdown(t *testing.T) {
},
assertions: func(t *testing.T, s *Server) {
// The HTTP Server must be closed.
require.ErrorIs(t, s.bundles[0].ln.HTTPServer.Serve(s.bundles[0].ln.ALPNListener), http.ErrServerClosed)
require.ErrorIs(t, s.bundles[0].ln.HTTPServer.Serve(s.bundles[0].ln.OpsListener), http.ErrServerClosed)
// The underlying listener must be closed.
require.ErrorIs(t, s.bundles[0].ln.Mux.Close(), net.ErrClosed)
require.ErrorIs(t, s.bundles[0].ln.OpsListener.Close(), net.ErrClosed)
},
expErr: false,
},
@ -426,10 +422,9 @@ func TestShutdown(t *testing.T) {
bundles: []*opsBundle{
{
ln: &base.ServerListener{
ALPNListener: l1,
HTTPServer: s1,
Mux: alpnmux.New(l1),
Config: &listenerutil.ListenerConfig{Type: "tcp", Purpose: []string{"ops"}},
OpsListener: l1,
HTTPServer: s1,
Config: &listenerutil.ListenerConfig{Type: "tcp", Purpose: []string{"ops"}},
},
},
},
@ -437,10 +432,10 @@ func TestShutdown(t *testing.T) {
},
assertions: func(t *testing.T, s *Server) {
// The HTTP Server must be closed.
require.ErrorIs(t, s.bundles[0].ln.HTTPServer.Serve(s.bundles[0].ln.ALPNListener), http.ErrServerClosed)
require.ErrorIs(t, s.bundles[0].ln.HTTPServer.Serve(s.bundles[0].ln.OpsListener), http.ErrServerClosed)
// The underlying listener must be closed.
require.ErrorIs(t, s.bundles[0].ln.Mux.Close(), net.ErrClosed)
require.ErrorIs(t, s.bundles[0].ln.OpsListener.Close(), net.ErrClosed)
},
},
{
@ -466,18 +461,16 @@ func TestShutdown(t *testing.T) {
bundles: []*opsBundle{
{
ln: &base.ServerListener{
ALPNListener: l1,
HTTPServer: s1,
Mux: alpnmux.New(l1),
Config: &listenerutil.ListenerConfig{Type: "tcp", Purpose: []string{"ops"}},
OpsListener: l1,
HTTPServer: s1,
Config: &listenerutil.ListenerConfig{Type: "tcp", Purpose: []string{"ops"}},
},
},
{
ln: &base.ServerListener{
ALPNListener: l2,
HTTPServer: s2,
Mux: alpnmux.New(l2),
Config: &listenerutil.ListenerConfig{Type: "tcp", Purpose: []string{"ops"}},
OpsListener: l2,
HTTPServer: s2,
Config: &listenerutil.ListenerConfig{Type: "tcp", Purpose: []string{"ops"}},
},
},
},
@ -485,12 +478,12 @@ func TestShutdown(t *testing.T) {
},
assertions: func(t *testing.T, s *Server) {
// The HTTP Server must be closed.
require.ErrorIs(t, s.bundles[0].ln.HTTPServer.Serve(s.bundles[0].ln.ALPNListener), http.ErrServerClosed)
require.ErrorIs(t, s.bundles[1].ln.HTTPServer.Serve(s.bundles[1].ln.ALPNListener), http.ErrServerClosed)
require.ErrorIs(t, s.bundles[0].ln.HTTPServer.Serve(s.bundles[0].ln.OpsListener), http.ErrServerClosed)
require.ErrorIs(t, s.bundles[1].ln.HTTPServer.Serve(s.bundles[1].ln.OpsListener), http.ErrServerClosed)
// The underlying listener must be closed.
require.ErrorIs(t, s.bundles[0].ln.Mux.Close(), net.ErrClosed)
require.ErrorIs(t, s.bundles[1].ln.Mux.Close(), net.ErrClosed)
require.ErrorIs(t, s.bundles[0].ln.OpsListener.Close(), net.ErrClosed)
require.ErrorIs(t, s.bundles[1].ln.OpsListener.Close(), net.ErrClosed)
},
},
}
@ -556,7 +549,7 @@ func TestHealthEndpointLifecycle(t *testing.T) {
opsServer.Start()
// Assert the ops endpoint is up and returning 200 OK.
rsp, err := http.Get("http://" + tc.Config().Listeners[0].Mux.Addr().String() + "/health")
rsp, err := http.Get("http://" + tc.Config().Listeners[0].OpsListener.Addr().String() + "/health")
require.NoError(t, err)
require.Equal(t, http.StatusOK, rsp.StatusCode)
@ -564,7 +557,7 @@ func TestHealthEndpointLifecycle(t *testing.T) {
tc.Controller().HealthService.StartServiceUnavailableReplies()
// Assert we're receiving 503 Service Unavailable now instead of 200 OK.
rsp, err = http.Get("http://" + tc.Config().Listeners[0].Mux.Addr().String() + "/health")
rsp, err = http.Get("http://" + tc.Config().Listeners[0].OpsListener.Addr().String() + "/health")
require.NoError(t, err)
require.Equal(t, http.StatusServiceUnavailable, rsp.StatusCode)
}

@ -1,273 +0,0 @@
package alpnmux
import (
"bufio"
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"strings"
"sync"
"github.com/hashicorp/boundary/internal/observability/event"
)
const (
// NoProto is used when the connection isn't actually TLS
NoProto = "(none)"
// DefaultProto is used when there is an ALPN we don't actually know about.
// If no protos are specified on an incoming TLS connection we will first
// look for a proto of ""; if not found, will use DefaultProto. On a
// connection that has protos defined, we will look for that proto first,
// then DefaultProto.
DefaultProto = "(*)"
)
type bufferedConn struct {
net.Conn
buffer *bufio.Reader
}
func (b *bufferedConn) Read(p []byte) (int, error) {
return b.buffer.Read(p)
}
type muxedListener struct {
connMutex *sync.RWMutex
ctx context.Context
addr net.Addr
proto string
tlsConf *tls.Config
connCh chan net.Conn
closed bool
closeFunc func()
closeOnce *sync.Once
}
type ALPNMux struct {
ctx context.Context
baseLn net.Listener
cancel context.CancelFunc
muxMap *sync.Map
}
func New(baseLn net.Listener) *ALPNMux {
ctx, cancel := context.WithCancel(context.Background())
ret := &ALPNMux{
ctx: ctx,
cancel: cancel,
muxMap: new(sync.Map),
baseLn: baseLn,
}
go ret.accept()
return ret
}
func (l *ALPNMux) Addr() net.Addr {
return l.baseLn.Addr()
}
func (l *ALPNMux) Close() error {
return l.baseLn.Close()
}
func (l *ALPNMux) RegisterProto(proto string, tlsConf *tls.Config) (net.Listener, error) {
const op = "alpnmux.(ALPNMux).RegisterProto"
switch proto {
case NoProto:
if tlsConf != nil {
return nil, errors.New("tls config cannot be non-nil when using NoProto")
}
default:
if tlsConf == nil {
return nil, errors.New("nil tls config given")
}
}
sub := &muxedListener{
connMutex: new(sync.RWMutex),
ctx: l.ctx,
addr: l.baseLn.Addr(),
proto: proto,
tlsConf: tlsConf,
connCh: make(chan net.Conn),
closeOnce: new(sync.Once),
}
_, loaded := l.muxMap.LoadOrStore(proto, sub)
if loaded {
close(sub.connCh)
return nil, fmt.Errorf("proto %q already registered", proto)
}
sub.closeFunc = func() {
go l.UnregisterProto(proto)
}
return sub, nil
}
func (l *ALPNMux) UnregisterProto(proto string) {
const op = "alpnmux.(ALPNMux).UnregisterProto"
val, ok := l.muxMap.Load(proto)
if !ok {
return
}
ml := val.(*muxedListener)
ml.closeOnce.Do(func() {
ml.connMutex.Lock()
defer ml.connMutex.Unlock()
ml.closed = true
close(ml.connCh)
})
l.muxMap.Delete(proto)
}
func (l *ALPNMux) GetListener(proto string) net.Listener {
val, ok := l.muxMap.Load(proto)
if !ok || val == nil {
val, ok = l.muxMap.Load(DefaultProto)
if !ok || val == nil {
return nil
}
}
return val.(*muxedListener)
}
func (l *ALPNMux) getConfigForClient(hello *tls.ClientHelloInfo) (*tls.Config, error) {
const op = "alpnmux.(ALPNMux).getConfigForClient"
var ret *tls.Config
supportedProtos := hello.SupportedProtos
if len(hello.SupportedProtos) == 0 {
supportedProtos = append(supportedProtos, "")
}
for _, proto := range supportedProtos {
val, ok := l.muxMap.Load(proto)
if !ok {
continue
}
ret = val.(*muxedListener).tlsConf
}
if ret == nil {
val, ok := l.muxMap.Load(DefaultProto)
if ok && val != nil {
ret = val.(*muxedListener).tlsConf
}
}
if ret == nil {
return nil, errors.New("no tls configuration available for any client protos")
}
// If the TLS config we found has its own lookup function, chain to it
if ret.GetConfigForClient != nil {
return ret.GetConfigForClient(hello)
}
return ret, nil
}
func (l *ALPNMux) accept() {
const op = "alpnmux.(ALPNMux).accept"
ctx := context.TODO()
baseTLSConf := &tls.Config{
GetConfigForClient: l.getConfigForClient,
}
for {
conn, err := l.baseLn.Accept()
if err != nil {
if strings.Contains(err.Error(), "use of closed network connection") {
l.cancel()
return
}
}
if conn == nil {
continue
}
// Do the rest in a goroutine so that a timeout in e.g. handshaking
// doesn't block acceptance of the next connection
go func() {
bufConn := &bufferedConn{
Conn: conn,
buffer: bufio.NewReader(conn),
}
peeked, err := bufConn.buffer.Peek(3)
if err != nil {
bufConn.Close()
return
}
switch {
// First byte should always be a handshake, second byte a 3, and
// third can be 3 or 1 depending on the implementation
case peeked[0] != 0x16 || peeked[1] != 0x03 || (peeked[2] != 0x03 && peeked[2] != 0x01):
val, ok := l.muxMap.Load(NoProto)
if !ok {
bufConn.Close()
return
}
ml := val.(*muxedListener)
ml.connMutex.RLock()
if !ml.closed {
ml.connCh <- bufConn
}
ml.connMutex.RUnlock()
default:
tlsConn := tls.Server(bufConn, baseTLSConf)
if err := tlsConn.Handshake(); err != nil {
closeErr := tlsConn.Close()
if closeErr != nil {
event.WriteError(ctx, op, err, event.WithInfoMsg("error handshaking connection", "addr", conn.RemoteAddr(), "close_error", closeErr))
}
return
}
negProto := tlsConn.ConnectionState().NegotiatedProtocol
val, ok := l.muxMap.Load(negProto)
if !ok {
val, ok = l.muxMap.Load(DefaultProto)
if !ok {
tlsConn.Close()
return
}
}
ml := val.(*muxedListener)
ml.connMutex.RLock()
if !ml.closed {
ml.connCh <- tlsConn
}
ml.connMutex.RUnlock()
}
}()
}
}
func (m *muxedListener) Accept() (net.Conn, error) {
for {
select {
case <-m.ctx.Done():
// Wouldn't it be so much better if this error was an exported
// const from Go...
m.closeFunc()
return nil, fmt.Errorf("accept proto %s: use of closed network connection", m.proto)
case conn, ok := <-m.connCh:
if !ok {
// Channel closed
return nil, fmt.Errorf("accept proto %s: use of closed network connection", m.proto)
}
if conn == nil {
return nil, fmt.Errorf("accept proto %s: nil connection received", m.proto)
}
return conn, nil
}
}
}
func (m *muxedListener) Close() error {
m.closeFunc()
return nil
}
func (m *muxedListener) Addr() net.Addr {
return m.addr
}

@ -1,201 +0,0 @@
package alpnmux
import (
"crypto/tls"
"fmt"
"log"
"net"
"strings"
"sync"
"testing"
"time"
"github.com/hashicorp/boundary/internal/observability/event"
"github.com/hashicorp/go-hclog"
"go.uber.org/atomic"
)
func TestListenCloseErrMsg(t *testing.T) {
listener := getListener(t)
listener.Close()
_, err := listener.Accept()
if !strings.Contains(err.Error(), "use of closed network connection") {
t.Fatal(err)
}
}
func TestRegistrationErrors(t *testing.T) {
listener := getListener(t)
defer listener.Close()
mux := New(listener)
p1config := getTestTLS(t, []string{"p1"})
if _, err := mux.RegisterProto("p1", nil); err.Error() != "nil tls config given" {
t.Fatal(err)
}
l, err := mux.RegisterProto("p1", p1config)
if err != nil {
t.Fatal(err)
}
if _, err := mux.RegisterProto("p1", p1config); err.Error() != `proto "p1" already registered` {
t.Fatal(err)
}
l.Close()
// Unregister is not sync, so need to wait for it to actually be removed
var unregistered bool
for i := 0; i < 5; i++ {
_, ok := mux.muxMap.Load("p1")
if !ok {
unregistered = true
break
}
time.Sleep(100 * time.Millisecond)
}
if !unregistered {
t.Fatal("failed to unregister proto")
}
l, err = mux.RegisterProto("p1", p1config)
if err != nil {
t.Fatal(err)
}
l.Close()
l, err = mux.RegisterProto(NoProto, nil)
if err != nil {
t.Fatal(err)
}
l.Close()
}
func TestListening(t *testing.T) {
event.TestEnableEventing(t, true)
testConfig := event.DefaultEventerConfig()
testLock := &sync.Mutex{}
testLogger := hclog.New(&hclog.LoggerOptions{
Mutex: testLock,
})
err := event.InitSysEventer(testLogger, testLock, "TestListening", event.WithEventerConfig(testConfig))
if err != nil {
t.Fatal(err)
}
listener := getListener(t)
mux := New(listener)
defer mux.Close()
emptyconns := atomic.NewUint32(0)
noneconns := atomic.NewUint32(0)
l1conns := atomic.NewUint32(0)
l2conns := atomic.NewUint32(0)
l3conns := atomic.NewUint32(0)
defconns := atomic.NewUint32(0)
clientCountTracker := atomic.NewUint32(0)
baseconfig := getTestTLS(t, nil)
noneconfig := baseconfig.Clone()
p1config := baseconfig.Clone()
p1config.NextProtos = []string{"p1"}
p2p3config := getTestTLS(t, []string{"p2", "p3"})
p3config := p2p3config.Clone()
p3config.NextProtos = []string{"p3"}
defconfig := baseconfig.Clone()
defconfig.GetConfigForClient = func(hello *tls.ClientHelloInfo) (*tls.Config, error) {
ret := baseconfig.Clone()
ret.NextProtos = []string{fmt.Sprintf("%d", clientCountTracker.Load())}
log.Printf("returning def config with next protos = %v\n", ret.NextProtos)
clientCountTracker.Inc()
return ret, nil
}
lempty, err := mux.RegisterProto("", noneconfig)
if err != nil {
t.Fatal(err)
}
l1, err := mux.RegisterProto("p1", p1config)
if err != nil {
t.Fatal(err)
}
l2, err := mux.RegisterProto("p2", p2p3config)
if err != nil {
t.Fatal(err)
}
l3, err := mux.RegisterProto("p3", p2p3config)
if err != nil {
t.Fatal(err)
}
lnone, err := mux.RegisterProto(NoProto, nil)
if err != nil {
t.Fatal(err)
}
ldef, err := mux.RegisterProto(DefaultProto, defconfig)
if err != nil {
t.Fatal(err)
}
addr := listener.Addr().String()
wg := new(sync.WaitGroup)
wg.Add(6)
connWatchFunc := func(l net.Listener, connCounter *atomic.Uint32, tlsConf *tls.Config, numConns int) {
defer wg.Done()
tlsToUse := tlsConf
go func() {
for i := 0; i < numConns; i++ {
var err error
var conn net.Conn
switch tlsToUse {
case nil:
conn, err = net.Dial("tcp4", addr)
if err != nil {
t.Fatal(err)
}
// We need to send some data here because we won't have any
// from just the TLS handshake
log.Println("defconn")
n, err := conn.Write([]byte("GET "))
if err != nil {
t.Fatal(err)
}
if n != 4 {
t.Fatal(n)
}
log.Println("defconn done")
default:
if connCounter == defconns {
tlsToUse = baseconfig.Clone()
log.Println("FOUND CURR")
tlsToUse.NextProtos = []string{fmt.Sprintf("%d", i)}
}
log.Println(fmt.Sprintf("dialing on %d, counter = %d, protos = %v", numConns, i, tlsToUse.NextProtos))
conn, err = tls.Dial("tcp4", addr, tlsToUse)
if err != nil {
t.Fatal(err)
}
log.Println(fmt.Sprintf("dialing done on %d, counter = %d, protos = %v", numConns, i, tlsToUse.NextProtos))
}
conn.Close()
}
}()
for i := 0; i < numConns; i++ {
log.Println(fmt.Sprintf("accepting on %d, counter = %d", numConns, connCounter.Load()))
conn, err := l.Accept()
if err == nil && conn != nil {
conn.Close()
} else {
t.Fatal(err)
}
log.Println(fmt.Sprintf("done accepting on %d, counter = %d", numConns, connCounter.Load()))
connCounter.Inc()
}
return
}
go connWatchFunc(lempty, emptyconns, noneconfig, 4)
go connWatchFunc(l1, l1conns, p1config, 5)
go connWatchFunc(l2, l2conns, p2p3config, 6)
go connWatchFunc(l3, l3conns, p3config, 7)
go connWatchFunc(lnone, noneconns, nil, 8)
go connWatchFunc(ldef, defconns, defconfig, 9)
wg.Wait()
if emptyconns.Load() != 4 || l1conns.Load() != 5 || l2conns.Load() != 6 || l3conns.Load() != 7 || noneconns.Load() != 8 || defconns.Load() != 9 {
t.Fatal("wrong number of conns")
}
}

@ -1,119 +0,0 @@
package alpnmux
import (
"crypto/ed25519"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
mathrand "math/rand"
"net"
"testing"
"time"
)
func getListener(t *testing.T) net.Listener {
addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
listener, err := net.ListenTCP("tcp", addr)
if err != nil {
t.Fatal(err)
}
return listener
}
func getTestTLS(t *testing.T, protos []string) *tls.Config {
certIPs := []net.IP{
net.IPv6loopback,
net.ParseIP("127.0.0.1"),
}
_, caKey, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatal(err)
}
caCertTemplate := &x509.Certificate{
Subject: pkix.Name{
CommonName: "localhost",
},
DNSNames: []string{"localhost"},
IPAddresses: certIPs,
KeyUsage: x509.KeyUsage(x509.KeyUsageCertSign | x509.KeyUsageCRLSign),
SerialNumber: big.NewInt(mathrand.Int63()),
NotBefore: time.Now().Add(-30 * time.Second),
NotAfter: time.Now().Add(262980 * time.Hour),
BasicConstraintsValid: true,
IsCA: true,
}
caBytes, err := x509.CreateCertificate(rand.Reader, caCertTemplate, caCertTemplate, caKey.Public(), caKey)
if err != nil {
t.Fatal(err)
}
caCert, err := x509.ParseCertificate(caBytes)
if err != nil {
t.Fatal(err)
}
rootCAs := x509.NewCertPool()
rootCAs.AddCert(caCert)
//
// Certs generation
//
_, key, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatal(err)
}
certTemplate := &x509.Certificate{
Subject: pkix.Name{
CommonName: "localhost",
},
DNSNames: []string{"localhost"},
IPAddresses: certIPs,
ExtKeyUsage: []x509.ExtKeyUsage{
x509.ExtKeyUsageServerAuth,
x509.ExtKeyUsageClientAuth,
},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement,
SerialNumber: big.NewInt(mathrand.Int63()),
NotBefore: time.Now().Add(-30 * time.Second),
NotAfter: time.Now().Add(262980 * time.Hour),
}
certBytes, err := x509.CreateCertificate(rand.Reader, certTemplate, caCert, key.Public(), caKey)
if err != nil {
t.Fatal(err)
}
certPEMBlock := &pem.Block{
Type: "CERTIFICATE",
Bytes: certBytes,
}
certPEM := pem.EncodeToMemory(certPEMBlock)
marshaledKey, err := x509.MarshalPKCS8PrivateKey(key)
if err != nil {
t.Fatal(err)
}
keyPEMBlock := &pem.Block{
Type: "PRIVATE KEY",
Bytes: marshaledKey,
}
keyPEM := pem.EncodeToMemory(keyPEMBlock)
tlsCert, err := tls.X509KeyPair(certPEM, keyPEM)
if err != nil {
t.Fatal(err)
}
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{tlsCert},
RootCAs: rootCAs,
ClientCAs: rootCAs,
ClientAuth: tls.RequestClientCert,
NextProtos: protos,
MinVersion: tls.VersionTLS12,
MaxVersion: tls.VersionTLS13,
}
return tlsConfig
}

@ -14,7 +14,6 @@ import (
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
"github.com/hashicorp/boundary/internal/cmd/base"
pbs "github.com/hashicorp/boundary/internal/gen/controller/servers/services"
"github.com/hashicorp/boundary/internal/libs/alpnmux"
"github.com/hashicorp/boundary/internal/servers/controller/handlers/workers"
"github.com/hashicorp/go-multierror"
"google.golang.org/grpc"
@ -42,7 +41,7 @@ func (c *Controller) startListeners() error {
for i := range c.apiListeners {
ln := c.apiListeners[i]
apiServers, err := c.configureForAPI(ln)
apiServers, err := c.configureForApi(ln)
if err != nil {
return fmt.Errorf("failed to configure listener for api mode: %w", err)
}
@ -62,7 +61,7 @@ func (c *Controller) startListeners() error {
return nil
}
func (c *Controller) configureForAPI(ln *base.ServerListener) ([]func(), error) {
func (c *Controller) configureForApi(ln *base.ServerListener) ([]func(), error) {
apiServers := make([]func(), 0)
handler, err := c.apiHandler(HandlerProperties{
@ -97,39 +96,15 @@ func (c *Controller) configureForAPI(ln *base.ServerListener) ([]func(), error)
server.IdleTimeout = ln.Config.HTTPIdleTimeout
}
switch ln.Config.TLSDisable {
case true:
l, err := ln.Mux.RegisterProto(alpnmux.NoProto, nil)
if err != nil {
return nil, fmt.Errorf("error getting non-tls listener: %w", err)
}
if l == nil {
return nil, errors.New("could not get non-tls listener")
}
apiServers = append(apiServers, func() { go server.Serve(l) })
default:
for _, v := range []string{"", "http/1.1", "h2"} {
l := ln.Mux.GetListener(v)
if l == nil {
return nil, fmt.Errorf("could not get tls proto %q listener", v)
}
apiServers = append(apiServers, func() { go server.Serve(l) })
}
}
apiServers = append(apiServers, func() { go server.Serve(ln.ApiListener) })
return apiServers, nil
}
func (c *Controller) configureForCluster(ln *base.ServerListener) (func(), error) {
// Clear out in case this is a second start of the controller
ln.Mux.UnregisterProto(alpnmux.DefaultProto)
l, err := ln.Mux.RegisterProto(alpnmux.DefaultProto, &tls.Config{
l := tls.NewListener(ln.ClusterListener, &tls.Config{
GetConfigForClient: c.validateWorkerTls,
})
if err != nil {
return nil, fmt.Errorf("error getting sub-listener for worker proto: %w", err)
}
workerReqInterceptor, err := workerRequestInfoInterceptor(c.baseContext, c.conf.Eventer)
if err != nil {
@ -153,11 +128,9 @@ func (c *Controller) configureForCluster(ln *base.ServerListener) (func(), error
pbs.RegisterServerCoordinationServiceServer(workerServer, workerService)
pbs.RegisterSessionServiceServer(workerServer, workerService)
interceptor := newInterceptingListener(c, l)
ln.ALPNListener = interceptor
ln.GrpcServer = workerServer
return func() { go ln.GrpcServer.Serve(ln.ALPNListener) }, nil
return func() { go ln.GrpcServer.Serve(newInterceptingListener(c, l)) }, nil
}
func (c *Controller) stopServersAndListeners() error {
@ -183,12 +156,12 @@ func (c *Controller) stopClusterGrpcServerAndListener() error {
if c.clusterListener.GrpcServer == nil {
return fmt.Errorf("no cluster grpc server")
}
if c.clusterListener.Mux == nil {
return fmt.Errorf("no cluster listener mux")
if c.clusterListener.ClusterListener == nil {
return fmt.Errorf("no cluster listener")
}
c.clusterListener.GrpcServer.GracefulStop()
err := c.clusterListener.Mux.Close()
err := c.clusterListener.ClusterListener.Close()
return listenerCloseErrorCheck(c.clusterListener.Config.Type, err)
}
@ -204,7 +177,7 @@ func (c *Controller) stopHttpServersAndListeners() error {
ln.HTTPServer.Shutdown(ctx)
cancel()
err := ln.Mux.Close() // The HTTP Shutdown call should close this, but just in case.
err := ln.ApiListener.Close() // The HTTP Shutdown call should close this, but just in case.
err = listenerCloseErrorCheck(ln.Config.Type, err)
if err != nil {
multierror.Append(closeErrors, err)
@ -231,11 +204,11 @@ func (c *Controller) stopAnyListeners() error {
var closeErrors *multierror.Error
for i := range c.apiListeners {
ln := c.apiListeners[i]
if ln == nil || ln.Mux == nil {
if ln == nil || ln.ApiListener == nil {
continue
}
err := ln.Mux.Close()
err := ln.ApiListener.Close()
err = listenerCloseErrorCheck(ln.Config.Type, err)
if err != nil {
multierror.Append(closeErrors, err)

@ -20,7 +20,6 @@ import (
"time"
"github.com/hashicorp/boundary/internal/cmd/base"
"github.com/hashicorp/boundary/internal/libs/alpnmux"
"github.com/hashicorp/go-secure-stdlib/base62"
"github.com/hashicorp/go-secure-stdlib/configutil/v2"
"github.com/hashicorp/go-secure-stdlib/listenerutil"
@ -256,9 +255,9 @@ func TestStartListeners(t *testing.T) {
apiAddrs := make([]string, 0)
for _, l := range c.apiListeners {
apiAddrs = append(apiAddrs, l.Mux.Addr().String())
apiAddrs = append(apiAddrs, l.ApiListener.Addr().String())
}
tt.assertions(t, c, apiAddrs, c.clusterListener.Mux.Addr().String())
tt.assertions(t, c, apiAddrs, c.clusterListener.ClusterListener.Addr().String())
})
}
}
@ -297,7 +296,7 @@ func TestStopClusterGrpcServerAndListener(t *testing.T) {
}
},
expErr: true,
expErrStr: "no cluster listener mux",
expErrStr: "no cluster listener",
},
{
name: "listener already closed",
@ -308,15 +307,14 @@ func TestStopClusterGrpcServerAndListener(t *testing.T) {
return &Controller{
clusterListener: &base.ServerListener{
ALPNListener: l,
GrpcServer: grpc.NewServer(),
Mux: alpnmux.New(l),
Config: &listenerutil.ListenerConfig{Type: "tcp"},
ClusterListener: l,
GrpcServer: grpc.NewServer(),
Config: &listenerutil.ListenerConfig{Type: "tcp"},
},
}
},
assertions: func(t *testing.T, c *Controller) {
require.ErrorIs(t, c.clusterListener.Mux.Close(), net.ErrClosed)
require.ErrorIs(t, c.clusterListener.ClusterListener.Close(), net.ErrClosed)
},
expErr: false,
},
@ -339,15 +337,14 @@ func TestStopClusterGrpcServerAndListener(t *testing.T) {
return &Controller{
clusterListener: &base.ServerListener{
ALPNListener: l,
GrpcServer: grpcServer,
Mux: alpnmux.New(l),
Config: &listenerutil.ListenerConfig{Type: "tcp"},
ClusterListener: l,
GrpcServer: grpcServer,
Config: &listenerutil.ListenerConfig{Type: "tcp"},
},
}
},
assertions: func(t *testing.T, c *Controller) {
require.ErrorIs(t, c.clusterListener.Mux.Close(), net.ErrClosed)
require.ErrorIs(t, c.clusterListener.ClusterListener.Close(), net.ErrClosed)
},
expErr: false,
},
@ -421,20 +418,19 @@ func TestStopHttpServersAndListeners(t *testing.T) {
baseContext: context.Background(),
apiListeners: []*base.ServerListener{
{
ALPNListener: l1,
HTTPServer: s1,
Mux: alpnmux.New(l1),
Config: &listenerutil.ListenerConfig{Type: "tcp"},
ApiListener: l1,
HTTPServer: s1,
Config: &listenerutil.ListenerConfig{Type: "tcp"},
},
},
}
},
assertions: func(t *testing.T, c *Controller) {
// Asserts the HTTP Servers are closed.
require.ErrorIs(t, c.apiListeners[0].HTTPServer.Serve(c.apiListeners[0].ALPNListener), http.ErrServerClosed)
require.ErrorIs(t, c.apiListeners[0].HTTPServer.Serve(c.apiListeners[0].ApiListener), http.ErrServerClosed)
// Asserts the underlying listeners are closed.
require.ErrorIs(t, c.apiListeners[0].Mux.Close(), net.ErrClosed)
require.ErrorIs(t, c.apiListeners[0].ApiListener.Close(), net.ErrClosed)
},
expErr: false,
},
@ -468,28 +464,26 @@ func TestStopHttpServersAndListeners(t *testing.T) {
baseContext: context.Background(),
apiListeners: []*base.ServerListener{
{
ALPNListener: l1,
HTTPServer: s1,
Mux: alpnmux.New(l1),
Config: &listenerutil.ListenerConfig{Type: "tcp"},
ApiListener: l1,
HTTPServer: s1,
Config: &listenerutil.ListenerConfig{Type: "tcp"},
},
{
ALPNListener: l2,
HTTPServer: s2,
Mux: alpnmux.New(l2),
Config: &listenerutil.ListenerConfig{Type: "tcp"},
ApiListener: l2,
HTTPServer: s2,
Config: &listenerutil.ListenerConfig{Type: "tcp"},
},
},
}
},
assertions: func(t *testing.T, c *Controller) {
// Asserts the HTTP Servers are closed.
require.ErrorIs(t, c.apiListeners[0].HTTPServer.Serve(c.apiListeners[0].ALPNListener), http.ErrServerClosed)
require.ErrorIs(t, c.apiListeners[1].HTTPServer.Serve(c.apiListeners[1].ALPNListener), http.ErrServerClosed)
require.ErrorIs(t, c.apiListeners[0].HTTPServer.Serve(c.apiListeners[0].ApiListener), http.ErrServerClosed)
require.ErrorIs(t, c.apiListeners[1].HTTPServer.Serve(c.apiListeners[1].ApiListener), http.ErrServerClosed)
// Asserts the underlying listeners are closed.
require.ErrorIs(t, c.apiListeners[0].Mux.Close(), net.ErrClosed)
require.ErrorIs(t, c.apiListeners[1].Mux.Close(), net.ErrClosed)
require.ErrorIs(t, c.apiListeners[0].ApiListener.Close(), net.ErrClosed)
require.ErrorIs(t, c.apiListeners[1].ApiListener.Close(), net.ErrClosed)
},
expErr: false,
},
@ -601,10 +595,10 @@ func TestStopAnyListeners(t *testing.T) {
expErr: false,
},
{
name: "listeners with nil mux",
name: "non-empty listeners with nil listeners",
controllerFn: func(t *testing.T) *Controller {
return &Controller{apiListeners: []*base.ServerListener{
{Mux: nil}, {Mux: nil}, {Mux: nil},
{ClusterListener: nil}, {ClusterListener: nil}, {ClusterListener: nil},
}}
},
expErr: false,
@ -624,23 +618,23 @@ func TestStopAnyListeners(t *testing.T) {
return &Controller{apiListeners: []*base.ServerListener{
{
Config: &listenerutil.ListenerConfig{Type: "tcp"},
Mux: alpnmux.New(l1),
Config: &listenerutil.ListenerConfig{Type: "tcp"},
ApiListener: l1,
},
{
Config: &listenerutil.ListenerConfig{Type: "tcp"},
Mux: alpnmux.New(l2),
Config: &listenerutil.ListenerConfig{Type: "tcp"},
ApiListener: l2,
},
{
Config: &listenerutil.ListenerConfig{Type: "tcp"},
Mux: alpnmux.New(l3),
Config: &listenerutil.ListenerConfig{Type: "tcp"},
ApiListener: l3,
},
}}
},
assertions: func(t *testing.T, c *Controller) {
for i := range c.apiListeners {
ln := c.apiListeners[i]
require.ErrorIs(t, ln.Mux.Close(), net.ErrClosed)
require.ErrorIs(t, ln.ApiListener.Close(), net.ErrClosed)
}
},
expErr: false,

@ -219,7 +219,15 @@ func (tc *TestController) addrs(purpose string) []string {
addrs := make([]string, 0, len(tc.b.Listeners))
for _, listener := range tc.b.Listeners {
if listener.Config.Purpose[0] == purpose {
addr := listener.Mux.Addr()
var addr net.Addr
switch purpose {
case "api":
addr = listener.ApiListener.Addr()
case "cluster":
addr = listener.ClusterListener.Addr()
case "ops":
addr = listener.OpsListener.Addr()
}
switch {
case strings.HasPrefix(addr.String(), "/"):
switch purpose {

@ -12,7 +12,6 @@ import (
"time"
"github.com/hashicorp/boundary/internal/cmd/base"
"github.com/hashicorp/boundary/internal/libs/alpnmux"
"github.com/hashicorp/boundary/internal/observability/event"
"github.com/hashicorp/go-multierror"
)
@ -24,7 +23,7 @@ func (w *Worker) startListeners() error {
if e == nil {
return fmt.Errorf("%s: sys eventer not initialized", op)
}
logger, err := e.StandardLogger(w.baseContext, "listeners", event.ErrorType)
logger, err := e.StandardLogger(w.baseContext, "worker.listeners: ", event.ErrorType)
if err != nil {
return fmt.Errorf("%s: unable to initialize std logger: %w", op, err)
}
@ -46,7 +45,7 @@ func (w *Worker) startListeners() error {
return nil
}
func (w *Worker) configureForWorker(ln *base.ServerListener, log *log.Logger) (func(), error) {
func (w *Worker) configureForWorker(ln *base.ServerListener, logger *log.Logger) (func(), error) {
handler, err := w.handler(HandlerProperties{ListenerConfig: ln.Config})
if err != nil {
return nil, err
@ -57,7 +56,7 @@ func (w *Worker) configureForWorker(ln *base.ServerListener, log *log.Logger) (f
Handler: handler,
ReadHeaderTimeout: 10 * time.Second,
ReadTimeout: 30 * time.Second,
ErrorLog: log,
ErrorLog: logger,
BaseContext: func(net.Listener) context.Context {
return cancelCtx
},
@ -77,18 +76,9 @@ func (w *Worker) configureForWorker(ln *base.ServerListener, log *log.Logger) (f
server.IdleTimeout = ln.Config.HTTPIdleTimeout
}
// Clear out in case this is a second start of the controller
ln.Mux.UnregisterProto(alpnmux.DefaultProto)
ln.Mux.UnregisterProto(alpnmux.NoProto)
l, err := ln.Mux.RegisterProto(alpnmux.DefaultProto, &tls.Config{
l := tls.NewListener(ln.ProxyListener, &tls.Config{
GetConfigForClient: w.getSessionTls,
})
if err != nil {
return nil, fmt.Errorf("error getting tls listener: %w", err)
}
if l == nil {
return nil, errors.New("could not get tls listener")
}
return func() { go server.Serve(l) }, nil
}
@ -120,7 +110,7 @@ func (w *Worker) stopHttpServersAndListeners() error {
ln.HTTPServer.Shutdown(ctx)
cancel()
err := ln.Mux.Close()
err := ln.ProxyListener.Close()
err = listenerCloseErrorCheck(ln.Config.Type, err)
if err != nil {
multierror.Append(closeErrors, err)
@ -136,11 +126,11 @@ func (w *Worker) stopHttpServersAndListeners() error {
func (w *Worker) stopAnyListeners() error {
var closeErrors *multierror.Error
for _, ln := range w.listeners {
if ln == nil || ln.Mux == nil {
if ln == nil || ln.ProxyListener == nil {
continue
}
err := ln.Mux.Close()
err := ln.ProxyListener.Close()
err = listenerCloseErrorCheck(ln.Config.Type, err)
if err != nil {
multierror.Append(closeErrors, err)

@ -12,12 +12,20 @@ import (
"time"
"github.com/hashicorp/boundary/internal/cmd/base"
"github.com/hashicorp/boundary/internal/libs/alpnmux"
"github.com/hashicorp/go-secure-stdlib/listenerutil"
"github.com/stretchr/testify/require"
)
func TestStartListeners(t *testing.T) {
testNonTlsRejected := func(t *testing.T, resp *http.Response, err error) {
require.NoError(t, err)
require.NotNil(t, resp)
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "Client sent an HTTP request to an HTTPS server.\n", string(body))
}
tests := []struct {
name string
listeners []*listenerutil.ListenerConfig
@ -45,16 +53,16 @@ func TestStartListeners(t *testing.T) {
assertions: func(t *testing.T, w *Worker, addrs []string) {
require.Len(t, addrs, 2)
_, err := http.Get("http://" + addrs[0] + "/v1/proxy/")
require.ErrorIs(t, err, io.EOF) // empty response because of worker tls request validation
resp, err := http.Get("http://" + addrs[0] + "/v1/proxy/")
testNonTlsRejected(t, resp, err)
cl := http.Client{
Transport: &http.Transport{
Dial: func(network, addr string) (net.Conn, error) { return net.Dial("unix", addrs[1]) },
},
}
_, err = cl.Get("http://anything.domain/v1/proxy/")
require.ErrorIs(t, err, io.EOF) // empty response because of worker tls request validation
resp, err = cl.Get("http://anything.domain/v1/proxy/")
testNonTlsRejected(t, resp, err)
},
},
{
@ -89,11 +97,11 @@ func TestStartListeners(t *testing.T) {
assertions: func(t *testing.T, w *Worker, addrs []string) {
require.Len(t, addrs, 4)
_, err := http.Get("http://" + addrs[0] + "/v1/proxy/")
require.ErrorIs(t, err, io.EOF) // empty response because of worker tls request validation
resp, err := http.Get("http://" + addrs[0] + "/v1/proxy/")
testNonTlsRejected(t, resp, err)
_, err = http.Get("http://" + addrs[1] + "/v1/proxy/")
require.ErrorIs(t, err, io.EOF) // empty response because of worker tls request validation
resp, err = http.Get("http://" + addrs[1] + "/v1/proxy/")
testNonTlsRejected(t, resp, err)
for _, proxyAddr := range []string{addrs[2], addrs[3]} {
cl := http.Client{
@ -101,8 +109,8 @@ func TestStartListeners(t *testing.T) {
Dial: func(network, addr string) (net.Conn, error) { return net.Dial("unix", proxyAddr) },
},
}
_, err = cl.Get("http://anything.domain/v1/proxy")
require.ErrorIs(t, err, io.EOF) // empty response because of worker tls request validation
resp, err = cl.Get("http://anything.domain/v1/proxy")
testNonTlsRejected(t, resp, err)
}
},
},
@ -130,7 +138,7 @@ func TestStartListeners(t *testing.T) {
addrs := make([]string, 0)
for _, l := range w.listeners {
addrs = append(addrs, l.Mux.Addr().String())
addrs = append(addrs, l.ProxyListener.Addr().String())
}
if tt.assertions != nil {
tt.assertions(t, w, addrs)
@ -204,28 +212,26 @@ func TestStopHttpServersAndListeners(t *testing.T) {
baseContext: context.Background(),
listeners: []*base.ServerListener{
{
ALPNListener: l1,
HTTPServer: s1,
Mux: alpnmux.New(l1),
Config: &listenerutil.ListenerConfig{Type: "tcp"},
ProxyListener: l1,
HTTPServer: s1,
Config: &listenerutil.ListenerConfig{Type: "tcp"},
},
{
ALPNListener: l2,
HTTPServer: s2,
Mux: alpnmux.New(l2),
Config: &listenerutil.ListenerConfig{Type: "tcp"},
ProxyListener: l2,
HTTPServer: s2,
Config: &listenerutil.ListenerConfig{Type: "tcp"},
},
},
}
},
assertions: func(t *testing.T, w *Worker) {
// Asserts the HTTP Servers are closed.
require.ErrorIs(t, w.listeners[0].HTTPServer.Serve(w.listeners[0].ALPNListener), http.ErrServerClosed)
require.ErrorIs(t, w.listeners[1].HTTPServer.Serve(w.listeners[1].ALPNListener), http.ErrServerClosed)
require.ErrorIs(t, w.listeners[0].HTTPServer.Serve(w.listeners[0].ProxyListener), http.ErrServerClosed)
require.ErrorIs(t, w.listeners[1].HTTPServer.Serve(w.listeners[1].ProxyListener), http.ErrServerClosed)
// Asserts the underlying listeners are closed.
require.ErrorIs(t, w.listeners[0].Mux.Close(), net.ErrClosed)
require.ErrorIs(t, w.listeners[1].Mux.Close(), net.ErrClosed)
require.ErrorIs(t, w.listeners[0].ProxyListener.Close(), net.ErrClosed)
require.ErrorIs(t, w.listeners[1].ProxyListener.Close(), net.ErrClosed)
},
expErr: false,
},
@ -278,10 +284,10 @@ func TestStopAnyListeners(t *testing.T) {
expErr: false,
},
{
name: "listeners with nil mux",
name: "non-empty but nil listeners",
workerFn: func(t *testing.T) *Worker {
return &Worker{listeners: []*base.ServerListener{
{Mux: nil}, {Mux: nil}, {Mux: nil},
{ProxyListener: nil}, {ProxyListener: nil}, {ProxyListener: nil},
}}
},
expErr: false,
@ -301,23 +307,23 @@ func TestStopAnyListeners(t *testing.T) {
return &Worker{listeners: []*base.ServerListener{
{
Config: &listenerutil.ListenerConfig{Type: "tcp"},
Mux: alpnmux.New(l1),
Config: &listenerutil.ListenerConfig{Type: "tcp"},
ProxyListener: l1,
},
{
Config: &listenerutil.ListenerConfig{Type: "tcp"},
Mux: alpnmux.New(l2),
Config: &listenerutil.ListenerConfig{Type: "tcp"},
ProxyListener: l2,
},
{
Config: &listenerutil.ListenerConfig{Type: "tcp"},
Mux: alpnmux.New(l3),
Config: &listenerutil.ListenerConfig{Type: "tcp"},
ProxyListener: l3,
},
}}
},
assertions: func(t *testing.T, w *Worker) {
for i := range w.listeners {
ln := w.listeners[i]
require.ErrorIs(t, ln.Mux.Close(), net.ErrClosed)
require.ErrorIs(t, ln.ProxyListener.Close(), net.ErrClosed)
}
},
expErr: false,

@ -68,7 +68,7 @@ func (tw *TestWorker) ProxyAddrs() []string {
for _, listener := range tw.b.Listeners {
if listener.Config.Purpose[0] == "proxy" {
tcpAddr, ok := listener.Mux.Addr().(*net.TCPAddr)
tcpAddr, ok := listener.ProxyListener.Addr().(*net.TCPAddr)
if !ok {
tw.t.Fatal("could not parse address as a TCP addr")
}

Loading…
Cancel
Save