diff --git a/internal/db/read_writer.go b/internal/db/read_writer.go index 459266d867..a1f2e10e33 100644 --- a/internal/db/read_writer.go +++ b/internal/db/read_writer.go @@ -21,6 +21,8 @@ var ( ErrRecordNotFound = errors.New("record not found") ) +const NoRowsAffected = 0 + // Reader interface defines lookups/searching for resources type Reader interface { // LookupByName will lookup resource by its friendly name which must be unique @@ -47,12 +49,13 @@ type Writer interface { // DoTx will wrap the TxHandler in a retryable transaction DoTx(ctx context.Context, retries uint, backOff Backoff, Handler TxHandler) (RetryInfo, error) - // Update an object in the db, if there's a fieldMask then only the field_mask.proto paths are updated, otherwise - // it will send every field to the DB. options: WithOplog - // the caller is responsible for the transaction life cycle of the writer - // and if an error is returned the caller must decide what to do with - // the transaction, which almost always should be to rollback. - Update(ctx context.Context, i interface{}, fieldMaskPaths []string, opt ...Option) error + // Update an object in the db, if there's a fieldMask then only the + // field_mask.proto paths are updated, otherwise it will send every field to + // the DB. options: WithOplog the caller is responsible for the transaction + // life cycle of the writer and if an error is returned the caller must + // decide what to do with the transaction, which almost always should be to + // rollback. Update returns the number of rows updated or an error. + Update(ctx context.Context, i interface{}, fieldMaskPaths []string, opt ...Option) (int, error) // Create an object in the db with options: WithOplog // the caller is responsible for the transaction life cycle of the writer @@ -63,8 +66,9 @@ type Writer interface { // Delete an object in the db with options: WithOplog // the caller is responsible for the transaction life cycle of the writer // and if an error is returned the caller must decide what to do with - // the transaction, which almost always should be to rollback. - Delete(ctx context.Context, i interface{}, opt ...Option) error + // the transaction, which almost always should be to rollback. Delete + // returns the number of rows deleted or an error. + Delete(ctx context.Context, i interface{}, opt ...Option) (int, error) // DB returns the sql.DB DB() (*sql.DB, error) @@ -159,6 +163,13 @@ func (rw *Db) Create(ctx context.Context, i interface{}, opt ...Option) error { opts := GetOpts(opt...) withOplog := opts.withOplog withDebug := opts.withDebug + if withOplog { + // let's validate oplog options before we start writing to the database + _, err := validateOplogArgs(i, opts) + if err != nil { + return err + } + } if withDebug { rw.underlying.LogMode(true) defer rw.underlying.LogMode(false) @@ -185,30 +196,39 @@ func (rw *Db) Create(ctx context.Context, i interface{}, opt ...Option) error { return nil } -// Update an object in the db, if there's a fieldMask then only the field_mask.proto paths are updated, otherwise -// it will send every field to the DB. Update supports embedding a struct (or structPtr) one level deep for updating -func (rw *Db) Update(ctx context.Context, i interface{}, fieldMaskPaths []string, opt ...Option) error { +// Update an object in the db, if there's a fieldMask then only the +// field_mask.proto paths are updated, otherwise it will send every field to the +// DB. Update supports embedding a struct (or structPtr) one level deep for +// updating. Update returns the number of rows updated and any errors. +func (rw *Db) Update(ctx context.Context, i interface{}, fieldMaskPaths []string, opt ...Option) (int, error) { if rw.underlying == nil { - return errors.New("update underlying db is nil") + return NoRowsAffected, errors.New("update underlying db is nil") } opts := GetOpts(opt...) withDebug := opts.withDebug withOplog := opts.withOplog + if withOplog { + // let's validate oplog options before we start writing to the database + _, err := validateOplogArgs(i, opts) + if err != nil { + return NoRowsAffected, err + } + } if withDebug { rw.underlying.LogMode(true) defer rw.underlying.LogMode(false) } if i == nil { - return errors.New("update interface is nil") + return NoRowsAffected, errors.New("update interface is nil") } if vetter, ok := i.(VetForWriter); ok { if err := vetter.VetForWrite(ctx, rw, UpdateOp, WithFieldMaskPaths(fieldMaskPaths)); err != nil { - return fmt.Errorf("error on update %w", err) + return NoRowsAffected, fmt.Errorf("error on update %w", err) } } if len(fieldMaskPaths) == 0 { if err := rw.underlying.Save(i).Error; err != nil { - return fmt.Errorf("error updating: %w", err) + return NoRowsAffected, fmt.Errorf("error updating: %w", err) } } updateFields := map[string]interface{}{} @@ -246,62 +266,83 @@ func (rw *Db) Update(ctx context.Context, i interface{}, fieldMaskPaths []string } } if len(updateFields) == 0 { - return fmt.Errorf("error no update fields matched using fieldMaskPaths: %s", fieldMaskPaths) + return NoRowsAffected, fmt.Errorf("error no update fields matched using fieldMaskPaths: %s", fieldMaskPaths) } - if err := rw.underlying.Model(i).Updates(updateFields).Error; err != nil { - return fmt.Errorf("error updating: %w", err) + underlying := rw.underlying.Model(i).Updates(updateFields) + if underlying.Error != nil { + return NoRowsAffected, fmt.Errorf("error updating: %w", underlying.Error) } + rowsUpdated := int(underlying.RowsAffected) if withOplog { if err := rw.addOplog(ctx, UpdateOp, opts, i); err != nil { - return err + return rowsUpdated, err } } // we need to force a lookupAfterWrite so the resource returned is correctly initialized // from the db opt = append(opt, WithLookup(true)) if err := rw.lookupAfterWrite(ctx, i, opt...); err != nil { - return fmt.Errorf("lookup error after update: %w", err) + return NoRowsAffected, fmt.Errorf("lookup error after update: %w", err) } - return nil + return rowsUpdated, nil } -// Delete an object in the db with options: WithOplog (which requires WithMetadata, WithWrapper) -func (rw *Db) Delete(ctx context.Context, i interface{}, opt ...Option) error { +// Delete an object in the db with options: WithOplog (which requires +// WithMetadata, WithWrapper). Delete returns the number of rows deleted and +// any errors. +func (rw *Db) Delete(ctx context.Context, i interface{}, opt ...Option) (int, error) { if rw.underlying == nil { - return errors.New("delete underlying db is nil") + return NoRowsAffected, errors.New("delete underlying db is nil") } if i == nil { - return errors.New("delete interface is nil") + return NoRowsAffected, errors.New("delete interface is nil") } opts := GetOpts(opt...) withDebug := opts.withDebug withOplog := opts.withOplog + if withOplog { + _, err := validateOplogArgs(i, opts) + if err != nil { + return NoRowsAffected, err + } + } if withDebug { rw.underlying.LogMode(true) defer rw.underlying.LogMode(false) } - if err := rw.underlying.Delete(i).Error; err != nil { - return fmt.Errorf("error deleting: %w", err) + underlying := rw.underlying.Delete(i) + if underlying.Error != nil { + return NoRowsAffected, fmt.Errorf("error deleting: %w", underlying.Error) } + rowsDeleted := int(underlying.RowsAffected) if withOplog { if err := rw.addOplog(ctx, DeleteOp, opts, i); err != nil { - return err + return rowsDeleted, err } } - return nil + return rowsDeleted, nil } -func (rw *Db) addOplog(ctx context.Context, opType OpType, opts Options, i interface{}) error { +func validateOplogArgs(i interface{}, opts Options) (oplog.ReplayableMessage, error) { oplogArgs := opts.oplogOpts if oplogArgs.wrapper == nil { - return errors.New("error wrapper is nil for WithWrapper") + return nil, errors.New("error no wrapper WithOplog") } if len(oplogArgs.metadata) == 0 { - return errors.New("error no metadata for WithOplog") + return nil, errors.New("error no metadata for WithOplog") } replayable, ok := i.(oplog.ReplayableMessage) if !ok { - return errors.New("error not a replayable message for WithOplog") + return nil, errors.New("error not a replayable message for WithOplog") + } + return replayable, nil +} + +func (rw *Db) addOplog(ctx context.Context, opType OpType, opts Options, i interface{}) error { + oplogArgs := opts.oplogOpts + replayable, err := validateOplogArgs(i, opts) + if err != nil { + return err } gdb := rw.underlying withDebug := opts.withDebug diff --git a/internal/db/read_writer_test.go b/internal/db/read_writer_test.go index 52aec02df5..5df5581afa 100644 --- a/internal/db/read_writer_test.go +++ b/internal/db/read_writer_test.go @@ -37,8 +37,9 @@ func TestDb_Update(t *testing.T) { assert.Equal(foundUser.Id, user.Id) user.Name = "friendly-" + id - err = w.Update(context.Background(), user, []string{"Name"}) + rowsUpdated, err := w.Update(context.Background(), user, []string{"Name"}) assert.Nil(err) + assert.Equal(1, rowsUpdated) err = w.LookupByPublicId(context.Background(), foundUser) assert.Nil(err) @@ -63,7 +64,7 @@ func TestDb_Update(t *testing.T) { assert.Equal(foundUser.Id, user.Id) user.Name = "friendly-" + id - err = w.Update(context.Background(), user, []string{"Name"}, + rowsUpdated, err := w.Update(context.Background(), user, []string{"Name"}, // write oplogs for this update WithOplog( TestWrapper(t), @@ -75,6 +76,7 @@ func TestDb_Update(t *testing.T) { }), ) assert.Nil(err) + assert.Equal(1, rowsUpdated) err = w.LookupByPublicId(context.Background(), foundUser) assert.Nil(err) @@ -95,8 +97,9 @@ func TestDb_Update(t *testing.T) { user, err := db_test.NewTestUser() assert.Nil(err) user.Name = "foo-" + id - err = w.Update(context.Background(), user, nil) + rowsUpdated, err := w.Update(context.Background(), user, nil) assert.True(err != nil) + assert.Equal(0, rowsUpdated) assert.Equal("update underlying db is nil", err.Error()) }) t.Run("no-wrapper-WithOplog", func(t *testing.T) { @@ -118,7 +121,7 @@ func TestDb_Update(t *testing.T) { assert.Equal(foundUser.Id, user.Id) user.Name = "friendly-" + id - err = w.Update(context.Background(), user, []string{"Name"}, + rowsUpdated, err := w.Update(context.Background(), user, []string{"Name"}, WithOplog( nil, oplog.Metadata{ @@ -128,7 +131,8 @@ func TestDb_Update(t *testing.T) { }), ) assert.True(err != nil) - assert.Equal("error wrapper is nil for WithWrapper", err.Error()) + assert.Equal(0, rowsUpdated) + assert.Equal("error no wrapper WithOplog", err.Error()) }) t.Run("no-metadata-WithOplog", func(t *testing.T) { w := Db{underlying: db} @@ -149,13 +153,14 @@ func TestDb_Update(t *testing.T) { assert.Equal(foundUser.Id, user.Id) user.Name = "friendly-" + id - err = w.Update(context.Background(), user, []string{"Name"}, + rowsUpdated, err := w.Update(context.Background(), user, []string{"Name"}, WithOplog( TestWrapper(t), nil, ), ) assert.True(err != nil) + assert.Equal(0, rowsUpdated) assert.Equal("error no metadata for WithOplog", err.Error()) }) } @@ -235,7 +240,7 @@ func TestDb_Create(t *testing.T) { ), ) assert.True(err != nil) - assert.Equal("error wrapper is nil for WithWrapper", err.Error()) + assert.Equal("error no wrapper WithOplog", err.Error()) }) t.Run("no-metadata-WithOplog", func(t *testing.T) { w := Db{underlying: db} @@ -627,8 +632,9 @@ func TestDb_Delete(t *testing.T) { assert.Nil(err) assert.Equal(foundUser.Id, user.Id) - err = w.Delete(context.Background(), user) + rowsDeleted, err := w.Delete(context.Background(), user) assert.Nil(err) + assert.Equal(1, rowsDeleted) err = w.LookupByPublicId(context.Background(), foundUser) assert.True(err != nil) @@ -663,7 +669,7 @@ func TestDb_Delete(t *testing.T) { assert.Nil(err) assert.Equal(foundUser.Id, user.Id) - err = w.Delete( + rowsDeleted, err := w.Delete( context.Background(), user, WithOplog( @@ -676,6 +682,7 @@ func TestDb_Delete(t *testing.T) { ), ) assert.Nil(err) + assert.Equal(1, rowsDeleted) err = w.LookupByPublicId(context.Background(), foundUser) assert.True(err != nil) @@ -710,7 +717,7 @@ func TestDb_Delete(t *testing.T) { assert.Nil(err) assert.Equal(foundUser.Id, user.Id) - err = w.Delete( + rowsDeleted, err := w.Delete( context.Background(), user, WithOplog( @@ -723,7 +730,8 @@ func TestDb_Delete(t *testing.T) { ), ) assert.True(err != nil) - assert.Equal("error wrapper is nil for WithWrapper", err.Error()) + assert.Equal(0, rowsDeleted) + assert.Equal("error no wrapper WithOplog", err.Error()) }) t.Run("no-metadata-WithOplog", func(t *testing.T) { w := Db{underlying: db} @@ -754,7 +762,7 @@ func TestDb_Delete(t *testing.T) { assert.Nil(err) assert.Equal(foundUser.Id, user.Id) - err = w.Delete( + rowsDeleted, err := w.Delete( context.Background(), user, WithOplog( @@ -763,6 +771,7 @@ func TestDb_Delete(t *testing.T) { ), ) assert.True(err != nil) + assert.Equal(0, rowsDeleted) assert.Equal("error no metadata for WithOplog", err.Error()) }) t.Run("nil-tx", func(t *testing.T) { diff --git a/internal/iam/repository.go b/internal/iam/repository.go index 14fe59df4d..5ddbaf3eac 100644 --- a/internal/iam/repository.go +++ b/internal/iam/repository.go @@ -69,20 +69,21 @@ func (r *Repository) create(ctx context.Context, resource Resource, opt ...Optio } // Update will update an iam resource in the db repository with an oplog entry -func (r *Repository) update(ctx context.Context, resource Resource, fieldMaskPaths []string, opt ...Option) (Resource, error) { +func (r *Repository) update(ctx context.Context, resource Resource, fieldMaskPaths []string, opt ...Option) (Resource, int, error) { if resource == nil { - return nil, errors.New("error updating resource that is nil") + return nil, db.NoRowsAffected, errors.New("error updating resource that is nil") } resourceCloner, ok := resource.(Clonable) if !ok { - return nil, errors.New("error resource is not clonable for update") + return nil, db.NoRowsAffected, errors.New("error resource is not clonable for update") } metadata, err := r.stdMetadata(ctx, resource) if err != nil { - return nil, fmt.Errorf("error getting metadata for update: %w", err) + return nil, db.NoRowsAffected, fmt.Errorf("error getting metadata for update: %w", err) } metadata["op-type"] = []string{strconv.Itoa(int(oplog.OpType_OP_TYPE_UPDATE))} + var rowsUpdated int var returnedResource interface{} _, err = r.writer.DoTx( ctx, @@ -90,15 +91,17 @@ func (r *Repository) update(ctx context.Context, resource Resource, fieldMaskPat db.ExpBackoff{}, func(w db.Writer) error { returnedResource = resourceCloner.Clone() - return w.Update( + var err error + rowsUpdated, err = w.Update( ctx, returnedResource, fieldMaskPaths, db.WithOplog(r.wrapper, metadata), ) + return err }, ) - return returnedResource.(Resource), err + return returnedResource.(Resource), rowsUpdated, err } func (r *Repository) stdMetadata(ctx context.Context, resource Resource) (oplog.Metadata, error) { diff --git a/internal/iam/repository_scope.go b/internal/iam/repository_scope.go index 7d444e2886..1a2450db2a 100644 --- a/internal/iam/repository_scope.go +++ b/internal/iam/repository_scope.go @@ -21,15 +21,15 @@ func (r *Repository) CreateScope(ctx context.Context, scope *Scope, opt ...Optio } // UpdateScope will update a scope in the repository and return the written scope -func (r *Repository) UpdateScope(ctx context.Context, scope *Scope, fieldMaskPaths []string, opt ...Option) (*Scope, error) { +func (r *Repository) UpdateScope(ctx context.Context, scope *Scope, fieldMaskPaths []string, opt ...Option) (*Scope, int, error) { if scope == nil { - return nil, errors.New("error scope is nil for update") + return nil, db.NoRowsAffected, errors.New("error scope is nil for update") } - resource, err := r.update(ctx, scope, fieldMaskPaths) + resource, rowsUpdated, err := r.update(ctx, scope, fieldMaskPaths) if err != nil { - return nil, fmt.Errorf("failed to update scope: %w", err) + return nil, db.NoRowsAffected, fmt.Errorf("failed to update scope: %w", err) } - return resource.(*Scope), err + return resource.(*Scope), rowsUpdated, err } // LookupScope will look up a scope in the repository. If the scope is not diff --git a/internal/iam/repository_scope_test.go b/internal/iam/repository_scope_test.go index e545e8bbc2..01f04bd601 100644 --- a/internal/iam/repository_scope_test.go +++ b/internal/iam/repository_scope_test.go @@ -89,8 +89,9 @@ func Test_Repository_UpdateScope(t *testing.T) { s.Name = "fname-" + id s.Description = "desc-id" // not in the field mask paths - s, err = repo.UpdateScope(context.Background(), s, []string{"Name"}) + s, updatedRows, err := repo.UpdateScope(context.Background(), s, []string{"Name"}) assert.Nil(err) + assert.Equal(1, updatedRows) assert.True(s != nil) assert.Equal(s.GetName(), "fname-"+id) assert.Equal(foundScope.GetDescription(), "") // should be "" after update in db @@ -128,8 +129,9 @@ func Test_Repository_UpdateScope(t *testing.T) { assert.True(project != nil) project.ParentId = project.PublicId - project, err = repo.UpdateScope(context.Background(), project, []string{"ParentId"}) + project, updatedRows, err := repo.UpdateScope(context.Background(), project, []string{"ParentId"}) assert.True(err != nil) + assert.Equal(0, updatedRows) assert.Equal("failed to update scope: error on update you cannot change a scope's parent", err.Error()) }) } diff --git a/internal/iam/repository_test.go b/internal/iam/repository_test.go index 9888669dbd..410e7f4fdc 100644 --- a/internal/iam/repository_test.go +++ b/internal/iam/repository_test.go @@ -169,9 +169,10 @@ func Test_dbRepository_update(t *testing.T) { assert.Equal(retScope.GetName(), "") retScope.(*Scope).Name = "fname-" + id - retScope, err = repo.update(context.Background(), retScope, []string{"Name"}) + retScope, updatedRows, err := repo.update(context.Background(), retScope, []string{"Name"}) assert.Nil(err) assert.True(retScope != nil) + assert.Equal(1, updatedRows) assert.Equal(retScope.GetName(), "fname-"+id) foundScope, err := repo.LookupScope(context.Background(), WithName("fname-"+id)) @@ -190,9 +191,10 @@ func Test_dbRepository_update(t *testing.T) { rw := db.New(conn) wrapper := db.TestWrapper(t) repo, err := NewRepository(rw, rw, wrapper) - resource, err := repo.update(context.Background(), nil, nil) + resource, updatedRows, err := repo.update(context.Background(), nil, nil) assert.True(err != nil) assert.True(resource == nil) + assert.Equal(0, updatedRows) assert.Equal(err.Error(), "error updating resource that is nil") }) } diff --git a/internal/iam/scope_test.go b/internal/iam/scope_test.go index 1456e89a9d..c809fac896 100644 --- a/internal/iam/scope_test.go +++ b/internal/iam/scope_test.go @@ -106,8 +106,9 @@ func Test_ScopeUpdate(t *testing.T) { id, err := uuid.GenerateUUID() assert.Nil(err) s.Name = id - err = w.Update(context.Background(), s, []string{"Name"}) + updatedRows, err := w.Update(context.Background(), s, []string{"Name"}) assert.Nil(err) + assert.Equal(1, updatedRows) }) t.Run("type-update-not-allowed", func(t *testing.T) { w := db.New(conn) @@ -119,8 +120,9 @@ func Test_ScopeUpdate(t *testing.T) { assert.True(s.PublicId != "") s.Type = ProjectScope.String() - err = w.Update(context.Background(), s, []string{"Type"}) + updatedRows, err := w.Update(context.Background(), s, []string{"Type"}) assert.True(err != nil) + assert.Equal(0, updatedRows) }) } func Test_ScopeGetScope(t *testing.T) {