From 410bec1e00507a52dd5250786bda9967e5611e43 Mon Sep 17 00:00:00 2001 From: Jim Date: Sat, 1 Apr 2023 08:43:04 -0400 Subject: [PATCH] feature(worker/controller): register UpstreamMessageServiceServer --- internal/daemon/cluster/handlers/options.go | 10 ++-- internal/daemon/cluster/handlers/testing.go | 2 +- .../upstream_message_service_controller.go | 4 +- .../handlers/upstream_message_service_test.go | 4 +- .../upstream_message_service_worker.go | 11 ++-- .../daemon/controller/rpc_registration.go | 29 ++++++++++ .../controller/rpc_registration_test.go | 56 +++++++++++++++++++ .../daemon/worker/controller_connection.go | 8 +++ internal/daemon/worker/listeners_test.go | 7 +++ internal/daemon/worker/rpc_registration.go | 28 ++++++++++ .../daemon/worker/rpc_registration_test.go | 56 +++++++++++++++++++ internal/daemon/worker/worker.go | 33 ++++++++--- 12 files changed, 225 insertions(+), 23 deletions(-) create mode 100644 internal/daemon/controller/rpc_registration_test.go create mode 100644 internal/daemon/worker/rpc_registration_test.go diff --git a/internal/daemon/cluster/handlers/options.go b/internal/daemon/cluster/handlers/options.go index 6d6ee1d199..efe4dadc4b 100644 --- a/internal/daemon/cluster/handlers/options.go +++ b/internal/daemon/cluster/handlers/options.go @@ -4,7 +4,7 @@ package handlers import ( - "github.com/hashicorp/nodeenrollment/types" + "github.com/hashicorp/nodeenrollment" ) // getOpts - iterate the inbound Options and return a struct @@ -21,16 +21,16 @@ type Option func(*options) // options = how options are represented type options struct { - withNodeInfo *types.NodeInformation + withKeyProducer nodeenrollment.X25519KeyProducer } func getDefaultOptions() options { return options{} } -// WithNodeInfo provides an option types.NodeInformation -func WithNodeInfo(nodeInfo *types.NodeInformation) Option { +// WithKeyProducer provides an option types.NodeInformation +func WithKeyProducer(nodeInfo nodeenrollment.X25519KeyProducer) Option { return func(o *options) { - o.withNodeInfo = nodeInfo + o.withKeyProducer = nodeInfo } } diff --git a/internal/daemon/cluster/handlers/testing.go b/internal/daemon/cluster/handlers/testing.go index f92ff5b1ad..7e797aa9db 100644 --- a/internal/daemon/cluster/handlers/testing.go +++ b/internal/daemon/cluster/handlers/testing.go @@ -55,7 +55,7 @@ func TestUpstreamService(t *testing.T) (UpstreamMessageServiceClientProducer, *t require.NoError(t, err) // start an upstream controller - testController, err := newControllerUpstreamMessageServiceServer(testCtx, initStorage) + testController, err := NewControllerUpstreamMessageServiceServer(testCtx, initStorage) require.NoError(t, err) require.NotNil(t, testController) diff --git a/internal/daemon/cluster/handlers/upstream_message_service_controller.go b/internal/daemon/cluster/handlers/upstream_message_service_controller.go index dc07f68c71..545587c9ae 100644 --- a/internal/daemon/cluster/handlers/upstream_message_service_controller.go +++ b/internal/daemon/cluster/handlers/upstream_message_service_controller.go @@ -91,10 +91,10 @@ type controllerUpstreamMessageServiceServer struct { var _ pbs.UpstreamMessageServiceServer = (*controllerUpstreamMessageServiceServer)(nil) -// newControllerUpstreamMessageServiceServer creates a new service implementing +// NewControllerUpstreamMessageServiceServer creates a new service implementing // UpstreamMessageServiceServer, storing values used for the implementing // functions. -func newControllerUpstreamMessageServiceServer( +func NewControllerUpstreamMessageServiceServer( ctx context.Context, storage nodeenrollment.Storage, ) (pbs.UpstreamMessageServiceServer, error) { diff --git a/internal/daemon/cluster/handlers/upstream_message_service_test.go b/internal/daemon/cluster/handlers/upstream_message_service_test.go index 9c42f14a82..cc9f901dbb 100644 --- a/internal/daemon/cluster/handlers/upstream_message_service_test.go +++ b/internal/daemon/cluster/handlers/upstream_message_service_test.go @@ -111,7 +111,7 @@ func Test_controllerUpstreamMessageServiceServer_UpstreamMessage(t *testing.T) { nodeInfo, err := types.LoadNodeInformation(testCtx, initStorage, initKeyId) require.NoError(t, err) // define a test controller - testController, err := newControllerUpstreamMessageServiceServer(testCtx, initStorage) + testController, err := NewControllerUpstreamMessageServiceServer(testCtx, initStorage) require.NoError(t, err) require.NotNil(t, testController) @@ -317,7 +317,7 @@ func Test_SendUpstreamMessage(t *testing.T) { clientProducer: workerClientProducer, originatingWorkerId: originatingWorkerId, req: &pbs.EchoUpstreamMessageRequest{Msg: "ping"}, - opt: []Option{WithNodeInfo(workerNodeInfo)}, + opt: []Option{WithKeyProducer(workerNodeInfo)}, wantResp: &pbs.EchoUpstreamMessageResponse{Msg: "ping"}, setupHandlers: TestRegisterHandlerFn(t, pbs.MsgType_MSG_TYPE_ECHO, testEncryptedHandler), }, diff --git a/internal/daemon/cluster/handlers/upstream_message_service_worker.go b/internal/daemon/cluster/handlers/upstream_message_service_worker.go index 43ca5b3435..e1dbb8f265 100644 --- a/internal/daemon/cluster/handlers/upstream_message_service_worker.go +++ b/internal/daemon/cluster/handlers/upstream_message_service_worker.go @@ -11,7 +11,6 @@ import ( pbs "github.com/hashicorp/boundary/internal/gen/controller/servers/services" "github.com/hashicorp/boundary/internal/util" "github.com/hashicorp/nodeenrollment" - "github.com/hashicorp/nodeenrollment/types" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/proto" @@ -93,10 +92,10 @@ func SendUpstreamMessage(ctx context.Context, clientProducer UpstreamMessageServ return nil, errors.Wrap(ctx, err, op) } default: - if opts.withNodeInfo == nil { + if opts.withKeyProducer == nil { return nil, errors.New(ctx, errors.InvalidParameter, op, "missing node information required for encrypting unwrap keys message") } - req, err = ctMsg(ctx, opts.withNodeInfo, msgType, msg) + req, err = ctMsg(ctx, opts.withKeyProducer, msgType, msg) if err != nil { return nil, errors.Wrap(ctx, err, op) } @@ -123,7 +122,7 @@ func SendUpstreamMessage(ctx context.Context, clientProducer UpstreamMessageServ return pt, nil default: ct := h.AllocResponse() - if err := nodeenrollment.DecryptMessage(ctx, rawResp.GetCt(), opts.withNodeInfo, ct); err != nil { + if err := nodeenrollment.DecryptMessage(ctx, rawResp.GetCt(), opts.withKeyProducer, ct); err != nil { return nil, errors.Wrap(ctx, err, op, errors.WithMsg("error decrypting unwrap keys response")) } return ct, nil @@ -145,9 +144,9 @@ func ptMsg(ctx context.Context, msgType pbs.MsgType, msg proto.Message) (*pbs.Up }, nil } -func ctMsg(ctx context.Context, nodeInfo *types.NodeInformation, msgType pbs.MsgType, msg proto.Message) (*pbs.UpstreamMessageRequest, error) { +func ctMsg(ctx context.Context, keySource nodeenrollment.X25519KeyProducer, msgType pbs.MsgType, msg proto.Message) (*pbs.UpstreamMessageRequest, error) { const op = "handlers.encryptMsg" - ct, err := nodeenrollment.EncryptMessage(ctx, msg, nodeInfo) + ct, err := nodeenrollment.EncryptMessage(ctx, msg, keySource) if err != nil { return nil, errors.Wrap(ctx, err, op, errors.WithMsg("error encrypting upstream message")) } diff --git a/internal/daemon/controller/rpc_registration.go b/internal/daemon/controller/rpc_registration.go index 544642d161..bcfb43312b 100644 --- a/internal/daemon/controller/rpc_registration.go +++ b/internal/daemon/controller/rpc_registration.go @@ -21,6 +21,7 @@ func init() { registerControllerServerCoordinationService, registerControllerSessionService, registerControllerMultihopService, + registerControllerUpstreamMessageService, ) } @@ -88,3 +89,31 @@ func registerControllerMultihopService(ctx context.Context, c *Controller, serve multihop.RegisterMultihopServiceServer(server, multihopService) return nil } + +func registerControllerUpstreamMessageService(ctx context.Context, c *Controller, server *grpc.Server) error { + const op = "controller.registerControllerUpstreamMessageService" + + switch { + case nodeenrollment.IsNil(ctx): + return fmt.Errorf("%s: context is nil", op) + case c == nil: + return fmt.Errorf("%s: controller is nil", op) + case server == nil: + return fmt.Errorf("%s: server is nil", op) + } + + workerAuthStorage, err := c.WorkerAuthRepoStorageFn() + switch { + case err != nil: + return fmt.Errorf("%s: error fetching worker auth storage: %w", op, err) + case workerAuthStorage == nil: + return fmt.Errorf("%s: worker auth repository storage func is unset", op) + } + + upstreamMsgService, err := handlers.NewControllerUpstreamMessageServiceServer(ctx, workerAuthStorage) + if err != nil { + return fmt.Errorf("%s: error creating upstream message service handler: %w", op, err) + } + pbs.RegisterUpstreamMessageServiceServer(server, upstreamMsgService) + return nil +} diff --git a/internal/daemon/controller/rpc_registration_test.go b/internal/daemon/controller/rpc_registration_test.go new file mode 100644 index 0000000000..24d6ba60e9 --- /dev/null +++ b/internal/daemon/controller/rpc_registration_test.go @@ -0,0 +1,56 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package controller_test + +import ( + "context" + "crypto/rand" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/hashicorp/boundary/internal/cmd/config" + "github.com/hashicorp/boundary/internal/daemon/cluster/handlers" + "github.com/hashicorp/boundary/internal/daemon/controller" + "github.com/hashicorp/boundary/internal/daemon/worker" + "github.com/hashicorp/boundary/internal/db" + pbs "github.com/hashicorp/boundary/internal/gen/controller/servers/services" + "github.com/hashicorp/boundary/internal/kms" + "github.com/hashicorp/boundary/internal/types/scope" + "github.com/hashicorp/go-hclog" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/testing/protocmp" +) + +func Test_Controller_RegisterUpstreamMessageServices(t *testing.T) { + assert, require := assert.New(t), require.New(t) + testCtx := context.Background() + conn, _ := db.TestSetup(t, "postgres") + wrapper := db.TestWrapper(t) + kmsCache := kms.TestKms(t, conn, wrapper) + require.NoError(kmsCache.CreateKeys(context.Background(), scope.Global.String(), kms.WithRandomReader(rand.Reader))) + + logger := hclog.New(&hclog.LoggerOptions{ + Level: hclog.Trace, + }) + conf, err := config.DevController() + require.NoError(err) + c := controller.NewTestController(t, &controller.TestControllerOpts{ + Config: conf, + Logger: logger.Named("controller"), + }) + t.Cleanup(c.Shutdown) + kmsWorker, pkiWorker, _, _ := worker.NewTestMultihopWorkers(t, logger, c.Context(), c.ClusterAddrs(), + c.Config().WorkerAuthKms, c.Controller().ServersRepoFn, nil, nil, nil, nil) + t.Cleanup(kmsWorker.Shutdown) + t.Cleanup(pkiWorker.Shutdown) + + err = handlers.RegisterUpstreamMessageHandler(testCtx, pbs.MsgType_MSG_TYPE_ECHO, &handlers.TestMockUpstreamMessageHandler{}) + require.NoError(err) + + resp, err := pkiWorker.Worker().SendUpstreamMessage(testCtx, &pbs.EchoUpstreamMessageRequest{Msg: "ping"}) + require.NoError(err) + + assert.Empty(cmp.Diff(resp, &pbs.EchoUpstreamMessageResponse{Msg: "ping"}, protocmp.Transform())) +} diff --git a/internal/daemon/worker/controller_connection.go b/internal/daemon/worker/controller_connection.go index 142c754af4..9da64036b6 100644 --- a/internal/daemon/worker/controller_connection.go +++ b/internal/daemon/worker/controller_connection.go @@ -24,6 +24,7 @@ import ( "github.com/hashicorp/boundary/globals" "github.com/hashicorp/boundary/internal/cmd/base" "github.com/hashicorp/boundary/internal/daemon/cluster" + "github.com/hashicorp/boundary/internal/daemon/cluster/handlers" "github.com/hashicorp/boundary/internal/daemon/worker/internal/metric" "github.com/hashicorp/boundary/internal/errors" pbs "github.com/hashicorp/boundary/internal/gen/controller/servers/services" @@ -264,6 +265,13 @@ func (w *Worker) createClientConn(addr string) error { w.controllerStatusConn.Store(pbs.NewServerCoordinationServiceClient(cc)) w.controllerMultihopConn.Store(multihop.NewMultihopServiceClient(cc)) + var producer handlers.UpstreamMessageServiceClientProducer + producer = func(context.Context) (pbs.UpstreamMessageServiceClient, error) { + return pbs.NewUpstreamMessageServiceClient(cc), nil + } + + w.controllerUpstreamMsgConn.Store(&producer) + return nil } diff --git a/internal/daemon/worker/listeners_test.go b/internal/daemon/worker/listeners_test.go index a4ec12e56a..36d874df1c 100644 --- a/internal/daemon/worker/listeners_test.go +++ b/internal/daemon/worker/listeners_test.go @@ -12,6 +12,7 @@ import ( "testing" "github.com/hashicorp/boundary/internal/cmd/base" + "github.com/hashicorp/boundary/internal/daemon/cluster/handlers" "github.com/hashicorp/boundary/internal/daemon/worker/session" pbs "github.com/hashicorp/boundary/internal/gen/controller/servers/services" "github.com/hashicorp/go-secure-stdlib/listenerutil" @@ -81,6 +82,12 @@ func TestStartListeners(t *testing.T) { require.NoError(t, err) w.baseContext = context.Background() + var dummyClientProducer handlers.UpstreamMessageServiceClientProducer + dummyClientProducer = func(ctx context.Context) (pbs.UpstreamMessageServiceClient, error) { + panic("stubbed producer: not used in this test") + } + w.controllerUpstreamMsgConn.Store(&dummyClientProducer) + manager, err := session.NewManager(pbs.NewSessionServiceClient(w.GrpcClientConn)) require.NoError(t, err) err = w.startListeners(manager) diff --git a/internal/daemon/worker/rpc_registration.go b/internal/daemon/worker/rpc_registration.go index 0f995d52bc..ccee878615 100644 --- a/internal/daemon/worker/rpc_registration.go +++ b/internal/daemon/worker/rpc_registration.go @@ -9,6 +9,7 @@ import ( "github.com/hashicorp/boundary/internal/daemon/cluster/handlers" pbs "github.com/hashicorp/boundary/internal/gen/controller/servers/services" + "github.com/hashicorp/boundary/internal/util" "github.com/hashicorp/nodeenrollment" "github.com/hashicorp/nodeenrollment/multihop" "google.golang.org/grpc" @@ -20,6 +21,7 @@ func init() { workerGrpcServiceRegistrationFunctions = append(workerGrpcServiceRegistrationFunctions, registerWorkerStatusSessionService, registerWorkerMultihopService, + registerWorkerUpstreamMessageService, ) } @@ -64,3 +66,29 @@ func registerWorkerMultihopService(ctx context.Context, w *Worker, server *grpc. multihop.RegisterMultihopServiceServer(server, multihopService) return nil } + +func registerWorkerUpstreamMessageService(ctx context.Context, w *Worker, server *grpc.Server) error { + const op = "worker.registerWorkerUpstreamMessageService" + + switch { + case util.IsNil(ctx): + return fmt.Errorf("%s: context is nil", op) + case w == nil: + return fmt.Errorf("%s: controller is nil", op) + case server == nil: + return fmt.Errorf("%s: server is nil", op) + } + + clientProducer := w.controllerUpstreamMsgConn.Load() + switch { + case clientProducer == nil: + return fmt.Errorf("%s: upstream message service client producer is unset", op) + } + + upstreamMsgService, err := handlers.NewWorkerUpstreamMessageServiceServer(ctx, *clientProducer) + if err != nil { + return fmt.Errorf("%s: error creating multihop service handler: %w", op, err) + } + pbs.RegisterUpstreamMessageServiceServer(server, upstreamMsgService) + return nil +} diff --git a/internal/daemon/worker/rpc_registration_test.go b/internal/daemon/worker/rpc_registration_test.go new file mode 100644 index 0000000000..b87e663ec3 --- /dev/null +++ b/internal/daemon/worker/rpc_registration_test.go @@ -0,0 +1,56 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package worker_test + +import ( + "context" + "crypto/rand" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/hashicorp/boundary/internal/cmd/config" + "github.com/hashicorp/boundary/internal/daemon/cluster/handlers" + "github.com/hashicorp/boundary/internal/daemon/controller" + "github.com/hashicorp/boundary/internal/daemon/worker" + "github.com/hashicorp/boundary/internal/db" + pbs "github.com/hashicorp/boundary/internal/gen/controller/servers/services" + "github.com/hashicorp/boundary/internal/kms" + "github.com/hashicorp/boundary/internal/types/scope" + "github.com/hashicorp/go-hclog" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/testing/protocmp" +) + +func Test_Worker_RegisterUpstreamMessageServices(t *testing.T) { + assert, require := assert.New(t), require.New(t) + testCtx := context.Background() + conn, _ := db.TestSetup(t, "postgres") + wrapper := db.TestWrapper(t) + kmsCache := kms.TestKms(t, conn, wrapper) + require.NoError(kmsCache.CreateKeys(context.Background(), scope.Global.String(), kms.WithRandomReader(rand.Reader))) + + logger := hclog.New(&hclog.LoggerOptions{ + Level: hclog.Trace, + }) + conf, err := config.DevController() + require.NoError(err) + c := controller.NewTestController(t, &controller.TestControllerOpts{ + Config: conf, + Logger: logger.Named("controller"), + }) + t.Cleanup(c.Shutdown) + kmsWorker, pkiWorker, _, _ := worker.NewTestMultihopWorkers(t, logger, c.Context(), c.ClusterAddrs(), + c.Config().WorkerAuthKms, c.Controller().ServersRepoFn, nil, nil, nil, nil) + t.Cleanup(kmsWorker.Shutdown) + t.Cleanup(pkiWorker.Shutdown) + + err = handlers.RegisterUpstreamMessageHandler(testCtx, pbs.MsgType_MSG_TYPE_ECHO, &handlers.TestMockUpstreamMessageHandler{}) + require.NoError(err) + + resp, err := pkiWorker.Worker().SendUpstreamMessage(testCtx, &pbs.EchoUpstreamMessageRequest{Msg: "ping"}) + require.NoError(err) + + assert.Empty(cmp.Diff(resp, &pbs.EchoUpstreamMessageResponse{Msg: "ping"}, protocmp.Transform())) +} diff --git a/internal/daemon/worker/worker.go b/internal/daemon/worker/worker.go index 2ee46f6a29..8c495258ca 100644 --- a/internal/daemon/worker/worker.go +++ b/internal/daemon/worker/worker.go @@ -20,6 +20,7 @@ import ( "github.com/hashicorp/boundary/internal/cmd/base" "github.com/hashicorp/boundary/internal/cmd/config" "github.com/hashicorp/boundary/internal/daemon/cluster" + "github.com/hashicorp/boundary/internal/daemon/cluster/handlers" "github.com/hashicorp/boundary/internal/daemon/worker/common" "github.com/hashicorp/boundary/internal/daemon/worker/internal/metric" "github.com/hashicorp/boundary/internal/daemon/worker/proxy" @@ -123,6 +124,8 @@ type Worker struct { controllerMultihopConn *atomic.Value + controllerUpstreamMsgConn atomic.Pointer[handlers.UpstreamMessageServiceClientProducer] + proxyListener *base.ServerListener // Used to generate a random nonce for Controller connections @@ -178,13 +181,14 @@ func New(conf *Config) (*Worker, error) { initializeReverseGrpcClientCollectors(conf.PrometheusRegisterer) w := &Worker{ - conf: conf, - logger: conf.Logger.Named("worker"), - started: ua.NewBool(false), - controllerStatusConn: new(atomic.Value), - everAuthenticated: ua.NewUint32(authenticationStatusNeverAuthenticated), - lastStatusSuccess: new(atomic.Value), - controllerMultihopConn: new(atomic.Value), + conf: conf, + logger: conf.Logger.Named("worker"), + started: ua.NewBool(false), + controllerStatusConn: new(atomic.Value), + everAuthenticated: ua.NewUint32(authenticationStatusNeverAuthenticated), + lastStatusSuccess: new(atomic.Value), + controllerMultihopConn: new(atomic.Value), + // controllerUpstreamMsgConn: new(atomic.Value), tags: new(atomic.Value), updateTags: ua.NewBool(false), nonceFn: base62.Random, @@ -736,3 +740,18 @@ func (w *Worker) getSessionTls(sessionManager session.Manager) func(hello *tls.C return tlsConf, nil } } + +// SendUpstreamMessage facilitates sending upstream messages to the controller. +func (w *Worker) SendUpstreamMessage(ctx context.Context, m proto.Message) (proto.Message, error) { + const op = "worker.(Worker).SendUpstreamMessage" + nodeCreds, err := types.LoadNodeCredentials(w.baseContext, w.WorkerAuthStorage, nodeenrollment.CurrentId, nodeenrollment.WithStorageWrapper(w.conf.WorkerAuthStorageKms)) + if err != nil { + return nil, errors.Wrap(ctx, err, op) + } + initKeyId, err := nodeenrollment.KeyIdFromPkix(nodeCreds.CertificatePublicKeyPkix) + if err != nil { + return nil, errors.Wrap(ctx, err, op) + } + clientProducer := w.controllerUpstreamMsgConn.Load() + return handlers.SendUpstreamMessage(ctx, *clientProducer, initKeyId, m, handlers.WithKeyProducer(nodeCreds)) +}