diff --git a/internal/db/read_writer.go b/internal/db/read_writer.go index 4e9a093a5b..859af9f367 100644 --- a/internal/db/read_writer.go +++ b/internal/db/read_writer.go @@ -13,7 +13,6 @@ import ( "github.com/hashicorp/watchtower/internal/oplog" "github.com/hashicorp/watchtower/internal/oplog/store" "github.com/jinzhu/gorm" - gormbulk "github.com/t-tiger/gorm-bulk-insert/v2" "google.golang.org/protobuf/proto" ) @@ -219,66 +218,50 @@ func (rw *Db) Create(ctx context.Context, i interface{}, opt ...Option) error { return nil } -// Supported options: WithOplog -func (rw *Db) CreateMany(ctx context.Context, aggregateName string, createItems []interface{}, opt ...Option) error { +// CreateItems will create multiple items of the same type. +// Supported options: WithOplog. WithLookup is not a supported option. +func (rw *Db) CreateItems(ctx context.Context, createItems []interface{}, opt ...Option) error { if rw.underlying == nil { - return fmt.Errorf("create many: missing underlying db %w", ErrNilParameter) + return fmt.Errorf("create items: missing underlying db %w", ErrNilParameter) } if isNil(createItems) { - return fmt.Errorf("create many: interfaces is missing %w", ErrNilParameter) + return fmt.Errorf("create items: interfaces is missing %w", ErrNilParameter) } if len(createItems) == 0 { - return fmt.Errorf("create many: no interfaces to create %w", ErrInvalidParameter) + return fmt.Errorf("create items: no interfaces to create %w", ErrInvalidParameter) } opts := GetOpts(opt...) if opts.withLookup { - return fmt.Errorf("create many: withLookup not a supported option %w", ErrInvalidParameter) + return fmt.Errorf("create items: withLookup not a supported option %w", ErrInvalidParameter) + } + // verify that createItems are all the same type. + var foundType reflect.Type + for i, v := range createItems { + if i == 0 { + foundType = reflect.TypeOf(v) + } + currentType := reflect.TypeOf(v) + if foundType != currentType { + return fmt.Errorf("create items: create items contains disparate types. item %d is not a %s: %w", i, foundType.Name(), ErrInvalidParameter) + } } var ticket *store.Ticket if opts.withOplog { - var err error - ticket, err = rw.getTicketFor(aggregateName) + _, err := validateOplogArgs(createItems[0], opts) if err != nil { - return fmt.Errorf("create many: unable to get ticket: %w", err) + return fmt.Errorf("create items: oplog validation failed %w", err) + } + ticket, err = rw.getTicket(createItems[0]) + if err != nil { + return fmt.Errorf("create items: unable to get ticket: %w", err) } } - if err := gormbulk.BulkInsert(rw.underlying, createItems, 3000); err != nil { - return fmt.Errorf("create many: failed %w", err) + for _, item := range createItems { + rw.Create(ctx, item) } if opts.withOplog { - oplogArgs := opts.oplogOpts - ticketer, err := oplog.NewGormTicketer(rw.underlying, oplog.WithAggregateNames(true)) - if err != nil { - return fmt.Errorf("add oplog: unable to get Ticketer %w", err) - } - entry, err := oplog.NewEntry( - aggregateName, - oplogArgs.metadata, - oplogArgs.wrapper, - ticketer, - ) - if err != nil { - return fmt.Errorf("create many: unable to create oplog entry %w", err) - } - oplogMsgs := []*oplog.Message{} - for i, item := range createItems { - replayable, ok := item.(oplog.ReplayableMessage) - if !ok { - return fmt.Errorf("create many: item %d not a replayable oplog message %w", i, ErrInvalidParameter) - } - oplogMsgs = append(oplogMsgs, &oplog.Message{ - Message: item.(proto.Message), - TypeName: replayable.TableName(), - OpType: oplog.OpType_OP_TYPE_CREATE, - }) - } - if err := entry.WriteEntryWith( - ctx, - &oplog.GormWriter{Tx: rw.underlying}, - ticket, - oplogMsgs..., - ); err != nil { - return fmt.Errorf("create many: unable to write oplog entry %w", err) + if err := rw.addOplogForItems(ctx, CreateOp, opts, ticket, createItems); err != nil { + return fmt.Errorf("create items: unable to add oplog %w", err) } } return nil @@ -447,6 +430,57 @@ func (rw *Db) Delete(ctx context.Context, i interface{}, opt ...Option) (int, er return rowsDeleted, nil } +// DeleteItems will deletes multiple items of the same type. +// Supported options: WithOplog. +func (rw *Db) DeleteItems(ctx context.Context, deleteItems []interface{}, opt ...Option) (int, error) { + if rw.underlying == nil { + return NoRowsAffected, fmt.Errorf("delete items: missing underlying db %w", ErrNilParameter) + } + if isNil(deleteItems) { + return NoRowsAffected, fmt.Errorf("delete items: interfaces is missing %w", ErrNilParameter) + } + if len(deleteItems) == 0 { + return NoRowsAffected, fmt.Errorf("delete items: no interfaces to delete %w", ErrInvalidParameter) + } + // verify that createItems are all the same type. + var foundType reflect.Type + for i, v := range deleteItems { + if i == 0 { + foundType = reflect.TypeOf(v) + } + currentType := reflect.TypeOf(v) + if foundType != currentType { + return NoRowsAffected, fmt.Errorf("delete items: items contain disparate types. item %d is not a %s: %w", i, foundType.Name(), ErrInvalidParameter) + } + } + opts := GetOpts(opt...) + var ticket *store.Ticket + if opts.withOplog { + _, err := validateOplogArgs(deleteItems[0], opts) + if err != nil { + return NoRowsAffected, fmt.Errorf("delete items: oplog validation failed %w", err) + } + ticket, err = rw.getTicket(deleteItems[0]) + if err != nil { + return NoRowsAffected, fmt.Errorf("delete items: unable to get ticket: %w", err) + } + } + rowsDeleted := 0 + for _, item := range deleteItems { + underlying := rw.underlying.Delete(item) + if underlying.Error != nil { + return rowsDeleted, fmt.Errorf("delete: failed %w", underlying.Error) + } + rowsDeleted += int(underlying.RowsAffected) + } + if opts.withOplog && rowsDeleted > 0 { + if err := rw.addOplogForItems(ctx, DeleteOp, opts, ticket, deleteItems); err != nil { + return rowsDeleted, fmt.Errorf("delete items: unable to add oplog %w", err) + } + } + return rowsDeleted, nil +} + func validateOplogArgs(i interface{}, opts Options) (oplog.ReplayableMessage, error) { oplogArgs := opts.oplogOpts if oplogArgs.wrapper == nil { @@ -487,6 +521,82 @@ func (rw *Db) getTicket(i interface{}) (*store.Ticket, error) { return rw.getTicketFor(replayable.TableName()) } +// addOplogForItems will add a multi-message oplog entry with one msg for each +// item. Items must all be of the same type. Only CreateOp and DeleteOp are +// currently supported operations. +func (rw *Db) addOplogForItems(ctx context.Context, opType OpType, opts Options, ticket *store.Ticket, items []interface{}) error { + oplogArgs := opts.oplogOpts + if ticket == nil { + return fmt.Errorf("oplog many: ticket is missing: %w", ErrNilParameter) + } + if items == nil { + return fmt.Errorf("oplog many: items are missing: %w", ErrNilParameter) + } + if len(items) == 0 { + return fmt.Errorf("oplog many: items is empty: %w", ErrInvalidParameter) + } + if oplogArgs.metadata == nil { + return fmt.Errorf("oplog many: metadata is missing: %w", ErrNilParameter) + } + if oplogArgs.wrapper == nil { + return fmt.Errorf("oplog many: wrapper is missing: %w", ErrNilParameter) + } + replayable, err := validateOplogArgs(items[0], opts) + if err != nil { + return fmt.Errorf("oplog many: oplog validation failed %w", err) + } + ticketer, err := oplog.NewGormTicketer(rw.underlying, oplog.WithAggregateNames(true)) + if err != nil { + return fmt.Errorf("oplog many: unable to get Ticketer %w", err) + } + entry, err := oplog.NewEntry( + replayable.TableName(), + oplogArgs.metadata, + oplogArgs.wrapper, + ticketer, + ) + if err != nil { + return fmt.Errorf("oplog many: unable to create oplog entry %w", err) + } + oplogMsgs := []*oplog.Message{} + var foundType reflect.Type + for i, item := range items { + if i == 0 { + foundType = reflect.TypeOf(item) + } + currentType := reflect.TypeOf(item) + if foundType != currentType { + return fmt.Errorf("oplog many: items contains disparate types. item %d is not a %s", i, foundType.Name()) + } + replayable, ok := item.(oplog.ReplayableMessage) + if !ok { + return fmt.Errorf("oplog many: item %d not a replayable oplog message %w", i, ErrInvalidParameter) + } + msg := &oplog.Message{ + Message: item.(proto.Message), + TypeName: replayable.TableName(), + OpType: oplog.OpType_OP_TYPE_CREATE, + } + switch opType { + case CreateOp: + msg.OpType = oplog.OpType_OP_TYPE_CREATE + case DeleteOp: + msg.OpType = oplog.OpType_OP_TYPE_DELETE + default: + return fmt.Errorf("oplog many: operation type %v is not supported", opType) + } + oplogMsgs = append(oplogMsgs, msg) + } + if err := entry.WriteEntryWith( + ctx, + &oplog.GormWriter{Tx: rw.underlying}, + ticket, + oplogMsgs..., + ); err != nil { + return fmt.Errorf("oplog many: unable to write oplog entry %w", err) + } + return nil +} func (rw *Db) addOplog(ctx context.Context, opType OpType, opts Options, ticket *store.Ticket, i interface{}) error { oplogArgs := opts.oplogOpts replayable, err := validateOplogArgs(i, opts) diff --git a/internal/db/read_writer_test.go b/internal/db/read_writer_test.go index 614a65a2ff..80ef0de292 100644 --- a/internal/db/read_writer_test.go +++ b/internal/db/read_writer_test.go @@ -1237,6 +1237,345 @@ func TestDb_ScanRows(t *testing.T) { }) } +func TestDb_CreateItems(t *testing.T) { + cleanup, db, _ := TestSetup(t, "postgres") + defer func() { + err := cleanup() + assert.NoError(t, err) + err = db.Close() + assert.NoError(t, err) + }() + testOplogResourceId := testId(t) + + createFn := func() []interface{} { + results := []interface{}{} + for i := 0; i < 10; i++ { + u, err := db_test.NewTestUser() + require.NoError(t, err) + results = append(results, u) + } + return results + } + createMixedFn := func() []interface{} { + u, err := db_test.NewTestUser() + require.NoError(t, err) + c, err := db_test.NewTestCar() + require.NoError(t, err) + return []interface{}{ + u, + c, + } + + } + type args struct { + createItems []interface{} + opt []Option + } + tests := []struct { + name string + underlying *gorm.DB + args args + wantOplogId string + wantErr bool + wantErrIs error + }{ + { + name: "simple", + underlying: db, + args: args{ + createItems: createFn(), + }, + wantErr: false, + }, + { + name: "withOplog", + underlying: db, + args: args{ + createItems: createFn(), + opt: []Option{ + WithOplog( + TestWrapper(t), + oplog.Metadata{ + "resource-public-id": []string{testOplogResourceId}, + "op-type": []string{oplog.OpType_OP_TYPE_CREATE.String()}, + }, + ), + }, + }, + wantOplogId: testOplogResourceId, + wantErr: false, + }, + { + name: "mixed items", + underlying: db, + args: args{ + createItems: createMixedFn(), + }, + wantErr: true, + wantErrIs: ErrInvalidParameter, + }, + { + name: "bad oplog opt: nil metadata", + underlying: nil, + args: args{ + createItems: createFn(), + opt: []Option{ + WithOplog( + TestWrapper(t), + nil, + ), + }, + }, + wantErr: true, + wantErrIs: ErrNilParameter, + }, + { + name: "bad oplog opt: nil wrapper", + underlying: nil, + args: args{ + createItems: createFn(), + opt: []Option{ + WithOplog( + nil, + oplog.Metadata{ + "resource-public-id": []string{"doesn't matter since wrapper is nil"}, + "op-type": []string{oplog.OpType_OP_TYPE_CREATE.String()}, + }, + ), + }, + }, + wantErr: true, + wantErrIs: ErrNilParameter, + }, + { + name: "bad opt: WithLookup", + underlying: nil, + args: args{ + createItems: createFn(), + opt: []Option{WithLookup(true)}, + }, + wantErr: true, + wantErrIs: ErrNilParameter, + }, + { + name: "nil underlying", + underlying: nil, + args: args{ + createItems: createFn(), + }, + wantErr: true, + wantErrIs: ErrNilParameter, + }, + { + name: "empty items", + underlying: db, + args: args{ + createItems: []interface{}{}, + }, + wantErr: true, + wantErrIs: ErrInvalidParameter, + }, + { + name: "nil items", + underlying: db, + args: args{ + createItems: nil, + }, + wantErr: true, + wantErrIs: ErrNilParameter, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + rw := &Db{ + underlying: tt.underlying, + } + err := rw.CreateItems(context.Background(), tt.args.createItems, tt.args.opt...) + if tt.wantErr { + require.Error(err) + if tt.wantErrIs != nil { + assert.Truef(errors.Is(err, tt.wantErrIs), "unexpected error: %s", err.Error()) + } + return + } + require.NoError(err) + for _, item := range tt.args.createItems { + u := db_test.AllocTestUser() + u.PublicId = item.(*db_test.TestUser).PublicId + err := rw.LookupByPublicId(context.Background(), &u) + assert.NoError(err) + } + if tt.wantOplogId != "" { + err = TestVerifyOplog(t, rw, tt.wantOplogId, WithOperation(oplog.OpType_OP_TYPE_CREATE), WithCreateNotBefore(10*time.Second)) + assert.NoError(err) + } + }) + } +} + +func TestDb_DeleteItems(t *testing.T) { + cleanup, db, _ := TestSetup(t, "postgres") + defer func() { + err := cleanup() + assert.NoError(t, err) + err = db.Close() + assert.NoError(t, err) + }() + + testOplogResourceId := testId(t) + + createFn := func() []interface{} { + results := []interface{}{} + for i := 0; i < 10; i++ { + u := testUser(t, db, "", "", "") + results = append(results, u) + } + return results + } + type args struct { + deleteItems []interface{} + opt []Option + } + tests := []struct { + name string + underlying *gorm.DB + args args + wantRowsDeleted int + wantOplogId string + wantErr bool + wantErrIs error + }{ + { + name: "simple", + underlying: db, + args: args{ + deleteItems: createFn(), + }, + wantRowsDeleted: 10, + wantErr: false, + }, + { + name: "withOplog", + underlying: db, + args: args{ + deleteItems: createFn(), + opt: []Option{ + WithOplog( + TestWrapper(t), + oplog.Metadata{ + "resource-public-id": []string{testOplogResourceId}, + "op-type": []string{oplog.OpType_OP_TYPE_DELETE.String()}, + }, + ), + }, + }, + wantRowsDeleted: 10, + wantOplogId: testOplogResourceId, + wantErr: false, + }, + { + name: "bad oplog opt: nil metadata", + underlying: nil, + args: args{ + deleteItems: createFn(), + opt: []Option{ + WithOplog( + TestWrapper(t), + nil, + ), + }, + }, + wantErr: true, + wantErrIs: ErrNilParameter, + }, + { + name: "bad oplog opt: nil wrapper", + underlying: nil, + args: args{ + deleteItems: createFn(), + opt: []Option{ + WithOplog( + nil, + oplog.Metadata{ + "resource-public-id": []string{"doesn't matter since wrapper is nil"}, + "op-type": []string{oplog.OpType_OP_TYPE_CREATE.String()}, + }, + ), + }, + }, + wantErr: true, + wantErrIs: ErrNilParameter, + }, + { + name: "bad opt: WithLookup", + underlying: nil, + args: args{ + deleteItems: createFn(), + opt: []Option{WithLookup(true)}, + }, + wantErr: true, + wantErrIs: ErrNilParameter, + }, + { + name: "nil underlying", + underlying: nil, + args: args{ + deleteItems: createFn(), + }, + wantErr: true, + wantErrIs: ErrNilParameter, + }, + { + name: "empty items", + underlying: db, + args: args{ + deleteItems: []interface{}{}, + }, + wantErr: true, + wantErrIs: ErrInvalidParameter, + }, + { + name: "nil items", + underlying: db, + args: args{ + deleteItems: nil, + }, + wantErr: true, + wantErrIs: ErrNilParameter, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + rw := &Db{ + underlying: tt.underlying, + } + rowsDeleted, err := rw.DeleteItems(context.Background(), tt.args.deleteItems, tt.args.opt...) + if tt.wantErr { + require.Error(err) + if tt.wantErrIs != nil { + assert.Truef(errors.Is(err, tt.wantErrIs), "unexpected error: %s", err.Error()) + } + return + } + require.NoError(err) + assert.Equal(tt.wantRowsDeleted, rowsDeleted) + for _, item := range tt.args.deleteItems { + u := db_test.AllocTestUser() + u.PublicId = item.(*db_test.TestUser).PublicId + err := rw.LookupByPublicId(context.Background(), &u) + require.Error(err) + require.Truef(errors.Is(err, ErrRecordNotFound), "found item %s that should be deleted", u.PublicId) + } + if tt.wantOplogId != "" { + err = TestVerifyOplog(t, rw, tt.wantOplogId, WithOperation(oplog.OpType_OP_TYPE_DELETE), WithCreateNotBefore(10*time.Second)) + assert.NoError(err) + } + }) + } +} + func testUser(t *testing.T, conn *gorm.DB, name, email, phoneNumber string) *db_test.TestUser { t.Helper() assert := assert.New(t)