refact: split UpstreamMessageHandler into two interfaces

Originally, UpstreamMessageHandler was used by
controllerUpstreamMessageServiceServer.UpstreamMessage to process
(aka handle) an upstream message and by SendUpstreamMessage to
send an upstream msg.

We split UpstreamMessageHandler into two interfaces so we can
register handlers independently of a senders.
pull/3251/head
Jim 3 years ago committed by Timothy Messier
parent 431b9d25a2
commit 8cd0e8595f
No known key found for this signature in database
GPG Key ID: EFD2F184F7600572

@ -84,15 +84,21 @@ func TestRegisterHandlerFn(t *testing.T, msgType pbs.MsgType, h UpstreamMessageH
return func(t *testing.T) {
t.Helper()
testCtx := context.Background()
var cp sync.Map
var cpHandlerRegistry sync.Map
upstreamMessageHandler.Range(func(k, v interface{}) bool {
cp.Store(k, v)
cpHandlerRegistry.Store(k, v)
return true
})
var cpTypeSpecifierRegister sync.Map
upstreamMessageTypeSpecifier.Range(func(k, v interface{}) bool {
cpTypeSpecifierRegister.Store(k, v)
return true
})
require.NoError(t, RegisterUpstreamMessageHandler(testCtx, msgType, h))
t.Cleanup(func() {
upstreamMessageHandler = cp
upstreamMessageHandler = cpHandlerRegistry
upstreamMessageTypeSpecifier = cpTypeSpecifierRegister
})
}
}

@ -31,14 +31,7 @@ type UpstreamMessageHandler interface {
// using google.golang.org/grpc/status
Handler(ctx context.Context, request proto.Message) (response proto.Message, statusErr error)
// Encrypted returns true if the handler request/response should be encrypted
Encrypted() bool
// AllocRequest will allocate handler specific request proto message
AllocRequest() proto.Message
// AllocResponse will allocate a handler specific response proto message
AllocResponse() proto.Message
UpstreamMessageTypeSpecifier
}
// RegisterUpstreamMessageHandler will register an UpstreamMessageHandler for
@ -56,6 +49,9 @@ func RegisterUpstreamMessageHandler(ctx context.Context, msgType pbs.MsgType, h
return errors.New(ctx, errors.InvalidParameter, op, "missing handler")
}
upstreamMessageHandler.Store(msgType, h)
if err := RegisterUpstreamMessageTypeSpecifier(ctx, msgType, h); err != nil {
return errors.Wrap(ctx, err, op)
}
return nil
}

@ -28,12 +28,13 @@ func Test_RegisterUpstreamMessageHandler(t *testing.T) {
// handlers pkg state.
testCtx := context.Background()
tests := []struct {
name string
msgType pbs.MsgType
h UpstreamMessageHandler
wantErr bool
wantErrMatch *errors.Template
wantErrContains string
name string
msgType pbs.MsgType
h UpstreamMessageHandler
withUpstreamMsgTypeSpecifier bool
wantErr bool
wantErrMatch *errors.Template
wantErrContains string
}{
{
name: "missing-msg-type",
@ -54,6 +55,12 @@ func Test_RegisterUpstreamMessageHandler(t *testing.T) {
msgType: pbs.MsgType_MSG_TYPE_ECHO,
h: &TestMockUpstreamMessageHandler{},
},
{
name: "success-with-type-specifier",
msgType: pbs.MsgType_MSG_TYPE_ECHO,
h: &TestMockUpstreamMessageHandler{},
withUpstreamMsgTypeSpecifier: true,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
@ -80,6 +87,11 @@ func Test_RegisterUpstreamMessageHandler(t *testing.T) {
require.NoError(err)
_, ok := getUpstreamMessageHandler(testCtx, tc.msgType)
assert.True(ok)
if tc.withUpstreamMsgTypeSpecifier {
_, ok := getUpstreamMessageTypeSpecifier(testCtx, tc.msgType)
assert.True(ok)
}
})
}
}
@ -420,3 +432,64 @@ func Test_SendUpstreamMessage(t *testing.T) {
})
}
}
func Test_RegisterUpstreamMessageTypeSpecifier(t *testing.T) {
// IMPORTANT: cannot run with t.Parallel() because it operates on the
// handlers pkg state.
testCtx := context.Background()
tests := []struct {
name string
msgType pbs.MsgType
s UpstreamMessageTypeSpecifier
wantErr bool
wantErrMatch *errors.Template
wantErrContains string
}{
{
name: "missing-msg-type",
s: &TestMockUpstreamMessageHandler{},
wantErr: true,
wantErrMatch: errors.T(errors.InvalidParameter),
wantErrContains: "missing msg type",
},
{
name: "missing-type-specifier",
msgType: pbs.MsgType_MSG_TYPE_ECHO,
wantErr: true,
wantErrMatch: errors.T(errors.InvalidParameter),
wantErrContains: "missing type specifier",
},
{
name: "success",
msgType: pbs.MsgType_MSG_TYPE_ECHO,
s: &TestMockUpstreamMessageHandler{},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
var cp sync.Map
upstreamMessageTypeSpecifier.Range(func(k, v interface{}) bool {
cp.Store(k, v)
return true
})
err := RegisterUpstreamMessageTypeSpecifier(testCtx, tc.msgType, tc.s)
t.Cleanup(func() {
upstreamMessageHandler = cp
})
if tc.wantErr {
require.Error(err)
if tc.wantErrMatch != nil {
assert.Truef(errors.Match(tc.wantErrMatch, err), "unexpected error: %q", err.Error())
}
if tc.wantErrContains != "" {
assert.Contains(err.Error(), tc.wantErrContains)
}
return
}
require.NoError(err)
_, ok := getUpstreamMessageTypeSpecifier(testCtx, tc.msgType)
assert.True(ok)
})
}
}

