You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
boundary/internal/auth/ldap/repository_auth_method_upda...

815 lines
30 KiB

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package ldap
import (
"context"
"fmt"
"net/url"
"reflect"
"strings"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/boundary/internal/kms"
"github.com/hashicorp/boundary/internal/oplog"
"github.com/hashicorp/go-dbw"
"github.com/hashicorp/go-secure-stdlib/strutil"
)
const (
OperationalStateField = "OperationalState"
VersionField = "Version"
IsPrimaryAuthMethodField = "IsPrimaryAuthMethod"
NameField = "Name"
DescriptionField = "Description"
StartTlsField = "StartTls"
InsecureTlsField = "InsecureTls"
DiscoverDnField = "DiscoverDn"
AnonGroupSearchField = "AnonGroupSearch"
UpnDomainField = "UpnDomain"
UrlsField = "Urls"
UserDnField = "UserDn"
UserAttrField = "UserAttr"
UserFilterField = "UserFilter"
EnableGroupsField = "EnableGroups"
UseTokenGroupsField = "UseTokenGroups"
GroupDnField = "GroupDn"
GroupAttrField = "GroupAttr"
GroupFilterField = "GroupFilter"
CertificatesField = "Certificates"
ClientCertificateField = "ClientCertificate"
ClientCertificateKeyField = "ClientCertificateKey"
BindDnField = "BindDn"
BindPasswordField = "BindPassword"
AccountAttributeMapsField = "AccountAttributeMaps"
GroupNamesField = "GroupNames"
DerefAliasesField = "DereferenceAliases"
MaximumPageSizeField = "MaximumPageSize"
)
// isEmpty returns true if all the args are empty. Only supports checking
// strings and pointers, all other types are assumed to be empty.
func isEmpty(args ...any) bool {
for _, i := range args {
switch v := reflect.ValueOf(i); v.Kind() {
case reflect.Pointer:
if !v.IsNil() {
return false
}
case reflect.String:
if v.String() != "" {
return false
}
}
}
return true
}
// UpdateAuthMethod will retrieve the auth method from the repository,
// and update it based on the field masks provided.
//
// fieldMaskPaths provides field_mask.proto paths for fields that should
// be updated. Fields will be set to NULL if the field is a
// zero value and included in fieldMask. Name, Description, StartTLs,
// DiscoverDn, AnonGroupSearch, UpnDomain, UserDn, UserAttr, UserFilter,
// GroupDn, GroupAttr, GroupFilter, ClientCertificateKey, ClientCertificate,
// BindDn and BindPassword are all updatable fields. The AuthMethod's Value
// Objects of Urls and Certificates are also updatable. If no updatable fields
// are included in the fieldMaskPaths, then an error is returned.
//
// No Options are currently supported.
func (r *Repository) UpdateAuthMethod(ctx context.Context, am *AuthMethod, version uint32, fieldMaskPaths []string, _ ...Option) (*AuthMethod, int, error) {
const op = "ldap.(AuthMethod).Update"
switch {
case am == nil:
return nil, db.NoRowsAffected, errors.New(ctx, errors.InvalidParameter, op, "missing auth method")
case am.AuthMethod == nil:
return nil, db.NoRowsAffected, errors.New(ctx, errors.InvalidParameter, op, "missing auth method store")
case am.PublicId == "":
return nil, db.NoRowsAffected, errors.New(ctx, errors.InvalidParameter, op, "missing public id")
}
if err := validateFieldMask(ctx, fieldMaskPaths); err != nil {
return nil, db.NoRowsAffected, errors.Wrap(ctx, err, op)
}
dbMask, nullFields := dbw.BuildUpdatePaths(
map[string]any{
OperationalStateField: am.OperationalState,
NameField: am.Name,
DescriptionField: am.Description,
StartTlsField: am.StartTls,
InsecureTlsField: am.InsecureTls,
DiscoverDnField: am.DiscoverDn,
AnonGroupSearchField: am.AnonGroupSearch,
UpnDomainField: am.UpnDomain,
UserDnField: am.UserDn,
UserAttrField: am.UserAttr,
UserFilterField: am.UserFilter,
EnableGroupsField: am.EnableGroups,
UseTokenGroupsField: am.UseTokenGroups,
GroupDnField: am.GroupDn,
GroupAttrField: am.GroupAttr,
GroupFilterField: am.GroupFilter,
CertificatesField: am.Certificates,
ClientCertificateField: am.ClientCertificate,
ClientCertificateKeyField: am.ClientCertificateKey,
BindDnField: am.BindDn,
BindPasswordField: am.BindPassword,
UrlsField: am.Urls,
AccountAttributeMapsField: am.AccountAttributeMaps,
DerefAliasesField: am.DereferenceAliases,
MaximumPageSizeField: am.MaximumPageSize,
},
fieldMaskPaths,
[]string{
StartTlsField,
InsecureTlsField,
DiscoverDnField,
AnonGroupSearchField,
EnableGroupsField,
UseTokenGroupsField,
MaximumPageSizeField,
},
)
if len(dbMask) == 0 && len(nullFields) == 0 {
return nil, db.NoRowsAffected, errors.New(ctx, errors.EmptyFieldMask, op, "empty field mask")
}
if strutil.StrListContains(nullFields, UrlsField) {
return nil, db.NoRowsAffected, errors.New(ctx, errors.InvalidParameter, op, "missing urls (you cannot delete all of them; there must be at least one)")
}
origAm, err := r.LookupAuthMethod(ctx, am.PublicId)
if err != nil {
return nil, db.NoRowsAffected, errors.Wrap(ctx, err, op, errors.WithMsg("%q auth method not found", am.PublicId))
}
if origAm == nil {
return nil, db.NoRowsAffected, errors.New(ctx, errors.RecordNotFound, op, fmt.Sprintf("auth method %q", am.PublicId))
}
// there's no reason to continue if another controller has already updated this auth method.
if origAm.Version != version {
return nil, db.NoRowsAffected, errors.New(ctx, errors.VersionMismatch, op, fmt.Sprintf("update version %d doesn't match db version %d", version, origAm.Version))
}
dbWrapper, err := r.kms.GetWrapper(ctx, origAm.ScopeId, kms.KeyPurposeDatabase)
if err != nil {
return nil, db.NoRowsAffected, errors.Wrap(ctx, err, op, errors.WithMsg("unable to get database wrapper"))
}
au, du, err := valueObjectChanges(ctx, origAm.PublicId, UrlVO, am.Urls, origAm.Urls, dbMask, nullFields)
if err != nil {
return nil, db.NoRowsAffected, errors.Wrap(ctx, err, op, errors.WithMsg("unable to update urls"))
}
addUrls := []*Url{}
for _, u := range au {
addUrls = append(addUrls, u.(*Url))
}
deleteUrls := []*Url{}
for _, u := range du {
deleteUrls = append(deleteUrls, u.(*Url))
}
ac, dc, err := valueObjectChanges(ctx, origAm.PublicId, CertificateVO, am.Certificates, origAm.Certificates, dbMask, nullFields)
if err != nil {
return nil, db.NoRowsAffected, errors.Wrap(ctx, err, op, errors.WithMsg("unable to update certificates"))
}
deleteCerts := []*Certificate{}
for _, c := range dc {
deleteCerts = append(deleteCerts, c.(*Certificate))
}
addCerts := []*Certificate{}
for _, c := range ac {
addCerts = append(addCerts, c.(*Certificate))
}
addM, deleteM, err := valueObjectChanges(ctx, origAm.PublicId, AccountAttributeMapsVO, am.AccountAttributeMaps, origAm.AccountAttributeMaps, dbMask, nullFields)
if err != nil {
return nil, db.NoRowsAffected, errors.Wrap(ctx, err, op, errors.WithMsg("unable to update account attribute maps"))
}
addMaps := []*AccountAttributeMap{}
for _, m := range addM {
addMaps = append(addMaps, m.(*AccountAttributeMap))
}
deleteMaps := []*AccountAttributeMap{}
for _, m := range deleteM {
deleteMaps = append(deleteMaps, m.(*AccountAttributeMap))
}
combinedMasks := append(dbMask, nullFields...)
var addUserSearchConf, deleteUserSearchConf any
if strListContainsOneOf(combinedMasks, UserDnField, UserAttrField, UserFilterField) {
if !isEmpty(origAm.UserDn, origAm.UserAttr, origAm.UserFilter) {
usc := allocUserEntrySearchConf()
usc.LdapMethodId = am.PublicId
deleteUserSearchConf = usc
}
userDn := origAm.UserDn
switch {
case strutil.StrListContains(dbMask, UserDnField):
userDn = am.UserDn
case strutil.StrListContains(nullFields, UserDnField):
userDn = ""
}
userAttr := origAm.UserAttr
switch {
case strutil.StrListContains(dbMask, UserAttrField):
userAttr = am.UserAttr
case strutil.StrListContains(nullFields, UserAttrField):
userAttr = ""
}
userFilter := origAm.UserFilter
switch {
case strutil.StrListContains(dbMask, UserFilterField):
userFilter = am.UserFilter
case strutil.StrListContains(nullFields, UserFilterField):
userFilter = ""
}
if !isEmpty(userDn, userAttr, userFilter) {
addUserSearchConf, err = NewUserEntrySearchConf(ctx, am.PublicId, WithUserDn(ctx, userDn), WithUserAttr(ctx, userAttr), WithUserFilter(ctx, userFilter))
if err != nil {
return nil, db.NoRowsAffected, errors.Wrap(ctx, err, op, errors.WithMsg("unable to update user search configuration"))
}
}
}
var addGroupSearchConf, deleteGroupSearchConf any
if strListContainsOneOf(combinedMasks, GroupDnField, GroupAttrField, GroupFilterField) {
if !isEmpty(origAm.GroupDn, origAm.GroupAttr, origAm.GroupFilter) {
gsc := allocGroupEntrySearchConf()
gsc.LdapMethodId = am.PublicId
deleteGroupSearchConf = gsc
}
groupDn := origAm.GroupDn
switch {
case strutil.StrListContains(dbMask, GroupDnField):
groupDn = am.GroupDn
case strutil.StrListContains(nullFields, GroupDnField):
groupDn = ""
}
groupAttr := origAm.GroupAttr
switch {
case strutil.StrListContains(dbMask, GroupAttrField):
groupAttr = am.GroupAttr
case strutil.StrListContains(nullFields, GroupAttrField):
groupAttr = ""
}
groupFilter := origAm.GroupFilter
switch {
case strutil.StrListContains(dbMask, GroupFilterField):
groupFilter = am.GroupFilter
case strutil.StrListContains(nullFields, GroupFilterField):
groupFilter = ""
}
if !isEmpty(groupDn, groupAttr, groupFilter) {
addGroupSearchConf, err = NewGroupEntrySearchConf(ctx, am.PublicId, WithGroupDn(ctx, groupDn), WithGroupAttr(ctx, groupAttr), WithGroupFilter(ctx, groupFilter))
if err != nil {
return nil, db.NoRowsAffected, errors.Wrap(ctx, err, op, errors.WithMsg("unable to update group search configuration"))
}
}
}
var addClientCert, deleteClientCert any
if strListContainsOneOf(combinedMasks, ClientCertificateField, ClientCertificateKeyField) {
if !isEmpty(origAm.ClientCertificate, origAm.ClientCertificateKey) {
cc := allocClientCertificate()
cc.LdapMethodId = am.PublicId
deleteClientCert = cc
}
clientCertificate := origAm.ClientCertificate
switch {
case strutil.StrListContains(dbMask, ClientCertificateField):
clientCertificate = am.ClientCertificate
case strutil.StrListContains(nullFields, ClientCertificateField):
clientCertificate = ""
}
clientCertificateKey := origAm.ClientCertificateKey
switch {
case strutil.StrListContains(dbMask, ClientCertificateKeyField):
clientCertificateKey = am.ClientCertificateKey
case strutil.StrListContains(nullFields, ClientCertificateKeyField):
clientCertificateKey = nil
}
if !isEmpty(clientCertificate, clientCertificateKey) {
cc, err := NewClientCertificate(ctx, am.PublicId, clientCertificateKey, clientCertificate)
if err != nil {
return nil, db.NoRowsAffected, errors.Wrap(ctx, err, op)
}
if err := cc.encrypt(ctx, dbWrapper); err != nil {
return nil, db.NoRowsAffected, errors.Wrap(ctx, err, op)
}
addClientCert = cc
}
}
var addBindCred, deleteBindCred any
if strListContainsOneOf(combinedMasks, BindDnField, BindPasswordField) {
if !isEmpty(origAm.BindDn, origAm.BindPassword) {
bc := allocBindCredential()
bc.LdapMethodId = am.PublicId
deleteBindCred = bc
}
bindDn := origAm.BindDn
switch {
case strutil.StrListContains(dbMask, BindDnField):
bindDn = am.BindDn
case strutil.StrListContains(nullFields, BindDnField):
bindDn = ""
}
bindPassword := origAm.BindPassword
switch {
case strutil.StrListContains(dbMask, BindPasswordField):
bindPassword = am.BindPassword
case strutil.StrListContains(nullFields, BindPasswordField):
bindPassword = ""
}
if !isEmpty(bindDn, bindPassword) {
bc, err := NewBindCredential(ctx, am.PublicId, bindDn, []byte(bindPassword))
if err != nil {
return nil, db.NoRowsAffected, errors.Wrap(ctx, err, op)
}
if err := bc.encrypt(ctx, dbWrapper); err != nil {
return nil, db.NoRowsAffected, errors.Wrap(ctx, err, op, errors.WithMsg("error wrapping updated bind credential"))
}
addBindCred = bc
}
}
var addDerefAliases, deleteDerefAliases any
if strListContainsOneOf(combinedMasks, DerefAliasesField) {
if !isEmpty(origAm.DereferenceAliases) {
d := allocDerefAliases()
d.LdapMethodId = am.PublicId
deleteDerefAliases = d
}
derefAliases := origAm.DereferenceAliases
switch {
case strutil.StrListContains(dbMask, DerefAliasesField):
derefAliases = am.DereferenceAliases
case strutil.StrListContains(nullFields, DerefAliasesField):
derefAliases = ""
}
if !isEmpty(derefAliases) {
d, err := NewDerefAliases(ctx, am.PublicId, DerefAliasType(derefAliases))
if err != nil {
return nil, db.NoRowsAffected, errors.Wrap(ctx, err, op)
}
addDerefAliases = d
}
}
var filteredDbMask, filteredNullFields []string
for _, f := range dbMask {
switch f {
case
UrlsField,
CertificatesField,
AccountAttributeMapsField,
UserDnField, UserAttrField, UserFilterField,
GroupDnField, GroupAttrField, GroupFilterField,
ClientCertificateField, ClientCertificateKeyField,
BindDnField, BindPasswordField,
DerefAliasesField:
continue
default:
filteredDbMask = append(filteredDbMask, f)
}
}
for _, f := range nullFields {
switch f {
case
StartTlsField, InsecureTlsField, DiscoverDnField, AnonGroupSearchField, EnableGroupsField, UseTokenGroupsField,
UrlsField,
CertificatesField,
AccountAttributeMapsField,
UserDnField, UserAttrField, UserFilterField,
GroupDnField, GroupAttrField, GroupFilterField,
ClientCertificateField, ClientCertificateKeyField,
BindDnField, BindPasswordField,
DerefAliasesField:
continue
default:
filteredNullFields = append(filteredNullFields, f)
}
}
// handle no changes...
if len(filteredDbMask) == 0 &&
len(filteredNullFields) == 0 &&
len(addUrls) == 0 &&
len(deleteUrls) == 0 &&
len(addCerts) == 0 &&
len(deleteCerts) == 0 &&
addUserSearchConf == nil &&
deleteUserSearchConf == nil &&
addGroupSearchConf == nil &&
deleteGroupSearchConf == nil &&
addClientCert == nil &&
deleteClientCert == nil &&
addBindCred == nil &&
deleteBindCred == nil &&
len(addMaps) == 0 &&
len(deleteMaps) == 0 &&
addDerefAliases == nil &&
deleteDerefAliases == nil {
return origAm, db.NoRowsAffected, nil
}
oplogWrapper, err := r.kms.GetWrapper(ctx, origAm.ScopeId, kms.KeyPurposeOplog)
if err != nil {
return nil, db.NoRowsAffected, errors.Wrap(ctx, err, op, errors.WithMsg("unable to get oplog wrapper"))
}
var updatedAm *AuthMethod
var rowsUpdated int
_, err = r.writer.DoTx(
ctx,
db.StdRetryCnt,
db.ExpBackoff{},
func(reader db.Reader, w db.Writer) error {
msgs := make([]*oplog.Message, 0, 7) // AuthMethod, Algs*2, Certs*2, Audiences*2
ticket, err := w.GetTicket(ctx, am)
if err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to get ticket"))
}
var authMethodOplogMsg oplog.Message
switch {
case len(filteredDbMask) == 0 && len(filteredNullFields) == 0:
// the auth method's fields are not being updated, just it's value objects, so we need to just update the auth
// method's version.
updatedAm = am.clone()
updatedAm.Version = uint32(version) + 1
rowsUpdated, err = w.Update(ctx, updatedAm, []string{VersionField}, nil, db.NewOplogMsg(&authMethodOplogMsg), db.WithVersion(&version))
if err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to update auth method version"))
}
if rowsUpdated != 1 {
return errors.New(ctx, errors.MultipleRecords, op, fmt.Sprintf("updated auth method version and %d rows updated", rowsUpdated))
}
default:
updatedAm = am.clone()
rowsUpdated, err = w.Update(ctx, updatedAm, filteredDbMask, filteredNullFields, db.NewOplogMsg(&authMethodOplogMsg), db.WithVersion(&version))
if err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to update auth method"))
}
if rowsUpdated != 1 {
return errors.New(ctx, errors.MultipleRecords, op, fmt.Sprintf("updated auth method and %d rows updated", rowsUpdated))
}
}
msgs = append(msgs, &authMethodOplogMsg)
if len(deleteCerts) > 0 {
deleteCertOplogMsgs := make([]*oplog.Message, 0, len(deleteCerts))
rowsDeleted, err := w.DeleteItems(ctx, deleteCerts, db.NewOplogMsgs(&deleteCertOplogMsgs))
if err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to delete certificates"))
}
if rowsDeleted != len(deleteCerts) {
return errors.New(ctx, errors.MultipleRecords, op, fmt.Sprintf("certificates deleted %d did not match request for %d", rowsDeleted, len(deleteCerts)))
}
msgs = append(msgs, deleteCertOplogMsgs...)
}
if len(addCerts) > 0 {
addCertsOplogMsgs := make([]*oplog.Message, 0, len(addCerts))
if err := w.CreateItems(ctx, addCerts, db.NewOplogMsgs(&addCertsOplogMsgs)); err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to add certificates"))
}
msgs = append(msgs, addCertsOplogMsgs...)
}
if len(deleteUrls) > 0 {
deleteAudsOplogMsgs := make([]*oplog.Message, 0, len(deleteUrls))
rowsDeleted, err := w.DeleteItems(ctx, deleteUrls, db.NewOplogMsgs(&deleteAudsOplogMsgs))
if err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to delete URLs"))
}
if rowsDeleted != len(deleteUrls) {
return errors.New(ctx, errors.MultipleRecords, op, fmt.Sprintf("urls deleted %d did not match request for %d", rowsDeleted, len(deleteUrls)))
}
msgs = append(msgs, deleteAudsOplogMsgs...)
}
if len(addUrls) > 0 {
addUrlsOplogMsgs := make([]*oplog.Message, 0, len(addUrls))
if err := w.CreateItems(ctx, addUrls, db.NewOplogMsgs(&addUrlsOplogMsgs)); err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to add urls"))
}
msgs = append(msgs, addUrlsOplogMsgs...)
}
if deleteUserSearchConf != nil {
var deleteUserSearchConfMsg oplog.Message
rowsDeleted, err := w.Delete(ctx, deleteUserSearchConf, db.NewOplogMsg(&deleteUserSearchConfMsg))
if err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to delete user search conf"))
}
if rowsDeleted != 1 {
return errors.New(ctx, errors.MultipleRecords, op, fmt.Sprintf("user search conf deleted %d did not match request for 1", rowsDeleted))
}
msgs = append(msgs, &deleteUserSearchConfMsg)
}
if addUserSearchConf != nil {
var addUserSearchConfOplogMsg oplog.Message
if err := w.Create(ctx, addUserSearchConf, db.NewOplogMsg(&addUserSearchConfOplogMsg)); err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to add user search conf"))
}
msgs = append(msgs, &addUserSearchConfOplogMsg)
}
if deleteGroupSearchConf != nil {
var deleteGroupSearchConfMsg oplog.Message
rowsDeleted, err := w.Delete(ctx, deleteGroupSearchConf, db.NewOplogMsg(&deleteGroupSearchConfMsg))
if err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to delete group search conf"))
}
if rowsDeleted != 1 {
return errors.New(ctx, errors.MultipleRecords, op, fmt.Sprintf("group search conf deleted %d did not match request for 1", rowsDeleted))
}
msgs = append(msgs, &deleteGroupSearchConfMsg)
}
if addGroupSearchConf != nil {
var addGroupSearchConfOplogMsg oplog.Message
if err := w.Create(ctx, addGroupSearchConf, db.NewOplogMsg(&addGroupSearchConfOplogMsg)); err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to add group search conf"))
}
msgs = append(msgs, &addGroupSearchConfOplogMsg)
}
if deleteClientCert != nil {
var deleteClientCertMsg oplog.Message
rowsDeleted, err := w.Delete(ctx, deleteClientCert, db.NewOplogMsg(&deleteClientCertMsg))
if err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to delete client cert"))
}
if rowsDeleted != 1 {
return errors.New(ctx, errors.MultipleRecords, op, fmt.Sprintf("client cert deleted %d did not match request for 1", rowsDeleted))
}
msgs = append(msgs, &deleteClientCertMsg)
}
if addClientCert != nil {
var addClientCertOplogMsg oplog.Message
if err := w.Create(ctx, addClientCert, db.NewOplogMsg(&addClientCertOplogMsg)); err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to add client cert"))
}
msgs = append(msgs, &addClientCertOplogMsg)
}
if deleteBindCred != nil {
var deleteBindCredMsg oplog.Message
rowsDeleted, err := w.Delete(ctx, deleteBindCred, db.NewOplogMsg(&deleteBindCredMsg))
if err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to delete bind credential conf"))
}
if rowsDeleted != 1 {
return errors.New(ctx, errors.MultipleRecords, op, fmt.Sprintf("bind credential deleted %d did not match request for 1", rowsDeleted))
}
msgs = append(msgs, &deleteBindCredMsg)
}
if addBindCred != nil {
var addBindCredOplogMsg oplog.Message
if err := w.Create(ctx, addBindCred, db.NewOplogMsg(&addBindCredOplogMsg)); err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to add bind credential"))
}
msgs = append(msgs, &addBindCredOplogMsg)
}
if len(deleteMaps) > 0 {
deleteMapsOplogMsgs := make([]*oplog.Message, 0, len(deleteMaps))
rowsDeleted, err := w.DeleteItems(ctx, deleteMaps, db.NewOplogMsgs(&deleteMapsOplogMsgs))
if err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to delete account attribute maps"))
}
if rowsDeleted != len(deleteMaps) {
return errors.New(ctx, errors.MultipleRecords, op, fmt.Sprintf("account attribute maps deleted %d did not match request for %d", rowsDeleted, len(deleteMaps)))
}
msgs = append(msgs, deleteMapsOplogMsgs...)
}
if len(addMaps) > 0 {
addMapsOplogMsgs := make([]*oplog.Message, 0, len(addMaps))
if err := w.CreateItems(ctx, addMaps, db.NewOplogMsgs(&addMapsOplogMsgs)); err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to add account attribute maps"))
}
msgs = append(msgs, addMapsOplogMsgs...)
}
if deleteDerefAliases != nil {
var deleteDerefMsg oplog.Message
rowsDeleted, err := w.Delete(ctx, deleteDerefAliases, db.NewOplogMsg(&deleteDerefMsg))
if err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to delete deref aliases"))
}
if rowsDeleted != 1 {
return errors.New(ctx, errors.MultipleRecords, op, fmt.Sprintf("deref aliases deleted %d did not match request for 1", rowsDeleted))
}
msgs = append(msgs, &deleteDerefMsg)
}
if addDerefAliases != nil {
var addDerefOplogMsg oplog.Message
if err := w.Create(ctx, addDerefAliases, db.NewOplogMsg(&addDerefOplogMsg)); err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to add deref aliases"))
}
msgs = append(msgs, &addDerefOplogMsg)
}
metadata, err := updatedAm.oplog(ctx, oplog.OpType_OP_TYPE_UPDATE)
if err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to generate oplog metadata"))
}
if err := w.WriteOplogEntryWith(ctx, oplogWrapper, ticket, metadata, msgs); err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to write oplog"))
}
// we need a new repo, that's using the same reader/writer as this TxHandler
txRepo := &Repository{
reader: reader,
writer: w,
kms: r.kms,
// intentionally not setting the defaultLimit, so we'll get all
// the account ids without a limit
}
updatedAm, err = txRepo.lookupAuthMethod(ctx, updatedAm.PublicId)
if err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to lookup auth method after update"))
}
if updatedAm == nil {
return errors.New(ctx, errors.RecordNotFound, op, "unable to lookup auth method after update")
}
return nil
},
)
if err != nil {
return nil, db.NoRowsAffected, errors.Wrap(ctx, err, op)
}
return updatedAm, rowsUpdated, nil
}
// validateFieldMasks ensures that all the fields in the mask are updatable
func validateFieldMask(ctx context.Context, fieldMaskPaths []string) error {
const op = "ldap.validateFieldMasks"
for _, f := range fieldMaskPaths {
switch {
case strings.EqualFold(OperationalStateField, f):
case strings.EqualFold(NameField, f):
case strings.EqualFold(DescriptionField, f):
case strings.EqualFold(StartTlsField, f):
case strings.EqualFold(InsecureTlsField, f):
case strings.EqualFold(DiscoverDnField, f):
case strings.EqualFold(AnonGroupSearchField, f):
case strings.EqualFold(UpnDomainField, f):
case strings.EqualFold(UserDnField, f):
case strings.EqualFold(UserAttrField, f):
case strings.EqualFold(UserFilterField, f):
case strings.EqualFold(EnableGroupsField, f):
case strings.EqualFold(UseTokenGroupsField, f):
case strings.EqualFold(GroupDnField, f):
case strings.EqualFold(GroupAttrField, f):
case strings.EqualFold(GroupFilterField, f):
case strings.EqualFold(CertificatesField, f):
case strings.EqualFold(ClientCertificateField, f):
case strings.EqualFold(ClientCertificateKeyField, f):
case strings.EqualFold(BindDnField, f):
case strings.EqualFold(BindPasswordField, f):
case strings.EqualFold(UrlsField, f):
case strings.EqualFold(AccountAttributeMapsField, f):
case strings.EqualFold(DerefAliasesField, f):
case strings.EqualFold(MaximumPageSizeField, f):
default:
return errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("invalid field mask: %q", f))
}
}
return nil
}
// voName represents the names of auth method value objects
type voName string
const (
CertificateVO voName = "Certificates"
UrlVO voName = "Urls"
AccountAttributeMapsVO voName = "AccountAttributeMaps"
)
// validVoName decides if the name is valid
func validVoName(name voName) bool {
switch name {
case CertificateVO, UrlVO, AccountAttributeMapsVO:
return true
default:
return false
}
}
// factoryFunc defines a func type for value object factories
type factoryFunc func(ctx context.Context, publicId string, idx int, s string) (any, error)
// supportedFactories are the currently supported factoryFunc for value objects
var supportedFactories = map[voName]factoryFunc{
CertificateVO: func(ctx context.Context, publicId string, idx int, s string) (any, error) {
return NewCertificate(ctx, publicId, s)
},
UrlVO: func(ctx context.Context, publicId string, idx int, s string) (any, error) {
u, err := url.Parse(s)
if err != nil {
return nil, errors.Wrap(ctx, err, "ldap.urlFactory")
}
return NewUrl(ctx, publicId, idx+1, u)
},
AccountAttributeMapsVO: func(ctx context.Context, publicId string, idx int, s string) (any, error) {
const op = "ldap.AccountAttributeMapsFactory"
acm, err := ParseAccountAttributeMaps(ctx, s)
if err != nil {
return nil, err
}
if len(acm) > 1 {
return nil, errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("unable to parse account attribute map %s", s))
}
var m AttributeMap
for _, m = range acm {
}
to, err := ConvertToAccountToAttribute(ctx, m.To)
if err != nil {
return nil, errors.Wrap(ctx, err, op)
}
return NewAccountAttributeMap(ctx, publicId, m.From, to)
},
}
// valueObjectChanges takes the new and old list of VOs (value objects) and
// using the dbMasks/nullFields it will return lists of VOs which need to be
// added and deleted in order to reconcile auth method's value objects.
func valueObjectChanges(
ctx context.Context,
publicId string,
valueObjectName voName,
newVOs,
oldVOs,
dbMask,
nullFields []string,
) (add []any, del []any, e error) {
const op = "ldap.valueObjectChanges"
switch {
case publicId == "":
return nil, nil, errors.New(ctx, errors.InvalidParameter, op, "missing public id")
case !validVoName(valueObjectName):
return nil, nil, errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("invalid value object name: %s", valueObjectName))
case !strutil.StrListContains(dbMask, string(valueObjectName)) && !strutil.StrListContains(nullFields, string(valueObjectName)):
return nil, nil, nil
case len(strutil.RemoveDuplicates(newVOs, false)) != len(newVOs):
return nil, nil, errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("duplicate new %s", valueObjectName))
case len(strutil.RemoveDuplicates(oldVOs, false)) != len(oldVOs):
return nil, nil, errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("duplicate old %s", valueObjectName))
}
factory, ok := supportedFactories[valueObjectName]
if !ok {
return nil, nil, errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("unsupported factory for value object: %s", valueObjectName))
}
foundVOs := map[string]int{}
for i, a := range oldVOs {
foundVOs[a] = i
}
var adds []any
var deletes []any
if strutil.StrListContains(nullFields, string(valueObjectName)) {
deletes = make([]any, 0, len(oldVOs))
for i, v := range oldVOs {
deleteObj, err := factory(ctx, publicId, i, v)
if err != nil {
return nil, nil, errors.Wrap(ctx, err, op)
}
deletes = append(deletes, deleteObj)
delete(foundVOs, v)
}
}
if strutil.StrListContains(dbMask, string(valueObjectName)) {
adds = make([]any, 0, len(newVOs))
for i, v := range newVOs {
if _, ok := foundVOs[v]; ok {
delete(foundVOs, v)
continue
}
obj, err := factory(ctx, publicId, i, v)
if err != nil {
return nil, nil, errors.Wrap(ctx, err, op)
}
adds = append(adds, obj)
delete(foundVOs, v)
}
}
if len(foundVOs) > 0 {
for v := range foundVOs {
obj, err := factory(ctx, publicId, foundVOs[v], v)
if err != nil {
return nil, nil, errors.Wrap(ctx, err, op)
}
deletes = append(deletes, obj)
delete(foundVOs, v)
}
}
return adds, deletes, nil
}
func strListContainsOneOf(haystack []string, needles ...string) bool {
for _, item := range haystack {
for _, n := range needles {
if item == n {
return true
}
}
}
return false
}