package oidc import ( "context" "fmt" "net/http" "strings" "github.com/hashicorp/boundary/internal/db" dbcommon "github.com/hashicorp/boundary/internal/db/common" "github.com/hashicorp/boundary/internal/errors" "github.com/hashicorp/boundary/internal/kms" "github.com/hashicorp/boundary/internal/oplog" "github.com/hashicorp/go-multierror" "github.com/hashicorp/go-secure-stdlib/strutil" ) const ( OperationalStateField = "OperationalState" DisableDiscoveredConfigValidationField = "DisableDiscoveredConfigValidation" VersionField = "Version" NameField = "Name" DescriptionField = "Description" FilterField = "Filter" IssuerField = "Issuer" ClientIdField = "ClientId" ClientSecretField = "ClientSecret" CtClientSecretField = "CtClientSecret" ClientSecretHmacField = "ClientSecretHmac" MaxAgeField = "MaxAge" SigningAlgsField = "SigningAlgs" ApiUrlField = "ApiUrl" AudClaimsField = "AudClaims" CertificatesField = "Certificates" ClaimsScopesField = "ClaimsScopes" AccountClaimMapsField = "AccountClaimMaps" TokenClaimsField = "TokenClaims" UserinfoClaimsField = "UserinfoClaims" ) // UpdateAuthMethod will retrieve the auth method from the repository, // and update it based on the field masks provided. // // The auth method will not be persisted in the repository if the auth // method's OperationalStatus is currently ActivePublic or ActivePrivate // and the update would have resulted in an incomplete/non-operational // auth method. // // During update, the auth method will be tested/validated against its // provider's published OIDC discovery document. If this validation // succeeds, the auth method is persisted in the repository, and the // written auth method is returned. // // fieldMaskPaths provides field_mask.proto paths for fields that should // be updated. Fields will be set to NULL if the field is a // zero value and included in fieldMask. Name, Description, Issuer, // ClientId, ClientSecret, MaxAge are all updatable fields. The AuthMethod's // Value Objects of SigningAlgs, CallbackUrls, AudClaims and Certificates are // also updatable. if no updatable fields are included in the fieldMaskPaths, // then an error is returned. // // Options supported: // // * WithDryRun: when this option is provided, the auth method is retrieved from // the repo, updated based on the fieldMask, tested via Repository.ValidateDiscoveryInfo, // the results of the update are returned, and and any errors reported. The // updates are not peristed to the repository. // // * WithForce: when this option is provided, the auth method is persisted in // the repository without testing it's validity against its provider's published // OIDC discovery document. Even if this option is provided, the auth method will // not be persisted in the repository when the update would have resulted in // an incomplete/non-operational auth method and it's OperationalStatus is // currently ActivePublic or ActivePrivate. // // Also, a successful update will invalidate (delete) the Repository's // cache of the oidc.Provider for the AuthMethod. func (r *Repository) UpdateAuthMethod(ctx context.Context, am *AuthMethod, version uint32, fieldMaskPaths []string, opt ...Option) (*AuthMethod, int, error) { const op = "oidc.(Repository).UpdateAuthMethod" if am == nil { return nil, db.NoRowsAffected, errors.New(ctx, errors.InvalidParameter, op, "missing auth method") } if am.AuthMethod == nil { return nil, db.NoRowsAffected, errors.New(ctx, errors.InvalidParameter, op, "missing auth method store") } if am.PublicId == "" { return nil, db.NoRowsAffected, errors.New(ctx, errors.InvalidParameter, op, "missing public id") } if err := validateFieldMask(ctx, fieldMaskPaths); err != nil { return nil, db.NoRowsAffected, errors.Wrap(ctx, err, op) } dbMask, nullFields := dbcommon.BuildUpdatePaths( map[string]interface{}{ NameField: am.Name, DescriptionField: am.Description, IssuerField: am.Issuer, ClientIdField: am.ClientId, ClientSecretField: am.ClientSecret, MaxAgeField: am.MaxAge, SigningAlgsField: am.SigningAlgs, ApiUrlField: am.ApiUrl, AudClaimsField: am.AudClaims, CertificatesField: am.Certificates, ClaimsScopesField: am.ClaimsScopes, AccountClaimMapsField: am.AccountClaimMaps, }, fieldMaskPaths, nil, ) if len(dbMask) == 0 && len(nullFields) == 0 { return nil, db.NoRowsAffected, errors.New(ctx, errors.EmptyFieldMask, op, "empty field mask") } origAm, err := r.lookupAuthMethod(ctx, am.PublicId) if err != nil { return nil, db.NoRowsAffected, errors.Wrap(ctx, err, op) } if origAm == nil { return nil, db.NoRowsAffected, errors.New(ctx, errors.RecordNotFound, op, fmt.Sprintf("auth method %s", am.PublicId)) } // there's no reason to continue if another controller has already updated this auth method. if origAm.Version != version { return nil, db.NoRowsAffected, errors.New(ctx, errors.VersionMismatch, op, fmt.Sprintf("update version %d doesn't match db version %d", version, origAm.Version)) } opts := getOpts(opt...) if opts.withDryRun { updated := applyUpdate(am, origAm, fieldMaskPaths) if err := updated.isComplete(ctx); err != nil { return updated, db.NoRowsAffected, err } err := r.ValidateDiscoveryInfo(ctx, WithAuthMethod(updated)) return updated, db.NoRowsAffected, err } // prevent an "active" auth method from being updated in a manner that would create // an incomplete and unusable auth method. if origAm.OperationalState != string(InactiveState) { if err := applyUpdate(am, origAm, fieldMaskPaths).isComplete(ctx); err != nil { return nil, db.NoRowsAffected, errors.Wrap(ctx, err, op, errors.WithMsg("update would result in an incomplete auth method")) } } if !opts.withForce { if err := r.ValidateDiscoveryInfo(ctx, WithAuthMethod(applyUpdate(am, origAm, fieldMaskPaths))); err != nil { return nil, db.NoRowsAffected, errors.Wrap(ctx, err, op) } } addAlgs, deleteAlgs, err := valueObjectChanges(ctx, origAm.PublicId, SigningAlgVO, am.SigningAlgs, origAm.SigningAlgs, dbMask, nullFields) if err != nil { return nil, db.NoRowsAffected, errors.Wrap(ctx, err, op) } addCerts, deleteCerts, err := valueObjectChanges(ctx, origAm.PublicId, CertificateVO, am.Certificates, origAm.Certificates, dbMask, nullFields) if err != nil { return nil, db.NoRowsAffected, errors.Wrap(ctx, err, op) } addAuds, deleteAuds, err := valueObjectChanges(ctx, origAm.PublicId, AudClaimVO, am.AudClaims, origAm.AudClaims, dbMask, nullFields) if err != nil { return nil, db.NoRowsAffected, errors.Wrap(ctx, err, op) } addScopes, deleteScopes, err := valueObjectChanges(ctx, origAm.PublicId, ClaimsScopesVO, am.ClaimsScopes, origAm.ClaimsScopes, dbMask, nullFields) if err != nil { return nil, db.NoRowsAffected, errors.Wrap(ctx, err, op) } addMaps, deleteMaps, err := valueObjectChanges(ctx, origAm.PublicId, AccountClaimMapsVO, am.AccountClaimMaps, origAm.AccountClaimMaps, dbMask, nullFields) if err != nil { return nil, db.NoRowsAffected, errors.Wrap(ctx, err, op) } // we don't allow updates for "sub" claim maps, because we have no way to // determine if the updated "from" claim in the map might create collisions // with any existing account's subject. for _, rawCm := range addMaps { cm := rawCm.(*AccountClaimMap) if cm.ToClaim == string(ToSubClaim) { return nil, db.NoRowsAffected, errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("you cannot update account claim map %s=%s for the \"sub\" claim", cm.FromClaim, cm.ToClaim)) } } for _, rawCm := range deleteMaps { cm := rawCm.(*AccountClaimMap) if cm.ToClaim == string(ToSubClaim) { return nil, db.NoRowsAffected, errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("you cannot update account claim map %s=%s for the \"sub\" claim", cm.FromClaim, cm.ToClaim)) } } var filteredDbMask, filteredNullFields []string for _, f := range dbMask { switch f { case SigningAlgsField, AudClaimsField, CertificatesField, ClaimsScopesField, AccountClaimMapsField: continue default: filteredDbMask = append(filteredDbMask, f) } } for _, f := range nullFields { switch f { case SigningAlgsField, AudClaimsField, CertificatesField, ClaimsScopesField, AccountClaimMapsField: continue default: filteredNullFields = append(filteredNullFields, f) } } // handle no changes... if len(filteredDbMask) == 0 && len(filteredNullFields) == 0 && len(addAlgs) == 0 && len(deleteAlgs) == 0 && len(addCerts) == 0 && len(deleteCerts) == 0 && len(addAuds) == 0 && len(deleteAuds) == 0 && len(addScopes) == 0 && len(deleteScopes) == 0 && len(addMaps) == 0 && len(deleteMaps) == 0 { return origAm, db.NoRowsAffected, nil } // ClientSecret is a bit odd, because it uses the Struct wrapping, we need // to add the encrypted fields to the dbMask or nullFields if strutil.StrListContains(filteredDbMask, ClientSecretField) { filteredDbMask = append(filteredDbMask, CtClientSecretField, ClientSecretHmacField) } if strutil.StrListContains(filteredNullFields, ClientSecretField) { filteredNullFields = append(filteredNullFields, CtClientSecretField, ClientSecretHmacField) } databaseWrapper, err := r.kms.GetWrapper(ctx, origAm.ScopeId, kms.KeyPurposeDatabase) if err != nil { return nil, db.NoRowsAffected, errors.Wrap(ctx, err, op, errors.WithMsg("unable to get database wrapper")) } if err := am.encrypt(ctx, databaseWrapper); err != nil { return nil, db.NoRowsAffected, errors.Wrap(ctx, err, op) } oplogWrapper, err := r.kms.GetWrapper(ctx, origAm.ScopeId, kms.KeyPurposeOplog) if err != nil { return nil, db.NoRowsAffected, errors.Wrap(ctx, err, op, errors.WithMsg("unable to get oplog wrapper")) } // we always set this to the current value of opts.withForce am.DisableDiscoveredConfigValidation = opts.withForce var updatedAm *AuthMethod var rowsUpdated int _, err = r.writer.DoTx( ctx, db.StdRetryCnt, db.ExpBackoff{}, func(reader db.Reader, w db.Writer) error { msgs := make([]*oplog.Message, 0, 7) // AuthMethod, Algs*2, Certs*2, Audiences*2 ticket, err := w.GetTicket(ctx, am) if err != nil { return errors.Wrap(ctx, err, op, errors.WithMsg("unable to get ticket")) } var authMethodOplogMsg oplog.Message switch { case len(filteredDbMask) == 0 && len(filteredNullFields) == 0: // the auth method's fields are not being updated, just it's value objects, so we need to just update the auth // method's version. updatedAm = am.Clone() updatedAm.Version = uint32(version) + 1 rowsUpdated, err = w.Update(ctx, updatedAm, []string{VersionField, DisableDiscoveredConfigValidationField}, nil, db.NewOplogMsg(&authMethodOplogMsg), db.WithVersion(&version)) if err != nil { return errors.Wrap(ctx, err, op, errors.WithMsg("unable to update auth method version")) } if rowsUpdated != 1 { return errors.New(ctx, errors.MultipleRecords, op, fmt.Sprintf("updated auth method version and %d rows updated", rowsUpdated)) } default: filteredDbMask = append(filteredDbMask, DisableDiscoveredConfigValidationField) updatedAm = am.Clone() rowsUpdated, err = w.Update(ctx, updatedAm, filteredDbMask, filteredNullFields, db.NewOplogMsg(&authMethodOplogMsg), db.WithVersion(&version)) if err != nil { return errors.Wrap(ctx, err, op, errors.WithMsg("unable to update auth method")) } if rowsUpdated != 1 { return errors.New(ctx, errors.MultipleRecords, op, fmt.Sprintf("updated auth method and %d rows updated", rowsUpdated)) } } msgs = append(msgs, &authMethodOplogMsg) if len(deleteAlgs) > 0 { deleteAlgOplogMsgs := make([]*oplog.Message, 0, len(deleteAlgs)) rowsDeleted, err := w.DeleteItems(ctx, deleteAlgs, db.NewOplogMsgs(&deleteAlgOplogMsgs)) if err != nil { return errors.Wrap(ctx, err, op, errors.WithMsg("unable to delete signing algorithms")) } if rowsDeleted != len(deleteAlgs) { return errors.New(ctx, errors.MultipleRecords, op, fmt.Sprintf("signing algorithms deleted %d did not match request for %d", rowsDeleted, len(deleteAlgs))) } msgs = append(msgs, deleteAlgOplogMsgs...) } if len(addAlgs) > 0 { addAlgsOplogMsgs := make([]*oplog.Message, 0, len(addAlgs)) if err := w.CreateItems(ctx, addAlgs, db.NewOplogMsgs(&addAlgsOplogMsgs)); err != nil { return errors.Wrap(ctx, err, op, errors.WithMsg("unable to add signing algorithms")) } msgs = append(msgs, addAlgsOplogMsgs...) } if len(deleteCerts) > 0 { deleteCertOplogMsgs := make([]*oplog.Message, 0, len(deleteCerts)) rowsDeleted, err := w.DeleteItems(ctx, deleteCerts, db.NewOplogMsgs(&deleteCertOplogMsgs)) if err != nil { return errors.Wrap(ctx, err, op, errors.WithMsg("unable to delete certificates")) } if rowsDeleted != len(deleteCerts) { return errors.New(ctx, errors.MultipleRecords, op, fmt.Sprintf("certificates deleted %d did not match request for %d", rowsDeleted, len(deleteCerts))) } msgs = append(msgs, deleteCertOplogMsgs...) } if len(addCerts) > 0 { addCertsOplogMsgs := make([]*oplog.Message, 0, len(addCerts)) if err := w.CreateItems(ctx, addCerts, db.NewOplogMsgs(&addCertsOplogMsgs)); err != nil { return errors.Wrap(ctx, err, op, errors.WithMsg("unable to add certificates")) } msgs = append(msgs, addCertsOplogMsgs...) } if len(deleteAuds) > 0 { deleteAudsOplogMsgs := make([]*oplog.Message, 0, len(deleteAuds)) rowsDeleted, err := w.DeleteItems(ctx, deleteAuds, db.NewOplogMsgs(&deleteAudsOplogMsgs)) if err != nil { return errors.Wrap(ctx, err, op, errors.WithMsg("unable to delete audiences URLs")) } if rowsDeleted != len(deleteAuds) { return errors.New(ctx, errors.MultipleRecords, op, fmt.Sprintf("audiences deleted %d did not match request for %d", rowsDeleted, len(deleteAuds))) } msgs = append(msgs, deleteAudsOplogMsgs...) } if len(addAuds) > 0 { addAudsOplogMsgs := make([]*oplog.Message, 0, len(addAuds)) if err := w.CreateItems(ctx, addAuds, db.NewOplogMsgs(&addAudsOplogMsgs)); err != nil { return errors.Wrap(ctx, err, op, errors.WithMsg("unable to add audiences URLs")) } msgs = append(msgs, addAudsOplogMsgs...) } if len(deleteScopes) > 0 { deleteScopesOplogMsgs := make([]*oplog.Message, 0, len(deleteScopes)) rowsDeleted, err := w.DeleteItems(ctx, deleteScopes, db.NewOplogMsgs(&deleteScopesOplogMsgs)) if err != nil { return errors.Wrap(ctx, err, op, errors.WithMsg("unable to delete claims scopes")) } if rowsDeleted != len(deleteScopes) { return errors.New(ctx, errors.MultipleRecords, op, fmt.Sprintf("claims scopes deleted %d did not match request for %d", rowsDeleted, len(deleteScopes))) } msgs = append(msgs, deleteScopesOplogMsgs...) } if len(addScopes) > 0 { addScopesOplogMsgs := make([]*oplog.Message, 0, len(addScopes)) if err := w.CreateItems(ctx, addScopes, db.NewOplogMsgs(&addScopesOplogMsgs)); err != nil { return errors.Wrap(ctx, err, op, errors.WithMsg("unable to add claims scopes")) } msgs = append(msgs, addScopesOplogMsgs...) } if len(deleteMaps) > 0 { deleteMapsOplogMsgs := make([]*oplog.Message, 0, len(deleteMaps)) rowsDeleted, err := w.DeleteItems(ctx, deleteMaps, db.NewOplogMsgs(&deleteMapsOplogMsgs)) if err != nil { return errors.Wrap(ctx, err, op, errors.WithMsg("unable to delete account claim maps")) } if rowsDeleted != len(deleteMaps) { return errors.New(ctx, errors.MultipleRecords, op, fmt.Sprintf("account claim maps deleted %d did not match request for %d", rowsDeleted, len(deleteMaps))) } msgs = append(msgs, deleteMapsOplogMsgs...) } if len(addMaps) > 0 { addMapsOplogMsgs := make([]*oplog.Message, 0, len(addMaps)) if err := w.CreateItems(ctx, addMaps, db.NewOplogMsgs(&addMapsOplogMsgs)); err != nil { return errors.Wrap(ctx, err, op, errors.WithMsg("unable to add account claim maps")) } msgs = append(msgs, addMapsOplogMsgs...) } metadata := updatedAm.oplog(oplog.OpType_OP_TYPE_UPDATE) if err := w.WriteOplogEntryWith(ctx, oplogWrapper, ticket, metadata, msgs); err != nil { return errors.Wrap(ctx, err, op, errors.WithMsg("unable to write oplog")) } // we need a new repo, that's using the same reader/writer as this TxHandler txRepo := &Repository{ reader: reader, writer: w, kms: r.kms, // intentionally not setting the defaultLimit, so we'll get all // the account ids without a limit } updatedAm, err = txRepo.lookupAuthMethod(ctx, updatedAm.PublicId) if err != nil { return errors.Wrap(ctx, err, op, errors.WithMsg("unable to lookup auth method after update")) } if updatedAm == nil { return errors.New(ctx, errors.RecordNotFound, op, "unable to lookup auth method after update") } return nil }, ) if err != nil { return nil, db.NoRowsAffected, errors.Wrap(ctx, err, op) } providerCache().delete(ctx, updatedAm.PublicId) return updatedAm, rowsUpdated, nil } // voName represents the names of auth method value objects type voName string const ( SigningAlgVO voName = "SigningAlgs" CertificateVO voName = "Certificates" AudClaimVO voName = "AudClaims" ClaimsScopesVO voName = "ClaimsScopes" AccountClaimMapsVO voName = "AccountClaimMaps" ) // validVoName decides if the name is valid func validVoName(name voName) bool { switch name { case SigningAlgVO, CertificateVO, AudClaimVO, ClaimsScopesVO, AccountClaimMapsVO: return true default: return false } } // factoryFunc defines a func type for value object factories type factoryFunc func(ctx context.Context, publicId string, i interface{}) (interface{}, error) // supportedFactories are the currently supported factoryFunc for value objects var supportedFactories = map[voName]factoryFunc{ SigningAlgVO: func(ctx context.Context, publicId string, i interface{}) (interface{}, error) { str := fmt.Sprintf("%s", i) return NewSigningAlg(ctx, publicId, Alg(str)) }, CertificateVO: func(ctx context.Context, publicId string, i interface{}) (interface{}, error) { str := fmt.Sprintf("%s", i) return NewCertificate(ctx, publicId, str) }, AudClaimVO: func(ctx context.Context, publicId string, i interface{}) (interface{}, error) { str := fmt.Sprintf("%s", i) return NewAudClaim(ctx, publicId, str) }, ClaimsScopesVO: func(ctx context.Context, publicId string, i interface{}) (interface{}, error) { str := fmt.Sprintf("%s", i) return NewClaimsScope(ctx, publicId, str) }, AccountClaimMapsVO: func(ctx context.Context, publicId string, i interface{}) (interface{}, error) { const op = "oidc.AccountClaimMapsFactory" str := fmt.Sprintf("%s", i) acm, err := ParseAccountClaimMaps(ctx, str) if err != nil { return nil, err } if len(acm) > 1 { return nil, errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("unable to parse account claim map %s", str)) } var m ClaimMap for _, m = range acm { } to, err := ConvertToAccountToClaim(ctx, m.To) if err != nil { return nil, errors.Wrap(ctx, err, op) } return NewAccountClaimMap(ctx, publicId, m.From, to) }, } // valueObjectChanges takes the new and old list of VOs (value objects) and // using the dbMasks/nullFields it will return lists of VOs where need to be // added and deleted in order to reconcile auth method's value objects. func valueObjectChanges( ctx context.Context, publicId string, valueObjectName voName, newVOs, oldVOs, dbMask, nullFields []string, ) (add []interface{}, del []interface{}, e error) { const op = "valueObjectChanges" if publicId == "" { return nil, nil, errors.New(ctx, errors.InvalidParameter, op, "missing public id") } if !validVoName(valueObjectName) { return nil, nil, errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("invalid value object name: %s", valueObjectName)) } if !strutil.StrListContains(dbMask, string(valueObjectName)) && !strutil.StrListContains(nullFields, string(valueObjectName)) { return nil, nil, nil } if len(strutil.RemoveDuplicates(newVOs, false)) != len(newVOs) { return nil, nil, errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("duplicate new %s", valueObjectName)) } if len(strutil.RemoveDuplicates(oldVOs, false)) != len(oldVOs) { return nil, nil, errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("duplicate old %s", valueObjectName)) } factory, ok := supportedFactories[valueObjectName] if !ok { return nil, nil, errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("unsupported factory for value object: %s", valueObjectName)) } if len(dbMask) == 0 && len(nullFields) == 0 { return nil, nil, nil } foundVOs := map[string]bool{} for _, a := range oldVOs { foundVOs[a] = true } var adds []interface{} var deletes []interface{} if strutil.StrListContains(nullFields, string(valueObjectName)) { deletes = make([]interface{}, 0, len(oldVOs)) for _, v := range oldVOs { deleteObj, err := factory(ctx, publicId, v) if err != nil { return nil, nil, errors.Wrap(ctx, err, op) } deletes = append(deletes, deleteObj) delete(foundVOs, v) } } if strutil.StrListContains(dbMask, string(valueObjectName)) { adds = make([]interface{}, 0, len(newVOs)) for _, v := range newVOs { if _, ok := foundVOs[v]; ok { delete(foundVOs, v) continue } obj, err := factory(ctx, publicId, v) if err != nil { return nil, nil, errors.Wrap(ctx, err, op) } adds = append(adds, obj) delete(foundVOs, v) } } if len(foundVOs) > 0 { for v := range foundVOs { obj, err := factory(ctx, publicId, v) if err != nil { return nil, nil, errors.Wrap(ctx, err, op) } deletes = append(deletes, obj) delete(foundVOs, v) } } return adds, deletes, nil } // validateFieldMask check the field mask to ensure all the fields are updatable func validateFieldMask(ctx context.Context, fieldMaskPaths []string) error { const op = "validateFieldMask" for _, f := range fieldMaskPaths { switch { case strings.EqualFold(NameField, f): case strings.EqualFold(DescriptionField, f): case strings.EqualFold(IssuerField, f): case strings.EqualFold(ClientIdField, f): case strings.EqualFold(ClientSecretField, f): case strings.EqualFold(MaxAgeField, f): case strings.EqualFold(SigningAlgsField, f): case strings.EqualFold(ApiUrlField, f): case strings.EqualFold(AudClaimsField, f): case strings.EqualFold(CertificatesField, f): case strings.EqualFold(ClaimsScopesField, f): case strings.EqualFold(AccountClaimMapsField, f): default: return errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("invalid field mask: %s", f)) } } return nil } // applyUpdate takes the new and applies it to the orig using the field masks func applyUpdate(new, orig *AuthMethod, fieldMaskPaths []string) *AuthMethod { cp := orig.Clone() for _, f := range fieldMaskPaths { switch f { case NameField: cp.Name = new.Name case DescriptionField: cp.Description = new.Description case IssuerField: cp.Issuer = new.Issuer case ClientIdField: cp.ClientId = new.ClientId case ClientSecretField: cp.ClientSecret = new.ClientSecret case MaxAgeField: cp.MaxAge = new.MaxAge case ApiUrlField: cp.ApiUrl = new.ApiUrl case SigningAlgsField: switch { case len(new.SigningAlgs) == 0: cp.SigningAlgs = nil default: cp.SigningAlgs = make([]string, 0, len(new.SigningAlgs)) cp.SigningAlgs = append(cp.SigningAlgs, new.SigningAlgs...) } case AudClaimsField: switch { case len(new.AudClaims) == 0: cp.AudClaims = nil default: cp.AudClaims = make([]string, 0, len(new.AudClaims)) cp.AudClaims = append(cp.AudClaims, new.AudClaims...) } case CertificatesField: switch { case len(new.Certificates) == 0: cp.Certificates = nil default: cp.Certificates = make([]string, 0, len(new.Certificates)) cp.Certificates = append(cp.Certificates, new.Certificates...) } case ClaimsScopesField: switch { case len(new.ClaimsScopes) == 0: cp.ClaimsScopes = nil default: cp.ClaimsScopes = make([]string, 0, len(new.ClaimsScopes)) cp.ClaimsScopes = append(cp.ClaimsScopes, new.ClaimsScopes...) } case AccountClaimMapsField: switch { case len(new.AccountClaimMaps) == 0: cp.AccountClaimMaps = nil default: cp.AccountClaimMaps = make([]string, 0, len(new.AccountClaimMaps)) cp.AccountClaimMaps = append(cp.AccountClaimMaps, new.AccountClaimMaps...) } } } return cp } // ValidateDiscoveryInfo will test/validate the provided AuthMethod against // the info from it's discovery URL. // // It will verify that all required fields for a working AuthMethod have values. // // If the AuthMethod is complete, ValidateDiscoveryInfo retrieves the auth // method's OpenID Configuration document. The values in the AuthMethod // (and associated data) are validated with the retrieved document. The issuer and // id token signing algorithm in the configuration are validated with the // retrieved document. ValidateDiscoveryInfo also verifies the authorization, token, // and userinfo endpoints by connecting to each and uses any certificates in the // configuration as trust anchors to confirm connectivity. // // Options supported are: WithPublicId, WithAuthMethod func (r *Repository) ValidateDiscoveryInfo(ctx context.Context, opt ...Option) error { const op = "oidc.(Repository).ValidateDiscoveryInfo" opts := getOpts(opt...) var am *AuthMethod switch { case opts.withPublicId != "": var err error am, err = r.lookupAuthMethod(ctx, opts.withPublicId) if err != nil { return errors.Wrap(ctx, err, op) } if am == nil { return errors.New(ctx, errors.RecordNotFound, op, fmt.Sprintf("unable to lookup auth method %s", opts.withPublicId)) } case opts.withAuthMethod != nil: am = opts.withAuthMethod default: return errors.New(ctx, errors.InvalidParameter, op, "neither WithPublicId(...) nor WithAuthMethod(...) options were provided") } // FYI: once converted to an oidc.Provider, any certs configured will be used as trust anchors for all HTTP requests provider, err := convertToProvider(ctx, am) if err != nil && am.OperationalState == string(InactiveState) { return nil } if err != nil { return errors.Wrap(ctx, err, op) } info, err := provider.DiscoveryInfo(ctx) if err != nil { return errors.Wrap(ctx, err, op) } var result *multierror.Error if info.Issuer != am.Issuer { result = multierror.Append(result, errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("auth method issuer doesn't match discovery issuer: expected %s and got %s", am.Issuer, info.Issuer))) } for _, a := range am.SigningAlgs { if !strutil.StrListContains(info.IdTokenSigningAlgsSupported, a) { result = multierror.Append(result, errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("auth method signing alg is not in discovered supported algs: expected %s and got %s", a, info.IdTokenSigningAlgsSupported))) } } providerClient, err := provider.HTTPClient() if err != nil { result = multierror.Append(result, errors.New(ctx, errors.Unknown, op, "unable to get oidc http client", errors.WithWrap(err))) return result.ErrorOrNil() } // we need to prevent redirects during these tests... we don't want to have // redirects going to the controller's callback (aka the configured provider's callback) providerClient.CheckRedirect = func(req *http.Request, via []*http.Request) error { return http.ErrUseLastResponse } // test JWKs URL statusCode, err := pingEndpoint(ctx, providerClient, "JWKs", "GET", info.JWKSURL) if err != nil { result = multierror.Append(result, errors.New(ctx, errors.Unknown, op, fmt.Sprintf("unable to verify JWKs endpoint: %s", info.JWKSURL), errors.WithWrap(err))) } if statusCode != 200 { result = multierror.Append(result, errors.New(ctx, errors.Unknown, op, fmt.Sprintf("non-200 status (%d) from JWKs endpoint: %s", statusCode, info.JWKSURL), errors.WithWrap(err))) } // test Auth URL if _, err := pingEndpoint(ctx, providerClient, "AuthURL", "GET", info.AuthURL); err != nil { result = multierror.Append(result, errors.New(ctx, errors.Unknown, op, fmt.Sprintf("unable to verify authorize endpoint: %s", info.AuthURL), errors.WithWrap(err))) } // test Token URL if _, err := pingEndpoint(ctx, providerClient, "TokenURL", "POST", info.TokenURL); err != nil { result = multierror.Append(result, errors.New(ctx, errors.Unknown, op, fmt.Sprintf("unable to verify token endpoint: %s", info.TokenURL), errors.WithWrap(err))) } // we're not verifying the UserInfo URL, since it's not a required dependency. return result.ErrorOrNil() } type HTTPClient interface { Do(req *http.Request) (*http.Response, error) } // pingEndpoint will make an attempted http request, return status code and errors func pingEndpoint(ctx context.Context, client HTTPClient, endpointType, method, url string) (int, error) { const op = "oidc.pingEndpoint" req, err := http.NewRequestWithContext(ctx, method, url, nil) if err != nil { return 0, errors.New(ctx, errors.Unknown, op, fmt.Sprintf("unable to create %s http request", endpointType), errors.WithWrap(err)) } resp, err := client.Do(req) if err != nil { return 0, errors.New(ctx, errors.Unknown, op, fmt.Sprintf("request to %s endpoint failed", endpointType), errors.WithWrap(err)) } return resp.StatusCode, nil }