|
|
|
|
@ -11,6 +11,7 @@ import (
|
|
|
|
|
|
|
|
|
|
"github.com/hashicorp/watchtower/internal/db/common"
|
|
|
|
|
"github.com/hashicorp/watchtower/internal/oplog"
|
|
|
|
|
"github.com/hashicorp/watchtower/internal/oplog/store"
|
|
|
|
|
"github.com/jinzhu/gorm"
|
|
|
|
|
"google.golang.org/protobuf/proto"
|
|
|
|
|
)
|
|
|
|
|
@ -195,11 +196,19 @@ func (rw *Db) Create(ctx context.Context, i interface{}, opt ...Option) error {
|
|
|
|
|
return fmt.Errorf("create: vet for write failed %w", err)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
var ticket *store.Ticket
|
|
|
|
|
if withOplog {
|
|
|
|
|
var err error
|
|
|
|
|
ticket, err = rw.getTicket(i)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return fmt.Errorf("create: unable to get ticket: %w", err)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if err := rw.underlying.Create(i).Error; err != nil {
|
|
|
|
|
return fmt.Errorf("create: failed %w", err)
|
|
|
|
|
}
|
|
|
|
|
if withOplog {
|
|
|
|
|
if err := rw.addOplog(ctx, CreateOp, opts, i); err != nil {
|
|
|
|
|
if err := rw.addOplog(ctx, CreateOp, opts, ticket, i); err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
@ -272,7 +281,14 @@ func (rw *Db) Update(ctx context.Context, i interface{}, fieldMaskPaths []string
|
|
|
|
|
return NoRowsAffected, fmt.Errorf("update: vet for write failed %w", err)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
var ticket *store.Ticket
|
|
|
|
|
if withOplog {
|
|
|
|
|
var err error
|
|
|
|
|
ticket, err = rw.getTicket(i)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return NoRowsAffected, fmt.Errorf("update: unable to get ticket: %w", err)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
underlying := rw.underlying.Model(i).Updates(updateFields)
|
|
|
|
|
if underlying.Error != nil {
|
|
|
|
|
if err == gorm.ErrRecordNotFound {
|
|
|
|
|
@ -294,7 +310,7 @@ func (rw *Db) Update(ctx context.Context, i interface{}, fieldMaskPaths []string
|
|
|
|
|
WithFieldMaskPaths: oplogFieldMasks,
|
|
|
|
|
WithNullPaths: oplogNullPaths,
|
|
|
|
|
}
|
|
|
|
|
if err := rw.addOplog(ctx, UpdateOp, oplogOpts, i); err != nil {
|
|
|
|
|
if err := rw.addOplog(ctx, UpdateOp, oplogOpts, ticket, i); err != nil {
|
|
|
|
|
return rowsUpdated, fmt.Errorf("update: add oplog failed %w", err)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
@ -331,13 +347,21 @@ func (rw *Db) Delete(ctx context.Context, i interface{}, opt ...Option) (int, er
|
|
|
|
|
return NoRowsAffected, fmt.Errorf("delete: oplog validation failed %w", err)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
var ticket *store.Ticket
|
|
|
|
|
if withOplog {
|
|
|
|
|
var err error
|
|
|
|
|
ticket, err = rw.getTicket(i)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return NoRowsAffected, fmt.Errorf("delete: unable to get ticket: %w", err)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
underlying := rw.underlying.Delete(i)
|
|
|
|
|
if underlying.Error != nil {
|
|
|
|
|
return NoRowsAffected, fmt.Errorf("delete: failed %w", underlying.Error)
|
|
|
|
|
}
|
|
|
|
|
rowsDeleted := int(underlying.RowsAffected)
|
|
|
|
|
if withOplog && rowsDeleted > 0 {
|
|
|
|
|
if err := rw.addOplog(ctx, DeleteOp, opts, i); err != nil {
|
|
|
|
|
if err := rw.addOplog(ctx, DeleteOp, opts, ticket, i); err != nil {
|
|
|
|
|
return rowsDeleted, fmt.Errorf("delete: add oplog failed %w", err)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
@ -359,22 +383,38 @@ func validateOplogArgs(i interface{}, opts Options) (oplog.ReplayableMessage, er
|
|
|
|
|
return replayable, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (rw *Db) addOplog(ctx context.Context, opType OpType, opts Options, i interface{}) error {
|
|
|
|
|
func (rw *Db) getTicket(i interface{}) (*store.Ticket, error) {
|
|
|
|
|
if rw.underlying == nil {
|
|
|
|
|
return nil, fmt.Errorf("get ticket: underlying db missing: %w", ErrNilParameter)
|
|
|
|
|
}
|
|
|
|
|
replayable, ok := i.(oplog.ReplayableMessage)
|
|
|
|
|
if !ok {
|
|
|
|
|
return nil, fmt.Errorf("get ticket: not a replayable message %w", ErrInvalidParameter)
|
|
|
|
|
}
|
|
|
|
|
ticketer, err := oplog.NewGormTicketer(rw.underlying, oplog.WithAggregateNames(true))
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, fmt.Errorf("get ticket: unable to get Ticketer %w", err)
|
|
|
|
|
}
|
|
|
|
|
ticket, err := ticketer.GetTicket(replayable.TableName())
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, fmt.Errorf("get ticket: unable to get ticket %w", err)
|
|
|
|
|
}
|
|
|
|
|
return ticket, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (rw *Db) addOplog(ctx context.Context, opType OpType, opts Options, ticket *store.Ticket, i interface{}) error {
|
|
|
|
|
oplogArgs := opts.oplogOpts
|
|
|
|
|
replayable, err := validateOplogArgs(i, opts)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
gdb := rw.underlying
|
|
|
|
|
ticketer, err := oplog.NewGormTicketer(gdb, oplog.WithAggregateNames(true))
|
|
|
|
|
if err != nil {
|
|
|
|
|
return fmt.Errorf("error getting Ticketer %w for WithOplog", err)
|
|
|
|
|
if ticket == nil {
|
|
|
|
|
return fmt.Errorf("add oplog: missing ticket %w", ErrNilParameter)
|
|
|
|
|
}
|
|
|
|
|
ticket, err := ticketer.GetTicket(replayable.TableName())
|
|
|
|
|
ticketer, err := oplog.NewGormTicketer(rw.underlying, oplog.WithAggregateNames(true))
|
|
|
|
|
if err != nil {
|
|
|
|
|
return fmt.Errorf("error getting ticket %w for WithOplog", err)
|
|
|
|
|
return fmt.Errorf("add oplog: unable to get Ticketer %w", err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
entry, err := oplog.NewEntry(
|
|
|
|
|
replayable.TableName(),
|
|
|
|
|
oplogArgs.metadata,
|
|
|
|
|
@ -398,16 +438,16 @@ func (rw *Db) addOplog(ctx context.Context, opType OpType, opts Options, i inter
|
|
|
|
|
case DeleteOp:
|
|
|
|
|
msg.OpType = oplog.OpType_OP_TYPE_DELETE
|
|
|
|
|
default:
|
|
|
|
|
return fmt.Errorf("error operation type %v is not supported", opType)
|
|
|
|
|
return fmt.Errorf("add oplog: operation type %v is not supported", opType)
|
|
|
|
|
}
|
|
|
|
|
err = entry.WriteEntryWith(
|
|
|
|
|
ctx,
|
|
|
|
|
&oplog.GormWriter{Tx: gdb},
|
|
|
|
|
&oplog.GormWriter{Tx: rw.underlying},
|
|
|
|
|
ticket,
|
|
|
|
|
&msg,
|
|
|
|
|
)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return fmt.Errorf("error creating oplog entry %w for WithOplog", err)
|
|
|
|
|
return fmt.Errorf("add oplog: unable to write oplog entry: %w", err)
|
|
|
|
|
}
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|