You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
boundary/internal/daemon/controller/interceptor.go

553 lines
18 KiB

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package controller
import (
"context"
"crypto/subtle"
"fmt"
"net/http"
"reflect"
"runtime/debug"
"time"
grpc_recovery "github.com/grpc-ecosystem/go-grpc-middleware/recovery"
"github.com/hashicorp/boundary/globals"
"github.com/hashicorp/boundary/internal/alias"
commonSrv "github.com/hashicorp/boundary/internal/daemon/common"
"github.com/hashicorp/boundary/internal/daemon/controller/auth"
"github.com/hashicorp/boundary/internal/daemon/controller/common"
"github.com/hashicorp/boundary/internal/daemon/controller/handlers"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/boundary/internal/event"
pb "github.com/hashicorp/boundary/internal/gen/controller/api"
authpb "github.com/hashicorp/boundary/internal/gen/controller/auth"
pberrors "github.com/hashicorp/boundary/internal/gen/errors"
"github.com/hashicorp/boundary/internal/kms"
"github.com/hashicorp/boundary/internal/requests"
"github.com/hashicorp/go-uuid"
"github.com/mr-tron/base58"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
)
const (
requestInfoMdKey = "request-info"
// domainErrHeader defines an http header for encoded domain errors from the
// grpc server.
domainErrHeader = "x-domain-err"
// apiErrHeader defines an http header for encoded api errors from the
// grpc server.
apiErrHeader = "x-api-err"
)
// customContextServerStream wraps the grpc.ServerStream interface and lets us
// set a custom context
type customContextServerStream struct {
grpc.ServerStream
customContext context.Context
}
func (c *customContextServerStream) Context() context.Context {
if c.customContext != nil {
return c.customContext
}
return c.ServerStream.Context()
}
// requestCtxUnaryInterceptor creates an unary server interceptor that pulls
// grpc metadata into a ctx for the request. The metadata must be set in an
// upstream http handler/middleware by marshalling a RequestInfo protobuf into
// the requestInfoMdKey header (see: controller.wrapHandlerWithCommonFuncs).
func requestCtxUnaryInterceptor(
ctx context.Context,
iamRepoFn common.IamRepoFactory,
authTokenRepoFn common.AuthTokenRepoFactory,
serversRepoFn common.ServersRepoFactory,
passwordAuthRepoFn common.PasswordAuthRepoFactory,
oidcAuthRepoFn common.OidcAuthRepoFactory,
ldapAuthRepoFn common.LdapAuthRepoFactory,
kms *kms.Kms,
ticket string,
eventer *event.Eventer,
) (grpc.UnaryServerInterceptor, error) {
const op = "controller.requestCtxUnaryInterceptor"
if err := sharedRequestInterceptorValidation(
ctx,
op,
iamRepoFn,
authTokenRepoFn,
serversRepoFn,
kms,
ticket,
eventer,
); err != nil {
return nil, err
}
return func(interceptorCtx context.Context,
req any,
_ *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (any, error) {
updatedCtx, err := sharedRequestInterceptorLogic(
interceptorCtx,
op,
iamRepoFn,
authTokenRepoFn,
serversRepoFn,
passwordAuthRepoFn,
oidcAuthRepoFn,
ldapAuthRepoFn,
kms,
ticket,
eventer,
)
if err != nil {
return nil, err
}
return handler(updatedCtx, req)
}, nil
}
// requestCtxStreamInterceptor creates a stream server interceptor that pulls
// grpc metadata into a ctx for the request. The metadata must be set in an
// upstream http handler/middleware by marshalling a RequestInfo protobuf into
// the requestInfoMdKey header (see: controller.wrapHandlerWithCommonFuncs).
func requestCtxStreamInterceptor(
ctx context.Context,
iamRepoFn common.IamRepoFactory,
authTokenRepoFn common.AuthTokenRepoFactory,
serversRepoFn common.ServersRepoFactory,
passwordAuthRepoFn common.PasswordAuthRepoFactory,
oidcAuthRepoFn common.OidcAuthRepoFactory,
ldapAuthRepoFn common.LdapAuthRepoFactory,
kms *kms.Kms,
ticket string,
eventer *event.Eventer,
) (grpc.StreamServerInterceptor, error) {
const op = "controller.requestCtxStreamInterceptor"
if err := sharedRequestInterceptorValidation(
ctx,
op,
iamRepoFn,
authTokenRepoFn,
serversRepoFn,
kms,
ticket,
eventer,
); err != nil {
return nil, err
}
return func(srv any,
ss grpc.ServerStream,
info *grpc.StreamServerInfo,
handler grpc.StreamHandler,
) error {
updatedCtx, err := sharedRequestInterceptorLogic(
ss.Context(),
op,
iamRepoFn,
authTokenRepoFn,
serversRepoFn,
passwordAuthRepoFn,
oidcAuthRepoFn,
ldapAuthRepoFn,
kms,
ticket,
eventer,
)
if err != nil {
return err
}
css := &customContextServerStream{
ServerStream: ss,
customContext: updatedCtx,
}
return handler(srv, css)
}, nil
}
func sharedRequestInterceptorValidation(
ctx context.Context,
op errors.Op,
iamRepoFn common.IamRepoFactory,
authTokenRepoFn common.AuthTokenRepoFactory,
serversRepoFn common.ServersRepoFactory,
kms *kms.Kms,
ticket string,
eventer *event.Eventer,
) error {
if iamRepoFn == nil {
return errors.New(ctx, errors.InvalidParameter, op, "missing iam repo function")
}
if authTokenRepoFn == nil {
return errors.New(ctx, errors.InvalidParameter, op, "missing auth token repo function")
}
if serversRepoFn == nil {
return errors.New(ctx, errors.InvalidParameter, op, "missing server repo function")
}
if kms == nil {
return errors.New(ctx, errors.InvalidParameter, op, "missing kms")
}
if ticket == "" {
return errors.New(ctx, errors.InvalidParameter, op, "missing ticket")
}
if eventer == nil {
return errors.New(ctx, errors.InvalidParameter, op, "missing eventer")
}
return nil
}
func sharedRequestInterceptorLogic(
interceptorCtx context.Context,
op errors.Op,
iamRepoFn common.IamRepoFactory,
authTokenRepoFn common.AuthTokenRepoFactory,
serversRepoFn common.ServersRepoFactory,
passwordAuthRepoFn common.PasswordAuthRepoFactory,
oidcAuthRepoFn common.OidcAuthRepoFactory,
ldapAuthRepoFn common.LdapAuthRepoFactory,
kms *kms.Kms,
ticket string,
eventer *event.Eventer,
) (context.Context, error) {
md, ok := metadata.FromIncomingContext(interceptorCtx)
if !ok {
return nil, errors.New(interceptorCtx, errors.Internal, op, "No metadata")
}
values := md.Get(requestInfoMdKey)
if len(values) == 0 {
return nil, errors.New(interceptorCtx, errors.Internal, op, "Missing request metadata")
}
if len(values) > 1 {
return nil, errors.New(interceptorCtx, errors.Internal, op, fmt.Sprintf("expected 1 value for %s metadata and got %d", requestInfoMdKey, len(values)))
}
decoded, err := base58.FastBase58Decoding(values[0])
if err != nil {
return nil, errors.Wrap(interceptorCtx, err, op, errors.WithCode(errors.Internal), errors.WithMsg("unable to decode request info"))
}
var requestInfo authpb.RequestInfo
if err := proto.Unmarshal(decoded, &requestInfo); err != nil {
return nil, errors.Wrap(interceptorCtx, err, op, errors.WithCode(errors.Internal), errors.WithMsg("unable to unmarshal request info"))
}
switch {
case requestInfo.Ticket == "":
return nil, errors.New(interceptorCtx, errors.Internal, op, "Invalid context (missing ticket)")
case subtle.ConstantTimeCompare([]byte(requestInfo.Ticket), []byte(ticket)) != 1:
return nil, errors.New(interceptorCtx, errors.Internal, op, "Invalid context (bad ticket)")
}
interceptorCtx = auth.NewVerifierContextWithAccounts(interceptorCtx, iamRepoFn, authTokenRepoFn, serversRepoFn, passwordAuthRepoFn, oidcAuthRepoFn, ldapAuthRepoFn, kms, &requestInfo)
// Add general request information to the context. The information from
// the auth verifier context is pretty specifically curated to
// authentication/authorization verification so this is more
// general-purpose.
//
// We could use requests.NewRequestContext but this saves an immediate
// lookup.
interceptorCtx = context.WithValue(interceptorCtx, requests.ContextRequestInformationKey, &requests.RequestContext{
Path: requestInfo.Path,
Method: requestInfo.Method,
})
// This event request info is required by downstream handlers
info := &event.RequestInfo{
EventId: requestInfo.EventId,
Id: requestInfo.TraceId,
PublicId: requestInfo.PublicId,
Method: requestInfo.Method,
Path: requestInfo.Path,
ClientIp: requestInfo.ClientIp,
}
interceptorCtx, err = event.NewRequestInfoContext(interceptorCtx, info)
if err != nil {
return nil, errors.Wrap(interceptorCtx, err, op, errors.WithCode(errors.Internal), errors.WithMsg("unable to create context with request info"))
}
interceptorCtx, err = event.NewEventerContext(interceptorCtx, eventer)
if err != nil {
return nil, errors.Wrap(interceptorCtx, err, op, errors.WithCode(errors.Internal), errors.WithMsg("unable to create context with eventer"))
}
return interceptorCtx, err // not convinced we want to wrap every error and turn them into domain errors...
}
func correlationIdInterceptor(
_ context.Context,
) grpc.UnaryServerInterceptor {
const op = "controller.correlationIdInterceptor"
return func(interceptorCtx context.Context, req any,
_ *grpc.UnaryServerInfo, handler grpc.UnaryHandler,
) (any, error) {
md, ok := metadata.FromIncomingContext(interceptorCtx)
if !ok {
return nil, errors.New(interceptorCtx, errors.Internal, op, "no metadata")
}
values := md.Get(globals.CorrelationIdKey)
if len(values) == 0 {
return nil, errors.New(interceptorCtx, errors.Internal, op, "missing correlation id metadata")
}
if len(values) > 1 {
return nil, errors.New(interceptorCtx, errors.Internal, op, fmt.Sprintf("expected 1 value for %s metadata and got %d", globals.CorrelationIdKey, len(values)))
}
correlationId := values[0]
// Validate the correlationId
if _, err := uuid.ParseUUID(correlationId); err != nil {
return nil, errors.Wrap(interceptorCtx, err, op, errors.WithMsg("failed to validated correlation id"))
}
interceptorCtx, err := event.NewCorrelationIdContext(interceptorCtx, correlationId)
if err != nil {
return nil, errors.Wrap(interceptorCtx, err, op, errors.WithCode(errors.Internal), errors.WithMsg("unable to create context with correlation id"))
}
// call the handler...
return handler(interceptorCtx, req)
}
}
func errorInterceptor(
_ context.Context,
) grpc.UnaryServerInterceptor {
const op = "controller.errorInterceptor"
return func(interceptorCtx context.Context,
req any,
_ *grpc.UnaryServerInfo,
handler grpc.UnaryHandler) (any, error,
) {
// call the handler...
h, handlerErr := handler(interceptorCtx, req)
// if there's an error and it's a domain err, then encode it into the
if handlerErr != nil {
var domainErr *errors.Err
isDomainErr := errors.As(handlerErr, &domainErr)
if isDomainErr {
pbErr := errors.ToPbErrors(domainErr)
var buf []byte
var err error
if buf, err = proto.Marshal(pbErr); err != nil {
return h, status.Errorf(codes.Internal, "unable to marshal domain error: %s", err)
}
if err := grpc.SetHeader(interceptorCtx, metadata.Pairs(domainErrHeader, base58.FastBase58Encoding(buf))); err != nil {
return h, status.Errorf(codes.Internal, "unable to set domain error header: %s", err)
}
return h, handlerErr
}
var apiErr *handlers.ApiError
isApiError := errors.As(handlerErr, &apiErr)
if isApiError {
pbErr := &pberrors.ApiError{
ApiError: apiErr.Inner,
Status: apiErr.Status,
}
var buf []byte
var err error
if buf, err = proto.Marshal(pbErr); err != nil {
return h, status.Errorf(codes.Internal, "unable to marshal api error: %s", err)
}
if err := grpc.SetHeader(interceptorCtx, metadata.Pairs(apiErrHeader, base58.FastBase58Encoding(buf))); err != nil {
return h, status.Errorf(codes.Internal, "unable to set api error header: %s", err)
}
return h, handlerErr
}
}
return h, handlerErr
}
}
// aliasResolutionInterceptor returns a grpc.UnaryServerInterceptor that resolves
// alias values in the request to their corresponding destination ids. If no
// alias is found or the alias has no destination id, an error is returned.
// For an field in the request to be considered for alias resolution, it must
// be annotated with the Aliasable proto option.
func aliasResolutionInterceptor(
_ context.Context,
aliasRepoFn common.AliasRepoFactory,
) grpc.UnaryServerInterceptor {
return func(interceptorCtx context.Context,
req any,
_ *grpc.UnaryServerInfo,
handler grpc.UnaryHandler) (any, error,
) {
reqMsg, ok := req.(proto.Message)
if !ok {
return nil, handlers.InvalidArgumentErrorf("The request was not a proto.Message.", nil)
}
r, err := aliasRepoFn()
if err != nil {
return nil, err
}
interceptorCtx, err = alias.ResolveRequestIds(interceptorCtx, reqMsg, r)
if err != nil {
// At this point, the request is unauthorized, therefore return a
// static error rather than exposing what the result of
// `ResolveRequestIds` was.
return nil, handlers.NotFoundError()
}
return handler(interceptorCtx, req)
}
}
func statusCodeInterceptor(
_ context.Context,
) grpc.UnaryServerInterceptor {
const op = "controller.statusCodeInterceptor"
return func(interceptorCtx context.Context,
req any,
_ *grpc.UnaryServerInfo,
handler grpc.UnaryHandler) (any, error,
) {
// call the handler...
h, handlerErr := handler(interceptorCtx, req)
// if a service handler returns nil, nil then we want to single a 204
// response to the proxy with no resp msg
if isNil(h) && handlerErr == nil {
if err := handlers.SetStatusCode(interceptorCtx, http.StatusNoContent); err != nil {
return &pb.EmptyResponse{}, err
}
return &pb.EmptyResponse{}, nil
}
return h, handlerErr
}
}
func isNil(i any) bool {
if i == nil {
return true
}
switch reflect.TypeOf(i).Kind() {
case reflect.Ptr, reflect.Map, reflect.Array, reflect.Chan, reflect.Slice:
return reflect.ValueOf(i).IsNil()
}
return false
}
func eventsRequestInterceptor(
_ context.Context,
) grpc.UnaryServerInterceptor {
const op = "controller.eventsRequestInterceptor"
return func(interceptorCtx context.Context,
req any,
_ *grpc.UnaryServerInfo,
handler grpc.UnaryHandler) (any, error,
) {
if msg, ok := req.(proto.Message); ok {
// Clone the request before writing it to the audit log,
// in case downstream interceptors modify it.
clonedMsg := proto.Clone(msg)
if err := event.WriteAudit(interceptorCtx, op, event.WithRequest(&event.Request{Details: clonedMsg})); err != nil {
return req, status.Errorf(codes.Internal, "unable to write request msg audit: %s", err)
}
if err := event.WriteObservation(interceptorCtx, op, event.WithRequest(&event.Request{Details: clonedMsg})); err != nil {
return req, status.Errorf(codes.Internal, "unable to write request msg observation: %s", err)
}
}
return handler(interceptorCtx, req)
}
}
func eventsResponseInterceptor(
_ context.Context,
) grpc.UnaryServerInterceptor {
const op = "controller.eventsResponseInterceptor"
return func(interceptorCtx context.Context,
req any,
_ *grpc.UnaryServerInfo,
handler grpc.UnaryHandler) (any, error,
) {
// call the handler...
resp, err := handler(interceptorCtx, req)
if msg, ok := resp.(proto.Message); ok {
// Clone the response before writing it to the audit log,
// in case downstream interceptors modify it.
clonedMsg := proto.Clone(msg)
if err := event.WriteAudit(interceptorCtx, op, event.WithResponse(&event.Response{Details: clonedMsg})); err != nil {
return req, status.Errorf(codes.Internal, "unable to write response msg audit: %s", err)
}
if err := event.WriteObservation(interceptorCtx, op, event.WithResponse(&event.Response{Details: clonedMsg})); err != nil {
return req, status.Errorf(codes.Internal, "unable to write response msg observation: %s", err)
}
}
return resp, err
}
}
func requestMaxDurationInterceptor(_ context.Context, maxRequestDuration time.Duration) grpc.UnaryServerInterceptor {
const op = "controller.requestMaxDurationInterceptor"
return func(interceptorCtx context.Context, req any, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
withTimeout, cancel := context.WithTimeout(interceptorCtx, maxRequestDuration)
defer cancel()
return handler(withTimeout, req)
}
}
func workerRequestInfoInterceptor(ctx context.Context, eventer *event.Eventer) (grpc.UnaryServerInterceptor, error) {
const op = "worker.requestInfoInterceptor"
if eventer == nil {
return nil, errors.New(ctx, errors.InvalidParameter, op, "missing eventer")
}
return func(interceptorCtx context.Context,
req any,
srvInfo *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (any, error) {
var err error
id, err := event.NewId(event.IdPrefix)
if err != nil {
event.WriteError(interceptorCtx, op, err, event.WithInfoMsg("unable to create id for event", "method", srvInfo.FullMethod))
return nil, status.Errorf(codes.Internal, "Error creating id for event: %v", err)
}
info := &event.RequestInfo{
EventId: id,
Id: commonSrv.GeneratedTraceId(interceptorCtx),
Method: srvInfo.FullMethod,
}
interceptorCtx, err = event.NewRequestInfoContext(interceptorCtx, info)
if err != nil {
event.WriteError(interceptorCtx, op, err, event.WithInfoMsg("unable to create context with request info", "method", srvInfo.FullMethod))
return nil, status.Errorf(codes.Internal, "Error creating context with request info: %v", err)
}
interceptorCtx, err = event.NewEventerContext(interceptorCtx, eventer)
if err != nil {
event.WriteError(interceptorCtx, op, err, event.WithInfoMsg("unable to create context with eventer", "method", srvInfo.FullMethod))
return nil, status.Errorf(codes.Internal, "Error creating context with eventer: %v", err)
}
// call the handler...
return handler(interceptorCtx, req)
}, nil
}
func recoveryHandler() grpc_recovery.RecoveryHandlerFuncContext {
const op = "controller.recoveryHandler"
return func(ctx context.Context, p any) (err error) {
event.WriteError(
ctx,
op,
fmt.Errorf("recovered from panic: %v", p),
event.WithInfo("stack", string(debug.Stack())),
)
return status.Errorf(codes.Internal, "%v", p)
}
}