mirror of https://github.com/hashicorp/boundary
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
238 lines
8.3 KiB
238 lines
8.3 KiB
// Copyright (c) HashiCorp, Inc.
|
|
// SPDX-License-Identifier: MPL-2.0
|
|
|
|
package handlers
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"sync"
|
|
|
|
"github.com/hashicorp/boundary/internal/errors"
|
|
pbs "github.com/hashicorp/boundary/internal/gen/controller/servers/services"
|
|
"github.com/hashicorp/boundary/internal/observability/event"
|
|
"github.com/hashicorp/boundary/internal/util"
|
|
"github.com/hashicorp/nodeenrollment"
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/status"
|
|
"google.golang.org/protobuf/proto"
|
|
)
|
|
|
|
// UpstreamMessageServiceClientProducer produces a client and should be
|
|
// self-healing if an existing grpc connection closes.
|
|
type UpstreamMessageServiceClientProducer func(context.Context) (pbs.UpstreamMessageServiceClient, error)
|
|
|
|
// workerUpstreamMessageServiceServer implements the
|
|
// UpstreamMessageServiceServer for workers and always forwards requests using
|
|
// its clients
|
|
type workerUpstreamMessageServiceServer struct {
|
|
pbs.UnimplementedUpstreamMessageServiceServer
|
|
clientProducer UpstreamMessageServiceClientProducer
|
|
}
|
|
|
|
var _ pbs.UpstreamMessageServiceServer = (*workerUpstreamMessageServiceServer)(nil)
|
|
|
|
// NewWorkerUpstreamMessageServiceServer creates a new service implementing
|
|
// UpstreamMessageServiceServer, storing values used for the implementing
|
|
// functions.
|
|
func NewWorkerUpstreamMessageServiceServer(
|
|
ctx context.Context,
|
|
clientProducer UpstreamMessageServiceClientProducer,
|
|
) (*workerUpstreamMessageServiceServer, error) {
|
|
const op = "handlers.NewWorkerUpstreamMessageServiceServer"
|
|
switch {
|
|
case clientProducer == nil:
|
|
return nil, errors.New(ctx, errors.InvalidParameter, op, "missing client producer")
|
|
}
|
|
|
|
return &workerUpstreamMessageServiceServer{
|
|
clientProducer: clientProducer,
|
|
}, nil
|
|
}
|
|
|
|
// UpstreamMessage implements the grpc service of the same name for workers and
|
|
// simply forwards the request using its client
|
|
func (s *workerUpstreamMessageServiceServer) UpstreamMessage(ctx context.Context, req *pbs.UpstreamMessageRequest) (*pbs.UpstreamMessageResponse, error) {
|
|
const op = "handlers.(workerUpstreamMessageServiceServer).UpstreamMessage"
|
|
switch {
|
|
case req == nil:
|
|
return nil, status.Errorf(codes.Internal, "%s: missing request", op)
|
|
}
|
|
c, err := s.clientProducer(ctx)
|
|
if err != nil {
|
|
return nil, status.Errorf(codes.Internal, "%s: unable to get client: %v", op, err)
|
|
}
|
|
return c.UpstreamMessage(ctx, req)
|
|
}
|
|
|
|
// UpstreamMessageTypeSpecifier defines an interface for specifying type
|
|
// information for an UpstreamMessageRequest(s).
|
|
//
|
|
// See handlers.SendUpstreamMessage for how this is
|
|
// used to send an UpstreamMessageRequest via registered upstream message
|
|
// type specifiers.
|
|
type UpstreamMessageTypeSpecifier interface {
|
|
// Encrypted returns true if the request/response should be encrypted
|
|
Encrypted() bool
|
|
|
|
// AllocRequest will allocate a type specific request proto message
|
|
AllocRequest() proto.Message
|
|
|
|
// AllocResponse will allocate a type specific response proto message
|
|
AllocResponse() proto.Message
|
|
}
|
|
|
|
var upstreamMessageTypeSpecifier sync.Map
|
|
|
|
// registerUpstreamMessageTypeSpecifier will register an
|
|
// UpstreamMessageTypeSpecifier for the specified msg name.
|
|
//
|
|
// See handlers.SendUpstreamMessage for how this is
|
|
// used to send UpstreamMessage requests
|
|
func registerUpstreamMessageTypeSpecifier(ctx context.Context, msgType pbs.MsgType, t UpstreamMessageTypeSpecifier) error {
|
|
const op = "handlers.registerUpstreamMessageTypeSpecifier"
|
|
switch {
|
|
case msgType == pbs.MsgType_MSG_TYPE_UNSPECIFIED:
|
|
return errors.New(ctx, errors.InvalidParameter, op, "missing msg type")
|
|
case util.IsNil(t):
|
|
return errors.New(ctx, errors.InvalidParameter, op, "missing type specifier")
|
|
}
|
|
upstreamMessageTypeSpecifier.Store(msgType, t)
|
|
return nil
|
|
}
|
|
|
|
func getUpstreamMessageTypeSpecifier(ctx context.Context, msgType pbs.MsgType) (UpstreamMessageTypeSpecifier, bool) {
|
|
const op = "handlers.getUpstreamMessageTypeSpecifier"
|
|
switch {
|
|
case msgType == pbs.MsgType_MSG_TYPE_UNSPECIFIED:
|
|
event.WriteError(ctx, op, fmt.Errorf("missing msg type"))
|
|
return nil, false
|
|
}
|
|
v, ok := upstreamMessageTypeSpecifier.Load(msgType)
|
|
if !ok {
|
|
return nil, false
|
|
}
|
|
|
|
h, ok := v.(UpstreamMessageTypeSpecifier)
|
|
if !ok {
|
|
event.WriteError(ctx, op, fmt.Errorf("malformed type specifier %q registered as incorrect type %T", msgType.String(), v))
|
|
return nil, false
|
|
}
|
|
return h, true
|
|
}
|
|
|
|
func SendUpstreamMessage(ctx context.Context, clientProducer UpstreamMessageServiceClientProducer, originatingWorkerKeyId string, msg proto.Message, opt ...Option) (proto.Message, error) {
|
|
const op = "handlers.SendUpstreamMessage"
|
|
switch {
|
|
case clientProducer == nil:
|
|
return nil, status.Errorf(codes.Internal, "%s: missing client producer", op)
|
|
case originatingWorkerKeyId == "":
|
|
return nil, status.Errorf(codes.Internal, "%s: missing originating worker key id", op)
|
|
case util.IsNil(msg):
|
|
return nil, status.Errorf(codes.Internal, "%s: missing message", op)
|
|
}
|
|
opts := getOpts(opt...)
|
|
|
|
msgType, err := toMsgType(ctx, msg)
|
|
if err != nil {
|
|
return nil, status.Errorf(codes.Internal, "%s: %v", op, err)
|
|
}
|
|
t, ok := getUpstreamMessageTypeSpecifier(ctx, msgType)
|
|
if !ok {
|
|
return nil, status.Errorf(codes.Unimplemented, "upstream message type specifier for %q is not implemented", msgType.String())
|
|
}
|
|
|
|
var req *pbs.UpstreamMessageRequest
|
|
switch {
|
|
case t.Encrypted() == false:
|
|
req, err = ptMsg(ctx, msgType, msg)
|
|
if err != nil {
|
|
return nil, errors.Wrap(ctx, err, op)
|
|
}
|
|
default:
|
|
if opts.withKeyProducer == nil {
|
|
return nil, errors.New(ctx, errors.InvalidParameter, op, "missing node information required for encrypting unwrap keys message")
|
|
}
|
|
req, err = ctMsg(ctx, opts.withKeyProducer, msgType, msg)
|
|
if err != nil {
|
|
return nil, errors.Wrap(ctx, err, op)
|
|
}
|
|
}
|
|
|
|
req.OriginatingWorkerKeyId = originatingWorkerKeyId
|
|
|
|
c, err := clientProducer(ctx)
|
|
if err != nil {
|
|
return nil, errors.Wrap(ctx, err, op, errors.WithMsg("error getting a client producer"))
|
|
}
|
|
|
|
rawResp, err := c.UpstreamMessage(ctx, req)
|
|
if err != nil {
|
|
return nil, errors.Wrap(ctx, err, op, errors.WithMsg("error from upstream client"))
|
|
}
|
|
|
|
switch {
|
|
case t.Encrypted() == false:
|
|
pt := t.AllocResponse()
|
|
if err := proto.Unmarshal(rawResp.GetPt(), pt); err != nil {
|
|
return nil, errors.Wrap(ctx, err, op, errors.WithMsg("error marshaling echo request"))
|
|
}
|
|
return pt, nil
|
|
default:
|
|
ct := t.AllocResponse()
|
|
if err := nodeenrollment.DecryptMessage(ctx, rawResp.GetCt(), opts.withKeyProducer, ct); err != nil {
|
|
return nil, errors.Wrap(ctx, err, op, errors.WithMsg("error decrypting unwrap keys response"))
|
|
}
|
|
return ct, nil
|
|
}
|
|
}
|
|
|
|
func ptMsg(ctx context.Context, msgType pbs.MsgType, msg proto.Message) (*pbs.UpstreamMessageRequest, error) {
|
|
const op = "handlers.ptMsg"
|
|
|
|
pt, err := proto.Marshal(msg)
|
|
if err != nil {
|
|
return nil, errors.Wrap(ctx, err, op, errors.WithMsg("error marshaling upstream message"))
|
|
}
|
|
return &pbs.UpstreamMessageRequest{
|
|
MsgType: msgType,
|
|
Message: &pbs.UpstreamMessageRequest_Pt{
|
|
Pt: pt,
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
func ctMsg(ctx context.Context, keySource nodeenrollment.X25519KeyProducer, msgType pbs.MsgType, msg proto.Message) (*pbs.UpstreamMessageRequest, error) {
|
|
const op = "handlers.encryptMsg"
|
|
ct, err := nodeenrollment.EncryptMessage(ctx, msg, keySource)
|
|
if err != nil {
|
|
return nil, errors.Wrap(ctx, err, op, errors.WithMsg("error encrypting upstream message"))
|
|
}
|
|
return &pbs.UpstreamMessageRequest{
|
|
MsgType: msgType,
|
|
Message: &pbs.UpstreamMessageRequest_Ct{
|
|
Ct: ct,
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
func toMsgType(ctx context.Context, m proto.Message) (pbs.MsgType, error) {
|
|
const op = "handlers.toMsgType"
|
|
switch t := m.(type) {
|
|
case *pbs.EchoUpstreamMessageRequest, *pbs.EchoUpstreamMessageResponse:
|
|
return pbs.MsgType_MSG_TYPE_ECHO, nil
|
|
case *pbs.UnwrapKeysRequest, *pbs.UnwrapKeysResponse:
|
|
return pbs.MsgType_MSG_TYPE_UNWRAP_KEYS, nil
|
|
case *pbs.VerifySignatureRequest, *pbs.VerifySignatureResponse:
|
|
return pbs.MsgType_MSG_TYPE_VERIFY_SIGNATURE, nil
|
|
case *pbs.CloseSessionRecordingRequest, *pbs.CloseSessionRecordingResponse:
|
|
return pbs.MsgType_MSG_TYPE_CLOSE_SESSION_RECORDING, nil
|
|
case *pbs.CloseConnectionRecordingRequest, *pbs.CloseConnectionRecordingResponse:
|
|
return pbs.MsgType_MSG_TYPE_CLOSE_CONNECTION_RECORDING, nil
|
|
case *pbs.CreateChannelRecordingRequest, *pbs.CreateChannelRecordingResponse:
|
|
return pbs.MsgType_MSG_TYPE_CREATE_CHANNEL_RECORDING, nil
|
|
default:
|
|
return pbs.MsgType_MSG_TYPE_UNSPECIFIED, errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("%q is an unknown msg type", t))
|
|
}
|
|
}
|