diff --git a/CHANGELOG.md b/CHANGELOG.md index 5bc51370ad..584a88a249 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,11 +8,18 @@ Canonical reference for changes, improvements, and bugfixes for Boundary. ### Improvements +* controller: Allow API/Cluster listeners to be Unix domain sockets + ([Issue](https://github.com/hashicorp/boundary/pull/699)) + ([PR](https://github.com/hashicorp/boundary/pull/705)) + ### Bug Fixes * cli: Fix hyphenation in help output for resources with compound names ([Issue](https://github.com/hashicorp/boundary/issues/686)) ([PR](https://github.com/hashicorp/boundary/pull/689)) +* controller, worker: Fix listening on IPv6 addresses + ([Issue](https://github.com/hashicorp/boundary/issues/701)) + ([PR](https://github.com/hashicorp/boundary/pull/703)) ## v0.1.0 diff --git a/internal/cmd/base/listener.go b/internal/cmd/base/listener.go index ed4e2c7a86..152b618acb 100644 --- a/internal/cmd/base/listener.go +++ b/internal/cmd/base/listener.go @@ -40,11 +40,12 @@ type WorkerAuthInfo struct { } // Factory is the factory function to create a listener. -type ListenerFactory func(*configutil.Listener, hclog.Logger, cli.Ui) (*alpnmux.ALPNMux, map[string]string, reloadutil.ReloadFunc, error) +type ListenerFactory func(string, *configutil.Listener, hclog.Logger, cli.Ui) (string, net.Listener, error) // BuiltinListeners is the list of built-in listener types. var BuiltinListeners = map[string]ListenerFactory{ - "tcp": tcpListenerFactory, + "tcp": tcpListenerFactory, + "unix": unixListenerFactory, } // New creates a new listener of the given type with the given @@ -55,15 +56,67 @@ func NewListener(l *configutil.Listener, logger hclog.Logger, ui cli.Ui) (*alpnm return nil, nil, nil, fmt.Errorf("unknown listener type: %q", l.Type) } - return f(l, logger, ui) -} - -func tcpListenerFactory(l *configutil.Listener, logger hclog.Logger, ui cli.Ui) (*alpnmux.ALPNMux, map[string]string, reloadutil.ReloadFunc, error) { var purpose string if len(l.Purpose) == 1 { purpose = l.Purpose[0] } + switch purpose { + case "cluster": + l.TLSDisable = true + case "proxy": + // TODO: Eventually we'll support bringing your own cert, and we'd only + // want to disable if you aren't actually bringing your own + l.TLSDisable = true + } + + finalAddr, ln, err := f(purpose, l, logger, ui) + if err != nil { + return nil, nil, nil, err + } + + ln, err = listenerWrapProxy(ln, l) + if err != nil { + return nil, nil, nil, err + } + + props := map[string]string{ + "addr": finalAddr, + } + + if _, ok := os.LookupEnv("BOUNDARY_LOG_CONNECTION_MUXING"); !ok { + logger = nil + } + alpnMux := alpnmux.New(ln, logger) + + if l.TLSDisable { + return alpnMux, props, nil, nil + } + + // Don't request a client cert unless they've explicitly configured it to do + // so + if !l.TLSRequireAndVerifyClientCert { + l.TLSDisableClientCerts = true + } + tlsConfig, reloadFunc, err := listenerutil.TLSConfig(l, props, ui) + if err != nil { + return nil, nil, nil, err + } + // Register no proto, "http/1.1", and "h2", with same TLS config + if _, err = alpnMux.RegisterProto("", tlsConfig); err != nil { + return nil, nil, nil, err + } + if _, err = alpnMux.RegisterProto("http/1.1", tlsConfig); err != nil { + return nil, nil, nil, err + } + if _, err = alpnMux.RegisterProto("h2", tlsConfig); err != nil { + return nil, nil, nil, err + } + + return alpnMux, props, reloadFunc, nil +} + +func tcpListenerFactory(purpose string, l *configutil.Listener, logger hclog.Logger, ui cli.Ui) (string, net.Listener, error) { if l.Address == "" { switch purpose { case "cluster": @@ -88,24 +141,15 @@ func tcpListenerFactory(l *configutil.Listener, logger hclog.Logger, ui cli.Ui) } host = l.Address } else { - return nil, nil, nil, fmt.Errorf("error splitting host/port: %w", err) + return "", nil, fmt.Errorf("error splitting host/port: %w", err) } } if host == "" { - return nil, nil, nil, errors.New("could not determine host") + return "", nil, errors.New("could not determine host") } if port == "" { - return nil, nil, nil, errors.New("could not determine port") - } - - switch purpose { - case "cluster": - l.TLSDisable = true - case "proxy": - // TODO: Eventually we'll support bringing your own cert, and we'd only - // want to disable if you aren't actually bringing your own - l.TLSDisable = true + return "", nil, errors.New("could not determine port") } bindProto := "tcp" @@ -120,54 +164,35 @@ func tcpListenerFactory(l *configutil.Listener, logger hclog.Logger, ui cli.Ui) port = "" } - finalListenAddr := fmt.Sprintf("%s:%s", host, port) + finalListenAddr := net.JoinHostPort(host, port) ln, err := net.Listen(bindProto, finalListenAddr) if err != nil { - return nil, nil, nil, err + return "", nil, err } ln = TCPKeepAliveListener{ln.(*net.TCPListener)} - ln, err = listenerWrapProxy(ln, l) - if err != nil { - return nil, nil, nil, err - } - - props := map[string]string{ - "addr": finalListenAddr, - } - - if _, ok := os.LookupEnv("BOUNDARY_LOG_CONNECTION_MUXING"); !ok { - logger = nil - } - alpnMux := alpnmux.New(ln, logger) - - if l.TLSDisable { - return alpnMux, props, nil, nil - } + return finalListenAddr, ln, nil +} - // Don't request a client cert unless they've explicitly configured it to do - // so - if !l.TLSRequireAndVerifyClientCert { - l.TLSDisableClientCerts = true +func unixListenerFactory(purpose string, l *configutil.Listener, logger hclog.Logger, ui cli.Ui) (string, net.Listener, error) { + var uConfig *listenerutil.UnixSocketsConfig + if l.SocketMode != "" && + l.SocketUser != "" && + l.SocketGroup != "" { + uConfig = &listenerutil.UnixSocketsConfig{ + Mode: l.SocketMode, + User: l.SocketUser, + Group: l.SocketGroup, + } } - tlsConfig, reloadFunc, err := listenerutil.TLSConfig(l, props, ui) + ln, err := listenerutil.UnixSocketListener(l.Address, uConfig) if err != nil { - return nil, nil, nil, err - } - // Register no proto, "http/1.1", and "h2", with same TLS config - if _, err = alpnMux.RegisterProto("", tlsConfig); err != nil { - return nil, nil, nil, err - } - if _, err = alpnMux.RegisterProto("http/1.1", tlsConfig); err != nil { - return nil, nil, nil, err - } - if _, err = alpnMux.RegisterProto("h2", tlsConfig); err != nil { - return nil, nil, nil, err + return "", nil, err } - return alpnMux, props, reloadFunc, nil + return l.Address, ln, nil } func listenerWrapProxy(ln net.Listener, l *configutil.Listener) (net.Listener, error) { diff --git a/internal/cmd/base/servers.go b/internal/cmd/base/servers.go index 6f4bef1a77..fd2b0c622b 100644 --- a/internal/cmd/base/servers.go +++ b/internal/cmd/base/servers.go @@ -623,6 +623,6 @@ func (b *Server) SetupWorkerPublicAddress(conf *config.Config, flagValue string) return fmt.Errorf("Error splitting public adddress host/port: %w", err) } } - conf.Worker.PublicAddr = fmt.Sprintf("%s:%s", host, port) + conf.Worker.PublicAddr = net.JoinHostPort(host, port) return nil } diff --git a/internal/cmd/commands/dev/dev.go b/internal/cmd/commands/dev/dev.go index d286af1bb9..772dd1fcfd 100644 --- a/internal/cmd/commands/dev/dev.go +++ b/internal/cmd/commands/dev/dev.go @@ -123,7 +123,7 @@ func (c *Command) Flags() *base.FlagSets { Name: "api-listen-address", Target: &c.flagControllerAPIListenAddr, EnvVar: "BOUNDARY_DEV_CONTROLLER_API_LISTEN_ADDRESS", - Usage: "Address to bind to for controller \"api\" purpose.", + Usage: "Address to bind to for controller \"api\" purpose. If this begins with a forward slash, it will be assumed to be a Unix domain socket path.", }) f.StringVar(&base.StringVar{ @@ -160,7 +160,7 @@ func (c *Command) Flags() *base.FlagSets { Name: "cluster-listen-address", Target: &c.flagControllerClusterListenAddr, EnvVar: "BOUNDARY_DEV_CONTROLLER_CLUSTER_LISTEN_ADDRESS", - Usage: "Address to bind to for controller \"cluster\" purpose.", + Usage: "Address to bind to for controller \"cluster\" purpose. If this begins with a forward slash, it will be assumed to be a Unix domain socket path.", }) f.StringVar(&base.StringVar{ @@ -296,10 +296,17 @@ func (c *Command) Run(args []string) int { if c.flagControllerAPIListenAddr != "" { l.Address = c.flagControllerAPIListenAddr } + if strings.HasPrefix(l.Address, "/") { + l.Type = "unix" + } case "cluster": if c.flagControllerClusterListenAddr != "" { l.Address = c.flagControllerClusterListenAddr + c.Config.Worker.Controllers = []string{l.Address} + } + if strings.HasPrefix(l.Address, "/") { + l.Type = "unix" } case "proxy": diff --git a/internal/servers/controller/listeners.go b/internal/servers/controller/listeners.go index ac63dc6cb9..267587b7c4 100644 --- a/internal/servers/controller/listeners.go +++ b/internal/servers/controller/listeners.go @@ -8,6 +8,7 @@ import ( "math" "net" "net/http" + "os" "sync" "time" @@ -195,7 +196,14 @@ func (c *Controller) stopListeners(serversOnly bool) error { var retErr *multierror.Error for _, ln := range c.conf.Listeners { if err := ln.Mux.Close(); err != nil { - retErr = multierror.Append(retErr, err) + if _, ok := err.(*os.PathError); ok && ln.Config.Type == "unix" { + // The rmListener probably tried to remove the file but it + // didn't exist, ignore the error; this is a conflict + // between rmListener and the default Go behavior of + // removing auto-vivified Unix domain sockets. + } else { + retErr = multierror.Append(retErr, err) + } } } return retErr.ErrorOrNil() diff --git a/internal/servers/controller/testing.go b/internal/servers/controller/testing.go index d8b6c82416..997d3023bc 100644 --- a/internal/servers/controller/testing.go +++ b/internal/servers/controller/testing.go @@ -4,6 +4,8 @@ import ( "context" "fmt" "net" + "strconv" + "strings" "testing" "github.com/hashicorp/boundary/api" @@ -150,12 +152,23 @@ func (tc *TestController) addrs(purpose string) []string { addrs := make([]string, 0, len(tc.b.Listeners)) for _, listener := range tc.b.Listeners { if listener.Config.Purpose[0] == purpose { - tcpAddr, ok := listener.Mux.Addr().(*net.TCPAddr) - if !ok { - tc.t.Fatal("could not parse address as a TCP addr") + addr := listener.Mux.Addr() + switch { + case strings.HasPrefix(addr.String(), "/"): + switch purpose { + case "api": + addrs = append(addrs, fmt.Sprintf("unix://%s", addr.String())) + default: + addrs = append(addrs, addr.String()) + } + default: + tcpAddr, ok := addr.(*net.TCPAddr) + if !ok { + tc.t.Fatal("could not parse address as a TCP addr") + } + addr := fmt.Sprintf("%s%s", prefix, net.JoinHostPort(tcpAddr.IP.String(), strconv.Itoa(tcpAddr.Port))) + addrs = append(addrs, addr) } - addr := fmt.Sprintf("%s%s:%d", prefix, tcpAddr.IP.String(), tcpAddr.Port) - addrs = append(addrs, addr) } } @@ -218,7 +231,12 @@ func (tc *TestController) Shutdown() { } type TestControllerOpts struct { - // Config; if not provided a dev one will be created + // ConfigHcl is the HCL to be parsed to generate the initial config. + // Overrides Config if both are set. + ConfigHcl string + + // Config; if not provided a dev one will be created, unless ConfigHcl is + // set. Config *config.Config // DefaultAuthMethodId is the default auth method ID to use, if set. @@ -313,7 +331,15 @@ func NewTestController(t *testing.T, opts *TestControllerOpts) *TestController { // Get dev config, or use a provided one var err error - if opts.Config == nil { + switch { + case opts.ConfigHcl != "": + cfg, err := config.Parse(opts.ConfigHcl) + if err != nil { + t.Fatal(err) + } + opts.Config = cfg + + case opts.Config == nil: opts.Config, err = config.DevController() if err != nil { t.Fatal(err) diff --git a/internal/servers/worker/controller_connection.go b/internal/servers/worker/controller_connection.go index b9a68dc2b4..d42df60781 100644 --- a/internal/servers/worker/controller_connection.go +++ b/internal/servers/worker/controller_connection.go @@ -28,15 +28,20 @@ import ( func (w *Worker) startControllerConnections() error { initialAddrs := make([]resolver.Address, 0, len(w.conf.RawConfig.Worker.Controllers)) for _, addr := range w.conf.RawConfig.Worker.Controllers { - host, port, err := net.SplitHostPort(addr) - if err != nil && strings.Contains(err.Error(), "missing port in address") { - w.logger.Trace("missing port in controller address, using port 9201", "address", addr) - host, port, err = net.SplitHostPort(fmt.Sprintf("%s:%s", addr, "9201")) - } - if err != nil { - return fmt.Errorf("error parsing controller address: %w", err) + switch { + case strings.HasPrefix(addr, "/"): + initialAddrs = append(initialAddrs, resolver.Address{Addr: addr}) + default: + host, port, err := net.SplitHostPort(addr) + if err != nil && strings.Contains(err.Error(), "missing port in address") { + w.logger.Trace("missing port in controller address, using port 9201", "address", addr) + host, port, err = net.SplitHostPort(net.JoinHostPort(addr, "9201")) + } + if err != nil { + return fmt.Errorf("error parsing controller address: %w", err) + } + initialAddrs = append(initialAddrs, resolver.Address{Addr: net.JoinHostPort(host, port)}) } - initialAddrs = append(initialAddrs, resolver.Address{Addr: fmt.Sprintf("%s:%s", host, port)}) } if len(initialAddrs) == 0 { @@ -60,7 +65,13 @@ func (w Worker) controllerDialerFunc() func(context.Context, string) (net.Conn, return nil, fmt.Errorf("error creating tls config for worker auth: %w", err) } dialer := &net.Dialer{} - nonTlsConn, err := dialer.DialContext(ctx, "tcp", addr) + var nonTlsConn net.Conn + switch { + case strings.HasPrefix(addr, "/"): + nonTlsConn, err = dialer.DialContext(ctx, "unix", addr) + default: + nonTlsConn, err = dialer.DialContext(ctx, "tcp", addr) + } if err != nil { return nil, fmt.Errorf("unable to dial to controller: %w", err) } diff --git a/internal/servers/worker/handler.go b/internal/servers/worker/handler.go index fcc00dc438..34074ecf10 100644 --- a/internal/servers/worker/handler.go +++ b/internal/servers/worker/handler.go @@ -42,7 +42,7 @@ func (w *Worker) handleProxy() http.HandlerFunc { clientIp, clientPort, err := net.SplitHostPort(r.RemoteAddr) if err != nil { - w.logger.Error("unable to understand remote address", "error", err) + w.logger.Error("unable to understand remote address", "error", err, "remote_addr", r.RemoteAddr) wr.WriteHeader(http.StatusInternalServerError) return } diff --git a/internal/servers/worker/listeners.go b/internal/servers/worker/listeners.go index 9ea9f9613d..d27d4551f7 100644 --- a/internal/servers/worker/listeners.go +++ b/internal/servers/worker/listeners.go @@ -7,6 +7,7 @@ import ( "fmt" "net" "net/http" + "os" "sync" "time" @@ -109,7 +110,14 @@ func (w *Worker) stopListeners() error { if !w.conf.RawConfig.DevController { for _, ln := range w.conf.Listeners { if err := ln.Mux.Close(); err != nil { - retErr = multierror.Append(retErr, err) + if _, ok := err.(*os.PathError); ok && ln.Config.Type == "unix" { + // The rmListener probably tried to remove the file but it + // didn't exist, ignore the error; this is a conflict + // between rmListener and the default Go behavior of + // removing auto-vivified Unix domain sockets. + } else { + retErr = multierror.Append(retErr, err) + } } } } diff --git a/internal/tests/cluster/ipv6_listener_test.go b/internal/tests/cluster/ipv6_listener_test.go new file mode 100644 index 0000000000..12828add6f --- /dev/null +++ b/internal/tests/cluster/ipv6_listener_test.go @@ -0,0 +1,108 @@ +package cluster + +import ( + "context" + "testing" + "time" + + "github.com/hashicorp/boundary/api" + "github.com/hashicorp/boundary/api/scopes" + "github.com/hashicorp/boundary/internal/cmd/config" + "github.com/hashicorp/boundary/internal/servers/controller" + "github.com/hashicorp/boundary/internal/servers/worker" + "github.com/hashicorp/go-hclog" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestIPv6Listener(t *testing.T) { + assert, require := assert.New(t), require.New(t) + amId := "ampw_1234567890" + user := "user" + password := "passpass" + logger := hclog.New(&hclog.LoggerOptions{ + Level: hclog.Trace, + }) + + conf, err := config.DevController() + require.NoError(err) + + for _, l := range conf.Listeners { + switch l.Purpose[0] { + case "api": + l.Address = "[::1]:9200" + + case "cluster": + l.Address = "[::1]:9201" + } + } + + c1 := controller.NewTestController(t, &controller.TestControllerOpts{ + Config: conf, + DefaultAuthMethodId: amId, + DefaultLoginName: user, + DefaultPassword: password, + Logger: logger.Named("c1"), + }) + defer c1.Shutdown() + + expectWorkers := func(c *controller.TestController, workers ...*worker.TestWorker) { + updateTimes := c.Controller().WorkerStatusUpdateTimes() + workerMap := map[string]*worker.TestWorker{} + for _, w := range workers { + workerMap[w.Name()] = w + } + updateTimes.Range(func(k, v interface{}) bool { + require.NotNil(k) + require.NotNil(v) + if workerMap[k.(string)] == nil { + // We don't remove from updateTimes currently so if we're not + // expecting it we'll see an out-of-date entry + return true + } + assert.WithinDuration(time.Now(), v.(time.Time), 35*time.Second) + delete(workerMap, k.(string)) + return true + }) + assert.Empty(workerMap) + } + + expectWorkers(c1) + + wconf, err := config.DevWorker() + require.NoError(err) + + w1 := worker.NewTestWorker(t, &worker.TestWorkerOpts{ + Config: wconf, + WorkerAuthKms: c1.Config().WorkerAuthKms, + InitialControllers: c1.ClusterAddrs(), + Logger: logger.Named("w1"), + }) + defer w1.Shutdown() + + time.Sleep(10 * time.Second) + expectWorkers(c1, w1) + + require.NoError(w1.Worker().Shutdown(true)) + time.Sleep(10 * time.Second) + expectWorkers(c1) + + require.NoError(c1.Controller().Shutdown(true)) + time.Sleep(10 * time.Second) + + require.NoError(c1.Controller().Start()) + time.Sleep(10 * time.Second) + expectWorkers(c1, w1) + + client, err := api.NewClient(nil) + require.NoError(err) + + addrs := c1.ApiAddrs() + require.Len(addrs, 1) + + require.NoError(client.SetAddr(addrs[0])) + + sc := scopes.NewClient(client) + _, err = sc.List(context.Background(), "global") + require.NoError(err) +} diff --git a/internal/tests/cluster/multi_controller_worker_test.go b/internal/tests/cluster/multi_controller_worker_test.go index 01e2a0b5bc..a76fcc3be0 100644 --- a/internal/tests/cluster/multi_controller_worker_test.go +++ b/internal/tests/cluster/multi_controller_worker_test.go @@ -4,6 +4,7 @@ import ( "testing" "time" + "github.com/hashicorp/boundary/internal/cmd/config" "github.com/hashicorp/boundary/internal/servers/controller" "github.com/hashicorp/boundary/internal/servers/worker" "github.com/hashicorp/go-hclog" @@ -20,7 +21,11 @@ func TestMultiControllerMultiWorkerConnections(t *testing.T) { Level: hclog.Trace, }) + conf, err := config.DevController() + require.NoError(err) + c1 := controller.NewTestController(t, &controller.TestControllerOpts{ + Config: conf, DefaultAuthMethodId: amId, DefaultLoginName: user, DefaultPassword: password, diff --git a/internal/tests/cluster/unix_listener_test.go b/internal/tests/cluster/unix_listener_test.go new file mode 100644 index 0000000000..735a99c1f9 --- /dev/null +++ b/internal/tests/cluster/unix_listener_test.go @@ -0,0 +1,120 @@ +package cluster + +import ( + "context" + "io/ioutil" + "os" + "path" + "testing" + "time" + + "github.com/hashicorp/boundary/api" + "github.com/hashicorp/boundary/api/scopes" + "github.com/hashicorp/boundary/internal/cmd/config" + "github.com/hashicorp/boundary/internal/servers/controller" + "github.com/hashicorp/boundary/internal/servers/worker" + "github.com/hashicorp/go-hclog" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestUnixListener(t *testing.T) { + assert, require := assert.New(t), require.New(t) + amId := "ampw_1234567890" + user := "user" + password := "passpass" + logger := hclog.New(&hclog.LoggerOptions{ + Level: hclog.Trace, + }) + + conf, err := config.DevController() + require.NoError(err) + + tempDir, err := ioutil.TempDir("", "boundary-unix-listener-test") + require.NoError(err) + + defer func() { + require.NoError(os.RemoveAll(tempDir)) + }() + + for _, l := range conf.Listeners { + switch l.Purpose[0] { + case "api": + l.Address = path.Join(tempDir, "api") + l.Type = "unix" + + case "cluster": + l.Address = path.Join(tempDir, "cluster") + l.Type = "unix" + } + } + + c1 := controller.NewTestController(t, &controller.TestControllerOpts{ + Config: conf, + DefaultAuthMethodId: amId, + DefaultLoginName: user, + DefaultPassword: password, + Logger: logger.Named("c1"), + }) + defer c1.Shutdown() + + expectWorkers := func(c *controller.TestController, workers ...*worker.TestWorker) { + updateTimes := c.Controller().WorkerStatusUpdateTimes() + workerMap := map[string]*worker.TestWorker{} + for _, w := range workers { + workerMap[w.Name()] = w + } + updateTimes.Range(func(k, v interface{}) bool { + require.NotNil(k) + require.NotNil(v) + if workerMap[k.(string)] == nil { + // We don't remove from updateTimes currently so if we're not + // expecting it we'll see an out-of-date entry + return true + } + assert.WithinDuration(time.Now(), v.(time.Time), 35*time.Second) + delete(workerMap, k.(string)) + return true + }) + assert.Empty(workerMap) + } + + expectWorkers(c1) + + wconf, err := config.DevWorker() + require.NoError(err) + + w1 := worker.NewTestWorker(t, &worker.TestWorkerOpts{ + Config: wconf, + WorkerAuthKms: c1.Config().WorkerAuthKms, + InitialControllers: c1.ClusterAddrs(), + Logger: logger.Named("w1"), + }) + defer w1.Shutdown() + + time.Sleep(10 * time.Second) + expectWorkers(c1, w1) + + require.NoError(w1.Worker().Shutdown(true)) + time.Sleep(10 * time.Second) + expectWorkers(c1) + + require.NoError(c1.Controller().Shutdown(true)) + time.Sleep(10 * time.Second) + + require.NoError(c1.Controller().Start()) + time.Sleep(10 * time.Second) + expectWorkers(c1, w1) + + client, err := api.NewClient(nil) + require.NoError(err) + + addrs := c1.ApiAddrs() + require.Len(addrs, 1) + + require.NoError(client.SetAddr(addrs[0])) + + sc := scopes.NewClient(client) + _, err = sc.List(context.Background(), "global") + require.NoError(err) +}