// Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: BUSL-1.1 package controller import ( "bytes" "context" "encoding/json" "errors" "fmt" "io/ioutil" "net/http" "net/textproto" "os" "strings" "time" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" "github.com/hashicorp/boundary/globals" "github.com/hashicorp/boundary/internal/auth/oidc" "github.com/hashicorp/boundary/internal/daemon/common" "github.com/hashicorp/boundary/internal/daemon/controller/auth" "github.com/hashicorp/boundary/internal/daemon/controller/handlers/accounts" "github.com/hashicorp/boundary/internal/daemon/controller/handlers/authmethods" "github.com/hashicorp/boundary/internal/daemon/controller/handlers/authtokens" "github.com/hashicorp/boundary/internal/daemon/controller/handlers/credentiallibraries" "github.com/hashicorp/boundary/internal/daemon/controller/handlers/credentials" "github.com/hashicorp/boundary/internal/daemon/controller/handlers/credentialstores" "github.com/hashicorp/boundary/internal/daemon/controller/handlers/groups" "github.com/hashicorp/boundary/internal/daemon/controller/handlers/health" "github.com/hashicorp/boundary/internal/daemon/controller/handlers/host_catalogs" "github.com/hashicorp/boundary/internal/daemon/controller/handlers/host_sets" "github.com/hashicorp/boundary/internal/daemon/controller/handlers/hosts" "github.com/hashicorp/boundary/internal/daemon/controller/handlers/managed_groups" "github.com/hashicorp/boundary/internal/daemon/controller/handlers/policies" "github.com/hashicorp/boundary/internal/daemon/controller/handlers/roles" "github.com/hashicorp/boundary/internal/daemon/controller/handlers/scopes" "github.com/hashicorp/boundary/internal/daemon/controller/handlers/session_recordings" "github.com/hashicorp/boundary/internal/daemon/controller/handlers/sessions" "github.com/hashicorp/boundary/internal/daemon/controller/handlers/storage_buckets" "github.com/hashicorp/boundary/internal/daemon/controller/handlers/targets" "github.com/hashicorp/boundary/internal/daemon/controller/handlers/users" "github.com/hashicorp/boundary/internal/daemon/controller/handlers/workers" "github.com/hashicorp/boundary/internal/daemon/controller/internal/metric" "github.com/hashicorp/boundary/internal/event" "github.com/hashicorp/boundary/internal/gen/controller/api/services" authpb "github.com/hashicorp/boundary/internal/gen/controller/auth" opsservices "github.com/hashicorp/boundary/internal/gen/ops/services" "github.com/hashicorp/boundary/internal/ratelimit" "github.com/hashicorp/go-cleanhttp" "github.com/hashicorp/go-secure-stdlib/listenerutil" "github.com/hashicorp/go-secure-stdlib/strutil" "github.com/mr-tron/base58" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/protobuf/proto" "github.com/hashicorp/boundary/internal/daemon/controller/handlers" ) type HandlerProperties struct { ListenerConfig *listenerutil.ListenerConfig CancelCtx context.Context } const uiPath = "/" // createMuxWithEndpoints performs all response logic for boundary, using isUiRequest // for unified logic between responses and headers. func createMuxWithEndpoints(c *Controller, props HandlerProperties) (http.Handler, func(req *http.Request) bool, error) { grpcGwMux := newGrpcGatewayMux() if err := registerGrpcGatewayEndpoints(props.CancelCtx, grpcGwMux, gatewayDialOptions(c.apiGrpcServerListener)...); err != nil { return nil, nil, err } mux := http.NewServeMux() mux.Handle("/v1/", ratelimit.Handler(c.baseContext, c.getRateLimiter, grpcGwMux)) mux.Handle(uiPath, handleUi(c)) isUiRequest := func(req *http.Request) bool { _, p := mux.Handler(req) // check to see if the matched pattern is for the ui return p == uiPath } return mux, isUiRequest, nil } // apiHandler returns an http.Handler for the services. This can be used on // its own to mount the Controller API within another web server. func (c *Controller) apiHandler(props HandlerProperties) (http.Handler, error) { mux, isUiRequest, err := createMuxWithEndpoints(c, props) if err != nil { return nil, err } corsWrappedHandler := wrapHandlerWithCors(mux, props) commonWrappedHandler := wrapHandlerWithCommonFuncs(corsWrappedHandler, c, props) callbackInterceptingHandler := wrapHandlerWithCallbackInterceptor(commonWrappedHandler, c) printablePathCheckHandler := cleanhttp.PrintablePathCheckHandler(callbackInterceptingHandler, nil) eventsHandler, err := common.WrapWithEventsHandler(c.baseContext, printablePathCheckHandler, c.conf.Eventer, c.kms, props.ListenerConfig) if err != nil { return nil, err } metricsHandler := metric.InstrumentApiHandler(eventsHandler) // This wrap MUST be performed last. If you add a new wrapper, do so above. return listenerutil.WrapCustomHeadersHandler(metricsHandler, props.ListenerConfig, isUiRequest), nil } // GetHealthHandler returns a gRPC Gateway mux that is registered against the // controller's gRPC health service to make it accessible from an HTTP API. func (c *Controller) GetHealthHandler(lcfg *listenerutil.ListenerConfig) (http.Handler, error) { const op = "controller.(Controller).GetHealthHandler" if lcfg == nil { return nil, fmt.Errorf("%s: received nil listener config", op) } healthGrpcGwMux := newGrpcGatewayMux() err := registerHealthGrpcGatewayEndpoint(c.baseContext, healthGrpcGwMux, gatewayDialOptions(c.apiGrpcServerListener)...) if err != nil { return nil, fmt.Errorf("%s: failed to register health service handler: %w", op, err) } wrapped := wrapHandlerWithCommonFuncs(healthGrpcGwMux, c, HandlerProperties{lcfg, c.baseContext}) return common.WrapWithEventsHandler(c.baseContext, wrapped, c.conf.Eventer, c.kms, lcfg) } func registerHealthGrpcGatewayEndpoint(ctx context.Context, gwMux *runtime.ServeMux, dialOptions ...grpc.DialOption) error { return opsservices.RegisterHealthServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions) } func (c *Controller) registerGrpcServices(s *grpc.Server) error { // We have to check against the current services because the gRPC lib treats a duplicate // register call as an error and os.Exits. currentServices := s.GetServiceInfo() if _, ok := currentServices[services.HostCatalogService_ServiceDesc.ServiceName]; !ok { hcs, err := host_catalogs.NewService( c.baseContext, c.StaticHostRepoFn, c.PluginHostRepoFn, c.PluginRepoFn, c.IamRepoFn, c.HostCatalogRepoFn, c.conf.RawConfig.Controller.MaxPageSize, ) if err != nil { return fmt.Errorf("failed to create host catalog handler service: %w", err) } services.RegisterHostCatalogServiceServer(s, hcs) } if _, ok := currentServices[services.HostSetService_ServiceDesc.ServiceName]; !ok { hss, err := host_sets.NewService(c.baseContext, c.StaticHostRepoFn, c.PluginHostRepoFn, c.conf.RawConfig.Controller.MaxPageSize) if err != nil { return fmt.Errorf("failed to create host set handler service: %w", err) } services.RegisterHostSetServiceServer(s, hss) } if _, ok := currentServices[services.HostService_ServiceDesc.ServiceName]; !ok { hs, err := hosts.NewService(c.baseContext, c.StaticHostRepoFn, c.PluginHostRepoFn, c.conf.RawConfig.Controller.MaxPageSize) if err != nil { return fmt.Errorf("failed to create host handler service: %w", err) } services.RegisterHostServiceServer(s, hs) } if _, ok := currentServices[services.AccountService_ServiceDesc.ServiceName]; !ok { accts, err := accounts.NewService(c.baseContext, c.PasswordAuthRepoFn, c.OidcRepoFn, c.LdapRepoFn, c.conf.RawConfig.Controller.MaxPageSize) if err != nil { return fmt.Errorf("failed to create account handler service: %w", err) } services.RegisterAccountServiceServer(s, accts) } if _, ok := currentServices[services.AuthMethodService_ServiceDesc.ServiceName]; !ok { authMethods, err := authmethods.NewService( c.baseContext, c.kms, c.PasswordAuthRepoFn, c.OidcRepoFn, c.IamRepoFn, c.AuthTokenRepoFn, c.LdapRepoFn, c.AuthMethodRepoFn, c.conf.RawConfig.Controller.MaxPageSize, ) if err != nil { return fmt.Errorf("failed to create auth method handler service: %w", err) } services.RegisterAuthMethodServiceServer(s, authMethods) } if _, ok := currentServices[services.AuthTokenService_ServiceDesc.ServiceName]; !ok { authtoks, err := authtokens.NewService(c.baseContext, c.AuthTokenRepoFn, c.IamRepoFn, c.conf.RawConfig.Controller.MaxPageSize) if err != nil { return fmt.Errorf("failed to create auth token handler service: %w", err) } services.RegisterAuthTokenServiceServer(s, authtoks) } if _, ok := currentServices[services.ScopeService_ServiceDesc.ServiceName]; !ok { os, err := scopes.NewServiceFn(c.baseContext, c.IamRepoFn, c.kms, c.conf.RawConfig.Controller.MaxPageSize) if err != nil { return fmt.Errorf("failed to create scope handler service: %w", err) } services.RegisterScopeServiceServer(s, os) } if _, ok := currentServices[services.UserService_ServiceDesc.ServiceName]; !ok { us, err := users.NewService(c.baseContext, c.IamRepoFn, c.conf.RawConfig.Controller.MaxPageSize) if err != nil { return fmt.Errorf("failed to create user handler service: %w", err) } services.RegisterUserServiceServer(s, us) } if _, ok := currentServices[services.StorageBucketService_ServiceDesc.ServiceName]; !ok { sbs, err := storage_buckets.NewServiceFn( c.baseContext, c.PluginStorageBucketRepoFn, c.IamRepoFn, c.PluginRepoFn, c.conf.RawConfig.Controller.MaxPageSize, c.ControllerExtension) if err != nil { return fmt.Errorf("failed to create storage bucket handler service: %w", err) } services.RegisterStorageBucketServiceServer(s, sbs) } if _, ok := currentServices[services.PolicyService_ServiceDesc.ServiceName]; !ok { ps, err := policies.NewServiceFn( c.baseContext, c.IamRepoFn, c.ControllerExtension, ) if err != nil { return fmt.Errorf("failed to create policy handler service: %w", err) } services.RegisterPolicyServiceServer(s, ps) } if _, ok := currentServices[services.SessionRecordingService_ServiceDesc.ServiceName]; !ok { srs, err := session_recordings.NewServiceFn( c.baseContext, c.IamRepoFn, c.workerStatusGracePeriod, c.kms, c.conf.RawConfig.Controller.MaxPageSize, c.ControllerExtension) if err != nil { return fmt.Errorf("failed to create session recording handler service: %w", err) } services.RegisterSessionRecordingServiceServer(s, srs) } if _, ok := currentServices[services.TargetService_ServiceDesc.ServiceName]; !ok { ts, err := targets.NewService( c.baseContext, c.kms, c.TargetRepoFn, c.IamRepoFn, c.ServersRepoFn, c.SessionRepoFn, c.PluginHostRepoFn, c.StaticHostRepoFn, c.VaultCredentialRepoFn, c.StaticCredentialRepoFn, c.downstreamWorkers, c.workerStatusGracePeriod, c.conf.RawConfig.Controller.MaxPageSize, c.ControllerExtension, ) if err != nil { return fmt.Errorf("failed to create target handler service: %w", err) } services.RegisterTargetServiceServer(s, ts) } if _, ok := currentServices[services.GroupService_ServiceDesc.ServiceName]; !ok { gs, err := groups.NewService(c.baseContext, c.IamRepoFn, c.conf.RawConfig.Controller.MaxPageSize) if err != nil { return fmt.Errorf("failed to create group handler service: %w", err) } services.RegisterGroupServiceServer(s, gs) } if _, ok := currentServices[services.RoleService_ServiceDesc.ServiceName]; !ok { rs, err := roles.NewService(c.baseContext, c.IamRepoFn, c.conf.RawConfig.Controller.MaxPageSize) if err != nil { return fmt.Errorf("failed to create role handler service: %w", err) } services.RegisterRoleServiceServer(s, rs) } if _, ok := currentServices[services.SessionService_ServiceDesc.ServiceName]; !ok { ss, err := sessions.NewService(c.baseContext, c.SessionRepoFn, c.IamRepoFn, c.conf.RawConfig.Controller.MaxPageSize) if err != nil { return fmt.Errorf("failed to create session handler service: %w", err) } services.RegisterSessionServiceServer(s, ss) } if _, ok := currentServices[services.ManagedGroupService_ServiceDesc.ServiceName]; !ok { mgs, err := managed_groups.NewService(c.baseContext, c.OidcRepoFn, c.LdapRepoFn, c.conf.RawConfig.Controller.MaxPageSize) if err != nil { return fmt.Errorf("failed to create managed groups handler service: %w", err) } services.RegisterManagedGroupServiceServer(s, mgs) } if _, ok := currentServices[services.CredentialStoreService_ServiceDesc.ServiceName]; !ok { cs, err := credentialstores.NewService( c.baseContext, c.IamRepoFn, c.VaultCredentialRepoFn, c.StaticCredentialRepoFn, c.CredentialStoreRepoFn, c.conf.RawConfig.Controller.MaxPageSize, ) if err != nil { return fmt.Errorf("failed to create credential store handler service: %w", err) } services.RegisterCredentialStoreServiceServer(s, cs) } if _, ok := currentServices[services.CredentialLibraryService_ServiceDesc.ServiceName]; !ok { cl, err := credentiallibraries.NewService( c.baseContext, c.IamRepoFn, c.VaultCredentialRepoFn, c.conf.RawConfig.Controller.MaxPageSize, ) if err != nil { return fmt.Errorf("failed to create credential library handler service: %w", err) } services.RegisterCredentialLibraryServiceServer(s, cl) } if _, ok := currentServices[services.WorkerService_ServiceDesc.ServiceName]; !ok { ws, err := workers.NewService(c.baseContext, c.ServersRepoFn, c.IamRepoFn, c.WorkerAuthRepoStorageFn, c.downstreamWorkers) if err != nil { return fmt.Errorf("failed to create worker handler service: %w", err) } services.RegisterWorkerServiceServer(s, ws) } if _, ok := currentServices[services.CredentialService_ServiceDesc.ServiceName]; !ok { c, err := credentials.NewService( c.baseContext, c.IamRepoFn, c.StaticCredentialRepoFn, c.conf.RawConfig.Controller.MaxPageSize, ) if err != nil { return fmt.Errorf("failed to create credential handler service: %w", err) } services.RegisterCredentialServiceServer(s, c) } if _, ok := currentServices[opsservices.HealthService_ServiceDesc.ServiceName]; !ok { hs := health.NewService() opsservices.RegisterHealthServiceServer(s, hs) c.HealthService = hs } return nil } func registerGrpcGatewayEndpoints(ctx context.Context, gwMux *runtime.ServeMux, dialOptions ...grpc.DialOption) error { // Register*ServiceHandlerServer methods ignore the passed in context. // Passing it in anyways in case this changes in the future. if err := services.RegisterHostCatalogServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { return fmt.Errorf("failed to register host catalog service handler: %w", err) } if err := services.RegisterHostSetServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { return fmt.Errorf("failed to register host set service handler: %w", err) } if err := services.RegisterHostServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { return fmt.Errorf("failed to register host service handler: %w", err) } if err := services.RegisterAccountServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { return fmt.Errorf("failed to register account service handler: %w", err) } if err := services.RegisterAuthMethodServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { return fmt.Errorf("failed to register auth method service handler: %w", err) } if err := services.RegisterAuthTokenServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { return fmt.Errorf("failed to register auth token service handler: %w", err) } if err := services.RegisterScopeServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { return fmt.Errorf("failed to register scope service handler: %w", err) } if err := services.RegisterUserServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { return fmt.Errorf("failed to register user service handler: %w", err) } if err := services.RegisterTargetServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { return fmt.Errorf("failed to register target service handler: %w", err) } if err := services.RegisterGroupServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { return fmt.Errorf("failed to register group service handler: %w", err) } if err := services.RegisterRoleServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { return fmt.Errorf("failed to register role service handler: %w", err) } if err := services.RegisterSessionServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { return fmt.Errorf("failed to register session service handler: %w", err) } if err := services.RegisterManagedGroupServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { return fmt.Errorf("failed to register managed groups service handler: %w", err) } if err := services.RegisterCredentialStoreServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { return fmt.Errorf("failed to register credential store service handler: %w", err) } if err := services.RegisterCredentialLibraryServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { return fmt.Errorf("failed to register credential library service handler: %w", err) } if err := services.RegisterWorkerServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { return fmt.Errorf("failed to register worker service handler: %w", err) } if err := services.RegisterCredentialServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { return fmt.Errorf("failed to register credential service handler: %w", err) } if err := services.RegisterSessionRecordingServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { return fmt.Errorf("failed to register session recording service handler: %w", err) } if err := services.RegisterStorageBucketServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { return fmt.Errorf("failed to register storage bucket handler: %w", err) } if err := services.RegisterSessionRecordingServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { return fmt.Errorf("failed to register session recording handler: %w", err) } if err := services.RegisterPolicyServiceHandlerFromEndpoint(ctx, gwMux, gatewayTarget, dialOptions); err != nil { return fmt.Errorf("failed to register policy handler: %w", err) } return nil } func wrapHandlerWithCommonFuncs(h http.Handler, c *Controller, props HandlerProperties) http.Handler { const op = "controller.wrapHandlerWithCommonFuncs" var maxRequestDuration time.Duration var maxRequestSize int64 if props.ListenerConfig != nil { maxRequestDuration = props.ListenerConfig.MaxRequestDuration maxRequestSize = props.ListenerConfig.MaxRequestSize } if maxRequestDuration == 0 { maxRequestDuration = globals.DefaultMaxRequestDuration } if maxRequestSize == 0 { maxRequestSize = globals.DefaultMaxRequestSize } disableAuthzFailures := c.conf.DisableAuthorizationFailures || (c.conf.RawConfig.DevController && os.Getenv("BOUNDARY_DEV_SKIP_AUTHZ") != "") if disableAuthzFailures { event.WriteSysEvent(context.TODO(), op, "AUTHORIZATION CHECKING DISABLED") } return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Set the Cache-Control header for all responses returned w.Header().Set("Cache-Control", "no-store") // Start with the request context and our timeout ctx, cancelFunc := context.WithTimeout(r.Context(), maxRequestDuration) defer cancelFunc() // Add a size limiter if desired if maxRequestSize > 0 { ctx = context.WithValue(ctx, globals.ContextMaxRequestSizeTypeKey, maxRequestSize) } // Add values for authn/authz checking requestInfo := authpb.RequestInfo{ Path: r.URL.Path, Method: r.Method, DisableAuthzFailures: disableAuthzFailures, } requestInfo.PublicId, requestInfo.EncryptedToken, requestInfo.TokenFormat = auth.GetTokenFromRequest(ctx, c.kms, r) ctx = context.WithValue(ctx, globals.ContextAuthTokenPublicIdKey, requestInfo.PublicId) if info, ok := event.RequestInfoFromContext(ctx); ok { // piggyback some eventing fields with the auth info proto message requestInfo.EventId = info.EventId requestInfo.TraceId = info.Id requestInfo.ClientIp = info.ClientIp } else { w.WriteHeader(http.StatusInternalServerError) event.WriteError(ctx, op, errors.New("unable to read event request info from context")) return } // Serialize the request info to send it across the wire to the // grpc-gateway via an http header requestInfo.Ticket = c.apiGrpcGatewayTicket // allows the grpc-gateway to verify the request info came from it's in-memory companion http proxy marshalledRequestInfo, err := proto.Marshal(&requestInfo) if err != nil { event.WriteError(ctx, op, err, event.WithInfoMsg("error marshaling request info")) w.WriteHeader(http.StatusInternalServerError) return } // Use the default grpc-gateway mapping rule to pass the request info as // metadata. // See: https://pkg.go.dev/github.com/grpc-ecosystem/grpc-gateway/runtime#DefaultHeaderMatcher r.Header.Set("Grpc-Metadata-"+requestInfoMdKey, base58.FastBase58Encoding(marshalledRequestInfo)) // Set the context back on the request r = r.Clone(ctx) h.ServeHTTP(w, r) }) } func wrapHandlerWithCors(h http.Handler, props HandlerProperties) http.Handler { allowedMethods := []string{ http.MethodDelete, http.MethodGet, http.MethodOptions, http.MethodPost, http.MethodPatch, } allowedOrigins := props.ListenerConfig.CorsAllowedOrigins allowedHeaders := append([]string{ "Content-Type", "X-Requested-With", "Authorization", }, props.ListenerConfig.CorsAllowedHeaders...) allowedResponseHeaders := strings.Join([]string{ "Retry-After", "RateLimit", "RateLimit-Policy", }, ", ") return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { if props.ListenerConfig.CorsEnabled == nil || !*props.ListenerConfig.CorsEnabled { h.ServeHTTP(w, req) return } origin := req.Header.Get("Origin") if origin == "" { // Serve directly h.ServeHTTP(w, req) return } // Check origin var valid bool switch { case len(allowedOrigins) == 0: // not valid case len(allowedOrigins) == 1 && allowedOrigins[0] == "*": valid = true // When allowed origins is "*" we want to return that rather than // round-tripping any user-specified value origin = "*" default: valid = strutil.StrListContains(allowedOrigins, origin) } if !valid { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusForbidden) err := handlers.ApiErrorWithCodeAndMessage(codes.PermissionDenied, "origin forbidden") enc := json.NewEncoder(w) _ = enc.Encode(err) return } if req.Method == http.MethodOptions && !strutil.StrListContains(allowedMethods, req.Header.Get("Access-Control-Request-Method")) { w.WriteHeader(http.StatusMethodNotAllowed) return } w.Header().Set("Access-Control-Allow-Origin", origin) w.Header().Set("Vary", "Origin") w.Header().Set("Access-Control-Expose-Headers", allowedResponseHeaders) // Apply headers for preflight requests if req.Method == http.MethodOptions { w.Header().Set("Access-Control-Allow-Methods", strings.Join(allowedMethods, ", ")) w.Header().Set("Access-Control-Allow-Headers", strings.Join(allowedHeaders, ", ")) w.Header().Set("Access-Control-Max-Age", "300") w.WriteHeader(http.StatusNoContent) return } h.ServeHTTP(w, req) }) } type cmdAttrs struct { Command string `json:"command,omitempty"` Attributes any `json:"attributes,omitempty"` } func wrapHandlerWithCallbackInterceptor(h http.Handler, c *Controller) http.Handler { logCallbackErrors := os.Getenv("BOUNDARY_LOG_CALLBACK_ERRORS") != "" return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { const op = "controller.wrapHandlerWithCallbackInterceptor" ctx := req.Context() var err error id, err := event.NewId(event.IdField) if err != nil { w.WriteHeader(http.StatusInternalServerError) event.WriteError(ctx, op, err, event.WithInfoMsg("unable to create id for event", "method", req.Method, "url", req.URL.RequestURI())) return } info := &event.RequestInfo{ EventId: id, Id: common.GeneratedTraceId(ctx), PublicId: "unknown", Method: req.Method, Path: req.URL.RequestURI(), } ctx, err = event.NewRequestInfoContext(ctx, info) if err != nil { w.WriteHeader(http.StatusInternalServerError) event.WriteError(req.Context(), op, err, event.WithInfoMsg("unable to create context with request info", "method", req.Method, "url", req.URL.RequestURI())) return } // If this doesn't have a callback suffix on a supported action, serve // normally if !strings.HasSuffix(req.URL.Path, ":authenticate:callback") { h.ServeHTTP(w, req) return } req.URL.Path = strings.TrimSuffix(req.URL.Path, ":callback") // How we get the parameters changes based on the method. Right now only // GET is supported with query args, but this can support POST with JSON // or URL-encoded args. In those cases, the MIME type would have to be // checked; for URL-encoded it'd use ParseForm like Get, and for JSON // you'd use a json.RawMessage for Attributes consisting of the body. Or // something very similar to that. var useForm bool switch req.Method { case http.MethodGet: if err := req.ParseForm(); err != nil { if logCallbackErrors && c != nil { event.WriteError(ctx, op, err, event.WithInfoMsg("callback error")) } w.WriteHeader(http.StatusBadRequest) return } useForm = true } attrs := &cmdAttrs{ Command: "callback", } switch { case useForm: if len(req.Form) > 0 { values := make(map[string]any, len(req.Form)) // This won't handle repeated values. That's fine, at least for now. // We can address that if needed, which seems unlikely. for k := range req.Form { values[k] = req.Form.Get(k) } if strings.HasSuffix(req.URL.Path, "oidc:authenticate") { if s, ok := values["state"].(string); ok { stateWrapper, err := oidc.UnwrapMessage(context.Background(), s) if err != nil { event.WriteError(ctx, op, err, event.WithInfoMsg("error marshaling state")) w.WriteHeader(http.StatusInternalServerError) return } if stateWrapper.AuthMethodId == "" { event.WriteError(ctx, op, err, event.WithInfoMsg("missing auth method id")) w.WriteHeader(http.StatusInternalServerError) return } stripped := strings.TrimSuffix(req.URL.Path, "oidc:authenticate") req.URL.Path = fmt.Sprintf("%s%s:authenticate", stripped, stateWrapper.AuthMethodId) } else { event.WriteError(ctx, op, errors.New("missing state parameter")) w.WriteHeader(http.StatusInternalServerError) return } } attrs.Attributes = values } attrBytes, err := json.Marshal(attrs) if err != nil { if logCallbackErrors && c != nil { event.WriteError(ctx, op, err, event.WithInfoMsg("error marshaling json")) } w.WriteHeader(http.StatusInternalServerError) return } // If there is any existing body, close it as we're going to replace // it. It shouldn't be populated in this code path, but you never // know. if req.Body != nil { if err := req.Body.Close(); err != nil { if logCallbackErrors && c != nil { event.WriteError(ctx, op, err, event.WithInfoMsg("error closing original request body")) } } } bytesReader := bytes.NewReader(attrBytes) req.Body = ioutil.NopCloser(bytesReader) req.ContentLength = int64(bytesReader.Len()) req.Header.Set(textproto.CanonicalMIMEHeaderKey("content-type"), "application/json") req.Method = http.MethodPost } h.ServeHTTP(w, req) }) } /* func WrapForwardedForHandler(h http.Handler, authorizedAddrs []*sockaddr.SockAddrMarshaler, rejectNotPresent, rejectNonAuthz bool, hopSkips int) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { headers, headersOK := r.Header[textproto.CanonicalMIMEHeaderKey("X-Forwarded-For")] if !headersOK || len(headers) == 0 { if !rejectNotPresent { h.ServeHTTP(w, r) return } respondError(w, http.StatusBadRequest, fmt.Errorf("missing x-forwarded-for header and configured to reject when not present")) return } host, port, err := net.SplitHostPort(r.RemoteAddr) if err != nil { // If not rejecting treat it like we just don't have a valid // header because we can't do a comparison against an address we // can't understand if !rejectNotPresent { h.ServeHTTP(w, r) return } respondError(w, http.StatusBadRequest, errwrap.Wrapf("error parsing client hostport: {{err}}", err)) return } addr, err := sockaddr.NewIPAddr(host) if err != nil { // We treat this the same as the case above if !rejectNotPresent { h.ServeHTTP(w, r) return } respondError(w, http.StatusBadRequest, errwrap.Wrapf("error parsing client address: {{err}}", err)) return } var found bool for _, authz := range authorizedAddrs { if authz.Contains(addr) { found = true break } } if !found { // If we didn't find it and aren't configured to reject, simply // don't trust it if !rejectNonAuthz { h.ServeHTTP(w, r) return } respondError(w, http.StatusBadRequest, fmt.Errorf("client address not authorized for x-forwarded-for and configured to reject connection")) return } // At this point we have at least one value and it's authorized // Split comma separated ones, which are common. This brings it in line // to the multiple-header case. var acc []string for _, header := range headers { vals := strings.Split(header, ",") for _, v := range vals { acc = append(acc, strings.TrimSpace(v)) } } indexToUse := len(acc) - 1 - hopSkips if indexToUse < 0 { // This is likely an error in either configuration or other // infrastructure. We could either deny the request, or we // could simply not trust the value. Denying the request is // "safer" since if this logic is configured at all there may // be an assumption it can always be trusted. Given that we can // deny accepting the request at all if it's not from an // authorized address, if we're at this point the address is // authorized (or we've turned off explicit rejection) and we // should assume that what comes in should be properly // formatted. respondError(w, http.StatusBadRequest, fmt.Errorf("malformed x-forwarded-for configuration or request, hops to skip (%d) would skip before earliest chain link (chain length %d)", hopSkips, len(headers))) return } r.RemoteAddr = net.JoinHostPort(acc[indexToUse], port) h.ServeHTTP(w, r) return }) } */