refactor Update/Delete to return the number of rows affected. (#46)

pull/48/head
Jim 6 years ago committed by GitHub
parent 61ca88dac9
commit c04c73453f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -21,6 +21,8 @@ var (
ErrRecordNotFound = errors.New("record not found")
)
const NoRowsAffected = 0
// Reader interface defines lookups/searching for resources
type Reader interface {
// LookupByName will lookup resource by its friendly name which must be unique
@ -47,12 +49,13 @@ type Writer interface {
// DoTx will wrap the TxHandler in a retryable transaction
DoTx(ctx context.Context, retries uint, backOff Backoff, Handler TxHandler) (RetryInfo, error)
// Update an object in the db, if there's a fieldMask then only the field_mask.proto paths are updated, otherwise
// it will send every field to the DB. options: WithOplog
// the caller is responsible for the transaction life cycle of the writer
// and if an error is returned the caller must decide what to do with
// the transaction, which almost always should be to rollback.
Update(ctx context.Context, i interface{}, fieldMaskPaths []string, opt ...Option) error
// Update an object in the db, if there's a fieldMask then only the
// field_mask.proto paths are updated, otherwise it will send every field to
// the DB. options: WithOplog the caller is responsible for the transaction
// life cycle of the writer and if an error is returned the caller must
// decide what to do with the transaction, which almost always should be to
// rollback. Update returns the number of rows updated or an error.
Update(ctx context.Context, i interface{}, fieldMaskPaths []string, opt ...Option) (int, error)
// Create an object in the db with options: WithOplog
// the caller is responsible for the transaction life cycle of the writer
@ -63,8 +66,9 @@ type Writer interface {
// Delete an object in the db with options: WithOplog
// the caller is responsible for the transaction life cycle of the writer
// and if an error is returned the caller must decide what to do with
// the transaction, which almost always should be to rollback.
Delete(ctx context.Context, i interface{}, opt ...Option) error
// the transaction, which almost always should be to rollback. Delete
// returns the number of rows deleted or an error.
Delete(ctx context.Context, i interface{}, opt ...Option) (int, error)
// DB returns the sql.DB
DB() (*sql.DB, error)
@ -159,6 +163,13 @@ func (rw *Db) Create(ctx context.Context, i interface{}, opt ...Option) error {
opts := GetOpts(opt...)
withOplog := opts.withOplog
withDebug := opts.withDebug
if withOplog {
// let's validate oplog options before we start writing to the database
_, err := validateOplogArgs(i, opts)
if err != nil {
return err
}
}
if withDebug {
rw.underlying.LogMode(true)
defer rw.underlying.LogMode(false)
@ -185,30 +196,39 @@ func (rw *Db) Create(ctx context.Context, i interface{}, opt ...Option) error {
return nil
}
// Update an object in the db, if there's a fieldMask then only the field_mask.proto paths are updated, otherwise
// it will send every field to the DB. Update supports embedding a struct (or structPtr) one level deep for updating
func (rw *Db) Update(ctx context.Context, i interface{}, fieldMaskPaths []string, opt ...Option) error {
// Update an object in the db, if there's a fieldMask then only the
// field_mask.proto paths are updated, otherwise it will send every field to the
// DB. Update supports embedding a struct (or structPtr) one level deep for
// updating. Update returns the number of rows updated and any errors.
func (rw *Db) Update(ctx context.Context, i interface{}, fieldMaskPaths []string, opt ...Option) (int, error) {
if rw.underlying == nil {
return errors.New("update underlying db is nil")
return NoRowsAffected, errors.New("update underlying db is nil")
}
opts := GetOpts(opt...)
withDebug := opts.withDebug
withOplog := opts.withOplog
if withOplog {
// let's validate oplog options before we start writing to the database
_, err := validateOplogArgs(i, opts)
if err != nil {
return NoRowsAffected, err
}
}
if withDebug {
rw.underlying.LogMode(true)
defer rw.underlying.LogMode(false)
}
if i == nil {
return errors.New("update interface is nil")
return NoRowsAffected, errors.New("update interface is nil")
}
if vetter, ok := i.(VetForWriter); ok {
if err := vetter.VetForWrite(ctx, rw, UpdateOp, WithFieldMaskPaths(fieldMaskPaths)); err != nil {
return fmt.Errorf("error on update %w", err)
return NoRowsAffected, fmt.Errorf("error on update %w", err)
}
}
if len(fieldMaskPaths) == 0 {
if err := rw.underlying.Save(i).Error; err != nil {
return fmt.Errorf("error updating: %w", err)
return NoRowsAffected, fmt.Errorf("error updating: %w", err)
}
}
updateFields := map[string]interface{}{}
@ -246,62 +266,83 @@ func (rw *Db) Update(ctx context.Context, i interface{}, fieldMaskPaths []string
}
}
if len(updateFields) == 0 {
return fmt.Errorf("error no update fields matched using fieldMaskPaths: %s", fieldMaskPaths)
return NoRowsAffected, fmt.Errorf("error no update fields matched using fieldMaskPaths: %s", fieldMaskPaths)
}
if err := rw.underlying.Model(i).Updates(updateFields).Error; err != nil {
return fmt.Errorf("error updating: %w", err)
underlying := rw.underlying.Model(i).Updates(updateFields)
if underlying.Error != nil {
return NoRowsAffected, fmt.Errorf("error updating: %w", underlying.Error)
}
rowsUpdated := int(underlying.RowsAffected)
if withOplog {
if err := rw.addOplog(ctx, UpdateOp, opts, i); err != nil {
return err
return rowsUpdated, err
}
}
// we need to force a lookupAfterWrite so the resource returned is correctly initialized
// from the db
opt = append(opt, WithLookup(true))
if err := rw.lookupAfterWrite(ctx, i, opt...); err != nil {
return fmt.Errorf("lookup error after update: %w", err)
return NoRowsAffected, fmt.Errorf("lookup error after update: %w", err)
}
return nil
return rowsUpdated, nil
}
// Delete an object in the db with options: WithOplog (which requires WithMetadata, WithWrapper)
func (rw *Db) Delete(ctx context.Context, i interface{}, opt ...Option) error {
// Delete an object in the db with options: WithOplog (which requires
// WithMetadata, WithWrapper). Delete returns the number of rows deleted and
// any errors.
func (rw *Db) Delete(ctx context.Context, i interface{}, opt ...Option) (int, error) {
if rw.underlying == nil {
return errors.New("delete underlying db is nil")
return NoRowsAffected, errors.New("delete underlying db is nil")
}
if i == nil {
return errors.New("delete interface is nil")
return NoRowsAffected, errors.New("delete interface is nil")
}
opts := GetOpts(opt...)
withDebug := opts.withDebug
withOplog := opts.withOplog
if withOplog {
_, err := validateOplogArgs(i, opts)
if err != nil {
return NoRowsAffected, err
}
}
if withDebug {
rw.underlying.LogMode(true)
defer rw.underlying.LogMode(false)
}
if err := rw.underlying.Delete(i).Error; err != nil {
return fmt.Errorf("error deleting: %w", err)
underlying := rw.underlying.Delete(i)
if underlying.Error != nil {
return NoRowsAffected, fmt.Errorf("error deleting: %w", underlying.Error)
}
rowsDeleted := int(underlying.RowsAffected)
if withOplog {
if err := rw.addOplog(ctx, DeleteOp, opts, i); err != nil {
return err
return rowsDeleted, err
}
}
return nil
return rowsDeleted, nil
}
func (rw *Db) addOplog(ctx context.Context, opType OpType, opts Options, i interface{}) error {
func validateOplogArgs(i interface{}, opts Options) (oplog.ReplayableMessage, error) {
oplogArgs := opts.oplogOpts
if oplogArgs.wrapper == nil {
return errors.New("error wrapper is nil for WithWrapper")
return nil, errors.New("error no wrapper WithOplog")
}
if len(oplogArgs.metadata) == 0 {
return errors.New("error no metadata for WithOplog")
return nil, errors.New("error no metadata for WithOplog")
}
replayable, ok := i.(oplog.ReplayableMessage)
if !ok {
return errors.New("error not a replayable message for WithOplog")
return nil, errors.New("error not a replayable message for WithOplog")
}
return replayable, nil
}
func (rw *Db) addOplog(ctx context.Context, opType OpType, opts Options, i interface{}) error {
oplogArgs := opts.oplogOpts
replayable, err := validateOplogArgs(i, opts)
if err != nil {
return err
}
gdb := rw.underlying
withDebug := opts.withDebug

@ -37,8 +37,9 @@ func TestDb_Update(t *testing.T) {
assert.Equal(foundUser.Id, user.Id)
user.Name = "friendly-" + id
err = w.Update(context.Background(), user, []string{"Name"})
rowsUpdated, err := w.Update(context.Background(), user, []string{"Name"})
assert.Nil(err)
assert.Equal(1, rowsUpdated)
err = w.LookupByPublicId(context.Background(), foundUser)
assert.Nil(err)
@ -63,7 +64,7 @@ func TestDb_Update(t *testing.T) {
assert.Equal(foundUser.Id, user.Id)
user.Name = "friendly-" + id
err = w.Update(context.Background(), user, []string{"Name"},
rowsUpdated, err := w.Update(context.Background(), user, []string{"Name"},
// write oplogs for this update
WithOplog(
TestWrapper(t),
@ -75,6 +76,7 @@ func TestDb_Update(t *testing.T) {
}),
)
assert.Nil(err)
assert.Equal(1, rowsUpdated)
err = w.LookupByPublicId(context.Background(), foundUser)
assert.Nil(err)
@ -95,8 +97,9 @@ func TestDb_Update(t *testing.T) {
user, err := db_test.NewTestUser()
assert.Nil(err)
user.Name = "foo-" + id
err = w.Update(context.Background(), user, nil)
rowsUpdated, err := w.Update(context.Background(), user, nil)
assert.True(err != nil)
assert.Equal(0, rowsUpdated)
assert.Equal("update underlying db is nil", err.Error())
})
t.Run("no-wrapper-WithOplog", func(t *testing.T) {
@ -118,7 +121,7 @@ func TestDb_Update(t *testing.T) {
assert.Equal(foundUser.Id, user.Id)
user.Name = "friendly-" + id
err = w.Update(context.Background(), user, []string{"Name"},
rowsUpdated, err := w.Update(context.Background(), user, []string{"Name"},
WithOplog(
nil,
oplog.Metadata{
@ -128,7 +131,8 @@ func TestDb_Update(t *testing.T) {
}),
)
assert.True(err != nil)
assert.Equal("error wrapper is nil for WithWrapper", err.Error())
assert.Equal(0, rowsUpdated)
assert.Equal("error no wrapper WithOplog", err.Error())
})
t.Run("no-metadata-WithOplog", func(t *testing.T) {
w := Db{underlying: db}
@ -149,13 +153,14 @@ func TestDb_Update(t *testing.T) {
assert.Equal(foundUser.Id, user.Id)
user.Name = "friendly-" + id
err = w.Update(context.Background(), user, []string{"Name"},
rowsUpdated, err := w.Update(context.Background(), user, []string{"Name"},
WithOplog(
TestWrapper(t),
nil,
),
)
assert.True(err != nil)
assert.Equal(0, rowsUpdated)
assert.Equal("error no metadata for WithOplog", err.Error())
})
}
@ -235,7 +240,7 @@ func TestDb_Create(t *testing.T) {
),
)
assert.True(err != nil)
assert.Equal("error wrapper is nil for WithWrapper", err.Error())
assert.Equal("error no wrapper WithOplog", err.Error())
})
t.Run("no-metadata-WithOplog", func(t *testing.T) {
w := Db{underlying: db}
@ -627,8 +632,9 @@ func TestDb_Delete(t *testing.T) {
assert.Nil(err)
assert.Equal(foundUser.Id, user.Id)
err = w.Delete(context.Background(), user)
rowsDeleted, err := w.Delete(context.Background(), user)
assert.Nil(err)
assert.Equal(1, rowsDeleted)
err = w.LookupByPublicId(context.Background(), foundUser)
assert.True(err != nil)
@ -663,7 +669,7 @@ func TestDb_Delete(t *testing.T) {
assert.Nil(err)
assert.Equal(foundUser.Id, user.Id)
err = w.Delete(
rowsDeleted, err := w.Delete(
context.Background(),
user,
WithOplog(
@ -676,6 +682,7 @@ func TestDb_Delete(t *testing.T) {
),
)
assert.Nil(err)
assert.Equal(1, rowsDeleted)
err = w.LookupByPublicId(context.Background(), foundUser)
assert.True(err != nil)
@ -710,7 +717,7 @@ func TestDb_Delete(t *testing.T) {
assert.Nil(err)
assert.Equal(foundUser.Id, user.Id)
err = w.Delete(
rowsDeleted, err := w.Delete(
context.Background(),
user,
WithOplog(
@ -723,7 +730,8 @@ func TestDb_Delete(t *testing.T) {
),
)
assert.True(err != nil)
assert.Equal("error wrapper is nil for WithWrapper", err.Error())
assert.Equal(0, rowsDeleted)
assert.Equal("error no wrapper WithOplog", err.Error())
})
t.Run("no-metadata-WithOplog", func(t *testing.T) {
w := Db{underlying: db}
@ -754,7 +762,7 @@ func TestDb_Delete(t *testing.T) {
assert.Nil(err)
assert.Equal(foundUser.Id, user.Id)
err = w.Delete(
rowsDeleted, err := w.Delete(
context.Background(),
user,
WithOplog(
@ -763,6 +771,7 @@ func TestDb_Delete(t *testing.T) {
),
)
assert.True(err != nil)
assert.Equal(0, rowsDeleted)
assert.Equal("error no metadata for WithOplog", err.Error())
})
t.Run("nil-tx", func(t *testing.T) {

@ -69,20 +69,21 @@ func (r *Repository) create(ctx context.Context, resource Resource, opt ...Optio
}
// Update will update an iam resource in the db repository with an oplog entry
func (r *Repository) update(ctx context.Context, resource Resource, fieldMaskPaths []string, opt ...Option) (Resource, error) {
func (r *Repository) update(ctx context.Context, resource Resource, fieldMaskPaths []string, opt ...Option) (Resource, int, error) {
if resource == nil {
return nil, errors.New("error updating resource that is nil")
return nil, db.NoRowsAffected, errors.New("error updating resource that is nil")
}
resourceCloner, ok := resource.(Clonable)
if !ok {
return nil, errors.New("error resource is not clonable for update")
return nil, db.NoRowsAffected, errors.New("error resource is not clonable for update")
}
metadata, err := r.stdMetadata(ctx, resource)
if err != nil {
return nil, fmt.Errorf("error getting metadata for update: %w", err)
return nil, db.NoRowsAffected, fmt.Errorf("error getting metadata for update: %w", err)
}
metadata["op-type"] = []string{strconv.Itoa(int(oplog.OpType_OP_TYPE_UPDATE))}
var rowsUpdated int
var returnedResource interface{}
_, err = r.writer.DoTx(
ctx,
@ -90,15 +91,17 @@ func (r *Repository) update(ctx context.Context, resource Resource, fieldMaskPat
db.ExpBackoff{},
func(w db.Writer) error {
returnedResource = resourceCloner.Clone()
return w.Update(
var err error
rowsUpdated, err = w.Update(
ctx,
returnedResource,
fieldMaskPaths,
db.WithOplog(r.wrapper, metadata),
)
return err
},
)
return returnedResource.(Resource), err
return returnedResource.(Resource), rowsUpdated, err
}
func (r *Repository) stdMetadata(ctx context.Context, resource Resource) (oplog.Metadata, error) {

@ -21,15 +21,15 @@ func (r *Repository) CreateScope(ctx context.Context, scope *Scope, opt ...Optio
}
// UpdateScope will update a scope in the repository and return the written scope
func (r *Repository) UpdateScope(ctx context.Context, scope *Scope, fieldMaskPaths []string, opt ...Option) (*Scope, error) {
func (r *Repository) UpdateScope(ctx context.Context, scope *Scope, fieldMaskPaths []string, opt ...Option) (*Scope, int, error) {
if scope == nil {
return nil, errors.New("error scope is nil for update")
return nil, db.NoRowsAffected, errors.New("error scope is nil for update")
}
resource, err := r.update(ctx, scope, fieldMaskPaths)
resource, rowsUpdated, err := r.update(ctx, scope, fieldMaskPaths)
if err != nil {
return nil, fmt.Errorf("failed to update scope: %w", err)
return nil, db.NoRowsAffected, fmt.Errorf("failed to update scope: %w", err)
}
return resource.(*Scope), err
return resource.(*Scope), rowsUpdated, err
}
// LookupScope will look up a scope in the repository. If the scope is not

@ -89,8 +89,9 @@ func Test_Repository_UpdateScope(t *testing.T) {
s.Name = "fname-" + id
s.Description = "desc-id" // not in the field mask paths
s, err = repo.UpdateScope(context.Background(), s, []string{"Name"})
s, updatedRows, err := repo.UpdateScope(context.Background(), s, []string{"Name"})
assert.Nil(err)
assert.Equal(1, updatedRows)
assert.True(s != nil)
assert.Equal(s.GetName(), "fname-"+id)
assert.Equal(foundScope.GetDescription(), "") // should be "" after update in db
@ -128,8 +129,9 @@ func Test_Repository_UpdateScope(t *testing.T) {
assert.True(project != nil)
project.ParentId = project.PublicId
project, err = repo.UpdateScope(context.Background(), project, []string{"ParentId"})
project, updatedRows, err := repo.UpdateScope(context.Background(), project, []string{"ParentId"})
assert.True(err != nil)
assert.Equal(0, updatedRows)
assert.Equal("failed to update scope: error on update you cannot change a scope's parent", err.Error())
})
}

@ -169,9 +169,10 @@ func Test_dbRepository_update(t *testing.T) {
assert.Equal(retScope.GetName(), "")
retScope.(*Scope).Name = "fname-" + id
retScope, err = repo.update(context.Background(), retScope, []string{"Name"})
retScope, updatedRows, err := repo.update(context.Background(), retScope, []string{"Name"})
assert.Nil(err)
assert.True(retScope != nil)
assert.Equal(1, updatedRows)
assert.Equal(retScope.GetName(), "fname-"+id)
foundScope, err := repo.LookupScope(context.Background(), WithName("fname-"+id))
@ -190,9 +191,10 @@ func Test_dbRepository_update(t *testing.T) {
rw := db.New(conn)
wrapper := db.TestWrapper(t)
repo, err := NewRepository(rw, rw, wrapper)
resource, err := repo.update(context.Background(), nil, nil)
resource, updatedRows, err := repo.update(context.Background(), nil, nil)
assert.True(err != nil)
assert.True(resource == nil)
assert.Equal(0, updatedRows)
assert.Equal(err.Error(), "error updating resource that is nil")
})
}

@ -106,8 +106,9 @@ func Test_ScopeUpdate(t *testing.T) {
id, err := uuid.GenerateUUID()
assert.Nil(err)
s.Name = id
err = w.Update(context.Background(), s, []string{"Name"})
updatedRows, err := w.Update(context.Background(), s, []string{"Name"})
assert.Nil(err)
assert.Equal(1, updatedRows)
})
t.Run("type-update-not-allowed", func(t *testing.T) {
w := db.New(conn)
@ -119,8 +120,9 @@ func Test_ScopeUpdate(t *testing.T) {
assert.True(s.PublicId != "")
s.Type = ProjectScope.String()
err = w.Update(context.Background(), s, []string{"Type"})
updatedRows, err := w.Update(context.Background(), s, []string{"Type"})
assert.True(err != nil)
assert.Equal(0, updatedRows)
})
}
func Test_ScopeGetScope(t *testing.T) {

Loading…
Cancel
Save