refact(controller): Ensure deterministic listeners start order

Previously, we initialized the listeners on the order they were provided
to us on the config file. In the future, we'll want to add new listener
purposes that depend on existing ones, so the init order needs to be
deterministic.

With this change, we'll initialize components in this order:
1. gRPC Server for Controller API.
2. All `api` HTTP Servers & gRPC Gateway Muxes.
3. `cluster` gRPC Server.

This commit introduces some renames for the sake of clarity as well as
improved test coverage:

- Removes `gatewayMux` reference from Controller object. This object was
misleading and is now uneeded.
- Removed uneeded `context.Context` arguments.
- Separates gRPC Server and gRPC Gateway concerns, both in functionality
and in naming. These two were conflated before.
- Added unit tests on `startListeners`.
pull/1907/head
Hugo Vieira 4 years ago
parent 8d48a8515c
commit b5ef468d33

@ -7,7 +7,6 @@ import (
"strings"
"sync"
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
"github.com/hashicorp/boundary/internal/auth/oidc"
"github.com/hashicorp/boundary/internal/auth/password"
"github.com/hashicorp/boundary/internal/authtoken"
@ -58,11 +57,9 @@ type Controller struct {
// Used for testing and tracking worker health
workerStatusUpdateTimes *sync.Map
// grpc gateway server
gatewayServer *grpc.Server
gatewayTicket string
gatewayListener gatewayListener
gatewayMux *runtime.ServeMux
apiGrpcServer *grpc.Server
apiGrpcServerListener grpcServerListener
apiGrpcGatewayTicket string
// Repo factory methods
AuthTokenRepoFn common.AuthTokenRepoFactory
@ -293,7 +290,7 @@ func (c *Controller) Start() error {
if err := c.registerJobs(); err != nil {
return fmt.Errorf("error registering jobs: %w", err)
}
if err := c.startListeners(c.baseContext); err != nil {
if err := c.startListeners(); err != nil {
return fmt.Errorf("error starting controller listeners: %w", err)
}
if err := c.scheduler.Start(c.baseContext, c.schedulerWg); err != nil {

@ -19,20 +19,14 @@ import (
"google.golang.org/grpc/test/bufconn"
)
// newGatewayListener will create an in-memory listener
func newGatewayListener() (gatewayListener, string) {
buffer := globals.DefaultMaxRequestSize // seems like a reasonable size for the ring buffer, but then happily change the size if more info becomes available
return bufconn.Listen(int(buffer)), ""
}
const gatewayTarget = ""
type gatewayListener interface {
type grpcServerListener interface {
net.Listener
Dial() (net.Conn, error)
}
func gatewayDialOptions(lis gatewayListener) []grpc.DialOption {
func gatewayDialOptions(lis grpcServerListener) []grpc.DialOption {
return []grpc.DialOption{
grpc.WithInsecure(),
grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) {
@ -41,7 +35,7 @@ func gatewayDialOptions(lis gatewayListener) []grpc.DialOption {
}
}
func newGatewayMux() *runtime.ServeMux {
func newGrpcGatewayMux() *runtime.ServeMux {
return runtime.NewServeMux(
runtime.WithMarshalerOption(runtime.MIMEWildcard, &runtime.HTTPBodyMarshaler{
Marshaler: handlers.JSONMarshaler(),
@ -51,7 +45,13 @@ func newGatewayMux() *runtime.ServeMux {
)
}
func newGatewayServer(
// newGrpcServerListener will create an in-memory listener for the gRPC server.
func newGrpcServerListener() grpcServerListener {
buffer := globals.DefaultMaxRequestSize // seems like a reasonable size for the ring buffer, but then happily change the size if more info becomes available
return bufconn.Listen(int(buffer))
}
func newGrpcServer(
ctx context.Context,
iamRepoFn common.IamRepoFactory,
authTokenRepoFn common.AuthTokenRepoFactory,
@ -59,7 +59,7 @@ func newGatewayServer(
kms *kms.Kms,
eventer *event.Eventer,
) (*grpc.Server, string, error) {
const op = "controller.newGatewayServer"
const op = "controller.newGrpcServer"
ticket, err := db.NewPrivateId("gwticket")
if err != nil {
return nil, "", errors.Wrap(ctx, err, op, errors.WithMsg("unable to generate gateway ticket"))

@ -13,6 +13,7 @@ import (
"strings"
"time"
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
"github.com/hashicorp/boundary/globals"
"github.com/hashicorp/boundary/internal/auth/oidc"
"github.com/hashicorp/boundary/internal/gen/controller/api/services"
@ -39,6 +40,7 @@ import (
"github.com/hashicorp/go-secure-stdlib/listenerutil"
"github.com/hashicorp/go-secure-stdlib/strutil"
"github.com/mr-tron/base58"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/protobuf/proto"
@ -50,17 +52,17 @@ type HandlerProperties struct {
CancelCtx context.Context
}
// Handler returns an http.Handler for the services. This can be used on
// apiHandler returns an http.Handler for the services. This can be used on
// its own to mount the Controller API within another web server.
func (c *Controller) handler(props HandlerProperties) (http.Handler, error) {
// Create the muxer to handle the actual endpoints
func (c *Controller) apiHandler(props HandlerProperties) (http.Handler, error) {
mux := http.NewServeMux()
h, err := handleGrpcGateway(c, props)
grpcGwMux := newGrpcGatewayMux()
err := registerGrpcGatewayEndpoints(props.CancelCtx, grpcGwMux, gatewayDialOptions(c.apiGrpcServerListener)...)
if err != nil {
return nil, err
}
mux.Handle("/v1/", h)
mux.Handle("/v1/", grpcGwMux)
mux.Handle("/", handleUi(c))
corsWrappedHandler := wrapHandlerWithCors(mux, props)
@ -68,103 +70,74 @@ func (c *Controller) handler(props HandlerProperties) (http.Handler, error) {
callbackInterceptingHandler := wrapHandlerWithCallbackInterceptor(commonWrappedHandler, c)
printablePathCheckHandler := cleanhttp.PrintablePathCheckHandler(callbackInterceptingHandler, nil)
eventsHandler, err := common.WrapWithEventsHandler(printablePathCheckHandler, c.conf.Eventer, c.kms, props.ListenerConfig)
if err != nil {
return nil, err
}
return eventsHandler, nil
return eventsHandler, err
}
func handleGrpcGateway(c *Controller, props HandlerProperties) (http.Handler, error) {
// Register*ServiceHandlerServer methods ignore the passed in ctx. Using it
// now however in case this changes in the future.
ctx := props.CancelCtx
currentServices := c.gatewayServer.GetServiceInfo()
dialOptions := gatewayDialOptions(c.gatewayListener)
func (c *Controller) registerGrpcServices(s *grpc.Server) error {
// We have to check against the current services because the gRPC lib treats a duplicate
// register call as an error and os.Exits.
currentServices := s.GetServiceInfo()
if _, ok := currentServices[services.HostCatalogService_ServiceDesc.ServiceName]; !ok {
hcs, err := host_catalogs.NewService(c.StaticHostRepoFn, c.PluginHostRepoFn, c.HostPluginRepoFn, c.IamRepoFn)
if err != nil {
return nil, fmt.Errorf("failed to create host catalog handler service: %w", err)
}
services.RegisterHostCatalogServiceServer(c.gatewayServer, hcs)
if err := services.RegisterHostCatalogServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil {
return nil, fmt.Errorf("failed to register host catalog service handler: %w", err)
return fmt.Errorf("failed to create host catalog handler service: %w", err)
}
services.RegisterHostCatalogServiceServer(s, hcs)
}
if _, ok := currentServices[services.HostSetService_ServiceDesc.ServiceName]; !ok {
hss, err := host_sets.NewService(c.StaticHostRepoFn, c.PluginHostRepoFn)
if err != nil {
return nil, fmt.Errorf("failed to create host set handler service: %w", err)
}
services.RegisterHostSetServiceServer(c.gatewayServer, hss)
if err := services.RegisterHostSetServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil {
return nil, fmt.Errorf("failed to register host set service handler: %w", err)
return fmt.Errorf("failed to create host set handler service: %w", err)
}
services.RegisterHostSetServiceServer(s, hss)
}
if _, ok := currentServices[services.HostService_ServiceDesc.ServiceName]; !ok {
hs, err := hosts.NewService(c.StaticHostRepoFn, c.PluginHostRepoFn)
if err != nil {
return nil, fmt.Errorf("failed to create host handler service: %w", err)
}
services.RegisterHostServiceServer(c.gatewayServer, hs)
if err := services.RegisterHostServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil {
return nil, fmt.Errorf("failed to register host service handler: %w", err)
return fmt.Errorf("failed to create host handler service: %w", err)
}
services.RegisterHostServiceServer(s, hs)
}
if _, ok := currentServices[services.AccountService_ServiceDesc.ServiceName]; !ok {
accts, err := accounts.NewService(c.PasswordAuthRepoFn, c.OidcRepoFn)
if err != nil {
return nil, fmt.Errorf("failed to create account handler service: %w", err)
}
services.RegisterAccountServiceServer(c.gatewayServer, accts)
if err := services.RegisterAccountServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil {
return nil, fmt.Errorf("failed to register account service handler: %w", err)
return fmt.Errorf("failed to create account handler service: %w", err)
}
services.RegisterAccountServiceServer(s, accts)
}
if _, ok := currentServices[services.AuthMethodService_ServiceDesc.ServiceName]; !ok {
authMethods, err := authmethods.NewService(c.kms, c.PasswordAuthRepoFn, c.OidcRepoFn, c.IamRepoFn, c.AuthTokenRepoFn)
if err != nil {
return nil, fmt.Errorf("failed to create auth method handler service: %w", err)
}
services.RegisterAuthMethodServiceServer(c.gatewayServer, authMethods)
if err := services.RegisterAuthMethodServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil {
return nil, fmt.Errorf("failed to register auth method service handler: %w", err)
return fmt.Errorf("failed to create auth method handler service: %w", err)
}
services.RegisterAuthMethodServiceServer(s, authMethods)
}
if _, ok := currentServices[services.AuthTokenService_ServiceDesc.ServiceName]; !ok {
authtoks, err := authtokens.NewService(c.AuthTokenRepoFn, c.IamRepoFn)
if err != nil {
return nil, fmt.Errorf("failed to create auth token handler service: %w", err)
}
services.RegisterAuthTokenServiceServer(c.gatewayServer, authtoks)
if err := services.RegisterAuthTokenServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil {
return nil, fmt.Errorf("failed to register auth token service handler: %w", err)
return fmt.Errorf("failed to create auth token handler service: %w", err)
}
services.RegisterAuthTokenServiceServer(s, authtoks)
}
if _, ok := currentServices[services.ScopeService_ServiceDesc.ServiceName]; !ok {
os, err := scopes.NewService(c.IamRepoFn)
if err != nil {
return nil, fmt.Errorf("failed to create scope handler service: %w", err)
}
services.RegisterScopeServiceServer(c.gatewayServer, os)
if err := services.RegisterScopeServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil {
return nil, fmt.Errorf("failed to register scope service handler: %w", err)
return fmt.Errorf("failed to create scope handler service: %w", err)
}
services.RegisterScopeServiceServer(s, os)
}
if _, ok := currentServices[services.UserService_ServiceDesc.ServiceName]; !ok {
us, err := users.NewService(c.IamRepoFn)
if err != nil {
return nil, fmt.Errorf("failed to create user handler service: %w", err)
}
services.RegisterUserServiceServer(c.gatewayServer, us)
if err := services.RegisterUserServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil {
return nil, fmt.Errorf("failed to register user service handler: %w", err)
return fmt.Errorf("failed to create user handler service: %w", err)
}
services.RegisterUserServiceServer(s, us)
}
if _, ok := currentServices[services.TargetService_ServiceDesc.ServiceName]; !ok {
ts, err := targets.NewService(
ctx,
c.baseContext,
c.kms,
c.TargetRepoFn,
c.IamRepoFn,
@ -174,75 +147,106 @@ func handleGrpcGateway(c *Controller, props HandlerProperties) (http.Handler, er
c.StaticHostRepoFn,
c.VaultCredentialRepoFn)
if err != nil {
return nil, fmt.Errorf("failed to create target handler service: %w", err)
}
services.RegisterTargetServiceServer(c.gatewayServer, ts)
if err := services.RegisterTargetServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil {
return nil, fmt.Errorf("failed to register target service handler: %w", err)
return fmt.Errorf("failed to create target handler service: %w", err)
}
services.RegisterTargetServiceServer(s, ts)
}
if _, ok := currentServices[services.GroupService_ServiceDesc.ServiceName]; !ok {
gs, err := groups.NewService(c.IamRepoFn)
if err != nil {
return nil, fmt.Errorf("failed to create group handler service: %w", err)
}
services.RegisterGroupServiceServer(c.gatewayServer, gs)
if err := services.RegisterGroupServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil {
return nil, fmt.Errorf("failed to register group service handler: %w", err)
return fmt.Errorf("failed to create group handler service: %w", err)
}
services.RegisterGroupServiceServer(s, gs)
}
if _, ok := currentServices[services.RoleService_ServiceDesc.ServiceName]; !ok {
rs, err := roles.NewService(c.IamRepoFn)
if err != nil {
return nil, fmt.Errorf("failed to create role handler service: %w", err)
}
services.RegisterRoleServiceServer(c.gatewayServer, rs)
if err := services.RegisterRoleServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil {
return nil, fmt.Errorf("failed to register role service handler: %w", err)
return fmt.Errorf("failed to create role handler service: %w", err)
}
services.RegisterRoleServiceServer(s, rs)
}
if _, ok := currentServices[services.SessionService_ServiceDesc.ServiceName]; !ok {
ss, err := sessions.NewService(c.SessionRepoFn, c.IamRepoFn)
if err != nil {
return nil, fmt.Errorf("failed to create session handler service: %w", err)
}
services.RegisterSessionServiceServer(c.gatewayServer, ss)
if err := services.RegisterSessionServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil {
return nil, fmt.Errorf("failed to register session service handler: %w", err)
return fmt.Errorf("failed to create session handler service: %w", err)
}
services.RegisterSessionServiceServer(s, ss)
}
if _, ok := currentServices[services.ManagedGroupService_ServiceDesc.ServiceName]; !ok {
mgs, err := managed_groups.NewService(c.OidcRepoFn)
if err != nil {
return nil, fmt.Errorf("failed to create managed groups handler service: %w", err)
}
services.RegisterManagedGroupServiceServer(c.gatewayServer, mgs)
if err := services.RegisterManagedGroupServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil {
return nil, fmt.Errorf("failed to register managed groups service handler: %w", err)
return fmt.Errorf("failed to create managed groups handler service: %w", err)
}
services.RegisterManagedGroupServiceServer(s, mgs)
}
if _, ok := currentServices[services.CredentialStoreService_ServiceDesc.ServiceName]; !ok {
cs, err := credentialstores.NewService(c.VaultCredentialRepoFn, c.IamRepoFn)
if err != nil {
return nil, fmt.Errorf("failed to create credential store handler service: %w", err)
}
services.RegisterCredentialStoreServiceServer(c.gatewayServer, cs)
if err := services.RegisterCredentialStoreServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil {
return nil, fmt.Errorf("failed to register credential store service handler: %w", err)
return fmt.Errorf("failed to create credential store handler service: %w", err)
}
services.RegisterCredentialStoreServiceServer(s, cs)
}
if _, ok := currentServices[services.CredentialLibraryService_ServiceDesc.ServiceName]; !ok {
cl, err := credentiallibraries.NewService(c.VaultCredentialRepoFn, c.IamRepoFn)
if err != nil {
return nil, fmt.Errorf("failed to create credential library handler service: %w", err)
}
services.RegisterCredentialLibraryServiceServer(c.gatewayServer, cl)
if err := services.RegisterCredentialLibraryServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil {
return nil, fmt.Errorf("failed to register credential library service handler: %w", err)
return fmt.Errorf("failed to create credential library handler service: %w", err)
}
services.RegisterCredentialLibraryServiceServer(s, cl)
}
return nil
}
func registerGrpcGatewayEndpoints(ctx context.Context, gwMux *runtime.ServeMux, dialOptions ...grpc.DialOption) error {
// Register*ServiceHandlerServer methods ignore the passed in context.
// Passing it in anyways in case this changes in the future.
if err := services.RegisterHostCatalogServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil {
return fmt.Errorf("failed to register host catalog service handler: %w", err)
}
if err := services.RegisterHostSetServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil {
return fmt.Errorf("failed to register host set service handler: %w", err)
}
if err := services.RegisterHostServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil {
return fmt.Errorf("failed to register host service handler: %w", err)
}
if err := services.RegisterAccountServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil {
return fmt.Errorf("failed to register account service handler: %w", err)
}
if err := services.RegisterAuthMethodServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil {
return fmt.Errorf("failed to register auth method service handler: %w", err)
}
if err := services.RegisterAuthTokenServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil {
return fmt.Errorf("failed to register auth token service handler: %w", err)
}
if err := services.RegisterScopeServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil {
return fmt.Errorf("failed to register scope service handler: %w", err)
}
if err := services.RegisterUserServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil {
return fmt.Errorf("failed to register user service handler: %w", err)
}
if err := services.RegisterTargetServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil {
return fmt.Errorf("failed to register target service handler: %w", err)
}
if err := services.RegisterGroupServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil {
return fmt.Errorf("failed to register group service handler: %w", err)
}
if err := services.RegisterRoleServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil {
return fmt.Errorf("failed to register role service handler: %w", err)
}
if err := services.RegisterSessionServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil {
return fmt.Errorf("failed to register session service handler: %w", err)
}
if err := services.RegisterManagedGroupServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil {
return fmt.Errorf("failed to register managed groups service handler: %w", err)
}
if err := services.RegisterCredentialStoreServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil {
return fmt.Errorf("failed to register credential store service handler: %w", err)
}
if err := services.RegisterCredentialLibraryServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil {
return fmt.Errorf("failed to register credential library service handler: %w", err)
}
return c.gatewayMux, nil
return nil
}
func wrapHandlerWithCommonFuncs(h http.Handler, c *Controller, props HandlerProperties) http.Handler {
@ -301,7 +305,7 @@ func wrapHandlerWithCommonFuncs(h http.Handler, c *Controller, props HandlerProp
// Serialize the request info to send it across the wire to the
// grpc-gateway via an http header
requestInfo.Ticket = c.gatewayTicket // allows the grpc-gateway to verify the request info came from it's in-memory companion http proxy
requestInfo.Ticket = c.apiGrpcGatewayTicket // allows the grpc-gateway to verify the request info came from it's in-memory companion http proxy
marshalledRequestInfo, err := proto.Marshal(&requestInfo)
if err != nil {
event.WriteError(ctx, op, err, event.WithInfoMsg("error marshaling request info"))

@ -22,151 +22,144 @@ import (
"google.golang.org/grpc"
)
func (c *Controller) startListeners(ctx context.Context) error {
func (c *Controller) startListeners() error {
servers := make([]func(), 0, len(c.conf.Listeners))
configureForAPI := func(ln *base.ServerListener) error {
var err error
if c.gatewayServer, c.gatewayTicket, err = newGatewayServer(ctx, c.IamRepoFn, c.AuthTokenRepoFn, c.ServersRepoFn, c.kms, c.conf.Eventer); err != nil {
return err
}
c.gatewayMux = newGatewayMux()
grpcServer, gwTicket, err := newGrpcServer(c.baseContext, c.IamRepoFn, c.AuthTokenRepoFn, c.ServersRepoFn, c.kms, c.conf.Eventer)
if err != nil {
return fmt.Errorf("failed to create new grpc server: %w", err)
}
c.apiGrpcServer = grpcServer
c.apiGrpcGatewayTicket = gwTicket
err = c.registerGrpcServices(c.apiGrpcServer)
if err != nil {
return fmt.Errorf("failed to register grpc services: %w", err)
}
c.apiGrpcServerListener = newGrpcServerListener()
servers = append(servers, func() {
go c.apiGrpcServer.Serve(c.apiGrpcServerListener)
})
handler, err := c.handler(HandlerProperties{
ListenerConfig: ln.Config,
CancelCtx: c.baseContext,
})
for i := range c.apiListeners {
ln := c.apiListeners[i]
apiServers, err := c.configureForAPI(ln)
if err != nil {
return err
return fmt.Errorf("failed to configure listener for api mode: %w", err)
}
servers = append(servers, apiServers...)
}
// Resolve it here to avoid race conditions if the base context is
// replaced
cancelCtx := c.baseContext
server := &http.Server{
Handler: handler,
ReadHeaderTimeout: 10 * time.Second,
ReadTimeout: 30 * time.Second,
IdleTimeout: 5 * time.Minute,
ErrorLog: c.logger.StandardLogger(nil),
BaseContext: func(net.Listener) context.Context {
return cancelCtx
},
}
ln.HTTPServer = server
clusterServer, err := c.configureForCluster(c.clusterListener)
if err != nil {
return fmt.Errorf("failed to configure listener for cluster mode: %w", err)
}
servers = append(servers, clusterServer)
if ln.Config.HTTPReadHeaderTimeout > 0 {
server.ReadHeaderTimeout = ln.Config.HTTPReadHeaderTimeout
}
if ln.Config.HTTPReadTimeout > 0 {
server.ReadTimeout = ln.Config.HTTPReadTimeout
}
if ln.Config.HTTPWriteTimeout > 0 {
server.WriteTimeout = ln.Config.HTTPWriteTimeout
}
if ln.Config.HTTPIdleTimeout > 0 {
server.IdleTimeout = ln.Config.HTTPIdleTimeout
}
for _, s := range servers {
s()
}
switch ln.Config.TLSDisable {
case true:
l, err := ln.Mux.RegisterProto(alpnmux.NoProto, nil)
if err != nil {
return fmt.Errorf("error getting non-tls listener: %w", err)
}
if l == nil {
return errors.New("could not get non-tls listener")
}
servers = append(servers, func() {
go server.Serve(l)
})
default:
protos := []string{"", "http/1.1", "h2"}
for _, v := range protos {
l := ln.Mux.GetListener(v)
if l == nil {
return fmt.Errorf("could not get tls proto %q listener", v)
}
servers = append(servers, func() {
go server.Serve(l)
})
}
}
return nil
}
return nil
func (c *Controller) configureForAPI(ln *base.ServerListener) ([]func(), error) {
apiServers := make([]func(), 0)
handler, err := c.apiHandler(HandlerProperties{
ListenerConfig: ln.Config,
CancelCtx: c.baseContext,
})
if err != nil {
return nil, err
}
configureForCluster := func(ln *base.ServerListener) error {
// Clear out in case this is a second start of the controller
ln.Mux.UnregisterProto(alpnmux.DefaultProto)
l, err := ln.Mux.RegisterProto(alpnmux.DefaultProto, &tls.Config{
GetConfigForClient: c.validateWorkerTls,
})
cancelCtx := c.baseContext // Resolve to avoid race conditions if the base context is replaced.
server := &http.Server{
Handler: handler,
ReadHeaderTimeout: 10 * time.Second,
ReadTimeout: 30 * time.Second,
IdleTimeout: 5 * time.Minute,
ErrorLog: c.logger.StandardLogger(nil),
BaseContext: func(net.Listener) context.Context { return cancelCtx },
}
ln.HTTPServer = server
if ln.Config.HTTPReadHeaderTimeout > 0 {
server.ReadHeaderTimeout = ln.Config.HTTPReadHeaderTimeout
}
if ln.Config.HTTPReadTimeout > 0 {
server.ReadTimeout = ln.Config.HTTPReadTimeout
}
if ln.Config.HTTPWriteTimeout > 0 {
server.WriteTimeout = ln.Config.HTTPWriteTimeout
}
if ln.Config.HTTPIdleTimeout > 0 {
server.IdleTimeout = ln.Config.HTTPIdleTimeout
}
switch ln.Config.TLSDisable {
case true:
l, err := ln.Mux.RegisterProto(alpnmux.NoProto, nil)
if err != nil {
return fmt.Errorf("error getting sub-listener for worker proto: %w", err)
return nil, fmt.Errorf("error getting non-tls listener: %w", err)
}
if l == nil {
return nil, errors.New("could not get non-tls listener")
}
apiServers = append(apiServers, func() { go server.Serve(l) })
workerReqInterceptor, err := workerRequestInfoInterceptor(ctx, c.conf.Eventer)
if err != nil {
return fmt.Errorf("error getting sub-listener for worker proto: %w", err)
default:
for _, v := range []string{"", "http/1.1", "h2"} {
l := ln.Mux.GetListener(v)
if l == nil {
return nil, fmt.Errorf("could not get tls proto %q listener", v)
}
apiServers = append(apiServers, func() { go server.Serve(l) })
}
workerServer := grpc.NewServer(
grpc.MaxRecvMsgSize(math.MaxInt32),
grpc.MaxSendMsgSize(math.MaxInt32),
grpc.UnaryInterceptor(
grpc_middleware.ChainUnaryServer(
workerReqInterceptor,
auditRequestInterceptor(ctx), // before we get started, audit the request
auditResponseInterceptor(ctx), // as we finish, audit the response
),
),
)
workerService := workers.NewWorkerServiceServer(c.ServersRepoFn, c.SessionRepoFn, c.ConnectionRepoFn,
c.workerStatusUpdateTimes, c.kms)
pbs.RegisterServerCoordinationServiceServer(workerServer, workerService)
pbs.RegisterSessionServiceServer(workerServer, workerService)
interceptor := newInterceptingListener(c, l)
ln.ALPNListener = interceptor
ln.GrpcServer = workerServer
servers = append(servers, func() {
go workerServer.Serve(interceptor)
})
return nil
}
c.gatewayListener, _ = newGatewayListener()
servers = append(servers, func() {
go c.gatewayServer.Serve(c.gatewayListener)
})
return apiServers, nil
}
for _, ln := range c.conf.Listeners {
var err error
for _, purpose := range ln.Config.Purpose {
switch purpose {
case "api":
err = configureForAPI(ln)
case "cluster":
err = configureForCluster(ln)
case "proxy":
// Do nothing, in a dev mode we might see it here
default:
err = fmt.Errorf("unknown listener purpose %q", purpose)
}
if err != nil {
return err
}
}
func (c *Controller) configureForCluster(ln *base.ServerListener) (func(), error) {
// Clear out in case this is a second start of the controller
ln.Mux.UnregisterProto(alpnmux.DefaultProto)
l, err := ln.Mux.RegisterProto(alpnmux.DefaultProto, &tls.Config{
GetConfigForClient: c.validateWorkerTls,
})
if err != nil {
return nil, fmt.Errorf("error getting sub-listener for worker proto: %w", err)
}
for _, s := range servers {
s()
workerReqInterceptor, err := workerRequestInfoInterceptor(c.baseContext, c.conf.Eventer)
if err != nil {
return nil, fmt.Errorf("error getting sub-listener for worker proto: %w", err)
}
return nil
workerServer := grpc.NewServer(
grpc.MaxRecvMsgSize(math.MaxInt32),
grpc.MaxSendMsgSize(math.MaxInt32),
grpc.UnaryInterceptor(
grpc_middleware.ChainUnaryServer(
workerReqInterceptor,
auditRequestInterceptor(c.baseContext), // before we get started, audit the request
auditResponseInterceptor(c.baseContext), // as we finish, audit the response
),
),
)
workerService := workers.NewWorkerServiceServer(c.ServersRepoFn, c.SessionRepoFn, c.ConnectionRepoFn,
c.workerStatusUpdateTimes, c.kms)
pbs.RegisterServerCoordinationServiceServer(workerServer, workerService)
pbs.RegisterSessionServiceServer(workerServer, workerService)
interceptor := newInterceptingListener(c, l)
ln.ALPNListener = interceptor
ln.GrpcServer = workerServer
return func() { go ln.GrpcServer.Serve(ln.ALPNListener) }, nil
}
func (c *Controller) stopListeners(serversOnly bool) error {
@ -194,7 +187,7 @@ func (c *Controller) stopListeners(serversOnly bool) error {
}()
}
if c.gatewayServer != nil {
if c.apiGrpcServer != nil {
serverWg.Add(1)
go func() {
defer serverWg.Done()
@ -202,9 +195,9 @@ func (c *Controller) stopListeners(serversOnly bool) error {
defer shutdownKillCancel()
go func() {
<-shutdownKill.Done()
c.gatewayServer.Stop()
c.apiGrpcServer.Stop()
}()
c.gatewayServer.GracefulStop()
c.apiGrpcServer.GracefulStop()
}()
}

@ -0,0 +1,400 @@
package controller
import (
"context"
"crypto/ed25519"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"fmt"
"io/ioutil"
"math/big"
"net"
"net/http"
"os"
"strconv"
"testing"
"time"
"github.com/hashicorp/boundary/internal/cmd/base"
"github.com/hashicorp/go-secure-stdlib/base62"
"github.com/hashicorp/go-secure-stdlib/configutil/v2"
"github.com/hashicorp/go-secure-stdlib/listenerutil"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/protobuf/proto"
)
func TestStartListeners(t *testing.T) {
tests := []struct {
name string
setup func(t *testing.T)
listeners []*listenerutil.ListenerConfig
assertions func(t *testing.T, c *Controller, apiAddrs []string, clusterAddr string)
}{
{
name: "one api, one cluster listener",
listeners: []*listenerutil.ListenerConfig{
{
Type: "tcp",
Purpose: []string{"api"},
Address: "127.0.0.1:0",
TLSDisable: true,
},
{
Type: "tcp",
Purpose: []string{"cluster"},
Address: "127.0.0.1:0",
},
},
assertions: func(t *testing.T, c *Controller, apiAddrs []string, clusterAddr string) {
rsp, err := http.Get("http://" + apiAddrs[0] + "/v1/auth-methods?scope_id=global")
require.NoError(t, err)
require.Equal(t, http.StatusOK, rsp.StatusCode)
clusterGrpcDialNoError(t, c, "tcp", clusterAddr)
},
},
{
name: "multiple api, one cluster listeners",
listeners: []*listenerutil.ListenerConfig{
{
Type: "tcp",
Purpose: []string{"api"},
Address: "127.0.0.1:0",
TLSDisable: true,
},
{
Type: "tcp",
Purpose: []string{"api"},
Address: "127.0.0.1:0",
TLSDisable: true,
},
{
Type: "tcp",
Purpose: []string{"cluster"},
Address: "127.0.0.1:0",
},
},
assertions: func(t *testing.T, c *Controller, apiAddrs []string, clusterAddr string) {
for _, apiAddr := range apiAddrs {
rsp, err := http.Get("http://" + apiAddr + "/v1/auth-methods?scope_id=global")
require.NoError(t, err)
require.Equal(t, http.StatusOK, rsp.StatusCode)
}
clusterGrpcDialNoError(t, c, "tcp", clusterAddr)
},
},
{
name: "one api (tls), one cluster listener",
setup: testApiTlsSetup(t, "./boundary-startlistenerstest.cert", "./boundary-startlistenerstest.key"),
listeners: []*listenerutil.ListenerConfig{
{
Type: "tcp",
Purpose: []string{"api"},
Address: "127.0.0.1:0",
TLSCertFile: "./boundary-startlistenerstest.cert",
TLSKeyFile: "./boundary-startlistenerstest.key",
},
{
Type: "tcp",
Purpose: []string{"cluster"},
Address: "127.0.0.1:0",
},
},
assertions: func(t *testing.T, c *Controller, apiAddrs []string, clusterAddr string) {
client := testTlsHttpClient(t, "./boundary-startlistenerstest.cert")
rsp, err := client.Get("https://" + apiAddrs[0] + "/v1/auth-methods?scope_id=global")
require.NoError(t, err)
require.Equal(t, http.StatusOK, rsp.StatusCode)
clusterGrpcDialNoError(t, c, "tcp", clusterAddr)
},
},
{
name: "multiple api (tls), one cluster listener",
setup: func(t *testing.T) {
testApiTlsSetup(t, "./boundary-startlistenerstest0.cert", "./boundary-startlistenerstest0.key")(t)
testApiTlsSetup(t, "./boundary-startlistenerstest1.cert", "./boundary-startlistenerstest1.key")(t)
},
listeners: []*listenerutil.ListenerConfig{
{
Type: "tcp",
Purpose: []string{"api"},
Address: "127.0.0.1:0",
TLSCertFile: "./boundary-startlistenerstest0.cert",
TLSKeyFile: "./boundary-startlistenerstest0.key",
},
{
Type: "tcp",
Purpose: []string{"api"},
Address: "127.0.0.1:0",
TLSCertFile: "./boundary-startlistenerstest1.cert",
TLSKeyFile: "./boundary-startlistenerstest1.key",
},
{
Type: "tcp",
Purpose: []string{"cluster"},
Address: "127.0.0.1:0",
},
},
assertions: func(t *testing.T, c *Controller, apiAddrs []string, clusterAddr string) {
for i, apiAddr := range apiAddrs {
client := testTlsHttpClient(t, "./boundary-startlistenerstest"+strconv.Itoa(i)+".cert")
rsp, err := client.Get("https://" + apiAddr + "/v1/auth-methods?scope_id=global")
require.NoError(t, err)
require.Equal(t, http.StatusOK, rsp.StatusCode)
}
clusterGrpcDialNoError(t, c, "tcp", clusterAddr)
},
},
{
name: "one api, one cluster listener on unix sockets",
listeners: []*listenerutil.ListenerConfig{
{
Type: "unix",
Purpose: []string{"api"},
Address: "/tmp/boundary-listener-test-api.sock",
TLSDisable: true,
},
{
Type: "unix",
Purpose: []string{"cluster"},
Address: "/tmp/boundary-listener-test-cluster.sock",
},
},
assertions: func(t *testing.T, c *Controller, apiAddrs []string, clusterAddr string) {
conn, err := net.Dial("unix", apiAddrs[0])
require.NoError(t, err)
cl := http.Client{
Transport: &http.Transport{
Dial: func(network, addr string) (net.Conn, error) { return conn, nil },
},
}
rsp, err := cl.Get("http://randomdomain.boundary/v1/auth-methods?scope_id=global")
require.NoError(t, err)
require.Equal(t, http.StatusOK, rsp.StatusCode)
require.NoError(t, conn.Close())
clusterGrpcDialNoError(t, c, "unix", clusterAddr)
},
},
{
name: "multiple api, one cluster listener on unix sockets",
listeners: []*listenerutil.ListenerConfig{
{
Type: "unix",
Purpose: []string{"api"},
Address: "/tmp/boundary-listener-test-api.sock",
TLSDisable: true,
},
{
Type: "unix",
Purpose: []string{"api"},
Address: "/tmp/boundary-listener-test-api2.sock",
TLSDisable: true,
},
{
Type: "unix",
Purpose: []string{"cluster"},
Address: "/tmp/boundary-listener-test-cluster.sock",
},
},
assertions: func(t *testing.T, c *Controller, apiAddrs []string, clusterAddr string) {
for _, apiAddr := range apiAddrs {
conn, err := net.Dial("unix", apiAddr)
require.NoError(t, err)
cl := http.Client{
Transport: &http.Transport{
Dial: func(network, addr string) (net.Conn, error) { return conn, nil },
},
}
rsp, err := cl.Get("http://randomdomain.boundary/v1/auth-methods?scope_id=global")
require.NoError(t, err)
require.Equal(t, http.StatusOK, rsp.StatusCode)
require.NoError(t, conn.Close())
}
clusterGrpcDialNoError(t, c, "unix", clusterAddr)
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.setup != nil {
tt.setup(t)
}
ctx, cancel := context.WithCancel(context.Background())
tc := &TestController{t: t, ctx: ctx, cancel: cancel, opts: &TestControllerOpts{}}
t.Cleanup(func() { tc.Shutdown() })
conf := TestControllerConfig(t, ctx, tc, nil)
conf.RawConfig.SharedConfig = &configutil.SharedConfig{Listeners: tt.listeners, DisableMlock: true}
err := conf.SetupListeners(nil, conf.RawConfig.SharedConfig, []string{"api", "cluster"})
require.NoError(t, err)
c, err := New(ctx, conf)
require.NoError(t, err)
c.baseContext = ctx
c.baseCancel = cancel
err = c.startListeners()
require.NoError(t, err)
apiAddrs := make([]string, 0)
for _, l := range c.apiListeners {
apiAddrs = append(apiAddrs, l.Mux.Addr().String())
}
tt.assertions(t, c, apiAddrs, c.clusterListener.Mux.Addr().String())
})
}
}
func clusterGrpcDialNoError(t *testing.T, c *Controller, network, addr string) {
grpcConn, err := grpc.Dial(addr,
grpc.WithInsecure(),
grpc.WithBlock(),
grpc.WithTimeout(5*time.Second),
grpc.WithContextDialer(clusterTestDialer(t, c, network)),
)
require.NoError(t, err)
require.NoError(t, grpcConn.Close())
}
func testApiTlsSetup(t *testing.T, certPath, keyPath string) func(t *testing.T) {
return func(t *testing.T) {
t.Cleanup(func() {
require.NoError(t, os.Remove(certPath))
require.NoError(t, os.Remove(keyPath))
})
certBytes, _, priv := createTestCert(t)
certFile, err := os.Create(certPath)
require.NoError(t, err)
require.NoError(t, pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: certBytes}))
require.NoError(t, certFile.Close())
keyFile, err := os.Create(keyPath)
require.NoError(t, err)
marshaledKey, err := x509.MarshalPKCS8PrivateKey(priv)
require.NoError(t, err)
require.NoError(t, pem.Encode(keyFile, &pem.Block{Type: "PRIVATE KEY", Bytes: marshaledKey}))
require.NoError(t, keyFile.Close())
}
}
func testTlsHttpClient(t *testing.T, filePath string) *http.Client {
f, err := os.Open(filePath)
require.NoError(t, err)
certBytes, err := ioutil.ReadAll(f)
require.NoError(t, err)
require.NoError(t, f.Close())
certPool := x509.NewCertPool()
certPool.AppendCertsFromPEM(certBytes)
return &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{RootCAs: certPool},
},
}
}
func createTestCert(t *testing.T) ([]byte, ed25519.PublicKey, ed25519.PrivateKey) {
pub, priv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
template := &x509.Certificate{
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement | x509.KeyUsageCertSign,
SerialNumber: big.NewInt(0),
NotBefore: time.Now().Add(-30 * time.Second),
NotAfter: time.Now().Add(5 * time.Minute),
BasicConstraintsValid: true,
IsCA: true,
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
DNSNames: []string{"/tmp/boundary-listener-test-cluster.sock"},
}
certBytes, err := x509.CreateCertificate(rand.Reader, template, template, pub, priv)
require.NoError(t, err)
return certBytes, pub, priv
}
func clusterTestDialer(t *testing.T, c *Controller, network string) func(context.Context, string) (net.Conn, error) {
return func(ctx context.Context, addr string) (net.Conn, error) {
certBytes, _, priv := createTestCert(t)
marshaledKey, err := x509.MarshalPKCS8PrivateKey(priv)
require.NoError(t, err)
nonce, err := base62.Random(20)
require.NoError(t, err)
info := base.WorkerAuthInfo{
CertPEM: pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certBytes}),
KeyPEM: pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: marshaledKey}),
ConnectionNonce: nonce,
}
infoBytes, err := json.Marshal(info)
require.NoError(t, err)
encInfo, err := c.conf.WorkerAuthKms.Encrypt(ctx, infoBytes, nil)
require.NoError(t, err)
protoEncInfo, err := proto.Marshal(encInfo)
require.NoError(t, err)
b64alpn := base64.RawStdEncoding.EncodeToString(protoEncInfo)
var nextProtos []string
var count int
for i := 0; i < len(b64alpn); i += 230 {
end := i + 230
if end > len(b64alpn) {
end = len(b64alpn)
}
nextProtos = append(nextProtos, fmt.Sprintf("v1workerauth-%02d-%s", count, b64alpn[i:end]))
count++
}
cert, err := x509.ParseCertificate(certBytes)
require.NoError(t, err)
rootCAs := x509.NewCertPool()
rootCAs.AddCert(cert)
tlsCert, err := tls.X509KeyPair(info.CertPEM, info.KeyPEM)
require.NoError(t, err)
conn, err := tls.Dial(network, addr, &tls.Config{
Certificates: []tls.Certificate{tlsCert},
RootCAs: rootCAs,
NextProtos: nextProtos,
MinVersion: tls.VersionTLS13,
})
require.NoError(t, err)
_, err = conn.Write([]byte(nonce))
require.NoError(t, err)
return conn, nil
}
}
Loading…
Cancel
Save