From ad8d4e9c8083fa986cc4e9a485763ebcbb415b43 Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Mon, 11 May 2020 21:05:11 -0400 Subject: [PATCH] Rename Db's Tx to underlying --- internal/cmd/base/servers.go | 2 +- internal/db/read_writer.go | 68 ++++++++++++++++-------------- internal/db/read_writer_test.go | 74 ++++++++++++++++----------------- 3 files changed, 74 insertions(+), 70 deletions(-) diff --git a/internal/cmd/base/servers.go b/internal/cmd/base/servers.go index da191d7421..54ebf45144 100644 --- a/internal/cmd/base/servers.go +++ b/internal/cmd/base/servers.go @@ -384,7 +384,7 @@ func (b *Server) CreateDevDatabase(dialect string) error { } b.Database = dbase - rw := &db.Db{Tx: b.Database} + rw := db.New(b.Database) repo, err := iam.NewRepository(rw, rw, b.ControllerKMS) if err != nil { c() diff --git a/internal/db/read_writer.go b/internal/db/read_writer.go index 3d7b3b623b..54ad310da9 100644 --- a/internal/db/read_writer.go +++ b/internal/db/read_writer.go @@ -109,24 +109,28 @@ type VetForWriter interface { // Db uses a gorm DB connection for read/write type Db struct { - Tx *gorm.DB + underlying *gorm.DB } // ensure that Db implements the interfaces of: Reader and Writer var _ Reader = (*Db)(nil) var _ Writer = (*Db)(nil) +func New(underlying *gorm.DB) *Db { + return &Db{underlying: underlying} +} + // DB returns the sql.DB func (rw *Db) DB() (*sql.DB, error) { - if rw.Tx == nil { + if rw.underlying == nil { return nil, errors.New("create tx is nil for db") } - return rw.Tx.DB(), nil + return rw.underlying.DB(), nil } // Scan rows will scan the rows into the interface func (rw *Db) ScanRows(rows *sql.Rows, result interface{}) error { - return rw.Tx.ScanRows(rows, result) + return rw.underlying.ScanRows(rows, result) } var ErrNotResourceWithId = errors.New("not a resource with an id") @@ -149,15 +153,15 @@ func (rw *Db) lookupAfterWrite(ctx context.Context, i interface{}, opt ...Option // Create an object in the db with options: WithOplog and WithLookup (to force a lookup after create)) func (rw *Db) Create(ctx context.Context, i interface{}, opt ...Option) error { - if rw.Tx == nil { + if rw.underlying == nil { return errors.New("create tx is nil") } opts := GetOpts(opt...) withOplog := opts.withOplog withDebug := opts.withDebug if withDebug { - rw.Tx.LogMode(true) - defer rw.Tx.LogMode(false) + rw.underlying.LogMode(true) + defer rw.underlying.LogMode(false) } if i == nil { return errors.New("create interface is nil") @@ -167,7 +171,7 @@ func (rw *Db) Create(ctx context.Context, i interface{}, opt ...Option) error { return fmt.Errorf("error on create %w", err) } } - if err := rw.Tx.Create(i).Error; err != nil { + if err := rw.underlying.Create(i).Error; err != nil { return fmt.Errorf("error creating: %w", err) } if withOplog { @@ -184,15 +188,15 @@ func (rw *Db) Create(ctx context.Context, i interface{}, opt ...Option) error { // Update an object in the db, if there's a fieldMask then only the field_mask.proto paths are updated, otherwise // it will send every field to the DB. Update supports embedding a struct (or structPtr) one level deep for updating func (rw *Db) Update(ctx context.Context, i interface{}, fieldMaskPaths []string, opt ...Option) error { - if rw.Tx == nil { + if rw.underlying == nil { return errors.New("update tx is nil") } opts := GetOpts(opt...) withDebug := opts.withDebug withOplog := opts.withOplog if withDebug { - rw.Tx.LogMode(true) - defer rw.Tx.LogMode(false) + rw.underlying.LogMode(true) + defer rw.underlying.LogMode(false) } if i == nil { return errors.New("update interface is nil") @@ -203,7 +207,7 @@ func (rw *Db) Update(ctx context.Context, i interface{}, fieldMaskPaths []string } } if len(fieldMaskPaths) == 0 { - if err := rw.Tx.Save(i).Error; err != nil { + if err := rw.underlying.Save(i).Error; err != nil { return fmt.Errorf("error updating: %w", err) } } @@ -244,7 +248,7 @@ func (rw *Db) Update(ctx context.Context, i interface{}, fieldMaskPaths []string if len(updateFields) == 0 { return fmt.Errorf("error no update fields matched using fieldMaskPaths: %s", fieldMaskPaths) } - if err := rw.Tx.Model(i).Updates(updateFields).Error; err != nil { + if err := rw.underlying.Model(i).Updates(updateFields).Error; err != nil { return fmt.Errorf("error updating: %w", err) } if withOplog { @@ -263,7 +267,7 @@ func (rw *Db) Update(ctx context.Context, i interface{}, fieldMaskPaths []string // Delete an object in the db with options: WithOplog (which requires WithMetadata, WithWrapper) func (rw *Db) Delete(ctx context.Context, i interface{}, opt ...Option) error { - if rw.Tx == nil { + if rw.underlying == nil { return errors.New("delete tx is nil") } if i == nil { @@ -273,10 +277,10 @@ func (rw *Db) Delete(ctx context.Context, i interface{}, opt ...Option) error { withDebug := opts.withDebug withOplog := opts.withOplog if withDebug { - rw.Tx.LogMode(true) - defer rw.Tx.LogMode(false) + rw.underlying.LogMode(true) + defer rw.underlying.LogMode(false) } - if err := rw.Tx.Delete(i).Error; err != nil { + if err := rw.underlying.Delete(i).Error; err != nil { return fmt.Errorf("error deleting: %w", err) } if withOplog { @@ -299,7 +303,7 @@ func (rw *Db) addOplog(ctx context.Context, opType OpType, opts Options, i inter if !ok { return errors.New("error not a replayable message for WithOplog") } - gdb := rw.Tx + gdb := rw.underlying withDebug := opts.withDebug if withDebug { gdb.LogMode(true) @@ -348,7 +352,7 @@ func (rw *Db) addOplog(ctx context.Context, opType OpType, opts Options, i inter // means that the object may be sent to the db several times (retried), so things like the primary key must // be reset before retry func (w *Db) DoTx(ctx context.Context, retries uint, backOff Backoff, Handler TxHandler) (RetryInfo, error) { - if w.Tx == nil { + if w.underlying == nil { return RetryInfo{}, errors.New("do tx is nil") } info := RetryInfo{} @@ -356,7 +360,7 @@ func (w *Db) DoTx(ctx context.Context, retries uint, backOff Backoff, Handler Tx if attempts > retries+1 { return info, fmt.Errorf("Too many retries: %d of %d", attempts-1, retries+1) } - newTx := w.Tx.BeginTx(ctx, nil) + newTx := w.underlying.BeginTx(ctx, nil) if err := Handler(&Db{newTx}); err != nil { if errors.Is(err, oplog.ErrTicketAlreadyRedeemed) { newTx.Rollback() // need to still handle rollback errors @@ -388,14 +392,14 @@ func (w *Db) DoTx(ctx context.Context, retries uint, backOff Backoff, Handler Tx // LookupByName will lookup resource my its friendly name which must be unique func (rw *Db) LookupByName(ctx context.Context, resource ResourceNamer, opt ...Option) error { - if rw.Tx == nil { + if rw.underlying == nil { return errors.New("error tx nil for lookup by name") } opts := GetOpts(opt...) withDebug := opts.withDebug if withDebug { - rw.Tx.LogMode(true) - defer rw.Tx.LogMode(false) + rw.underlying.LogMode(true) + defer rw.underlying.LogMode(false) } if reflect.ValueOf(resource).Kind() != reflect.Ptr { return errors.New("error interface parameter must to be a pointer for lookup by name") @@ -403,7 +407,7 @@ func (rw *Db) LookupByName(ctx context.Context, resource ResourceNamer, opt ...O if resource.GetName() == "" { return errors.New("error name empty string for lookup by name") } - if err := rw.Tx.Where("name = ?", resource.GetName()).First(resource).Error; err != nil { + if err := rw.underlying.Where("name = ?", resource.GetName()).First(resource).Error; err != nil { if err == gorm.ErrRecordNotFound { return ErrRecordNotFound } @@ -414,14 +418,14 @@ func (rw *Db) LookupByName(ctx context.Context, resource ResourceNamer, opt ...O // LookupByPublicId will lookup resource my its public_id which must be unique func (rw *Db) LookupByPublicId(ctx context.Context, resource ResourcePublicIder, opt ...Option) error { - if rw.Tx == nil { + if rw.underlying == nil { return errors.New("error tx nil for lookup by public id") } opts := GetOpts(opt...) withDebug := opts.withDebug if withDebug { - rw.Tx.LogMode(true) - defer rw.Tx.LogMode(false) + rw.underlying.LogMode(true) + defer rw.underlying.LogMode(false) } if reflect.ValueOf(resource).Kind() != reflect.Ptr { return errors.New("error interface parameter must to be a pointer for lookup by public id") @@ -429,7 +433,7 @@ func (rw *Db) LookupByPublicId(ctx context.Context, resource ResourcePublicIder, if resource.GetPublicId() == "" { return errors.New("error public id empty string for lookup by public id") } - if err := rw.Tx.Where("public_id = ?", resource.GetPublicId()).First(resource).Error; err != nil { + if err := rw.underlying.Where("public_id = ?", resource.GetPublicId()).First(resource).Error; err != nil { if err == gorm.ErrRecordNotFound { return ErrRecordNotFound } @@ -440,22 +444,22 @@ func (rw *Db) LookupByPublicId(ctx context.Context, resource ResourcePublicIder, // LookupWhere will lookup the first resource using a where clause with parameters (it only returns the first one) func (rw *Db) LookupWhere(ctx context.Context, resource interface{}, where string, args ...interface{}) error { - if rw.Tx == nil { + if rw.underlying == nil { return errors.New("error tx nil for lookup by") } if reflect.ValueOf(resource).Kind() != reflect.Ptr { return errors.New("error interface parameter must to be a pointer for lookup by") } - return rw.Tx.Where(where, args...).First(resource).Error + return rw.underlying.Where(where, args...).First(resource).Error } // SearchWhere will search for all the resources it can find using a where clause with parameters func (rw *Db) SearchWhere(ctx context.Context, resources interface{}, where string, args ...interface{}) error { - if rw.Tx == nil { + if rw.underlying == nil { return errors.New("error tx nil for search by") } if reflect.ValueOf(resources).Kind() != reflect.Ptr { return errors.New("error interface parameter must to be a pointer for search by") } - return rw.Tx.Where(where, args...).Find(resources).Error + return rw.underlying.Where(where, args...).Find(resources).Error } diff --git a/internal/db/read_writer_test.go b/internal/db/read_writer_test.go index 0fbdc56792..c1d3337add 100644 --- a/internal/db/read_writer_test.go +++ b/internal/db/read_writer_test.go @@ -19,7 +19,7 @@ func TestDb_Update(t *testing.T) { assert := assert.New(t) defer db.Close() t.Run("simple", func(t *testing.T) { - w := Db{Tx: db} + w := Db{underlying: db} id, err := uuid.GenerateUUID() assert.Nil(err) user, err := db_test.NewTestUser() @@ -45,7 +45,7 @@ func TestDb_Update(t *testing.T) { assert.Equal(foundUser.Name, user.Name) }) t.Run("valid-WithOplog", func(t *testing.T) { - w := Db{Tx: db} + w := Db{underlying: db} id, err := uuid.GenerateUUID() assert.Nil(err) user, err := db_test.NewTestUser() @@ -89,7 +89,7 @@ func TestDb_Update(t *testing.T) { assert.Nil(err) }) t.Run("nil-tx", func(t *testing.T) { - w := Db{Tx: nil} + w := Db{underlying: nil} id, err := uuid.GenerateUUID() assert.Nil(err) user, err := db_test.NewTestUser() @@ -100,7 +100,7 @@ func TestDb_Update(t *testing.T) { assert.Equal("update tx is nil", err.Error()) }) t.Run("no-wrapper-WithOplog", func(t *testing.T) { - w := Db{Tx: db} + w := Db{underlying: db} id, err := uuid.GenerateUUID() assert.Nil(err) user, err := db_test.NewTestUser() @@ -131,7 +131,7 @@ func TestDb_Update(t *testing.T) { assert.Equal("error wrapper is nil for WithWrapper", err.Error()) }) t.Run("no-metadata-WithOplog", func(t *testing.T) { - w := Db{Tx: db} + w := Db{underlying: db} id, err := uuid.GenerateUUID() assert.Nil(err) user, err := db_test.NewTestUser() @@ -167,7 +167,7 @@ func TestDb_Create(t *testing.T) { assert := assert.New(t) defer db.Close() t.Run("simple", func(t *testing.T) { - w := Db{Tx: db} + w := Db{underlying: db} id, err := uuid.GenerateUUID() assert.Nil(err) user, err := db_test.NewTestUser() @@ -187,7 +187,7 @@ func TestDb_Create(t *testing.T) { assert.Equal(foundUser.Id, user.Id) }) t.Run("valid-WithOplog", func(t *testing.T) { - w := Db{Tx: db} + w := Db{underlying: db} id, err := uuid.GenerateUUID() assert.Nil(err) user, err := db_test.NewTestUser() @@ -216,7 +216,7 @@ func TestDb_Create(t *testing.T) { assert.Equal(foundUser.Id, user.Id) }) t.Run("no-wrapper-WithOplog", func(t *testing.T) { - w := Db{Tx: db} + w := Db{underlying: db} id, err := uuid.GenerateUUID() assert.Nil(err) user, err := db_test.NewTestUser() @@ -238,7 +238,7 @@ func TestDb_Create(t *testing.T) { assert.Equal("error wrapper is nil for WithWrapper", err.Error()) }) t.Run("no-metadata-WithOplog", func(t *testing.T) { - w := Db{Tx: db} + w := Db{underlying: db} id, err := uuid.GenerateUUID() assert.Nil(err) user, err := db_test.NewTestUser() @@ -256,7 +256,7 @@ func TestDb_Create(t *testing.T) { assert.Equal("error no metadata for WithOplog", err.Error()) }) t.Run("nil-tx", func(t *testing.T) { - w := Db{Tx: nil} + w := Db{underlying: nil} id, err := uuid.GenerateUUID() assert.Nil(err) user, err := db_test.NewTestUser() @@ -275,7 +275,7 @@ func TestDb_LookupByName(t *testing.T) { assert := assert.New(t) defer db.Close() t.Run("simple", func(t *testing.T) { - w := Db{Tx: db} + w := Db{underlying: db} id, err := uuid.GenerateUUID() assert.Nil(err) user, err := db_test.NewTestUser() @@ -302,7 +302,7 @@ func TestDb_LookupByName(t *testing.T) { assert.Equal("error tx nil for lookup by name", err.Error()) }) t.Run("no-friendly-name-set", func(t *testing.T) { - w := Db{Tx: db} + w := Db{underlying: db} foundUser, err := db_test.NewTestUser() assert.Nil(err) err = w.LookupByName(context.Background(), foundUser) @@ -310,7 +310,7 @@ func TestDb_LookupByName(t *testing.T) { assert.Equal("error name empty string for lookup by name", err.Error()) }) t.Run("not-found", func(t *testing.T) { - w := Db{Tx: db} + w := Db{underlying: db} id, err := uuid.GenerateUUID() assert.Nil(err) @@ -330,7 +330,7 @@ func TestDb_LookupByPublicId(t *testing.T) { assert := assert.New(t) defer db.Close() t.Run("simple", func(t *testing.T) { - w := Db{Tx: db} + w := Db{underlying: db} id, err := uuid.GenerateUUID() assert.Nil(err) user, err := db_test.NewTestUser() @@ -356,7 +356,7 @@ func TestDb_LookupByPublicId(t *testing.T) { assert.Equal("error tx nil for lookup by public id", err.Error()) }) t.Run("no-public-id-set", func(t *testing.T) { - w := Db{Tx: db} + w := Db{underlying: db} foundUser, err := db_test.NewTestUser() foundUser.PublicId = "" assert.Nil(err) @@ -365,7 +365,7 @@ func TestDb_LookupByPublicId(t *testing.T) { assert.Equal("error public id empty string for lookup by public id", err.Error()) }) t.Run("not-found", func(t *testing.T) { - w := Db{Tx: db} + w := Db{underlying: db} id, err := uuid.GenerateUUID() assert.Nil(err) @@ -385,7 +385,7 @@ func TestDb_LookupWhere(t *testing.T) { assert := assert.New(t) defer db.Close() t.Run("simple", func(t *testing.T) { - w := Db{Tx: db} + w := Db{underlying: db} id, err := uuid.GenerateUUID() assert.Nil(err) user, err := db_test.NewTestUser() @@ -408,7 +408,7 @@ func TestDb_LookupWhere(t *testing.T) { assert.Equal("error tx nil for lookup by", err.Error()) }) t.Run("not-found", func(t *testing.T) { - w := Db{Tx: db} + w := Db{underlying: db} id, err := uuid.GenerateUUID() assert.Nil(err) @@ -418,7 +418,7 @@ func TestDb_LookupWhere(t *testing.T) { assert.Equal(ErrRecordNotFound, err) }) t.Run("bad-where", func(t *testing.T) { - w := Db{Tx: db} + w := Db{underlying: db} id, err := uuid.GenerateUUID() assert.Nil(err) @@ -435,7 +435,7 @@ func TestDb_SearchWhere(t *testing.T) { assert := assert.New(t) defer db.Close() t.Run("simple", func(t *testing.T) { - w := Db{Tx: db} + w := Db{underlying: db} id, err := uuid.GenerateUUID() assert.Nil(err) user, err := db_test.NewTestUser() @@ -458,7 +458,7 @@ func TestDb_SearchWhere(t *testing.T) { assert.Equal("error tx nil for search by", err.Error()) }) t.Run("not-found", func(t *testing.T) { - w := Db{Tx: db} + w := Db{underlying: db} id, err := uuid.GenerateUUID() assert.Nil(err) @@ -468,7 +468,7 @@ func TestDb_SearchWhere(t *testing.T) { assert.Equal(0, len(foundUsers)) }) t.Run("bad-where", func(t *testing.T) { - w := Db{Tx: db} + w := Db{underlying: db} id, err := uuid.GenerateUUID() assert.Nil(err) @@ -485,7 +485,7 @@ func TestDb_DB(t *testing.T) { assert := assert.New(t) defer db.Close() t.Run("valid", func(t *testing.T) { - w := Db{Tx: db} + w := Db{underlying: db} d, err := w.DB() assert.Nil(err) assert.True(d != nil) @@ -493,7 +493,7 @@ func TestDb_DB(t *testing.T) { assert.Nil(err) }) t.Run("nil-tx", func(t *testing.T) { - w := Db{Tx: nil} + w := Db{underlying: nil} d, err := w.DB() assert.True(err != nil) assert.True(d == nil) @@ -509,7 +509,7 @@ func TestDb_DoTx(t *testing.T) { defer db.Close() t.Run("valid-with-10-retries", func(t *testing.T) { - w := &Db{Tx: db} + w := &Db{underlying: db} attempts := 0 got, err := w.DoTx(context.Background(), 10, ExpBackoff{}, func(Writer) error { @@ -524,7 +524,7 @@ func TestDb_DoTx(t *testing.T) { assert.Equal(9, attempts) // attempted 1 + 8 retries }) t.Run("valid-with-1-retries", func(t *testing.T) { - w := &Db{Tx: db} + w := &Db{underlying: db} attempts := 0 got, err := w.DoTx(context.Background(), 1, ExpBackoff{}, func(Writer) error { @@ -539,7 +539,7 @@ func TestDb_DoTx(t *testing.T) { assert.Equal(2, attempts) // attempted 1 + 8 retries }) t.Run("valid-with-2-retries", func(t *testing.T) { - w := &Db{Tx: db} + w := &Db{underlying: db} attempts := 0 got, err := w.DoTx(context.Background(), 3, ExpBackoff{}, func(Writer) error { @@ -554,7 +554,7 @@ func TestDb_DoTx(t *testing.T) { assert.Equal(3, attempts) // attempted 1 + 8 retries }) t.Run("valid-with-4-retries", func(t *testing.T) { - w := &Db{Tx: db} + w := &Db{underlying: db} attempts := 0 got, err := w.DoTx(context.Background(), 4, ExpBackoff{}, func(Writer) error { @@ -569,7 +569,7 @@ func TestDb_DoTx(t *testing.T) { assert.Equal(4, attempts) // attempted 1 + 8 retries }) t.Run("zero-retries", func(t *testing.T) { - w := &Db{Tx: db} + w := &Db{underlying: db} attempts := 0 got, err := w.DoTx(context.Background(), 0, ExpBackoff{}, func(Writer) error { attempts += 1; return nil }) assert.Nil(err) @@ -585,14 +585,14 @@ func TestDb_DoTx(t *testing.T) { assert.Equal("do tx is nil", err.Error()) }) t.Run("not-a-retry-err", func(t *testing.T) { - w := &Db{Tx: db} + w := &Db{underlying: db} got, err := w.DoTx(context.Background(), 1, ExpBackoff{}, func(Writer) error { return errors.New("not a retry error") }) assert.True(err != nil) assert.Equal(RetryInfo{}, got) assert.True(err != oplog.ErrTicketAlreadyRedeemed) }) t.Run("too-many-retries", func(t *testing.T) { - w := &Db{Tx: db} + w := &Db{underlying: db} attempts := 0 got, err := w.DoTx(context.Background(), 2, ExpBackoff{}, func(Writer) error { attempts += 1; return oplog.ErrTicketAlreadyRedeemed }) assert.True(err != nil) @@ -608,7 +608,7 @@ func TestDb_Delete(t *testing.T) { assert := assert.New(t) defer db.Close() t.Run("simple", func(t *testing.T) { - w := Db{Tx: db} + w := Db{underlying: db} id, err := uuid.GenerateUUID() assert.Nil(err) user, err := db_test.NewTestUser() @@ -635,7 +635,7 @@ func TestDb_Delete(t *testing.T) { assert.Equal(ErrRecordNotFound, err) }) t.Run("valid-WithOplog", func(t *testing.T) { - w := Db{Tx: db} + w := Db{underlying: db} id, err := uuid.GenerateUUID() assert.Nil(err) user, err := db_test.NewTestUser() @@ -682,7 +682,7 @@ func TestDb_Delete(t *testing.T) { assert.Equal(ErrRecordNotFound, err) }) t.Run("no-wrapper-WithOplog", func(t *testing.T) { - w := Db{Tx: db} + w := Db{underlying: db} id, err := uuid.GenerateUUID() assert.Nil(err) user, err := db_test.NewTestUser() @@ -726,7 +726,7 @@ func TestDb_Delete(t *testing.T) { assert.Equal("error wrapper is nil for WithWrapper", err.Error()) }) t.Run("no-metadata-WithOplog", func(t *testing.T) { - w := Db{Tx: db} + w := Db{underlying: db} id, err := uuid.GenerateUUID() assert.Nil(err) user, err := db_test.NewTestUser() @@ -766,7 +766,7 @@ func TestDb_Delete(t *testing.T) { assert.Equal("error no metadata for WithOplog", err.Error()) }) t.Run("nil-tx", func(t *testing.T) { - w := Db{Tx: nil} + w := Db{underlying: nil} id, err := uuid.GenerateUUID() assert.Nil(err) user, err := db_test.NewTestUser() @@ -785,7 +785,7 @@ func TestDb_ScanRows(t *testing.T) { assert := assert.New(t) defer db.Close() t.Run("valid", func(t *testing.T) { - w := Db{Tx: db} + w := Db{underlying: db} user, err := db_test.NewTestUser() assert.Nil(err) err = w.Create(context.Background(), user)