From 2aca5fff030919399cea066f827e134fbb722fa5 Mon Sep 17 00:00:00 2001 From: Jim Date: Fri, 15 May 2020 08:38:36 -0400 Subject: [PATCH] fix transaction rollbacks in DoTx (#55) * fix transaction rollbacks in DoTx * better assertion * Update internal/db/read_writer_test.go Co-authored-by: Michael Gaffney Co-authored-by: Jim Lambert Co-authored-by: Michael Gaffney --- internal/db/read_writer.go | 34 ++++++++--------- internal/db/read_writer_test.go | 66 +++++++++++++++++++++++++++++++++ 2 files changed, 82 insertions(+), 18 deletions(-) diff --git a/internal/db/read_writer.go b/internal/db/read_writer.go index a1f2e10e33..b6bd29fdd9 100644 --- a/internal/db/read_writer.go +++ b/internal/db/read_writer.go @@ -273,7 +273,7 @@ func (rw *Db) Update(ctx context.Context, i interface{}, fieldMaskPaths []string return NoRowsAffected, fmt.Errorf("error updating: %w", underlying.Error) } rowsUpdated := int(underlying.RowsAffected) - if withOplog { + if withOplog && rowsUpdated > 0 { if err := rw.addOplog(ctx, UpdateOp, opts, i); err != nil { return rowsUpdated, err } @@ -315,7 +315,7 @@ func (rw *Db) Delete(ctx context.Context, i interface{}, opt ...Option) (int, er return NoRowsAffected, fmt.Errorf("error deleting: %w", underlying.Error) } rowsDeleted := int(underlying.RowsAffected) - if withOplog { + if withOplog && rowsDeleted > 0 { if err := rw.addOplog(ctx, DeleteOp, opts, i); err != nil { return rowsDeleted, err } @@ -401,34 +401,32 @@ func (w *Db) DoTx(ctx context.Context, retries uint, backOff Backoff, Handler Tx if attempts > retries+1 { return info, fmt.Errorf("Too many retries: %d of %d", attempts-1, retries+1) } + + // step one of this, start a transaction... newTx := w.underlying.BeginTx(ctx, nil) + if err := Handler(&Db{newTx}); err != nil { + if err := newTx.Rollback().Error; err != nil { + return info, err + } if errors.Is(err, oplog.ErrTicketAlreadyRedeemed) { - newTx.Rollback() // need to still handle rollback errors d := backOff.Duration(attempts) info.Retries++ info.Backoff = info.Backoff + d time.Sleep(d) continue - } else { - return info, err } - } else { - if err := newTx.Commit().Error; err != nil { - if errors.Is(err, oplog.ErrTicketAlreadyRedeemed) { - d := backOff.Duration(attempts) - info.Retries++ - info.Backoff = info.Backoff + d - time.Sleep(d) - continue - } else { - return info, err - } + return info, err + } + + if err := newTx.Commit().Error; err != nil { + if err := newTx.Rollback().Error; err != nil { + return info, err } - break // it all worked!!! + return info, err } + return info, nil // it all worked!!! } - return info, nil } // LookupByName will lookup resource my its friendly name which must be unique diff --git a/internal/db/read_writer_test.go b/internal/db/read_writer_test.go index e604f69b51..c8f53238a7 100644 --- a/internal/db/read_writer_test.go +++ b/internal/db/read_writer_test.go @@ -3,6 +3,7 @@ package db import ( "context" "errors" + "fmt" "testing" "github.com/hashicorp/go-uuid" @@ -604,6 +605,71 @@ func TestDb_DoTx(t *testing.T) { assert.Equal(3, got.Retries) assert.Equal("Too many retries: 3 of 3", err.Error()) }) + t.Run("updating-good-bad-good", func(t *testing.T) { + w := Db{underlying: db} + id, err := uuid.GenerateUUID() + assert.NoError(err) + user, err := db_test.NewTestUser() + assert.NoError(err) + user.Name = "foo-" + id + err = w.Create(context.Background(), user) + assert.NoError(err) + assert.NotZero(user.Id) + + _, err = w.DoTx(context.Background(), 10, ExpBackoff{}, func(w Writer) error { + user.Name = "friendly-" + id + rowsUpdated, err := w.Update(context.Background(), user, []string{"Name"}) + if err != nil { + return err + } + if rowsUpdated != 1 { + return fmt.Errorf("error in number of rows updated %d", rowsUpdated) + } + return nil + }) + assert.NoError(err) + + foundUser, err := db_test.NewTestUser() + assert.NoError(err) + foundUser.PublicId = user.PublicId + err = w.LookupByPublicId(context.Background(), foundUser) + assert.NoError(err) + assert.Equal(foundUser.Name, user.Name) + + user2, err := db_test.NewTestUser() + assert.NoError(err) + _, err = w.DoTx(context.Background(), 10, ExpBackoff{}, func(w Writer) error { + user2.Name = "friendly2-" + id + rowsUpdated, err := w.Update(context.Background(), user2, []string{"Name"}) + if err != nil { + return err + } + if rowsUpdated != 1 { + return fmt.Errorf("error in number of rows updated %d", rowsUpdated) + } + return nil + }) + assert.Error(err) + err = w.LookupByPublicId(context.Background(), foundUser) + assert.NoError(err) + assert.NotEqual(foundUser.Name, user2.Name) + + _, err = w.DoTx(context.Background(), 10, ExpBackoff{}, func(w Writer) error { + user.Name = "friendly2-" + id + rowsUpdated, err := w.Update(context.Background(), user, []string{"Name"}) + if err != nil { + return err + } + if rowsUpdated != 1 { + return fmt.Errorf("error in number of rows updated %d", rowsUpdated) + } + return nil + }) + assert.NoError(err) + err = w.LookupByPublicId(context.Background(), foundUser) + assert.NoError(err) + assert.Equal(foundUser.Name, user.Name) + }) } func TestDb_Delete(t *testing.T) {