diff --git a/internal/db/read_writer.go b/internal/db/read_writer.go index 2f18146617..ed2ec12037 100644 --- a/internal/db/read_writer.go +++ b/internal/db/read_writer.go @@ -92,7 +92,7 @@ type RetryInfo struct { } // TxHandler defines a handler for a func that writes a transaction for use with DoTx -type TxHandler func(Writer) error +type TxHandler func(Reader, Writer) error // ResourcePublicIder defines an interface that LookupByPublicId() can use to get the resource's public id type ResourcePublicIder interface { @@ -469,7 +469,8 @@ func (w *Db) DoTx(ctx context.Context, retries uint, backOff Backoff, Handler Tx // step one of this, start a transaction... newTx := w.underlying.BeginTx(ctx, nil) - if err := Handler(&Db{newTx}); err != nil { + rw := &Db{newTx} + if err := Handler(rw, rw); err != nil { if err := newTx.Rollback().Error; err != nil { return info, err } diff --git a/internal/db/read_writer_test.go b/internal/db/read_writer_test.go index e54637114c..37f73a5b39 100644 --- a/internal/db/read_writer_test.go +++ b/internal/db/read_writer_test.go @@ -828,7 +828,7 @@ func TestDb_DoTx(t *testing.T) { w := &Db{underlying: db} attempts := 0 got, err := w.DoTx(context.Background(), 10, ExpBackoff{}, - func(Writer) error { + func(Reader, Writer) error { attempts += 1 if attempts < 9 { return oplog.ErrTicketAlreadyRedeemed @@ -843,7 +843,7 @@ func TestDb_DoTx(t *testing.T) { w := &Db{underlying: db} attempts := 0 got, err := w.DoTx(context.Background(), 1, ExpBackoff{}, - func(Writer) error { + func(Reader, Writer) error { attempts += 1 if attempts < 2 { return oplog.ErrTicketAlreadyRedeemed @@ -858,7 +858,7 @@ func TestDb_DoTx(t *testing.T) { w := &Db{underlying: db} attempts := 0 got, err := w.DoTx(context.Background(), 3, ExpBackoff{}, - func(Writer) error { + func(Reader, Writer) error { attempts += 1 if attempts < 3 { return oplog.ErrTicketAlreadyRedeemed @@ -873,7 +873,7 @@ func TestDb_DoTx(t *testing.T) { w := &Db{underlying: db} attempts := 0 got, err := w.DoTx(context.Background(), 4, ExpBackoff{}, - func(Writer) error { + func(Reader, Writer) error { attempts += 1 if attempts < 4 { return oplog.ErrTicketAlreadyRedeemed @@ -887,7 +887,7 @@ func TestDb_DoTx(t *testing.T) { t.Run("zero-retries", func(t *testing.T) { w := &Db{underlying: db} attempts := 0 - got, err := w.DoTx(context.Background(), 0, ExpBackoff{}, func(Writer) error { attempts += 1; return nil }) + got, err := w.DoTx(context.Background(), 0, ExpBackoff{}, func(Reader, Writer) error { attempts += 1; return nil }) assert.NoError(err) assert.Equal(RetryInfo{}, got) assert.Equal(1, attempts) @@ -895,14 +895,14 @@ func TestDb_DoTx(t *testing.T) { t.Run("nil-tx", func(t *testing.T) { w := &Db{nil} attempts := 0 - got, err := w.DoTx(context.Background(), 1, ExpBackoff{}, func(Writer) error { attempts += 1; return nil }) + got, err := w.DoTx(context.Background(), 1, ExpBackoff{}, func(Reader, Writer) error { attempts += 1; return nil }) assert.Error(err) assert.Equal(RetryInfo{}, got) assert.Equal("do underlying db is nil", err.Error()) }) t.Run("not-a-retry-err", func(t *testing.T) { w := &Db{underlying: db} - got, err := w.DoTx(context.Background(), 1, ExpBackoff{}, func(Writer) error { return errors.New("not a retry error") }) + got, err := w.DoTx(context.Background(), 1, ExpBackoff{}, func(Reader, Writer) error { return errors.New("not a retry error") }) assert.Error(err) assert.Equal(RetryInfo{}, got) assert.NotEqual(err, oplog.ErrTicketAlreadyRedeemed) @@ -910,23 +910,23 @@ func TestDb_DoTx(t *testing.T) { t.Run("too-many-retries", func(t *testing.T) { w := &Db{underlying: db} attempts := 0 - got, err := w.DoTx(context.Background(), 2, ExpBackoff{}, func(Writer) error { attempts += 1; return oplog.ErrTicketAlreadyRedeemed }) + got, err := w.DoTx(context.Background(), 2, ExpBackoff{}, func(Reader, Writer) error { attempts += 1; return oplog.ErrTicketAlreadyRedeemed }) assert.Error(err) assert.Equal(3, got.Retries) assert.Equal("Too many retries: 3 of 3", err.Error()) }) t.Run("updating-good-bad-good", func(t *testing.T) { - w := Db{underlying: db} + rw := Db{underlying: db} id, err := uuid.GenerateUUID() assert.NoError(err) user, err := db_test.NewTestUser() assert.NoError(err) user.Name = "foo-" + id - err = w.Create(context.Background(), user) + err = rw.Create(context.Background(), user) assert.NoError(err) assert.NotZero(user.Id) - _, err = w.DoTx(context.Background(), 10, ExpBackoff{}, func(w Writer) error { + _, err = rw.DoTx(context.Background(), 10, ExpBackoff{}, func(r Reader, w Writer) error { user.Name = "friendly-" + id rowsUpdated, err := w.Update(context.Background(), user, []string{"Name"}, nil) if err != nil { @@ -942,13 +942,13 @@ func TestDb_DoTx(t *testing.T) { foundUser, err := db_test.NewTestUser() assert.NoError(err) foundUser.PublicId = user.PublicId - err = w.LookupByPublicId(context.Background(), foundUser) + err = rw.LookupByPublicId(context.Background(), foundUser) assert.NoError(err) assert.Equal(foundUser.Name, user.Name) user2, err := db_test.NewTestUser() assert.NoError(err) - _, err = w.DoTx(context.Background(), 10, ExpBackoff{}, func(w Writer) error { + _, err = rw.DoTx(context.Background(), 10, ExpBackoff{}, func(_ Reader, w Writer) error { user2.Name = "friendly2-" + id rowsUpdated, err := w.Update(context.Background(), user2, []string{"Name"}, nil) if err != nil { @@ -960,11 +960,11 @@ func TestDb_DoTx(t *testing.T) { return nil }) assert.Error(err) - err = w.LookupByPublicId(context.Background(), foundUser) + err = rw.LookupByPublicId(context.Background(), foundUser) assert.NoError(err) assert.NotEqual(foundUser.Name, user2.Name) - _, err = w.DoTx(context.Background(), 10, ExpBackoff{}, func(w Writer) error { + _, err = rw.DoTx(context.Background(), 10, ExpBackoff{}, func(r Reader, w Writer) error { user.Name = "friendly2-" + id rowsUpdated, err := w.Update(context.Background(), user, []string{"Name"}, nil) if err != nil { @@ -976,7 +976,7 @@ func TestDb_DoTx(t *testing.T) { return nil }) assert.NoError(err) - err = w.LookupByPublicId(context.Background(), foundUser) + err = rw.LookupByPublicId(context.Background(), foundUser) assert.NoError(err) assert.Equal(foundUser.Name, user.Name) }) diff --git a/internal/host/static/repository.go b/internal/host/static/repository.go index c2fcd90ad6..e8da636775 100644 --- a/internal/host/static/repository.go +++ b/internal/host/static/repository.go @@ -77,7 +77,7 @@ func (r *Repository) CreateCatalog(ctx context.Context, c *HostCatalog, opt ...O ctx, db.StdRetryCnt, db.ExpBackoff{}, - func(w db.Writer) error { + func(_ db.Reader, w db.Writer) error { newHostCatalog = c.clone() return w.Create( ctx, @@ -148,7 +148,7 @@ func (r *Repository) UpdateCatalog(ctx context.Context, c *HostCatalog, fieldMas ctx, db.StdRetryCnt, db.ExpBackoff{}, - func(w db.Writer) error { + func(_ db.Reader, w db.Writer) error { returnedCatalog = c.clone() var err error rowsUpdated, err = w.Update( @@ -211,7 +211,7 @@ func (r *Repository) DeleteCatalog(ctx context.Context, id string, opt ...Option ctx, db.StdRetryCnt, db.ExpBackoff{}, - func(w db.Writer) error { + func(_ db.Reader, w db.Writer) error { deleteCatalog = c.clone() var err error rowsDeleted, err = w.Delete( diff --git a/internal/iam/repository.go b/internal/iam/repository.go index bf526e90c8..f416fb1b99 100644 --- a/internal/iam/repository.go +++ b/internal/iam/repository.go @@ -83,7 +83,7 @@ func (r *Repository) create(ctx context.Context, resource Resource, opt ...Optio ctx, db.StdRetryCnt, db.ExpBackoff{}, - func(w db.Writer) error { + func(_ db.Reader, w db.Writer) error { returnedResource = resourceCloner.Clone() return w.Create( ctx, @@ -116,7 +116,7 @@ func (r *Repository) update(ctx context.Context, resource Resource, fieldMaskPat ctx, db.StdRetryCnt, db.ExpBackoff{}, - func(w db.Writer) error { + func(_ db.Reader, w db.Writer) error { returnedResource = resourceCloner.Clone() var err error rowsUpdated, err = w.Update( @@ -157,7 +157,7 @@ func (r *Repository) delete(ctx context.Context, resource Resource, opt ...Optio ctx, db.StdRetryCnt, db.ExpBackoff{}, - func(w db.Writer) error { + func(_ db.Reader, w db.Writer) error { deleteResource = resourceCloner.Clone() var err error rowsDeleted, err = w.Delete(