You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
boundary/internal/daemon/controller/handlers/authmethods/oidc.go

511 lines
18 KiB

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package authmethods
import (
"context"
"encoding/json"
"fmt"
"net/url"
"strings"
"github.com/hashicorp/boundary/internal/auth/oidc"
oidcstore "github.com/hashicorp/boundary/internal/auth/oidc/store"
"github.com/hashicorp/boundary/internal/daemon/controller/auth"
"github.com/hashicorp/boundary/internal/daemon/controller/handlers"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/boundary/internal/event"
pbs "github.com/hashicorp/boundary/internal/gen/controller/api/services"
"github.com/hashicorp/boundary/internal/types/action"
pb "github.com/hashicorp/boundary/sdk/pbs/controller/api/resources/authmethods"
"github.com/hashicorp/go-secure-stdlib/parseutil"
"google.golang.org/grpc/codes"
)
const (
// commands
startCommand = "start"
callbackCommand = "callback"
tokenCommand = "token"
// token request/response fields
statusField = "status"
// field names
issuerField = "attributes.issuer"
clientSecretField = "attributes.client_secret"
clientIdField = "attributes.client_id"
clientSecretHmacField = "attributes.client_secret_hmac"
stateField = "attributes.state"
callbackUrlField = "attributes.callback_url"
apiUrlPrefixField = "attributes.api_url_prefix"
idpCaCertsField = "attributes.idp_ca_certs"
signingAlgorithmField = "attributes.signing_algorithms"
disableDiscoveredConfigValidationField = "attributes.disable_discovered_config_validation"
roundtripPayloadAttributesField = "attributes.roundtrip_payload"
codeField = "attributes.code"
claimsScopesField = "attributes.claims_scopes"
accountClaimMapsField = "attributes.account_claim_maps"
promptsField = "attributes.prompts"
)
var oidcMaskManager handlers.MaskManager
func init() {
var err error
if oidcMaskManager, err = handlers.NewMaskManager(context.Background(), handlers.MaskDestination{&oidcstore.AuthMethod{}}, handlers.MaskSource{&pb.AuthMethod{}, &pb.OidcAuthMethodAttributes{}}); err != nil {
panic(err)
}
IdActions[oidc.Subtype] = action.NewActionSet(
action.NoOp,
action.Read,
action.Update,
action.Delete,
action.ChangeState,
action.Authenticate,
)
}
type oidcState uint
const (
_ oidcState = iota
inactiveState
privateState
publicState
)
var oidcStateMap = map[string]oidcState{
inactiveState.String(): inactiveState,
privateState.String(): privateState,
publicState.String(): publicState,
}
func (o oidcState) String() string {
return [...]string{
"unknown",
"inactive",
"active-private",
"active-public",
}[o]
}
// createOidcInRepo creates an oidc auth method in a repo and returns the result.
// This method should never return a nil AuthMethod without returning an error.
func (s Service) createOidcInRepo(ctx context.Context, scopeId string, item *pb.AuthMethod) (*oidc.AuthMethod, error) {
u, _, _, err := toStorageOidcAuthMethod(ctx, scopeId, item)
if err != nil {
return nil, err
}
repo, err := s.oidcRepoFn()
if err != nil {
return nil, err
}
out, err := repo.CreateAuthMethod(ctx, u)
if err != nil {
return nil, fmt.Errorf("unable to create auth method: %w", err)
}
return out, nil
}
func (s Service) updateOidcInRepo(ctx context.Context, scopeId string, req *pbs.UpdateAuthMethodRequest) (*oidc.AuthMethod, bool, error) {
item := req.GetItem()
u, dryRun, forced, err := toStorageOidcAuthMethod(ctx, scopeId, item)
if err != nil {
return nil, dryRun, err
}
u.PublicId = req.GetId()
var opts []oidc.Option
if forced {
opts = append(opts, oidc.WithForce())
}
if dryRun {
opts = append(opts, oidc.WithDryRun())
}
version := item.GetVersion()
dbMask := oidcMaskManager.Translate(req.GetUpdateMask().GetPaths())
if len(dbMask) == 0 {
return nil, dryRun, handlers.InvalidArgumentErrorf("No valid fields included in the update mask.", map[string]string{"update_mask": "No valid fields provided in the update mask."})
}
repo, err := s.oidcRepoFn()
if err != nil {
return nil, dryRun, err
}
out, rowsUpdated, err := repo.UpdateAuthMethod(ctx, u, version, dbMask, opts...)
if err != nil {
return nil, dryRun, fmt.Errorf("unable to update auth method: %w", err)
}
if rowsUpdated == 0 && !dryRun && out == nil {
return nil, dryRun, handlers.NotFoundErrorf("AuthMethod %q doesn't exist or incorrect version provided.", req.GetId())
}
return out, dryRun, nil
}
func (s Service) authenticateOidc(ctx context.Context, req *pbs.AuthenticateRequest, authResults *auth.VerifyResults) (*pbs.AuthenticateResponse, error) {
const op = "authmethod_service.(Service).authenticateOidc"
if req == nil {
return nil, errors.New(ctx, errors.InvalidParameter, op, "Nil request.")
}
if authResults == nil {
return nil, errors.New(ctx, errors.InvalidParameter, op, "Nil auth results.")
}
switch req.GetCommand() {
case startCommand:
return s.authenticateOidcStart(ctx, req)
case callbackCommand:
return s.authenticateOidcCallback(ctx, req)
case tokenCommand:
return s.authenticateOidcToken(ctx, req, authResults)
}
return &pbs.AuthenticateResponse{Command: req.GetCommand()}, nil
}
func (s Service) authenticateOidcStart(ctx context.Context, req *pbs.AuthenticateRequest) (*pbs.AuthenticateResponse, error) {
const op = "authmethod_service.(Service).authenticateOidcStart"
if req == nil {
return nil, handlers.InvalidArgumentErrorf("Nil request.", nil)
}
var opts []oidc.Option
attrs := req.GetOidcStartAttributes()
if attrs.GetCachedRoundtripPayload() != "" {
opts = append(opts, oidc.WithRoundtripPayload(attrs.GetCachedRoundtripPayload()))
}
authUrl, tokenId, err := oidc.StartAuth(ctx, s.oidcRepoFn, req.GetAuthMethodId(), opts...)
switch {
case errors.Match(errors.T(errors.AuthMethodInactive), err):
return nil, handlers.ApiErrorWithCodeAndMessage(codes.FailedPrecondition, "Cannot start authentication against an inactive OIDC auth method")
case errors.Match(errors.T(errors.RecordNotFound), err):
return nil, handlers.ApiErrorWithCodeAndMessage(codes.NotFound, "Auth method %s was not found", req.GetAuthMethodId())
case errors.Match(errors.T(errors.InvalidParameter), err):
return nil, handlers.ApiErrorWithCodeAndMessage(codes.InvalidArgument, err.Error())
case err != nil:
event.WriteError(ctx, op, err, event.WithInfoMsg("error starting the oidc authentication flow"))
return nil, handlers.ApiErrorWithCodeAndMessage(codes.Internal, "Error generating parameters for starting the OIDC flow. See the controller's log for more information.")
}
return &pbs.AuthenticateResponse{
Command: req.GetCommand(),
Attrs: &pbs.AuthenticateResponse_OidcAuthMethodAuthenticateStartResponse{
OidcAuthMethodAuthenticateStartResponse: &pb.OidcAuthMethodAuthenticateStartResponse{
AuthUrl: authUrl.String(),
TokenId: tokenId,
},
},
}, nil
}
// authenticateOidcCallback behaves differently than other service methods.
// Because of the way it this is called by the end user, it should only return
// an error if we are unable to lookup the auth method or the request
// parameters were invalid. All other errors should be returned back through
// the response as a finalRedirectUrl to an endpoint that can properly show the
// error details.
func (s Service) authenticateOidcCallback(ctx context.Context, req *pbs.AuthenticateRequest) (*pbs.AuthenticateResponse, error) {
const op = "authmethod_service.(Service).authenticateOidcCallback"
// TODO: Return all errors (including the validate request based errors
// in the redirect URL once we start looking at the url used for this
// request instead of requiring the API URL to be set on the auth method.
if req == nil {
return nil, handlers.InvalidArgumentErrorf("Nil request.", nil)
}
repo, err := s.oidcRepoFn()
if err != nil {
return nil, handlers.ApiErrorWithCodeAndMessage(codes.Internal, err.Error())
}
am, err := repo.LookupAuthMethod(ctx, req.GetAuthMethodId())
if err != nil {
return nil, handlers.ApiErrorWithCodeAndMessage(codes.Internal, err.Error())
}
if am == nil {
return nil, handlers.NotFoundErrorf("Auth method %s not found.", req.GetAuthMethodId())
}
if am.GetApiUrl() == "" {
return nil, handlers.InvalidArgumentErrorf("Auth method doesn't have API URL defined.", nil)
}
errRedirectBase := fmt.Sprintf(oidc.AuthenticationErrorsEndpoint, am.GetApiUrl())
errResponse := func(err error) (*pbs.AuthenticateResponse, error) {
u := make(url.Values)
pbErr := handlers.ToApiError(err)
out, err := handlers.JSONMarshaler().Marshal(pbErr)
if err != nil {
event.WriteError(ctx, op, err, event.WithInfoMsg("unable to marshal the error for callback"))
return nil, handlers.ApiErrorWithCodeAndMessage(codes.Internal, "unable to marshal the error for callback")
}
u.Add("error", string(out))
errRedirect := fmt.Sprintf("%s?%s", errRedirectBase, u.Encode())
return &pbs.AuthenticateResponse{
Command: callbackCommand,
Attrs: &pbs.AuthenticateResponse_OidcAuthMethodAuthenticateCallbackResponse{
OidcAuthMethodAuthenticateCallbackResponse: &pb.OidcAuthMethodAuthenticateCallbackResponse{
FinalRedirectUrl: errRedirect,
},
},
}, nil
}
attrs := req.GetOidcAuthMethodAuthenticateCallbackRequest()
var finalRedirectUrl string
if attrs.GetError() != "" {
err := errors.Wrap(ctx, fmt.Errorf("Error: %q, Details: %q", attrs.GetError(), attrs.GetErrorDescription()), op, errors.WithCode(errors.OidcProviderCallbackError))
return errResponse(err)
}
finalRedirectUrl, err = oidc.Callback(
ctx,
s.oidcRepoFn,
oidc.IamRepoFactory(s.iamRepoFn),
s.atRepoFn,
am,
attrs.GetState(),
attrs.GetCode())
if err != nil {
return errResponse(errors.New(ctx, errors.InvalidParameter, op, "Callback validation failed.", errors.WithWrap(err)))
}
return &pbs.AuthenticateResponse{
Command: req.GetCommand(),
Attrs: &pbs.AuthenticateResponse_OidcAuthMethodAuthenticateCallbackResponse{
OidcAuthMethodAuthenticateCallbackResponse: &pb.OidcAuthMethodAuthenticateCallbackResponse{
FinalRedirectUrl: finalRedirectUrl,
},
},
}, nil
}
func (s Service) authenticateOidcToken(ctx context.Context, req *pbs.AuthenticateRequest, authResults *auth.VerifyResults) (*pbs.AuthenticateResponse, error) {
const op = "authmethod_service.(Service).authenticateOidcToken"
if req == nil {
return nil, handlers.InvalidArgumentErrorf("Nil request.", nil)
}
if authResults == nil {
return nil, handlers.InvalidArgumentErrorf("Nil auth results.", nil)
}
if req.GetOidcAuthMethodAuthenticateTokenRequest() == nil {
return nil, handlers.InvalidArgumentErrorf("Nil request attributes.", nil)
}
attrs := req.GetOidcAuthMethodAuthenticateTokenRequest()
if attrs.TokenId == "" {
return nil, handlers.InvalidArgumentErrorf("Empty token ID in request attributes.", nil)
}
token, err := oidc.TokenRequest(ctx, s.kms, s.atRepoFn, req.GetAuthMethodId(), attrs.TokenId)
if err != nil {
switch {
case errors.Match(errors.T(errors.Forbidden), err):
return nil, handlers.ForbiddenError()
case errors.Match(errors.T(errors.AuthAttemptExpired), err):
return nil, handlers.ForbiddenError()
default:
event.WriteError(ctx, op, err, event.WithInfoMsg("error generating parameters for token request"))
return nil, handlers.ApiErrorWithCodeAndMessage(codes.Internal, "Error generating parameters for token request. See the controller's log for more information.")
}
}
if token == nil {
return &pbs.AuthenticateResponse{
Command: req.Command,
Attrs: &pbs.AuthenticateResponse_OidcAuthMethodAuthenticateTokenResponse{
OidcAuthMethodAuthenticateTokenResponse: &pb.OidcAuthMethodAuthenticateTokenResponse{
Status: "unknown",
},
},
}, nil
}
responseToken, err := s.ConvertInternalAuthTokenToApiAuthToken(
ctx,
token,
)
if err != nil {
event.WriteError(ctx, op, err, event.WithInfoMsg("error converting response to proper format."))
return nil, handlers.ApiErrorWithCodeAndMessage(codes.Internal, "Error converting response to proper format. See the controller's log for more information.")
}
return s.convertToAuthenticateResponse(ctx, req, authResults, responseToken)
}
func validateAuthenticateOidcRequest(_ context.Context, req *pbs.AuthenticateRequest) error {
badFields := make(map[string]string)
switch req.GetCommand() {
case startCommand:
if req.GetOidcStartAttributes() != nil {
attrs := req.GetOidcStartAttributes()
switch {
case attrs == nil:
badFields["attributes"] = "Attributes field not supplied request"
default:
// Ensure we pay no attention to cache information provided by the client
attrs.CachedRoundtripPayload = ""
payload := attrs.GetRoundtripPayload()
if payload == nil {
break
}
m, err := json.Marshal(payload.AsMap())
if err != nil {
// We don't know what's in this payload so we swallow the
// error, as it could be something sensitive.
badFields[roundtripPayloadAttributesField] = "Unable to marshal given value as JSON."
} else {
// Cache for later
attrs.CachedRoundtripPayload = string(m)
}
}
}
case callbackCommand:
attrs := req.GetOidcAuthMethodAuthenticateCallbackRequest()
switch {
case attrs == nil:
badFields["attributes"] = "Attributes field not supplied request"
return handlers.InvalidArgumentErrorf("This is a required field.", badFields)
default:
if attrs.GetCode() == "" && attrs.GetError() == "" {
badFields[codeField] = "Code field not supplied in callback request."
}
if attrs.GetState() == "" {
badFields[stateField] = "State field not supplied in callback request."
}
}
case tokenCommand:
tokenType := req.GetType()
if tokenType == "" {
// Fall back to deprecated field if type is not set
tokenType = req.GetTokenType()
}
tType := strings.ToLower(strings.TrimSpace(tokenType))
if tType != "" && tType != "token" && tType != "cookie" {
badFields[tokenTypeField] = `The only accepted types are "token" and "cookie".`
}
default:
badFields[commandField] = "Invalid command for this auth method type."
}
if len(badFields) > 0 {
return handlers.InvalidArgumentErrorf("Invalid fields provided in request.", badFields)
}
return nil
}
func toStorageOidcAuthMethod(ctx context.Context, scopeId string, in *pb.AuthMethod) (out *oidc.AuthMethod, dryRun, forced bool, err error) {
const op = "authmethod_service.toStorageOidcAuthMethod"
if in == nil {
return nil, false, false, errors.New(ctx, errors.InvalidParameter, op, "nil auth method.")
}
attrs := in.GetOidcAuthMethodsAttributes()
clientId := attrs.GetClientId().GetValue()
clientSecret := oidc.ClientSecret(attrs.GetClientSecret().GetValue())
var opts []oidc.Option
if in.GetName() != nil {
opts = append(opts, oidc.WithName(in.GetName().GetValue()))
}
if in.GetDescription() != nil {
opts = append(opts, oidc.WithDescription(in.GetDescription().GetValue()))
}
if iss := strings.TrimSpace(attrs.GetIssuer().GetValue()); iss != "" {
// Strip off everything after and including ".well-known/openid-configuration"
// but leave the "/" attached to the end.
iss = strings.SplitN(iss, ".well-known/", 2)[0]
iss, err := parseutil.NormalizeAddr(iss)
if err != nil {
return nil, false, false, errors.Wrap(ctx, err, op, errors.WithMsg("cannot normalize issuer"), errors.WithCode(errors.InvalidParameter))
}
issuer, err := url.Parse(iss)
if err != nil {
return nil, false, false, errors.Wrap(ctx, err, op, errors.WithMsg("cannot parse issuer"), errors.WithCode(errors.InvalidParameter))
}
opts = append(opts, oidc.WithIssuer(issuer))
}
if apiUrl := strings.TrimSpace(attrs.GetApiUrlPrefix().GetValue()); apiUrl != "" {
apiUrl, err := parseutil.NormalizeAddr(apiUrl)
if err != nil {
return nil, false, false, errors.Wrap(ctx, err, op, errors.WithMsg("cannot normalize api_url_prefix"), errors.WithCode(errors.InvalidParameter))
}
apiU, err := url.Parse(apiUrl)
if err != nil {
return nil, false, false, errors.Wrap(ctx, err, op, errors.WithMsg("cannot parse api_url_prefix"), errors.WithCode(errors.InvalidParameter))
}
opts = append(opts, oidc.WithApiUrl(apiU))
}
if attrs.GetMaxAge() != nil {
maxAge := attrs.GetMaxAge().GetValue()
if maxAge == 0 {
opts = append(opts, oidc.WithMaxAge(-1))
} else {
opts = append(opts, oidc.WithMaxAge(int(maxAge)))
}
}
var signAlgs []oidc.Alg
for _, a := range attrs.GetSigningAlgorithms() {
signAlgs = append(signAlgs, oidc.Alg(a))
}
if len(signAlgs) > 0 {
opts = append(opts, oidc.WithSigningAlgs(signAlgs...))
}
var prompts []oidc.PromptParam
for _, a := range attrs.GetPrompts() {
prompts = append(prompts, oidc.PromptParam(a))
}
if len(prompts) > 0 {
opts = append(opts, oidc.WithPrompts(prompts...))
}
if len(attrs.GetAllowedAudiences()) > 0 {
opts = append(opts, oidc.WithAudClaims(attrs.GetAllowedAudiences()...))
}
if len(attrs.GetIdpCaCerts()) > 0 {
certs, err := oidc.ParseCertificates(ctx, attrs.GetIdpCaCerts()...)
if err != nil {
return nil, false, false, err
}
opts = append(opts, oidc.WithCertificates(certs...))
}
if len(attrs.GetClaimsScopes()) > 0 {
opts = append(opts, oidc.WithClaimsScopes(attrs.GetClaimsScopes()...))
}
if len(attrs.GetAccountClaimMaps()) > 0 {
claimsMap := make(map[string]oidc.AccountToClaim, len(attrs.GetAccountClaimMaps()))
for _, v := range attrs.GetAccountClaimMaps() {
acm, err := oidc.ParseAccountClaimMaps(ctx, v)
if err != nil {
return nil, false, false, errors.Wrap(ctx, err, op)
}
if len(acm) > 1 {
return nil, false, false, errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("unable to parse account claim map %s", v))
}
var m oidc.ClaimMap
for _, m = range acm {
}
to, err := oidc.ConvertToAccountToClaim(ctx, m.To)
if err != nil {
return nil, false, false, errors.Wrap(ctx, err, op)
}
claimsMap[m.From] = to
}
opts = append(opts, oidc.WithAccountClaimMap(claimsMap))
}
u, err := oidc.NewAuthMethod(ctx, scopeId, clientId, clientSecret, opts...)
if err != nil {
return nil, false, false, handlers.ApiErrorWithCodeAndMessage(codes.Internal, "Unable to build auth method: %v.", err)
}
return u, attrs.GetDryRun(), attrs.GetDisableDiscoveredConfigValidation(), nil
}