You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
boundary/internal/auth/oidc/repository_auth_method_upda...

858 lines
32 KiB

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package oidc
import (
"context"
stderrors "errors"
"fmt"
"net/http"
"strings"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/boundary/internal/kms"
"github.com/hashicorp/boundary/internal/oplog"
"github.com/hashicorp/go-dbw"
"github.com/hashicorp/go-secure-stdlib/strutil"
)
const (
OperationalStateField = "OperationalState"
DisableDiscoveredConfigValidationField = "DisableDiscoveredConfigValidation"
VersionField = "Version"
NameField = "Name"
DescriptionField = "Description"
FilterField = "Filter"
IssuerField = "Issuer"
ClientIdField = "ClientId"
ClientSecretField = "ClientSecret"
CtClientSecretField = "CtClientSecret"
ClientSecretHmacField = "ClientSecretHmac"
MaxAgeField = "MaxAge"
SigningAlgsField = "SigningAlgs"
ApiUrlField = "ApiUrl"
AudClaimsField = "AudClaims"
CertificatesField = "Certificates"
ClaimsScopesField = "ClaimsScopes"
AccountClaimMapsField = "AccountClaimMaps"
TokenClaimsField = "TokenClaims"
UserinfoClaimsField = "UserinfoClaims"
KeyIdField = "KeyId"
PromptsField = "Prompts"
)
// UpdateAuthMethod will retrieve the auth method from the repository,
// and update it based on the field masks provided.
//
// The auth method will not be persisted in the repository if the auth
// method's OperationalStatus is currently ActivePublic or ActivePrivate
// and the update would have resulted in an incomplete/non-operational
// auth method.
//
// During update, the auth method will be tested/validated against its
// provider's published OIDC discovery document. If this validation
// succeeds, the auth method is persisted in the repository, and the
// written auth method is returned.
//
// fieldMaskPaths provides field_mask.proto paths for fields that should
// be updated. Fields will be set to NULL if the field is a
// zero value and included in fieldMask. Name, Description, Issuer,
// ClientId, ClientSecret, MaxAge are all updatable fields. The AuthMethod's
// Value Objects of SigningAlgs, Prompts, CallbackUrls, AudClaims and Certificates are
// also updatable. if no updatable fields are included in the fieldMaskPaths,
// then an error is returned.
//
// Options supported:
//
// * WithDryRun: when this option is provided, the auth method is retrieved from
// the repo, updated based on the fieldMask, tested via Repository.ValidateDiscoveryInfo,
// the results of the update are returned, and and any errors reported. The
// updates are not peristed to the repository.
//
// * WithForce: when this option is provided, the auth method is persisted in
// the repository without testing it's validity against its provider's published
// OIDC discovery document. Even if this option is provided, the auth method will
// not be persisted in the repository when the update would have resulted in
// an incomplete/non-operational auth method and it's OperationalStatus is
// currently ActivePublic or ActivePrivate.
//
// Also, a successful update will invalidate (delete) the Repository's
// cache of the oidc.Provider for the AuthMethod.
func (r *Repository) UpdateAuthMethod(ctx context.Context, am *AuthMethod, version uint32, fieldMaskPaths []string, opt ...Option) (*AuthMethod, int, error) {
const op = "oidc.(Repository).UpdateAuthMethod"
if am == nil {
return nil, db.NoRowsAffected, errors.New(ctx, errors.InvalidParameter, op, "missing auth method")
}
if am.AuthMethod == nil {
return nil, db.NoRowsAffected, errors.New(ctx, errors.InvalidParameter, op, "missing auth method store")
}
if am.PublicId == "" {
return nil, db.NoRowsAffected, errors.New(ctx, errors.InvalidParameter, op, "missing public id")
}
if err := validateFieldMask(ctx, fieldMaskPaths); err != nil {
return nil, db.NoRowsAffected, errors.Wrap(ctx, err, op)
}
dbMask, nullFields := dbw.BuildUpdatePaths(
map[string]any{
NameField: am.Name,
DescriptionField: am.Description,
IssuerField: am.Issuer,
ClientIdField: am.ClientId,
ClientSecretField: am.ClientSecret,
MaxAgeField: am.MaxAge,
SigningAlgsField: am.SigningAlgs,
ApiUrlField: am.ApiUrl,
AudClaimsField: am.AudClaims,
CertificatesField: am.Certificates,
ClaimsScopesField: am.ClaimsScopes,
AccountClaimMapsField: am.AccountClaimMaps,
PromptsField: am.Prompts,
},
fieldMaskPaths,
nil,
)
if len(dbMask) == 0 && len(nullFields) == 0 {
return nil, db.NoRowsAffected, errors.New(ctx, errors.EmptyFieldMask, op, "empty field mask")
}
origAm, err := r.lookupAuthMethod(ctx, am.PublicId)
if err != nil {
return nil, db.NoRowsAffected, errors.Wrap(ctx, err, op)
}
if origAm == nil {
return nil, db.NoRowsAffected, errors.New(ctx, errors.RecordNotFound, op, fmt.Sprintf("auth method %s", am.PublicId))
}
// there's no reason to continue if another controller has already updated this auth method.
if origAm.Version != version {
return nil, db.NoRowsAffected, errors.New(ctx, errors.VersionMismatch, op, fmt.Sprintf("update version %d doesn't match db version %d", version, origAm.Version))
}
opts := getOpts(opt...)
if opts.withDryRun {
updated := applyUpdate(am, origAm, fieldMaskPaths)
if err := updated.isComplete(ctx); err != nil {
return updated, db.NoRowsAffected, err
}
err := r.ValidateDiscoveryInfo(ctx, WithAuthMethod(updated))
return updated, db.NoRowsAffected, err
}
// prevent an "active" auth method from being updated in a manner that would create
// an incomplete and unusable auth method.
if origAm.OperationalState != string(InactiveState) {
if err := applyUpdate(am, origAm, fieldMaskPaths).isComplete(ctx); err != nil {
return nil, db.NoRowsAffected, errors.Wrap(ctx, err, op, errors.WithMsg("update would result in an incomplete auth method"))
}
}
if !opts.withForce {
if err := r.ValidateDiscoveryInfo(ctx, WithAuthMethod(applyUpdate(am, origAm, fieldMaskPaths))); err != nil {
return nil, db.NoRowsAffected, errors.Wrap(ctx, err, op)
}
}
aa, ad, err := valueObjectChanges(ctx, origAm.PublicId, SigningAlgVO, am.SigningAlgs, origAm.SigningAlgs, dbMask, nullFields)
if err != nil {
return nil, db.NoRowsAffected, errors.Wrap(ctx, err, op)
}
addAlgs := []*SigningAlg{}
for _, a := range aa {
addAlgs = append(addAlgs, a.(*SigningAlg))
}
deleteAlgs := []*SigningAlg{}
for _, a := range ad {
deleteAlgs = append(deleteAlgs, a.(*SigningAlg))
}
ac, dc, err := valueObjectChanges(ctx, origAm.PublicId, CertificateVO, am.Certificates, origAm.Certificates, dbMask, nullFields)
if err != nil {
return nil, db.NoRowsAffected, errors.Wrap(ctx, err, op)
}
addCerts := []*Certificate{}
for _, c := range ac {
addCerts = append(addCerts, c.(*Certificate))
}
deleteCerts := []*Certificate{}
for _, c := range dc {
deleteCerts = append(deleteCerts, c.(*Certificate))
}
aa, ad, err = valueObjectChanges(ctx, origAm.PublicId, AudClaimVO, am.AudClaims, origAm.AudClaims, dbMask, nullFields)
if err != nil {
return nil, db.NoRowsAffected, errors.Wrap(ctx, err, op)
}
addAuds := []*AudClaim{}
for _, a := range aa {
addAuds = append(addAuds, a.(*AudClaim))
}
deleteAuds := []*AudClaim{}
for _, a := range ad {
deleteAuds = append(deleteAuds, a.(*AudClaim))
}
as, ds, err := valueObjectChanges(ctx, origAm.PublicId, ClaimsScopesVO, am.ClaimsScopes, origAm.ClaimsScopes, dbMask, nullFields)
if err != nil {
return nil, db.NoRowsAffected, errors.Wrap(ctx, err, op)
}
addScopes := []*ClaimsScope{}
for _, s := range as {
addScopes = append(addScopes, s.(*ClaimsScope))
}
deleteScopes := []*ClaimsScope{}
for _, s := range ds {
deleteScopes = append(deleteScopes, s.(*ClaimsScope))
}
aacm, dacm, err := valueObjectChanges(ctx, origAm.PublicId, AccountClaimMapsVO, am.AccountClaimMaps, origAm.AccountClaimMaps, dbMask, nullFields)
if err != nil {
return nil, db.NoRowsAffected, errors.Wrap(ctx, err, op)
}
addMaps := []*AccountClaimMap{}
for _, m := range aacm {
addMaps = append(addMaps, m.(*AccountClaimMap))
}
deleteMaps := []*AccountClaimMap{}
for _, m := range dacm {
deleteMaps = append(deleteMaps, m.(*AccountClaimMap))
}
ap, dp, err := valueObjectChanges(ctx, origAm.PublicId, PromptsVO, am.Prompts, origAm.Prompts, dbMask, nullFields)
if err != nil {
return nil, db.NoRowsAffected, errors.Wrap(ctx, err, op)
}
addPrompts := []*Prompt{}
for _, p := range ap {
addPrompts = append(addPrompts, p.(*Prompt))
}
deletePrompts := []*Prompt{}
for _, p := range dp {
deletePrompts = append(deletePrompts, p.(*Prompt))
}
// we don't allow updates for "sub" claim maps, because we have no way to
// determine if the updated "from" claim in the map might create collisions
// with any existing account's subject.
for _, cm := range addMaps {
if cm.ToClaim == string(ToSubClaim) {
return nil, db.NoRowsAffected, errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("you cannot update account claim map %s=%s for the \"sub\" claim", cm.FromClaim, cm.ToClaim))
}
}
for _, cm := range deleteMaps {
if cm.ToClaim == string(ToSubClaim) {
return nil, db.NoRowsAffected, errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("you cannot update account claim map %s=%s for the \"sub\" claim", cm.FromClaim, cm.ToClaim))
}
}
var filteredDbMask, filteredNullFields []string
for _, f := range dbMask {
switch f {
case SigningAlgsField, AudClaimsField, CertificatesField, ClaimsScopesField, AccountClaimMapsField, PromptsField:
continue
default:
filteredDbMask = append(filteredDbMask, f)
}
}
for _, f := range nullFields {
switch f {
case SigningAlgsField, AudClaimsField, CertificatesField, ClaimsScopesField, AccountClaimMapsField, PromptsField:
continue
default:
filteredNullFields = append(filteredNullFields, f)
}
}
// handle no changes...
if len(filteredDbMask) == 0 &&
len(filteredNullFields) == 0 &&
len(addAlgs) == 0 &&
len(deleteAlgs) == 0 &&
len(addCerts) == 0 &&
len(deleteCerts) == 0 &&
len(addAuds) == 0 &&
len(deleteAuds) == 0 &&
len(addScopes) == 0 &&
len(deleteScopes) == 0 &&
len(addMaps) == 0 &&
len(deleteMaps) == 0 &&
len(addPrompts) == 0 &&
len(deletePrompts) == 0 {
return origAm, db.NoRowsAffected, nil
}
// ClientSecret is a bit odd, because it uses the Struct wrapping, we need
// to add the encrypted fields to the dbMask or nullFields
if strutil.StrListContains(filteredDbMask, ClientSecretField) {
filteredDbMask = append(filteredDbMask, CtClientSecretField, ClientSecretHmacField, KeyIdField)
}
if strutil.StrListContains(filteredNullFields, ClientSecretField) {
filteredNullFields = append(filteredNullFields, CtClientSecretField, ClientSecretHmacField, KeyIdField)
}
databaseWrapper, err := r.kms.GetWrapper(ctx, origAm.ScopeId, kms.KeyPurposeDatabase)
if err != nil {
return nil, db.NoRowsAffected, errors.Wrap(ctx, err, op, errors.WithMsg("unable to get database wrapper"))
}
if err := am.encrypt(ctx, databaseWrapper); err != nil {
return nil, db.NoRowsAffected, errors.Wrap(ctx, err, op)
}
oplogWrapper, err := r.kms.GetWrapper(ctx, origAm.ScopeId, kms.KeyPurposeOplog)
if err != nil {
return nil, db.NoRowsAffected, errors.Wrap(ctx, err, op, errors.WithMsg("unable to get oplog wrapper"))
}
// we always set this to the current value of opts.withForce
am.DisableDiscoveredConfigValidation = opts.withForce
var updatedAm *AuthMethod
var rowsUpdated int
_, err = r.writer.DoTx(
ctx,
db.StdRetryCnt,
db.ExpBackoff{},
func(reader db.Reader, w db.Writer) error {
msgs := make([]*oplog.Message, 0, 9) // AuthMethod, Algs*2, Certs*2, Audiences*2, Prompts*2
ticket, err := w.GetTicket(ctx, am)
if err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to get ticket"))
}
var authMethodOplogMsg oplog.Message
switch {
case len(filteredDbMask) == 0 && len(filteredNullFields) == 0:
// the auth method's fields are not being updated, just it's value objects, so we need to just update the auth
// method's version.
updatedAm = am.Clone()
updatedAm.Version = uint32(version) + 1
rowsUpdated, err = w.Update(ctx, updatedAm, []string{VersionField, DisableDiscoveredConfigValidationField}, nil, db.NewOplogMsg(&authMethodOplogMsg), db.WithVersion(&version))
if err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to update auth method version"))
}
if rowsUpdated != 1 {
return errors.New(ctx, errors.MultipleRecords, op, fmt.Sprintf("updated auth method version and %d rows updated", rowsUpdated))
}
default:
filteredDbMask = append(filteredDbMask, DisableDiscoveredConfigValidationField)
updatedAm = am.Clone()
rowsUpdated, err = w.Update(ctx, updatedAm, filteredDbMask, filteredNullFields, db.NewOplogMsg(&authMethodOplogMsg), db.WithVersion(&version))
if err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to update auth method"))
}
if rowsUpdated != 1 {
return errors.New(ctx, errors.MultipleRecords, op, fmt.Sprintf("updated auth method and %d rows updated", rowsUpdated))
}
}
msgs = append(msgs, &authMethodOplogMsg)
if len(deleteAlgs) > 0 {
deleteAlgOplogMsgs := make([]*oplog.Message, 0, len(deleteAlgs))
rowsDeleted, err := w.DeleteItems(ctx, deleteAlgs, db.NewOplogMsgs(&deleteAlgOplogMsgs))
if err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to delete signing algorithms"))
}
if rowsDeleted != len(deleteAlgs) {
return errors.New(ctx, errors.MultipleRecords, op, fmt.Sprintf("signing algorithms deleted %d did not match request for %d", rowsDeleted, len(deleteAlgs)))
}
msgs = append(msgs, deleteAlgOplogMsgs...)
}
if len(addAlgs) > 0 {
addAlgsOplogMsgs := make([]*oplog.Message, 0, len(addAlgs))
if err := w.CreateItems(ctx, addAlgs, db.NewOplogMsgs(&addAlgsOplogMsgs)); err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to add signing algorithms"))
}
msgs = append(msgs, addAlgsOplogMsgs...)
}
if len(deletePrompts) > 0 {
deletePromptOplogMsgs := make([]*oplog.Message, 0, len(deletePrompts))
rowsDeleted, err := w.DeleteItems(ctx, deletePrompts, db.NewOplogMsgs(&deletePromptOplogMsgs))
if err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to delete prompts"))
}
if rowsDeleted != len(deletePrompts) {
return errors.New(ctx, errors.MultipleRecords, op, fmt.Sprintf("prompts deleted %d did not match request for %d", rowsDeleted, len(deletePrompts)))
}
msgs = append(msgs, deletePromptOplogMsgs...)
}
if len(addPrompts) > 0 {
addPromptsOplogMsgs := make([]*oplog.Message, 0, len(addPrompts))
if err := w.CreateItems(ctx, addPrompts, db.NewOplogMsgs(&addPromptsOplogMsgs)); err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to add prompts"))
}
msgs = append(msgs, addPromptsOplogMsgs...)
}
if len(deleteCerts) > 0 {
deleteCertOplogMsgs := make([]*oplog.Message, 0, len(deleteCerts))
rowsDeleted, err := w.DeleteItems(ctx, deleteCerts, db.NewOplogMsgs(&deleteCertOplogMsgs))
if err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to delete certificates"))
}
if rowsDeleted != len(deleteCerts) {
return errors.New(ctx, errors.MultipleRecords, op, fmt.Sprintf("certificates deleted %d did not match request for %d", rowsDeleted, len(deleteCerts)))
}
msgs = append(msgs, deleteCertOplogMsgs...)
}
if len(addCerts) > 0 {
addCertsOplogMsgs := make([]*oplog.Message, 0, len(addCerts))
if err := w.CreateItems(ctx, addCerts, db.NewOplogMsgs(&addCertsOplogMsgs)); err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to add certificates"))
}
msgs = append(msgs, addCertsOplogMsgs...)
}
if len(deleteAuds) > 0 {
deleteAudsOplogMsgs := make([]*oplog.Message, 0, len(deleteAuds))
rowsDeleted, err := w.DeleteItems(ctx, deleteAuds, db.NewOplogMsgs(&deleteAudsOplogMsgs))
if err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to delete audiences URLs"))
}
if rowsDeleted != len(deleteAuds) {
return errors.New(ctx, errors.MultipleRecords, op, fmt.Sprintf("audiences deleted %d did not match request for %d", rowsDeleted, len(deleteAuds)))
}
msgs = append(msgs, deleteAudsOplogMsgs...)
}
if len(addAuds) > 0 {
addAudsOplogMsgs := make([]*oplog.Message, 0, len(addAuds))
if err := w.CreateItems(ctx, addAuds, db.NewOplogMsgs(&addAudsOplogMsgs)); err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to add audiences URLs"))
}
msgs = append(msgs, addAudsOplogMsgs...)
}
if len(deleteScopes) > 0 {
deleteScopesOplogMsgs := make([]*oplog.Message, 0, len(deleteScopes))
rowsDeleted, err := w.DeleteItems(ctx, deleteScopes, db.NewOplogMsgs(&deleteScopesOplogMsgs))
if err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to delete claims scopes"))
}
if rowsDeleted != len(deleteScopes) {
return errors.New(ctx, errors.MultipleRecords, op, fmt.Sprintf("claims scopes deleted %d did not match request for %d", rowsDeleted, len(deleteScopes)))
}
msgs = append(msgs, deleteScopesOplogMsgs...)
}
if len(addScopes) > 0 {
addScopesOplogMsgs := make([]*oplog.Message, 0, len(addScopes))
if err := w.CreateItems(ctx, addScopes, db.NewOplogMsgs(&addScopesOplogMsgs)); err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to add claims scopes"))
}
msgs = append(msgs, addScopesOplogMsgs...)
}
if len(deleteMaps) > 0 {
deleteMapsOplogMsgs := make([]*oplog.Message, 0, len(deleteMaps))
rowsDeleted, err := w.DeleteItems(ctx, deleteMaps, db.NewOplogMsgs(&deleteMapsOplogMsgs))
if err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to delete account claim maps"))
}
if rowsDeleted != len(deleteMaps) {
return errors.New(ctx, errors.MultipleRecords, op, fmt.Sprintf("account claim maps deleted %d did not match request for %d", rowsDeleted, len(deleteMaps)))
}
msgs = append(msgs, deleteMapsOplogMsgs...)
}
if len(addMaps) > 0 {
addMapsOplogMsgs := make([]*oplog.Message, 0, len(addMaps))
if err := w.CreateItems(ctx, addMaps, db.NewOplogMsgs(&addMapsOplogMsgs)); err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to add account claim maps"))
}
msgs = append(msgs, addMapsOplogMsgs...)
}
metadata := updatedAm.oplog(oplog.OpType_OP_TYPE_UPDATE)
if err := w.WriteOplogEntryWith(ctx, oplogWrapper, ticket, metadata, msgs); err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to write oplog"))
}
// we need a new repo, that's using the same reader/writer as this TxHandler
txRepo := &Repository{
reader: reader,
writer: w,
kms: r.kms,
// intentionally not setting the defaultLimit, so we'll get all
// the account ids without a limit
}
updatedAm, err = txRepo.lookupAuthMethod(ctx, updatedAm.PublicId)
if err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to lookup auth method after update"))
}
if updatedAm == nil {
return errors.New(ctx, errors.RecordNotFound, op, "unable to lookup auth method after update")
}
return nil
},
)
if err != nil {
return nil, db.NoRowsAffected, errors.Wrap(ctx, err, op)
}
providerCache().delete(ctx, updatedAm.PublicId)
return updatedAm, rowsUpdated, nil
}
// voName represents the names of auth method value objects
type voName string
const (
SigningAlgVO voName = "SigningAlgs"
CertificateVO voName = "Certificates"
AudClaimVO voName = "AudClaims"
ClaimsScopesVO voName = "ClaimsScopes"
AccountClaimMapsVO voName = "AccountClaimMaps"
PromptsVO voName = "Prompts"
)
// validVoName decides if the name is valid
func validVoName(name voName) bool {
switch name {
case SigningAlgVO, CertificateVO, AudClaimVO, ClaimsScopesVO, AccountClaimMapsVO, PromptsVO:
return true
default:
return false
}
}
// factoryFunc defines a func type for value object factories
type factoryFunc func(ctx context.Context, publicId string, i any) (any, error)
// supportedFactories are the currently supported factoryFunc for value objects
var supportedFactories = map[voName]factoryFunc{
SigningAlgVO: func(ctx context.Context, publicId string, i any) (any, error) {
str := fmt.Sprintf("%s", i)
return NewSigningAlg(ctx, publicId, Alg(str))
},
CertificateVO: func(ctx context.Context, publicId string, i any) (any, error) {
str := fmt.Sprintf("%s", i)
return NewCertificate(ctx, publicId, str)
},
AudClaimVO: func(ctx context.Context, publicId string, i any) (any, error) {
str := fmt.Sprintf("%s", i)
return NewAudClaim(ctx, publicId, str)
},
ClaimsScopesVO: func(ctx context.Context, publicId string, i any) (any, error) {
str := fmt.Sprintf("%s", i)
return NewClaimsScope(ctx, publicId, str)
},
AccountClaimMapsVO: func(ctx context.Context, publicId string, i any) (any, error) {
const op = "oidc.AccountClaimMapsFactory"
str := fmt.Sprintf("%s", i)
acm, err := ParseAccountClaimMaps(ctx, str)
if err != nil {
return nil, err
}
if len(acm) > 1 {
return nil, errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("unable to parse account claim map %s", str))
}
var m ClaimMap
for _, m = range acm {
}
to, err := ConvertToAccountToClaim(ctx, m.To)
if err != nil {
return nil, errors.Wrap(ctx, err, op)
}
return NewAccountClaimMap(ctx, publicId, m.From, to)
},
PromptsVO: func(ctx context.Context, publicId string, i any) (any, error) {
str := fmt.Sprintf("%s", i)
return NewPrompt(ctx, publicId, PromptParam(str))
},
}
// valueObjectChanges takes the new and old list of VOs (value objects) and
// using the dbMasks/nullFields it will return lists of VOs where need to be
// added and deleted in order to reconcile auth method's value objects.
func valueObjectChanges(
ctx context.Context,
publicId string,
valueObjectName voName,
newVOs,
oldVOs,
dbMask,
nullFields []string,
) (add []any, del []any, e error) {
const op = "valueObjectChanges"
if publicId == "" {
return nil, nil, errors.New(ctx, errors.InvalidParameter, op, "missing public id")
}
if !validVoName(valueObjectName) {
return nil, nil, errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("invalid value object name: %s", valueObjectName))
}
if !strutil.StrListContains(dbMask, string(valueObjectName)) && !strutil.StrListContains(nullFields, string(valueObjectName)) {
return nil, nil, nil
}
if len(strutil.RemoveDuplicates(newVOs, false)) != len(newVOs) {
return nil, nil, errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("duplicate new %s", valueObjectName))
}
if len(strutil.RemoveDuplicates(oldVOs, false)) != len(oldVOs) {
return nil, nil, errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("duplicate old %s", valueObjectName))
}
factory, ok := supportedFactories[valueObjectName]
if !ok {
return nil, nil, errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("unsupported factory for value object: %s", valueObjectName))
}
if len(dbMask) == 0 && len(nullFields) == 0 {
return nil, nil, nil
}
foundVOs := map[string]bool{}
for _, a := range oldVOs {
foundVOs[a] = true
}
var adds []any
var deletes []any
if strutil.StrListContains(nullFields, string(valueObjectName)) {
deletes = make([]any, 0, len(oldVOs))
for _, v := range oldVOs {
deleteObj, err := factory(ctx, publicId, v)
if err != nil {
return nil, nil, errors.Wrap(ctx, err, op)
}
deletes = append(deletes, deleteObj)
delete(foundVOs, v)
}
}
if strutil.StrListContains(dbMask, string(valueObjectName)) {
adds = make([]any, 0, len(newVOs))
for _, v := range newVOs {
if _, ok := foundVOs[v]; ok {
delete(foundVOs, v)
continue
}
obj, err := factory(ctx, publicId, v)
if err != nil {
return nil, nil, errors.Wrap(ctx, err, op)
}
adds = append(adds, obj)
delete(foundVOs, v)
}
}
if len(foundVOs) > 0 {
for v := range foundVOs {
obj, err := factory(ctx, publicId, v)
if err != nil {
return nil, nil, errors.Wrap(ctx, err, op)
}
deletes = append(deletes, obj)
delete(foundVOs, v)
}
}
return adds, deletes, nil
}
// validateFieldMask check the field mask to ensure all the fields are updatable
func validateFieldMask(ctx context.Context, fieldMaskPaths []string) error {
const op = "validateFieldMask"
for _, f := range fieldMaskPaths {
switch {
case strings.EqualFold(NameField, f):
case strings.EqualFold(DescriptionField, f):
case strings.EqualFold(IssuerField, f):
case strings.EqualFold(ClientIdField, f):
case strings.EqualFold(ClientSecretField, f):
case strings.EqualFold(MaxAgeField, f):
case strings.EqualFold(SigningAlgsField, f):
case strings.EqualFold(ApiUrlField, f):
case strings.EqualFold(AudClaimsField, f):
case strings.EqualFold(CertificatesField, f):
case strings.EqualFold(ClaimsScopesField, f):
case strings.EqualFold(AccountClaimMapsField, f):
case strings.EqualFold(PromptsField, f):
default:
return errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("invalid field mask: %s", f))
}
}
return nil
}
// applyUpdate takes the new and applies it to the orig using the field masks
func applyUpdate(new, orig *AuthMethod, fieldMaskPaths []string) *AuthMethod {
cp := orig.Clone()
for _, f := range fieldMaskPaths {
switch f {
case NameField:
cp.Name = new.Name
case DescriptionField:
cp.Description = new.Description
case IssuerField:
cp.Issuer = new.Issuer
case ClientIdField:
cp.ClientId = new.ClientId
case ClientSecretField:
cp.ClientSecret = new.ClientSecret
case MaxAgeField:
cp.MaxAge = new.MaxAge
case ApiUrlField:
cp.ApiUrl = new.ApiUrl
case SigningAlgsField:
switch {
case len(new.SigningAlgs) == 0:
cp.SigningAlgs = nil
default:
cp.SigningAlgs = make([]string, 0, len(new.SigningAlgs))
cp.SigningAlgs = append(cp.SigningAlgs, new.SigningAlgs...)
}
case AudClaimsField:
switch {
case len(new.AudClaims) == 0:
cp.AudClaims = nil
default:
cp.AudClaims = make([]string, 0, len(new.AudClaims))
cp.AudClaims = append(cp.AudClaims, new.AudClaims...)
}
case CertificatesField:
switch {
case len(new.Certificates) == 0:
cp.Certificates = nil
default:
cp.Certificates = make([]string, 0, len(new.Certificates))
cp.Certificates = append(cp.Certificates, new.Certificates...)
}
case ClaimsScopesField:
switch {
case len(new.ClaimsScopes) == 0:
cp.ClaimsScopes = nil
default:
cp.ClaimsScopes = make([]string, 0, len(new.ClaimsScopes))
cp.ClaimsScopes = append(cp.ClaimsScopes, new.ClaimsScopes...)
}
case AccountClaimMapsField:
switch {
case len(new.AccountClaimMaps) == 0:
cp.AccountClaimMaps = nil
default:
cp.AccountClaimMaps = make([]string, 0, len(new.AccountClaimMaps))
cp.AccountClaimMaps = append(cp.AccountClaimMaps, new.AccountClaimMaps...)
}
case PromptsField:
switch {
case len(new.Prompts) == 0:
cp.Prompts = nil
default:
cp.Prompts = make([]string, 0, len(new.Prompts))
cp.Prompts = append(cp.Prompts, new.Prompts...)
}
}
}
return cp
}
// ValidateDiscoveryInfo will test/validate the provided AuthMethod against
// the info from it's discovery URL.
//
// It will verify that all required fields for a working AuthMethod have values.
//
// If the AuthMethod is complete, ValidateDiscoveryInfo retrieves the auth
// method's OpenID Configuration document. The values in the AuthMethod
// (and associated data) are validated with the retrieved document. The issuer and
// id token signing algorithm in the configuration are validated with the
// retrieved document. ValidateDiscoveryInfo also verifies the authorization, token,
// and userinfo endpoints by connecting to each and uses any certificates in the
// configuration as trust anchors to confirm connectivity.
//
// Options supported are: WithPublicId, WithAuthMethod
func (r *Repository) ValidateDiscoveryInfo(ctx context.Context, opt ...Option) error {
const op = "oidc.(Repository).ValidateDiscoveryInfo"
opts := getOpts(opt...)
var am *AuthMethod
switch {
case opts.withPublicId != "":
var err error
am, err = r.lookupAuthMethod(ctx, opts.withPublicId)
if err != nil {
return errors.Wrap(ctx, err, op)
}
if am == nil {
return errors.New(ctx, errors.RecordNotFound, op, fmt.Sprintf("unable to lookup auth method %s", opts.withPublicId))
}
case opts.withAuthMethod != nil:
am = opts.withAuthMethod
default:
return errors.New(ctx, errors.InvalidParameter, op, "neither WithPublicId(...) nor WithAuthMethod(...) options were provided")
}
// FYI: once converted to an oidc.Provider, any certs configured will be used as trust anchors for all HTTP requests
provider, err := convertToProvider(ctx, am)
if err != nil && am.OperationalState == string(InactiveState) {
return nil
}
if err != nil {
return errors.Wrap(ctx, err, op)
}
info, err := provider.DiscoveryInfo(ctx)
if err != nil {
return errors.Wrap(ctx, err, op)
}
var result error
if info.Issuer != am.Issuer {
result = stderrors.Join(result, errors.New(ctx, errors.InvalidParameter, op,
fmt.Sprintf("auth method issuer doesn't match discovery issuer: expected %s and got %s", am.Issuer, info.Issuer)))
}
for _, a := range am.SigningAlgs {
if !strutil.StrListContains(info.IdTokenSigningAlgsSupported, a) {
result = stderrors.Join(result, errors.New(ctx, errors.InvalidParameter, op,
fmt.Sprintf("auth method signing alg is not in discovered supported algs: expected %s and got %s", a, info.IdTokenSigningAlgsSupported)))
}
}
providerClient, err := provider.HTTPClient()
if err != nil {
result = stderrors.Join(result, errors.New(ctx, errors.Unknown, op, "unable to get oidc http client", errors.WithWrap(err)))
return result
}
// we need to prevent redirects during these tests... we don't want to have
// redirects going to the controller's callback (aka the configured provider's callback)
providerClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
}
// test JWKs URL
statusCode, err := pingEndpoint(ctx, providerClient, "JWKs", "GET", info.JWKSURL)
if err != nil {
result = stderrors.Join(result, errors.New(ctx, errors.Unknown, op, fmt.Sprintf("unable to verify JWKs endpoint: %s", info.JWKSURL), errors.WithWrap(err)))
}
if statusCode != 200 {
result = stderrors.Join(result, errors.New(ctx, errors.Unknown, op, fmt.Sprintf("non-200 status (%d) from JWKs endpoint: %s", statusCode, info.JWKSURL), errors.WithWrap(err)))
}
// test Auth URL
if _, err := pingEndpoint(ctx, providerClient, "AuthURL", "GET", info.AuthURL); err != nil {
result = stderrors.Join(result, errors.New(ctx, errors.Unknown, op, fmt.Sprintf("unable to verify authorize endpoint: %s", info.AuthURL), errors.WithWrap(err)))
}
// test Token URL
if _, err := pingEndpoint(ctx, providerClient, "TokenURL", "POST", info.TokenURL); err != nil {
result = stderrors.Join(result, errors.New(ctx, errors.Unknown, op, fmt.Sprintf("unable to verify token endpoint: %s", info.TokenURL), errors.WithWrap(err)))
}
// we're not verifying the UserInfo URL, since it's not a required dependency.
return result
}
type HTTPClient interface {
Do(req *http.Request) (*http.Response, error)
}
// pingEndpoint will make an attempted http request, return status code and errors
func pingEndpoint(ctx context.Context, client HTTPClient, endpointType, method, url string) (int, error) {
const op = "oidc.pingEndpoint"
req, err := http.NewRequestWithContext(ctx, method, url, nil)
if err != nil {
return 0, errors.New(ctx, errors.Unknown, op, fmt.Sprintf("unable to create %s http request", endpointType), errors.WithWrap(err))
}
resp, err := client.Do(req)
if err != nil {
return 0, errors.New(ctx, errors.Unknown, op, fmt.Sprintf("request to %s endpoint failed", endpointType), errors.WithWrap(err))
}
if resp.Body != nil {
resp.Body.Close()
}
return resp.StatusCode, nil
}