Rename Db's Tx to underlying

pull/45/head
Jeff Mitchell 6 years ago
parent 062ce6079f
commit ad8d4e9c80

@ -384,7 +384,7 @@ func (b *Server) CreateDevDatabase(dialect string) error {
}
b.Database = dbase
rw := &db.Db{Tx: b.Database}
rw := db.New(b.Database)
repo, err := iam.NewRepository(rw, rw, b.ControllerKMS)
if err != nil {
c()

@ -109,24 +109,28 @@ type VetForWriter interface {
// Db uses a gorm DB connection for read/write
type Db struct {
Tx *gorm.DB
underlying *gorm.DB
}
// ensure that Db implements the interfaces of: Reader and Writer
var _ Reader = (*Db)(nil)
var _ Writer = (*Db)(nil)
func New(underlying *gorm.DB) *Db {
return &Db{underlying: underlying}
}
// DB returns the sql.DB
func (rw *Db) DB() (*sql.DB, error) {
if rw.Tx == nil {
if rw.underlying == nil {
return nil, errors.New("create tx is nil for db")
}
return rw.Tx.DB(), nil
return rw.underlying.DB(), nil
}
// Scan rows will scan the rows into the interface
func (rw *Db) ScanRows(rows *sql.Rows, result interface{}) error {
return rw.Tx.ScanRows(rows, result)
return rw.underlying.ScanRows(rows, result)
}
var ErrNotResourceWithId = errors.New("not a resource with an id")
@ -149,15 +153,15 @@ func (rw *Db) lookupAfterWrite(ctx context.Context, i interface{}, opt ...Option
// Create an object in the db with options: WithOplog and WithLookup (to force a lookup after create))
func (rw *Db) Create(ctx context.Context, i interface{}, opt ...Option) error {
if rw.Tx == nil {
if rw.underlying == nil {
return errors.New("create tx is nil")
}
opts := GetOpts(opt...)
withOplog := opts.withOplog
withDebug := opts.withDebug
if withDebug {
rw.Tx.LogMode(true)
defer rw.Tx.LogMode(false)
rw.underlying.LogMode(true)
defer rw.underlying.LogMode(false)
}
if i == nil {
return errors.New("create interface is nil")
@ -167,7 +171,7 @@ func (rw *Db) Create(ctx context.Context, i interface{}, opt ...Option) error {
return fmt.Errorf("error on create %w", err)
}
}
if err := rw.Tx.Create(i).Error; err != nil {
if err := rw.underlying.Create(i).Error; err != nil {
return fmt.Errorf("error creating: %w", err)
}
if withOplog {
@ -184,15 +188,15 @@ func (rw *Db) Create(ctx context.Context, i interface{}, opt ...Option) error {
// Update an object in the db, if there's a fieldMask then only the field_mask.proto paths are updated, otherwise
// it will send every field to the DB. Update supports embedding a struct (or structPtr) one level deep for updating
func (rw *Db) Update(ctx context.Context, i interface{}, fieldMaskPaths []string, opt ...Option) error {
if rw.Tx == nil {
if rw.underlying == nil {
return errors.New("update tx is nil")
}
opts := GetOpts(opt...)
withDebug := opts.withDebug
withOplog := opts.withOplog
if withDebug {
rw.Tx.LogMode(true)
defer rw.Tx.LogMode(false)
rw.underlying.LogMode(true)
defer rw.underlying.LogMode(false)
}
if i == nil {
return errors.New("update interface is nil")
@ -203,7 +207,7 @@ func (rw *Db) Update(ctx context.Context, i interface{}, fieldMaskPaths []string
}
}
if len(fieldMaskPaths) == 0 {
if err := rw.Tx.Save(i).Error; err != nil {
if err := rw.underlying.Save(i).Error; err != nil {
return fmt.Errorf("error updating: %w", err)
}
}
@ -244,7 +248,7 @@ func (rw *Db) Update(ctx context.Context, i interface{}, fieldMaskPaths []string
if len(updateFields) == 0 {
return fmt.Errorf("error no update fields matched using fieldMaskPaths: %s", fieldMaskPaths)
}
if err := rw.Tx.Model(i).Updates(updateFields).Error; err != nil {
if err := rw.underlying.Model(i).Updates(updateFields).Error; err != nil {
return fmt.Errorf("error updating: %w", err)
}
if withOplog {
@ -263,7 +267,7 @@ func (rw *Db) Update(ctx context.Context, i interface{}, fieldMaskPaths []string
// Delete an object in the db with options: WithOplog (which requires WithMetadata, WithWrapper)
func (rw *Db) Delete(ctx context.Context, i interface{}, opt ...Option) error {
if rw.Tx == nil {
if rw.underlying == nil {
return errors.New("delete tx is nil")
}
if i == nil {
@ -273,10 +277,10 @@ func (rw *Db) Delete(ctx context.Context, i interface{}, opt ...Option) error {
withDebug := opts.withDebug
withOplog := opts.withOplog
if withDebug {
rw.Tx.LogMode(true)
defer rw.Tx.LogMode(false)
rw.underlying.LogMode(true)
defer rw.underlying.LogMode(false)
}
if err := rw.Tx.Delete(i).Error; err != nil {
if err := rw.underlying.Delete(i).Error; err != nil {
return fmt.Errorf("error deleting: %w", err)
}
if withOplog {
@ -299,7 +303,7 @@ func (rw *Db) addOplog(ctx context.Context, opType OpType, opts Options, i inter
if !ok {
return errors.New("error not a replayable message for WithOplog")
}
gdb := rw.Tx
gdb := rw.underlying
withDebug := opts.withDebug
if withDebug {
gdb.LogMode(true)
@ -348,7 +352,7 @@ func (rw *Db) addOplog(ctx context.Context, opType OpType, opts Options, i inter
// means that the object may be sent to the db several times (retried), so things like the primary key must
// be reset before retry
func (w *Db) DoTx(ctx context.Context, retries uint, backOff Backoff, Handler TxHandler) (RetryInfo, error) {
if w.Tx == nil {
if w.underlying == nil {
return RetryInfo{}, errors.New("do tx is nil")
}
info := RetryInfo{}
@ -356,7 +360,7 @@ func (w *Db) DoTx(ctx context.Context, retries uint, backOff Backoff, Handler Tx
if attempts > retries+1 {
return info, fmt.Errorf("Too many retries: %d of %d", attempts-1, retries+1)
}
newTx := w.Tx.BeginTx(ctx, nil)
newTx := w.underlying.BeginTx(ctx, nil)
if err := Handler(&Db{newTx}); err != nil {
if errors.Is(err, oplog.ErrTicketAlreadyRedeemed) {
newTx.Rollback() // need to still handle rollback errors
@ -388,14 +392,14 @@ func (w *Db) DoTx(ctx context.Context, retries uint, backOff Backoff, Handler Tx
// LookupByName will lookup resource my its friendly name which must be unique
func (rw *Db) LookupByName(ctx context.Context, resource ResourceNamer, opt ...Option) error {
if rw.Tx == nil {
if rw.underlying == nil {
return errors.New("error tx nil for lookup by name")
}
opts := GetOpts(opt...)
withDebug := opts.withDebug
if withDebug {
rw.Tx.LogMode(true)
defer rw.Tx.LogMode(false)
rw.underlying.LogMode(true)
defer rw.underlying.LogMode(false)
}
if reflect.ValueOf(resource).Kind() != reflect.Ptr {
return errors.New("error interface parameter must to be a pointer for lookup by name")
@ -403,7 +407,7 @@ func (rw *Db) LookupByName(ctx context.Context, resource ResourceNamer, opt ...O
if resource.GetName() == "" {
return errors.New("error name empty string for lookup by name")
}
if err := rw.Tx.Where("name = ?", resource.GetName()).First(resource).Error; err != nil {
if err := rw.underlying.Where("name = ?", resource.GetName()).First(resource).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return ErrRecordNotFound
}
@ -414,14 +418,14 @@ func (rw *Db) LookupByName(ctx context.Context, resource ResourceNamer, opt ...O
// LookupByPublicId will lookup resource my its public_id which must be unique
func (rw *Db) LookupByPublicId(ctx context.Context, resource ResourcePublicIder, opt ...Option) error {
if rw.Tx == nil {
if rw.underlying == nil {
return errors.New("error tx nil for lookup by public id")
}
opts := GetOpts(opt...)
withDebug := opts.withDebug
if withDebug {
rw.Tx.LogMode(true)
defer rw.Tx.LogMode(false)
rw.underlying.LogMode(true)
defer rw.underlying.LogMode(false)
}
if reflect.ValueOf(resource).Kind() != reflect.Ptr {
return errors.New("error interface parameter must to be a pointer for lookup by public id")
@ -429,7 +433,7 @@ func (rw *Db) LookupByPublicId(ctx context.Context, resource ResourcePublicIder,
if resource.GetPublicId() == "" {
return errors.New("error public id empty string for lookup by public id")
}
if err := rw.Tx.Where("public_id = ?", resource.GetPublicId()).First(resource).Error; err != nil {
if err := rw.underlying.Where("public_id = ?", resource.GetPublicId()).First(resource).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return ErrRecordNotFound
}
@ -440,22 +444,22 @@ func (rw *Db) LookupByPublicId(ctx context.Context, resource ResourcePublicIder,
// LookupWhere will lookup the first resource using a where clause with parameters (it only returns the first one)
func (rw *Db) LookupWhere(ctx context.Context, resource interface{}, where string, args ...interface{}) error {
if rw.Tx == nil {
if rw.underlying == nil {
return errors.New("error tx nil for lookup by")
}
if reflect.ValueOf(resource).Kind() != reflect.Ptr {
return errors.New("error interface parameter must to be a pointer for lookup by")
}
return rw.Tx.Where(where, args...).First(resource).Error
return rw.underlying.Where(where, args...).First(resource).Error
}
// SearchWhere will search for all the resources it can find using a where clause with parameters
func (rw *Db) SearchWhere(ctx context.Context, resources interface{}, where string, args ...interface{}) error {
if rw.Tx == nil {
if rw.underlying == nil {
return errors.New("error tx nil for search by")
}
if reflect.ValueOf(resources).Kind() != reflect.Ptr {
return errors.New("error interface parameter must to be a pointer for search by")
}
return rw.Tx.Where(where, args...).Find(resources).Error
return rw.underlying.Where(where, args...).Find(resources).Error
}

@ -19,7 +19,7 @@ func TestDb_Update(t *testing.T) {
assert := assert.New(t)
defer db.Close()
t.Run("simple", func(t *testing.T) {
w := Db{Tx: db}
w := Db{underlying: db}
id, err := uuid.GenerateUUID()
assert.Nil(err)
user, err := db_test.NewTestUser()
@ -45,7 +45,7 @@ func TestDb_Update(t *testing.T) {
assert.Equal(foundUser.Name, user.Name)
})
t.Run("valid-WithOplog", func(t *testing.T) {
w := Db{Tx: db}
w := Db{underlying: db}
id, err := uuid.GenerateUUID()
assert.Nil(err)
user, err := db_test.NewTestUser()
@ -89,7 +89,7 @@ func TestDb_Update(t *testing.T) {
assert.Nil(err)
})
t.Run("nil-tx", func(t *testing.T) {
w := Db{Tx: nil}
w := Db{underlying: nil}
id, err := uuid.GenerateUUID()
assert.Nil(err)
user, err := db_test.NewTestUser()
@ -100,7 +100,7 @@ func TestDb_Update(t *testing.T) {
assert.Equal("update tx is nil", err.Error())
})
t.Run("no-wrapper-WithOplog", func(t *testing.T) {
w := Db{Tx: db}
w := Db{underlying: db}
id, err := uuid.GenerateUUID()
assert.Nil(err)
user, err := db_test.NewTestUser()
@ -131,7 +131,7 @@ func TestDb_Update(t *testing.T) {
assert.Equal("error wrapper is nil for WithWrapper", err.Error())
})
t.Run("no-metadata-WithOplog", func(t *testing.T) {
w := Db{Tx: db}
w := Db{underlying: db}
id, err := uuid.GenerateUUID()
assert.Nil(err)
user, err := db_test.NewTestUser()
@ -167,7 +167,7 @@ func TestDb_Create(t *testing.T) {
assert := assert.New(t)
defer db.Close()
t.Run("simple", func(t *testing.T) {
w := Db{Tx: db}
w := Db{underlying: db}
id, err := uuid.GenerateUUID()
assert.Nil(err)
user, err := db_test.NewTestUser()
@ -187,7 +187,7 @@ func TestDb_Create(t *testing.T) {
assert.Equal(foundUser.Id, user.Id)
})
t.Run("valid-WithOplog", func(t *testing.T) {
w := Db{Tx: db}
w := Db{underlying: db}
id, err := uuid.GenerateUUID()
assert.Nil(err)
user, err := db_test.NewTestUser()
@ -216,7 +216,7 @@ func TestDb_Create(t *testing.T) {
assert.Equal(foundUser.Id, user.Id)
})
t.Run("no-wrapper-WithOplog", func(t *testing.T) {
w := Db{Tx: db}
w := Db{underlying: db}
id, err := uuid.GenerateUUID()
assert.Nil(err)
user, err := db_test.NewTestUser()
@ -238,7 +238,7 @@ func TestDb_Create(t *testing.T) {
assert.Equal("error wrapper is nil for WithWrapper", err.Error())
})
t.Run("no-metadata-WithOplog", func(t *testing.T) {
w := Db{Tx: db}
w := Db{underlying: db}
id, err := uuid.GenerateUUID()
assert.Nil(err)
user, err := db_test.NewTestUser()
@ -256,7 +256,7 @@ func TestDb_Create(t *testing.T) {
assert.Equal("error no metadata for WithOplog", err.Error())
})
t.Run("nil-tx", func(t *testing.T) {
w := Db{Tx: nil}
w := Db{underlying: nil}
id, err := uuid.GenerateUUID()
assert.Nil(err)
user, err := db_test.NewTestUser()
@ -275,7 +275,7 @@ func TestDb_LookupByName(t *testing.T) {
assert := assert.New(t)
defer db.Close()
t.Run("simple", func(t *testing.T) {
w := Db{Tx: db}
w := Db{underlying: db}
id, err := uuid.GenerateUUID()
assert.Nil(err)
user, err := db_test.NewTestUser()
@ -302,7 +302,7 @@ func TestDb_LookupByName(t *testing.T) {
assert.Equal("error tx nil for lookup by name", err.Error())
})
t.Run("no-friendly-name-set", func(t *testing.T) {
w := Db{Tx: db}
w := Db{underlying: db}
foundUser, err := db_test.NewTestUser()
assert.Nil(err)
err = w.LookupByName(context.Background(), foundUser)
@ -310,7 +310,7 @@ func TestDb_LookupByName(t *testing.T) {
assert.Equal("error name empty string for lookup by name", err.Error())
})
t.Run("not-found", func(t *testing.T) {
w := Db{Tx: db}
w := Db{underlying: db}
id, err := uuid.GenerateUUID()
assert.Nil(err)
@ -330,7 +330,7 @@ func TestDb_LookupByPublicId(t *testing.T) {
assert := assert.New(t)
defer db.Close()
t.Run("simple", func(t *testing.T) {
w := Db{Tx: db}
w := Db{underlying: db}
id, err := uuid.GenerateUUID()
assert.Nil(err)
user, err := db_test.NewTestUser()
@ -356,7 +356,7 @@ func TestDb_LookupByPublicId(t *testing.T) {
assert.Equal("error tx nil for lookup by public id", err.Error())
})
t.Run("no-public-id-set", func(t *testing.T) {
w := Db{Tx: db}
w := Db{underlying: db}
foundUser, err := db_test.NewTestUser()
foundUser.PublicId = ""
assert.Nil(err)
@ -365,7 +365,7 @@ func TestDb_LookupByPublicId(t *testing.T) {
assert.Equal("error public id empty string for lookup by public id", err.Error())
})
t.Run("not-found", func(t *testing.T) {
w := Db{Tx: db}
w := Db{underlying: db}
id, err := uuid.GenerateUUID()
assert.Nil(err)
@ -385,7 +385,7 @@ func TestDb_LookupWhere(t *testing.T) {
assert := assert.New(t)
defer db.Close()
t.Run("simple", func(t *testing.T) {
w := Db{Tx: db}
w := Db{underlying: db}
id, err := uuid.GenerateUUID()
assert.Nil(err)
user, err := db_test.NewTestUser()
@ -408,7 +408,7 @@ func TestDb_LookupWhere(t *testing.T) {
assert.Equal("error tx nil for lookup by", err.Error())
})
t.Run("not-found", func(t *testing.T) {
w := Db{Tx: db}
w := Db{underlying: db}
id, err := uuid.GenerateUUID()
assert.Nil(err)
@ -418,7 +418,7 @@ func TestDb_LookupWhere(t *testing.T) {
assert.Equal(ErrRecordNotFound, err)
})
t.Run("bad-where", func(t *testing.T) {
w := Db{Tx: db}
w := Db{underlying: db}
id, err := uuid.GenerateUUID()
assert.Nil(err)
@ -435,7 +435,7 @@ func TestDb_SearchWhere(t *testing.T) {
assert := assert.New(t)
defer db.Close()
t.Run("simple", func(t *testing.T) {
w := Db{Tx: db}
w := Db{underlying: db}
id, err := uuid.GenerateUUID()
assert.Nil(err)
user, err := db_test.NewTestUser()
@ -458,7 +458,7 @@ func TestDb_SearchWhere(t *testing.T) {
assert.Equal("error tx nil for search by", err.Error())
})
t.Run("not-found", func(t *testing.T) {
w := Db{Tx: db}
w := Db{underlying: db}
id, err := uuid.GenerateUUID()
assert.Nil(err)
@ -468,7 +468,7 @@ func TestDb_SearchWhere(t *testing.T) {
assert.Equal(0, len(foundUsers))
})
t.Run("bad-where", func(t *testing.T) {
w := Db{Tx: db}
w := Db{underlying: db}
id, err := uuid.GenerateUUID()
assert.Nil(err)
@ -485,7 +485,7 @@ func TestDb_DB(t *testing.T) {
assert := assert.New(t)
defer db.Close()
t.Run("valid", func(t *testing.T) {
w := Db{Tx: db}
w := Db{underlying: db}
d, err := w.DB()
assert.Nil(err)
assert.True(d != nil)
@ -493,7 +493,7 @@ func TestDb_DB(t *testing.T) {
assert.Nil(err)
})
t.Run("nil-tx", func(t *testing.T) {
w := Db{Tx: nil}
w := Db{underlying: nil}
d, err := w.DB()
assert.True(err != nil)
assert.True(d == nil)
@ -509,7 +509,7 @@ func TestDb_DoTx(t *testing.T) {
defer db.Close()
t.Run("valid-with-10-retries", func(t *testing.T) {
w := &Db{Tx: db}
w := &Db{underlying: db}
attempts := 0
got, err := w.DoTx(context.Background(), 10, ExpBackoff{},
func(Writer) error {
@ -524,7 +524,7 @@ func TestDb_DoTx(t *testing.T) {
assert.Equal(9, attempts) // attempted 1 + 8 retries
})
t.Run("valid-with-1-retries", func(t *testing.T) {
w := &Db{Tx: db}
w := &Db{underlying: db}
attempts := 0
got, err := w.DoTx(context.Background(), 1, ExpBackoff{},
func(Writer) error {
@ -539,7 +539,7 @@ func TestDb_DoTx(t *testing.T) {
assert.Equal(2, attempts) // attempted 1 + 8 retries
})
t.Run("valid-with-2-retries", func(t *testing.T) {
w := &Db{Tx: db}
w := &Db{underlying: db}
attempts := 0
got, err := w.DoTx(context.Background(), 3, ExpBackoff{},
func(Writer) error {
@ -554,7 +554,7 @@ func TestDb_DoTx(t *testing.T) {
assert.Equal(3, attempts) // attempted 1 + 8 retries
})
t.Run("valid-with-4-retries", func(t *testing.T) {
w := &Db{Tx: db}
w := &Db{underlying: db}
attempts := 0
got, err := w.DoTx(context.Background(), 4, ExpBackoff{},
func(Writer) error {
@ -569,7 +569,7 @@ func TestDb_DoTx(t *testing.T) {
assert.Equal(4, attempts) // attempted 1 + 8 retries
})
t.Run("zero-retries", func(t *testing.T) {
w := &Db{Tx: db}
w := &Db{underlying: db}
attempts := 0
got, err := w.DoTx(context.Background(), 0, ExpBackoff{}, func(Writer) error { attempts += 1; return nil })
assert.Nil(err)
@ -585,14 +585,14 @@ func TestDb_DoTx(t *testing.T) {
assert.Equal("do tx is nil", err.Error())
})
t.Run("not-a-retry-err", func(t *testing.T) {
w := &Db{Tx: db}
w := &Db{underlying: db}
got, err := w.DoTx(context.Background(), 1, ExpBackoff{}, func(Writer) error { return errors.New("not a retry error") })
assert.True(err != nil)
assert.Equal(RetryInfo{}, got)
assert.True(err != oplog.ErrTicketAlreadyRedeemed)
})
t.Run("too-many-retries", func(t *testing.T) {
w := &Db{Tx: db}
w := &Db{underlying: db}
attempts := 0
got, err := w.DoTx(context.Background(), 2, ExpBackoff{}, func(Writer) error { attempts += 1; return oplog.ErrTicketAlreadyRedeemed })
assert.True(err != nil)
@ -608,7 +608,7 @@ func TestDb_Delete(t *testing.T) {
assert := assert.New(t)
defer db.Close()
t.Run("simple", func(t *testing.T) {
w := Db{Tx: db}
w := Db{underlying: db}
id, err := uuid.GenerateUUID()
assert.Nil(err)
user, err := db_test.NewTestUser()
@ -635,7 +635,7 @@ func TestDb_Delete(t *testing.T) {
assert.Equal(ErrRecordNotFound, err)
})
t.Run("valid-WithOplog", func(t *testing.T) {
w := Db{Tx: db}
w := Db{underlying: db}
id, err := uuid.GenerateUUID()
assert.Nil(err)
user, err := db_test.NewTestUser()
@ -682,7 +682,7 @@ func TestDb_Delete(t *testing.T) {
assert.Equal(ErrRecordNotFound, err)
})
t.Run("no-wrapper-WithOplog", func(t *testing.T) {
w := Db{Tx: db}
w := Db{underlying: db}
id, err := uuid.GenerateUUID()
assert.Nil(err)
user, err := db_test.NewTestUser()
@ -726,7 +726,7 @@ func TestDb_Delete(t *testing.T) {
assert.Equal("error wrapper is nil for WithWrapper", err.Error())
})
t.Run("no-metadata-WithOplog", func(t *testing.T) {
w := Db{Tx: db}
w := Db{underlying: db}
id, err := uuid.GenerateUUID()
assert.Nil(err)
user, err := db_test.NewTestUser()
@ -766,7 +766,7 @@ func TestDb_Delete(t *testing.T) {
assert.Equal("error no metadata for WithOplog", err.Error())
})
t.Run("nil-tx", func(t *testing.T) {
w := Db{Tx: nil}
w := Db{underlying: nil}
id, err := uuid.GenerateUUID()
assert.Nil(err)
user, err := db_test.NewTestUser()
@ -785,7 +785,7 @@ func TestDb_ScanRows(t *testing.T) {
assert := assert.New(t)
defer db.Close()
t.Run("valid", func(t *testing.T) {
w := Db{Tx: db}
w := Db{underlying: db}
user, err := db_test.NewTestUser()
assert.Nil(err)
err = w.Create(context.Background(), user)

Loading…
Cancel
Save