diff --git a/internal/db/read_writer.go b/internal/db/read_writer.go index f368bd3767..42ea20f0ac 100644 --- a/internal/db/read_writer.go +++ b/internal/db/read_writer.go @@ -76,6 +76,13 @@ type Writer interface { // the transaction, which almost always should be to rollback. Create(ctx context.Context, i interface{}, opt ...Option) error + // CreateItems will create multiple items of the same type. + // Supported options: WithOplog. WithLookup is not a supported option. + // The caller is responsible for the transaction life cycle of the writer + // and if an error is returned the caller must decide what to do with + // the transaction, which almost always should be to rollback. + CreateItems(ctx context.Context, createItems []interface{}, opt ...Option) error + // Delete an object in the db with options: WithOplog // the caller is responsible for the transaction life cycle of the writer // and if an error is returned the caller must decide what to do with @@ -83,6 +90,13 @@ type Writer interface { // returns the number of rows deleted or an error. Delete(ctx context.Context, i interface{}, opt ...Option) (int, error) + // DeleteItems will delete multiple items of the same type. + // Supported options: WithOplog. The caller is responsible for the + // transaction life cycle of the writer and if an error is returned the + // caller must decide what to do with the transaction, which almost always + // should be to rollback. Delete returns the number of rows deleted or an error. + DeleteItems(ctx context.Context, deleteItems []interface{}, opt ...Option) (int, error) + // DB returns the sql.DB DB() (*sql.DB, error) @@ -294,7 +308,10 @@ func (rw *Db) CreateItems(ctx context.Context, createItems []interface{}, opt .. } } for _, item := range createItems { - rw.Create(ctx, item) + if err := rw.Create(ctx, item); err != nil { + return fmt.Errorf("create items: %w", err) + } + } if opts.withOplog { if err := rw.addOplogForItems(ctx, CreateOp, opts, ticket, createItems); err != nil { diff --git a/internal/db/read_writer_test.go b/internal/db/read_writer_test.go index 3ec5d26073..896771bb7a 100644 --- a/internal/db/read_writer_test.go +++ b/internal/db/read_writer_test.go @@ -1618,6 +1618,9 @@ func TestDb_CreateItems(t *testing.T) { u.PublicId = item.(*db_test.TestUser).PublicId err := rw.LookupByPublicId(context.Background(), &u) assert.NoError(err) + if _, ok := item.(*db_test.TestUser); ok { + assert.Truef(proto.Equal(item.(*db_test.TestUser).StoreTestUser, u.StoreTestUser), "%s and %s should be equal", item, u) + } } if tt.wantOplogId != "" { err = TestVerifyOplog(t, rw, tt.wantOplogId, WithOperation(oplog.OpType_OP_TYPE_CREATE), WithCreateNotBefore(10*time.Second))