diff --git a/internal/daemon/cluster/handlers/testing.go b/internal/daemon/cluster/handlers/testing.go index 7e797aa9db..d5dc59b956 100644 --- a/internal/daemon/cluster/handlers/testing.go +++ b/internal/daemon/cluster/handlers/testing.go @@ -84,15 +84,21 @@ func TestRegisterHandlerFn(t *testing.T, msgType pbs.MsgType, h UpstreamMessageH return func(t *testing.T) { t.Helper() testCtx := context.Background() - var cp sync.Map + var cpHandlerRegistry sync.Map upstreamMessageHandler.Range(func(k, v interface{}) bool { - cp.Store(k, v) + cpHandlerRegistry.Store(k, v) + return true + }) + var cpTypeSpecifierRegister sync.Map + upstreamMessageTypeSpecifier.Range(func(k, v interface{}) bool { + cpTypeSpecifierRegister.Store(k, v) return true }) require.NoError(t, RegisterUpstreamMessageHandler(testCtx, msgType, h)) t.Cleanup(func() { - upstreamMessageHandler = cp + upstreamMessageHandler = cpHandlerRegistry + upstreamMessageTypeSpecifier = cpTypeSpecifierRegister }) } } diff --git a/internal/daemon/cluster/handlers/upstream_message_service_controller.go b/internal/daemon/cluster/handlers/upstream_message_service_controller.go index fab8848a7b..0f0951946e 100644 --- a/internal/daemon/cluster/handlers/upstream_message_service_controller.go +++ b/internal/daemon/cluster/handlers/upstream_message_service_controller.go @@ -31,14 +31,7 @@ type UpstreamMessageHandler interface { // using google.golang.org/grpc/status Handler(ctx context.Context, request proto.Message) (response proto.Message, statusErr error) - // Encrypted returns true if the handler request/response should be encrypted - Encrypted() bool - - // AllocRequest will allocate handler specific request proto message - AllocRequest() proto.Message - - // AllocResponse will allocate a handler specific response proto message - AllocResponse() proto.Message + UpstreamMessageTypeSpecifier } // RegisterUpstreamMessageHandler will register an UpstreamMessageHandler for @@ -56,6 +49,9 @@ func RegisterUpstreamMessageHandler(ctx context.Context, msgType pbs.MsgType, h return errors.New(ctx, errors.InvalidParameter, op, "missing handler") } upstreamMessageHandler.Store(msgType, h) + if err := RegisterUpstreamMessageTypeSpecifier(ctx, msgType, h); err != nil { + return errors.Wrap(ctx, err, op) + } return nil } diff --git a/internal/daemon/cluster/handlers/upstream_message_service_test.go b/internal/daemon/cluster/handlers/upstream_message_service_test.go index cc9f901dbb..9efd9c0864 100644 --- a/internal/daemon/cluster/handlers/upstream_message_service_test.go +++ b/internal/daemon/cluster/handlers/upstream_message_service_test.go @@ -28,12 +28,13 @@ func Test_RegisterUpstreamMessageHandler(t *testing.T) { // handlers pkg state. testCtx := context.Background() tests := []struct { - name string - msgType pbs.MsgType - h UpstreamMessageHandler - wantErr bool - wantErrMatch *errors.Template - wantErrContains string + name string + msgType pbs.MsgType + h UpstreamMessageHandler + withUpstreamMsgTypeSpecifier bool + wantErr bool + wantErrMatch *errors.Template + wantErrContains string }{ { name: "missing-msg-type", @@ -54,6 +55,12 @@ func Test_RegisterUpstreamMessageHandler(t *testing.T) { msgType: pbs.MsgType_MSG_TYPE_ECHO, h: &TestMockUpstreamMessageHandler{}, }, + { + name: "success-with-type-specifier", + msgType: pbs.MsgType_MSG_TYPE_ECHO, + h: &TestMockUpstreamMessageHandler{}, + withUpstreamMsgTypeSpecifier: true, + }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { @@ -80,6 +87,11 @@ func Test_RegisterUpstreamMessageHandler(t *testing.T) { require.NoError(err) _, ok := getUpstreamMessageHandler(testCtx, tc.msgType) assert.True(ok) + + if tc.withUpstreamMsgTypeSpecifier { + _, ok := getUpstreamMessageTypeSpecifier(testCtx, tc.msgType) + assert.True(ok) + } }) } } @@ -420,3 +432,64 @@ func Test_SendUpstreamMessage(t *testing.T) { }) } } + +func Test_RegisterUpstreamMessageTypeSpecifier(t *testing.T) { + // IMPORTANT: cannot run with t.Parallel() because it operates on the + // handlers pkg state. + testCtx := context.Background() + tests := []struct { + name string + msgType pbs.MsgType + s UpstreamMessageTypeSpecifier + wantErr bool + wantErrMatch *errors.Template + wantErrContains string + }{ + { + name: "missing-msg-type", + s: &TestMockUpstreamMessageHandler{}, + wantErr: true, + wantErrMatch: errors.T(errors.InvalidParameter), + wantErrContains: "missing msg type", + }, + { + name: "missing-type-specifier", + msgType: pbs.MsgType_MSG_TYPE_ECHO, + wantErr: true, + wantErrMatch: errors.T(errors.InvalidParameter), + wantErrContains: "missing type specifier", + }, + { + name: "success", + msgType: pbs.MsgType_MSG_TYPE_ECHO, + s: &TestMockUpstreamMessageHandler{}, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + var cp sync.Map + upstreamMessageTypeSpecifier.Range(func(k, v interface{}) bool { + cp.Store(k, v) + return true + }) + err := RegisterUpstreamMessageTypeSpecifier(testCtx, tc.msgType, tc.s) + t.Cleanup(func() { + upstreamMessageHandler = cp + }) + if tc.wantErr { + require.Error(err) + if tc.wantErrMatch != nil { + assert.Truef(errors.Match(tc.wantErrMatch, err), "unexpected error: %q", err.Error()) + } + if tc.wantErrContains != "" { + assert.Contains(err.Error(), tc.wantErrContains) + } + return + } + require.NoError(err) + _, ok := getUpstreamMessageTypeSpecifier(testCtx, tc.msgType) + assert.True(ok) + }) + } +} diff --git a/internal/daemon/cluster/handlers/upstream_message_service_worker.go b/internal/daemon/cluster/handlers/upstream_message_service_worker.go index f68d245dc2..3b8c8be2f1 100644 --- a/internal/daemon/cluster/handlers/upstream_message_service_worker.go +++ b/internal/daemon/cluster/handlers/upstream_message_service_worker.go @@ -6,9 +6,11 @@ package handlers import ( "context" "fmt" + "sync" "github.com/hashicorp/boundary/internal/errors" pbs "github.com/hashicorp/boundary/internal/gen/controller/servers/services" + "github.com/hashicorp/boundary/internal/observability/event" "github.com/hashicorp/boundary/internal/util" "github.com/hashicorp/nodeenrollment" "google.golang.org/grpc/codes" @@ -63,6 +65,62 @@ func (s *workerUpstreamMessageServiceServer) UpstreamMessage(ctx context.Context return c.UpstreamMessage(ctx, req) } +// UpstreamMessageTypeSpecifier defines an interface for specifying type +// information for an UpstreamMessageRequest(s). +// +// See handlers.SendUpstreamMessage for how this is +// used to send an UpstreamMessageRequest via registered upstream message +// type specifiers. +type UpstreamMessageTypeSpecifier interface { + // Encrypted returns true if the request/response should be encrypted + Encrypted() bool + + // AllocRequest will allocate a type specific request proto message + AllocRequest() proto.Message + + // AllocResponse will allocate a type specific response proto message + AllocResponse() proto.Message +} + +var upstreamMessageTypeSpecifier sync.Map + +// RegisterUpstreamMessageTypeSpecifier will register an +// UpstreamMessageTypeSpecifier for the specified msg name. +// +// See handlers.SendUpstreamMessage for how this is +// used to send UpstreamMessage requests +func RegisterUpstreamMessageTypeSpecifier(ctx context.Context, msgType pbs.MsgType, t UpstreamMessageTypeSpecifier) error { + const op = "handlers.RegisterUpstreamMessageTypeSpecifier" + switch { + case msgType == pbs.MsgType_MSG_TYPE_UNSPECIFIED: + return errors.New(ctx, errors.InvalidParameter, op, "missing msg type") + case util.IsNil(t): + return errors.New(ctx, errors.InvalidParameter, op, "missing type specifier") + } + upstreamMessageTypeSpecifier.Store(msgType, t) + return nil +} + +func getUpstreamMessageTypeSpecifier(ctx context.Context, msgType pbs.MsgType) (UpstreamMessageTypeSpecifier, bool) { + const op = "handlers.getUpstreamMessageTypeSpecifier" + switch { + case msgType == pbs.MsgType_MSG_TYPE_UNSPECIFIED: + event.WriteError(ctx, op, fmt.Errorf("missing msg type")) + return nil, false + } + v, ok := upstreamMessageTypeSpecifier.Load(msgType) + if !ok { + return nil, false + } + + h, ok := v.(UpstreamMessageTypeSpecifier) + if !ok { + event.WriteError(ctx, op, fmt.Errorf("malformed type specifier %q registered as incorrect type %T", msgType.String(), v)) + return nil, false + } + return h, true +} + func SendUpstreamMessage(ctx context.Context, clientProducer UpstreamMessageServiceClientProducer, originatingWorkerKeyId string, msg proto.Message, opt ...Option) (proto.Message, error) { const op = "handlers.SendUpstreamMessage" switch { @@ -79,14 +137,14 @@ func SendUpstreamMessage(ctx context.Context, clientProducer UpstreamMessageServ if err != nil { return nil, status.Errorf(codes.Internal, "%s: %v", op, err) } - h, ok := getUpstreamMessageHandler(ctx, msgType) + t, ok := getUpstreamMessageTypeSpecifier(ctx, msgType) if !ok { - return nil, status.Errorf(codes.Unimplemented, "upstream message handler for %q is not implemented", msgType.String()) + return nil, status.Errorf(codes.Unimplemented, "upstream message type specifier for %q is not implemented", msgType.String()) } var req *pbs.UpstreamMessageRequest switch { - case h.Encrypted() == false: + case t.Encrypted() == false: req, err = ptMsg(ctx, msgType, msg) if err != nil { return nil, errors.Wrap(ctx, err, op) @@ -114,14 +172,14 @@ func SendUpstreamMessage(ctx context.Context, clientProducer UpstreamMessageServ } switch { - case h.Encrypted() == false: - pt := h.AllocResponse() + case t.Encrypted() == false: + pt := t.AllocResponse() if err := proto.Unmarshal(rawResp.GetPt(), pt); err != nil { return nil, errors.Wrap(ctx, err, op, errors.WithMsg("error marshaling echo request")) } return pt, nil default: - ct := h.AllocResponse() + ct := t.AllocResponse() if err := nodeenrollment.DecryptMessage(ctx, rawResp.GetCt(), opts.withKeyProducer, ct); err != nil { return nil, errors.Wrap(ctx, err, op, errors.WithMsg("error decrypting unwrap keys response")) }