Merge pull request #1907 from hashicorp/hugoamvieira-controller-start-listeners

refactor(controller): `startListeners` improvements
pull/1941/head
Hugo 4 years ago committed by GitHub
commit d438a6d2dd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -7,7 +7,6 @@ import (
"strings"
"sync"
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
"github.com/hashicorp/boundary/internal/auth/oidc"
"github.com/hashicorp/boundary/internal/auth/password"
"github.com/hashicorp/boundary/internal/authtoken"
@ -52,14 +51,15 @@ type Controller struct {
workerAuthCache *sync.Map
apiListeners []*base.ServerListener
clusterListener *base.ServerListener
// Used for testing and tracking worker health
workerStatusUpdateTimes *sync.Map
// grpc gateway server
gatewayServer *grpc.Server
gatewayTicket string
gatewayListener gatewayListener
gatewayMux *runtime.ServeMux
apiGrpcServer *grpc.Server
apiGrpcServerListener grpcServerListener
apiGrpcGatewayTicket string
// Repo factory methods
AuthTokenRepoFn common.AuthTokenRepoFactory
@ -125,6 +125,32 @@ func New(ctx context.Context, conf *Config) (*Controller, error) {
}
}
clusterListeners := make([]*base.ServerListener, 0)
for i := range conf.Listeners {
l := conf.Listeners[i]
if l == nil || l.Config == nil || l.Config.Purpose == nil {
continue
}
if len(l.Config.Purpose) != 1 {
return nil, fmt.Errorf("found listener with multiple purposes %q", strings.Join(l.Config.Purpose, ","))
}
switch l.Config.Purpose[0] {
case "api":
c.apiListeners = append(c.apiListeners, l)
case "cluster":
clusterListeners = append(clusterListeners, l)
}
}
if len(c.apiListeners) == 0 {
return nil, fmt.Errorf("no api listeners found")
}
if len(clusterListeners) != 1 {
// in the future, we might pick the cluster that is exposed to the outside
// instead of limiting it to one.
return nil, fmt.Errorf("exactly one cluster listener is required")
}
c.clusterListener = clusterListeners[0]
var pluginLogger hclog.Logger
for _, enabledPlugin := range c.enabledPlugins {
if pluginLogger == nil {
@ -264,7 +290,7 @@ func (c *Controller) Start() error {
if err := c.registerJobs(); err != nil {
return fmt.Errorf("error registering jobs: %w", err)
}
if err := c.startListeners(c.baseContext); err != nil {
if err := c.startListeners(); err != nil {
return fmt.Errorf("error starting controller listeners: %w", err)
}
if err := c.scheduler.Start(c.baseContext, c.schedulerWg); err != nil {

@ -4,8 +4,10 @@ import (
"context"
"testing"
"github.com/hashicorp/boundary/internal/cmd/base"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/kms"
"github.com/hashicorp/go-secure-stdlib/listenerutil"
"github.com/stretchr/testify/require"
)
@ -32,3 +34,153 @@ func TestController_New(t *testing.T) {
require.NoError(err)
})
}
func TestControllerNewListenerConfig(t *testing.T) {
tests := []struct {
name string
listeners []*base.ServerListener
assertions func(t *testing.T, c *Controller)
expErr bool
expErrMsg string
}{
{
name: "valid listener configuration",
listeners: []*base.ServerListener{
{
Config: &listenerutil.ListenerConfig{
Purpose: []string{"api"},
},
},
{
Config: &listenerutil.ListenerConfig{
Purpose: []string{"api"},
},
},
{
Config: &listenerutil.ListenerConfig{
Purpose: []string{"cluster"},
},
},
},
assertions: func(t *testing.T, c *Controller) {
require.Len(t, c.apiListeners, 2)
require.NotNil(t, c.clusterListener)
},
},
{
name: "listeners are required",
listeners: []*base.ServerListener{},
expErr: true,
expErrMsg: "no api listeners found",
},
{
name: "listeners are required - not nil",
listeners: []*base.ServerListener{nil, nil},
expErr: true,
expErrMsg: "no api listeners found",
},
{
name: "listeners are required - with config",
listeners: []*base.ServerListener{{}, {}},
expErr: true,
expErrMsg: "no api listeners found",
},
{
name: "listeners are required - with purposes",
listeners: []*base.ServerListener{
{
Config: &listenerutil.ListenerConfig{Purpose: nil},
},
{
Config: &listenerutil.ListenerConfig{Purpose: nil},
},
},
expErr: true,
expErrMsg: "no api listeners found",
},
{
name: "both api and cluster listeners are required",
listeners: []*base.ServerListener{
{
Config: &listenerutil.ListenerConfig{
Purpose: []string{"api"},
},
},
},
expErr: true,
expErrMsg: "exactly one cluster listener is required",
},
{
name: "both api and cluster listeners are required 2",
listeners: []*base.ServerListener{
{
Config: &listenerutil.ListenerConfig{
Purpose: []string{"cluster"},
},
},
},
expErr: true,
expErrMsg: "no api listeners found",
},
{
name: "only one cluster listener is allowed",
listeners: []*base.ServerListener{
{
Config: &listenerutil.ListenerConfig{
Purpose: []string{"api"},
},
},
{
Config: &listenerutil.ListenerConfig{
Purpose: []string{"cluster"},
},
},
{
Config: &listenerutil.ListenerConfig{
Purpose: []string{"cluster"},
},
},
},
expErr: true,
expErrMsg: "exactly one cluster listener is required",
},
{
name: "only one purpose is allowed per listener",
listeners: []*base.ServerListener{
{
Config: &listenerutil.ListenerConfig{
Purpose: []string{"api", "cluster"},
},
},
},
expErr: true,
expErrMsg: `found listener with multiple purposes "api,cluster"`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
tc := &TestController{
t: t,
ctx: ctx,
cancel: cancel,
opts: nil,
}
conf := TestControllerConfig(t, ctx, tc, nil)
conf.Listeners = tt.listeners
c, err := New(ctx, conf)
if tt.expErr {
require.EqualError(t, err, tt.expErrMsg)
require.Nil(t, c)
return
}
require.NoError(t, err)
require.NotNil(t, c)
tt.assertions(t, c)
})
}
}

@ -20,50 +20,54 @@ const corsTestConfig = `
disable_mlock = true
telemetry {
prometheus_retention_time = "24h"
disable_hostname = true
prometheus_retention_time = "24h"
disable_hostname = true
}
kms "aead" {
purpose = "root"
aead_type = "aes-gcm"
key = "09iqFxRJNYsl/b8CQxjnGw=="
key_id = "global_root"
purpose = "root"
aead_type = "aes-gcm"
key = "09iqFxRJNYsl/b8CQxjnGw=="
key_id = "global_root"
}
kms "aead" {
purpose = "worker-auth"
aead_type = "aes-gcm"
key = "09iqFxRJNYsl/b8CQxjnGw=="
key_id = "global_worker-auth"
purpose = "worker-auth"
aead_type = "aes-gcm"
key = "09iqFxRJNYsl/b8CQxjnGw=="
key_id = "global_worker-auth"
}
listener "tcp" {
purpose = "api"
tls_disable = true
cors_enabled = false
purpose = "cluster"
}
listener "tcp" {
purpose = "api"
tls_disable = true
cors_enabled = true
cors_allowed_origins = []
purpose = "api"
tls_disable = true
cors_enabled = false
}
listener "tcp" {
purpose = "api"
tls_disable = true
cors_enabled = true
cors_allowed_origins = ["foobar.com", "barfoo.com"]
purpose = "api"
tls_disable = true
cors_enabled = true
cors_allowed_origins = []
}
listener "tcp" {
purpose = "api"
tls_disable = true
cors_enabled = true
cors_allowed_origins = ["*"]
cors_allowed_headers = ["x-foobar"]
purpose = "api"
tls_disable = true
cors_enabled = true
cors_allowed_origins = ["foobar.com", "barfoo.com"]
}
listener "tcp" {
purpose = "api"
tls_disable = true
cors_enabled = true
cors_allowed_origins = ["*"]
cors_allowed_headers = ["x-foobar"]
}
`

@ -19,20 +19,14 @@ import (
"google.golang.org/grpc/test/bufconn"
)
// newGatewayListener will create an in-memory listener
func newGatewayListener() (gatewayListener, string) {
buffer := globals.DefaultMaxRequestSize // seems like a reasonable size for the ring buffer, but then happily change the size if more info becomes available
return bufconn.Listen(int(buffer)), ""
}
const gatewayTarget = ""
type gatewayListener interface {
type grpcServerListener interface {
net.Listener
Dial() (net.Conn, error)
}
func gatewayDialOptions(lis gatewayListener) []grpc.DialOption {
func gatewayDialOptions(lis grpcServerListener) []grpc.DialOption {
return []grpc.DialOption{
grpc.WithInsecure(),
grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) {
@ -41,7 +35,7 @@ func gatewayDialOptions(lis gatewayListener) []grpc.DialOption {
}
}
func newGatewayMux() *runtime.ServeMux {
func newGrpcGatewayMux() *runtime.ServeMux {
return runtime.NewServeMux(
runtime.WithMarshalerOption(runtime.MIMEWildcard, &runtime.HTTPBodyMarshaler{
Marshaler: handlers.JSONMarshaler(),
@ -51,7 +45,13 @@ func newGatewayMux() *runtime.ServeMux {
)
}
func newGatewayServer(
// newGrpcServerListener will create an in-memory listener for the gRPC server.
func newGrpcServerListener() grpcServerListener {
buffer := globals.DefaultMaxRequestSize // seems like a reasonable size for the ring buffer, but then happily change the size if more info becomes available
return bufconn.Listen(int(buffer))
}
func newGrpcServer(
ctx context.Context,
iamRepoFn common.IamRepoFactory,
authTokenRepoFn common.AuthTokenRepoFactory,
@ -59,7 +59,7 @@ func newGatewayServer(
kms *kms.Kms,
eventer *event.Eventer,
) (*grpc.Server, string, error) {
const op = "controller.newGatewayServer"
const op = "controller.newGrpcServer"
ticket, err := db.NewPrivateId("gwticket")
if err != nil {
return nil, "", errors.Wrap(ctx, err, op, errors.WithMsg("unable to generate gateway ticket"))

@ -13,6 +13,7 @@ import (
"strings"
"time"
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
"github.com/hashicorp/boundary/globals"
"github.com/hashicorp/boundary/internal/auth/oidc"
"github.com/hashicorp/boundary/internal/gen/controller/api/services"
@ -39,6 +40,7 @@ import (
"github.com/hashicorp/go-secure-stdlib/listenerutil"
"github.com/hashicorp/go-secure-stdlib/strutil"
"github.com/mr-tron/base58"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/protobuf/proto"
@ -50,17 +52,17 @@ type HandlerProperties struct {
CancelCtx context.Context
}
// Handler returns an http.Handler for the services. This can be used on
// apiHandler returns an http.Handler for the services. This can be used on
// its own to mount the Controller API within another web server.
func (c *Controller) handler(props HandlerProperties) (http.Handler, error) {
// Create the muxer to handle the actual endpoints
func (c *Controller) apiHandler(props HandlerProperties) (http.Handler, error) {
mux := http.NewServeMux()
h, err := handleGrpcGateway(c, props)
grpcGwMux := newGrpcGatewayMux()
err := registerGrpcGatewayEndpoints(props.CancelCtx, grpcGwMux, gatewayDialOptions(c.apiGrpcServerListener)...)
if err != nil {
return nil, err
}
mux.Handle("/v1/", h)
mux.Handle("/v1/", grpcGwMux)
mux.Handle("/", handleUi(c))
corsWrappedHandler := wrapHandlerWithCors(mux, props)
@ -68,103 +70,74 @@ func (c *Controller) handler(props HandlerProperties) (http.Handler, error) {
callbackInterceptingHandler := wrapHandlerWithCallbackInterceptor(commonWrappedHandler, c)
printablePathCheckHandler := cleanhttp.PrintablePathCheckHandler(callbackInterceptingHandler, nil)
eventsHandler, err := common.WrapWithEventsHandler(printablePathCheckHandler, c.conf.Eventer, c.kms, props.ListenerConfig)
if err != nil {
return nil, err
}
return eventsHandler, nil
return eventsHandler, err
}
func handleGrpcGateway(c *Controller, props HandlerProperties) (http.Handler, error) {
// Register*ServiceHandlerServer methods ignore the passed in ctx. Using it
// now however in case this changes in the future.
ctx := props.CancelCtx
currentServices := c.gatewayServer.GetServiceInfo()
dialOptions := gatewayDialOptions(c.gatewayListener)
func (c *Controller) registerGrpcServices(s *grpc.Server) error {
// We have to check against the current services because the gRPC lib treats a duplicate
// register call as an error and os.Exits.
currentServices := s.GetServiceInfo()
if _, ok := currentServices[services.HostCatalogService_ServiceDesc.ServiceName]; !ok {
hcs, err := host_catalogs.NewService(c.StaticHostRepoFn, c.PluginHostRepoFn, c.HostPluginRepoFn, c.IamRepoFn)
if err != nil {
return nil, fmt.Errorf("failed to create host catalog handler service: %w", err)
}
services.RegisterHostCatalogServiceServer(c.gatewayServer, hcs)
if err := services.RegisterHostCatalogServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil {
return nil, fmt.Errorf("failed to register host catalog service handler: %w", err)
return fmt.Errorf("failed to create host catalog handler service: %w", err)
}
services.RegisterHostCatalogServiceServer(s, hcs)
}
if _, ok := currentServices[services.HostSetService_ServiceDesc.ServiceName]; !ok {
hss, err := host_sets.NewService(c.StaticHostRepoFn, c.PluginHostRepoFn)
if err != nil {
return nil, fmt.Errorf("failed to create host set handler service: %w", err)
}
services.RegisterHostSetServiceServer(c.gatewayServer, hss)
if err := services.RegisterHostSetServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil {
return nil, fmt.Errorf("failed to register host set service handler: %w", err)
return fmt.Errorf("failed to create host set handler service: %w", err)
}
services.RegisterHostSetServiceServer(s, hss)
}
if _, ok := currentServices[services.HostService_ServiceDesc.ServiceName]; !ok {
hs, err := hosts.NewService(c.StaticHostRepoFn, c.PluginHostRepoFn)
if err != nil {
return nil, fmt.Errorf("failed to create host handler service: %w", err)
}
services.RegisterHostServiceServer(c.gatewayServer, hs)
if err := services.RegisterHostServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil {
return nil, fmt.Errorf("failed to register host service handler: %w", err)
return fmt.Errorf("failed to create host handler service: %w", err)
}
services.RegisterHostServiceServer(s, hs)
}
if _, ok := currentServices[services.AccountService_ServiceDesc.ServiceName]; !ok {
accts, err := accounts.NewService(c.PasswordAuthRepoFn, c.OidcRepoFn)
if err != nil {
return nil, fmt.Errorf("failed to create account handler service: %w", err)
}
services.RegisterAccountServiceServer(c.gatewayServer, accts)
if err := services.RegisterAccountServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil {
return nil, fmt.Errorf("failed to register account service handler: %w", err)
return fmt.Errorf("failed to create account handler service: %w", err)
}
services.RegisterAccountServiceServer(s, accts)
}
if _, ok := currentServices[services.AuthMethodService_ServiceDesc.ServiceName]; !ok {
authMethods, err := authmethods.NewService(c.kms, c.PasswordAuthRepoFn, c.OidcRepoFn, c.IamRepoFn, c.AuthTokenRepoFn)
if err != nil {
return nil, fmt.Errorf("failed to create auth method handler service: %w", err)
}
services.RegisterAuthMethodServiceServer(c.gatewayServer, authMethods)
if err := services.RegisterAuthMethodServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil {
return nil, fmt.Errorf("failed to register auth method service handler: %w", err)
return fmt.Errorf("failed to create auth method handler service: %w", err)
}
services.RegisterAuthMethodServiceServer(s, authMethods)
}
if _, ok := currentServices[services.AuthTokenService_ServiceDesc.ServiceName]; !ok {
authtoks, err := authtokens.NewService(c.AuthTokenRepoFn, c.IamRepoFn)
if err != nil {
return nil, fmt.Errorf("failed to create auth token handler service: %w", err)
}
services.RegisterAuthTokenServiceServer(c.gatewayServer, authtoks)
if err := services.RegisterAuthTokenServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil {
return nil, fmt.Errorf("failed to register auth token service handler: %w", err)
return fmt.Errorf("failed to create auth token handler service: %w", err)
}
services.RegisterAuthTokenServiceServer(s, authtoks)
}
if _, ok := currentServices[services.ScopeService_ServiceDesc.ServiceName]; !ok {
os, err := scopes.NewService(c.IamRepoFn)
if err != nil {
return nil, fmt.Errorf("failed to create scope handler service: %w", err)
}
services.RegisterScopeServiceServer(c.gatewayServer, os)
if err := services.RegisterScopeServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil {
return nil, fmt.Errorf("failed to register scope service handler: %w", err)
return fmt.Errorf("failed to create scope handler service: %w", err)
}
services.RegisterScopeServiceServer(s, os)
}
if _, ok := currentServices[services.UserService_ServiceDesc.ServiceName]; !ok {
us, err := users.NewService(c.IamRepoFn)
if err != nil {
return nil, fmt.Errorf("failed to create user handler service: %w", err)
}
services.RegisterUserServiceServer(c.gatewayServer, us)
if err := services.RegisterUserServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil {
return nil, fmt.Errorf("failed to register user service handler: %w", err)
return fmt.Errorf("failed to create user handler service: %w", err)
}
services.RegisterUserServiceServer(s, us)
}
if _, ok := currentServices[services.TargetService_ServiceDesc.ServiceName]; !ok {
ts, err := targets.NewService(
ctx,
c.baseContext,
c.kms,
c.TargetRepoFn,
c.IamRepoFn,
@ -174,75 +147,106 @@ func handleGrpcGateway(c *Controller, props HandlerProperties) (http.Handler, er
c.StaticHostRepoFn,
c.VaultCredentialRepoFn)
if err != nil {
return nil, fmt.Errorf("failed to create target handler service: %w", err)
}
services.RegisterTargetServiceServer(c.gatewayServer, ts)
if err := services.RegisterTargetServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil {
return nil, fmt.Errorf("failed to register target service handler: %w", err)
return fmt.Errorf("failed to create target handler service: %w", err)
}
services.RegisterTargetServiceServer(s, ts)
}
if _, ok := currentServices[services.GroupService_ServiceDesc.ServiceName]; !ok {
gs, err := groups.NewService(c.IamRepoFn)
if err != nil {
return nil, fmt.Errorf("failed to create group handler service: %w", err)
}
services.RegisterGroupServiceServer(c.gatewayServer, gs)
if err := services.RegisterGroupServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil {
return nil, fmt.Errorf("failed to register group service handler: %w", err)
return fmt.Errorf("failed to create group handler service: %w", err)
}
services.RegisterGroupServiceServer(s, gs)
}
if _, ok := currentServices[services.RoleService_ServiceDesc.ServiceName]; !ok {
rs, err := roles.NewService(c.IamRepoFn)
if err != nil {
return nil, fmt.Errorf("failed to create role handler service: %w", err)
}
services.RegisterRoleServiceServer(c.gatewayServer, rs)
if err := services.RegisterRoleServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil {
return nil, fmt.Errorf("failed to register role service handler: %w", err)
return fmt.Errorf("failed to create role handler service: %w", err)
}
services.RegisterRoleServiceServer(s, rs)
}
if _, ok := currentServices[services.SessionService_ServiceDesc.ServiceName]; !ok {
ss, err := sessions.NewService(c.SessionRepoFn, c.IamRepoFn)
if err != nil {
return nil, fmt.Errorf("failed to create session handler service: %w", err)
}
services.RegisterSessionServiceServer(c.gatewayServer, ss)
if err := services.RegisterSessionServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil {
return nil, fmt.Errorf("failed to register session service handler: %w", err)
return fmt.Errorf("failed to create session handler service: %w", err)
}
services.RegisterSessionServiceServer(s, ss)
}
if _, ok := currentServices[services.ManagedGroupService_ServiceDesc.ServiceName]; !ok {
mgs, err := managed_groups.NewService(c.OidcRepoFn)
if err != nil {
return nil, fmt.Errorf("failed to create managed groups handler service: %w", err)
}
services.RegisterManagedGroupServiceServer(c.gatewayServer, mgs)
if err := services.RegisterManagedGroupServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil {
return nil, fmt.Errorf("failed to register managed groups service handler: %w", err)
return fmt.Errorf("failed to create managed groups handler service: %w", err)
}
services.RegisterManagedGroupServiceServer(s, mgs)
}
if _, ok := currentServices[services.CredentialStoreService_ServiceDesc.ServiceName]; !ok {
cs, err := credentialstores.NewService(c.VaultCredentialRepoFn, c.IamRepoFn)
if err != nil {
return nil, fmt.Errorf("failed to create credential store handler service: %w", err)
}
services.RegisterCredentialStoreServiceServer(c.gatewayServer, cs)
if err := services.RegisterCredentialStoreServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil {
return nil, fmt.Errorf("failed to register credential store service handler: %w", err)
return fmt.Errorf("failed to create credential store handler service: %w", err)
}
services.RegisterCredentialStoreServiceServer(s, cs)
}
if _, ok := currentServices[services.CredentialLibraryService_ServiceDesc.ServiceName]; !ok {
cl, err := credentiallibraries.NewService(c.VaultCredentialRepoFn, c.IamRepoFn)
if err != nil {
return nil, fmt.Errorf("failed to create credential library handler service: %w", err)
}
services.RegisterCredentialLibraryServiceServer(c.gatewayServer, cl)
if err := services.RegisterCredentialLibraryServiceHandlerFromEndpoint(ctx, c.gatewayMux, gatewayTarget, dialOptions); err != nil {
return nil, fmt.Errorf("failed to register credential library service handler: %w", err)
return fmt.Errorf("failed to create credential library handler service: %w", err)
}
services.RegisterCredentialLibraryServiceServer(s, cl)
}
return nil
}
func registerGrpcGatewayEndpoints(ctx context.Context, gwMux *runtime.ServeMux, dialOptions ...grpc.DialOption) error {
// Register*ServiceHandlerServer methods ignore the passed in context.
// Passing it in anyways in case this changes in the future.
if err := services.RegisterHostCatalogServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil {
return fmt.Errorf("failed to register host catalog service handler: %w", err)
}
if err := services.RegisterHostSetServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil {
return fmt.Errorf("failed to register host set service handler: %w", err)
}
if err := services.RegisterHostServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil {
return fmt.Errorf("failed to register host service handler: %w", err)
}
if err := services.RegisterAccountServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil {
return fmt.Errorf("failed to register account service handler: %w", err)
}
if err := services.RegisterAuthMethodServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil {
return fmt.Errorf("failed to register auth method service handler: %w", err)
}
if err := services.RegisterAuthTokenServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil {
return fmt.Errorf("failed to register auth token service handler: %w", err)
}
if err := services.RegisterScopeServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil {
return fmt.Errorf("failed to register scope service handler: %w", err)
}
if err := services.RegisterUserServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil {
return fmt.Errorf("failed to register user service handler: %w", err)
}
if err := services.RegisterTargetServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil {
return fmt.Errorf("failed to register target service handler: %w", err)
}
if err := services.RegisterGroupServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil {
return fmt.Errorf("failed to register group service handler: %w", err)
}
if err := services.RegisterRoleServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil {
return fmt.Errorf("failed to register role service handler: %w", err)
}
if err := services.RegisterSessionServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil {
return fmt.Errorf("failed to register session service handler: %w", err)
}
if err := services.RegisterManagedGroupServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil {
return fmt.Errorf("failed to register managed groups service handler: %w", err)
}
if err := services.RegisterCredentialStoreServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil {
return fmt.Errorf("failed to register credential store service handler: %w", err)
}
if err := services.RegisterCredentialLibraryServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil {
return fmt.Errorf("failed to register credential library service handler: %w", err)
}
return c.gatewayMux, nil
return nil
}
func wrapHandlerWithCommonFuncs(h http.Handler, c *Controller, props HandlerProperties) http.Handler {
@ -301,7 +305,7 @@ func wrapHandlerWithCommonFuncs(h http.Handler, c *Controller, props HandlerProp
// Serialize the request info to send it across the wire to the
// grpc-gateway via an http header
requestInfo.Ticket = c.gatewayTicket // allows the grpc-gateway to verify the request info came from it's in-memory companion http proxy
requestInfo.Ticket = c.apiGrpcGatewayTicket // allows the grpc-gateway to verify the request info came from it's in-memory companion http proxy
marshalledRequestInfo, err := proto.Marshal(&requestInfo)
if err != nil {
event.WriteError(ctx, op, err, event.WithInfoMsg("error marshaling request info"))

@ -22,151 +22,144 @@ import (
"google.golang.org/grpc"
)
func (c *Controller) startListeners(ctx context.Context) error {
func (c *Controller) startListeners() error {
servers := make([]func(), 0, len(c.conf.Listeners))
configureForAPI := func(ln *base.ServerListener) error {
var err error
if c.gatewayServer, c.gatewayTicket, err = newGatewayServer(ctx, c.IamRepoFn, c.AuthTokenRepoFn, c.ServersRepoFn, c.kms, c.conf.Eventer); err != nil {
return err
}
c.gatewayMux = newGatewayMux()
grpcServer, gwTicket, err := newGrpcServer(c.baseContext, c.IamRepoFn, c.AuthTokenRepoFn, c.ServersRepoFn, c.kms, c.conf.Eventer)
if err != nil {
return fmt.Errorf("failed to create new grpc server: %w", err)
}
c.apiGrpcServer = grpcServer
c.apiGrpcGatewayTicket = gwTicket
err = c.registerGrpcServices(c.apiGrpcServer)
if err != nil {
return fmt.Errorf("failed to register grpc services: %w", err)
}
c.apiGrpcServerListener = newGrpcServerListener()
servers = append(servers, func() {
go c.apiGrpcServer.Serve(c.apiGrpcServerListener)
})
handler, err := c.handler(HandlerProperties{
ListenerConfig: ln.Config,
CancelCtx: c.baseContext,
})
for i := range c.apiListeners {
ln := c.apiListeners[i]
apiServers, err := c.configureForAPI(ln)
if err != nil {
return err
return fmt.Errorf("failed to configure listener for api mode: %w", err)
}
servers = append(servers, apiServers...)
}
// Resolve it here to avoid race conditions if the base context is
// replaced
cancelCtx := c.baseContext
server := &http.Server{
Handler: handler,
ReadHeaderTimeout: 10 * time.Second,
ReadTimeout: 30 * time.Second,
IdleTimeout: 5 * time.Minute,
ErrorLog: c.logger.StandardLogger(nil),
BaseContext: func(net.Listener) context.Context {
return cancelCtx
},
}
ln.HTTPServer = server
clusterServer, err := c.configureForCluster(c.clusterListener)
if err != nil {
return fmt.Errorf("failed to configure listener for cluster mode: %w", err)
}
servers = append(servers, clusterServer)
if ln.Config.HTTPReadHeaderTimeout > 0 {
server.ReadHeaderTimeout = ln.Config.HTTPReadHeaderTimeout
}
if ln.Config.HTTPReadTimeout > 0 {
server.ReadTimeout = ln.Config.HTTPReadTimeout
}
if ln.Config.HTTPWriteTimeout > 0 {
server.WriteTimeout = ln.Config.HTTPWriteTimeout
}
if ln.Config.HTTPIdleTimeout > 0 {
server.IdleTimeout = ln.Config.HTTPIdleTimeout
}
for _, s := range servers {
s()
}
switch ln.Config.TLSDisable {
case true:
l, err := ln.Mux.RegisterProto(alpnmux.NoProto, nil)
if err != nil {
return fmt.Errorf("error getting non-tls listener: %w", err)
}
if l == nil {
return errors.New("could not get non-tls listener")
}
servers = append(servers, func() {
go server.Serve(l)
})
default:
protos := []string{"", "http/1.1", "h2"}
for _, v := range protos {
l := ln.Mux.GetListener(v)
if l == nil {
return fmt.Errorf("could not get tls proto %q listener", v)
}
servers = append(servers, func() {
go server.Serve(l)
})
}
}
return nil
}
return nil
func (c *Controller) configureForAPI(ln *base.ServerListener) ([]func(), error) {
apiServers := make([]func(), 0)
handler, err := c.apiHandler(HandlerProperties{
ListenerConfig: ln.Config,
CancelCtx: c.baseContext,
})
if err != nil {
return nil, err
}
configureForCluster := func(ln *base.ServerListener) error {
// Clear out in case this is a second start of the controller
ln.Mux.UnregisterProto(alpnmux.DefaultProto)
l, err := ln.Mux.RegisterProto(alpnmux.DefaultProto, &tls.Config{
GetConfigForClient: c.validateWorkerTls,
})
cancelCtx := c.baseContext // Resolve to avoid race conditions if the base context is replaced.
server := &http.Server{
Handler: handler,
ReadHeaderTimeout: 10 * time.Second,
ReadTimeout: 30 * time.Second,
IdleTimeout: 5 * time.Minute,
ErrorLog: c.logger.StandardLogger(nil),
BaseContext: func(net.Listener) context.Context { return cancelCtx },
}
ln.HTTPServer = server
if ln.Config.HTTPReadHeaderTimeout > 0 {
server.ReadHeaderTimeout = ln.Config.HTTPReadHeaderTimeout
}
if ln.Config.HTTPReadTimeout > 0 {
server.ReadTimeout = ln.Config.HTTPReadTimeout
}
if ln.Config.HTTPWriteTimeout > 0 {
server.WriteTimeout = ln.Config.HTTPWriteTimeout
}
if ln.Config.HTTPIdleTimeout > 0 {
server.IdleTimeout = ln.Config.HTTPIdleTimeout
}
switch ln.Config.TLSDisable {
case true:
l, err := ln.Mux.RegisterProto(alpnmux.NoProto, nil)
if err != nil {
return fmt.Errorf("error getting sub-listener for worker proto: %w", err)
return nil, fmt.Errorf("error getting non-tls listener: %w", err)
}
if l == nil {
return nil, errors.New("could not get non-tls listener")
}
apiServers = append(apiServers, func() { go server.Serve(l) })
workerReqInterceptor, err := workerRequestInfoInterceptor(ctx, c.conf.Eventer)
if err != nil {
return fmt.Errorf("error getting sub-listener for worker proto: %w", err)
default:
for _, v := range []string{"", "http/1.1", "h2"} {
l := ln.Mux.GetListener(v)
if l == nil {
return nil, fmt.Errorf("could not get tls proto %q listener", v)
}
apiServers = append(apiServers, func() { go server.Serve(l) })
}
workerServer := grpc.NewServer(
grpc.MaxRecvMsgSize(math.MaxInt32),
grpc.MaxSendMsgSize(math.MaxInt32),
grpc.UnaryInterceptor(
grpc_middleware.ChainUnaryServer(
workerReqInterceptor,
auditRequestInterceptor(ctx), // before we get started, audit the request
auditResponseInterceptor(ctx), // as we finish, audit the response
),
),
)
workerService := workers.NewWorkerServiceServer(c.ServersRepoFn, c.SessionRepoFn, c.ConnectionRepoFn,
c.workerStatusUpdateTimes, c.kms)
pbs.RegisterServerCoordinationServiceServer(workerServer, workerService)
pbs.RegisterSessionServiceServer(workerServer, workerService)
interceptor := newInterceptingListener(c, l)
ln.ALPNListener = interceptor
ln.GrpcServer = workerServer
servers = append(servers, func() {
go workerServer.Serve(interceptor)
})
return nil
}
c.gatewayListener, _ = newGatewayListener()
servers = append(servers, func() {
go c.gatewayServer.Serve(c.gatewayListener)
})
return apiServers, nil
}
for _, ln := range c.conf.Listeners {
var err error
for _, purpose := range ln.Config.Purpose {
switch purpose {
case "api":
err = configureForAPI(ln)
case "cluster":
err = configureForCluster(ln)
case "proxy":
// Do nothing, in a dev mode we might see it here
default:
err = fmt.Errorf("unknown listener purpose %q", purpose)
}
if err != nil {
return err
}
}
func (c *Controller) configureForCluster(ln *base.ServerListener) (func(), error) {
// Clear out in case this is a second start of the controller
ln.Mux.UnregisterProto(alpnmux.DefaultProto)
l, err := ln.Mux.RegisterProto(alpnmux.DefaultProto, &tls.Config{
GetConfigForClient: c.validateWorkerTls,
})
if err != nil {
return nil, fmt.Errorf("error getting sub-listener for worker proto: %w", err)
}
for _, s := range servers {
s()
workerReqInterceptor, err := workerRequestInfoInterceptor(c.baseContext, c.conf.Eventer)
if err != nil {
return nil, fmt.Errorf("error getting sub-listener for worker proto: %w", err)
}
return nil
workerServer := grpc.NewServer(
grpc.MaxRecvMsgSize(math.MaxInt32),
grpc.MaxSendMsgSize(math.MaxInt32),
grpc.UnaryInterceptor(
grpc_middleware.ChainUnaryServer(
workerReqInterceptor,
auditRequestInterceptor(c.baseContext), // before we get started, audit the request
auditResponseInterceptor(c.baseContext), // as we finish, audit the response
),
),
)
workerService := workers.NewWorkerServiceServer(c.ServersRepoFn, c.SessionRepoFn, c.ConnectionRepoFn,
c.workerStatusUpdateTimes, c.kms)
pbs.RegisterServerCoordinationServiceServer(workerServer, workerService)
pbs.RegisterSessionServiceServer(workerServer, workerService)
interceptor := newInterceptingListener(c, l)
ln.ALPNListener = interceptor
ln.GrpcServer = workerServer
return func() { go ln.GrpcServer.Serve(ln.ALPNListener) }, nil
}
func (c *Controller) stopListeners(serversOnly bool) error {
@ -194,7 +187,7 @@ func (c *Controller) stopListeners(serversOnly bool) error {
}()
}
if c.gatewayServer != nil {
if c.apiGrpcServer != nil {
serverWg.Add(1)
go func() {
defer serverWg.Done()
@ -202,9 +195,9 @@ func (c *Controller) stopListeners(serversOnly bool) error {
defer shutdownKillCancel()
go func() {
<-shutdownKill.Done()
c.gatewayServer.Stop()
c.apiGrpcServer.Stop()
}()
c.gatewayServer.GracefulStop()
c.apiGrpcServer.GracefulStop()
}()
}

@ -0,0 +1,400 @@
package controller
import (
"context"
"crypto/ed25519"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"fmt"
"io/ioutil"
"math/big"
"net"
"net/http"
"os"
"strconv"
"testing"
"time"
"github.com/hashicorp/boundary/internal/cmd/base"
"github.com/hashicorp/go-secure-stdlib/base62"
"github.com/hashicorp/go-secure-stdlib/configutil/v2"
"github.com/hashicorp/go-secure-stdlib/listenerutil"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/protobuf/proto"
)
func TestStartListeners(t *testing.T) {
tests := []struct {
name string
setup func(t *testing.T)
listeners []*listenerutil.ListenerConfig
assertions func(t *testing.T, c *Controller, apiAddrs []string, clusterAddr string)
}{
{
name: "one api, one cluster listener",
listeners: []*listenerutil.ListenerConfig{
{
Type: "tcp",
Purpose: []string{"api"},
Address: "127.0.0.1:0",
TLSDisable: true,
},
{
Type: "tcp",
Purpose: []string{"cluster"},
Address: "127.0.0.1:0",
},
},
assertions: func(t *testing.T, c *Controller, apiAddrs []string, clusterAddr string) {
rsp, err := http.Get("http://" + apiAddrs[0] + "/v1/auth-methods?scope_id=global")
require.NoError(t, err)
require.Equal(t, http.StatusOK, rsp.StatusCode)
clusterGrpcDialNoError(t, c, "tcp", clusterAddr)
},
},
{
name: "multiple api, one cluster listeners",
listeners: []*listenerutil.ListenerConfig{
{
Type: "tcp",
Purpose: []string{"api"},
Address: "127.0.0.1:0",
TLSDisable: true,
},
{
Type: "tcp",
Purpose: []string{"api"},
Address: "127.0.0.1:0",
TLSDisable: true,
},
{
Type: "tcp",
Purpose: []string{"cluster"},
Address: "127.0.0.1:0",
},
},
assertions: func(t *testing.T, c *Controller, apiAddrs []string, clusterAddr string) {
for _, apiAddr := range apiAddrs {
rsp, err := http.Get("http://" + apiAddr + "/v1/auth-methods?scope_id=global")
require.NoError(t, err)
require.Equal(t, http.StatusOK, rsp.StatusCode)
}
clusterGrpcDialNoError(t, c, "tcp", clusterAddr)
},
},
{
name: "one api (tls), one cluster listener",
setup: testApiTlsSetup(t, "./boundary-startlistenerstest.cert", "./boundary-startlistenerstest.key"),
listeners: []*listenerutil.ListenerConfig{
{
Type: "tcp",
Purpose: []string{"api"},
Address: "127.0.0.1:0",
TLSCertFile: "./boundary-startlistenerstest.cert",
TLSKeyFile: "./boundary-startlistenerstest.key",
},
{
Type: "tcp",
Purpose: []string{"cluster"},
Address: "127.0.0.1:0",
},
},
assertions: func(t *testing.T, c *Controller, apiAddrs []string, clusterAddr string) {
client := testTlsHttpClient(t, "./boundary-startlistenerstest.cert")
rsp, err := client.Get("https://" + apiAddrs[0] + "/v1/auth-methods?scope_id=global")
require.NoError(t, err)
require.Equal(t, http.StatusOK, rsp.StatusCode)
clusterGrpcDialNoError(t, c, "tcp", clusterAddr)
},
},
{
name: "multiple api (tls), one cluster listener",
setup: func(t *testing.T) {
testApiTlsSetup(t, "./boundary-startlistenerstest0.cert", "./boundary-startlistenerstest0.key")(t)
testApiTlsSetup(t, "./boundary-startlistenerstest1.cert", "./boundary-startlistenerstest1.key")(t)
},
listeners: []*listenerutil.ListenerConfig{
{
Type: "tcp",
Purpose: []string{"api"},
Address: "127.0.0.1:0",
TLSCertFile: "./boundary-startlistenerstest0.cert",
TLSKeyFile: "./boundary-startlistenerstest0.key",
},
{
Type: "tcp",
Purpose: []string{"api"},
Address: "127.0.0.1:0",
TLSCertFile: "./boundary-startlistenerstest1.cert",
TLSKeyFile: "./boundary-startlistenerstest1.key",
},
{
Type: "tcp",
Purpose: []string{"cluster"},
Address: "127.0.0.1:0",
},
},
assertions: func(t *testing.T, c *Controller, apiAddrs []string, clusterAddr string) {
for i, apiAddr := range apiAddrs {
client := testTlsHttpClient(t, "./boundary-startlistenerstest"+strconv.Itoa(i)+".cert")
rsp, err := client.Get("https://" + apiAddr + "/v1/auth-methods?scope_id=global")
require.NoError(t, err)
require.Equal(t, http.StatusOK, rsp.StatusCode)
}
clusterGrpcDialNoError(t, c, "tcp", clusterAddr)
},
},
{
name: "one api, one cluster listener on unix sockets",
listeners: []*listenerutil.ListenerConfig{
{
Type: "unix",
Purpose: []string{"api"},
Address: "/tmp/boundary-listener-test-api.sock",
TLSDisable: true,
},
{
Type: "unix",
Purpose: []string{"cluster"},
Address: "/tmp/boundary-listener-test-cluster.sock",
},
},
assertions: func(t *testing.T, c *Controller, apiAddrs []string, clusterAddr string) {
conn, err := net.Dial("unix", apiAddrs[0])
require.NoError(t, err)
cl := http.Client{
Transport: &http.Transport{
Dial: func(network, addr string) (net.Conn, error) { return conn, nil },
},
}
rsp, err := cl.Get("http://randomdomain.boundary/v1/auth-methods?scope_id=global")
require.NoError(t, err)
require.Equal(t, http.StatusOK, rsp.StatusCode)
require.NoError(t, conn.Close())
clusterGrpcDialNoError(t, c, "unix", clusterAddr)
},
},
{
name: "multiple api, one cluster listener on unix sockets",
listeners: []*listenerutil.ListenerConfig{
{
Type: "unix",
Purpose: []string{"api"},
Address: "/tmp/boundary-listener-test-api.sock",
TLSDisable: true,
},
{
Type: "unix",
Purpose: []string{"api"},
Address: "/tmp/boundary-listener-test-api2.sock",
TLSDisable: true,
},
{
Type: "unix",
Purpose: []string{"cluster"},
Address: "/tmp/boundary-listener-test-cluster.sock",
},
},
assertions: func(t *testing.T, c *Controller, apiAddrs []string, clusterAddr string) {
for _, apiAddr := range apiAddrs {
conn, err := net.Dial("unix", apiAddr)
require.NoError(t, err)
cl := http.Client{
Transport: &http.Transport{
Dial: func(network, addr string) (net.Conn, error) { return conn, nil },
},
}
rsp, err := cl.Get("http://randomdomain.boundary/v1/auth-methods?scope_id=global")
require.NoError(t, err)
require.Equal(t, http.StatusOK, rsp.StatusCode)
require.NoError(t, conn.Close())
}
clusterGrpcDialNoError(t, c, "unix", clusterAddr)
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.setup != nil {
tt.setup(t)
}
ctx, cancel := context.WithCancel(context.Background())
tc := &TestController{t: t, ctx: ctx, cancel: cancel, opts: &TestControllerOpts{}}
t.Cleanup(func() { tc.Shutdown() })
conf := TestControllerConfig(t, ctx, tc, nil)
conf.RawConfig.SharedConfig = &configutil.SharedConfig{Listeners: tt.listeners, DisableMlock: true}
err := conf.SetupListeners(nil, conf.RawConfig.SharedConfig, []string{"api", "cluster"})
require.NoError(t, err)
c, err := New(ctx, conf)
require.NoError(t, err)
c.baseContext = ctx
c.baseCancel = cancel
err = c.startListeners()
require.NoError(t, err)
apiAddrs := make([]string, 0)
for _, l := range c.apiListeners {
apiAddrs = append(apiAddrs, l.Mux.Addr().String())
}
tt.assertions(t, c, apiAddrs, c.clusterListener.Mux.Addr().String())
})
}
}
func clusterGrpcDialNoError(t *testing.T, c *Controller, network, addr string) {
grpcConn, err := grpc.Dial(addr,
grpc.WithInsecure(),
grpc.WithBlock(),
grpc.WithTimeout(5*time.Second),
grpc.WithContextDialer(clusterTestDialer(t, c, network)),
)
require.NoError(t, err)
require.NoError(t, grpcConn.Close())
}
func testApiTlsSetup(t *testing.T, certPath, keyPath string) func(t *testing.T) {
return func(t *testing.T) {
t.Cleanup(func() {
require.NoError(t, os.Remove(certPath))
require.NoError(t, os.Remove(keyPath))
})
certBytes, _, priv := createTestCert(t)
certFile, err := os.Create(certPath)
require.NoError(t, err)
require.NoError(t, pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: certBytes}))
require.NoError(t, certFile.Close())
keyFile, err := os.Create(keyPath)
require.NoError(t, err)
marshaledKey, err := x509.MarshalPKCS8PrivateKey(priv)
require.NoError(t, err)
require.NoError(t, pem.Encode(keyFile, &pem.Block{Type: "PRIVATE KEY", Bytes: marshaledKey}))
require.NoError(t, keyFile.Close())
}
}
func testTlsHttpClient(t *testing.T, filePath string) *http.Client {
f, err := os.Open(filePath)
require.NoError(t, err)
certBytes, err := ioutil.ReadAll(f)
require.NoError(t, err)
require.NoError(t, f.Close())
certPool := x509.NewCertPool()
certPool.AppendCertsFromPEM(certBytes)
return &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{RootCAs: certPool},
},
}
}
func createTestCert(t *testing.T) ([]byte, ed25519.PublicKey, ed25519.PrivateKey) {
pub, priv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
template := &x509.Certificate{
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement | x509.KeyUsageCertSign,
SerialNumber: big.NewInt(0),
NotBefore: time.Now().Add(-30 * time.Second),
NotAfter: time.Now().Add(5 * time.Minute),
BasicConstraintsValid: true,
IsCA: true,
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
DNSNames: []string{"/tmp/boundary-listener-test-cluster.sock"},
}
certBytes, err := x509.CreateCertificate(rand.Reader, template, template, pub, priv)
require.NoError(t, err)
return certBytes, pub, priv
}
func clusterTestDialer(t *testing.T, c *Controller, network string) func(context.Context, string) (net.Conn, error) {
return func(ctx context.Context, addr string) (net.Conn, error) {
certBytes, _, priv := createTestCert(t)
marshaledKey, err := x509.MarshalPKCS8PrivateKey(priv)
require.NoError(t, err)
nonce, err := base62.Random(20)
require.NoError(t, err)
info := base.WorkerAuthInfo{
CertPEM: pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certBytes}),
KeyPEM: pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: marshaledKey}),
ConnectionNonce: nonce,
}
infoBytes, err := json.Marshal(info)
require.NoError(t, err)
encInfo, err := c.conf.WorkerAuthKms.Encrypt(ctx, infoBytes, nil)
require.NoError(t, err)
protoEncInfo, err := proto.Marshal(encInfo)
require.NoError(t, err)
b64alpn := base64.RawStdEncoding.EncodeToString(protoEncInfo)
var nextProtos []string
var count int
for i := 0; i < len(b64alpn); i += 230 {
end := i + 230
if end > len(b64alpn) {
end = len(b64alpn)
}
nextProtos = append(nextProtos, fmt.Sprintf("v1workerauth-%02d-%s", count, b64alpn[i:end]))
count++
}
cert, err := x509.ParseCertificate(certBytes)
require.NoError(t, err)
rootCAs := x509.NewCertPool()
rootCAs.AddCert(cert)
tlsCert, err := tls.X509KeyPair(info.CertPEM, info.KeyPEM)
require.NoError(t, err)
conn, err := tls.Dial(network, addr, &tls.Config{
Certificates: []tls.Certificate{tlsCert},
RootCAs: rootCAs,
NextProtos: nextProtos,
MinVersion: tls.VersionTLS13,
})
require.NoError(t, err)
_, err = conn.Write([]byte(nonce))
require.NoError(t, err)
return conn, nil
}
}
Loading…
Cancel
Save