diff --git a/internal/db/db_test/db.go b/internal/db/db_test/db.go index 4933d16532..79097a19c1 100644 --- a/internal/db/db_test/db.go +++ b/internal/db/db_test/db.go @@ -23,6 +23,12 @@ func NewTestUser() (*TestUser, error) { }, nil } +func AllocTestUser() TestUser { + return TestUser{ + StoreTestUser: &StoreTestUser{}, + } +} + // Clone is useful when you're retrying transactions and you need to send the user several times func (u *TestUser) Clone() *TestUser { s := proto.Clone(u.StoreTestUser) diff --git a/internal/db/read_writer.go b/internal/db/read_writer.go index b408e840d6..3c36c44055 100644 --- a/internal/db/read_writer.go +++ b/internal/db/read_writer.go @@ -218,6 +218,52 @@ func (rw *Db) Create(ctx context.Context, i interface{}, opt ...Option) error { return nil } +// CreateItems will create multiple items of the same type. +// Supported options: WithOplog. WithLookup is not a supported option. +func (rw *Db) CreateItems(ctx context.Context, createItems []interface{}, opt ...Option) error { + if rw.underlying == nil { + return fmt.Errorf("create items: missing underlying db: %w", ErrNilParameter) + } + if len(createItems) == 0 { + return fmt.Errorf("create items: no interfaces to create: %w", ErrInvalidParameter) + } + opts := GetOpts(opt...) + if opts.withLookup { + return fmt.Errorf("create items: withLookup not a supported option: %w", ErrInvalidParameter) + } + // verify that createItems are all the same type. + var foundType reflect.Type + for i, v := range createItems { + if i == 0 { + foundType = reflect.TypeOf(v) + } + currentType := reflect.TypeOf(v) + if foundType != currentType { + return fmt.Errorf("create items: create items contains disparate types. item %d is not a %s: %w", i, foundType.Name(), ErrInvalidParameter) + } + } + var ticket *store.Ticket + if opts.withOplog { + _, err := validateOplogArgs(createItems[0], opts) + if err != nil { + return fmt.Errorf("create items: oplog validation failed: %w", err) + } + ticket, err = rw.getTicket(createItems[0]) + if err != nil { + return fmt.Errorf("create items: unable to get ticket: %w", err) + } + } + for _, item := range createItems { + rw.Create(ctx, item) + } + if opts.withOplog { + if err := rw.addOplogForItems(ctx, CreateOp, opts, ticket, createItems); err != nil { + return fmt.Errorf("create items: unable to add oplog: %w", err) + } + } + return nil +} + // Update an object in the db, fieldMask is required and provides // field_mask.proto paths for fields that should be updated. The i interface // parameter is the type the caller wants to update in the db and its @@ -381,6 +427,57 @@ func (rw *Db) Delete(ctx context.Context, i interface{}, opt ...Option) (int, er return rowsDeleted, nil } +// DeleteItems will delete multiple items of the same type. +// Supported options: WithOplog. +func (rw *Db) DeleteItems(ctx context.Context, deleteItems []interface{}, opt ...Option) (int, error) { + if rw.underlying == nil { + return NoRowsAffected, fmt.Errorf("delete items: missing underlying db: %w", ErrNilParameter) + } + if len(deleteItems) == 0 { + return NoRowsAffected, fmt.Errorf("delete items: no interfaces to delete: %w", ErrInvalidParameter) + } + // verify that createItems are all the same type. + var foundType reflect.Type + for i, v := range deleteItems { + if i == 0 { + foundType = reflect.TypeOf(v) + } + currentType := reflect.TypeOf(v) + if foundType != currentType { + return NoRowsAffected, fmt.Errorf("delete items: items contain disparate types. item %d is not a %s: %w", i, foundType.Name(), ErrInvalidParameter) + } + } + opts := GetOpts(opt...) + var ticket *store.Ticket + if opts.withOplog { + _, err := validateOplogArgs(deleteItems[0], opts) + if err != nil { + return NoRowsAffected, fmt.Errorf("delete items: oplog validation failed: %w", err) + } + ticket, err = rw.getTicket(deleteItems[0]) + if err != nil { + return NoRowsAffected, fmt.Errorf("delete items: unable to get ticket: %w", err) + } + } + rowsDeleted := 0 + for _, item := range deleteItems { + // calling delete directly on the underlying db, since the writer.Delete + // doesn't provide capabilities needed here (which is different from the + // relationship between Create and CreateItems). + underlying := rw.underlying.Delete(item) + if underlying.Error != nil { + return rowsDeleted, fmt.Errorf("delete: failed: %w", underlying.Error) + } + rowsDeleted += int(underlying.RowsAffected) + } + if opts.withOplog && rowsDeleted > 0 { + if err := rw.addOplogForItems(ctx, DeleteOp, opts, ticket, deleteItems); err != nil { + return rowsDeleted, fmt.Errorf("delete items: unable to add oplog: %w", err) + } + } + return rowsDeleted, nil +} + func validateOplogArgs(i interface{}, opts Options) (oplog.ReplayableMessage, error) { oplogArgs := opts.oplogOpts if oplogArgs.wrapper == nil { @@ -396,6 +493,20 @@ func validateOplogArgs(i interface{}, opts Options) (oplog.ReplayableMessage, er return replayable, nil } +func (rw *Db) getTicketFor(aggregateName string) (*store.Ticket, error) { + if rw.underlying == nil { + return nil, fmt.Errorf("get ticket for %s: underlying db missing: %w", aggregateName, ErrNilParameter) + } + ticketer, err := oplog.NewGormTicketer(rw.underlying, oplog.WithAggregateNames(true)) + if err != nil { + return nil, fmt.Errorf("get ticket for %s: unable to get Ticketer %w", aggregateName, err) + } + ticket, err := ticketer.GetTicket(aggregateName) + if err != nil { + return nil, fmt.Errorf("get ticket for %s: unable to get ticket %w", aggregateName, err) + } + return ticket, nil +} func (rw *Db) getTicket(i interface{}) (*store.Ticket, error) { if rw.underlying == nil { return nil, fmt.Errorf("get ticket: underlying db missing: %w", ErrNilParameter) @@ -404,17 +515,84 @@ func (rw *Db) getTicket(i interface{}) (*store.Ticket, error) { if !ok { return nil, fmt.Errorf("get ticket: not a replayable message %w", ErrInvalidParameter) } + return rw.getTicketFor(replayable.TableName()) +} + +// addOplogForItems will add a multi-message oplog entry with one msg for each +// item. Items must all be of the same type. Only CreateOp and DeleteOp are +// currently supported operations. +func (rw *Db) addOplogForItems(ctx context.Context, opType OpType, opts Options, ticket *store.Ticket, items []interface{}) error { + oplogArgs := opts.oplogOpts + if ticket == nil { + return fmt.Errorf("oplog many: ticket is missing: %w", ErrNilParameter) + } + if items == nil { + return fmt.Errorf("oplog many: items are missing: %w", ErrNilParameter) + } + if len(items) == 0 { + return fmt.Errorf("oplog many: items is empty: %w", ErrInvalidParameter) + } + if oplogArgs.metadata == nil { + return fmt.Errorf("oplog many: metadata is missing: %w", ErrNilParameter) + } + if oplogArgs.wrapper == nil { + return fmt.Errorf("oplog many: wrapper is missing: %w", ErrNilParameter) + } + replayable, err := validateOplogArgs(items[0], opts) + if err != nil { + return fmt.Errorf("oplog many: oplog validation failed %w", err) + } ticketer, err := oplog.NewGormTicketer(rw.underlying, oplog.WithAggregateNames(true)) if err != nil { - return nil, fmt.Errorf("get ticket: unable to get Ticketer %w", err) + return fmt.Errorf("oplog many: unable to get Ticketer %w", err) } - ticket, err := ticketer.GetTicket(replayable.TableName()) + entry, err := oplog.NewEntry( + replayable.TableName(), + oplogArgs.metadata, + oplogArgs.wrapper, + ticketer, + ) if err != nil { - return nil, fmt.Errorf("get ticket: unable to get ticket %w", err) + return fmt.Errorf("oplog many: unable to create oplog entry %w", err) } - return ticket, nil + oplogMsgs := []*oplog.Message{} + var foundType reflect.Type + for i, item := range items { + if i == 0 { + foundType = reflect.TypeOf(item) + } + currentType := reflect.TypeOf(item) + if foundType != currentType { + return fmt.Errorf("oplog many: items contains disparate types. item %d is not a %s", i, foundType.Name()) + } + replayable, ok := item.(oplog.ReplayableMessage) + if !ok { + return fmt.Errorf("oplog many: item %d not a replayable oplog message %w", i, ErrInvalidParameter) + } + msg := &oplog.Message{ + Message: item.(proto.Message), + TypeName: replayable.TableName(), + } + switch opType { + case CreateOp: + msg.OpType = oplog.OpType_OP_TYPE_CREATE + case DeleteOp: + msg.OpType = oplog.OpType_OP_TYPE_DELETE + default: + return fmt.Errorf("oplog many: operation type %v is not supported", opType) + } + oplogMsgs = append(oplogMsgs, msg) + } + if err := entry.WriteEntryWith( + ctx, + &oplog.GormWriter{Tx: rw.underlying}, + ticket, + oplogMsgs..., + ); err != nil { + return fmt.Errorf("oplog many: unable to write oplog entry %w", err) + } + return nil } - func (rw *Db) addOplog(ctx context.Context, opType OpType, opts Options, ticket *store.Ticket, i interface{}) error { oplogArgs := opts.oplogOpts replayable, err := validateOplogArgs(i, opts) diff --git a/internal/db/read_writer_test.go b/internal/db/read_writer_test.go index 4ebfbdadcb..58bb41ff69 100644 --- a/internal/db/read_writer_test.go +++ b/internal/db/read_writer_test.go @@ -1238,6 +1238,345 @@ func TestDb_ScanRows(t *testing.T) { }) } +func TestDb_CreateItems(t *testing.T) { + cleanup, db, _ := TestSetup(t, "postgres") + defer func() { + err := cleanup() + assert.NoError(t, err) + err = db.Close() + assert.NoError(t, err) + }() + testOplogResourceId := testId(t) + + createFn := func() []interface{} { + results := []interface{}{} + for i := 0; i < 10; i++ { + u, err := db_test.NewTestUser() + require.NoError(t, err) + results = append(results, u) + } + return results + } + createMixedFn := func() []interface{} { + u, err := db_test.NewTestUser() + require.NoError(t, err) + c, err := db_test.NewTestCar() + require.NoError(t, err) + return []interface{}{ + u, + c, + } + + } + type args struct { + createItems []interface{} + opt []Option + } + tests := []struct { + name string + underlying *gorm.DB + args args + wantOplogId string + wantErr bool + wantErrIs error + }{ + { + name: "simple", + underlying: db, + args: args{ + createItems: createFn(), + }, + wantErr: false, + }, + { + name: "withOplog", + underlying: db, + args: args{ + createItems: createFn(), + opt: []Option{ + WithOplog( + TestWrapper(t), + oplog.Metadata{ + "resource-public-id": []string{testOplogResourceId}, + "op-type": []string{oplog.OpType_OP_TYPE_CREATE.String()}, + }, + ), + }, + }, + wantOplogId: testOplogResourceId, + wantErr: false, + }, + { + name: "mixed items", + underlying: db, + args: args{ + createItems: createMixedFn(), + }, + wantErr: true, + wantErrIs: ErrInvalidParameter, + }, + { + name: "bad oplog opt: nil metadata", + underlying: nil, + args: args{ + createItems: createFn(), + opt: []Option{ + WithOplog( + TestWrapper(t), + nil, + ), + }, + }, + wantErr: true, + wantErrIs: ErrNilParameter, + }, + { + name: "bad oplog opt: nil wrapper", + underlying: nil, + args: args{ + createItems: createFn(), + opt: []Option{ + WithOplog( + nil, + oplog.Metadata{ + "resource-public-id": []string{"doesn't matter since wrapper is nil"}, + "op-type": []string{oplog.OpType_OP_TYPE_CREATE.String()}, + }, + ), + }, + }, + wantErr: true, + wantErrIs: ErrNilParameter, + }, + { + name: "bad opt: WithLookup", + underlying: nil, + args: args{ + createItems: createFn(), + opt: []Option{WithLookup(true)}, + }, + wantErr: true, + wantErrIs: ErrNilParameter, + }, + { + name: "nil underlying", + underlying: nil, + args: args{ + createItems: createFn(), + }, + wantErr: true, + wantErrIs: ErrNilParameter, + }, + { + name: "empty items", + underlying: db, + args: args{ + createItems: []interface{}{}, + }, + wantErr: true, + wantErrIs: ErrInvalidParameter, + }, + { + name: "nil items", + underlying: db, + args: args{ + createItems: nil, + }, + wantErr: true, + wantErrIs: ErrInvalidParameter, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + rw := &Db{ + underlying: tt.underlying, + } + err := rw.CreateItems(context.Background(), tt.args.createItems, tt.args.opt...) + if tt.wantErr { + require.Error(err) + if tt.wantErrIs != nil { + assert.Truef(errors.Is(err, tt.wantErrIs), "unexpected error: %s", err.Error()) + } + return + } + require.NoError(err) + for _, item := range tt.args.createItems { + u := db_test.AllocTestUser() + u.PublicId = item.(*db_test.TestUser).PublicId + err := rw.LookupByPublicId(context.Background(), &u) + assert.NoError(err) + } + if tt.wantOplogId != "" { + err = TestVerifyOplog(t, rw, tt.wantOplogId, WithOperation(oplog.OpType_OP_TYPE_CREATE), WithCreateNotBefore(10*time.Second)) + assert.NoError(err) + } + }) + } +} + +func TestDb_DeleteItems(t *testing.T) { + cleanup, db, _ := TestSetup(t, "postgres") + defer func() { + err := cleanup() + assert.NoError(t, err) + err = db.Close() + assert.NoError(t, err) + }() + + testOplogResourceId := testId(t) + + createFn := func() []interface{} { + results := []interface{}{} + for i := 0; i < 10; i++ { + u := testUser(t, db, "", "", "") + results = append(results, u) + } + return results + } + type args struct { + deleteItems []interface{} + opt []Option + } + tests := []struct { + name string + underlying *gorm.DB + args args + wantRowsDeleted int + wantOplogId string + wantErr bool + wantErrIs error + }{ + { + name: "simple", + underlying: db, + args: args{ + deleteItems: createFn(), + }, + wantRowsDeleted: 10, + wantErr: false, + }, + { + name: "withOplog", + underlying: db, + args: args{ + deleteItems: createFn(), + opt: []Option{ + WithOplog( + TestWrapper(t), + oplog.Metadata{ + "resource-public-id": []string{testOplogResourceId}, + "op-type": []string{oplog.OpType_OP_TYPE_DELETE.String()}, + }, + ), + }, + }, + wantRowsDeleted: 10, + wantOplogId: testOplogResourceId, + wantErr: false, + }, + { + name: "bad oplog opt: nil metadata", + underlying: nil, + args: args{ + deleteItems: createFn(), + opt: []Option{ + WithOplog( + TestWrapper(t), + nil, + ), + }, + }, + wantErr: true, + wantErrIs: ErrNilParameter, + }, + { + name: "bad oplog opt: nil wrapper", + underlying: nil, + args: args{ + deleteItems: createFn(), + opt: []Option{ + WithOplog( + nil, + oplog.Metadata{ + "resource-public-id": []string{"doesn't matter since wrapper is nil"}, + "op-type": []string{oplog.OpType_OP_TYPE_CREATE.String()}, + }, + ), + }, + }, + wantErr: true, + wantErrIs: ErrNilParameter, + }, + { + name: "bad opt: WithLookup", + underlying: nil, + args: args{ + deleteItems: createFn(), + opt: []Option{WithLookup(true)}, + }, + wantErr: true, + wantErrIs: ErrNilParameter, + }, + { + name: "nil underlying", + underlying: nil, + args: args{ + deleteItems: createFn(), + }, + wantErr: true, + wantErrIs: ErrNilParameter, + }, + { + name: "empty items", + underlying: db, + args: args{ + deleteItems: []interface{}{}, + }, + wantErr: true, + wantErrIs: ErrInvalidParameter, + }, + { + name: "nil items", + underlying: db, + args: args{ + deleteItems: nil, + }, + wantErr: true, + wantErrIs: ErrInvalidParameter, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + rw := &Db{ + underlying: tt.underlying, + } + rowsDeleted, err := rw.DeleteItems(context.Background(), tt.args.deleteItems, tt.args.opt...) + if tt.wantErr { + require.Error(err) + if tt.wantErrIs != nil { + assert.Truef(errors.Is(err, tt.wantErrIs), "unexpected error: %s", err.Error()) + } + return + } + require.NoError(err) + assert.Equal(tt.wantRowsDeleted, rowsDeleted) + for _, item := range tt.args.deleteItems { + u := db_test.AllocTestUser() + u.PublicId = item.(*db_test.TestUser).PublicId + err := rw.LookupByPublicId(context.Background(), &u) + require.Error(err) + require.Truef(errors.Is(err, ErrRecordNotFound), "found item %s that should be deleted", u.PublicId) + } + if tt.wantOplogId != "" { + err = TestVerifyOplog(t, rw, tt.wantOplogId, WithOperation(oplog.OpType_OP_TYPE_DELETE), WithCreateNotBefore(10*time.Second)) + assert.NoError(err) + } + }) + } +} + func testUser(t *testing.T, conn *gorm.DB, name, email, phoneNumber string) *db_test.TestUser { t.Helper() assert := assert.New(t)