@ -6,9 +6,11 @@ package handlers
import (
"context"
"fmt"
"sync"
"github.com/hashicorp/boundary/internal/errors"
pbs "github.com/hashicorp/boundary/internal/gen/controller/servers/services"
"github.com/hashicorp/boundary/internal/observability/event"
"github.com/hashicorp/boundary/internal/util"
"github.com/hashicorp/nodeenrollment"
"google.golang.org/grpc/codes"
@ -63,6 +65,62 @@ func (s *workerUpstreamMessageServiceServer) UpstreamMessage(ctx context.Context
return c.UpstreamMessage(ctx, req)
}
// UpstreamMessageTypeSpecifier defines an interface for specifying type
// information for an UpstreamMessageRequest(s).
//
// See handlers.SendUpstreamMessage for how this is
// used to send an UpstreamMessageRequest via registered upstream message
// type specifiers.
type UpstreamMessageTypeSpecifier interface {
// Encrypted returns true if the request/response should be encrypted
Encrypted() bool
// AllocRequest will allocate a type specific request proto message
AllocRequest() proto.Message
// AllocResponse will allocate a type specific response proto message
AllocResponse() proto.Message
}
var upstreamMessageTypeSpecifier sync.Map
// RegisterUpstreamMessageTypeSpecifier will register an
// UpstreamMessageTypeSpecifier for the specified msg name.
//
// See handlers.SendUpstreamMessage for how this is
// used to send UpstreamMessage requests
func RegisterUpstreamMessageTypeSpecifier(ctx context.Context, msgType pbs.MsgType, t UpstreamMessageTypeSpecifier) error {
const op = "handlers.RegisterUpstreamMessageTypeSpecifier"
switch {
case msgType == pbs.MsgType_MSG_TYPE_UNSPECIFIED:
return errors.New(ctx, errors.InvalidParameter, op, "missing msg type")
case util.IsNil(t):
return errors.New(ctx, errors.InvalidParameter, op, "missing type specifier")
}
upstreamMessageTypeSpecifier.Store(msgType, t)
return nil
}
func getUpstreamMessageTypeSpecifier(ctx context.Context, msgType pbs.MsgType) (UpstreamMessageTypeSpecifier, bool) {
const op = "handlers.getUpstreamMessageTypeSpecifier"
switch {
case msgType == pbs.MsgType_MSG_TYPE_UNSPECIFIED:
event.WriteError(ctx, op, fmt.Errorf("missing msg type"))
return nil, false
}
v, ok := upstreamMessageTypeSpecifier.Load(msgType)
if !ok {
return nil, false
}
h, ok := v.(UpstreamMessageTypeSpecifier)
if !ok {
event.WriteError(ctx, op, fmt.Errorf("malformed type specifier %q registered as incorrect type %T", msgType.String(), v))
return nil, false
}
return h, true
}
func SendUpstreamMessage(ctx context.Context, clientProducer UpstreamMessageServiceClientProducer, originatingWorkerKeyId string, msg proto.Message, opt ...Option) (proto.Message, error) {
const op = "handlers.SendUpstreamMessage"
switch {
@ -79,14 +137,14 @@ func SendUpstreamMessage(ctx context.Context, clientProducer UpstreamMessageServ
if err != nil {
return nil, status.Errorf(codes.Internal, "%s: %v", op, err)
}
h, ok := getUpstreamMessageHandler(ctx, msgType)
t, ok := getUpstreamMessageTypeSpecifier(ctx, msgType)
if !ok {
return nil, status.Errorf(codes.Unimplemented, "upstream message handler for %q is not implemented", msgType.String())
return nil, status.Errorf(codes.Unimplemented, "upstream message type specifier for %q is not implemented", msgType.String())
}
var req *pbs.UpstreamMessageRequest
switch {
case h.Encrypted() == false:
case t.Encrypted() == false:
req, err = ptMsg(ctx, msgType, msg)
if err != nil {
return nil, errors.Wrap(ctx, err, op)
@ -114,14 +172,14 @@ func SendUpstreamMessage(ctx context.Context, clientProducer UpstreamMessageServ
}
switch {
case h.Encrypted() == false:
pt := h.AllocResponse()
case t.Encrypted() == false:
pt := t.AllocResponse()
if err := proto.Unmarshal(rawResp.GetPt(), pt); err != nil {
return nil, errors.Wrap(ctx, err, op, errors.WithMsg("error marshaling echo request"))
}
return pt, nil
default:
ct := h.AllocResponse()
ct := t.AllocResponse()
if err := nodeenrollment.DecryptMessage(ctx, rawResp.GetCt(), opts.withKeyProducer, ct); err != nil {
return nil, errors.Wrap(ctx, err, op, errors.WithMsg("error decrypting unwrap keys response"))
}

Loading…
Cancel
Save