get oplog Ticket before writing data (#127)

pull/132/head
Jim 6 years ago committed by GitHub
parent 342bef7767
commit c413e013bd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -11,6 +11,7 @@ import (
"github.com/hashicorp/watchtower/internal/db/common"
"github.com/hashicorp/watchtower/internal/oplog"
"github.com/hashicorp/watchtower/internal/oplog/store"
"github.com/jinzhu/gorm"
"google.golang.org/protobuf/proto"
)
@ -195,11 +196,19 @@ func (rw *Db) Create(ctx context.Context, i interface{}, opt ...Option) error {
return fmt.Errorf("create: vet for write failed %w", err)
}
}
var ticket *store.Ticket
if withOplog {
var err error
ticket, err = rw.getTicket(i)
if err != nil {
return fmt.Errorf("create: unable to get ticket: %w", err)
}
}
if err := rw.underlying.Create(i).Error; err != nil {
return fmt.Errorf("create: failed %w", err)
}
if withOplog {
if err := rw.addOplog(ctx, CreateOp, opts, i); err != nil {
if err := rw.addOplog(ctx, CreateOp, opts, ticket, i); err != nil {
return err
}
}
@ -272,7 +281,14 @@ func (rw *Db) Update(ctx context.Context, i interface{}, fieldMaskPaths []string
return NoRowsAffected, fmt.Errorf("update: vet for write failed %w", err)
}
}
var ticket *store.Ticket
if withOplog {
var err error
ticket, err = rw.getTicket(i)
if err != nil {
return NoRowsAffected, fmt.Errorf("update: unable to get ticket: %w", err)
}
}
underlying := rw.underlying.Model(i).Updates(updateFields)
if underlying.Error != nil {
if err == gorm.ErrRecordNotFound {
@ -294,7 +310,7 @@ func (rw *Db) Update(ctx context.Context, i interface{}, fieldMaskPaths []string
WithFieldMaskPaths: oplogFieldMasks,
WithNullPaths: oplogNullPaths,
}
if err := rw.addOplog(ctx, UpdateOp, oplogOpts, i); err != nil {
if err := rw.addOplog(ctx, UpdateOp, oplogOpts, ticket, i); err != nil {
return rowsUpdated, fmt.Errorf("update: add oplog failed %w", err)
}
}
@ -331,13 +347,21 @@ func (rw *Db) Delete(ctx context.Context, i interface{}, opt ...Option) (int, er
return NoRowsAffected, fmt.Errorf("delete: oplog validation failed %w", err)
}
}
var ticket *store.Ticket
if withOplog {
var err error
ticket, err = rw.getTicket(i)
if err != nil {
return NoRowsAffected, fmt.Errorf("delete: unable to get ticket: %w", err)
}
}
underlying := rw.underlying.Delete(i)
if underlying.Error != nil {
return NoRowsAffected, fmt.Errorf("delete: failed %w", underlying.Error)
}
rowsDeleted := int(underlying.RowsAffected)
if withOplog && rowsDeleted > 0 {
if err := rw.addOplog(ctx, DeleteOp, opts, i); err != nil {
if err := rw.addOplog(ctx, DeleteOp, opts, ticket, i); err != nil {
return rowsDeleted, fmt.Errorf("delete: add oplog failed %w", err)
}
}
@ -359,22 +383,38 @@ func validateOplogArgs(i interface{}, opts Options) (oplog.ReplayableMessage, er
return replayable, nil
}
func (rw *Db) addOplog(ctx context.Context, opType OpType, opts Options, i interface{}) error {
func (rw *Db) getTicket(i interface{}) (*store.Ticket, error) {
if rw.underlying == nil {
return nil, fmt.Errorf("get ticket: underlying db missing: %w", ErrNilParameter)
}
replayable, ok := i.(oplog.ReplayableMessage)
if !ok {
return nil, fmt.Errorf("get ticket: not a replayable message %w", ErrInvalidParameter)
}
ticketer, err := oplog.NewGormTicketer(rw.underlying, oplog.WithAggregateNames(true))
if err != nil {
return nil, fmt.Errorf("get ticket: unable to get Ticketer %w", err)
}
ticket, err := ticketer.GetTicket(replayable.TableName())
if err != nil {
return nil, fmt.Errorf("get ticket: unable to get ticket %w", err)
}
return ticket, nil
}
func (rw *Db) addOplog(ctx context.Context, opType OpType, opts Options, ticket *store.Ticket, i interface{}) error {
oplogArgs := opts.oplogOpts
replayable, err := validateOplogArgs(i, opts)
if err != nil {
return err
}
gdb := rw.underlying
ticketer, err := oplog.NewGormTicketer(gdb, oplog.WithAggregateNames(true))
if err != nil {
return fmt.Errorf("error getting Ticketer %w for WithOplog", err)
if ticket == nil {
return fmt.Errorf("add oplog: missing ticket %w", ErrNilParameter)
}
ticket, err := ticketer.GetTicket(replayable.TableName())
ticketer, err := oplog.NewGormTicketer(rw.underlying, oplog.WithAggregateNames(true))
if err != nil {
return fmt.Errorf("error getting ticket %w for WithOplog", err)
return fmt.Errorf("add oplog: unable to get Ticketer %w", err)
}
entry, err := oplog.NewEntry(
replayable.TableName(),
oplogArgs.metadata,
@ -398,16 +438,16 @@ func (rw *Db) addOplog(ctx context.Context, opType OpType, opts Options, i inter
case DeleteOp:
msg.OpType = oplog.OpType_OP_TYPE_DELETE
default:
return fmt.Errorf("error operation type %v is not supported", opType)
return fmt.Errorf("add oplog: operation type %v is not supported", opType)
}
err = entry.WriteEntryWith(
ctx,
&oplog.GormWriter{Tx: gdb},
&oplog.GormWriter{Tx: rw.underlying},
ticket,
&msg,
)
if err != nil {
return fmt.Errorf("error creating oplog entry %w for WithOplog", err)
return fmt.Errorf("add oplog: unable to write oplog entry: %w", err)
}
return nil
}

Loading…
Cancel
Save