diff --git a/internal/servers/controller/controller.go b/internal/servers/controller/controller.go index 9c81c2f5a8..4a104da979 100644 --- a/internal/servers/controller/controller.go +++ b/internal/servers/controller/controller.go @@ -7,7 +7,6 @@ import ( "strings" "sync" - "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" "github.com/hashicorp/boundary/internal/auth/oidc" "github.com/hashicorp/boundary/internal/auth/password" "github.com/hashicorp/boundary/internal/authtoken" @@ -52,14 +51,15 @@ type Controller struct { workerAuthCache *sync.Map + apiListeners []*base.ServerListener + clusterListener *base.ServerListener + // Used for testing and tracking worker health workerStatusUpdateTimes *sync.Map - // grpc gateway server - gatewayServer *grpc.Server - gatewayTicket string - gatewayListener gatewayListener - gatewayMux *runtime.ServeMux + apiGrpcServer *grpc.Server + apiGrpcServerListener grpcServerListener + apiGrpcGatewayTicket string // Repo factory methods AuthTokenRepoFn common.AuthTokenRepoFactory @@ -125,6 +125,32 @@ func New(ctx context.Context, conf *Config) (*Controller, error) { } } + clusterListeners := make([]*base.ServerListener, 0) + for i := range conf.Listeners { + l := conf.Listeners[i] + if l == nil || l.Config == nil || l.Config.Purpose == nil { + continue + } + if len(l.Config.Purpose) != 1 { + return nil, fmt.Errorf("found listener with multiple purposes %q", strings.Join(l.Config.Purpose, ",")) + } + switch l.Config.Purpose[0] { + case "api": + c.apiListeners = append(c.apiListeners, l) + case "cluster": + clusterListeners = append(clusterListeners, l) + } + } + if len(c.apiListeners) == 0 { + return nil, fmt.Errorf("no api listeners found") + } + if len(clusterListeners) != 1 { + // in the future, we might pick the cluster that is exposed to the outside + // instead of limiting it to one. + return nil, fmt.Errorf("exactly one cluster listener is required") + } + c.clusterListener = clusterListeners[0] + var pluginLogger hclog.Logger for _, enabledPlugin := range c.enabledPlugins { if pluginLogger == nil { @@ -264,7 +290,7 @@ func (c *Controller) Start() error { if err := c.registerJobs(); err != nil { return fmt.Errorf("error registering jobs: %w", err) } - if err := c.startListeners(c.baseContext); err != nil { + if err := c.startListeners(); err != nil { return fmt.Errorf("error starting controller listeners: %w", err) } if err := c.scheduler.Start(c.baseContext, c.schedulerWg); err != nil { diff --git a/internal/servers/controller/controller_test.go b/internal/servers/controller/controller_test.go index c3bef2e3ba..1866b1a951 100644 --- a/internal/servers/controller/controller_test.go +++ b/internal/servers/controller/controller_test.go @@ -4,8 +4,10 @@ import ( "context" "testing" + "github.com/hashicorp/boundary/internal/cmd/base" "github.com/hashicorp/boundary/internal/db" "github.com/hashicorp/boundary/internal/kms" + "github.com/hashicorp/go-secure-stdlib/listenerutil" "github.com/stretchr/testify/require" ) @@ -32,3 +34,153 @@ func TestController_New(t *testing.T) { require.NoError(err) }) } + +func TestControllerNewListenerConfig(t *testing.T) { + tests := []struct { + name string + listeners []*base.ServerListener + assertions func(t *testing.T, c *Controller) + expErr bool + expErrMsg string + }{ + { + name: "valid listener configuration", + listeners: []*base.ServerListener{ + { + Config: &listenerutil.ListenerConfig{ + Purpose: []string{"api"}, + }, + }, + { + Config: &listenerutil.ListenerConfig{ + Purpose: []string{"api"}, + }, + }, + { + Config: &listenerutil.ListenerConfig{ + Purpose: []string{"cluster"}, + }, + }, + }, + assertions: func(t *testing.T, c *Controller) { + require.Len(t, c.apiListeners, 2) + require.NotNil(t, c.clusterListener) + }, + }, + { + name: "listeners are required", + listeners: []*base.ServerListener{}, + expErr: true, + expErrMsg: "no api listeners found", + }, + { + name: "listeners are required - not nil", + listeners: []*base.ServerListener{nil, nil}, + expErr: true, + expErrMsg: "no api listeners found", + }, + { + name: "listeners are required - with config", + listeners: []*base.ServerListener{{}, {}}, + expErr: true, + expErrMsg: "no api listeners found", + }, + { + name: "listeners are required - with purposes", + listeners: []*base.ServerListener{ + { + Config: &listenerutil.ListenerConfig{Purpose: nil}, + }, + { + Config: &listenerutil.ListenerConfig{Purpose: nil}, + }, + }, + expErr: true, + expErrMsg: "no api listeners found", + }, + { + name: "both api and cluster listeners are required", + listeners: []*base.ServerListener{ + { + Config: &listenerutil.ListenerConfig{ + Purpose: []string{"api"}, + }, + }, + }, + expErr: true, + expErrMsg: "exactly one cluster listener is required", + }, + { + name: "both api and cluster listeners are required 2", + listeners: []*base.ServerListener{ + { + Config: &listenerutil.ListenerConfig{ + Purpose: []string{"cluster"}, + }, + }, + }, + expErr: true, + expErrMsg: "no api listeners found", + }, + { + name: "only one cluster listener is allowed", + listeners: []*base.ServerListener{ + { + Config: &listenerutil.ListenerConfig{ + Purpose: []string{"api"}, + }, + }, + { + Config: &listenerutil.ListenerConfig{ + Purpose: []string{"cluster"}, + }, + }, + { + Config: &listenerutil.ListenerConfig{ + Purpose: []string{"cluster"}, + }, + }, + }, + expErr: true, + expErrMsg: "exactly one cluster listener is required", + }, + { + name: "only one purpose is allowed per listener", + listeners: []*base.ServerListener{ + { + Config: &listenerutil.ListenerConfig{ + Purpose: []string{"api", "cluster"}, + }, + }, + }, + expErr: true, + expErrMsg: `found listener with multiple purposes "api,cluster"`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + tc := &TestController{ + t: t, + ctx: ctx, + cancel: cancel, + opts: nil, + } + conf := TestControllerConfig(t, ctx, tc, nil) + conf.Listeners = tt.listeners + + c, err := New(ctx, conf) + if tt.expErr { + require.EqualError(t, err, tt.expErrMsg) + require.Nil(t, c) + return + } + + require.NoError(t, err) + require.NotNil(t, c) + tt.assertions(t, c) + }) + } +} diff --git a/internal/servers/controller/cors_test.go b/internal/servers/controller/cors_test.go index 7110446fdc..63d8543084 100644 --- a/internal/servers/controller/cors_test.go +++ b/internal/servers/controller/cors_test.go @@ -20,50 +20,54 @@ const corsTestConfig = ` disable_mlock = true telemetry { - prometheus_retention_time = "24h" - disable_hostname = true + prometheus_retention_time = "24h" + disable_hostname = true } kms "aead" { - purpose = "root" - aead_type = "aes-gcm" - key = "09iqFxRJNYsl/b8CQxjnGw==" - key_id = "global_root" + purpose = "root" + aead_type = "aes-gcm" + key = "09iqFxRJNYsl/b8CQxjnGw==" + key_id = "global_root" } kms "aead" { - purpose = "worker-auth" - aead_type = "aes-gcm" - key = "09iqFxRJNYsl/b8CQxjnGw==" - key_id = "global_worker-auth" + purpose = "worker-auth" + aead_type = "aes-gcm" + key = "09iqFxRJNYsl/b8CQxjnGw==" + key_id = "global_worker-auth" } listener "tcp" { - purpose = "api" - tls_disable = true - cors_enabled = false + purpose = "cluster" } listener "tcp" { - purpose = "api" - tls_disable = true - cors_enabled = true - cors_allowed_origins = [] + purpose = "api" + tls_disable = true + cors_enabled = false } listener "tcp" { - purpose = "api" - tls_disable = true - cors_enabled = true - cors_allowed_origins = ["foobar.com", "barfoo.com"] + purpose = "api" + tls_disable = true + cors_enabled = true + cors_allowed_origins = [] } listener "tcp" { - purpose = "api" - tls_disable = true - cors_enabled = true - cors_allowed_origins = ["*"] - cors_allowed_headers = ["x-foobar"] + purpose = "api" + tls_disable = true + cors_enabled = true + cors_allowed_origins = ["foobar.com", "barfoo.com"] +} + +listener "tcp" { + purpose = "api" + tls_disable = true + cors_enabled = true + cors_allowed_origins = ["*"] + cors_allowed_headers = ["x-foobar"] } ` diff --git a/internal/servers/controller/gateway.go b/internal/servers/controller/gateway.go index c0fc6eac64..005e188034 100644 --- a/internal/servers/controller/gateway.go +++ b/internal/servers/controller/gateway.go @@ -19,20 +19,14 @@ import ( "google.golang.org/grpc/test/bufconn" ) -// newGatewayListener will create an in-memory listener -func newGatewayListener() (gatewayListener, string) { - buffer := globals.DefaultMaxRequestSize // seems like a reasonable size for the ring buffer, but then happily change the size if more info becomes available - return bufconn.Listen(int(buffer)), "" -} - const gatewayTarget = "" -type gatewayListener interface { +type grpcServerListener interface { net.Listener Dial() (net.Conn, error) } -func gatewayDialOptions(lis gatewayListener) []grpc.DialOption { +func gatewayDialOptions(lis grpcServerListener) []grpc.DialOption { return []grpc.DialOption{ grpc.WithInsecure(), grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) { @@ -41,7 +35,7 @@ func gatewayDialOptions(lis gatewayListener) []grpc.DialOption { } } -func newGatewayMux() *runtime.ServeMux { +func newGrpcGatewayMux() *runtime.ServeMux { return runtime.NewServeMux( runtime.WithMarshalerOption(runtime.MIMEWildcard, &runtime.HTTPBodyMarshaler{ Marshaler: handlers.JSONMarshaler(), @@ -51,7 +45,13 @@ func newGatewayMux() *runtime.ServeMux { ) } -func newGatewayServer( +// newGrpcServerListener will create an in-memory listener for the gRPC server. +func newGrpcServerListener() grpcServerListener { + buffer := globals.DefaultMaxRequestSize // seems like a reasonable size for the ring buffer, but then happily change the size if more info becomes available + return bufconn.Listen(int(buffer)) +} + +func newGrpcServer( ctx context.Context, iamRepoFn common.IamRepoFactory, authTokenRepoFn common.AuthTokenRepoFactory, @@ -59,7 +59,7 @@ func newGatewayServer( kms *kms.Kms, eventer *event.Eventer, ) (*grpc.Server, string, error) { - const op = "controller.newGatewayServer" + const op = "controller.newGrpcServer" ticket, err := db.NewPrivateId("gwticket") if err != nil { return nil, "", errors.Wrap(ctx, err, op, errors.WithMsg("unable to generate gateway ticket")) diff --git a/internal/servers/controller/handler.go b/internal/servers/controller/handler.go index dae3312514..eab9ce9578 100644 --- a/internal/servers/controller/handler.go +++ b/internal/servers/controller/handler.go @@ -13,6 +13,7 @@ import ( "strings" "time" + "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" "github.com/hashicorp/boundary/globals" "github.com/hashicorp/boundary/internal/auth/oidc" "github.com/hashicorp/boundary/internal/gen/controller/api/services" @@ -39,6 +40,7 @@ import ( "github.com/hashicorp/go-secure-stdlib/listenerutil" "github.com/hashicorp/go-secure-stdlib/strutil" "github.com/mr-tron/base58" + "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/protobuf/proto" @@ -50,17 +52,17 @@ type HandlerProperties struct { CancelCtx context.Context } -// Handler returns an http.Handler for the services. This can be used on +// apiHandler returns an http.Handler for the services. This can be used on // its own to mount the Controller API within another web server. -func (c *Controller) handler(props HandlerProperties) (http.Handler, error) { - // Create the muxer to handle the actual endpoints +func (c *Controller) apiHandler(props HandlerProperties) (http.Handler, error) { mux := http.NewServeMux() - h, err := handleGrpcGateway(c, props) + grpcGwMux := newGrpcGatewayMux() + err := registerGrpcGatewayEndpoints(props.CancelCtx, grpcGwMux, gatewayDialOptions(c.apiGrpcServerListener)...) if err != nil { return nil, err } - mux.Handle("/v1/", h) + mux.Handle("/v1/", grpcGwMux) mux.Handle("/", handleUi(c)) corsWrappedHandler := wrapHandlerWithCors(mux, props) @@ -68,103 +70,74 @@ func (c *Controller) handler(props HandlerProperties) (http.Handler, error) { callbackInterceptingHandler := wrapHandlerWithCallbackInterceptor(commonWrappedHandler, c) printablePathCheckHandler := cleanhttp.PrintablePathCheckHandler(callbackInterceptingHandler, nil) eventsHandler, err := common.WrapWithEventsHandler(printablePathCheckHandler, c.conf.Eventer, c.kms, props.ListenerConfig) - if err != nil { - return nil, err - } - return eventsHandler, nil + return eventsHandler, err } -func handleGrpcGateway(c *Controller, props HandlerProperties) (http.Handler, error) { - // Register*ServiceHandlerServer methods ignore the passed in ctx. Using it - // now however in case this changes in the future. - ctx := props.CancelCtx - currentServices := c.gatewayServer.GetServiceInfo() - dialOptions := gatewayDialOptions(c.gatewayListener) +func (c *Controller) registerGrpcServices(s *grpc.Server) error { + // We have to check against the current services because the gRPC lib treats a duplicate + // register call as an error and os.Exits. + currentServices := s.GetServiceInfo() if _, ok := currentServices[services.HostCatalogService_ServiceDesc.ServiceName]; !ok { hcs, err := host_catalogs.NewService(c.StaticHostRepoFn, c.PluginHostRepoFn, c.HostPluginRepoFn, c.IamRepoFn) if err != nil { - return nil, fmt.Errorf("failed to create host catalog handler service: %w", err) - } - services.RegisterHostCatalogServiceServer(c.gatewayServer, hcs) - if err := services.RegisterHostCatalogServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil { - return nil, fmt.Errorf("failed to register host catalog service handler: %w", err) + return fmt.Errorf("failed to create host catalog handler service: %w", err) } + services.RegisterHostCatalogServiceServer(s, hcs) } if _, ok := currentServices[services.HostSetService_ServiceDesc.ServiceName]; !ok { hss, err := host_sets.NewService(c.StaticHostRepoFn, c.PluginHostRepoFn) if err != nil { - return nil, fmt.Errorf("failed to create host set handler service: %w", err) - } - services.RegisterHostSetServiceServer(c.gatewayServer, hss) - if err := services.RegisterHostSetServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil { - return nil, fmt.Errorf("failed to register host set service handler: %w", err) + return fmt.Errorf("failed to create host set handler service: %w", err) } + services.RegisterHostSetServiceServer(s, hss) } if _, ok := currentServices[services.HostService_ServiceDesc.ServiceName]; !ok { hs, err := hosts.NewService(c.StaticHostRepoFn, c.PluginHostRepoFn) if err != nil { - return nil, fmt.Errorf("failed to create host handler service: %w", err) - } - services.RegisterHostServiceServer(c.gatewayServer, hs) - if err := services.RegisterHostServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil { - return nil, fmt.Errorf("failed to register host service handler: %w", err) + return fmt.Errorf("failed to create host handler service: %w", err) } + services.RegisterHostServiceServer(s, hs) } if _, ok := currentServices[services.AccountService_ServiceDesc.ServiceName]; !ok { accts, err := accounts.NewService(c.PasswordAuthRepoFn, c.OidcRepoFn) if err != nil { - return nil, fmt.Errorf("failed to create account handler service: %w", err) - } - services.RegisterAccountServiceServer(c.gatewayServer, accts) - if err := services.RegisterAccountServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil { - return nil, fmt.Errorf("failed to register account service handler: %w", err) + return fmt.Errorf("failed to create account handler service: %w", err) } + services.RegisterAccountServiceServer(s, accts) } if _, ok := currentServices[services.AuthMethodService_ServiceDesc.ServiceName]; !ok { authMethods, err := authmethods.NewService(c.kms, c.PasswordAuthRepoFn, c.OidcRepoFn, c.IamRepoFn, c.AuthTokenRepoFn) if err != nil { - return nil, fmt.Errorf("failed to create auth method handler service: %w", err) - } - services.RegisterAuthMethodServiceServer(c.gatewayServer, authMethods) - if err := services.RegisterAuthMethodServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil { - return nil, fmt.Errorf("failed to register auth method service handler: %w", err) + return fmt.Errorf("failed to create auth method handler service: %w", err) } + services.RegisterAuthMethodServiceServer(s, authMethods) } if _, ok := currentServices[services.AuthTokenService_ServiceDesc.ServiceName]; !ok { authtoks, err := authtokens.NewService(c.AuthTokenRepoFn, c.IamRepoFn) if err != nil { - return nil, fmt.Errorf("failed to create auth token handler service: %w", err) - } - services.RegisterAuthTokenServiceServer(c.gatewayServer, authtoks) - if err := services.RegisterAuthTokenServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil { - return nil, fmt.Errorf("failed to register auth token service handler: %w", err) + return fmt.Errorf("failed to create auth token handler service: %w", err) } + services.RegisterAuthTokenServiceServer(s, authtoks) } if _, ok := currentServices[services.ScopeService_ServiceDesc.ServiceName]; !ok { os, err := scopes.NewService(c.IamRepoFn) if err != nil { - return nil, fmt.Errorf("failed to create scope handler service: %w", err) - } - services.RegisterScopeServiceServer(c.gatewayServer, os) - if err := services.RegisterScopeServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil { - return nil, fmt.Errorf("failed to register scope service handler: %w", err) + return fmt.Errorf("failed to create scope handler service: %w", err) } + services.RegisterScopeServiceServer(s, os) } if _, ok := currentServices[services.UserService_ServiceDesc.ServiceName]; !ok { us, err := users.NewService(c.IamRepoFn) if err != nil { - return nil, fmt.Errorf("failed to create user handler service: %w", err) - } - services.RegisterUserServiceServer(c.gatewayServer, us) - if err := services.RegisterUserServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil { - return nil, fmt.Errorf("failed to register user service handler: %w", err) + return fmt.Errorf("failed to create user handler service: %w", err) } + services.RegisterUserServiceServer(s, us) } if _, ok := currentServices[services.TargetService_ServiceDesc.ServiceName]; !ok { ts, err := targets.NewService( - ctx, + c.baseContext, c.kms, c.TargetRepoFn, c.IamRepoFn, @@ -174,75 +147,106 @@ func handleGrpcGateway(c *Controller, props HandlerProperties) (http.Handler, er c.StaticHostRepoFn, c.VaultCredentialRepoFn) if err != nil { - return nil, fmt.Errorf("failed to create target handler service: %w", err) - } - services.RegisterTargetServiceServer(c.gatewayServer, ts) - if err := services.RegisterTargetServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil { - return nil, fmt.Errorf("failed to register target service handler: %w", err) + return fmt.Errorf("failed to create target handler service: %w", err) } + services.RegisterTargetServiceServer(s, ts) } if _, ok := currentServices[services.GroupService_ServiceDesc.ServiceName]; !ok { gs, err := groups.NewService(c.IamRepoFn) if err != nil { - return nil, fmt.Errorf("failed to create group handler service: %w", err) - } - services.RegisterGroupServiceServer(c.gatewayServer, gs) - if err := services.RegisterGroupServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil { - return nil, fmt.Errorf("failed to register group service handler: %w", err) + return fmt.Errorf("failed to create group handler service: %w", err) } + services.RegisterGroupServiceServer(s, gs) } if _, ok := currentServices[services.RoleService_ServiceDesc.ServiceName]; !ok { rs, err := roles.NewService(c.IamRepoFn) if err != nil { - return nil, fmt.Errorf("failed to create role handler service: %w", err) - } - services.RegisterRoleServiceServer(c.gatewayServer, rs) - if err := services.RegisterRoleServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil { - return nil, fmt.Errorf("failed to register role service handler: %w", err) + return fmt.Errorf("failed to create role handler service: %w", err) } + services.RegisterRoleServiceServer(s, rs) } if _, ok := currentServices[services.SessionService_ServiceDesc.ServiceName]; !ok { ss, err := sessions.NewService(c.SessionRepoFn, c.IamRepoFn) if err != nil { - return nil, fmt.Errorf("failed to create session handler service: %w", err) - } - services.RegisterSessionServiceServer(c.gatewayServer, ss) - if err := services.RegisterSessionServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil { - return nil, fmt.Errorf("failed to register session service handler: %w", err) + return fmt.Errorf("failed to create session handler service: %w", err) } + services.RegisterSessionServiceServer(s, ss) } if _, ok := currentServices[services.ManagedGroupService_ServiceDesc.ServiceName]; !ok { mgs, err := managed_groups.NewService(c.OidcRepoFn) if err != nil { - return nil, fmt.Errorf("failed to create managed groups handler service: %w", err) - } - services.RegisterManagedGroupServiceServer(c.gatewayServer, mgs) - if err := services.RegisterManagedGroupServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil { - return nil, fmt.Errorf("failed to register managed groups service handler: %w", err) + return fmt.Errorf("failed to create managed groups handler service: %w", err) } + services.RegisterManagedGroupServiceServer(s, mgs) } if _, ok := currentServices[services.CredentialStoreService_ServiceDesc.ServiceName]; !ok { cs, err := credentialstores.NewService(c.VaultCredentialRepoFn, c.IamRepoFn) if err != nil { - return nil, fmt.Errorf("failed to create credential store handler service: %w", err) - } - services.RegisterCredentialStoreServiceServer(c.gatewayServer, cs) - if err := services.RegisterCredentialStoreServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil { - return nil, fmt.Errorf("failed to register credential store service handler: %w", err) + return fmt.Errorf("failed to create credential store handler service: %w", err) } + services.RegisterCredentialStoreServiceServer(s, cs) } if _, ok := currentServices[services.CredentialLibraryService_ServiceDesc.ServiceName]; !ok { cl, err := credentiallibraries.NewService(c.VaultCredentialRepoFn, c.IamRepoFn) if err != nil { - return nil, fmt.Errorf("failed to create credential library handler service: %w", err) - } - services.RegisterCredentialLibraryServiceServer(c.gatewayServer, cl) - if err := services.RegisterCredentialLibraryServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil { - return nil, fmt.Errorf("failed to register credential library service handler: %w", err) + return fmt.Errorf("failed to create credential library handler service: %w", err) } + services.RegisterCredentialLibraryServiceServer(s, cl) + } + + return nil +} + +func registerGrpcGatewayEndpoints(ctx context.Context, gwMux *runtime.ServeMux, dialOptions ...grpc.DialOption) error { + // Register*ServiceHandlerServer methods ignore the passed in context. + // Passing it in anyways in case this changes in the future. + if err := services.RegisterHostCatalogServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { + return fmt.Errorf("failed to register host catalog service handler: %w", err) + } + if err := services.RegisterHostSetServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { + return fmt.Errorf("failed to register host set service handler: %w", err) + } + if err := services.RegisterHostServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { + return fmt.Errorf("failed to register host service handler: %w", err) + } + if err := services.RegisterAccountServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { + return fmt.Errorf("failed to register account service handler: %w", err) + } + if err := services.RegisterAuthMethodServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { + return fmt.Errorf("failed to register auth method service handler: %w", err) + } + if err := services.RegisterAuthTokenServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { + return fmt.Errorf("failed to register auth token service handler: %w", err) + } + if err := services.RegisterScopeServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { + return fmt.Errorf("failed to register scope service handler: %w", err) + } + if err := services.RegisterUserServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { + return fmt.Errorf("failed to register user service handler: %w", err) + } + if err := services.RegisterTargetServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { + return fmt.Errorf("failed to register target service handler: %w", err) + } + if err := services.RegisterGroupServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { + return fmt.Errorf("failed to register group service handler: %w", err) + } + if err := services.RegisterRoleServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { + return fmt.Errorf("failed to register role service handler: %w", err) + } + if err := services.RegisterSessionServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { + return fmt.Errorf("failed to register session service handler: %w", err) + } + if err := services.RegisterManagedGroupServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { + return fmt.Errorf("failed to register managed groups service handler: %w", err) + } + if err := services.RegisterCredentialStoreServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { + return fmt.Errorf("failed to register credential store service handler: %w", err) + } + if err := services.RegisterCredentialLibraryServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { + return fmt.Errorf("failed to register credential library service handler: %w", err) } - return c.gatewayMux, nil + return nil } func wrapHandlerWithCommonFuncs(h http.Handler, c *Controller, props HandlerProperties) http.Handler { @@ -301,7 +305,7 @@ func wrapHandlerWithCommonFuncs(h http.Handler, c *Controller, props HandlerProp // Serialize the request info to send it across the wire to the // grpc-gateway via an http header - requestInfo.Ticket = c.gatewayTicket // allows the grpc-gateway to verify the request info came from it's in-memory companion http proxy + requestInfo.Ticket = c.apiGrpcGatewayTicket // allows the grpc-gateway to verify the request info came from it's in-memory companion http proxy marshalledRequestInfo, err := proto.Marshal(&requestInfo) if err != nil { event.WriteError(ctx, op, err, event.WithInfoMsg("error marshaling request info")) diff --git a/internal/servers/controller/listeners.go b/internal/servers/controller/listeners.go index 2079018ded..8e5852870d 100644 --- a/internal/servers/controller/listeners.go +++ b/internal/servers/controller/listeners.go @@ -22,151 +22,144 @@ import ( "google.golang.org/grpc" ) -func (c *Controller) startListeners(ctx context.Context) error { +func (c *Controller) startListeners() error { servers := make([]func(), 0, len(c.conf.Listeners)) - configureForAPI := func(ln *base.ServerListener) error { - var err error - if c.gatewayServer, c.gatewayTicket, err = newGatewayServer(ctx, c.IamRepoFn, c.AuthTokenRepoFn, c.ServersRepoFn, c.kms, c.conf.Eventer); err != nil { - return err - } - c.gatewayMux = newGatewayMux() + grpcServer, gwTicket, err := newGrpcServer(c.baseContext, c.IamRepoFn, c.AuthTokenRepoFn, c.ServersRepoFn, c.kms, c.conf.Eventer) + if err != nil { + return fmt.Errorf("failed to create new grpc server: %w", err) + } + c.apiGrpcServer = grpcServer + c.apiGrpcGatewayTicket = gwTicket + + err = c.registerGrpcServices(c.apiGrpcServer) + if err != nil { + return fmt.Errorf("failed to register grpc services: %w", err) + } + + c.apiGrpcServerListener = newGrpcServerListener() + servers = append(servers, func() { + go c.apiGrpcServer.Serve(c.apiGrpcServerListener) + }) - handler, err := c.handler(HandlerProperties{ - ListenerConfig: ln.Config, - CancelCtx: c.baseContext, - }) + for i := range c.apiListeners { + ln := c.apiListeners[i] + apiServers, err := c.configureForAPI(ln) if err != nil { - return err + return fmt.Errorf("failed to configure listener for api mode: %w", err) } + servers = append(servers, apiServers...) + } - // Resolve it here to avoid race conditions if the base context is - // replaced - cancelCtx := c.baseContext - - server := &http.Server{ - Handler: handler, - ReadHeaderTimeout: 10 * time.Second, - ReadTimeout: 30 * time.Second, - IdleTimeout: 5 * time.Minute, - ErrorLog: c.logger.StandardLogger(nil), - BaseContext: func(net.Listener) context.Context { - return cancelCtx - }, - } - ln.HTTPServer = server + clusterServer, err := c.configureForCluster(c.clusterListener) + if err != nil { + return fmt.Errorf("failed to configure listener for cluster mode: %w", err) + } + servers = append(servers, clusterServer) - if ln.Config.HTTPReadHeaderTimeout > 0 { - server.ReadHeaderTimeout = ln.Config.HTTPReadHeaderTimeout - } - if ln.Config.HTTPReadTimeout > 0 { - server.ReadTimeout = ln.Config.HTTPReadTimeout - } - if ln.Config.HTTPWriteTimeout > 0 { - server.WriteTimeout = ln.Config.HTTPWriteTimeout - } - if ln.Config.HTTPIdleTimeout > 0 { - server.IdleTimeout = ln.Config.HTTPIdleTimeout - } + for _, s := range servers { + s() + } - switch ln.Config.TLSDisable { - case true: - l, err := ln.Mux.RegisterProto(alpnmux.NoProto, nil) - if err != nil { - return fmt.Errorf("error getting non-tls listener: %w", err) - } - if l == nil { - return errors.New("could not get non-tls listener") - } - servers = append(servers, func() { - go server.Serve(l) - }) - - default: - protos := []string{"", "http/1.1", "h2"} - for _, v := range protos { - l := ln.Mux.GetListener(v) - if l == nil { - return fmt.Errorf("could not get tls proto %q listener", v) - } - servers = append(servers, func() { - go server.Serve(l) - }) - } - } + return nil +} - return nil +func (c *Controller) configureForAPI(ln *base.ServerListener) ([]func(), error) { + apiServers := make([]func(), 0) + + handler, err := c.apiHandler(HandlerProperties{ + ListenerConfig: ln.Config, + CancelCtx: c.baseContext, + }) + if err != nil { + return nil, err } - configureForCluster := func(ln *base.ServerListener) error { - // Clear out in case this is a second start of the controller - ln.Mux.UnregisterProto(alpnmux.DefaultProto) - l, err := ln.Mux.RegisterProto(alpnmux.DefaultProto, &tls.Config{ - GetConfigForClient: c.validateWorkerTls, - }) + cancelCtx := c.baseContext // Resolve to avoid race conditions if the base context is replaced. + server := &http.Server{ + Handler: handler, + ReadHeaderTimeout: 10 * time.Second, + ReadTimeout: 30 * time.Second, + IdleTimeout: 5 * time.Minute, + ErrorLog: c.logger.StandardLogger(nil), + BaseContext: func(net.Listener) context.Context { return cancelCtx }, + } + ln.HTTPServer = server + + if ln.Config.HTTPReadHeaderTimeout > 0 { + server.ReadHeaderTimeout = ln.Config.HTTPReadHeaderTimeout + } + if ln.Config.HTTPReadTimeout > 0 { + server.ReadTimeout = ln.Config.HTTPReadTimeout + } + if ln.Config.HTTPWriteTimeout > 0 { + server.WriteTimeout = ln.Config.HTTPWriteTimeout + } + if ln.Config.HTTPIdleTimeout > 0 { + server.IdleTimeout = ln.Config.HTTPIdleTimeout + } + + switch ln.Config.TLSDisable { + case true: + l, err := ln.Mux.RegisterProto(alpnmux.NoProto, nil) if err != nil { - return fmt.Errorf("error getting sub-listener for worker proto: %w", err) + return nil, fmt.Errorf("error getting non-tls listener: %w", err) + } + if l == nil { + return nil, errors.New("could not get non-tls listener") } + apiServers = append(apiServers, func() { go server.Serve(l) }) - workerReqInterceptor, err := workerRequestInfoInterceptor(ctx, c.conf.Eventer) - if err != nil { - return fmt.Errorf("error getting sub-listener for worker proto: %w", err) + default: + for _, v := range []string{"", "http/1.1", "h2"} { + l := ln.Mux.GetListener(v) + if l == nil { + return nil, fmt.Errorf("could not get tls proto %q listener", v) + } + apiServers = append(apiServers, func() { go server.Serve(l) }) } - workerServer := grpc.NewServer( - grpc.MaxRecvMsgSize(math.MaxInt32), - grpc.MaxSendMsgSize(math.MaxInt32), - grpc.UnaryInterceptor( - grpc_middleware.ChainUnaryServer( - workerReqInterceptor, - auditRequestInterceptor(ctx), // before we get started, audit the request - auditResponseInterceptor(ctx), // as we finish, audit the response - ), - ), - ) - workerService := workers.NewWorkerServiceServer(c.ServersRepoFn, c.SessionRepoFn, c.ConnectionRepoFn, - c.workerStatusUpdateTimes, c.kms) - pbs.RegisterServerCoordinationServiceServer(workerServer, workerService) - pbs.RegisterSessionServiceServer(workerServer, workerService) - - interceptor := newInterceptingListener(c, l) - ln.ALPNListener = interceptor - ln.GrpcServer = workerServer - - servers = append(servers, func() { - go workerServer.Serve(interceptor) - }) - return nil } - c.gatewayListener, _ = newGatewayListener() - servers = append(servers, func() { - go c.gatewayServer.Serve(c.gatewayListener) - }) + return apiServers, nil +} - for _, ln := range c.conf.Listeners { - var err error - for _, purpose := range ln.Config.Purpose { - switch purpose { - case "api": - err = configureForAPI(ln) - case "cluster": - err = configureForCluster(ln) - case "proxy": - // Do nothing, in a dev mode we might see it here - default: - err = fmt.Errorf("unknown listener purpose %q", purpose) - } - if err != nil { - return err - } - } +func (c *Controller) configureForCluster(ln *base.ServerListener) (func(), error) { + // Clear out in case this is a second start of the controller + ln.Mux.UnregisterProto(alpnmux.DefaultProto) + l, err := ln.Mux.RegisterProto(alpnmux.DefaultProto, &tls.Config{ + GetConfigForClient: c.validateWorkerTls, + }) + if err != nil { + return nil, fmt.Errorf("error getting sub-listener for worker proto: %w", err) } - for _, s := range servers { - s() + workerReqInterceptor, err := workerRequestInfoInterceptor(c.baseContext, c.conf.Eventer) + if err != nil { + return nil, fmt.Errorf("error getting sub-listener for worker proto: %w", err) } - return nil + workerServer := grpc.NewServer( + grpc.MaxRecvMsgSize(math.MaxInt32), + grpc.MaxSendMsgSize(math.MaxInt32), + grpc.UnaryInterceptor( + grpc_middleware.ChainUnaryServer( + workerReqInterceptor, + auditRequestInterceptor(c.baseContext), // before we get started, audit the request + auditResponseInterceptor(c.baseContext), // as we finish, audit the response + ), + ), + ) + + workerService := workers.NewWorkerServiceServer(c.ServersRepoFn, c.SessionRepoFn, c.ConnectionRepoFn, + c.workerStatusUpdateTimes, c.kms) + pbs.RegisterServerCoordinationServiceServer(workerServer, workerService) + pbs.RegisterSessionServiceServer(workerServer, workerService) + + interceptor := newInterceptingListener(c, l) + ln.ALPNListener = interceptor + ln.GrpcServer = workerServer + + return func() { go ln.GrpcServer.Serve(ln.ALPNListener) }, nil } func (c *Controller) stopListeners(serversOnly bool) error { @@ -194,7 +187,7 @@ func (c *Controller) stopListeners(serversOnly bool) error { }() } - if c.gatewayServer != nil { + if c.apiGrpcServer != nil { serverWg.Add(1) go func() { defer serverWg.Done() @@ -202,9 +195,9 @@ func (c *Controller) stopListeners(serversOnly bool) error { defer shutdownKillCancel() go func() { <-shutdownKill.Done() - c.gatewayServer.Stop() + c.apiGrpcServer.Stop() }() - c.gatewayServer.GracefulStop() + c.apiGrpcServer.GracefulStop() }() } diff --git a/internal/servers/controller/listeners_test.go b/internal/servers/controller/listeners_test.go new file mode 100644 index 0000000000..e674882b86 --- /dev/null +++ b/internal/servers/controller/listeners_test.go @@ -0,0 +1,400 @@ +package controller + +import ( + "context" + "crypto/ed25519" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "encoding/base64" + "encoding/json" + "encoding/pem" + "fmt" + "io/ioutil" + "math/big" + "net" + "net/http" + "os" + "strconv" + "testing" + "time" + + "github.com/hashicorp/boundary/internal/cmd/base" + "github.com/hashicorp/go-secure-stdlib/base62" + "github.com/hashicorp/go-secure-stdlib/configutil/v2" + "github.com/hashicorp/go-secure-stdlib/listenerutil" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/protobuf/proto" +) + +func TestStartListeners(t *testing.T) { + tests := []struct { + name string + setup func(t *testing.T) + listeners []*listenerutil.ListenerConfig + assertions func(t *testing.T, c *Controller, apiAddrs []string, clusterAddr string) + }{ + { + name: "one api, one cluster listener", + listeners: []*listenerutil.ListenerConfig{ + { + Type: "tcp", + Purpose: []string{"api"}, + Address: "127.0.0.1:0", + TLSDisable: true, + }, + { + Type: "tcp", + Purpose: []string{"cluster"}, + Address: "127.0.0.1:0", + }, + }, + assertions: func(t *testing.T, c *Controller, apiAddrs []string, clusterAddr string) { + rsp, err := http.Get("http://" + apiAddrs[0] + "/v1/auth-methods?scope_id=global") + require.NoError(t, err) + require.Equal(t, http.StatusOK, rsp.StatusCode) + + clusterGrpcDialNoError(t, c, "tcp", clusterAddr) + }, + }, + { + name: "multiple api, one cluster listeners", + listeners: []*listenerutil.ListenerConfig{ + { + Type: "tcp", + Purpose: []string{"api"}, + Address: "127.0.0.1:0", + TLSDisable: true, + }, + { + Type: "tcp", + Purpose: []string{"api"}, + Address: "127.0.0.1:0", + TLSDisable: true, + }, + { + Type: "tcp", + Purpose: []string{"cluster"}, + Address: "127.0.0.1:0", + }, + }, + assertions: func(t *testing.T, c *Controller, apiAddrs []string, clusterAddr string) { + for _, apiAddr := range apiAddrs { + rsp, err := http.Get("http://" + apiAddr + "/v1/auth-methods?scope_id=global") + require.NoError(t, err) + require.Equal(t, http.StatusOK, rsp.StatusCode) + } + + clusterGrpcDialNoError(t, c, "tcp", clusterAddr) + }, + }, + { + name: "one api (tls), one cluster listener", + setup: testApiTlsSetup(t, "./boundary-startlistenerstest.cert", "./boundary-startlistenerstest.key"), + listeners: []*listenerutil.ListenerConfig{ + { + Type: "tcp", + Purpose: []string{"api"}, + Address: "127.0.0.1:0", + TLSCertFile: "./boundary-startlistenerstest.cert", + TLSKeyFile: "./boundary-startlistenerstest.key", + }, + { + Type: "tcp", + Purpose: []string{"cluster"}, + Address: "127.0.0.1:0", + }, + }, + assertions: func(t *testing.T, c *Controller, apiAddrs []string, clusterAddr string) { + client := testTlsHttpClient(t, "./boundary-startlistenerstest.cert") + rsp, err := client.Get("https://" + apiAddrs[0] + "/v1/auth-methods?scope_id=global") + require.NoError(t, err) + require.Equal(t, http.StatusOK, rsp.StatusCode) + + clusterGrpcDialNoError(t, c, "tcp", clusterAddr) + }, + }, + { + name: "multiple api (tls), one cluster listener", + setup: func(t *testing.T) { + testApiTlsSetup(t, "./boundary-startlistenerstest0.cert", "./boundary-startlistenerstest0.key")(t) + testApiTlsSetup(t, "./boundary-startlistenerstest1.cert", "./boundary-startlistenerstest1.key")(t) + }, + listeners: []*listenerutil.ListenerConfig{ + { + Type: "tcp", + Purpose: []string{"api"}, + Address: "127.0.0.1:0", + TLSCertFile: "./boundary-startlistenerstest0.cert", + TLSKeyFile: "./boundary-startlistenerstest0.key", + }, + { + Type: "tcp", + Purpose: []string{"api"}, + Address: "127.0.0.1:0", + TLSCertFile: "./boundary-startlistenerstest1.cert", + TLSKeyFile: "./boundary-startlistenerstest1.key", + }, + { + Type: "tcp", + Purpose: []string{"cluster"}, + Address: "127.0.0.1:0", + }, + }, + assertions: func(t *testing.T, c *Controller, apiAddrs []string, clusterAddr string) { + for i, apiAddr := range apiAddrs { + client := testTlsHttpClient(t, "./boundary-startlistenerstest"+strconv.Itoa(i)+".cert") + rsp, err := client.Get("https://" + apiAddr + "/v1/auth-methods?scope_id=global") + require.NoError(t, err) + require.Equal(t, http.StatusOK, rsp.StatusCode) + } + + clusterGrpcDialNoError(t, c, "tcp", clusterAddr) + }, + }, + { + name: "one api, one cluster listener on unix sockets", + listeners: []*listenerutil.ListenerConfig{ + { + Type: "unix", + Purpose: []string{"api"}, + Address: "/tmp/boundary-listener-test-api.sock", + TLSDisable: true, + }, + { + Type: "unix", + Purpose: []string{"cluster"}, + Address: "/tmp/boundary-listener-test-cluster.sock", + }, + }, + assertions: func(t *testing.T, c *Controller, apiAddrs []string, clusterAddr string) { + conn, err := net.Dial("unix", apiAddrs[0]) + require.NoError(t, err) + + cl := http.Client{ + Transport: &http.Transport{ + Dial: func(network, addr string) (net.Conn, error) { return conn, nil }, + }, + } + rsp, err := cl.Get("http://randomdomain.boundary/v1/auth-methods?scope_id=global") + require.NoError(t, err) + require.Equal(t, http.StatusOK, rsp.StatusCode) + require.NoError(t, conn.Close()) + + clusterGrpcDialNoError(t, c, "unix", clusterAddr) + }, + }, + { + name: "multiple api, one cluster listener on unix sockets", + listeners: []*listenerutil.ListenerConfig{ + { + Type: "unix", + Purpose: []string{"api"}, + Address: "/tmp/boundary-listener-test-api.sock", + TLSDisable: true, + }, + { + Type: "unix", + Purpose: []string{"api"}, + Address: "/tmp/boundary-listener-test-api2.sock", + TLSDisable: true, + }, + { + Type: "unix", + Purpose: []string{"cluster"}, + Address: "/tmp/boundary-listener-test-cluster.sock", + }, + }, + assertions: func(t *testing.T, c *Controller, apiAddrs []string, clusterAddr string) { + for _, apiAddr := range apiAddrs { + conn, err := net.Dial("unix", apiAddr) + require.NoError(t, err) + + cl := http.Client{ + Transport: &http.Transport{ + Dial: func(network, addr string) (net.Conn, error) { return conn, nil }, + }, + } + rsp, err := cl.Get("http://randomdomain.boundary/v1/auth-methods?scope_id=global") + require.NoError(t, err) + require.Equal(t, http.StatusOK, rsp.StatusCode) + require.NoError(t, conn.Close()) + } + + clusterGrpcDialNoError(t, c, "unix", clusterAddr) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.setup != nil { + tt.setup(t) + } + + ctx, cancel := context.WithCancel(context.Background()) + + tc := &TestController{t: t, ctx: ctx, cancel: cancel, opts: &TestControllerOpts{}} + t.Cleanup(func() { tc.Shutdown() }) + + conf := TestControllerConfig(t, ctx, tc, nil) + conf.RawConfig.SharedConfig = &configutil.SharedConfig{Listeners: tt.listeners, DisableMlock: true} + + err := conf.SetupListeners(nil, conf.RawConfig.SharedConfig, []string{"api", "cluster"}) + require.NoError(t, err) + + c, err := New(ctx, conf) + require.NoError(t, err) + + c.baseContext = ctx + c.baseCancel = cancel + + err = c.startListeners() + require.NoError(t, err) + + apiAddrs := make([]string, 0) + for _, l := range c.apiListeners { + apiAddrs = append(apiAddrs, l.Mux.Addr().String()) + } + tt.assertions(t, c, apiAddrs, c.clusterListener.Mux.Addr().String()) + }) + } +} + +func clusterGrpcDialNoError(t *testing.T, c *Controller, network, addr string) { + grpcConn, err := grpc.Dial(addr, + grpc.WithInsecure(), + grpc.WithBlock(), + grpc.WithTimeout(5*time.Second), + grpc.WithContextDialer(clusterTestDialer(t, c, network)), + ) + require.NoError(t, err) + require.NoError(t, grpcConn.Close()) +} + +func testApiTlsSetup(t *testing.T, certPath, keyPath string) func(t *testing.T) { + return func(t *testing.T) { + t.Cleanup(func() { + require.NoError(t, os.Remove(certPath)) + require.NoError(t, os.Remove(keyPath)) + }) + + certBytes, _, priv := createTestCert(t) + + certFile, err := os.Create(certPath) + require.NoError(t, err) + require.NoError(t, pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: certBytes})) + require.NoError(t, certFile.Close()) + + keyFile, err := os.Create(keyPath) + require.NoError(t, err) + + marshaledKey, err := x509.MarshalPKCS8PrivateKey(priv) + require.NoError(t, err) + + require.NoError(t, pem.Encode(keyFile, &pem.Block{Type: "PRIVATE KEY", Bytes: marshaledKey})) + require.NoError(t, keyFile.Close()) + } +} + +func testTlsHttpClient(t *testing.T, filePath string) *http.Client { + f, err := os.Open(filePath) + require.NoError(t, err) + + certBytes, err := ioutil.ReadAll(f) + require.NoError(t, err) + require.NoError(t, f.Close()) + + certPool := x509.NewCertPool() + certPool.AppendCertsFromPEM(certBytes) + + return &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{RootCAs: certPool}, + }, + } +} + +func createTestCert(t *testing.T) ([]byte, ed25519.PublicKey, ed25519.PrivateKey) { + pub, priv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + template := &x509.Certificate{ + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement | x509.KeyUsageCertSign, + SerialNumber: big.NewInt(0), + NotBefore: time.Now().Add(-30 * time.Second), + NotAfter: time.Now().Add(5 * time.Minute), + BasicConstraintsValid: true, + IsCA: true, + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + DNSNames: []string{"/tmp/boundary-listener-test-cluster.sock"}, + } + + certBytes, err := x509.CreateCertificate(rand.Reader, template, template, pub, priv) + require.NoError(t, err) + + return certBytes, pub, priv +} + +func clusterTestDialer(t *testing.T, c *Controller, network string) func(context.Context, string) (net.Conn, error) { + return func(ctx context.Context, addr string) (net.Conn, error) { + certBytes, _, priv := createTestCert(t) + + marshaledKey, err := x509.MarshalPKCS8PrivateKey(priv) + require.NoError(t, err) + + nonce, err := base62.Random(20) + require.NoError(t, err) + + info := base.WorkerAuthInfo{ + CertPEM: pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certBytes}), + KeyPEM: pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: marshaledKey}), + ConnectionNonce: nonce, + } + infoBytes, err := json.Marshal(info) + require.NoError(t, err) + + encInfo, err := c.conf.WorkerAuthKms.Encrypt(ctx, infoBytes, nil) + require.NoError(t, err) + + protoEncInfo, err := proto.Marshal(encInfo) + require.NoError(t, err) + + b64alpn := base64.RawStdEncoding.EncodeToString(protoEncInfo) + + var nextProtos []string + var count int + for i := 0; i < len(b64alpn); i += 230 { + end := i + 230 + if end > len(b64alpn) { + end = len(b64alpn) + } + nextProtos = append(nextProtos, fmt.Sprintf("v1workerauth-%02d-%s", count, b64alpn[i:end])) + count++ + } + + cert, err := x509.ParseCertificate(certBytes) + require.NoError(t, err) + + rootCAs := x509.NewCertPool() + rootCAs.AddCert(cert) + + tlsCert, err := tls.X509KeyPair(info.CertPEM, info.KeyPEM) + require.NoError(t, err) + + conn, err := tls.Dial(network, addr, &tls.Config{ + Certificates: []tls.Certificate{tlsCert}, + RootCAs: rootCAs, + NextProtos: nextProtos, + MinVersion: tls.VersionTLS13, + }) + require.NoError(t, err) + + _, err = conn.Write([]byte(nonce)) + require.NoError(t, err) + + return conn, nil + } +}