Support for Unix domain socket listeners (#705)

The client already supports protocol-agnostic transport so handles
unix:/// schemes.
pull/713/head
Jeff Mitchell 6 years ago committed by GitHub
parent fb94971a99
commit 1f48d97327
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -8,11 +8,18 @@ Canonical reference for changes, improvements, and bugfixes for Boundary.
### Improvements
* controller: Allow API/Cluster listeners to be Unix domain sockets
([Issue](https://github.com/hashicorp/boundary/pull/699))
([PR](https://github.com/hashicorp/boundary/pull/705))
### Bug Fixes
* cli: Fix hyphenation in help output for resources with compound names
([Issue](https://github.com/hashicorp/boundary/issues/686))
([PR](https://github.com/hashicorp/boundary/pull/689))
* controller, worker: Fix listening on IPv6 addresses
([Issue](https://github.com/hashicorp/boundary/issues/701))
([PR](https://github.com/hashicorp/boundary/pull/703))
## v0.1.0

@ -40,11 +40,12 @@ type WorkerAuthInfo struct {
}
// Factory is the factory function to create a listener.
type ListenerFactory func(*configutil.Listener, hclog.Logger, cli.Ui) (*alpnmux.ALPNMux, map[string]string, reloadutil.ReloadFunc, error)
type ListenerFactory func(string, *configutil.Listener, hclog.Logger, cli.Ui) (string, net.Listener, error)
// BuiltinListeners is the list of built-in listener types.
var BuiltinListeners = map[string]ListenerFactory{
"tcp": tcpListenerFactory,
"tcp": tcpListenerFactory,
"unix": unixListenerFactory,
}
// New creates a new listener of the given type with the given
@ -55,15 +56,67 @@ func NewListener(l *configutil.Listener, logger hclog.Logger, ui cli.Ui) (*alpnm
return nil, nil, nil, fmt.Errorf("unknown listener type: %q", l.Type)
}
return f(l, logger, ui)
}
func tcpListenerFactory(l *configutil.Listener, logger hclog.Logger, ui cli.Ui) (*alpnmux.ALPNMux, map[string]string, reloadutil.ReloadFunc, error) {
var purpose string
if len(l.Purpose) == 1 {
purpose = l.Purpose[0]
}
switch purpose {
case "cluster":
l.TLSDisable = true
case "proxy":
// TODO: Eventually we'll support bringing your own cert, and we'd only
// want to disable if you aren't actually bringing your own
l.TLSDisable = true
}
finalAddr, ln, err := f(purpose, l, logger, ui)
if err != nil {
return nil, nil, nil, err
}
ln, err = listenerWrapProxy(ln, l)
if err != nil {
return nil, nil, nil, err
}
props := map[string]string{
"addr": finalAddr,
}
if _, ok := os.LookupEnv("BOUNDARY_LOG_CONNECTION_MUXING"); !ok {
logger = nil
}
alpnMux := alpnmux.New(ln, logger)
if l.TLSDisable {
return alpnMux, props, nil, nil
}
// Don't request a client cert unless they've explicitly configured it to do
// so
if !l.TLSRequireAndVerifyClientCert {
l.TLSDisableClientCerts = true
}
tlsConfig, reloadFunc, err := listenerutil.TLSConfig(l, props, ui)
if err != nil {
return nil, nil, nil, err
}
// Register no proto, "http/1.1", and "h2", with same TLS config
if _, err = alpnMux.RegisterProto("", tlsConfig); err != nil {
return nil, nil, nil, err
}
if _, err = alpnMux.RegisterProto("http/1.1", tlsConfig); err != nil {
return nil, nil, nil, err
}
if _, err = alpnMux.RegisterProto("h2", tlsConfig); err != nil {
return nil, nil, nil, err
}
return alpnMux, props, reloadFunc, nil
}
func tcpListenerFactory(purpose string, l *configutil.Listener, logger hclog.Logger, ui cli.Ui) (string, net.Listener, error) {
if l.Address == "" {
switch purpose {
case "cluster":
@ -88,24 +141,15 @@ func tcpListenerFactory(l *configutil.Listener, logger hclog.Logger, ui cli.Ui)
}
host = l.Address
} else {
return nil, nil, nil, fmt.Errorf("error splitting host/port: %w", err)
return "", nil, fmt.Errorf("error splitting host/port: %w", err)
}
}
if host == "" {
return nil, nil, nil, errors.New("could not determine host")
return "", nil, errors.New("could not determine host")
}
if port == "" {
return nil, nil, nil, errors.New("could not determine port")
}
switch purpose {
case "cluster":
l.TLSDisable = true
case "proxy":
// TODO: Eventually we'll support bringing your own cert, and we'd only
// want to disable if you aren't actually bringing your own
l.TLSDisable = true
return "", nil, errors.New("could not determine port")
}
bindProto := "tcp"
@ -120,54 +164,35 @@ func tcpListenerFactory(l *configutil.Listener, logger hclog.Logger, ui cli.Ui)
port = ""
}
finalListenAddr := fmt.Sprintf("%s:%s", host, port)
finalListenAddr := net.JoinHostPort(host, port)
ln, err := net.Listen(bindProto, finalListenAddr)
if err != nil {
return nil, nil, nil, err
return "", nil, err
}
ln = TCPKeepAliveListener{ln.(*net.TCPListener)}
ln, err = listenerWrapProxy(ln, l)
if err != nil {
return nil, nil, nil, err
}
props := map[string]string{
"addr": finalListenAddr,
}
if _, ok := os.LookupEnv("BOUNDARY_LOG_CONNECTION_MUXING"); !ok {
logger = nil
}
alpnMux := alpnmux.New(ln, logger)
if l.TLSDisable {
return alpnMux, props, nil, nil
}
return finalListenAddr, ln, nil
}
// Don't request a client cert unless they've explicitly configured it to do
// so
if !l.TLSRequireAndVerifyClientCert {
l.TLSDisableClientCerts = true
func unixListenerFactory(purpose string, l *configutil.Listener, logger hclog.Logger, ui cli.Ui) (string, net.Listener, error) {
var uConfig *listenerutil.UnixSocketsConfig
if l.SocketMode != "" &&
l.SocketUser != "" &&
l.SocketGroup != "" {
uConfig = &listenerutil.UnixSocketsConfig{
Mode: l.SocketMode,
User: l.SocketUser,
Group: l.SocketGroup,
}
}
tlsConfig, reloadFunc, err := listenerutil.TLSConfig(l, props, ui)
ln, err := listenerutil.UnixSocketListener(l.Address, uConfig)
if err != nil {
return nil, nil, nil, err
}
// Register no proto, "http/1.1", and "h2", with same TLS config
if _, err = alpnMux.RegisterProto("", tlsConfig); err != nil {
return nil, nil, nil, err
}
if _, err = alpnMux.RegisterProto("http/1.1", tlsConfig); err != nil {
return nil, nil, nil, err
}
if _, err = alpnMux.RegisterProto("h2", tlsConfig); err != nil {
return nil, nil, nil, err
return "", nil, err
}
return alpnMux, props, reloadFunc, nil
return l.Address, ln, nil
}
func listenerWrapProxy(ln net.Listener, l *configutil.Listener) (net.Listener, error) {

@ -623,6 +623,6 @@ func (b *Server) SetupWorkerPublicAddress(conf *config.Config, flagValue string)
return fmt.Errorf("Error splitting public adddress host/port: %w", err)
}
}
conf.Worker.PublicAddr = fmt.Sprintf("%s:%s", host, port)
conf.Worker.PublicAddr = net.JoinHostPort(host, port)
return nil
}

@ -123,7 +123,7 @@ func (c *Command) Flags() *base.FlagSets {
Name: "api-listen-address",
Target: &c.flagControllerAPIListenAddr,
EnvVar: "BOUNDARY_DEV_CONTROLLER_API_LISTEN_ADDRESS",
Usage: "Address to bind to for controller \"api\" purpose.",
Usage: "Address to bind to for controller \"api\" purpose. If this begins with a forward slash, it will be assumed to be a Unix domain socket path.",
})
f.StringVar(&base.StringVar{
@ -160,7 +160,7 @@ func (c *Command) Flags() *base.FlagSets {
Name: "cluster-listen-address",
Target: &c.flagControllerClusterListenAddr,
EnvVar: "BOUNDARY_DEV_CONTROLLER_CLUSTER_LISTEN_ADDRESS",
Usage: "Address to bind to for controller \"cluster\" purpose.",
Usage: "Address to bind to for controller \"cluster\" purpose. If this begins with a forward slash, it will be assumed to be a Unix domain socket path.",
})
f.StringVar(&base.StringVar{
@ -296,10 +296,17 @@ func (c *Command) Run(args []string) int {
if c.flagControllerAPIListenAddr != "" {
l.Address = c.flagControllerAPIListenAddr
}
if strings.HasPrefix(l.Address, "/") {
l.Type = "unix"
}
case "cluster":
if c.flagControllerClusterListenAddr != "" {
l.Address = c.flagControllerClusterListenAddr
c.Config.Worker.Controllers = []string{l.Address}
}
if strings.HasPrefix(l.Address, "/") {
l.Type = "unix"
}
case "proxy":

@ -8,6 +8,7 @@ import (
"math"
"net"
"net/http"
"os"
"sync"
"time"
@ -195,7 +196,14 @@ func (c *Controller) stopListeners(serversOnly bool) error {
var retErr *multierror.Error
for _, ln := range c.conf.Listeners {
if err := ln.Mux.Close(); err != nil {
retErr = multierror.Append(retErr, err)
if _, ok := err.(*os.PathError); ok && ln.Config.Type == "unix" {
// The rmListener probably tried to remove the file but it
// didn't exist, ignore the error; this is a conflict
// between rmListener and the default Go behavior of
// removing auto-vivified Unix domain sockets.
} else {
retErr = multierror.Append(retErr, err)
}
}
}
return retErr.ErrorOrNil()

@ -4,6 +4,8 @@ import (
"context"
"fmt"
"net"
"strconv"
"strings"
"testing"
"github.com/hashicorp/boundary/api"
@ -150,12 +152,23 @@ func (tc *TestController) addrs(purpose string) []string {
addrs := make([]string, 0, len(tc.b.Listeners))
for _, listener := range tc.b.Listeners {
if listener.Config.Purpose[0] == purpose {
tcpAddr, ok := listener.Mux.Addr().(*net.TCPAddr)
if !ok {
tc.t.Fatal("could not parse address as a TCP addr")
addr := listener.Mux.Addr()
switch {
case strings.HasPrefix(addr.String(), "/"):
switch purpose {
case "api":
addrs = append(addrs, fmt.Sprintf("unix://%s", addr.String()))
default:
addrs = append(addrs, addr.String())
}
default:
tcpAddr, ok := addr.(*net.TCPAddr)
if !ok {
tc.t.Fatal("could not parse address as a TCP addr")
}
addr := fmt.Sprintf("%s%s", prefix, net.JoinHostPort(tcpAddr.IP.String(), strconv.Itoa(tcpAddr.Port)))
addrs = append(addrs, addr)
}
addr := fmt.Sprintf("%s%s:%d", prefix, tcpAddr.IP.String(), tcpAddr.Port)
addrs = append(addrs, addr)
}
}
@ -218,7 +231,12 @@ func (tc *TestController) Shutdown() {
}
type TestControllerOpts struct {
// Config; if not provided a dev one will be created
// ConfigHcl is the HCL to be parsed to generate the initial config.
// Overrides Config if both are set.
ConfigHcl string
// Config; if not provided a dev one will be created, unless ConfigHcl is
// set.
Config *config.Config
// DefaultAuthMethodId is the default auth method ID to use, if set.
@ -313,7 +331,15 @@ func NewTestController(t *testing.T, opts *TestControllerOpts) *TestController {
// Get dev config, or use a provided one
var err error
if opts.Config == nil {
switch {
case opts.ConfigHcl != "":
cfg, err := config.Parse(opts.ConfigHcl)
if err != nil {
t.Fatal(err)
}
opts.Config = cfg
case opts.Config == nil:
opts.Config, err = config.DevController()
if err != nil {
t.Fatal(err)

@ -28,15 +28,20 @@ import (
func (w *Worker) startControllerConnections() error {
initialAddrs := make([]resolver.Address, 0, len(w.conf.RawConfig.Worker.Controllers))
for _, addr := range w.conf.RawConfig.Worker.Controllers {
host, port, err := net.SplitHostPort(addr)
if err != nil && strings.Contains(err.Error(), "missing port in address") {
w.logger.Trace("missing port in controller address, using port 9201", "address", addr)
host, port, err = net.SplitHostPort(fmt.Sprintf("%s:%s", addr, "9201"))
}
if err != nil {
return fmt.Errorf("error parsing controller address: %w", err)
switch {
case strings.HasPrefix(addr, "/"):
initialAddrs = append(initialAddrs, resolver.Address{Addr: addr})
default:
host, port, err := net.SplitHostPort(addr)
if err != nil && strings.Contains(err.Error(), "missing port in address") {
w.logger.Trace("missing port in controller address, using port 9201", "address", addr)
host, port, err = net.SplitHostPort(net.JoinHostPort(addr, "9201"))
}
if err != nil {
return fmt.Errorf("error parsing controller address: %w", err)
}
initialAddrs = append(initialAddrs, resolver.Address{Addr: net.JoinHostPort(host, port)})
}
initialAddrs = append(initialAddrs, resolver.Address{Addr: fmt.Sprintf("%s:%s", host, port)})
}
if len(initialAddrs) == 0 {
@ -60,7 +65,13 @@ func (w Worker) controllerDialerFunc() func(context.Context, string) (net.Conn,
return nil, fmt.Errorf("error creating tls config for worker auth: %w", err)
}
dialer := &net.Dialer{}
nonTlsConn, err := dialer.DialContext(ctx, "tcp", addr)
var nonTlsConn net.Conn
switch {
case strings.HasPrefix(addr, "/"):
nonTlsConn, err = dialer.DialContext(ctx, "unix", addr)
default:
nonTlsConn, err = dialer.DialContext(ctx, "tcp", addr)
}
if err != nil {
return nil, fmt.Errorf("unable to dial to controller: %w", err)
}

@ -42,7 +42,7 @@ func (w *Worker) handleProxy() http.HandlerFunc {
clientIp, clientPort, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
w.logger.Error("unable to understand remote address", "error", err)
w.logger.Error("unable to understand remote address", "error", err, "remote_addr", r.RemoteAddr)
wr.WriteHeader(http.StatusInternalServerError)
return
}

@ -7,6 +7,7 @@ import (
"fmt"
"net"
"net/http"
"os"
"sync"
"time"
@ -109,7 +110,14 @@ func (w *Worker) stopListeners() error {
if !w.conf.RawConfig.DevController {
for _, ln := range w.conf.Listeners {
if err := ln.Mux.Close(); err != nil {
retErr = multierror.Append(retErr, err)
if _, ok := err.(*os.PathError); ok && ln.Config.Type == "unix" {
// The rmListener probably tried to remove the file but it
// didn't exist, ignore the error; this is a conflict
// between rmListener and the default Go behavior of
// removing auto-vivified Unix domain sockets.
} else {
retErr = multierror.Append(retErr, err)
}
}
}
}

@ -0,0 +1,108 @@
package cluster
import (
"context"
"testing"
"time"
"github.com/hashicorp/boundary/api"
"github.com/hashicorp/boundary/api/scopes"
"github.com/hashicorp/boundary/internal/cmd/config"
"github.com/hashicorp/boundary/internal/servers/controller"
"github.com/hashicorp/boundary/internal/servers/worker"
"github.com/hashicorp/go-hclog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestIPv6Listener(t *testing.T) {
assert, require := assert.New(t), require.New(t)
amId := "ampw_1234567890"
user := "user"
password := "passpass"
logger := hclog.New(&hclog.LoggerOptions{
Level: hclog.Trace,
})
conf, err := config.DevController()
require.NoError(err)
for _, l := range conf.Listeners {
switch l.Purpose[0] {
case "api":
l.Address = "[::1]:9200"
case "cluster":
l.Address = "[::1]:9201"
}
}
c1 := controller.NewTestController(t, &controller.TestControllerOpts{
Config: conf,
DefaultAuthMethodId: amId,
DefaultLoginName: user,
DefaultPassword: password,
Logger: logger.Named("c1"),
})
defer c1.Shutdown()
expectWorkers := func(c *controller.TestController, workers ...*worker.TestWorker) {
updateTimes := c.Controller().WorkerStatusUpdateTimes()
workerMap := map[string]*worker.TestWorker{}
for _, w := range workers {
workerMap[w.Name()] = w
}
updateTimes.Range(func(k, v interface{}) bool {
require.NotNil(k)
require.NotNil(v)
if workerMap[k.(string)] == nil {
// We don't remove from updateTimes currently so if we're not
// expecting it we'll see an out-of-date entry
return true
}
assert.WithinDuration(time.Now(), v.(time.Time), 35*time.Second)
delete(workerMap, k.(string))
return true
})
assert.Empty(workerMap)
}
expectWorkers(c1)
wconf, err := config.DevWorker()
require.NoError(err)
w1 := worker.NewTestWorker(t, &worker.TestWorkerOpts{
Config: wconf,
WorkerAuthKms: c1.Config().WorkerAuthKms,
InitialControllers: c1.ClusterAddrs(),
Logger: logger.Named("w1"),
})
defer w1.Shutdown()
time.Sleep(10 * time.Second)
expectWorkers(c1, w1)
require.NoError(w1.Worker().Shutdown(true))
time.Sleep(10 * time.Second)
expectWorkers(c1)
require.NoError(c1.Controller().Shutdown(true))
time.Sleep(10 * time.Second)
require.NoError(c1.Controller().Start())
time.Sleep(10 * time.Second)
expectWorkers(c1, w1)
client, err := api.NewClient(nil)
require.NoError(err)
addrs := c1.ApiAddrs()
require.Len(addrs, 1)
require.NoError(client.SetAddr(addrs[0]))
sc := scopes.NewClient(client)
_, err = sc.List(context.Background(), "global")
require.NoError(err)
}

@ -4,6 +4,7 @@ import (
"testing"
"time"
"github.com/hashicorp/boundary/internal/cmd/config"
"github.com/hashicorp/boundary/internal/servers/controller"
"github.com/hashicorp/boundary/internal/servers/worker"
"github.com/hashicorp/go-hclog"
@ -20,7 +21,11 @@ func TestMultiControllerMultiWorkerConnections(t *testing.T) {
Level: hclog.Trace,
})
conf, err := config.DevController()
require.NoError(err)
c1 := controller.NewTestController(t, &controller.TestControllerOpts{
Config: conf,
DefaultAuthMethodId: amId,
DefaultLoginName: user,
DefaultPassword: password,

@ -0,0 +1,120 @@
package cluster
import (
"context"
"io/ioutil"
"os"
"path"
"testing"
"time"
"github.com/hashicorp/boundary/api"
"github.com/hashicorp/boundary/api/scopes"
"github.com/hashicorp/boundary/internal/cmd/config"
"github.com/hashicorp/boundary/internal/servers/controller"
"github.com/hashicorp/boundary/internal/servers/worker"
"github.com/hashicorp/go-hclog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestUnixListener(t *testing.T) {
assert, require := assert.New(t), require.New(t)
amId := "ampw_1234567890"
user := "user"
password := "passpass"
logger := hclog.New(&hclog.LoggerOptions{
Level: hclog.Trace,
})
conf, err := config.DevController()
require.NoError(err)
tempDir, err := ioutil.TempDir("", "boundary-unix-listener-test")
require.NoError(err)
defer func() {
require.NoError(os.RemoveAll(tempDir))
}()
for _, l := range conf.Listeners {
switch l.Purpose[0] {
case "api":
l.Address = path.Join(tempDir, "api")
l.Type = "unix"
case "cluster":
l.Address = path.Join(tempDir, "cluster")
l.Type = "unix"
}
}
c1 := controller.NewTestController(t, &controller.TestControllerOpts{
Config: conf,
DefaultAuthMethodId: amId,
DefaultLoginName: user,
DefaultPassword: password,
Logger: logger.Named("c1"),
})
defer c1.Shutdown()
expectWorkers := func(c *controller.TestController, workers ...*worker.TestWorker) {
updateTimes := c.Controller().WorkerStatusUpdateTimes()
workerMap := map[string]*worker.TestWorker{}
for _, w := range workers {
workerMap[w.Name()] = w
}
updateTimes.Range(func(k, v interface{}) bool {
require.NotNil(k)
require.NotNil(v)
if workerMap[k.(string)] == nil {
// We don't remove from updateTimes currently so if we're not
// expecting it we'll see an out-of-date entry
return true
}
assert.WithinDuration(time.Now(), v.(time.Time), 35*time.Second)
delete(workerMap, k.(string))
return true
})
assert.Empty(workerMap)
}
expectWorkers(c1)
wconf, err := config.DevWorker()
require.NoError(err)
w1 := worker.NewTestWorker(t, &worker.TestWorkerOpts{
Config: wconf,
WorkerAuthKms: c1.Config().WorkerAuthKms,
InitialControllers: c1.ClusterAddrs(),
Logger: logger.Named("w1"),
})
defer w1.Shutdown()
time.Sleep(10 * time.Second)
expectWorkers(c1, w1)
require.NoError(w1.Worker().Shutdown(true))
time.Sleep(10 * time.Second)
expectWorkers(c1)
require.NoError(c1.Controller().Shutdown(true))
time.Sleep(10 * time.Second)
require.NoError(c1.Controller().Start())
time.Sleep(10 * time.Second)
expectWorkers(c1, w1)
client, err := api.NewClient(nil)
require.NoError(err)
addrs := c1.ApiAddrs()
require.Len(addrs, 1)
require.NoError(client.SetAddr(addrs[0]))
sc := scopes.NewClient(client)
_, err = sc.List(context.Background(), "global")
require.NoError(err)
}
Loading…
Cancel
Save