add CreateItems() and DeleteItems() (#146)

pull/144/head^2
Jim 6 years ago committed by GitHub
parent d69588131c
commit 9a87df38c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -23,6 +23,12 @@ func NewTestUser() (*TestUser, error) {
}, nil
}
func AllocTestUser() TestUser {
return TestUser{
StoreTestUser: &StoreTestUser{},
}
}
// Clone is useful when you're retrying transactions and you need to send the user several times
func (u *TestUser) Clone() *TestUser {
s := proto.Clone(u.StoreTestUser)

@ -218,6 +218,52 @@ func (rw *Db) Create(ctx context.Context, i interface{}, opt ...Option) error {
return nil
}
// CreateItems will create multiple items of the same type.
// Supported options: WithOplog. WithLookup is not a supported option.
func (rw *Db) CreateItems(ctx context.Context, createItems []interface{}, opt ...Option) error {
if rw.underlying == nil {
return fmt.Errorf("create items: missing underlying db: %w", ErrNilParameter)
}
if len(createItems) == 0 {
return fmt.Errorf("create items: no interfaces to create: %w", ErrInvalidParameter)
}
opts := GetOpts(opt...)
if opts.withLookup {
return fmt.Errorf("create items: withLookup not a supported option: %w", ErrInvalidParameter)
}
// verify that createItems are all the same type.
var foundType reflect.Type
for i, v := range createItems {
if i == 0 {
foundType = reflect.TypeOf(v)
}
currentType := reflect.TypeOf(v)
if foundType != currentType {
return fmt.Errorf("create items: create items contains disparate types. item %d is not a %s: %w", i, foundType.Name(), ErrInvalidParameter)
}
}
var ticket *store.Ticket
if opts.withOplog {
_, err := validateOplogArgs(createItems[0], opts)
if err != nil {
return fmt.Errorf("create items: oplog validation failed: %w", err)
}
ticket, err = rw.getTicket(createItems[0])
if err != nil {
return fmt.Errorf("create items: unable to get ticket: %w", err)
}
}
for _, item := range createItems {
rw.Create(ctx, item)
}
if opts.withOplog {
if err := rw.addOplogForItems(ctx, CreateOp, opts, ticket, createItems); err != nil {
return fmt.Errorf("create items: unable to add oplog: %w", err)
}
}
return nil
}
// Update an object in the db, fieldMask is required and provides
// field_mask.proto paths for fields that should be updated. The i interface
// parameter is the type the caller wants to update in the db and its
@ -381,6 +427,57 @@ func (rw *Db) Delete(ctx context.Context, i interface{}, opt ...Option) (int, er
return rowsDeleted, nil
}
// DeleteItems will delete multiple items of the same type.
// Supported options: WithOplog.
func (rw *Db) DeleteItems(ctx context.Context, deleteItems []interface{}, opt ...Option) (int, error) {
if rw.underlying == nil {
return NoRowsAffected, fmt.Errorf("delete items: missing underlying db: %w", ErrNilParameter)
}
if len(deleteItems) == 0 {
return NoRowsAffected, fmt.Errorf("delete items: no interfaces to delete: %w", ErrInvalidParameter)
}
// verify that createItems are all the same type.
var foundType reflect.Type
for i, v := range deleteItems {
if i == 0 {
foundType = reflect.TypeOf(v)
}
currentType := reflect.TypeOf(v)
if foundType != currentType {
return NoRowsAffected, fmt.Errorf("delete items: items contain disparate types. item %d is not a %s: %w", i, foundType.Name(), ErrInvalidParameter)
}
}
opts := GetOpts(opt...)
var ticket *store.Ticket
if opts.withOplog {
_, err := validateOplogArgs(deleteItems[0], opts)
if err != nil {
return NoRowsAffected, fmt.Errorf("delete items: oplog validation failed: %w", err)
}
ticket, err = rw.getTicket(deleteItems[0])
if err != nil {
return NoRowsAffected, fmt.Errorf("delete items: unable to get ticket: %w", err)
}
}
rowsDeleted := 0
for _, item := range deleteItems {
// calling delete directly on the underlying db, since the writer.Delete
// doesn't provide capabilities needed here (which is different from the
// relationship between Create and CreateItems).
underlying := rw.underlying.Delete(item)
if underlying.Error != nil {
return rowsDeleted, fmt.Errorf("delete: failed: %w", underlying.Error)
}
rowsDeleted += int(underlying.RowsAffected)
}
if opts.withOplog && rowsDeleted > 0 {
if err := rw.addOplogForItems(ctx, DeleteOp, opts, ticket, deleteItems); err != nil {
return rowsDeleted, fmt.Errorf("delete items: unable to add oplog: %w", err)
}
}
return rowsDeleted, nil
}
func validateOplogArgs(i interface{}, opts Options) (oplog.ReplayableMessage, error) {
oplogArgs := opts.oplogOpts
if oplogArgs.wrapper == nil {
@ -396,6 +493,20 @@ func validateOplogArgs(i interface{}, opts Options) (oplog.ReplayableMessage, er
return replayable, nil
}
func (rw *Db) getTicketFor(aggregateName string) (*store.Ticket, error) {
if rw.underlying == nil {
return nil, fmt.Errorf("get ticket for %s: underlying db missing: %w", aggregateName, ErrNilParameter)
}
ticketer, err := oplog.NewGormTicketer(rw.underlying, oplog.WithAggregateNames(true))
if err != nil {
return nil, fmt.Errorf("get ticket for %s: unable to get Ticketer %w", aggregateName, err)
}
ticket, err := ticketer.GetTicket(aggregateName)
if err != nil {
return nil, fmt.Errorf("get ticket for %s: unable to get ticket %w", aggregateName, err)
}
return ticket, nil
}
func (rw *Db) getTicket(i interface{}) (*store.Ticket, error) {
if rw.underlying == nil {
return nil, fmt.Errorf("get ticket: underlying db missing: %w", ErrNilParameter)
@ -404,17 +515,84 @@ func (rw *Db) getTicket(i interface{}) (*store.Ticket, error) {
if !ok {
return nil, fmt.Errorf("get ticket: not a replayable message %w", ErrInvalidParameter)
}
return rw.getTicketFor(replayable.TableName())
}
// addOplogForItems will add a multi-message oplog entry with one msg for each
// item. Items must all be of the same type. Only CreateOp and DeleteOp are
// currently supported operations.
func (rw *Db) addOplogForItems(ctx context.Context, opType OpType, opts Options, ticket *store.Ticket, items []interface{}) error {
oplogArgs := opts.oplogOpts
if ticket == nil {
return fmt.Errorf("oplog many: ticket is missing: %w", ErrNilParameter)
}
if items == nil {
return fmt.Errorf("oplog many: items are missing: %w", ErrNilParameter)
}
if len(items) == 0 {
return fmt.Errorf("oplog many: items is empty: %w", ErrInvalidParameter)
}
if oplogArgs.metadata == nil {
return fmt.Errorf("oplog many: metadata is missing: %w", ErrNilParameter)
}
if oplogArgs.wrapper == nil {
return fmt.Errorf("oplog many: wrapper is missing: %w", ErrNilParameter)
}
replayable, err := validateOplogArgs(items[0], opts)
if err != nil {
return fmt.Errorf("oplog many: oplog validation failed %w", err)
}
ticketer, err := oplog.NewGormTicketer(rw.underlying, oplog.WithAggregateNames(true))
if err != nil {
return nil, fmt.Errorf("get ticket: unable to get Ticketer %w", err)
return fmt.Errorf("oplog many: unable to get Ticketer %w", err)
}
ticket, err := ticketer.GetTicket(replayable.TableName())
entry, err := oplog.NewEntry(
replayable.TableName(),
oplogArgs.metadata,
oplogArgs.wrapper,
ticketer,
)
if err != nil {
return nil, fmt.Errorf("get ticket: unable to get ticket %w", err)
return fmt.Errorf("oplog many: unable to create oplog entry %w", err)
}
return ticket, nil
oplogMsgs := []*oplog.Message{}
var foundType reflect.Type
for i, item := range items {
if i == 0 {
foundType = reflect.TypeOf(item)
}
currentType := reflect.TypeOf(item)
if foundType != currentType {
return fmt.Errorf("oplog many: items contains disparate types. item %d is not a %s", i, foundType.Name())
}
replayable, ok := item.(oplog.ReplayableMessage)
if !ok {
return fmt.Errorf("oplog many: item %d not a replayable oplog message %w", i, ErrInvalidParameter)
}
msg := &oplog.Message{
Message: item.(proto.Message),
TypeName: replayable.TableName(),
}
switch opType {
case CreateOp:
msg.OpType = oplog.OpType_OP_TYPE_CREATE
case DeleteOp:
msg.OpType = oplog.OpType_OP_TYPE_DELETE
default:
return fmt.Errorf("oplog many: operation type %v is not supported", opType)
}
oplogMsgs = append(oplogMsgs, msg)
}
if err := entry.WriteEntryWith(
ctx,
&oplog.GormWriter{Tx: rw.underlying},
ticket,
oplogMsgs...,
); err != nil {
return fmt.Errorf("oplog many: unable to write oplog entry %w", err)
}
return nil
}
func (rw *Db) addOplog(ctx context.Context, opType OpType, opts Options, ticket *store.Ticket, i interface{}) error {
oplogArgs := opts.oplogOpts
replayable, err := validateOplogArgs(i, opts)

@ -1238,6 +1238,345 @@ func TestDb_ScanRows(t *testing.T) {
})
}
func TestDb_CreateItems(t *testing.T) {
cleanup, db, _ := TestSetup(t, "postgres")
defer func() {
err := cleanup()
assert.NoError(t, err)
err = db.Close()
assert.NoError(t, err)
}()
testOplogResourceId := testId(t)
createFn := func() []interface{} {
results := []interface{}{}
for i := 0; i < 10; i++ {
u, err := db_test.NewTestUser()
require.NoError(t, err)
results = append(results, u)
}
return results
}
createMixedFn := func() []interface{} {
u, err := db_test.NewTestUser()
require.NoError(t, err)
c, err := db_test.NewTestCar()
require.NoError(t, err)
return []interface{}{
u,
c,
}
}
type args struct {
createItems []interface{}
opt []Option
}
tests := []struct {
name string
underlying *gorm.DB
args args
wantOplogId string
wantErr bool
wantErrIs error
}{
{
name: "simple",
underlying: db,
args: args{
createItems: createFn(),
},
wantErr: false,
},
{
name: "withOplog",
underlying: db,
args: args{
createItems: createFn(),
opt: []Option{
WithOplog(
TestWrapper(t),
oplog.Metadata{
"resource-public-id": []string{testOplogResourceId},
"op-type": []string{oplog.OpType_OP_TYPE_CREATE.String()},
},
),
},
},
wantOplogId: testOplogResourceId,
wantErr: false,
},
{
name: "mixed items",
underlying: db,
args: args{
createItems: createMixedFn(),
},
wantErr: true,
wantErrIs: ErrInvalidParameter,
},
{
name: "bad oplog opt: nil metadata",
underlying: nil,
args: args{
createItems: createFn(),
opt: []Option{
WithOplog(
TestWrapper(t),
nil,
),
},
},
wantErr: true,
wantErrIs: ErrNilParameter,
},
{
name: "bad oplog opt: nil wrapper",
underlying: nil,
args: args{
createItems: createFn(),
opt: []Option{
WithOplog(
nil,
oplog.Metadata{
"resource-public-id": []string{"doesn't matter since wrapper is nil"},
"op-type": []string{oplog.OpType_OP_TYPE_CREATE.String()},
},
),
},
},
wantErr: true,
wantErrIs: ErrNilParameter,
},
{
name: "bad opt: WithLookup",
underlying: nil,
args: args{
createItems: createFn(),
opt: []Option{WithLookup(true)},
},
wantErr: true,
wantErrIs: ErrNilParameter,
},
{
name: "nil underlying",
underlying: nil,
args: args{
createItems: createFn(),
},
wantErr: true,
wantErrIs: ErrNilParameter,
},
{
name: "empty items",
underlying: db,
args: args{
createItems: []interface{}{},
},
wantErr: true,
wantErrIs: ErrInvalidParameter,
},
{
name: "nil items",
underlying: db,
args: args{
createItems: nil,
},
wantErr: true,
wantErrIs: ErrInvalidParameter,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
rw := &Db{
underlying: tt.underlying,
}
err := rw.CreateItems(context.Background(), tt.args.createItems, tt.args.opt...)
if tt.wantErr {
require.Error(err)
if tt.wantErrIs != nil {
assert.Truef(errors.Is(err, tt.wantErrIs), "unexpected error: %s", err.Error())
}
return
}
require.NoError(err)
for _, item := range tt.args.createItems {
u := db_test.AllocTestUser()
u.PublicId = item.(*db_test.TestUser).PublicId
err := rw.LookupByPublicId(context.Background(), &u)
assert.NoError(err)
}
if tt.wantOplogId != "" {
err = TestVerifyOplog(t, rw, tt.wantOplogId, WithOperation(oplog.OpType_OP_TYPE_CREATE), WithCreateNotBefore(10*time.Second))
assert.NoError(err)
}
})
}
}
func TestDb_DeleteItems(t *testing.T) {
cleanup, db, _ := TestSetup(t, "postgres")
defer func() {
err := cleanup()
assert.NoError(t, err)
err = db.Close()
assert.NoError(t, err)
}()
testOplogResourceId := testId(t)
createFn := func() []interface{} {
results := []interface{}{}
for i := 0; i < 10; i++ {
u := testUser(t, db, "", "", "")
results = append(results, u)
}
return results
}
type args struct {
deleteItems []interface{}
opt []Option
}
tests := []struct {
name string
underlying *gorm.DB
args args
wantRowsDeleted int
wantOplogId string
wantErr bool
wantErrIs error
}{
{
name: "simple",
underlying: db,
args: args{
deleteItems: createFn(),
},
wantRowsDeleted: 10,
wantErr: false,
},
{
name: "withOplog",
underlying: db,
args: args{
deleteItems: createFn(),
opt: []Option{
WithOplog(
TestWrapper(t),
oplog.Metadata{
"resource-public-id": []string{testOplogResourceId},
"op-type": []string{oplog.OpType_OP_TYPE_DELETE.String()},
},
),
},
},
wantRowsDeleted: 10,
wantOplogId: testOplogResourceId,
wantErr: false,
},
{
name: "bad oplog opt: nil metadata",
underlying: nil,
args: args{
deleteItems: createFn(),
opt: []Option{
WithOplog(
TestWrapper(t),
nil,
),
},
},
wantErr: true,
wantErrIs: ErrNilParameter,
},
{
name: "bad oplog opt: nil wrapper",
underlying: nil,
args: args{
deleteItems: createFn(),
opt: []Option{
WithOplog(
nil,
oplog.Metadata{
"resource-public-id": []string{"doesn't matter since wrapper is nil"},
"op-type": []string{oplog.OpType_OP_TYPE_CREATE.String()},
},
),
},
},
wantErr: true,
wantErrIs: ErrNilParameter,
},
{
name: "bad opt: WithLookup",
underlying: nil,
args: args{
deleteItems: createFn(),
opt: []Option{WithLookup(true)},
},
wantErr: true,
wantErrIs: ErrNilParameter,
},
{
name: "nil underlying",
underlying: nil,
args: args{
deleteItems: createFn(),
},
wantErr: true,
wantErrIs: ErrNilParameter,
},
{
name: "empty items",
underlying: db,
args: args{
deleteItems: []interface{}{},
},
wantErr: true,
wantErrIs: ErrInvalidParameter,
},
{
name: "nil items",
underlying: db,
args: args{
deleteItems: nil,
},
wantErr: true,
wantErrIs: ErrInvalidParameter,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
rw := &Db{
underlying: tt.underlying,
}
rowsDeleted, err := rw.DeleteItems(context.Background(), tt.args.deleteItems, tt.args.opt...)
if tt.wantErr {
require.Error(err)
if tt.wantErrIs != nil {
assert.Truef(errors.Is(err, tt.wantErrIs), "unexpected error: %s", err.Error())
}
return
}
require.NoError(err)
assert.Equal(tt.wantRowsDeleted, rowsDeleted)
for _, item := range tt.args.deleteItems {
u := db_test.AllocTestUser()
u.PublicId = item.(*db_test.TestUser).PublicId
err := rw.LookupByPublicId(context.Background(), &u)
require.Error(err)
require.Truef(errors.Is(err, ErrRecordNotFound), "found item %s that should be deleted", u.PublicId)
}
if tt.wantOplogId != "" {
err = TestVerifyOplog(t, rw, tt.wantOplogId, WithOperation(oplog.OpType_OP_TYPE_DELETE), WithCreateNotBefore(10*time.Second))
assert.NoError(err)
}
})
}
}
func testUser(t *testing.T, conn *gorm.DB, name, email, phoneNumber string) *db_test.TestUser {
t.Helper()
assert := assert.New(t)

Loading…
Cancel
Save