From 8d48a8515cb8c14aebf58f04c1d38a9b87db48b7 Mon Sep 17 00:00:00 2001 From: Hugo Vieira Date: Thu, 10 Mar 2022 12:19:53 +0000 Subject: [PATCH] refact(controller): Validate listener configuration in New With this change, the Controller is now aware of its own listeners in a more codified and structured way. The reasoning behind these changes are as follows: 1. Currently, we're iterating over these listeners in multiple spots along this initialization flow to check each listeners' purpose. By doing this once when we create a Controller and essentially storing the result, we can clean up the downstream code to use those results. 2. Initial effort to centralize Controller validation. At the moment, we are doing a lot of validation all over the `cmd` and `config` packages. From a conceptual perspective, Controller-related validation should be performed in the function that is responsible for constructing the Controller object itself. This enables any developer to look in a single place for all the constraints that we place on creating a Controller. --- internal/servers/controller/controller.go | 29 ++++ .../servers/controller/controller_test.go | 152 ++++++++++++++++++ internal/servers/controller/cors_test.go | 56 ++++--- 3 files changed, 211 insertions(+), 26 deletions(-) diff --git a/internal/servers/controller/controller.go b/internal/servers/controller/controller.go index 9c81c2f5a8..d3bae5bcda 100644 --- a/internal/servers/controller/controller.go +++ b/internal/servers/controller/controller.go @@ -52,6 +52,9 @@ type Controller struct { workerAuthCache *sync.Map + apiListeners []*base.ServerListener + clusterListener *base.ServerListener + // Used for testing and tracking worker health workerStatusUpdateTimes *sync.Map @@ -125,6 +128,32 @@ func New(ctx context.Context, conf *Config) (*Controller, error) { } } + clusterListeners := make([]*base.ServerListener, 0) + for i := range conf.Listeners { + l := conf.Listeners[i] + if l == nil || l.Config == nil || l.Config.Purpose == nil { + continue + } + if len(l.Config.Purpose) != 1 { + return nil, fmt.Errorf("found listener with multiple purposes %q", strings.Join(l.Config.Purpose, ",")) + } + switch l.Config.Purpose[0] { + case "api": + c.apiListeners = append(c.apiListeners, l) + case "cluster": + clusterListeners = append(clusterListeners, l) + } + } + if len(c.apiListeners) == 0 { + return nil, fmt.Errorf("no api listeners found") + } + if len(clusterListeners) != 1 { + // in the future, we might pick the cluster that is exposed to the outside + // instead of limiting it to one. + return nil, fmt.Errorf("exactly one cluster listener is required") + } + c.clusterListener = clusterListeners[0] + var pluginLogger hclog.Logger for _, enabledPlugin := range c.enabledPlugins { if pluginLogger == nil { diff --git a/internal/servers/controller/controller_test.go b/internal/servers/controller/controller_test.go index c3bef2e3ba..1866b1a951 100644 --- a/internal/servers/controller/controller_test.go +++ b/internal/servers/controller/controller_test.go @@ -4,8 +4,10 @@ import ( "context" "testing" + "github.com/hashicorp/boundary/internal/cmd/base" "github.com/hashicorp/boundary/internal/db" "github.com/hashicorp/boundary/internal/kms" + "github.com/hashicorp/go-secure-stdlib/listenerutil" "github.com/stretchr/testify/require" ) @@ -32,3 +34,153 @@ func TestController_New(t *testing.T) { require.NoError(err) }) } + +func TestControllerNewListenerConfig(t *testing.T) { + tests := []struct { + name string + listeners []*base.ServerListener + assertions func(t *testing.T, c *Controller) + expErr bool + expErrMsg string + }{ + { + name: "valid listener configuration", + listeners: []*base.ServerListener{ + { + Config: &listenerutil.ListenerConfig{ + Purpose: []string{"api"}, + }, + }, + { + Config: &listenerutil.ListenerConfig{ + Purpose: []string{"api"}, + }, + }, + { + Config: &listenerutil.ListenerConfig{ + Purpose: []string{"cluster"}, + }, + }, + }, + assertions: func(t *testing.T, c *Controller) { + require.Len(t, c.apiListeners, 2) + require.NotNil(t, c.clusterListener) + }, + }, + { + name: "listeners are required", + listeners: []*base.ServerListener{}, + expErr: true, + expErrMsg: "no api listeners found", + }, + { + name: "listeners are required - not nil", + listeners: []*base.ServerListener{nil, nil}, + expErr: true, + expErrMsg: "no api listeners found", + }, + { + name: "listeners are required - with config", + listeners: []*base.ServerListener{{}, {}}, + expErr: true, + expErrMsg: "no api listeners found", + }, + { + name: "listeners are required - with purposes", + listeners: []*base.ServerListener{ + { + Config: &listenerutil.ListenerConfig{Purpose: nil}, + }, + { + Config: &listenerutil.ListenerConfig{Purpose: nil}, + }, + }, + expErr: true, + expErrMsg: "no api listeners found", + }, + { + name: "both api and cluster listeners are required", + listeners: []*base.ServerListener{ + { + Config: &listenerutil.ListenerConfig{ + Purpose: []string{"api"}, + }, + }, + }, + expErr: true, + expErrMsg: "exactly one cluster listener is required", + }, + { + name: "both api and cluster listeners are required 2", + listeners: []*base.ServerListener{ + { + Config: &listenerutil.ListenerConfig{ + Purpose: []string{"cluster"}, + }, + }, + }, + expErr: true, + expErrMsg: "no api listeners found", + }, + { + name: "only one cluster listener is allowed", + listeners: []*base.ServerListener{ + { + Config: &listenerutil.ListenerConfig{ + Purpose: []string{"api"}, + }, + }, + { + Config: &listenerutil.ListenerConfig{ + Purpose: []string{"cluster"}, + }, + }, + { + Config: &listenerutil.ListenerConfig{ + Purpose: []string{"cluster"}, + }, + }, + }, + expErr: true, + expErrMsg: "exactly one cluster listener is required", + }, + { + name: "only one purpose is allowed per listener", + listeners: []*base.ServerListener{ + { + Config: &listenerutil.ListenerConfig{ + Purpose: []string{"api", "cluster"}, + }, + }, + }, + expErr: true, + expErrMsg: `found listener with multiple purposes "api,cluster"`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + tc := &TestController{ + t: t, + ctx: ctx, + cancel: cancel, + opts: nil, + } + conf := TestControllerConfig(t, ctx, tc, nil) + conf.Listeners = tt.listeners + + c, err := New(ctx, conf) + if tt.expErr { + require.EqualError(t, err, tt.expErrMsg) + require.Nil(t, c) + return + } + + require.NoError(t, err) + require.NotNil(t, c) + tt.assertions(t, c) + }) + } +} diff --git a/internal/servers/controller/cors_test.go b/internal/servers/controller/cors_test.go index 7110446fdc..63d8543084 100644 --- a/internal/servers/controller/cors_test.go +++ b/internal/servers/controller/cors_test.go @@ -20,50 +20,54 @@ const corsTestConfig = ` disable_mlock = true telemetry { - prometheus_retention_time = "24h" - disable_hostname = true + prometheus_retention_time = "24h" + disable_hostname = true } kms "aead" { - purpose = "root" - aead_type = "aes-gcm" - key = "09iqFxRJNYsl/b8CQxjnGw==" - key_id = "global_root" + purpose = "root" + aead_type = "aes-gcm" + key = "09iqFxRJNYsl/b8CQxjnGw==" + key_id = "global_root" } kms "aead" { - purpose = "worker-auth" - aead_type = "aes-gcm" - key = "09iqFxRJNYsl/b8CQxjnGw==" - key_id = "global_worker-auth" + purpose = "worker-auth" + aead_type = "aes-gcm" + key = "09iqFxRJNYsl/b8CQxjnGw==" + key_id = "global_worker-auth" } listener "tcp" { - purpose = "api" - tls_disable = true - cors_enabled = false + purpose = "cluster" } listener "tcp" { - purpose = "api" - tls_disable = true - cors_enabled = true - cors_allowed_origins = [] + purpose = "api" + tls_disable = true + cors_enabled = false } listener "tcp" { - purpose = "api" - tls_disable = true - cors_enabled = true - cors_allowed_origins = ["foobar.com", "barfoo.com"] + purpose = "api" + tls_disable = true + cors_enabled = true + cors_allowed_origins = [] } listener "tcp" { - purpose = "api" - tls_disable = true - cors_enabled = true - cors_allowed_origins = ["*"] - cors_allowed_headers = ["x-foobar"] + purpose = "api" + tls_disable = true + cors_enabled = true + cors_allowed_origins = ["foobar.com", "barfoo.com"] +} + +listener "tcp" { + purpose = "api" + tls_disable = true + cors_enabled = true + cors_allowed_origins = ["*"] + cors_allowed_headers = ["x-foobar"] } `