From b5ef468d333f83073c258e6043b418707ae66008 Mon Sep 17 00:00:00 2001 From: Hugo Vieira Date: Wed, 9 Mar 2022 20:41:15 +0000 Subject: [PATCH] refact(controller): Ensure deterministic listeners start order Previously, we initialized the listeners on the order they were provided to us on the config file. In the future, we'll want to add new listener purposes that depend on existing ones, so the init order needs to be deterministic. With this change, we'll initialize components in this order: 1. gRPC Server for Controller API. 2. All `api` HTTP Servers & gRPC Gateway Muxes. 3. `cluster` gRPC Server. This commit introduces some renames for the sake of clarity as well as improved test coverage: - Removes `gatewayMux` reference from Controller object. This object was misleading and is now uneeded. - Removed uneeded `context.Context` arguments. - Separates gRPC Server and gRPC Gateway concerns, both in functionality and in naming. These two were conflated before. - Added unit tests on `startListeners`. --- internal/servers/controller/controller.go | 11 +- internal/servers/controller/gateway.go | 22 +- internal/servers/controller/handler.go | 190 +++++---- internal/servers/controller/listeners.go | 245 ++++++----- internal/servers/controller/listeners_test.go | 400 ++++++++++++++++++ 5 files changed, 631 insertions(+), 237 deletions(-) create mode 100644 internal/servers/controller/listeners_test.go diff --git a/internal/servers/controller/controller.go b/internal/servers/controller/controller.go index d3bae5bcda..4a104da979 100644 --- a/internal/servers/controller/controller.go +++ b/internal/servers/controller/controller.go @@ -7,7 +7,6 @@ import ( "strings" "sync" - "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" "github.com/hashicorp/boundary/internal/auth/oidc" "github.com/hashicorp/boundary/internal/auth/password" "github.com/hashicorp/boundary/internal/authtoken" @@ -58,11 +57,9 @@ type Controller struct { // Used for testing and tracking worker health workerStatusUpdateTimes *sync.Map - // grpc gateway server - gatewayServer *grpc.Server - gatewayTicket string - gatewayListener gatewayListener - gatewayMux *runtime.ServeMux + apiGrpcServer *grpc.Server + apiGrpcServerListener grpcServerListener + apiGrpcGatewayTicket string // Repo factory methods AuthTokenRepoFn common.AuthTokenRepoFactory @@ -293,7 +290,7 @@ func (c *Controller) Start() error { if err := c.registerJobs(); err != nil { return fmt.Errorf("error registering jobs: %w", err) } - if err := c.startListeners(c.baseContext); err != nil { + if err := c.startListeners(); err != nil { return fmt.Errorf("error starting controller listeners: %w", err) } if err := c.scheduler.Start(c.baseContext, c.schedulerWg); err != nil { diff --git a/internal/servers/controller/gateway.go b/internal/servers/controller/gateway.go index c0fc6eac64..005e188034 100644 --- a/internal/servers/controller/gateway.go +++ b/internal/servers/controller/gateway.go @@ -19,20 +19,14 @@ import ( "google.golang.org/grpc/test/bufconn" ) -// newGatewayListener will create an in-memory listener -func newGatewayListener() (gatewayListener, string) { - buffer := globals.DefaultMaxRequestSize // seems like a reasonable size for the ring buffer, but then happily change the size if more info becomes available - return bufconn.Listen(int(buffer)), "" -} - const gatewayTarget = "" -type gatewayListener interface { +type grpcServerListener interface { net.Listener Dial() (net.Conn, error) } -func gatewayDialOptions(lis gatewayListener) []grpc.DialOption { +func gatewayDialOptions(lis grpcServerListener) []grpc.DialOption { return []grpc.DialOption{ grpc.WithInsecure(), grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) { @@ -41,7 +35,7 @@ func gatewayDialOptions(lis gatewayListener) []grpc.DialOption { } } -func newGatewayMux() *runtime.ServeMux { +func newGrpcGatewayMux() *runtime.ServeMux { return runtime.NewServeMux( runtime.WithMarshalerOption(runtime.MIMEWildcard, &runtime.HTTPBodyMarshaler{ Marshaler: handlers.JSONMarshaler(), @@ -51,7 +45,13 @@ func newGatewayMux() *runtime.ServeMux { ) } -func newGatewayServer( +// newGrpcServerListener will create an in-memory listener for the gRPC server. +func newGrpcServerListener() grpcServerListener { + buffer := globals.DefaultMaxRequestSize // seems like a reasonable size for the ring buffer, but then happily change the size if more info becomes available + return bufconn.Listen(int(buffer)) +} + +func newGrpcServer( ctx context.Context, iamRepoFn common.IamRepoFactory, authTokenRepoFn common.AuthTokenRepoFactory, @@ -59,7 +59,7 @@ func newGatewayServer( kms *kms.Kms, eventer *event.Eventer, ) (*grpc.Server, string, error) { - const op = "controller.newGatewayServer" + const op = "controller.newGrpcServer" ticket, err := db.NewPrivateId("gwticket") if err != nil { return nil, "", errors.Wrap(ctx, err, op, errors.WithMsg("unable to generate gateway ticket")) diff --git a/internal/servers/controller/handler.go b/internal/servers/controller/handler.go index dae3312514..eab9ce9578 100644 --- a/internal/servers/controller/handler.go +++ b/internal/servers/controller/handler.go @@ -13,6 +13,7 @@ import ( "strings" "time" + "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" "github.com/hashicorp/boundary/globals" "github.com/hashicorp/boundary/internal/auth/oidc" "github.com/hashicorp/boundary/internal/gen/controller/api/services" @@ -39,6 +40,7 @@ import ( "github.com/hashicorp/go-secure-stdlib/listenerutil" "github.com/hashicorp/go-secure-stdlib/strutil" "github.com/mr-tron/base58" + "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/protobuf/proto" @@ -50,17 +52,17 @@ type HandlerProperties struct { CancelCtx context.Context } -// Handler returns an http.Handler for the services. This can be used on +// apiHandler returns an http.Handler for the services. This can be used on // its own to mount the Controller API within another web server. -func (c *Controller) handler(props HandlerProperties) (http.Handler, error) { - // Create the muxer to handle the actual endpoints +func (c *Controller) apiHandler(props HandlerProperties) (http.Handler, error) { mux := http.NewServeMux() - h, err := handleGrpcGateway(c, props) + grpcGwMux := newGrpcGatewayMux() + err := registerGrpcGatewayEndpoints(props.CancelCtx, grpcGwMux, gatewayDialOptions(c.apiGrpcServerListener)...) if err != nil { return nil, err } - mux.Handle("/v1/", h) + mux.Handle("/v1/", grpcGwMux) mux.Handle("/", handleUi(c)) corsWrappedHandler := wrapHandlerWithCors(mux, props) @@ -68,103 +70,74 @@ func (c *Controller) handler(props HandlerProperties) (http.Handler, error) { callbackInterceptingHandler := wrapHandlerWithCallbackInterceptor(commonWrappedHandler, c) printablePathCheckHandler := cleanhttp.PrintablePathCheckHandler(callbackInterceptingHandler, nil) eventsHandler, err := common.WrapWithEventsHandler(printablePathCheckHandler, c.conf.Eventer, c.kms, props.ListenerConfig) - if err != nil { - return nil, err - } - return eventsHandler, nil + return eventsHandler, err } -func handleGrpcGateway(c *Controller, props HandlerProperties) (http.Handler, error) { - // Register*ServiceHandlerServer methods ignore the passed in ctx. Using it - // now however in case this changes in the future. - ctx := props.CancelCtx - currentServices := c.gatewayServer.GetServiceInfo() - dialOptions := gatewayDialOptions(c.gatewayListener) +func (c *Controller) registerGrpcServices(s *grpc.Server) error { + // We have to check against the current services because the gRPC lib treats a duplicate + // register call as an error and os.Exits. + currentServices := s.GetServiceInfo() if _, ok := currentServices[services.HostCatalogService_ServiceDesc.ServiceName]; !ok { hcs, err := host_catalogs.NewService(c.StaticHostRepoFn, c.PluginHostRepoFn, c.HostPluginRepoFn, c.IamRepoFn) if err != nil { - return nil, fmt.Errorf("failed to create host catalog handler service: %w", err) - } - services.RegisterHostCatalogServiceServer(c.gatewayServer, hcs) - if err := services.RegisterHostCatalogServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil { - return nil, fmt.Errorf("failed to register host catalog service handler: %w", err) + return fmt.Errorf("failed to create host catalog handler service: %w", err) } + services.RegisterHostCatalogServiceServer(s, hcs) } if _, ok := currentServices[services.HostSetService_ServiceDesc.ServiceName]; !ok { hss, err := host_sets.NewService(c.StaticHostRepoFn, c.PluginHostRepoFn) if err != nil { - return nil, fmt.Errorf("failed to create host set handler service: %w", err) - } - services.RegisterHostSetServiceServer(c.gatewayServer, hss) - if err := services.RegisterHostSetServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil { - return nil, fmt.Errorf("failed to register host set service handler: %w", err) + return fmt.Errorf("failed to create host set handler service: %w", err) } + services.RegisterHostSetServiceServer(s, hss) } if _, ok := currentServices[services.HostService_ServiceDesc.ServiceName]; !ok { hs, err := hosts.NewService(c.StaticHostRepoFn, c.PluginHostRepoFn) if err != nil { - return nil, fmt.Errorf("failed to create host handler service: %w", err) - } - services.RegisterHostServiceServer(c.gatewayServer, hs) - if err := services.RegisterHostServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil { - return nil, fmt.Errorf("failed to register host service handler: %w", err) + return fmt.Errorf("failed to create host handler service: %w", err) } + services.RegisterHostServiceServer(s, hs) } if _, ok := currentServices[services.AccountService_ServiceDesc.ServiceName]; !ok { accts, err := accounts.NewService(c.PasswordAuthRepoFn, c.OidcRepoFn) if err != nil { - return nil, fmt.Errorf("failed to create account handler service: %w", err) - } - services.RegisterAccountServiceServer(c.gatewayServer, accts) - if err := services.RegisterAccountServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil { - return nil, fmt.Errorf("failed to register account service handler: %w", err) + return fmt.Errorf("failed to create account handler service: %w", err) } + services.RegisterAccountServiceServer(s, accts) } if _, ok := currentServices[services.AuthMethodService_ServiceDesc.ServiceName]; !ok { authMethods, err := authmethods.NewService(c.kms, c.PasswordAuthRepoFn, c.OidcRepoFn, c.IamRepoFn, c.AuthTokenRepoFn) if err != nil { - return nil, fmt.Errorf("failed to create auth method handler service: %w", err) - } - services.RegisterAuthMethodServiceServer(c.gatewayServer, authMethods) - if err := services.RegisterAuthMethodServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil { - return nil, fmt.Errorf("failed to register auth method service handler: %w", err) + return fmt.Errorf("failed to create auth method handler service: %w", err) } + services.RegisterAuthMethodServiceServer(s, authMethods) } if _, ok := currentServices[services.AuthTokenService_ServiceDesc.ServiceName]; !ok { authtoks, err := authtokens.NewService(c.AuthTokenRepoFn, c.IamRepoFn) if err != nil { - return nil, fmt.Errorf("failed to create auth token handler service: %w", err) - } - services.RegisterAuthTokenServiceServer(c.gatewayServer, authtoks) - if err := services.RegisterAuthTokenServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil { - return nil, fmt.Errorf("failed to register auth token service handler: %w", err) + return fmt.Errorf("failed to create auth token handler service: %w", err) } + services.RegisterAuthTokenServiceServer(s, authtoks) } if _, ok := currentServices[services.ScopeService_ServiceDesc.ServiceName]; !ok { os, err := scopes.NewService(c.IamRepoFn) if err != nil { - return nil, fmt.Errorf("failed to create scope handler service: %w", err) - } - services.RegisterScopeServiceServer(c.gatewayServer, os) - if err := services.RegisterScopeServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil { - return nil, fmt.Errorf("failed to register scope service handler: %w", err) + return fmt.Errorf("failed to create scope handler service: %w", err) } + services.RegisterScopeServiceServer(s, os) } if _, ok := currentServices[services.UserService_ServiceDesc.ServiceName]; !ok { us, err := users.NewService(c.IamRepoFn) if err != nil { - return nil, fmt.Errorf("failed to create user handler service: %w", err) - } - services.RegisterUserServiceServer(c.gatewayServer, us) - if err := services.RegisterUserServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil { - return nil, fmt.Errorf("failed to register user service handler: %w", err) + return fmt.Errorf("failed to create user handler service: %w", err) } + services.RegisterUserServiceServer(s, us) } if _, ok := currentServices[services.TargetService_ServiceDesc.ServiceName]; !ok { ts, err := targets.NewService( - ctx, + c.baseContext, c.kms, c.TargetRepoFn, c.IamRepoFn, @@ -174,75 +147,106 @@ func handleGrpcGateway(c *Controller, props HandlerProperties) (http.Handler, er c.StaticHostRepoFn, c.VaultCredentialRepoFn) if err != nil { - return nil, fmt.Errorf("failed to create target handler service: %w", err) - } - services.RegisterTargetServiceServer(c.gatewayServer, ts) - if err := services.RegisterTargetServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil { - return nil, fmt.Errorf("failed to register target service handler: %w", err) + return fmt.Errorf("failed to create target handler service: %w", err) } + services.RegisterTargetServiceServer(s, ts) } if _, ok := currentServices[services.GroupService_ServiceDesc.ServiceName]; !ok { gs, err := groups.NewService(c.IamRepoFn) if err != nil { - return nil, fmt.Errorf("failed to create group handler service: %w", err) - } - services.RegisterGroupServiceServer(c.gatewayServer, gs) - if err := services.RegisterGroupServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil { - return nil, fmt.Errorf("failed to register group service handler: %w", err) + return fmt.Errorf("failed to create group handler service: %w", err) } + services.RegisterGroupServiceServer(s, gs) } if _, ok := currentServices[services.RoleService_ServiceDesc.ServiceName]; !ok { rs, err := roles.NewService(c.IamRepoFn) if err != nil { - return nil, fmt.Errorf("failed to create role handler service: %w", err) - } - services.RegisterRoleServiceServer(c.gatewayServer, rs) - if err := services.RegisterRoleServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil { - return nil, fmt.Errorf("failed to register role service handler: %w", err) + return fmt.Errorf("failed to create role handler service: %w", err) } + services.RegisterRoleServiceServer(s, rs) } if _, ok := currentServices[services.SessionService_ServiceDesc.ServiceName]; !ok { ss, err := sessions.NewService(c.SessionRepoFn, c.IamRepoFn) if err != nil { - return nil, fmt.Errorf("failed to create session handler service: %w", err) - } - services.RegisterSessionServiceServer(c.gatewayServer, ss) - if err := services.RegisterSessionServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil { - return nil, fmt.Errorf("failed to register session service handler: %w", err) + return fmt.Errorf("failed to create session handler service: %w", err) } + services.RegisterSessionServiceServer(s, ss) } if _, ok := currentServices[services.ManagedGroupService_ServiceDesc.ServiceName]; !ok { mgs, err := managed_groups.NewService(c.OidcRepoFn) if err != nil { - return nil, fmt.Errorf("failed to create managed groups handler service: %w", err) - } - services.RegisterManagedGroupServiceServer(c.gatewayServer, mgs) - if err := services.RegisterManagedGroupServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil { - return nil, fmt.Errorf("failed to register managed groups service handler: %w", err) + return fmt.Errorf("failed to create managed groups handler service: %w", err) } + services.RegisterManagedGroupServiceServer(s, mgs) } if _, ok := currentServices[services.CredentialStoreService_ServiceDesc.ServiceName]; !ok { cs, err := credentialstores.NewService(c.VaultCredentialRepoFn, c.IamRepoFn) if err != nil { - return nil, fmt.Errorf("failed to create credential store handler service: %w", err) - } - services.RegisterCredentialStoreServiceServer(c.gatewayServer, cs) - if err := services.RegisterCredentialStoreServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil { - return nil, fmt.Errorf("failed to register credential store service handler: %w", err) + return fmt.Errorf("failed to create credential store handler service: %w", err) } + services.RegisterCredentialStoreServiceServer(s, cs) } if _, ok := currentServices[services.CredentialLibraryService_ServiceDesc.ServiceName]; !ok { cl, err := credentiallibraries.NewService(c.VaultCredentialRepoFn, c.IamRepoFn) if err != nil { - return nil, fmt.Errorf("failed to create credential library handler service: %w", err) - } - services.RegisterCredentialLibraryServiceServer(c.gatewayServer, cl) - if err := services.RegisterCredentialLibraryServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil { - return nil, fmt.Errorf("failed to register credential library service handler: %w", err) + return fmt.Errorf("failed to create credential library handler service: %w", err) } + services.RegisterCredentialLibraryServiceServer(s, cl) + } + + return nil +} + +func registerGrpcGatewayEndpoints(ctx context.Context, gwMux *runtime.ServeMux, dialOptions ...grpc.DialOption) error { + // Register*ServiceHandlerServer methods ignore the passed in context. + // Passing it in anyways in case this changes in the future. + if err := services.RegisterHostCatalogServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { + return fmt.Errorf("failed to register host catalog service handler: %w", err) + } + if err := services.RegisterHostSetServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { + return fmt.Errorf("failed to register host set service handler: %w", err) + } + if err := services.RegisterHostServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { + return fmt.Errorf("failed to register host service handler: %w", err) + } + if err := services.RegisterAccountServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { + return fmt.Errorf("failed to register account service handler: %w", err) + } + if err := services.RegisterAuthMethodServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { + return fmt.Errorf("failed to register auth method service handler: %w", err) + } + if err := services.RegisterAuthTokenServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { + return fmt.Errorf("failed to register auth token service handler: %w", err) + } + if err := services.RegisterScopeServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { + return fmt.Errorf("failed to register scope service handler: %w", err) + } + if err := services.RegisterUserServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { + return fmt.Errorf("failed to register user service handler: %w", err) + } + if err := services.RegisterTargetServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { + return fmt.Errorf("failed to register target service handler: %w", err) + } + if err := services.RegisterGroupServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { + return fmt.Errorf("failed to register group service handler: %w", err) + } + if err := services.RegisterRoleServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { + return fmt.Errorf("failed to register role service handler: %w", err) + } + if err := services.RegisterSessionServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { + return fmt.Errorf("failed to register session service handler: %w", err) + } + if err := services.RegisterManagedGroupServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { + return fmt.Errorf("failed to register managed groups service handler: %w", err) + } + if err := services.RegisterCredentialStoreServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { + return fmt.Errorf("failed to register credential store service handler: %w", err) + } + if err := services.RegisterCredentialLibraryServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { + return fmt.Errorf("failed to register credential library service handler: %w", err) } - return c.gatewayMux, nil + return nil } func wrapHandlerWithCommonFuncs(h http.Handler, c *Controller, props HandlerProperties) http.Handler { @@ -301,7 +305,7 @@ func wrapHandlerWithCommonFuncs(h http.Handler, c *Controller, props HandlerProp // Serialize the request info to send it across the wire to the // grpc-gateway via an http header - requestInfo.Ticket = c.gatewayTicket // allows the grpc-gateway to verify the request info came from it's in-memory companion http proxy + requestInfo.Ticket = c.apiGrpcGatewayTicket // allows the grpc-gateway to verify the request info came from it's in-memory companion http proxy marshalledRequestInfo, err := proto.Marshal(&requestInfo) if err != nil { event.WriteError(ctx, op, err, event.WithInfoMsg("error marshaling request info")) diff --git a/internal/servers/controller/listeners.go b/internal/servers/controller/listeners.go index 2079018ded..8e5852870d 100644 --- a/internal/servers/controller/listeners.go +++ b/internal/servers/controller/listeners.go @@ -22,151 +22,144 @@ import ( "google.golang.org/grpc" ) -func (c *Controller) startListeners(ctx context.Context) error { +func (c *Controller) startListeners() error { servers := make([]func(), 0, len(c.conf.Listeners)) - configureForAPI := func(ln *base.ServerListener) error { - var err error - if c.gatewayServer, c.gatewayTicket, err = newGatewayServer(ctx, c.IamRepoFn, c.AuthTokenRepoFn, c.ServersRepoFn, c.kms, c.conf.Eventer); err != nil { - return err - } - c.gatewayMux = newGatewayMux() + grpcServer, gwTicket, err := newGrpcServer(c.baseContext, c.IamRepoFn, c.AuthTokenRepoFn, c.ServersRepoFn, c.kms, c.conf.Eventer) + if err != nil { + return fmt.Errorf("failed to create new grpc server: %w", err) + } + c.apiGrpcServer = grpcServer + c.apiGrpcGatewayTicket = gwTicket + + err = c.registerGrpcServices(c.apiGrpcServer) + if err != nil { + return fmt.Errorf("failed to register grpc services: %w", err) + } + + c.apiGrpcServerListener = newGrpcServerListener() + servers = append(servers, func() { + go c.apiGrpcServer.Serve(c.apiGrpcServerListener) + }) - handler, err := c.handler(HandlerProperties{ - ListenerConfig: ln.Config, - CancelCtx: c.baseContext, - }) + for i := range c.apiListeners { + ln := c.apiListeners[i] + apiServers, err := c.configureForAPI(ln) if err != nil { - return err + return fmt.Errorf("failed to configure listener for api mode: %w", err) } + servers = append(servers, apiServers...) + } - // Resolve it here to avoid race conditions if the base context is - // replaced - cancelCtx := c.baseContext - - server := &http.Server{ - Handler: handler, - ReadHeaderTimeout: 10 * time.Second, - ReadTimeout: 30 * time.Second, - IdleTimeout: 5 * time.Minute, - ErrorLog: c.logger.StandardLogger(nil), - BaseContext: func(net.Listener) context.Context { - return cancelCtx - }, - } - ln.HTTPServer = server + clusterServer, err := c.configureForCluster(c.clusterListener) + if err != nil { + return fmt.Errorf("failed to configure listener for cluster mode: %w", err) + } + servers = append(servers, clusterServer) - if ln.Config.HTTPReadHeaderTimeout > 0 { - server.ReadHeaderTimeout = ln.Config.HTTPReadHeaderTimeout - } - if ln.Config.HTTPReadTimeout > 0 { - server.ReadTimeout = ln.Config.HTTPReadTimeout - } - if ln.Config.HTTPWriteTimeout > 0 { - server.WriteTimeout = ln.Config.HTTPWriteTimeout - } - if ln.Config.HTTPIdleTimeout > 0 { - server.IdleTimeout = ln.Config.HTTPIdleTimeout - } + for _, s := range servers { + s() + } - switch ln.Config.TLSDisable { - case true: - l, err := ln.Mux.RegisterProto(alpnmux.NoProto, nil) - if err != nil { - return fmt.Errorf("error getting non-tls listener: %w", err) - } - if l == nil { - return errors.New("could not get non-tls listener") - } - servers = append(servers, func() { - go server.Serve(l) - }) - - default: - protos := []string{"", "http/1.1", "h2"} - for _, v := range protos { - l := ln.Mux.GetListener(v) - if l == nil { - return fmt.Errorf("could not get tls proto %q listener", v) - } - servers = append(servers, func() { - go server.Serve(l) - }) - } - } + return nil +} - return nil +func (c *Controller) configureForAPI(ln *base.ServerListener) ([]func(), error) { + apiServers := make([]func(), 0) + + handler, err := c.apiHandler(HandlerProperties{ + ListenerConfig: ln.Config, + CancelCtx: c.baseContext, + }) + if err != nil { + return nil, err } - configureForCluster := func(ln *base.ServerListener) error { - // Clear out in case this is a second start of the controller - ln.Mux.UnregisterProto(alpnmux.DefaultProto) - l, err := ln.Mux.RegisterProto(alpnmux.DefaultProto, &tls.Config{ - GetConfigForClient: c.validateWorkerTls, - }) + cancelCtx := c.baseContext // Resolve to avoid race conditions if the base context is replaced. + server := &http.Server{ + Handler: handler, + ReadHeaderTimeout: 10 * time.Second, + ReadTimeout: 30 * time.Second, + IdleTimeout: 5 * time.Minute, + ErrorLog: c.logger.StandardLogger(nil), + BaseContext: func(net.Listener) context.Context { return cancelCtx }, + } + ln.HTTPServer = server + + if ln.Config.HTTPReadHeaderTimeout > 0 { + server.ReadHeaderTimeout = ln.Config.HTTPReadHeaderTimeout + } + if ln.Config.HTTPReadTimeout > 0 { + server.ReadTimeout = ln.Config.HTTPReadTimeout + } + if ln.Config.HTTPWriteTimeout > 0 { + server.WriteTimeout = ln.Config.HTTPWriteTimeout + } + if ln.Config.HTTPIdleTimeout > 0 { + server.IdleTimeout = ln.Config.HTTPIdleTimeout + } + + switch ln.Config.TLSDisable { + case true: + l, err := ln.Mux.RegisterProto(alpnmux.NoProto, nil) if err != nil { - return fmt.Errorf("error getting sub-listener for worker proto: %w", err) + return nil, fmt.Errorf("error getting non-tls listener: %w", err) + } + if l == nil { + return nil, errors.New("could not get non-tls listener") } + apiServers = append(apiServers, func() { go server.Serve(l) }) - workerReqInterceptor, err := workerRequestInfoInterceptor(ctx, c.conf.Eventer) - if err != nil { - return fmt.Errorf("error getting sub-listener for worker proto: %w", err) + default: + for _, v := range []string{"", "http/1.1", "h2"} { + l := ln.Mux.GetListener(v) + if l == nil { + return nil, fmt.Errorf("could not get tls proto %q listener", v) + } + apiServers = append(apiServers, func() { go server.Serve(l) }) } - workerServer := grpc.NewServer( - grpc.MaxRecvMsgSize(math.MaxInt32), - grpc.MaxSendMsgSize(math.MaxInt32), - grpc.UnaryInterceptor( - grpc_middleware.ChainUnaryServer( - workerReqInterceptor, - auditRequestInterceptor(ctx), // before we get started, audit the request - auditResponseInterceptor(ctx), // as we finish, audit the response - ), - ), - ) - workerService := workers.NewWorkerServiceServer(c.ServersRepoFn, c.SessionRepoFn, c.ConnectionRepoFn, - c.workerStatusUpdateTimes, c.kms) - pbs.RegisterServerCoordinationServiceServer(workerServer, workerService) - pbs.RegisterSessionServiceServer(workerServer, workerService) - - interceptor := newInterceptingListener(c, l) - ln.ALPNListener = interceptor - ln.GrpcServer = workerServer - - servers = append(servers, func() { - go workerServer.Serve(interceptor) - }) - return nil } - c.gatewayListener, _ = newGatewayListener() - servers = append(servers, func() { - go c.gatewayServer.Serve(c.gatewayListener) - }) + return apiServers, nil +} - for _, ln := range c.conf.Listeners { - var err error - for _, purpose := range ln.Config.Purpose { - switch purpose { - case "api": - err = configureForAPI(ln) - case "cluster": - err = configureForCluster(ln) - case "proxy": - // Do nothing, in a dev mode we might see it here - default: - err = fmt.Errorf("unknown listener purpose %q", purpose) - } - if err != nil { - return err - } - } +func (c *Controller) configureForCluster(ln *base.ServerListener) (func(), error) { + // Clear out in case this is a second start of the controller + ln.Mux.UnregisterProto(alpnmux.DefaultProto) + l, err := ln.Mux.RegisterProto(alpnmux.DefaultProto, &tls.Config{ + GetConfigForClient: c.validateWorkerTls, + }) + if err != nil { + return nil, fmt.Errorf("error getting sub-listener for worker proto: %w", err) } - for _, s := range servers { - s() + workerReqInterceptor, err := workerRequestInfoInterceptor(c.baseContext, c.conf.Eventer) + if err != nil { + return nil, fmt.Errorf("error getting sub-listener for worker proto: %w", err) } - return nil + workerServer := grpc.NewServer( + grpc.MaxRecvMsgSize(math.MaxInt32), + grpc.MaxSendMsgSize(math.MaxInt32), + grpc.UnaryInterceptor( + grpc_middleware.ChainUnaryServer( + workerReqInterceptor, + auditRequestInterceptor(c.baseContext), // before we get started, audit the request + auditResponseInterceptor(c.baseContext), // as we finish, audit the response + ), + ), + ) + + workerService := workers.NewWorkerServiceServer(c.ServersRepoFn, c.SessionRepoFn, c.ConnectionRepoFn, + c.workerStatusUpdateTimes, c.kms) + pbs.RegisterServerCoordinationServiceServer(workerServer, workerService) + pbs.RegisterSessionServiceServer(workerServer, workerService) + + interceptor := newInterceptingListener(c, l) + ln.ALPNListener = interceptor + ln.GrpcServer = workerServer + + return func() { go ln.GrpcServer.Serve(ln.ALPNListener) }, nil } func (c *Controller) stopListeners(serversOnly bool) error { @@ -194,7 +187,7 @@ func (c *Controller) stopListeners(serversOnly bool) error { }() } - if c.gatewayServer != nil { + if c.apiGrpcServer != nil { serverWg.Add(1) go func() { defer serverWg.Done() @@ -202,9 +195,9 @@ func (c *Controller) stopListeners(serversOnly bool) error { defer shutdownKillCancel() go func() { <-shutdownKill.Done() - c.gatewayServer.Stop() + c.apiGrpcServer.Stop() }() - c.gatewayServer.GracefulStop() + c.apiGrpcServer.GracefulStop() }() } diff --git a/internal/servers/controller/listeners_test.go b/internal/servers/controller/listeners_test.go new file mode 100644 index 0000000000..e674882b86 --- /dev/null +++ b/internal/servers/controller/listeners_test.go @@ -0,0 +1,400 @@ +package controller + +import ( + "context" + "crypto/ed25519" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "encoding/base64" + "encoding/json" + "encoding/pem" + "fmt" + "io/ioutil" + "math/big" + "net" + "net/http" + "os" + "strconv" + "testing" + "time" + + "github.com/hashicorp/boundary/internal/cmd/base" + "github.com/hashicorp/go-secure-stdlib/base62" + "github.com/hashicorp/go-secure-stdlib/configutil/v2" + "github.com/hashicorp/go-secure-stdlib/listenerutil" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/protobuf/proto" +) + +func TestStartListeners(t *testing.T) { + tests := []struct { + name string + setup func(t *testing.T) + listeners []*listenerutil.ListenerConfig + assertions func(t *testing.T, c *Controller, apiAddrs []string, clusterAddr string) + }{ + { + name: "one api, one cluster listener", + listeners: []*listenerutil.ListenerConfig{ + { + Type: "tcp", + Purpose: []string{"api"}, + Address: "127.0.0.1:0", + TLSDisable: true, + }, + { + Type: "tcp", + Purpose: []string{"cluster"}, + Address: "127.0.0.1:0", + }, + }, + assertions: func(t *testing.T, c *Controller, apiAddrs []string, clusterAddr string) { + rsp, err := http.Get("http://" + apiAddrs[0] + "/v1/auth-methods?scope_id=global") + require.NoError(t, err) + require.Equal(t, http.StatusOK, rsp.StatusCode) + + clusterGrpcDialNoError(t, c, "tcp", clusterAddr) + }, + }, + { + name: "multiple api, one cluster listeners", + listeners: []*listenerutil.ListenerConfig{ + { + Type: "tcp", + Purpose: []string{"api"}, + Address: "127.0.0.1:0", + TLSDisable: true, + }, + { + Type: "tcp", + Purpose: []string{"api"}, + Address: "127.0.0.1:0", + TLSDisable: true, + }, + { + Type: "tcp", + Purpose: []string{"cluster"}, + Address: "127.0.0.1:0", + }, + }, + assertions: func(t *testing.T, c *Controller, apiAddrs []string, clusterAddr string) { + for _, apiAddr := range apiAddrs { + rsp, err := http.Get("http://" + apiAddr + "/v1/auth-methods?scope_id=global") + require.NoError(t, err) + require.Equal(t, http.StatusOK, rsp.StatusCode) + } + + clusterGrpcDialNoError(t, c, "tcp", clusterAddr) + }, + }, + { + name: "one api (tls), one cluster listener", + setup: testApiTlsSetup(t, "./boundary-startlistenerstest.cert", "./boundary-startlistenerstest.key"), + listeners: []*listenerutil.ListenerConfig{ + { + Type: "tcp", + Purpose: []string{"api"}, + Address: "127.0.0.1:0", + TLSCertFile: "./boundary-startlistenerstest.cert", + TLSKeyFile: "./boundary-startlistenerstest.key", + }, + { + Type: "tcp", + Purpose: []string{"cluster"}, + Address: "127.0.0.1:0", + }, + }, + assertions: func(t *testing.T, c *Controller, apiAddrs []string, clusterAddr string) { + client := testTlsHttpClient(t, "./boundary-startlistenerstest.cert") + rsp, err := client.Get("https://" + apiAddrs[0] + "/v1/auth-methods?scope_id=global") + require.NoError(t, err) + require.Equal(t, http.StatusOK, rsp.StatusCode) + + clusterGrpcDialNoError(t, c, "tcp", clusterAddr) + }, + }, + { + name: "multiple api (tls), one cluster listener", + setup: func(t *testing.T) { + testApiTlsSetup(t, "./boundary-startlistenerstest0.cert", "./boundary-startlistenerstest0.key")(t) + testApiTlsSetup(t, "./boundary-startlistenerstest1.cert", "./boundary-startlistenerstest1.key")(t) + }, + listeners: []*listenerutil.ListenerConfig{ + { + Type: "tcp", + Purpose: []string{"api"}, + Address: "127.0.0.1:0", + TLSCertFile: "./boundary-startlistenerstest0.cert", + TLSKeyFile: "./boundary-startlistenerstest0.key", + }, + { + Type: "tcp", + Purpose: []string{"api"}, + Address: "127.0.0.1:0", + TLSCertFile: "./boundary-startlistenerstest1.cert", + TLSKeyFile: "./boundary-startlistenerstest1.key", + }, + { + Type: "tcp", + Purpose: []string{"cluster"}, + Address: "127.0.0.1:0", + }, + }, + assertions: func(t *testing.T, c *Controller, apiAddrs []string, clusterAddr string) { + for i, apiAddr := range apiAddrs { + client := testTlsHttpClient(t, "./boundary-startlistenerstest"+strconv.Itoa(i)+".cert") + rsp, err := client.Get("https://" + apiAddr + "/v1/auth-methods?scope_id=global") + require.NoError(t, err) + require.Equal(t, http.StatusOK, rsp.StatusCode) + } + + clusterGrpcDialNoError(t, c, "tcp", clusterAddr) + }, + }, + { + name: "one api, one cluster listener on unix sockets", + listeners: []*listenerutil.ListenerConfig{ + { + Type: "unix", + Purpose: []string{"api"}, + Address: "/tmp/boundary-listener-test-api.sock", + TLSDisable: true, + }, + { + Type: "unix", + Purpose: []string{"cluster"}, + Address: "/tmp/boundary-listener-test-cluster.sock", + }, + }, + assertions: func(t *testing.T, c *Controller, apiAddrs []string, clusterAddr string) { + conn, err := net.Dial("unix", apiAddrs[0]) + require.NoError(t, err) + + cl := http.Client{ + Transport: &http.Transport{ + Dial: func(network, addr string) (net.Conn, error) { return conn, nil }, + }, + } + rsp, err := cl.Get("http://randomdomain.boundary/v1/auth-methods?scope_id=global") + require.NoError(t, err) + require.Equal(t, http.StatusOK, rsp.StatusCode) + require.NoError(t, conn.Close()) + + clusterGrpcDialNoError(t, c, "unix", clusterAddr) + }, + }, + { + name: "multiple api, one cluster listener on unix sockets", + listeners: []*listenerutil.ListenerConfig{ + { + Type: "unix", + Purpose: []string{"api"}, + Address: "/tmp/boundary-listener-test-api.sock", + TLSDisable: true, + }, + { + Type: "unix", + Purpose: []string{"api"}, + Address: "/tmp/boundary-listener-test-api2.sock", + TLSDisable: true, + }, + { + Type: "unix", + Purpose: []string{"cluster"}, + Address: "/tmp/boundary-listener-test-cluster.sock", + }, + }, + assertions: func(t *testing.T, c *Controller, apiAddrs []string, clusterAddr string) { + for _, apiAddr := range apiAddrs { + conn, err := net.Dial("unix", apiAddr) + require.NoError(t, err) + + cl := http.Client{ + Transport: &http.Transport{ + Dial: func(network, addr string) (net.Conn, error) { return conn, nil }, + }, + } + rsp, err := cl.Get("http://randomdomain.boundary/v1/auth-methods?scope_id=global") + require.NoError(t, err) + require.Equal(t, http.StatusOK, rsp.StatusCode) + require.NoError(t, conn.Close()) + } + + clusterGrpcDialNoError(t, c, "unix", clusterAddr) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.setup != nil { + tt.setup(t) + } + + ctx, cancel := context.WithCancel(context.Background()) + + tc := &TestController{t: t, ctx: ctx, cancel: cancel, opts: &TestControllerOpts{}} + t.Cleanup(func() { tc.Shutdown() }) + + conf := TestControllerConfig(t, ctx, tc, nil) + conf.RawConfig.SharedConfig = &configutil.SharedConfig{Listeners: tt.listeners, DisableMlock: true} + + err := conf.SetupListeners(nil, conf.RawConfig.SharedConfig, []string{"api", "cluster"}) + require.NoError(t, err) + + c, err := New(ctx, conf) + require.NoError(t, err) + + c.baseContext = ctx + c.baseCancel = cancel + + err = c.startListeners() + require.NoError(t, err) + + apiAddrs := make([]string, 0) + for _, l := range c.apiListeners { + apiAddrs = append(apiAddrs, l.Mux.Addr().String()) + } + tt.assertions(t, c, apiAddrs, c.clusterListener.Mux.Addr().String()) + }) + } +} + +func clusterGrpcDialNoError(t *testing.T, c *Controller, network, addr string) { + grpcConn, err := grpc.Dial(addr, + grpc.WithInsecure(), + grpc.WithBlock(), + grpc.WithTimeout(5*time.Second), + grpc.WithContextDialer(clusterTestDialer(t, c, network)), + ) + require.NoError(t, err) + require.NoError(t, grpcConn.Close()) +} + +func testApiTlsSetup(t *testing.T, certPath, keyPath string) func(t *testing.T) { + return func(t *testing.T) { + t.Cleanup(func() { + require.NoError(t, os.Remove(certPath)) + require.NoError(t, os.Remove(keyPath)) + }) + + certBytes, _, priv := createTestCert(t) + + certFile, err := os.Create(certPath) + require.NoError(t, err) + require.NoError(t, pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: certBytes})) + require.NoError(t, certFile.Close()) + + keyFile, err := os.Create(keyPath) + require.NoError(t, err) + + marshaledKey, err := x509.MarshalPKCS8PrivateKey(priv) + require.NoError(t, err) + + require.NoError(t, pem.Encode(keyFile, &pem.Block{Type: "PRIVATE KEY", Bytes: marshaledKey})) + require.NoError(t, keyFile.Close()) + } +} + +func testTlsHttpClient(t *testing.T, filePath string) *http.Client { + f, err := os.Open(filePath) + require.NoError(t, err) + + certBytes, err := ioutil.ReadAll(f) + require.NoError(t, err) + require.NoError(t, f.Close()) + + certPool := x509.NewCertPool() + certPool.AppendCertsFromPEM(certBytes) + + return &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{RootCAs: certPool}, + }, + } +} + +func createTestCert(t *testing.T) ([]byte, ed25519.PublicKey, ed25519.PrivateKey) { + pub, priv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + template := &x509.Certificate{ + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement | x509.KeyUsageCertSign, + SerialNumber: big.NewInt(0), + NotBefore: time.Now().Add(-30 * time.Second), + NotAfter: time.Now().Add(5 * time.Minute), + BasicConstraintsValid: true, + IsCA: true, + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + DNSNames: []string{"/tmp/boundary-listener-test-cluster.sock"}, + } + + certBytes, err := x509.CreateCertificate(rand.Reader, template, template, pub, priv) + require.NoError(t, err) + + return certBytes, pub, priv +} + +func clusterTestDialer(t *testing.T, c *Controller, network string) func(context.Context, string) (net.Conn, error) { + return func(ctx context.Context, addr string) (net.Conn, error) { + certBytes, _, priv := createTestCert(t) + + marshaledKey, err := x509.MarshalPKCS8PrivateKey(priv) + require.NoError(t, err) + + nonce, err := base62.Random(20) + require.NoError(t, err) + + info := base.WorkerAuthInfo{ + CertPEM: pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certBytes}), + KeyPEM: pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: marshaledKey}), + ConnectionNonce: nonce, + } + infoBytes, err := json.Marshal(info) + require.NoError(t, err) + + encInfo, err := c.conf.WorkerAuthKms.Encrypt(ctx, infoBytes, nil) + require.NoError(t, err) + + protoEncInfo, err := proto.Marshal(encInfo) + require.NoError(t, err) + + b64alpn := base64.RawStdEncoding.EncodeToString(protoEncInfo) + + var nextProtos []string + var count int + for i := 0; i < len(b64alpn); i += 230 { + end := i + 230 + if end > len(b64alpn) { + end = len(b64alpn) + } + nextProtos = append(nextProtos, fmt.Sprintf("v1workerauth-%02d-%s", count, b64alpn[i:end])) + count++ + } + + cert, err := x509.ParseCertificate(certBytes) + require.NoError(t, err) + + rootCAs := x509.NewCertPool() + rootCAs.AddCert(cert) + + tlsCert, err := tls.X509KeyPair(info.CertPEM, info.KeyPEM) + require.NoError(t, err) + + conn, err := tls.Dial(network, addr, &tls.Config{ + Certificates: []tls.Certificate{tlsCert}, + RootCAs: rootCAs, + NextProtos: nextProtos, + MinVersion: tls.VersionTLS13, + }) + require.NoError(t, err) + + _, err = conn.Write([]byte(nonce)) + require.NoError(t, err) + + return conn, nil + } +}