diff --git a/internal/credential/vault/fields.go b/internal/credential/vault/fields.go index bd2415dd8a..cad36b46bf 100644 --- a/internal/credential/vault/fields.go +++ b/internal/credential/vault/fields.go @@ -17,4 +17,6 @@ const ( tlsServerNameField = "TlsServerName" tlsSkipVerifyField = "TlsSkipVerify" tokenField = "Token" + + mappingOverrideField = "MappingOverride" ) diff --git a/internal/credential/vault/repository_credential_library.go b/internal/credential/vault/repository_credential_library.go index faff3f278e..1ded5ea07e 100644 --- a/internal/credential/vault/repository_credential_library.go +++ b/internal/credential/vault/repository_credential_library.go @@ -114,8 +114,8 @@ func (r *Repository) CreateCredentialLibrary(ctx context.Context, scopeId string // number of records updated. l is not changed. // // l must contain a valid PublicId. Only Name, Description, VaultPath, -// HttpMethod, and HttpRequestBody can be updated. If l.Name is set to a -// non-empty string, it must be unique within l.StoreId. +// HttpMethod, HttpRequestBody, and MappingOverride can be updated. If +// l.Name is set to a non-empty string, it must be unique within l.StoreId. // // An attribute of l will be set to NULL in the database if the attribute // in l is the zero value and it is included in fieldMaskPaths except for @@ -141,6 +141,7 @@ func (r *Repository) UpdateCredentialLibrary(ctx context.Context, scopeId string } l = l.clone() + var updateMappingOverride bool for _, f := range fieldMaskPaths { switch { case strings.EqualFold(nameField, f): @@ -148,6 +149,8 @@ func (r *Repository) UpdateCredentialLibrary(ctx context.Context, scopeId string case strings.EqualFold(vaultPathField, f): case strings.EqualFold(httpMethodField, f): case strings.EqualFold(httpRequestBodyField, f): + case strings.EqualFold(mappingOverrideField, f): + updateMappingOverride = true default: return nil, db.NoRowsAffected, errors.New(ctx, errors.InvalidFieldMask, op, f) } @@ -160,6 +163,7 @@ func (r *Repository) UpdateCredentialLibrary(ctx context.Context, scopeId string vaultPathField: l.VaultPath, httpMethodField: l.HttpMethod, httpRequestBodyField: l.HttpRequestBody, + mappingOverrideField: l.MappingOverride, }, fieldMaskPaths, nil, @@ -181,6 +185,32 @@ func (r *Repository) UpdateCredentialLibrary(ctx context.Context, scopeId string return nil, db.NoRowsAffected, errors.New(ctx, errors.EmptyFieldMask, op, "missing field mask") } + origLib, err := r.LookupCredentialLibrary(ctx, l.PublicId) + switch { + case err != nil: + return nil, db.NoRowsAffected, errors.Wrap(ctx, err, op) + case origLib == nil: + return nil, db.NoRowsAffected, errors.New(ctx, errors.RecordNotFound, op, fmt.Sprintf("credential library %s", l.PublicId)) + case updateMappingOverride && !validMappingOverride(l.MappingOverride, origLib.CredentialType()): + return nil, db.NoRowsAffected, errors.New(ctx, errors.VaultInvalidMappingOverride, op, "invalid mapping override for credential type") + } + + var filteredDbMask, filteredNullFields []string + for _, f := range dbMask { + switch { + case strings.EqualFold(mappingOverrideField, f): + default: + filteredDbMask = append(filteredDbMask, f) + } + } + for _, f := range nullFields { + switch { + case strings.EqualFold(mappingOverrideField, f): + default: + filteredNullFields = append(filteredNullFields, f) + } + } + oplogWrapper, err := r.kms.GetWrapper(ctx, scopeId, kms.KeyPurposeOplog) if err != nil { return nil, db.NoRowsAffected, errors.Wrap(ctx, err, op, errors.WithCode(errors.Encrypt), @@ -190,16 +220,94 @@ func (r *Repository) UpdateCredentialLibrary(ctx context.Context, scopeId string var rowsUpdated int var returnedCredentialLibrary *CredentialLibrary _, err = r.writer.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, - func(_ db.Reader, w db.Writer) error { - returnedCredentialLibrary = l.clone() - var err error - rowsUpdated, err = w.Update(ctx, returnedCredentialLibrary, dbMask, nullFields, - db.WithOplog(oplogWrapper, l.oplog(oplog.OpType_OP_TYPE_UPDATE)), - db.WithVersion(&version)) - if err == nil && rowsUpdated > 1 { - return errors.New(ctx, errors.MultipleRecords, op, "more than 1 resource would have been updated") + func(rr db.Reader, w db.Writer) error { + var msgs []*oplog.Message + ticket, err := w.GetTicket(l) + if err != nil { + return errors.Wrap(ctx, err, op, errors.WithMsg("unable to get ticket")) } - return err + + l := l.clone() + var lOplogMsg oplog.Message + + // Update the credential library table + switch { + case len(filteredDbMask) == 0 && len(filteredNullFields) == 0: + // the credential library's fields are not being updated, + // just one of it's child objects, so we just need to + // update the library's version. + l.Version = version + 1 + rowsUpdated, err = w.Update(ctx, l, []string{"Version"}, nil, db.NewOplogMsg(&lOplogMsg), db.WithVersion(&version)) + if err != nil { + return errors.Wrap(ctx, err, op, errors.WithMsg("unable to update credential library version")) + } + switch rowsUpdated { + case 1: + case 0: + return nil + default: + return errors.New(ctx, errors.MultipleRecords, op, fmt.Sprintf("updated credential library version and %d rows updated", rowsUpdated)) + } + default: + rowsUpdated, err = w.Update(ctx, l, filteredDbMask, filteredNullFields, db.NewOplogMsg(&lOplogMsg), db.WithVersion(&version)) + if err != nil { + if errors.IsUniqueError(err) { + return errors.New(ctx, errors.NotUnique, op, + fmt.Sprintf("name %s already exists: %s", l.Name, l.PublicId)) + } + return errors.Wrap(ctx, err, op, errors.WithMsg("unable to update credential library")) + } + switch rowsUpdated { + case 1: + case 0: + return nil + default: + return errors.New(ctx, errors.MultipleRecords, op, fmt.Sprintf("updated credential library and %d rows updated", rowsUpdated)) + } + } + msgs = append(msgs, &lOplogMsg) + + // Update a credential mapping override table if applicable + if updateMappingOverride { + // delete the current mapping override if it exists + if origLib.MappingOverride != nil { + var msg oplog.Message + rowsDeleted, err := w.Delete(ctx, origLib.MappingOverride, db.NewOplogMsg(&msg)) + switch { + case err != nil: + return errors.Wrap(ctx, err, op, errors.WithMsg("unable to delete mapping override")) + default: + switch rowsDeleted { + case 0, 1: + default: + return errors.New(ctx, errors.MultipleRecords, op, fmt.Sprintf("delete mapping override and %d rows deleted", rowsDeleted)) + } + } + msgs = append(msgs, &msg) + } + // insert a new mapping override if specified + if l.MappingOverride != nil { + var msg oplog.Message + l.MappingOverride.setLibraryId(l.PublicId) + if err := w.Create(ctx, l.MappingOverride, db.NewOplogMsg(&msg)); err != nil { + return errors.Wrap(ctx, err, op) + } + msgs = append(msgs, &msg) + } + } + + metadata := l.oplog(oplog.OpType_OP_TYPE_UPDATE) + if err := w.WriteOplogEntryWith(ctx, oplogWrapper, ticket, metadata, msgs); err != nil { + return errors.Wrap(ctx, err, op, errors.WithMsg("unable to write oplog")) + } + + pl := allocPublicLibrary() + pl.PublicId = l.PublicId + if err := rr.LookupByPublicId(ctx, pl); err != nil { + return errors.Wrap(ctx, err, op, errors.WithMsg("unable to retrieve updated credential library")) + } + returnedCredentialLibrary = pl.toCredentialLibrary() + return nil }, ) diff --git a/internal/credential/vault/repository_credential_library_test.go b/internal/credential/vault/repository_credential_library_test.go index 2b3b876f92..ed6ce3923c 100644 --- a/internal/credential/vault/repository_credential_library_test.go +++ b/internal/credential/vault/repository_credential_library_test.go @@ -464,6 +464,20 @@ func TestRepository_UpdateCredentialLibrary(t *testing.T) { } } + changeCredentialType := func(t credential.Type) func(*CredentialLibrary) *CredentialLibrary { + return func(l *CredentialLibrary) *CredentialLibrary { + l.CredentialLibrary.CredentialType = string(t) + return l + } + } + + changeMappingOverride := func(m MappingOverride) func(*CredentialLibrary) *CredentialLibrary { + return func(l *CredentialLibrary) *CredentialLibrary { + l.MappingOverride = m + return l + } + } + makeNil := func() func(*CredentialLibrary) *CredentialLibrary { return func(l *CredentialLibrary) *CredentialLibrary { return nil @@ -508,6 +522,7 @@ func TestRepository_UpdateCredentialLibrary(t *testing.T) { wantCount int wantErr errors.Code }{ + { name: "nil-credential-library", orig: &CredentialLibrary{ @@ -746,17 +761,21 @@ func TestRepository_UpdateCredentialLibrary(t *testing.T) { { name: "change-vault-path", orig: &CredentialLibrary{ + MappingOverride: NewUserPasswordOverride(WithOverrideUsernameAttribute("orig-username")), CredentialLibrary: &store.CredentialLibrary{ - HttpMethod: "GET", - VaultPath: "/old/path", + HttpMethod: "GET", + VaultPath: "/old/path", + CredentialType: string(credential.UserPasswordType), }, }, chgFn: changeVaultPath("/new/path"), masks: []string{vaultPathField}, want: &CredentialLibrary{ + MappingOverride: NewUserPasswordOverride(WithOverrideUsernameAttribute("orig-username")), CredentialLibrary: &store.CredentialLibrary{ - HttpMethod: "GET", - VaultPath: "/new/path", + HttpMethod: "GET", + VaultPath: "/new/path", + CredentialType: string(credential.UserPasswordType), }, }, wantCount: 1, @@ -899,6 +918,196 @@ func TestRepository_UpdateCredentialLibrary(t *testing.T) { }, wantCount: 1, }, + { + name: "read-only-credential-type-in-field-mask", + orig: &CredentialLibrary{ + CredentialLibrary: &store.CredentialLibrary{ + HttpMethod: "GET", + VaultPath: "/some/path", + Name: "test-name-repo", + CredentialType: string(credential.UserPasswordType), + }, + }, + chgFn: changeCredentialType(credential.UnspecifiedType), + masks: []string{"PublicId", "CreateTime", "UpdateTime", "StoreId", "CredentialType"}, + wantErr: errors.InvalidFieldMask, + }, + { + name: "user-password-attributes-change-username-attribute", + orig: &CredentialLibrary{ + MappingOverride: NewUserPasswordOverride( + WithOverrideUsernameAttribute("orig-username"), + WithOverridePasswordAttribute("orig-password"), + ), + CredentialLibrary: &store.CredentialLibrary{ + HttpMethod: "GET", + VaultPath: "/some/path", + Name: "test-name-repo", + CredentialType: string(credential.UserPasswordType), + }, + }, + chgFn: changeMappingOverride( + NewUserPasswordOverride( + WithOverrideUsernameAttribute("changed-username"), + ), + ), + masks: []string{"MappingOverride"}, + want: &CredentialLibrary{ + MappingOverride: NewUserPasswordOverride( + WithOverrideUsernameAttribute("changed-username"), + ), + CredentialLibrary: &store.CredentialLibrary{ + HttpMethod: "GET", + VaultPath: "/some/path", + Name: "test-name-repo", + CredentialType: string(credential.UserPasswordType), + }, + }, + wantCount: 1, + }, + { + name: "user-password-attributes-change-password-attribute", + orig: &CredentialLibrary{ + MappingOverride: NewUserPasswordOverride( + WithOverrideUsernameAttribute("orig-username"), + WithOverridePasswordAttribute("orig-password"), + ), + CredentialLibrary: &store.CredentialLibrary{ + HttpMethod: "GET", + VaultPath: "/some/path", + Name: "test-name-repo", + CredentialType: string(credential.UserPasswordType), + }, + }, + chgFn: changeMappingOverride( + NewUserPasswordOverride( + WithOverridePasswordAttribute("changed-password"), + ), + ), + masks: []string{"MappingOverride"}, + want: &CredentialLibrary{ + MappingOverride: NewUserPasswordOverride( + WithOverridePasswordAttribute("changed-password"), + ), + CredentialLibrary: &store.CredentialLibrary{ + HttpMethod: "GET", + VaultPath: "/some/path", + Name: "test-name-repo", + CredentialType: string(credential.UserPasswordType), + }, + }, + wantCount: 1, + }, + { + name: "user-password-attributes-change-username-and-password-attributes", + orig: &CredentialLibrary{ + MappingOverride: NewUserPasswordOverride( + WithOverrideUsernameAttribute("orig-username"), + WithOverridePasswordAttribute("orig-password"), + ), + CredentialLibrary: &store.CredentialLibrary{ + HttpMethod: "GET", + VaultPath: "/some/path", + Name: "test-name-repo", + CredentialType: string(credential.UserPasswordType), + }, + }, + chgFn: changeMappingOverride( + NewUserPasswordOverride( + WithOverrideUsernameAttribute("changed-username"), + WithOverridePasswordAttribute("changed-password"), + ), + ), + masks: []string{"MappingOverride"}, + want: &CredentialLibrary{ + MappingOverride: NewUserPasswordOverride( + WithOverrideUsernameAttribute("changed-username"), + WithOverridePasswordAttribute("changed-password"), + ), + CredentialLibrary: &store.CredentialLibrary{ + HttpMethod: "GET", + VaultPath: "/some/path", + Name: "test-name-repo", + CredentialType: string(credential.UserPasswordType), + }, + }, + wantCount: 1, + }, + { + name: "no-mapping-override-change-username-and-password-attributes", + orig: &CredentialLibrary{ + CredentialLibrary: &store.CredentialLibrary{ + HttpMethod: "GET", + VaultPath: "/some/path", + Name: "test-name-repo", + CredentialType: string(credential.UserPasswordType), + }, + }, + chgFn: changeMappingOverride( + NewUserPasswordOverride( + WithOverrideUsernameAttribute("changed-username"), + WithOverridePasswordAttribute("changed-password"), + ), + ), + masks: []string{"MappingOverride"}, + want: &CredentialLibrary{ + MappingOverride: NewUserPasswordOverride( + WithOverrideUsernameAttribute("changed-username"), + WithOverridePasswordAttribute("changed-password"), + ), + CredentialLibrary: &store.CredentialLibrary{ + HttpMethod: "GET", + VaultPath: "/some/path", + Name: "test-name-repo", + CredentialType: string(credential.UserPasswordType), + }, + }, + wantCount: 1, + }, + { + name: "user-password-attributes-delete-mapping-override", + orig: &CredentialLibrary{ + MappingOverride: NewUserPasswordOverride( + WithOverrideUsernameAttribute("orig-username"), + WithOverridePasswordAttribute("orig-password"), + ), + CredentialLibrary: &store.CredentialLibrary{ + HttpMethod: "GET", + VaultPath: "/some/path", + Name: "test-name-repo", + CredentialType: string(credential.UserPasswordType), + }, + }, + chgFn: changeMappingOverride(nil), + masks: []string{"MappingOverride"}, + want: &CredentialLibrary{ + CredentialLibrary: &store.CredentialLibrary{ + HttpMethod: "GET", + VaultPath: "/some/path", + Name: "test-name-repo", + CredentialType: string(credential.UserPasswordType), + }, + }, + wantCount: 1, + }, + { + name: "set-mapping-override-on-unspecified-credential-type", + orig: &CredentialLibrary{ + CredentialLibrary: &store.CredentialLibrary{ + HttpMethod: "GET", + VaultPath: "/some/path", + Name: "test-name-repo", + }, + }, + chgFn: changeMappingOverride( + NewUserPasswordOverride( + WithOverrideUsernameAttribute("changed-username"), + WithOverridePasswordAttribute("changed-password"), + ), + ), + masks: []string{"MappingOverride"}, + wantErr: errors.VaultInvalidMappingOverride, + }, } for _, tt := range tests { @@ -940,19 +1149,36 @@ func TestRepository_UpdateCredentialLibrary(t *testing.T) { underlyingDB, err := conn.SqlDB(ctx) require.NoError(err) dbassert := dbassert.New(t, underlyingDB) - if tt.want.Name == "" { + + switch tt.want.Name { + case "": dbassert.IsNull(got, "name") - return + default: + assert.Equal(tt.want.Name, got.Name) } - assert.Equal(tt.want.Name, got.Name) - if tt.want.Description == "" { + + switch tt.want.Description { + case "": dbassert.IsNull(got, "description") - return + default: + assert.Equal(tt.want.Description, got.Description) } - assert.Equal(tt.want.Description, got.Description) + if tt.wantCount > 0 { assert.NoError(db.TestVerifyOplog(t, rw, got.GetPublicId(), db.WithOperation(oplog.OpType_OP_TYPE_UPDATE), db.WithCreateNotBefore(10*time.Second))) } + + switch w := tt.want.MappingOverride.(type) { + case nil: + assert.Nil(got.MappingOverride) + case *UserPasswordOverride: + g, ok := got.MappingOverride.(*UserPasswordOverride) + require.True(ok) + assert.Equal(w.UsernameAttribute, g.UsernameAttribute) + assert.Equal(w.PasswordAttribute, g.PasswordAttribute) + default: + assert.Fail("Unknown mapping override") + } }) }