From 8d48a8515cb8c14aebf58f04c1d38a9b87db48b7 Mon Sep 17 00:00:00 2001 From: Hugo Vieira Date: Thu, 10 Mar 2022 12:19:53 +0000 Subject: [PATCH 1/2] refact(controller): Validate listener configuration in New With this change, the Controller is now aware of its own listeners in a more codified and structured way. The reasoning behind these changes are as follows: 1. Currently, we're iterating over these listeners in multiple spots along this initialization flow to check each listeners' purpose. By doing this once when we create a Controller and essentially storing the result, we can clean up the downstream code to use those results. 2. Initial effort to centralize Controller validation. At the moment, we are doing a lot of validation all over the `cmd` and `config` packages. From a conceptual perspective, Controller-related validation should be performed in the function that is responsible for constructing the Controller object itself. This enables any developer to look in a single place for all the constraints that we place on creating a Controller. --- internal/servers/controller/controller.go | 29 ++++ .../servers/controller/controller_test.go | 152 ++++++++++++++++++ internal/servers/controller/cors_test.go | 56 ++++--- 3 files changed, 211 insertions(+), 26 deletions(-) diff --git a/internal/servers/controller/controller.go b/internal/servers/controller/controller.go index 9c81c2f5a8..d3bae5bcda 100644 --- a/internal/servers/controller/controller.go +++ b/internal/servers/controller/controller.go @@ -52,6 +52,9 @@ type Controller struct { workerAuthCache *sync.Map + apiListeners []*base.ServerListener + clusterListener *base.ServerListener + // Used for testing and tracking worker health workerStatusUpdateTimes *sync.Map @@ -125,6 +128,32 @@ func New(ctx context.Context, conf *Config) (*Controller, error) { } } + clusterListeners := make([]*base.ServerListener, 0) + for i := range conf.Listeners { + l := conf.Listeners[i] + if l == nil || l.Config == nil || l.Config.Purpose == nil { + continue + } + if len(l.Config.Purpose) != 1 { + return nil, fmt.Errorf("found listener with multiple purposes %q", strings.Join(l.Config.Purpose, ",")) + } + switch l.Config.Purpose[0] { + case "api": + c.apiListeners = append(c.apiListeners, l) + case "cluster": + clusterListeners = append(clusterListeners, l) + } + } + if len(c.apiListeners) == 0 { + return nil, fmt.Errorf("no api listeners found") + } + if len(clusterListeners) != 1 { + // in the future, we might pick the cluster that is exposed to the outside + // instead of limiting it to one. + return nil, fmt.Errorf("exactly one cluster listener is required") + } + c.clusterListener = clusterListeners[0] + var pluginLogger hclog.Logger for _, enabledPlugin := range c.enabledPlugins { if pluginLogger == nil { diff --git a/internal/servers/controller/controller_test.go b/internal/servers/controller/controller_test.go index c3bef2e3ba..1866b1a951 100644 --- a/internal/servers/controller/controller_test.go +++ b/internal/servers/controller/controller_test.go @@ -4,8 +4,10 @@ import ( "context" "testing" + "github.com/hashicorp/boundary/internal/cmd/base" "github.com/hashicorp/boundary/internal/db" "github.com/hashicorp/boundary/internal/kms" + "github.com/hashicorp/go-secure-stdlib/listenerutil" "github.com/stretchr/testify/require" ) @@ -32,3 +34,153 @@ func TestController_New(t *testing.T) { require.NoError(err) }) } + +func TestControllerNewListenerConfig(t *testing.T) { + tests := []struct { + name string + listeners []*base.ServerListener + assertions func(t *testing.T, c *Controller) + expErr bool + expErrMsg string + }{ + { + name: "valid listener configuration", + listeners: []*base.ServerListener{ + { + Config: &listenerutil.ListenerConfig{ + Purpose: []string{"api"}, + }, + }, + { + Config: &listenerutil.ListenerConfig{ + Purpose: []string{"api"}, + }, + }, + { + Config: &listenerutil.ListenerConfig{ + Purpose: []string{"cluster"}, + }, + }, + }, + assertions: func(t *testing.T, c *Controller) { + require.Len(t, c.apiListeners, 2) + require.NotNil(t, c.clusterListener) + }, + }, + { + name: "listeners are required", + listeners: []*base.ServerListener{}, + expErr: true, + expErrMsg: "no api listeners found", + }, + { + name: "listeners are required - not nil", + listeners: []*base.ServerListener{nil, nil}, + expErr: true, + expErrMsg: "no api listeners found", + }, + { + name: "listeners are required - with config", + listeners: []*base.ServerListener{{}, {}}, + expErr: true, + expErrMsg: "no api listeners found", + }, + { + name: "listeners are required - with purposes", + listeners: []*base.ServerListener{ + { + Config: &listenerutil.ListenerConfig{Purpose: nil}, + }, + { + Config: &listenerutil.ListenerConfig{Purpose: nil}, + }, + }, + expErr: true, + expErrMsg: "no api listeners found", + }, + { + name: "both api and cluster listeners are required", + listeners: []*base.ServerListener{ + { + Config: &listenerutil.ListenerConfig{ + Purpose: []string{"api"}, + }, + }, + }, + expErr: true, + expErrMsg: "exactly one cluster listener is required", + }, + { + name: "both api and cluster listeners are required 2", + listeners: []*base.ServerListener{ + { + Config: &listenerutil.ListenerConfig{ + Purpose: []string{"cluster"}, + }, + }, + }, + expErr: true, + expErrMsg: "no api listeners found", + }, + { + name: "only one cluster listener is allowed", + listeners: []*base.ServerListener{ + { + Config: &listenerutil.ListenerConfig{ + Purpose: []string{"api"}, + }, + }, + { + Config: &listenerutil.ListenerConfig{ + Purpose: []string{"cluster"}, + }, + }, + { + Config: &listenerutil.ListenerConfig{ + Purpose: []string{"cluster"}, + }, + }, + }, + expErr: true, + expErrMsg: "exactly one cluster listener is required", + }, + { + name: "only one purpose is allowed per listener", + listeners: []*base.ServerListener{ + { + Config: &listenerutil.ListenerConfig{ + Purpose: []string{"api", "cluster"}, + }, + }, + }, + expErr: true, + expErrMsg: `found listener with multiple purposes "api,cluster"`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + tc := &TestController{ + t: t, + ctx: ctx, + cancel: cancel, + opts: nil, + } + conf := TestControllerConfig(t, ctx, tc, nil) + conf.Listeners = tt.listeners + + c, err := New(ctx, conf) + if tt.expErr { + require.EqualError(t, err, tt.expErrMsg) + require.Nil(t, c) + return + } + + require.NoError(t, err) + require.NotNil(t, c) + tt.assertions(t, c) + }) + } +} diff --git a/internal/servers/controller/cors_test.go b/internal/servers/controller/cors_test.go index 7110446fdc..63d8543084 100644 --- a/internal/servers/controller/cors_test.go +++ b/internal/servers/controller/cors_test.go @@ -20,50 +20,54 @@ const corsTestConfig = ` disable_mlock = true telemetry { - prometheus_retention_time = "24h" - disable_hostname = true + prometheus_retention_time = "24h" + disable_hostname = true } kms "aead" { - purpose = "root" - aead_type = "aes-gcm" - key = "09iqFxRJNYsl/b8CQxjnGw==" - key_id = "global_root" + purpose = "root" + aead_type = "aes-gcm" + key = "09iqFxRJNYsl/b8CQxjnGw==" + key_id = "global_root" } kms "aead" { - purpose = "worker-auth" - aead_type = "aes-gcm" - key = "09iqFxRJNYsl/b8CQxjnGw==" - key_id = "global_worker-auth" + purpose = "worker-auth" + aead_type = "aes-gcm" + key = "09iqFxRJNYsl/b8CQxjnGw==" + key_id = "global_worker-auth" } listener "tcp" { - purpose = "api" - tls_disable = true - cors_enabled = false + purpose = "cluster" } listener "tcp" { - purpose = "api" - tls_disable = true - cors_enabled = true - cors_allowed_origins = [] + purpose = "api" + tls_disable = true + cors_enabled = false } listener "tcp" { - purpose = "api" - tls_disable = true - cors_enabled = true - cors_allowed_origins = ["foobar.com", "barfoo.com"] + purpose = "api" + tls_disable = true + cors_enabled = true + cors_allowed_origins = [] } listener "tcp" { - purpose = "api" - tls_disable = true - cors_enabled = true - cors_allowed_origins = ["*"] - cors_allowed_headers = ["x-foobar"] + purpose = "api" + tls_disable = true + cors_enabled = true + cors_allowed_origins = ["foobar.com", "barfoo.com"] +} + +listener "tcp" { + purpose = "api" + tls_disable = true + cors_enabled = true + cors_allowed_origins = ["*"] + cors_allowed_headers = ["x-foobar"] } ` From b5ef468d333f83073c258e6043b418707ae66008 Mon Sep 17 00:00:00 2001 From: Hugo Vieira Date: Wed, 9 Mar 2022 20:41:15 +0000 Subject: [PATCH 2/2] refact(controller): Ensure deterministic listeners start order Previously, we initialized the listeners on the order they were provided to us on the config file. In the future, we'll want to add new listener purposes that depend on existing ones, so the init order needs to be deterministic. With this change, we'll initialize components in this order: 1. gRPC Server for Controller API. 2. All `api` HTTP Servers & gRPC Gateway Muxes. 3. `cluster` gRPC Server. This commit introduces some renames for the sake of clarity as well as improved test coverage: - Removes `gatewayMux` reference from Controller object. This object was misleading and is now uneeded. - Removed uneeded `context.Context` arguments. - Separates gRPC Server and gRPC Gateway concerns, both in functionality and in naming. These two were conflated before. - Added unit tests on `startListeners`. --- internal/servers/controller/controller.go | 11 +- internal/servers/controller/gateway.go | 22 +- internal/servers/controller/handler.go | 190 +++++---- internal/servers/controller/listeners.go | 245 ++++++----- internal/servers/controller/listeners_test.go | 400 ++++++++++++++++++ 5 files changed, 631 insertions(+), 237 deletions(-) create mode 100644 internal/servers/controller/listeners_test.go diff --git a/internal/servers/controller/controller.go b/internal/servers/controller/controller.go index d3bae5bcda..4a104da979 100644 --- a/internal/servers/controller/controller.go +++ b/internal/servers/controller/controller.go @@ -7,7 +7,6 @@ import ( "strings" "sync" - "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" "github.com/hashicorp/boundary/internal/auth/oidc" "github.com/hashicorp/boundary/internal/auth/password" "github.com/hashicorp/boundary/internal/authtoken" @@ -58,11 +57,9 @@ type Controller struct { // Used for testing and tracking worker health workerStatusUpdateTimes *sync.Map - // grpc gateway server - gatewayServer *grpc.Server - gatewayTicket string - gatewayListener gatewayListener - gatewayMux *runtime.ServeMux + apiGrpcServer *grpc.Server + apiGrpcServerListener grpcServerListener + apiGrpcGatewayTicket string // Repo factory methods AuthTokenRepoFn common.AuthTokenRepoFactory @@ -293,7 +290,7 @@ func (c *Controller) Start() error { if err := c.registerJobs(); err != nil { return fmt.Errorf("error registering jobs: %w", err) } - if err := c.startListeners(c.baseContext); err != nil { + if err := c.startListeners(); err != nil { return fmt.Errorf("error starting controller listeners: %w", err) } if err := c.scheduler.Start(c.baseContext, c.schedulerWg); err != nil { diff --git a/internal/servers/controller/gateway.go b/internal/servers/controller/gateway.go index c0fc6eac64..005e188034 100644 --- a/internal/servers/controller/gateway.go +++ b/internal/servers/controller/gateway.go @@ -19,20 +19,14 @@ import ( "google.golang.org/grpc/test/bufconn" ) -// newGatewayListener will create an in-memory listener -func newGatewayListener() (gatewayListener, string) { - buffer := globals.DefaultMaxRequestSize // seems like a reasonable size for the ring buffer, but then happily change the size if more info becomes available - return bufconn.Listen(int(buffer)), "" -} - const gatewayTarget = "" -type gatewayListener interface { +type grpcServerListener interface { net.Listener Dial() (net.Conn, error) } -func gatewayDialOptions(lis gatewayListener) []grpc.DialOption { +func gatewayDialOptions(lis grpcServerListener) []grpc.DialOption { return []grpc.DialOption{ grpc.WithInsecure(), grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) { @@ -41,7 +35,7 @@ func gatewayDialOptions(lis gatewayListener) []grpc.DialOption { } } -func newGatewayMux() *runtime.ServeMux { +func newGrpcGatewayMux() *runtime.ServeMux { return runtime.NewServeMux( runtime.WithMarshalerOption(runtime.MIMEWildcard, &runtime.HTTPBodyMarshaler{ Marshaler: handlers.JSONMarshaler(), @@ -51,7 +45,13 @@ func newGatewayMux() *runtime.ServeMux { ) } -func newGatewayServer( +// newGrpcServerListener will create an in-memory listener for the gRPC server. +func newGrpcServerListener() grpcServerListener { + buffer := globals.DefaultMaxRequestSize // seems like a reasonable size for the ring buffer, but then happily change the size if more info becomes available + return bufconn.Listen(int(buffer)) +} + +func newGrpcServer( ctx context.Context, iamRepoFn common.IamRepoFactory, authTokenRepoFn common.AuthTokenRepoFactory, @@ -59,7 +59,7 @@ func newGatewayServer( kms *kms.Kms, eventer *event.Eventer, ) (*grpc.Server, string, error) { - const op = "controller.newGatewayServer" + const op = "controller.newGrpcServer" ticket, err := db.NewPrivateId("gwticket") if err != nil { return nil, "", errors.Wrap(ctx, err, op, errors.WithMsg("unable to generate gateway ticket")) diff --git a/internal/servers/controller/handler.go b/internal/servers/controller/handler.go index dae3312514..eab9ce9578 100644 --- a/internal/servers/controller/handler.go +++ b/internal/servers/controller/handler.go @@ -13,6 +13,7 @@ import ( "strings" "time" + "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" "github.com/hashicorp/boundary/globals" "github.com/hashicorp/boundary/internal/auth/oidc" "github.com/hashicorp/boundary/internal/gen/controller/api/services" @@ -39,6 +40,7 @@ import ( "github.com/hashicorp/go-secure-stdlib/listenerutil" "github.com/hashicorp/go-secure-stdlib/strutil" "github.com/mr-tron/base58" + "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/protobuf/proto" @@ -50,17 +52,17 @@ type HandlerProperties struct { CancelCtx context.Context } -// Handler returns an http.Handler for the services. This can be used on +// apiHandler returns an http.Handler for the services. This can be used on // its own to mount the Controller API within another web server. -func (c *Controller) handler(props HandlerProperties) (http.Handler, error) { - // Create the muxer to handle the actual endpoints +func (c *Controller) apiHandler(props HandlerProperties) (http.Handler, error) { mux := http.NewServeMux() - h, err := handleGrpcGateway(c, props) + grpcGwMux := newGrpcGatewayMux() + err := registerGrpcGatewayEndpoints(props.CancelCtx, grpcGwMux, gatewayDialOptions(c.apiGrpcServerListener)...) if err != nil { return nil, err } - mux.Handle("/v1/", h) + mux.Handle("/v1/", grpcGwMux) mux.Handle("/", handleUi(c)) corsWrappedHandler := wrapHandlerWithCors(mux, props) @@ -68,103 +70,74 @@ func (c *Controller) handler(props HandlerProperties) (http.Handler, error) { callbackInterceptingHandler := wrapHandlerWithCallbackInterceptor(commonWrappedHandler, c) printablePathCheckHandler := cleanhttp.PrintablePathCheckHandler(callbackInterceptingHandler, nil) eventsHandler, err := common.WrapWithEventsHandler(printablePathCheckHandler, c.conf.Eventer, c.kms, props.ListenerConfig) - if err != nil { - return nil, err - } - return eventsHandler, nil + return eventsHandler, err } -func handleGrpcGateway(c *Controller, props HandlerProperties) (http.Handler, error) { - // Register*ServiceHandlerServer methods ignore the passed in ctx. Using it - // now however in case this changes in the future. - ctx := props.CancelCtx - currentServices := c.gatewayServer.GetServiceInfo() - dialOptions := gatewayDialOptions(c.gatewayListener) +func (c *Controller) registerGrpcServices(s *grpc.Server) error { + // We have to check against the current services because the gRPC lib treats a duplicate + // register call as an error and os.Exits. + currentServices := s.GetServiceInfo() if _, ok := currentServices[services.HostCatalogService_ServiceDesc.ServiceName]; !ok { hcs, err := host_catalogs.NewService(c.StaticHostRepoFn, c.PluginHostRepoFn, c.HostPluginRepoFn, c.IamRepoFn) if err != nil { - return nil, fmt.Errorf("failed to create host catalog handler service: %w", err) - } - services.RegisterHostCatalogServiceServer(c.gatewayServer, hcs) - if err := services.RegisterHostCatalogServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil { - return nil, fmt.Errorf("failed to register host catalog service handler: %w", err) + return fmt.Errorf("failed to create host catalog handler service: %w", err) } + services.RegisterHostCatalogServiceServer(s, hcs) } if _, ok := currentServices[services.HostSetService_ServiceDesc.ServiceName]; !ok { hss, err := host_sets.NewService(c.StaticHostRepoFn, c.PluginHostRepoFn) if err != nil { - return nil, fmt.Errorf("failed to create host set handler service: %w", err) - } - services.RegisterHostSetServiceServer(c.gatewayServer, hss) - if err := services.RegisterHostSetServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil { - return nil, fmt.Errorf("failed to register host set service handler: %w", err) + return fmt.Errorf("failed to create host set handler service: %w", err) } + services.RegisterHostSetServiceServer(s, hss) } if _, ok := currentServices[services.HostService_ServiceDesc.ServiceName]; !ok { hs, err := hosts.NewService(c.StaticHostRepoFn, c.PluginHostRepoFn) if err != nil { - return nil, fmt.Errorf("failed to create host handler service: %w", err) - } - services.RegisterHostServiceServer(c.gatewayServer, hs) - if err := services.RegisterHostServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil { - return nil, fmt.Errorf("failed to register host service handler: %w", err) + return fmt.Errorf("failed to create host handler service: %w", err) } + services.RegisterHostServiceServer(s, hs) } if _, ok := currentServices[services.AccountService_ServiceDesc.ServiceName]; !ok { accts, err := accounts.NewService(c.PasswordAuthRepoFn, c.OidcRepoFn) if err != nil { - return nil, fmt.Errorf("failed to create account handler service: %w", err) - } - services.RegisterAccountServiceServer(c.gatewayServer, accts) - if err := services.RegisterAccountServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil { - return nil, fmt.Errorf("failed to register account service handler: %w", err) + return fmt.Errorf("failed to create account handler service: %w", err) } + services.RegisterAccountServiceServer(s, accts) } if _, ok := currentServices[services.AuthMethodService_ServiceDesc.ServiceName]; !ok { authMethods, err := authmethods.NewService(c.kms, c.PasswordAuthRepoFn, c.OidcRepoFn, c.IamRepoFn, c.AuthTokenRepoFn) if err != nil { - return nil, fmt.Errorf("failed to create auth method handler service: %w", err) - } - services.RegisterAuthMethodServiceServer(c.gatewayServer, authMethods) - if err := services.RegisterAuthMethodServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil { - return nil, fmt.Errorf("failed to register auth method service handler: %w", err) + return fmt.Errorf("failed to create auth method handler service: %w", err) } + services.RegisterAuthMethodServiceServer(s, authMethods) } if _, ok := currentServices[services.AuthTokenService_ServiceDesc.ServiceName]; !ok { authtoks, err := authtokens.NewService(c.AuthTokenRepoFn, c.IamRepoFn) if err != nil { - return nil, fmt.Errorf("failed to create auth token handler service: %w", err) - } - services.RegisterAuthTokenServiceServer(c.gatewayServer, authtoks) - if err := services.RegisterAuthTokenServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil { - return nil, fmt.Errorf("failed to register auth token service handler: %w", err) + return fmt.Errorf("failed to create auth token handler service: %w", err) } + services.RegisterAuthTokenServiceServer(s, authtoks) } if _, ok := currentServices[services.ScopeService_ServiceDesc.ServiceName]; !ok { os, err := scopes.NewService(c.IamRepoFn) if err != nil { - return nil, fmt.Errorf("failed to create scope handler service: %w", err) - } - services.RegisterScopeServiceServer(c.gatewayServer, os) - if err := services.RegisterScopeServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil { - return nil, fmt.Errorf("failed to register scope service handler: %w", err) + return fmt.Errorf("failed to create scope handler service: %w", err) } + services.RegisterScopeServiceServer(s, os) } if _, ok := currentServices[services.UserService_ServiceDesc.ServiceName]; !ok { us, err := users.NewService(c.IamRepoFn) if err != nil { - return nil, fmt.Errorf("failed to create user handler service: %w", err) - } - services.RegisterUserServiceServer(c.gatewayServer, us) - if err := services.RegisterUserServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil { - return nil, fmt.Errorf("failed to register user service handler: %w", err) + return fmt.Errorf("failed to create user handler service: %w", err) } + services.RegisterUserServiceServer(s, us) } if _, ok := currentServices[services.TargetService_ServiceDesc.ServiceName]; !ok { ts, err := targets.NewService( - ctx, + c.baseContext, c.kms, c.TargetRepoFn, c.IamRepoFn, @@ -174,75 +147,106 @@ func handleGrpcGateway(c *Controller, props HandlerProperties) (http.Handler, er c.StaticHostRepoFn, c.VaultCredentialRepoFn) if err != nil { - return nil, fmt.Errorf("failed to create target handler service: %w", err) - } - services.RegisterTargetServiceServer(c.gatewayServer, ts) - if err := services.RegisterTargetServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil { - return nil, fmt.Errorf("failed to register target service handler: %w", err) + return fmt.Errorf("failed to create target handler service: %w", err) } + services.RegisterTargetServiceServer(s, ts) } if _, ok := currentServices[services.GroupService_ServiceDesc.ServiceName]; !ok { gs, err := groups.NewService(c.IamRepoFn) if err != nil { - return nil, fmt.Errorf("failed to create group handler service: %w", err) - } - services.RegisterGroupServiceServer(c.gatewayServer, gs) - if err := services.RegisterGroupServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil { - return nil, fmt.Errorf("failed to register group service handler: %w", err) + return fmt.Errorf("failed to create group handler service: %w", err) } + services.RegisterGroupServiceServer(s, gs) } if _, ok := currentServices[services.RoleService_ServiceDesc.ServiceName]; !ok { rs, err := roles.NewService(c.IamRepoFn) if err != nil { - return nil, fmt.Errorf("failed to create role handler service: %w", err) - } - services.RegisterRoleServiceServer(c.gatewayServer, rs) - if err := services.RegisterRoleServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil { - return nil, fmt.Errorf("failed to register role service handler: %w", err) + return fmt.Errorf("failed to create role handler service: %w", err) } + services.RegisterRoleServiceServer(s, rs) } if _, ok := currentServices[services.SessionService_ServiceDesc.ServiceName]; !ok { ss, err := sessions.NewService(c.SessionRepoFn, c.IamRepoFn) if err != nil { - return nil, fmt.Errorf("failed to create session handler service: %w", err) - } - services.RegisterSessionServiceServer(c.gatewayServer, ss) - if err := services.RegisterSessionServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil { - return nil, fmt.Errorf("failed to register session service handler: %w", err) + return fmt.Errorf("failed to create session handler service: %w", err) } + services.RegisterSessionServiceServer(s, ss) } if _, ok := currentServices[services.ManagedGroupService_ServiceDesc.ServiceName]; !ok { mgs, err := managed_groups.NewService(c.OidcRepoFn) if err != nil { - return nil, fmt.Errorf("failed to create managed groups handler service: %w", err) - } - services.RegisterManagedGroupServiceServer(c.gatewayServer, mgs) - if err := services.RegisterManagedGroupServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil { - return nil, fmt.Errorf("failed to register managed groups service handler: %w", err) + return fmt.Errorf("failed to create managed groups handler service: %w", err) } + services.RegisterManagedGroupServiceServer(s, mgs) } if _, ok := currentServices[services.CredentialStoreService_ServiceDesc.ServiceName]; !ok { cs, err := credentialstores.NewService(c.VaultCredentialRepoFn, c.IamRepoFn) if err != nil { - return nil, fmt.Errorf("failed to create credential store handler service: %w", err) - } - services.RegisterCredentialStoreServiceServer(c.gatewayServer, cs) - if err := services.RegisterCredentialStoreServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil { - return nil, fmt.Errorf("failed to register credential store service handler: %w", err) + return fmt.Errorf("failed to create credential store handler service: %w", err) } + services.RegisterCredentialStoreServiceServer(s, cs) } if _, ok := currentServices[services.CredentialLibraryService_ServiceDesc.ServiceName]; !ok { cl, err := credentiallibraries.NewService(c.VaultCredentialRepoFn, c.IamRepoFn) if err != nil { - return nil, fmt.Errorf("failed to create credential library handler service: %w", err) - } - services.RegisterCredentialLibraryServiceServer(c.gatewayServer, cl) - if err := services.RegisterCredentialLibraryServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil { - return nil, fmt.Errorf("failed to register credential library service handler: %w", err) + return fmt.Errorf("failed to create credential library handler service: %w", err) } + services.RegisterCredentialLibraryServiceServer(s, cl) + } + + return nil +} + +func registerGrpcGatewayEndpoints(ctx context.Context, gwMux *runtime.ServeMux, dialOptions ...grpc.DialOption) error { + // Register*ServiceHandlerServer methods ignore the passed in context. + // Passing it in anyways in case this changes in the future. + if err := services.RegisterHostCatalogServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { + return fmt.Errorf("failed to register host catalog service handler: %w", err) + } + if err := services.RegisterHostSetServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { + return fmt.Errorf("failed to register host set service handler: %w", err) + } + if err := services.RegisterHostServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { + return fmt.Errorf("failed to register host service handler: %w", err) + } + if err := services.RegisterAccountServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { + return fmt.Errorf("failed to register account service handler: %w", err) + } + if err := services.RegisterAuthMethodServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { + return fmt.Errorf("failed to register auth method service handler: %w", err) + } + if err := services.RegisterAuthTokenServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { + return fmt.Errorf("failed to register auth token service handler: %w", err) + } + if err := services.RegisterScopeServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { + return fmt.Errorf("failed to register scope service handler: %w", err) + } + if err := services.RegisterUserServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { + return fmt.Errorf("failed to register user service handler: %w", err) + } + if err := services.RegisterTargetServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { + return fmt.Errorf("failed to register target service handler: %w", err) + } + if err := services.RegisterGroupServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { + return fmt.Errorf("failed to register group service handler: %w", err) + } + if err := services.RegisterRoleServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { + return fmt.Errorf("failed to register role service handler: %w", err) + } + if err := services.RegisterSessionServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { + return fmt.Errorf("failed to register session service handler: %w", err) + } + if err := services.RegisterManagedGroupServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { + return fmt.Errorf("failed to register managed groups service handler: %w", err) + } + if err := services.RegisterCredentialStoreServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { + return fmt.Errorf("failed to register credential store service handler: %w", err) + } + if err := services.RegisterCredentialLibraryServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { + return fmt.Errorf("failed to register credential library service handler: %w", err) } - return c.gatewayMux, nil + return nil } func wrapHandlerWithCommonFuncs(h http.Handler, c *Controller, props HandlerProperties) http.Handler { @@ -301,7 +305,7 @@ func wrapHandlerWithCommonFuncs(h http.Handler, c *Controller, props HandlerProp // Serialize the request info to send it across the wire to the // grpc-gateway via an http header - requestInfo.Ticket = c.gatewayTicket // allows the grpc-gateway to verify the request info came from it's in-memory companion http proxy + requestInfo.Ticket = c.apiGrpcGatewayTicket // allows the grpc-gateway to verify the request info came from it's in-memory companion http proxy marshalledRequestInfo, err := proto.Marshal(&requestInfo) if err != nil { event.WriteError(ctx, op, err, event.WithInfoMsg("error marshaling request info")) diff --git a/internal/servers/controller/listeners.go b/internal/servers/controller/listeners.go index 2079018ded..8e5852870d 100644 --- a/internal/servers/controller/listeners.go +++ b/internal/servers/controller/listeners.go @@ -22,151 +22,144 @@ import ( "google.golang.org/grpc" ) -func (c *Controller) startListeners(ctx context.Context) error { +func (c *Controller) startListeners() error { servers := make([]func(), 0, len(c.conf.Listeners)) - configureForAPI := func(ln *base.ServerListener) error { - var err error - if c.gatewayServer, c.gatewayTicket, err = newGatewayServer(ctx, c.IamRepoFn, c.AuthTokenRepoFn, c.ServersRepoFn, c.kms, c.conf.Eventer); err != nil { - return err - } - c.gatewayMux = newGatewayMux() + grpcServer, gwTicket, err := newGrpcServer(c.baseContext, c.IamRepoFn, c.AuthTokenRepoFn, c.ServersRepoFn, c.kms, c.conf.Eventer) + if err != nil { + return fmt.Errorf("failed to create new grpc server: %w", err) + } + c.apiGrpcServer = grpcServer + c.apiGrpcGatewayTicket = gwTicket + + err = c.registerGrpcServices(c.apiGrpcServer) + if err != nil { + return fmt.Errorf("failed to register grpc services: %w", err) + } + + c.apiGrpcServerListener = newGrpcServerListener() + servers = append(servers, func() { + go c.apiGrpcServer.Serve(c.apiGrpcServerListener) + }) - handler, err := c.handler(HandlerProperties{ - ListenerConfig: ln.Config, - CancelCtx: c.baseContext, - }) + for i := range c.apiListeners { + ln := c.apiListeners[i] + apiServers, err := c.configureForAPI(ln) if err != nil { - return err + return fmt.Errorf("failed to configure listener for api mode: %w", err) } + servers = append(servers, apiServers...) + } - // Resolve it here to avoid race conditions if the base context is - // replaced - cancelCtx := c.baseContext - - server := &http.Server{ - Handler: handler, - ReadHeaderTimeout: 10 * time.Second, - ReadTimeout: 30 * time.Second, - IdleTimeout: 5 * time.Minute, - ErrorLog: c.logger.StandardLogger(nil), - BaseContext: func(net.Listener) context.Context { - return cancelCtx - }, - } - ln.HTTPServer = server + clusterServer, err := c.configureForCluster(c.clusterListener) + if err != nil { + return fmt.Errorf("failed to configure listener for cluster mode: %w", err) + } + servers = append(servers, clusterServer) - if ln.Config.HTTPReadHeaderTimeout > 0 { - server.ReadHeaderTimeout = ln.Config.HTTPReadHeaderTimeout - } - if ln.Config.HTTPReadTimeout > 0 { - server.ReadTimeout = ln.Config.HTTPReadTimeout - } - if ln.Config.HTTPWriteTimeout > 0 { - server.WriteTimeout = ln.Config.HTTPWriteTimeout - } - if ln.Config.HTTPIdleTimeout > 0 { - server.IdleTimeout = ln.Config.HTTPIdleTimeout - } + for _, s := range servers { + s() + } - switch ln.Config.TLSDisable { - case true: - l, err := ln.Mux.RegisterProto(alpnmux.NoProto, nil) - if err != nil { - return fmt.Errorf("error getting non-tls listener: %w", err) - } - if l == nil { - return errors.New("could not get non-tls listener") - } - servers = append(servers, func() { - go server.Serve(l) - }) - - default: - protos := []string{"", "http/1.1", "h2"} - for _, v := range protos { - l := ln.Mux.GetListener(v) - if l == nil { - return fmt.Errorf("could not get tls proto %q listener", v) - } - servers = append(servers, func() { - go server.Serve(l) - }) - } - } + return nil +} - return nil +func (c *Controller) configureForAPI(ln *base.ServerListener) ([]func(), error) { + apiServers := make([]func(), 0) + + handler, err := c.apiHandler(HandlerProperties{ + ListenerConfig: ln.Config, + CancelCtx: c.baseContext, + }) + if err != nil { + return nil, err } - configureForCluster := func(ln *base.ServerListener) error { - // Clear out in case this is a second start of the controller - ln.Mux.UnregisterProto(alpnmux.DefaultProto) - l, err := ln.Mux.RegisterProto(alpnmux.DefaultProto, &tls.Config{ - GetConfigForClient: c.validateWorkerTls, - }) + cancelCtx := c.baseContext // Resolve to avoid race conditions if the base context is replaced. + server := &http.Server{ + Handler: handler, + ReadHeaderTimeout: 10 * time.Second, + ReadTimeout: 30 * time.Second, + IdleTimeout: 5 * time.Minute, + ErrorLog: c.logger.StandardLogger(nil), + BaseContext: func(net.Listener) context.Context { return cancelCtx }, + } + ln.HTTPServer = server + + if ln.Config.HTTPReadHeaderTimeout > 0 { + server.ReadHeaderTimeout = ln.Config.HTTPReadHeaderTimeout + } + if ln.Config.HTTPReadTimeout > 0 { + server.ReadTimeout = ln.Config.HTTPReadTimeout + } + if ln.Config.HTTPWriteTimeout > 0 { + server.WriteTimeout = ln.Config.HTTPWriteTimeout + } + if ln.Config.HTTPIdleTimeout > 0 { + server.IdleTimeout = ln.Config.HTTPIdleTimeout + } + + switch ln.Config.TLSDisable { + case true: + l, err := ln.Mux.RegisterProto(alpnmux.NoProto, nil) if err != nil { - return fmt.Errorf("error getting sub-listener for worker proto: %w", err) + return nil, fmt.Errorf("error getting non-tls listener: %w", err) + } + if l == nil { + return nil, errors.New("could not get non-tls listener") } + apiServers = append(apiServers, func() { go server.Serve(l) }) - workerReqInterceptor, err := workerRequestInfoInterceptor(ctx, c.conf.Eventer) - if err != nil { - return fmt.Errorf("error getting sub-listener for worker proto: %w", err) + default: + for _, v := range []string{"", "http/1.1", "h2"} { + l := ln.Mux.GetListener(v) + if l == nil { + return nil, fmt.Errorf("could not get tls proto %q listener", v) + } + apiServers = append(apiServers, func() { go server.Serve(l) }) } - workerServer := grpc.NewServer( - grpc.MaxRecvMsgSize(math.MaxInt32), - grpc.MaxSendMsgSize(math.MaxInt32), - grpc.UnaryInterceptor( - grpc_middleware.ChainUnaryServer( - workerReqInterceptor, - auditRequestInterceptor(ctx), // before we get started, audit the request - auditResponseInterceptor(ctx), // as we finish, audit the response - ), - ), - ) - workerService := workers.NewWorkerServiceServer(c.ServersRepoFn, c.SessionRepoFn, c.ConnectionRepoFn, - c.workerStatusUpdateTimes, c.kms) - pbs.RegisterServerCoordinationServiceServer(workerServer, workerService) - pbs.RegisterSessionServiceServer(workerServer, workerService) - - interceptor := newInterceptingListener(c, l) - ln.ALPNListener = interceptor - ln.GrpcServer = workerServer - - servers = append(servers, func() { - go workerServer.Serve(interceptor) - }) - return nil } - c.gatewayListener, _ = newGatewayListener() - servers = append(servers, func() { - go c.gatewayServer.Serve(c.gatewayListener) - }) + return apiServers, nil +} - for _, ln := range c.conf.Listeners { - var err error - for _, purpose := range ln.Config.Purpose { - switch purpose { - case "api": - err = configureForAPI(ln) - case "cluster": - err = configureForCluster(ln) - case "proxy": - // Do nothing, in a dev mode we might see it here - default: - err = fmt.Errorf("unknown listener purpose %q", purpose) - } - if err != nil { - return err - } - } +func (c *Controller) configureForCluster(ln *base.ServerListener) (func(), error) { + // Clear out in case this is a second start of the controller + ln.Mux.UnregisterProto(alpnmux.DefaultProto) + l, err := ln.Mux.RegisterProto(alpnmux.DefaultProto, &tls.Config{ + GetConfigForClient: c.validateWorkerTls, + }) + if err != nil { + return nil, fmt.Errorf("error getting sub-listener for worker proto: %w", err) } - for _, s := range servers { - s() + workerReqInterceptor, err := workerRequestInfoInterceptor(c.baseContext, c.conf.Eventer) + if err != nil { + return nil, fmt.Errorf("error getting sub-listener for worker proto: %w", err) } - return nil + workerServer := grpc.NewServer( + grpc.MaxRecvMsgSize(math.MaxInt32), + grpc.MaxSendMsgSize(math.MaxInt32), + grpc.UnaryInterceptor( + grpc_middleware.ChainUnaryServer( + workerReqInterceptor, + auditRequestInterceptor(c.baseContext), // before we get started, audit the request + auditResponseInterceptor(c.baseContext), // as we finish, audit the response + ), + ), + ) + + workerService := workers.NewWorkerServiceServer(c.ServersRepoFn, c.SessionRepoFn, c.ConnectionRepoFn, + c.workerStatusUpdateTimes, c.kms) + pbs.RegisterServerCoordinationServiceServer(workerServer, workerService) + pbs.RegisterSessionServiceServer(workerServer, workerService) + + interceptor := newInterceptingListener(c, l) + ln.ALPNListener = interceptor + ln.GrpcServer = workerServer + + return func() { go ln.GrpcServer.Serve(ln.ALPNListener) }, nil } func (c *Controller) stopListeners(serversOnly bool) error { @@ -194,7 +187,7 @@ func (c *Controller) stopListeners(serversOnly bool) error { }() } - if c.gatewayServer != nil { + if c.apiGrpcServer != nil { serverWg.Add(1) go func() { defer serverWg.Done() @@ -202,9 +195,9 @@ func (c *Controller) stopListeners(serversOnly bool) error { defer shutdownKillCancel() go func() { <-shutdownKill.Done() - c.gatewayServer.Stop() + c.apiGrpcServer.Stop() }() - c.gatewayServer.GracefulStop() + c.apiGrpcServer.GracefulStop() }() } diff --git a/internal/servers/controller/listeners_test.go b/internal/servers/controller/listeners_test.go new file mode 100644 index 0000000000..e674882b86 --- /dev/null +++ b/internal/servers/controller/listeners_test.go @@ -0,0 +1,400 @@ +package controller + +import ( + "context" + "crypto/ed25519" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "encoding/base64" + "encoding/json" + "encoding/pem" + "fmt" + "io/ioutil" + "math/big" + "net" + "net/http" + "os" + "strconv" + "testing" + "time" + + "github.com/hashicorp/boundary/internal/cmd/base" + "github.com/hashicorp/go-secure-stdlib/base62" + "github.com/hashicorp/go-secure-stdlib/configutil/v2" + "github.com/hashicorp/go-secure-stdlib/listenerutil" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/protobuf/proto" +) + +func TestStartListeners(t *testing.T) { + tests := []struct { + name string + setup func(t *testing.T) + listeners []*listenerutil.ListenerConfig + assertions func(t *testing.T, c *Controller, apiAddrs []string, clusterAddr string) + }{ + { + name: "one api, one cluster listener", + listeners: []*listenerutil.ListenerConfig{ + { + Type: "tcp", + Purpose: []string{"api"}, + Address: "127.0.0.1:0", + TLSDisable: true, + }, + { + Type: "tcp", + Purpose: []string{"cluster"}, + Address: "127.0.0.1:0", + }, + }, + assertions: func(t *testing.T, c *Controller, apiAddrs []string, clusterAddr string) { + rsp, err := http.Get("http://" + apiAddrs[0] + "/v1/auth-methods?scope_id=global") + require.NoError(t, err) + require.Equal(t, http.StatusOK, rsp.StatusCode) + + clusterGrpcDialNoError(t, c, "tcp", clusterAddr) + }, + }, + { + name: "multiple api, one cluster listeners", + listeners: []*listenerutil.ListenerConfig{ + { + Type: "tcp", + Purpose: []string{"api"}, + Address: "127.0.0.1:0", + TLSDisable: true, + }, + { + Type: "tcp", + Purpose: []string{"api"}, + Address: "127.0.0.1:0", + TLSDisable: true, + }, + { + Type: "tcp", + Purpose: []string{"cluster"}, + Address: "127.0.0.1:0", + }, + }, + assertions: func(t *testing.T, c *Controller, apiAddrs []string, clusterAddr string) { + for _, apiAddr := range apiAddrs { + rsp, err := http.Get("http://" + apiAddr + "/v1/auth-methods?scope_id=global") + require.NoError(t, err) + require.Equal(t, http.StatusOK, rsp.StatusCode) + } + + clusterGrpcDialNoError(t, c, "tcp", clusterAddr) + }, + }, + { + name: "one api (tls), one cluster listener", + setup: testApiTlsSetup(t, "./boundary-startlistenerstest.cert", "./boundary-startlistenerstest.key"), + listeners: []*listenerutil.ListenerConfig{ + { + Type: "tcp", + Purpose: []string{"api"}, + Address: "127.0.0.1:0", + TLSCertFile: "./boundary-startlistenerstest.cert", + TLSKeyFile: "./boundary-startlistenerstest.key", + }, + { + Type: "tcp", + Purpose: []string{"cluster"}, + Address: "127.0.0.1:0", + }, + }, + assertions: func(t *testing.T, c *Controller, apiAddrs []string, clusterAddr string) { + client := testTlsHttpClient(t, "./boundary-startlistenerstest.cert") + rsp, err := client.Get("https://" + apiAddrs[0] + "/v1/auth-methods?scope_id=global") + require.NoError(t, err) + require.Equal(t, http.StatusOK, rsp.StatusCode) + + clusterGrpcDialNoError(t, c, "tcp", clusterAddr) + }, + }, + { + name: "multiple api (tls), one cluster listener", + setup: func(t *testing.T) { + testApiTlsSetup(t, "./boundary-startlistenerstest0.cert", "./boundary-startlistenerstest0.key")(t) + testApiTlsSetup(t, "./boundary-startlistenerstest1.cert", "./boundary-startlistenerstest1.key")(t) + }, + listeners: []*listenerutil.ListenerConfig{ + { + Type: "tcp", + Purpose: []string{"api"}, + Address: "127.0.0.1:0", + TLSCertFile: "./boundary-startlistenerstest0.cert", + TLSKeyFile: "./boundary-startlistenerstest0.key", + }, + { + Type: "tcp", + Purpose: []string{"api"}, + Address: "127.0.0.1:0", + TLSCertFile: "./boundary-startlistenerstest1.cert", + TLSKeyFile: "./boundary-startlistenerstest1.key", + }, + { + Type: "tcp", + Purpose: []string{"cluster"}, + Address: "127.0.0.1:0", + }, + }, + assertions: func(t *testing.T, c *Controller, apiAddrs []string, clusterAddr string) { + for i, apiAddr := range apiAddrs { + client := testTlsHttpClient(t, "./boundary-startlistenerstest"+strconv.Itoa(i)+".cert") + rsp, err := client.Get("https://" + apiAddr + "/v1/auth-methods?scope_id=global") + require.NoError(t, err) + require.Equal(t, http.StatusOK, rsp.StatusCode) + } + + clusterGrpcDialNoError(t, c, "tcp", clusterAddr) + }, + }, + { + name: "one api, one cluster listener on unix sockets", + listeners: []*listenerutil.ListenerConfig{ + { + Type: "unix", + Purpose: []string{"api"}, + Address: "/tmp/boundary-listener-test-api.sock", + TLSDisable: true, + }, + { + Type: "unix", + Purpose: []string{"cluster"}, + Address: "/tmp/boundary-listener-test-cluster.sock", + }, + }, + assertions: func(t *testing.T, c *Controller, apiAddrs []string, clusterAddr string) { + conn, err := net.Dial("unix", apiAddrs[0]) + require.NoError(t, err) + + cl := http.Client{ + Transport: &http.Transport{ + Dial: func(network, addr string) (net.Conn, error) { return conn, nil }, + }, + } + rsp, err := cl.Get("http://randomdomain.boundary/v1/auth-methods?scope_id=global") + require.NoError(t, err) + require.Equal(t, http.StatusOK, rsp.StatusCode) + require.NoError(t, conn.Close()) + + clusterGrpcDialNoError(t, c, "unix", clusterAddr) + }, + }, + { + name: "multiple api, one cluster listener on unix sockets", + listeners: []*listenerutil.ListenerConfig{ + { + Type: "unix", + Purpose: []string{"api"}, + Address: "/tmp/boundary-listener-test-api.sock", + TLSDisable: true, + }, + { + Type: "unix", + Purpose: []string{"api"}, + Address: "/tmp/boundary-listener-test-api2.sock", + TLSDisable: true, + }, + { + Type: "unix", + Purpose: []string{"cluster"}, + Address: "/tmp/boundary-listener-test-cluster.sock", + }, + }, + assertions: func(t *testing.T, c *Controller, apiAddrs []string, clusterAddr string) { + for _, apiAddr := range apiAddrs { + conn, err := net.Dial("unix", apiAddr) + require.NoError(t, err) + + cl := http.Client{ + Transport: &http.Transport{ + Dial: func(network, addr string) (net.Conn, error) { return conn, nil }, + }, + } + rsp, err := cl.Get("http://randomdomain.boundary/v1/auth-methods?scope_id=global") + require.NoError(t, err) + require.Equal(t, http.StatusOK, rsp.StatusCode) + require.NoError(t, conn.Close()) + } + + clusterGrpcDialNoError(t, c, "unix", clusterAddr) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.setup != nil { + tt.setup(t) + } + + ctx, cancel := context.WithCancel(context.Background()) + + tc := &TestController{t: t, ctx: ctx, cancel: cancel, opts: &TestControllerOpts{}} + t.Cleanup(func() { tc.Shutdown() }) + + conf := TestControllerConfig(t, ctx, tc, nil) + conf.RawConfig.SharedConfig = &configutil.SharedConfig{Listeners: tt.listeners, DisableMlock: true} + + err := conf.SetupListeners(nil, conf.RawConfig.SharedConfig, []string{"api", "cluster"}) + require.NoError(t, err) + + c, err := New(ctx, conf) + require.NoError(t, err) + + c.baseContext = ctx + c.baseCancel = cancel + + err = c.startListeners() + require.NoError(t, err) + + apiAddrs := make([]string, 0) + for _, l := range c.apiListeners { + apiAddrs = append(apiAddrs, l.Mux.Addr().String()) + } + tt.assertions(t, c, apiAddrs, c.clusterListener.Mux.Addr().String()) + }) + } +} + +func clusterGrpcDialNoError(t *testing.T, c *Controller, network, addr string) { + grpcConn, err := grpc.Dial(addr, + grpc.WithInsecure(), + grpc.WithBlock(), + grpc.WithTimeout(5*time.Second), + grpc.WithContextDialer(clusterTestDialer(t, c, network)), + ) + require.NoError(t, err) + require.NoError(t, grpcConn.Close()) +} + +func testApiTlsSetup(t *testing.T, certPath, keyPath string) func(t *testing.T) { + return func(t *testing.T) { + t.Cleanup(func() { + require.NoError(t, os.Remove(certPath)) + require.NoError(t, os.Remove(keyPath)) + }) + + certBytes, _, priv := createTestCert(t) + + certFile, err := os.Create(certPath) + require.NoError(t, err) + require.NoError(t, pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: certBytes})) + require.NoError(t, certFile.Close()) + + keyFile, err := os.Create(keyPath) + require.NoError(t, err) + + marshaledKey, err := x509.MarshalPKCS8PrivateKey(priv) + require.NoError(t, err) + + require.NoError(t, pem.Encode(keyFile, &pem.Block{Type: "PRIVATE KEY", Bytes: marshaledKey})) + require.NoError(t, keyFile.Close()) + } +} + +func testTlsHttpClient(t *testing.T, filePath string) *http.Client { + f, err := os.Open(filePath) + require.NoError(t, err) + + certBytes, err := ioutil.ReadAll(f) + require.NoError(t, err) + require.NoError(t, f.Close()) + + certPool := x509.NewCertPool() + certPool.AppendCertsFromPEM(certBytes) + + return &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{RootCAs: certPool}, + }, + } +} + +func createTestCert(t *testing.T) ([]byte, ed25519.PublicKey, ed25519.PrivateKey) { + pub, priv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + template := &x509.Certificate{ + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement | x509.KeyUsageCertSign, + SerialNumber: big.NewInt(0), + NotBefore: time.Now().Add(-30 * time.Second), + NotAfter: time.Now().Add(5 * time.Minute), + BasicConstraintsValid: true, + IsCA: true, + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + DNSNames: []string{"/tmp/boundary-listener-test-cluster.sock"}, + } + + certBytes, err := x509.CreateCertificate(rand.Reader, template, template, pub, priv) + require.NoError(t, err) + + return certBytes, pub, priv +} + +func clusterTestDialer(t *testing.T, c *Controller, network string) func(context.Context, string) (net.Conn, error) { + return func(ctx context.Context, addr string) (net.Conn, error) { + certBytes, _, priv := createTestCert(t) + + marshaledKey, err := x509.MarshalPKCS8PrivateKey(priv) + require.NoError(t, err) + + nonce, err := base62.Random(20) + require.NoError(t, err) + + info := base.WorkerAuthInfo{ + CertPEM: pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certBytes}), + KeyPEM: pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: marshaledKey}), + ConnectionNonce: nonce, + } + infoBytes, err := json.Marshal(info) + require.NoError(t, err) + + encInfo, err := c.conf.WorkerAuthKms.Encrypt(ctx, infoBytes, nil) + require.NoError(t, err) + + protoEncInfo, err := proto.Marshal(encInfo) + require.NoError(t, err) + + b64alpn := base64.RawStdEncoding.EncodeToString(protoEncInfo) + + var nextProtos []string + var count int + for i := 0; i < len(b64alpn); i += 230 { + end := i + 230 + if end > len(b64alpn) { + end = len(b64alpn) + } + nextProtos = append(nextProtos, fmt.Sprintf("v1workerauth-%02d-%s", count, b64alpn[i:end])) + count++ + } + + cert, err := x509.ParseCertificate(certBytes) + require.NoError(t, err) + + rootCAs := x509.NewCertPool() + rootCAs.AddCert(cert) + + tlsCert, err := tls.X509KeyPair(info.CertPEM, info.KeyPEM) + require.NoError(t, err) + + conn, err := tls.Dial(network, addr, &tls.Config{ + Certificates: []tls.Certificate{tlsCert}, + RootCAs: rootCAs, + NextProtos: nextProtos, + MinVersion: tls.VersionTLS13, + }) + require.NoError(t, err) + + _, err = conn.Write([]byte(nonce)) + require.NoError(t, err) + + return conn, nil + } +}