mirror of https://github.com/hashicorp/boundary
Remove ALPN Muxer (#1965)
We never ended up using this and it makes it much harder to reason about the listeners. An explanation: for the moment I separated out all of the listeners in the struct, even though one should suffice, because when I did it that way first I was hitting some really weird and hard to find intermittent behavior where it seemed like the listeners were getting crossed. So I re-did it this way with each listener very explicitly referenced. I'm happy to make a follow-up PR that attempts to combine them all into a single listener without hitting the issues I did, if desired, but at least doing it via a separate PR means that it won't be something with removing the ALPN muxer that is at fault.pull/1961/head
parent
ce7292d9c8
commit
07dc3db974
@ -1,273 +0,0 @@
|
||||
package alpnmux
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/hashicorp/boundary/internal/observability/event"
|
||||
)
|
||||
|
||||
const (
|
||||
// NoProto is used when the connection isn't actually TLS
|
||||
NoProto = "(none)"
|
||||
|
||||
// DefaultProto is used when there is an ALPN we don't actually know about.
|
||||
// If no protos are specified on an incoming TLS connection we will first
|
||||
// look for a proto of ""; if not found, will use DefaultProto. On a
|
||||
// connection that has protos defined, we will look for that proto first,
|
||||
// then DefaultProto.
|
||||
DefaultProto = "(*)"
|
||||
)
|
||||
|
||||
type bufferedConn struct {
|
||||
net.Conn
|
||||
buffer *bufio.Reader
|
||||
}
|
||||
|
||||
func (b *bufferedConn) Read(p []byte) (int, error) {
|
||||
return b.buffer.Read(p)
|
||||
}
|
||||
|
||||
type muxedListener struct {
|
||||
connMutex *sync.RWMutex
|
||||
ctx context.Context
|
||||
addr net.Addr
|
||||
proto string
|
||||
tlsConf *tls.Config
|
||||
connCh chan net.Conn
|
||||
closed bool
|
||||
closeFunc func()
|
||||
closeOnce *sync.Once
|
||||
}
|
||||
|
||||
type ALPNMux struct {
|
||||
ctx context.Context
|
||||
baseLn net.Listener
|
||||
cancel context.CancelFunc
|
||||
muxMap *sync.Map
|
||||
}
|
||||
|
||||
func New(baseLn net.Listener) *ALPNMux {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
ret := &ALPNMux{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
muxMap: new(sync.Map),
|
||||
baseLn: baseLn,
|
||||
}
|
||||
go ret.accept()
|
||||
return ret
|
||||
}
|
||||
|
||||
func (l *ALPNMux) Addr() net.Addr {
|
||||
return l.baseLn.Addr()
|
||||
}
|
||||
|
||||
func (l *ALPNMux) Close() error {
|
||||
return l.baseLn.Close()
|
||||
}
|
||||
|
||||
func (l *ALPNMux) RegisterProto(proto string, tlsConf *tls.Config) (net.Listener, error) {
|
||||
const op = "alpnmux.(ALPNMux).RegisterProto"
|
||||
switch proto {
|
||||
case NoProto:
|
||||
if tlsConf != nil {
|
||||
return nil, errors.New("tls config cannot be non-nil when using NoProto")
|
||||
}
|
||||
default:
|
||||
if tlsConf == nil {
|
||||
return nil, errors.New("nil tls config given")
|
||||
}
|
||||
}
|
||||
sub := &muxedListener{
|
||||
connMutex: new(sync.RWMutex),
|
||||
ctx: l.ctx,
|
||||
addr: l.baseLn.Addr(),
|
||||
proto: proto,
|
||||
tlsConf: tlsConf,
|
||||
connCh: make(chan net.Conn),
|
||||
closeOnce: new(sync.Once),
|
||||
}
|
||||
_, loaded := l.muxMap.LoadOrStore(proto, sub)
|
||||
if loaded {
|
||||
close(sub.connCh)
|
||||
return nil, fmt.Errorf("proto %q already registered", proto)
|
||||
}
|
||||
|
||||
sub.closeFunc = func() {
|
||||
go l.UnregisterProto(proto)
|
||||
}
|
||||
|
||||
return sub, nil
|
||||
}
|
||||
|
||||
func (l *ALPNMux) UnregisterProto(proto string) {
|
||||
const op = "alpnmux.(ALPNMux).UnregisterProto"
|
||||
val, ok := l.muxMap.Load(proto)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
ml := val.(*muxedListener)
|
||||
ml.closeOnce.Do(func() {
|
||||
ml.connMutex.Lock()
|
||||
defer ml.connMutex.Unlock()
|
||||
ml.closed = true
|
||||
close(ml.connCh)
|
||||
})
|
||||
l.muxMap.Delete(proto)
|
||||
}
|
||||
|
||||
func (l *ALPNMux) GetListener(proto string) net.Listener {
|
||||
val, ok := l.muxMap.Load(proto)
|
||||
if !ok || val == nil {
|
||||
val, ok = l.muxMap.Load(DefaultProto)
|
||||
if !ok || val == nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return val.(*muxedListener)
|
||||
}
|
||||
|
||||
func (l *ALPNMux) getConfigForClient(hello *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||
const op = "alpnmux.(ALPNMux).getConfigForClient"
|
||||
var ret *tls.Config
|
||||
|
||||
supportedProtos := hello.SupportedProtos
|
||||
if len(hello.SupportedProtos) == 0 {
|
||||
supportedProtos = append(supportedProtos, "")
|
||||
}
|
||||
for _, proto := range supportedProtos {
|
||||
val, ok := l.muxMap.Load(proto)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
ret = val.(*muxedListener).tlsConf
|
||||
}
|
||||
if ret == nil {
|
||||
val, ok := l.muxMap.Load(DefaultProto)
|
||||
if ok && val != nil {
|
||||
ret = val.(*muxedListener).tlsConf
|
||||
}
|
||||
}
|
||||
if ret == nil {
|
||||
return nil, errors.New("no tls configuration available for any client protos")
|
||||
}
|
||||
|
||||
// If the TLS config we found has its own lookup function, chain to it
|
||||
if ret.GetConfigForClient != nil {
|
||||
return ret.GetConfigForClient(hello)
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (l *ALPNMux) accept() {
|
||||
const op = "alpnmux.(ALPNMux).accept"
|
||||
ctx := context.TODO()
|
||||
baseTLSConf := &tls.Config{
|
||||
GetConfigForClient: l.getConfigForClient,
|
||||
}
|
||||
for {
|
||||
conn, err := l.baseLn.Accept()
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "use of closed network connection") {
|
||||
l.cancel()
|
||||
return
|
||||
}
|
||||
}
|
||||
if conn == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Do the rest in a goroutine so that a timeout in e.g. handshaking
|
||||
// doesn't block acceptance of the next connection
|
||||
go func() {
|
||||
bufConn := &bufferedConn{
|
||||
Conn: conn,
|
||||
buffer: bufio.NewReader(conn),
|
||||
}
|
||||
peeked, err := bufConn.buffer.Peek(3)
|
||||
if err != nil {
|
||||
bufConn.Close()
|
||||
return
|
||||
}
|
||||
switch {
|
||||
// First byte should always be a handshake, second byte a 3, and
|
||||
// third can be 3 or 1 depending on the implementation
|
||||
case peeked[0] != 0x16 || peeked[1] != 0x03 || (peeked[2] != 0x03 && peeked[2] != 0x01):
|
||||
val, ok := l.muxMap.Load(NoProto)
|
||||
if !ok {
|
||||
bufConn.Close()
|
||||
return
|
||||
}
|
||||
ml := val.(*muxedListener)
|
||||
ml.connMutex.RLock()
|
||||
if !ml.closed {
|
||||
ml.connCh <- bufConn
|
||||
}
|
||||
ml.connMutex.RUnlock()
|
||||
|
||||
default:
|
||||
tlsConn := tls.Server(bufConn, baseTLSConf)
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
closeErr := tlsConn.Close()
|
||||
if closeErr != nil {
|
||||
event.WriteError(ctx, op, err, event.WithInfoMsg("error handshaking connection", "addr", conn.RemoteAddr(), "close_error", closeErr))
|
||||
}
|
||||
return
|
||||
}
|
||||
negProto := tlsConn.ConnectionState().NegotiatedProtocol
|
||||
val, ok := l.muxMap.Load(negProto)
|
||||
if !ok {
|
||||
val, ok = l.muxMap.Load(DefaultProto)
|
||||
if !ok {
|
||||
tlsConn.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
ml := val.(*muxedListener)
|
||||
ml.connMutex.RLock()
|
||||
if !ml.closed {
|
||||
ml.connCh <- tlsConn
|
||||
}
|
||||
ml.connMutex.RUnlock()
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func (m *muxedListener) Accept() (net.Conn, error) {
|
||||
for {
|
||||
select {
|
||||
case <-m.ctx.Done():
|
||||
// Wouldn't it be so much better if this error was an exported
|
||||
// const from Go...
|
||||
m.closeFunc()
|
||||
return nil, fmt.Errorf("accept proto %s: use of closed network connection", m.proto)
|
||||
case conn, ok := <-m.connCh:
|
||||
if !ok {
|
||||
// Channel closed
|
||||
return nil, fmt.Errorf("accept proto %s: use of closed network connection", m.proto)
|
||||
}
|
||||
if conn == nil {
|
||||
return nil, fmt.Errorf("accept proto %s: nil connection received", m.proto)
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *muxedListener) Close() error {
|
||||
m.closeFunc()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *muxedListener) Addr() net.Addr {
|
||||
return m.addr
|
||||
}
|
||||
@ -1,201 +0,0 @@
|
||||
package alpnmux
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/boundary/internal/observability/event"
|
||||
"github.com/hashicorp/go-hclog"
|
||||
"go.uber.org/atomic"
|
||||
)
|
||||
|
||||
func TestListenCloseErrMsg(t *testing.T) {
|
||||
listener := getListener(t)
|
||||
listener.Close()
|
||||
_, err := listener.Accept()
|
||||
if !strings.Contains(err.Error(), "use of closed network connection") {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistrationErrors(t *testing.T) {
|
||||
listener := getListener(t)
|
||||
defer listener.Close()
|
||||
mux := New(listener)
|
||||
p1config := getTestTLS(t, []string{"p1"})
|
||||
if _, err := mux.RegisterProto("p1", nil); err.Error() != "nil tls config given" {
|
||||
t.Fatal(err)
|
||||
}
|
||||
l, err := mux.RegisterProto("p1", p1config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := mux.RegisterProto("p1", p1config); err.Error() != `proto "p1" already registered` {
|
||||
t.Fatal(err)
|
||||
}
|
||||
l.Close()
|
||||
// Unregister is not sync, so need to wait for it to actually be removed
|
||||
var unregistered bool
|
||||
for i := 0; i < 5; i++ {
|
||||
_, ok := mux.muxMap.Load("p1")
|
||||
if !ok {
|
||||
unregistered = true
|
||||
break
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
if !unregistered {
|
||||
t.Fatal("failed to unregister proto")
|
||||
}
|
||||
l, err = mux.RegisterProto("p1", p1config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
l.Close()
|
||||
l, err = mux.RegisterProto(NoProto, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
l.Close()
|
||||
}
|
||||
|
||||
func TestListening(t *testing.T) {
|
||||
event.TestEnableEventing(t, true)
|
||||
testConfig := event.DefaultEventerConfig()
|
||||
testLock := &sync.Mutex{}
|
||||
testLogger := hclog.New(&hclog.LoggerOptions{
|
||||
Mutex: testLock,
|
||||
})
|
||||
err := event.InitSysEventer(testLogger, testLock, "TestListening", event.WithEventerConfig(testConfig))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
listener := getListener(t)
|
||||
|
||||
mux := New(listener)
|
||||
defer mux.Close()
|
||||
|
||||
emptyconns := atomic.NewUint32(0)
|
||||
noneconns := atomic.NewUint32(0)
|
||||
l1conns := atomic.NewUint32(0)
|
||||
l2conns := atomic.NewUint32(0)
|
||||
l3conns := atomic.NewUint32(0)
|
||||
defconns := atomic.NewUint32(0)
|
||||
clientCountTracker := atomic.NewUint32(0)
|
||||
|
||||
baseconfig := getTestTLS(t, nil)
|
||||
noneconfig := baseconfig.Clone()
|
||||
p1config := baseconfig.Clone()
|
||||
p1config.NextProtos = []string{"p1"}
|
||||
p2p3config := getTestTLS(t, []string{"p2", "p3"})
|
||||
p3config := p2p3config.Clone()
|
||||
p3config.NextProtos = []string{"p3"}
|
||||
defconfig := baseconfig.Clone()
|
||||
defconfig.GetConfigForClient = func(hello *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||
ret := baseconfig.Clone()
|
||||
ret.NextProtos = []string{fmt.Sprintf("%d", clientCountTracker.Load())}
|
||||
log.Printf("returning def config with next protos = %v\n", ret.NextProtos)
|
||||
clientCountTracker.Inc()
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
lempty, err := mux.RegisterProto("", noneconfig)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
l1, err := mux.RegisterProto("p1", p1config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
l2, err := mux.RegisterProto("p2", p2p3config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
l3, err := mux.RegisterProto("p3", p2p3config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
lnone, err := mux.RegisterProto(NoProto, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ldef, err := mux.RegisterProto(DefaultProto, defconfig)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := listener.Addr().String()
|
||||
wg := new(sync.WaitGroup)
|
||||
wg.Add(6)
|
||||
connWatchFunc := func(l net.Listener, connCounter *atomic.Uint32, tlsConf *tls.Config, numConns int) {
|
||||
defer wg.Done()
|
||||
tlsToUse := tlsConf
|
||||
go func() {
|
||||
for i := 0; i < numConns; i++ {
|
||||
var err error
|
||||
var conn net.Conn
|
||||
switch tlsToUse {
|
||||
case nil:
|
||||
conn, err = net.Dial("tcp4", addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// We need to send some data here because we won't have any
|
||||
// from just the TLS handshake
|
||||
log.Println("defconn")
|
||||
n, err := conn.Write([]byte("GET "))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if n != 4 {
|
||||
t.Fatal(n)
|
||||
}
|
||||
log.Println("defconn done")
|
||||
|
||||
default:
|
||||
if connCounter == defconns {
|
||||
tlsToUse = baseconfig.Clone()
|
||||
log.Println("FOUND CURR")
|
||||
tlsToUse.NextProtos = []string{fmt.Sprintf("%d", i)}
|
||||
}
|
||||
log.Println(fmt.Sprintf("dialing on %d, counter = %d, protos = %v", numConns, i, tlsToUse.NextProtos))
|
||||
conn, err = tls.Dial("tcp4", addr, tlsToUse)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
log.Println(fmt.Sprintf("dialing done on %d, counter = %d, protos = %v", numConns, i, tlsToUse.NextProtos))
|
||||
}
|
||||
conn.Close()
|
||||
}
|
||||
}()
|
||||
for i := 0; i < numConns; i++ {
|
||||
log.Println(fmt.Sprintf("accepting on %d, counter = %d", numConns, connCounter.Load()))
|
||||
conn, err := l.Accept()
|
||||
if err == nil && conn != nil {
|
||||
conn.Close()
|
||||
} else {
|
||||
t.Fatal(err)
|
||||
}
|
||||
log.Println(fmt.Sprintf("done accepting on %d, counter = %d", numConns, connCounter.Load()))
|
||||
connCounter.Inc()
|
||||
}
|
||||
return
|
||||
}
|
||||
go connWatchFunc(lempty, emptyconns, noneconfig, 4)
|
||||
go connWatchFunc(l1, l1conns, p1config, 5)
|
||||
go connWatchFunc(l2, l2conns, p2p3config, 6)
|
||||
go connWatchFunc(l3, l3conns, p3config, 7)
|
||||
go connWatchFunc(lnone, noneconns, nil, 8)
|
||||
go connWatchFunc(ldef, defconns, defconfig, 9)
|
||||
wg.Wait()
|
||||
|
||||
if emptyconns.Load() != 4 || l1conns.Load() != 5 || l2conns.Load() != 6 || l3conns.Load() != 7 || noneconns.Load() != 8 || defconns.Load() != 9 {
|
||||
t.Fatal("wrong number of conns")
|
||||
}
|
||||
}
|
||||
@ -1,119 +0,0 @@
|
||||
package alpnmux
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"math/big"
|
||||
mathrand "math/rand"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func getListener(t *testing.T) net.Listener {
|
||||
addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
listener, err := net.ListenTCP("tcp", addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return listener
|
||||
}
|
||||
|
||||
func getTestTLS(t *testing.T, protos []string) *tls.Config {
|
||||
certIPs := []net.IP{
|
||||
net.IPv6loopback,
|
||||
net.ParseIP("127.0.0.1"),
|
||||
}
|
||||
|
||||
_, caKey, err := ed25519.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
caCertTemplate := &x509.Certificate{
|
||||
Subject: pkix.Name{
|
||||
CommonName: "localhost",
|
||||
},
|
||||
DNSNames: []string{"localhost"},
|
||||
IPAddresses: certIPs,
|
||||
KeyUsage: x509.KeyUsage(x509.KeyUsageCertSign | x509.KeyUsageCRLSign),
|
||||
SerialNumber: big.NewInt(mathrand.Int63()),
|
||||
NotBefore: time.Now().Add(-30 * time.Second),
|
||||
NotAfter: time.Now().Add(262980 * time.Hour),
|
||||
BasicConstraintsValid: true,
|
||||
IsCA: true,
|
||||
}
|
||||
caBytes, err := x509.CreateCertificate(rand.Reader, caCertTemplate, caCertTemplate, caKey.Public(), caKey)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
caCert, err := x509.ParseCertificate(caBytes)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rootCAs := x509.NewCertPool()
|
||||
rootCAs.AddCert(caCert)
|
||||
|
||||
//
|
||||
// Certs generation
|
||||
//
|
||||
_, key, err := ed25519.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
certTemplate := &x509.Certificate{
|
||||
Subject: pkix.Name{
|
||||
CommonName: "localhost",
|
||||
},
|
||||
DNSNames: []string{"localhost"},
|
||||
IPAddresses: certIPs,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{
|
||||
x509.ExtKeyUsageServerAuth,
|
||||
x509.ExtKeyUsageClientAuth,
|
||||
},
|
||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement,
|
||||
SerialNumber: big.NewInt(mathrand.Int63()),
|
||||
NotBefore: time.Now().Add(-30 * time.Second),
|
||||
NotAfter: time.Now().Add(262980 * time.Hour),
|
||||
}
|
||||
certBytes, err := x509.CreateCertificate(rand.Reader, certTemplate, caCert, key.Public(), caKey)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
certPEMBlock := &pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: certBytes,
|
||||
}
|
||||
certPEM := pem.EncodeToMemory(certPEMBlock)
|
||||
marshaledKey, err := x509.MarshalPKCS8PrivateKey(key)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
keyPEMBlock := &pem.Block{
|
||||
Type: "PRIVATE KEY",
|
||||
Bytes: marshaledKey,
|
||||
}
|
||||
keyPEM := pem.EncodeToMemory(keyPEMBlock)
|
||||
|
||||
tlsCert, err := tls.X509KeyPair(certPEM, keyPEM)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tlsConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{tlsCert},
|
||||
RootCAs: rootCAs,
|
||||
ClientCAs: rootCAs,
|
||||
ClientAuth: tls.RequestClientCert,
|
||||
NextProtos: protos,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
MaxVersion: tls.VersionTLS13,
|
||||
}
|
||||
|
||||
return tlsConfig
|
||||
}
|
||||
Loading…
Reference in new issue