feature(worker/controller): register UpstreamMessageServiceServer

pull/3251/head
Jim 3 years ago committed by Timothy Messier
parent 20cf521cc0
commit 410bec1e00
No known key found for this signature in database
GPG Key ID: EFD2F184F7600572

@ -4,7 +4,7 @@
package handlers
import (
"github.com/hashicorp/nodeenrollment/types"
"github.com/hashicorp/nodeenrollment"
)
// getOpts - iterate the inbound Options and return a struct
@ -21,16 +21,16 @@ type Option func(*options)
// options = how options are represented
type options struct {
withNodeInfo *types.NodeInformation
withKeyProducer nodeenrollment.X25519KeyProducer
}
func getDefaultOptions() options {
return options{}
}
// WithNodeInfo provides an option types.NodeInformation
func WithNodeInfo(nodeInfo *types.NodeInformation) Option {
// WithKeyProducer provides an option types.NodeInformation
func WithKeyProducer(nodeInfo nodeenrollment.X25519KeyProducer) Option {
return func(o *options) {
o.withNodeInfo = nodeInfo
o.withKeyProducer = nodeInfo
}
}

@ -55,7 +55,7 @@ func TestUpstreamService(t *testing.T) (UpstreamMessageServiceClientProducer, *t
require.NoError(t, err)
// start an upstream controller
testController, err := newControllerUpstreamMessageServiceServer(testCtx, initStorage)
testController, err := NewControllerUpstreamMessageServiceServer(testCtx, initStorage)
require.NoError(t, err)
require.NotNil(t, testController)

@ -91,10 +91,10 @@ type controllerUpstreamMessageServiceServer struct {
var _ pbs.UpstreamMessageServiceServer = (*controllerUpstreamMessageServiceServer)(nil)
// newControllerUpstreamMessageServiceServer creates a new service implementing
// NewControllerUpstreamMessageServiceServer creates a new service implementing
// UpstreamMessageServiceServer, storing values used for the implementing
// functions.
func newControllerUpstreamMessageServiceServer(
func NewControllerUpstreamMessageServiceServer(
ctx context.Context,
storage nodeenrollment.Storage,
) (pbs.UpstreamMessageServiceServer, error) {

@ -111,7 +111,7 @@ func Test_controllerUpstreamMessageServiceServer_UpstreamMessage(t *testing.T) {
nodeInfo, err := types.LoadNodeInformation(testCtx, initStorage, initKeyId)
require.NoError(t, err)
// define a test controller
testController, err := newControllerUpstreamMessageServiceServer(testCtx, initStorage)
testController, err := NewControllerUpstreamMessageServiceServer(testCtx, initStorage)
require.NoError(t, err)
require.NotNil(t, testController)
@ -317,7 +317,7 @@ func Test_SendUpstreamMessage(t *testing.T) {
clientProducer: workerClientProducer,
originatingWorkerId: originatingWorkerId,
req: &pbs.EchoUpstreamMessageRequest{Msg: "ping"},
opt: []Option{WithNodeInfo(workerNodeInfo)},
opt: []Option{WithKeyProducer(workerNodeInfo)},
wantResp: &pbs.EchoUpstreamMessageResponse{Msg: "ping"},
setupHandlers: TestRegisterHandlerFn(t, pbs.MsgType_MSG_TYPE_ECHO, testEncryptedHandler),
},

@ -11,7 +11,6 @@ import (
pbs "github.com/hashicorp/boundary/internal/gen/controller/servers/services"
"github.com/hashicorp/boundary/internal/util"
"github.com/hashicorp/nodeenrollment"
"github.com/hashicorp/nodeenrollment/types"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
@ -93,10 +92,10 @@ func SendUpstreamMessage(ctx context.Context, clientProducer UpstreamMessageServ
return nil, errors.Wrap(ctx, err, op)
}
default:
if opts.withNodeInfo == nil {
if opts.withKeyProducer == nil {
return nil, errors.New(ctx, errors.InvalidParameter, op, "missing node information required for encrypting unwrap keys message")
}
req, err = ctMsg(ctx, opts.withNodeInfo, msgType, msg)
req, err = ctMsg(ctx, opts.withKeyProducer, msgType, msg)
if err != nil {
return nil, errors.Wrap(ctx, err, op)
}
@ -123,7 +122,7 @@ func SendUpstreamMessage(ctx context.Context, clientProducer UpstreamMessageServ
return pt, nil
default:
ct := h.AllocResponse()
if err := nodeenrollment.DecryptMessage(ctx, rawResp.GetCt(), opts.withNodeInfo, ct); err != nil {
if err := nodeenrollment.DecryptMessage(ctx, rawResp.GetCt(), opts.withKeyProducer, ct); err != nil {
return nil, errors.Wrap(ctx, err, op, errors.WithMsg("error decrypting unwrap keys response"))
}
return ct, nil
@ -145,9 +144,9 @@ func ptMsg(ctx context.Context, msgType pbs.MsgType, msg proto.Message) (*pbs.Up
}, nil
}
func ctMsg(ctx context.Context, nodeInfo *types.NodeInformation, msgType pbs.MsgType, msg proto.Message) (*pbs.UpstreamMessageRequest, error) {
func ctMsg(ctx context.Context, keySource nodeenrollment.X25519KeyProducer, msgType pbs.MsgType, msg proto.Message) (*pbs.UpstreamMessageRequest, error) {
const op = "handlers.encryptMsg"
ct, err := nodeenrollment.EncryptMessage(ctx, msg, nodeInfo)
ct, err := nodeenrollment.EncryptMessage(ctx, msg, keySource)
if err != nil {
return nil, errors.Wrap(ctx, err, op, errors.WithMsg("error encrypting upstream message"))
}

@ -21,6 +21,7 @@ func init() {
registerControllerServerCoordinationService,
registerControllerSessionService,
registerControllerMultihopService,
registerControllerUpstreamMessageService,
)
}
@ -88,3 +89,31 @@ func registerControllerMultihopService(ctx context.Context, c *Controller, serve
multihop.RegisterMultihopServiceServer(server, multihopService)
return nil
}
func registerControllerUpstreamMessageService(ctx context.Context, c *Controller, server *grpc.Server) error {
const op = "controller.registerControllerUpstreamMessageService"
switch {
case nodeenrollment.IsNil(ctx):
return fmt.Errorf("%s: context is nil", op)
case c == nil:
return fmt.Errorf("%s: controller is nil", op)
case server == nil:
return fmt.Errorf("%s: server is nil", op)
}
workerAuthStorage, err := c.WorkerAuthRepoStorageFn()
switch {
case err != nil:
return fmt.Errorf("%s: error fetching worker auth storage: %w", op, err)
case workerAuthStorage == nil:
return fmt.Errorf("%s: worker auth repository storage func is unset", op)
}
upstreamMsgService, err := handlers.NewControllerUpstreamMessageServiceServer(ctx, workerAuthStorage)
if err != nil {
return fmt.Errorf("%s: error creating upstream message service handler: %w", op, err)
}
pbs.RegisterUpstreamMessageServiceServer(server, upstreamMsgService)
return nil
}

@ -0,0 +1,56 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package controller_test
import (
"context"
"crypto/rand"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/hashicorp/boundary/internal/cmd/config"
"github.com/hashicorp/boundary/internal/daemon/cluster/handlers"
"github.com/hashicorp/boundary/internal/daemon/controller"
"github.com/hashicorp/boundary/internal/daemon/worker"
"github.com/hashicorp/boundary/internal/db"
pbs "github.com/hashicorp/boundary/internal/gen/controller/servers/services"
"github.com/hashicorp/boundary/internal/kms"
"github.com/hashicorp/boundary/internal/types/scope"
"github.com/hashicorp/go-hclog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/testing/protocmp"
)
func Test_Controller_RegisterUpstreamMessageServices(t *testing.T) {
assert, require := assert.New(t), require.New(t)
testCtx := context.Background()
conn, _ := db.TestSetup(t, "postgres")
wrapper := db.TestWrapper(t)
kmsCache := kms.TestKms(t, conn, wrapper)
require.NoError(kmsCache.CreateKeys(context.Background(), scope.Global.String(), kms.WithRandomReader(rand.Reader)))
logger := hclog.New(&hclog.LoggerOptions{
Level: hclog.Trace,
})
conf, err := config.DevController()
require.NoError(err)
c := controller.NewTestController(t, &controller.TestControllerOpts{
Config: conf,
Logger: logger.Named("controller"),
})
t.Cleanup(c.Shutdown)
kmsWorker, pkiWorker, _, _ := worker.NewTestMultihopWorkers(t, logger, c.Context(), c.ClusterAddrs(),
c.Config().WorkerAuthKms, c.Controller().ServersRepoFn, nil, nil, nil, nil)
t.Cleanup(kmsWorker.Shutdown)
t.Cleanup(pkiWorker.Shutdown)
err = handlers.RegisterUpstreamMessageHandler(testCtx, pbs.MsgType_MSG_TYPE_ECHO, &handlers.TestMockUpstreamMessageHandler{})
require.NoError(err)
resp, err := pkiWorker.Worker().SendUpstreamMessage(testCtx, &pbs.EchoUpstreamMessageRequest{Msg: "ping"})
require.NoError(err)
assert.Empty(cmp.Diff(resp, &pbs.EchoUpstreamMessageResponse{Msg: "ping"}, protocmp.Transform()))
}

@ -24,6 +24,7 @@ import (
"github.com/hashicorp/boundary/globals"
"github.com/hashicorp/boundary/internal/cmd/base"
"github.com/hashicorp/boundary/internal/daemon/cluster"
"github.com/hashicorp/boundary/internal/daemon/cluster/handlers"
"github.com/hashicorp/boundary/internal/daemon/worker/internal/metric"
"github.com/hashicorp/boundary/internal/errors"
pbs "github.com/hashicorp/boundary/internal/gen/controller/servers/services"
@ -264,6 +265,13 @@ func (w *Worker) createClientConn(addr string) error {
w.controllerStatusConn.Store(pbs.NewServerCoordinationServiceClient(cc))
w.controllerMultihopConn.Store(multihop.NewMultihopServiceClient(cc))
var producer handlers.UpstreamMessageServiceClientProducer
producer = func(context.Context) (pbs.UpstreamMessageServiceClient, error) {
return pbs.NewUpstreamMessageServiceClient(cc), nil
}
w.controllerUpstreamMsgConn.Store(&producer)
return nil
}

@ -12,6 +12,7 @@ import (
"testing"
"github.com/hashicorp/boundary/internal/cmd/base"
"github.com/hashicorp/boundary/internal/daemon/cluster/handlers"
"github.com/hashicorp/boundary/internal/daemon/worker/session"
pbs "github.com/hashicorp/boundary/internal/gen/controller/servers/services"
"github.com/hashicorp/go-secure-stdlib/listenerutil"
@ -81,6 +82,12 @@ func TestStartListeners(t *testing.T) {
require.NoError(t, err)
w.baseContext = context.Background()
var dummyClientProducer handlers.UpstreamMessageServiceClientProducer
dummyClientProducer = func(ctx context.Context) (pbs.UpstreamMessageServiceClient, error) {
panic("stubbed producer: not used in this test")
}
w.controllerUpstreamMsgConn.Store(&dummyClientProducer)
manager, err := session.NewManager(pbs.NewSessionServiceClient(w.GrpcClientConn))
require.NoError(t, err)
err = w.startListeners(manager)

@ -9,6 +9,7 @@ import (
"github.com/hashicorp/boundary/internal/daemon/cluster/handlers"
pbs "github.com/hashicorp/boundary/internal/gen/controller/servers/services"
"github.com/hashicorp/boundary/internal/util"
"github.com/hashicorp/nodeenrollment"
"github.com/hashicorp/nodeenrollment/multihop"
"google.golang.org/grpc"
@ -20,6 +21,7 @@ func init() {
workerGrpcServiceRegistrationFunctions = append(workerGrpcServiceRegistrationFunctions,
registerWorkerStatusSessionService,
registerWorkerMultihopService,
registerWorkerUpstreamMessageService,
)
}
@ -64,3 +66,29 @@ func registerWorkerMultihopService(ctx context.Context, w *Worker, server *grpc.
multihop.RegisterMultihopServiceServer(server, multihopService)
return nil
}
func registerWorkerUpstreamMessageService(ctx context.Context, w *Worker, server *grpc.Server) error {
const op = "worker.registerWorkerUpstreamMessageService"
switch {
case util.IsNil(ctx):
return fmt.Errorf("%s: context is nil", op)
case w == nil:
return fmt.Errorf("%s: controller is nil", op)
case server == nil:
return fmt.Errorf("%s: server is nil", op)
}
clientProducer := w.controllerUpstreamMsgConn.Load()
switch {
case clientProducer == nil:
return fmt.Errorf("%s: upstream message service client producer is unset", op)
}
upstreamMsgService, err := handlers.NewWorkerUpstreamMessageServiceServer(ctx, *clientProducer)
if err != nil {
return fmt.Errorf("%s: error creating multihop service handler: %w", op, err)
}
pbs.RegisterUpstreamMessageServiceServer(server, upstreamMsgService)
return nil
}

@ -0,0 +1,56 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package worker_test
import (
"context"
"crypto/rand"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/hashicorp/boundary/internal/cmd/config"
"github.com/hashicorp/boundary/internal/daemon/cluster/handlers"
"github.com/hashicorp/boundary/internal/daemon/controller"
"github.com/hashicorp/boundary/internal/daemon/worker"
"github.com/hashicorp/boundary/internal/db"
pbs "github.com/hashicorp/boundary/internal/gen/controller/servers/services"
"github.com/hashicorp/boundary/internal/kms"
"github.com/hashicorp/boundary/internal/types/scope"
"github.com/hashicorp/go-hclog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/testing/protocmp"
)
func Test_Worker_RegisterUpstreamMessageServices(t *testing.T) {
assert, require := assert.New(t), require.New(t)
testCtx := context.Background()
conn, _ := db.TestSetup(t, "postgres")
wrapper := db.TestWrapper(t)
kmsCache := kms.TestKms(t, conn, wrapper)
require.NoError(kmsCache.CreateKeys(context.Background(), scope.Global.String(), kms.WithRandomReader(rand.Reader)))
logger := hclog.New(&hclog.LoggerOptions{
Level: hclog.Trace,
})
conf, err := config.DevController()
require.NoError(err)
c := controller.NewTestController(t, &controller.TestControllerOpts{
Config: conf,
Logger: logger.Named("controller"),
})
t.Cleanup(c.Shutdown)
kmsWorker, pkiWorker, _, _ := worker.NewTestMultihopWorkers(t, logger, c.Context(), c.ClusterAddrs(),
c.Config().WorkerAuthKms, c.Controller().ServersRepoFn, nil, nil, nil, nil)
t.Cleanup(kmsWorker.Shutdown)
t.Cleanup(pkiWorker.Shutdown)
err = handlers.RegisterUpstreamMessageHandler(testCtx, pbs.MsgType_MSG_TYPE_ECHO, &handlers.TestMockUpstreamMessageHandler{})
require.NoError(err)
resp, err := pkiWorker.Worker().SendUpstreamMessage(testCtx, &pbs.EchoUpstreamMessageRequest{Msg: "ping"})
require.NoError(err)
assert.Empty(cmp.Diff(resp, &pbs.EchoUpstreamMessageResponse{Msg: "ping"}, protocmp.Transform()))
}

@ -20,6 +20,7 @@ import (
"github.com/hashicorp/boundary/internal/cmd/base"
"github.com/hashicorp/boundary/internal/cmd/config"
"github.com/hashicorp/boundary/internal/daemon/cluster"
"github.com/hashicorp/boundary/internal/daemon/cluster/handlers"
"github.com/hashicorp/boundary/internal/daemon/worker/common"
"github.com/hashicorp/boundary/internal/daemon/worker/internal/metric"
"github.com/hashicorp/boundary/internal/daemon/worker/proxy"
@ -123,6 +124,8 @@ type Worker struct {
controllerMultihopConn *atomic.Value
controllerUpstreamMsgConn atomic.Pointer[handlers.UpstreamMessageServiceClientProducer]
proxyListener *base.ServerListener
// Used to generate a random nonce for Controller connections
@ -178,13 +181,14 @@ func New(conf *Config) (*Worker, error) {
initializeReverseGrpcClientCollectors(conf.PrometheusRegisterer)
w := &Worker{
conf: conf,
logger: conf.Logger.Named("worker"),
started: ua.NewBool(false),
controllerStatusConn: new(atomic.Value),
everAuthenticated: ua.NewUint32(authenticationStatusNeverAuthenticated),
lastStatusSuccess: new(atomic.Value),
controllerMultihopConn: new(atomic.Value),
conf: conf,
logger: conf.Logger.Named("worker"),
started: ua.NewBool(false),
controllerStatusConn: new(atomic.Value),
everAuthenticated: ua.NewUint32(authenticationStatusNeverAuthenticated),
lastStatusSuccess: new(atomic.Value),
controllerMultihopConn: new(atomic.Value),
// controllerUpstreamMsgConn: new(atomic.Value),
tags: new(atomic.Value),
updateTags: ua.NewBool(false),
nonceFn: base62.Random,
@ -736,3 +740,18 @@ func (w *Worker) getSessionTls(sessionManager session.Manager) func(hello *tls.C
return tlsConf, nil
}
}
// SendUpstreamMessage facilitates sending upstream messages to the controller.
func (w *Worker) SendUpstreamMessage(ctx context.Context, m proto.Message) (proto.Message, error) {
const op = "worker.(Worker).SendUpstreamMessage"
nodeCreds, err := types.LoadNodeCredentials(w.baseContext, w.WorkerAuthStorage, nodeenrollment.CurrentId, nodeenrollment.WithStorageWrapper(w.conf.WorkerAuthStorageKms))
if err != nil {
return nil, errors.Wrap(ctx, err, op)
}
initKeyId, err := nodeenrollment.KeyIdFromPkix(nodeCreds.CertificatePublicKeyPkix)
if err != nil {
return nil, errors.Wrap(ctx, err, op)
}
clientProducer := w.controllerUpstreamMsgConn.Load()
return handlers.SendUpstreamMessage(ctx, *clientProducer, initKeyId, m, handlers.WithKeyProducer(nodeCreds))
}

Loading…
Cancel
Save